├── .gitignore
├── LICENSE.md
├── README.md
├── environment.yml
├── inference.py
├── phases.md
├── quobert
├── __init__.py
├── dataprocessing
│ ├── mturk
│ │ ├── mturkify.py
│ │ └── sample-mturk-context.py
│ ├── postprocessing
│ │ ├── process_res.py
│ │ └── speakers_offset.py
│ ├── preprocessing
│ │ ├── bootstrap_EM.py
│ │ ├── extract_entities.py
│ │ ├── features.py
│ │ ├── merge.py
│ │ ├── sampling.py
│ │ └── sampling_uncased.py
│ ├── pygtrie.py
│ └── run.sh
├── model
│ ├── __init__.py
│ ├── eval.py
│ ├── fit.py
│ └── model.py
└── utils
│ ├── __init__.py
│ ├── _utils.py
│ └── data
│ ├── __init__.py
│ ├── dataloader.py
│ └── dataset.py
├── quootstrap
├── README.md
├── extract_quotations.sh
├── lib
│ └── spinn3r-client-3.4.05-edit.jar
├── pom.xml
├── quootstrap.tar.gz
├── resources
│ ├── config.properties
│ └── seedPatterns.txt
└── src
│ └── main
│ └── java
│ └── ch
│ └── epfl
│ └── dlab
│ ├── quootstrap
│ ├── ConfigManager.java
│ ├── ContextExtractor.java
│ ├── DatasetLoader.java
│ ├── Dawg.java
│ ├── Exporter.java
│ ├── ExporterArticle.java
│ ├── ExporterContext.java
│ ├── ExporterSpeakers.java
│ ├── GroundTruthEvaluator.java
│ ├── HashTrie.java
│ ├── HashTriePatternMatcher.java
│ ├── Hashed.java
│ ├── LineageInfo.java
│ ├── MultiCounter.java
│ ├── NameDatabaseWikiData.java
│ ├── ParquetDatasetLoader.java
│ ├── Pattern.java
│ ├── PatternExtractor.java
│ ├── PatternMatcher.java
│ ├── QuotationExtraction.java
│ ├── Sentence.java
│ ├── SimplePatternMatcher.java
│ ├── SpeakerAlias.java
│ ├── Spinn3rDatasetLoader.java
│ ├── Spinn3rDocument.java
│ ├── Spinn3rTextDatasetLoader.java
│ ├── StaticRules.java
│ ├── Token.java
│ ├── Trie.java
│ ├── TriePatternMatcher.java
│ ├── TupleExtractor.java
│ └── Utils.java
│ └── spinn3r
│ ├── EntryWrapper.java
│ ├── Tokenizer.java
│ ├── TokenizerImpl.java
│ └── converter
│ ├── AbstractDecoder.java
│ ├── CombinedDecoder.java
│ ├── Decompressor.java
│ ├── EntryWrapperBuilder.java
│ ├── ProtoToJson.java
│ ├── README.md
│ ├── SpinnerDecoder.java
│ └── Stopwatch.java
├── setup.cfg
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | runs/
132 | quobert/evaluation/
133 | report/
134 | *.ipynb
135 |
136 | quootstrap/target/
137 | quootstrap/.*
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 DLAB
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: base
2 | channels:
3 | - defaults
4 | - pytorch
5 | - conda-forge
6 | dependencies:
7 | - python=3.7
8 | - cudatoolkit=10.1
9 | - pip
10 | - pip:
11 | - tensorboard
12 | - torch===1.4.0
13 | - pandas
14 | - pyarrow
15 | - transformers==2.6
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import logging
4 | import os
5 |
6 | import torch
7 |
8 | from quobert.model import BertForQuotationAttribution, evaluate
9 | from quobert.utils.data import ParquetDataset
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument(
16 | "--model_dir",
17 | default=None,
18 | type=str,
19 | required=True,
20 | help="Path to pre-trained model",
21 | )
22 | parser.add_argument(
23 | "--output_dir",
24 | default=None,
25 | type=str,
26 | required=True,
27 | help="The output directory where the model results will be written.",
28 | )
29 | parser.add_argument(
30 | "--inference_dir",
31 | default=None,
32 | type=str,
33 | required=True,
34 | help="The input inference directory. Should contain (.gz.parquet) files",
35 | )
36 | parser.add_argument(
37 | "--per_gpu_eval_batch_size",
38 | default=256,
39 | type=int,
40 | help="Batch size per GPU/CPU for Inference.",
41 | )
42 |
43 | args = parser.parse_args()
44 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45 | args.n_gpu = torch.cuda.device_count()
46 |
47 | logging.basicConfig(
48 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
49 | datefmt="%m/%d/%Y %H:%M:%S",
50 | level=logging.INFO,
51 | )
52 |
53 | model = BertForQuotationAttribution.from_pretrained(args.model_dir)
54 | model.to(args.device)
55 |
56 | logger.info(f"Started loading the dataset from {args.inference_dir}")
57 | files = sorted(glob.glob(os.path.join(args.inference_dir, "**.gz.parquet")))
58 |
59 | for i, f in enumerate(files):
60 | dataset = ParquetDataset(f)
61 | args.output_file = os.path.join(args.output_dir, f"results_{i:04d}.csv")
62 | evaluate(args, model, dataset, has_target=False)
63 |
--------------------------------------------------------------------------------
/phases.md:
--------------------------------------------------------------------------------
1 | # Spinn3r character encoding: description of Phases A through E
2 |
3 | Quotebank was extracted from news articles that had been collected from the media aggregation service Spinn3r (now called [Datastreamer](https://www.datastreamer.io)).
4 | The Spinn3r data was collected over the course of over a decade.
5 | During this time, the client-side code used for collecting the data changed several times, and various character-encoding-related issues led to different representations of the original text at different times. Most issues relate to capital letters and non-ASCII characters.
6 |
7 | This document, although not entirely conclusive, is the result of an "archæological" endeavor to reconstruct the history of the various character encodings used at different times during Spinn3r data collection.
8 | Based on our insights, we divide the 12 years spanned by the Spinn3r corpus into five phases (Phases A through E), detailed below.
9 | Non-ASCII characters are most relevant for non-English text; capitalization, however, matters for English as well.
10 | For most users, the key takeaways of this document are these:
11 |
12 | 1. Text was lowercased in Phases A, B, and C, whereas the original capitalization was maintained in Phases D and E.
13 | 2. Non-ASCII characters are properly represented only in Phase E.
14 |
15 | (This document is based on an initial write-up made on 6 June 2014.)
16 |
17 |
18 | ## Phase A (until 2010-07-13)
19 |
20 | Spinn3r's probably UTF-8-encoded data was read as Latin-1 (a.k.a. ISO-8859-1). UTF-8 has potentially several bytes per character, while Latin-1 has always one byte per character. That is, a single character from the original data now looks like two characters. For instance, Unicode code point U+00E4 ("Latin small letter a with diaeresis", a.k.a. "ä") is represented by the two-byte code C3A4 in UTF-8. Reading the bytes C3A4 as Latin-1 results in the two-character sequence "ä", since C3 encodes "Ã" in Latin-1, and A4, "¤".
21 | Then, **lowercasing** was performed on the garbled text, making it even more garbled. For instance, "ä" became "ã¤".
22 | Finally, the data was written to disk as UTF-8.
23 |
24 | **Approximate solution:**
25 | Take the debugging table from [http://www.i18nqa.com/debug/utf8-debug.html](https://web.archive.org/web/20210228174408/http://www.i18nqa.com/debug/utf8-debug.html), look for the garbled and lower-cased sequences and replace them by their original character.
26 | Note that the garbling is not bijective, but since most of the garbled sequences are highly unlikely (e.g., "ã¤"), this should be mostly fine.
27 |
28 |
29 | ## Phase B (2010-07-14 to 2010-07-26)
30 |
31 | For just about two weeks, the data seems to have been read as UTF-8 and written as Latin-1 (i.e., the other way round than in phase A).
32 | Non-Latin-1 characters are printed as "?". However, there also seem to be a very few cases as in Phase A.
33 | All text was **lowercased** in this phase.
34 |
35 | **Approximate solution:**
36 | Simply read the data as Latin-1.
37 |
38 |
39 | ## Phase C (2010-07-27 to 2013-04-28)
40 |
41 | The data was written to disk as ASCII, such that all non-ASCII characters (including Latin-1 characters) appear as "?".
42 | All text was **lowercased** in this phase.
43 |
44 | **Approximate solution:**
45 | None. We simply need to byte (haha...) the bullet and deal with the question marks.
46 |
47 |
48 | ## Phase D (2013-04-29 to 2014-05-21)
49 |
50 | Attempt 1 at fixing the above legacy issues:
51 | capitalization is kept as in the original text obtained from Spinn3r.
52 | However, due to a bad BASH environment variable, text was written to disk as ASCII, such that non-ASCII characters still appear as "?".
53 |
54 |
55 | ## Phase E (since 2014-05-22)
56 |
57 | Attempt 2 at fixing the above legacy issues:
58 | capitalization is kept as in the original text obtained from Spinn3r, and output is now finally written as proper UTF-8 Unicode.
59 |
--------------------------------------------------------------------------------
/quobert/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-dlab/Quotebank/a203acf5a05e34841671578d14d6a0be6a66a3ef/quobert/__init__.py
--------------------------------------------------------------------------------
/quobert/dataprocessing/mturk/sample-mturk-context.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | from pyspark.sql import SparkSession
5 | import pyspark.sql.functions as F
6 |
7 |
8 | def extract(
9 | spark: SparkSession,
10 | *,
11 | mturk_path: str,
12 | output_path: str,
13 | nb_partition: int,
14 | compression: str = "gzip",
15 | ):
16 | df = spark.read.parquet(mturk_path)
17 | df_key = df.withColumn(
18 | "key", F.substring(df.articleUID, 1, 6).cast("int")
19 | )
20 | key_proportions = df_key.groupBy("key").count().orderBy("key").collect()
21 | proportions = {r.key: 100 / r["count"] for r in key_proportions}
22 | selected_keys = df_key.sampleBy("key", proportions, seed=1).withColumn("source", F.lit("normal")).cache()
23 | print("number of normal quotation", selected_keys.count())
24 |
25 | speaker_proportions = df_key.groupBy("nb_entities").count().orderBy("nb_entities").collect()
26 | proportions = {r.nb_entities: min(1., 100 / r["count"]) if r.nb_entities <= 20 else 0 for r in speaker_proportions}
27 | selected_speakers = df_key.sampleBy("nb_entities", proportions, seed=2).withColumn("source", F.lit("nb_entities")).cache()
28 | print("number of quotation by more speakers", selected_speakers.count())
29 |
30 | df_filtered = df_key.filter(~df_key.in_quootstrap)
31 | nb_hard = df_filtered.count()
32 | selected_hard = df_filtered.sample(fraction=5000 / nb_hard, seed=3).withColumn("source", F.lit("not_quootstrap")).cache()
33 | print("number of quotation not in quootstrap", selected_hard.count())
34 | to_evaluate = selected_keys.union(selected_speakers).union(selected_hard).dropDuplicates(['articleUID', 'articleOffset']).dropDuplicates(['context']).cache()
35 | print("total after dropDuplicates", to_evaluate.count())
36 | to_evaluate.coalesce(nb_partition).write.parquet(
37 | os.path.join(output_path, "mturk"),
38 | mode="overwrite",
39 | compression=compression,
40 | )
41 |
42 |
43 | if __name__ == "__main__":
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument(
46 | "-m",
47 | "--mturk",
48 | type=str,
49 | help="Path to folder with all raw context for mturk (.gz.parquet)",
50 | required=True,
51 | )
52 | parser.add_argument(
53 | "-o", "--output", type=str, help="Path to output folder", required=True
54 | )
55 | parser.add_argument(
56 | "-n",
57 | "--nb_partition",
58 | type=int,
59 | help="Number of partition for the output (useful if used with unsplittable compression algorithm). Default=10",
60 | default=10,
61 | )
62 | parser.add_argument(
63 | "--compression",
64 | type=str,
65 | help="Compression algorithm. Can be any compatible alogrithm with Spark Parquet. Default=gzip",
66 | default="gzip",
67 | )
68 | parser.add_argument(
69 | "-l",
70 | "--local",
71 | help="Add if you want to execute locally. The code is expected to be run on a cluster if you run on big files",
72 | action="store_true",
73 | )
74 | args = parser.parse_args()
75 |
76 | print("Starting the Spark Session")
77 | if args.local:
78 | import findspark
79 |
80 | findspark.init()
81 |
82 | spark = (
83 | SparkSession.builder.master("local[24]")
84 | .appName("SampleMTurkLocal")
85 | .config("spark.driver.memory", "16g")
86 | .config("spark.executor.memory", "32g")
87 | .getOrCreate()
88 | )
89 | else:
90 | spark = SparkSession.builder.appName("SampleMTurk").getOrCreate()
91 |
92 | print("Starting the merging process")
93 | extract(
94 | spark,
95 | mturk_path=args.mturk,
96 | output_path=args.output,
97 | nb_partition=args.nb_partition,
98 | compression=args.compression,
99 | )
100 |
--------------------------------------------------------------------------------
/quobert/dataprocessing/postprocessing/process_res.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import re
3 | import string
4 | from collections import Counter
5 | from os.path import join
6 |
7 | import pyspark.sql.functions as F
8 | from pyspark.sql import SparkSession, Window
9 | from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StringType
10 |
11 |
12 | @F.udf
13 | def lower(text):
14 | no_punct = "".join(t for t in text if t not in string.punctuation)
15 | one_space = re.sub(r"\W+", " ", no_punct)
16 | return one_space.lower()
17 |
18 |
19 | @F.udf
20 | def longest(quotes):
21 | return sorted(quotes, key=len, reverse=True)[0]
22 |
23 |
24 | @F.udf
25 | def pad_int(idx):
26 | return f"{idx:06d}"
27 |
28 |
29 | def first(col):
30 | return F.first(f"quotes.{col}").alias(col)
31 |
32 |
33 | def get_col(col):
34 | return F.col(f"quotes.{col}").alias(col)
35 |
36 |
37 | @F.udf(returnType=DoubleType())
38 | def normalize(x):
39 | return x / 100.0 if x is not None else None
40 |
41 |
42 | @F.udf
43 | def get_top1(probas):
44 | return probas[0][0]
45 |
46 |
47 | @F.udf(returnType=IntegerType())
48 | def get_article_len(article):
49 | return len(article.split())
50 |
51 |
52 | @F.udf
53 | def get_most_common_name(names):
54 | return sorted(Counter(names).items(), key=lambda x: (-x[1], x[0]))[0][0]
55 |
56 |
57 | @F.udf(returnType=ArrayType(StringType()))
58 | def get_website(array):
59 | return list(dict.fromkeys(x["website"] for x in array))
60 |
61 |
62 | @F.udf(returnType=ArrayType(StringType()))
63 | def get_first_dedup(ids):
64 | return list({x[0] for x in ids})
65 |
66 |
67 | NB_PARTITIONS = 2 ** 11
68 | parser = argparse.ArgumentParser()
69 | parser.add_argument(
70 | "-a",
71 | "--articles",
72 | type=str,
73 | help="Path to folder with all articles (.json.gz)",
74 | required=True,
75 | )
76 | parser.add_argument(
77 | "-c",
78 | "--context",
79 | type=str,
80 | help="Path to folder with all quotes with context (.json.gz)",
81 | required=True,
82 | )
83 | parser.add_argument(
84 | "-s",
85 | "--speakers",
86 | type=str,
87 | help="Path to the augmented speakers folder (.parquet.gz)",
88 | required=True,
89 | )
90 | parser.add_argument(
91 | "-r",
92 | "--results",
93 | type=str,
94 | help="Path to the results from the inference step (.csv)",
95 | required=True,
96 | )
97 | parser.add_argument(
98 | "-o", "--output", type=str, help="Path to output folder", required=True
99 | )
100 | parser.add_argument(
101 | "--partition",
102 | nargs="+",
103 | default=["year", "month"],
104 | help="Column to partition the out with. A choice of year and month. Order matters. Default 'year month'",
105 | )
106 |
107 | args = parser.parse_args()
108 |
109 | spark = SparkSession.builder.appName("processRes").getOrCreate()
110 | qc = spark.read.json(args.context)
111 | articles = spark.read.json(args.articles)
112 | speakers = spark.read.parquet(args.speakers)
113 | res = spark.read.csv(
114 | args.results,
115 | header=True,
116 | schema="articleUID STRING, articleOffset LONG, rank INT, speaker STRING, proba DOUBLE",
117 | )
118 |
119 | quote_article_link = (
120 | qc.join(articles.select("articleUID", "date", "website", "phase"), on="articleUID")
121 | .groupBy(lower(F.col("quotation")).alias("canonicalQuotation"))
122 | .agg(
123 | F.collect_set("quotation").alias("quotations"),
124 | F.min("date").alias("earliest_date"),
125 | F.min("phase").alias("phase"),
126 | F.count("*").alias("numOccurrences"),
127 | F.collect_list(
128 | F.struct(
129 | F.col("articleUID"),
130 | F.col("articleOffset"),
131 | F.col("quotation"),
132 | F.col("leftContext"),
133 | F.col("rightContext"),
134 | F.col("quotationOffset"),
135 | F.col("leftOffset").alias("contextStart"),
136 | F.col("rightOffset").alias("contextEnd"),
137 | )
138 | ).alias("quotes_link"),
139 | get_website(F.sort_array(F.collect_list(F.struct("date", "website")))).alias(
140 | "urls"
141 | ),
142 | )
143 | .withColumn("quotation", longest(F.col("quotations")))
144 | .withColumn(
145 | "row_nb",
146 | F.row_number().over(
147 | Window.partitionBy(F.to_date("earliest_date")).orderBy("canonicalQuotation")
148 | ),
149 | )
150 | .withColumn(
151 | "quoteID", F.concat_ws("-", F.to_date("earliest_date"), pad_int("row_nb")),
152 | )
153 | .withColumn("month", F.month("earliest_date"))
154 | .withColumn("year", F.year("earliest_date"))
155 | .drop("quotations", "row_nb")
156 | )
157 |
158 | joined_df = qc.join(res, on=["articleUID", "articleOffset"])
159 |
160 | w = Window.partitionBy("canonicalQuotation")
161 | rank_w = Window.partitionBy("canonicalQuotation").orderBy(F.desc("sum(proba)"))
162 | agg_proba = (
163 | joined_df.groupBy(lower(F.col("quotation")).alias("canonicalQuotation"), "qids")
164 | .agg(F.sum("proba"), F.collect_list("speaker").alias("speakers"))
165 | .select(
166 | "*",
167 | F.sum("sum(proba)").over(w).alias("weight"),
168 | F.row_number().over(rank_w).alias("rank"),
169 | get_most_common_name("speakers").alias("speaker"),
170 | )
171 | .withColumn("proba", F.round(F.col("sum(proba)") / F.col("weight"), 4))
172 | .filter("proba >= 1e-4")
173 | .drop("sum(porba)", "weight")
174 | )
175 |
176 | agg_proba.write.parquet(
177 | join(args.output, "quotebank-cache-proba"),
178 | mode="overwrite",
179 | compression="gzip",
180 | )
181 | agg_proba = spark.read.parquet(join(args.output, "quotebank-cache-proba"))
182 |
183 | top_speaker = agg_proba.filter("rank = 1").select(
184 | "canonicalQuotation",
185 | F.col("speaker").alias("top_speaker"),
186 | F.col("qids").alias("top_speaker_qid"),
187 | F.col("speakers").alias("top_surface_forms"),
188 | )
189 |
190 | probas = (
191 | agg_proba.orderBy("canonicalQuotation", "rank")
192 | .groupBy("canonicalQuotation")
193 | .agg(
194 | F.collect_list(F.struct(F.col("speaker"), F.col("proba"), F.col("qids"))).alias(
195 | "probas"
196 | )
197 | )
198 | )
199 |
200 | final = quote_article_link.join(top_speaker, on="canonicalQuotation").join(
201 | probas, on="canonicalQuotation"
202 | )
203 | final.write.parquet(
204 | join(args.output, "quotebank-cache1"), mode="overwrite", compression="gzip"
205 | )
206 | final = spark.read.parquet(join(args.output, "quotebank-cache1"))
207 |
208 |
209 | SMALL_COLS = [
210 | "quoteID",
211 | "quotation",
212 | F.col("top_speaker").alias("speaker"),
213 | # F.col("top_speaker_qid").alias("qids"),
214 | F.col("earliest_date").alias("date"),
215 | "numOccurrences",
216 | "probas",
217 | "year",
218 | "month",
219 | "urls",
220 | "phase",
221 | ]
222 |
223 | final.select(*SMALL_COLS).repartition(*args.partition).write.partitionBy(
224 | *args.partition
225 | ).json(
226 | join(args.output, "quotes-df"), mode="overwrite", compression="bzip2",
227 | )
228 |
229 | BIG_COLS = [
230 | "quoteID",
231 | "numOccurrences",
232 | F.col("top_speaker").alias("globalTopSpeaker"),
233 | F.col("probas").alias("globalProbas"),
234 | F.explode("quotes_link").alias("quotes"),
235 | ]
236 |
237 | individual_probas = (
238 | res.filter("proba > 0")
239 | .orderBy("articleUID", "articleOffset", "rank")
240 | .groupBy("articleUID", "articleOffset")
241 | .agg(
242 | F.collect_list(
243 | F.struct("speaker", F.round(normalize("proba"), 4).alias("proba"), "qids")
244 | ).alias("localProbas")
245 | )
246 | .withColumn("speaker", get_top1("localProbas"))
247 | )
248 |
249 | df = final.select(*BIG_COLS)
250 | df = df.join(
251 | individual_probas,
252 | on=[
253 | df.quotes.articleUID == individual_probas.articleUID,
254 | df.quotes.articleOffset == individual_probas.articleOffset,
255 | ],
256 | ).drop("articleUID", "articleOffset")
257 |
258 | article_df = df.groupBy(F.col("quotes.articleUID").alias("articleUID")).agg(
259 | F.collect_list(
260 | F.struct(
261 | "quoteID",
262 | "numOccurrences",
263 | get_col("quotation"),
264 | get_col("quotationOffset"),
265 | get_col("contextStart"),
266 | get_col("contextEnd"),
267 | "globalTopSpeaker",
268 | "globalProbas",
269 | F.col("speaker").alias("localTopSpeaker"),
270 | "localProbas",
271 | )
272 | ).alias("quotations"),
273 | )
274 |
275 | article_df = (
276 | article_df.join(articles, on="articleUID")
277 | .join(speakers, on="articleUID")
278 | .withColumn("articleLength", get_article_len("content"))
279 | .withColumnRenamed("articleUID", "articleID")
280 | .withColumnRenamed("website", "url")
281 | .withColumn("year", F.year("date"))
282 | .withColumn("month", F.month("date"))
283 | )
284 |
285 | article_df.repartition(*args.partition).write.partitionBy(*args.partition).parquet(
286 | join(args.output, "article-df"), mode="overwrite", compression="gzip",
287 | )
288 |
--------------------------------------------------------------------------------
/quobert/dataprocessing/postprocessing/speakers_offset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import string
3 |
4 | from pyspark.sql import Row, SparkSession
5 |
6 | from pygtrie import StringTrie # type: ignore # Sideloaded in the spark-submit
7 |
8 | PUNCT = "".join(x for x in string.punctuation if x not in "[]")
9 | BLACK_LIST = {"president", "manager"}
10 |
11 |
12 | def create_trie(names):
13 | trie = StringTrie(delimiter="/")
14 | for name in names:
15 | processed_name = [
16 | x for x in name["name"].split() if x.lower() not in BLACK_LIST
17 | ]
18 | for i in range(len(processed_name)):
19 | trie["/".join(processed_name[i:]).lower()] = (name["name"], name["ids"])
20 | return trie
21 |
22 |
23 | def update_entity(name, qinfo, start, end, out):
24 | if name in out:
25 | out[name]["offsets"].append([start, end])
26 | else:
27 | out[name] = {
28 | "name": name,
29 | "ids": sorted({x[0] for x in qinfo}),
30 | "offsets": [[start, end]],
31 | }
32 |
33 |
34 | def reduce_entities(entities):
35 | out = dict()
36 | for i, (value, start, end) in entities.items():
37 | if isinstance(value, list):
38 | for name, qinfo in value:
39 | update_entity(name, qinfo, start, end, out)
40 | else:
41 | try:
42 | update_entity(value[0], value[1], start, end, out)
43 | except IndexError: # this happens when value = '/', which is rare
44 | print(value, "We had an error here")
45 | return list(out.values())
46 |
47 |
48 | def get_partial_match(trie, key):
49 | match = list(trie[key:])
50 | return match if len(match) > 1 else match[0]
51 |
52 |
53 | def get_entity(trie, key):
54 | entites = get_partial_match(trie, key)
55 | if not type(entites) == tuple:
56 | raise Exception(f"{entites} is not a tuple")
57 | return entites
58 |
59 |
60 | def find_entites(text: str, trie: StringTrie):
61 | tokens = text.split()
62 | start = 0
63 | count = 1 # start at 1, 0 is for the "NO_MATCH"
64 | entities = dict()
65 | for i in range(len(tokens)):
66 | key = "/".join(tokens[start : i + 1]).lower()
67 | if trie.has_subtrie(key): # Not done yet
68 | if i == len(tokens) - 1: # Reached the end of the string
69 | entities[count] = (get_entity(trie, key), start, i + 1)
70 | elif trie.has_key(key): # noqa: W601 # Find a perfect match
71 | entities[count] = (trie[key], start, i + 1)
72 | count += 1
73 | start = i + 1
74 | elif start < i: # Found partial prefix match before this token
75 | old_key = "/".join(tokens[start:i]).lower()
76 | entities[count] = (get_entity(trie, old_key), start, i)
77 | count += 1
78 | if trie.has_node(
79 | tokens[i].lower()
80 | ): # Need to verify that the current token isn't in the Trie
81 | start = i
82 | else:
83 | start = i + 1
84 | else: # No match
85 | start = i + 1
86 | return reduce_entities(entities)
87 |
88 |
89 | def transform(x: Row):
90 | trie = create_trie(x.names)
91 | try:
92 | entities = find_entites(x.content, trie)
93 | except Exception:
94 | return None
95 |
96 | return Row(articleUID=x.articleUID, names=entities,)
97 |
98 |
99 | def speakers_offset(
100 | spark: SparkSession,
101 | *,
102 | articles_path: str,
103 | speakers_path: str,
104 | output_path: str,
105 | nb_partition: int,
106 | compression: str = "gzip",
107 | ):
108 | df = (
109 | spark.read.json(articles_path)
110 | .select("articleUID", "content")
111 | .repartition(nb_partition)
112 | )
113 | speakers = spark.read.json(speakers_path)
114 | joined = df.join(speakers, on="articleUID")
115 |
116 | transformed = joined.rdd.map(transform).filter(lambda x: x is not None).toDF()
117 | transformed.write.parquet(output_path, "overwrite", compression=compression)
118 |
119 |
120 | if __name__ == "__main__":
121 | parser = argparse.ArgumentParser()
122 | parser.add_argument(
123 | "-a",
124 | "--articles",
125 | type=str,
126 | help="Path to the articles (.json)",
127 | required=True,
128 | )
129 | parser.add_argument(
130 | "-s",
131 | "--speakers",
132 | type=str,
133 | help="Path to the speakers folder (.json)",
134 | required=True,
135 | )
136 | parser.add_argument(
137 | "-o",
138 | "--output",
139 | type=str,
140 | help="Path to output folder for the transformed speaker offsets",
141 | required=True,
142 | )
143 | parser.add_argument(
144 | "-l",
145 | "--local",
146 | help="Add if you want to execute locally. The code is expected to be run on a cluster if you run on big files",
147 | action="store_true",
148 | )
149 | parser.add_argument(
150 | "-n",
151 | "--nb_partition",
152 | type=int,
153 | help="Number of partition for the output (useful if used with unsplittable compression algorithm). Default=50",
154 | default=200,
155 | )
156 | parser.add_argument(
157 | "--compression",
158 | type=str,
159 | help="Compression algorithm. Can be any compatible alogrithm with Spark Parquet. Default=gzip",
160 | default="gzip",
161 | )
162 | args = parser.parse_args()
163 |
164 | if args.local:
165 | # import findspark
166 |
167 | # findspark.init()
168 |
169 | spark = (
170 | SparkSession.builder.master("local[24]")
171 | .appName("SpeakerOffsetsLocal")
172 | .config("spark.driver.memory", "16g")
173 | .config("spark.executor.memory", "32g")
174 | .getOrCreate()
175 | )
176 | else:
177 | spark = SparkSession.builder.appName("SpeakerOffsets").getOrCreate()
178 |
179 | speakers_offset(
180 | spark,
181 | articles_path=args.articles,
182 | speakers_path=args.speakers,
183 | output_path=args.output,
184 | nb_partition=args.nb_partition,
185 | compression=args.compression,
186 | )
187 |
--------------------------------------------------------------------------------
/quobert/dataprocessing/preprocessing/bootstrap_EM.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from string import punctuation
3 |
4 | import pyspark.sql.functions as F
5 | from pyspark.sql import SparkSession
6 |
7 |
8 | @F.udf()
9 | def process_quote(quote):
10 | return "".join(x for x in quote.lower() if x not in punctuation)
11 |
12 |
13 | @F.udf()
14 | def most_frequent(x):
15 | return max(set(x), key=x.count)
16 |
17 |
18 | def find_exact_match(
19 | spark: SparkSession,
20 | quotes_context_path: str,
21 | quootstrap_path: str,
22 | output_path: str,
23 | nb_partition: int = 200,
24 | compression: str = "gzip",
25 | ):
26 | quootstrap_df = spark.read.json(quootstrap_path)
27 | quotes_context_df = spark.read.json(quotes_context_path)
28 |
29 | q2 = quootstrap_df.select(F.explode("occurrences").alias("occurrence"))
30 | fields_to_keep = [
31 | q2.occurrence.articleUID.alias("articleUID"),
32 | q2.occurrence.articleOffset.alias("articleOffset"),
33 | ]
34 |
35 | attributed_quotes_df = q2.select(*fields_to_keep)
36 |
37 | new_quotes_context_df = quotes_context_df.join(
38 | attributed_quotes_df, on=["articleUID", "articleOffset"], how="left_anti",
39 | )
40 |
41 | quootstrap_df.select(
42 | most_frequent("speaker").alias("speaker"),
43 | process_quote("quotation").alias("uncased_quote"),
44 | ).join(
45 | new_quotes_context_df.withColumn("uncased_quote", process_quote("quotation")),
46 | on="uncased_quote",
47 | ).drop(
48 | "uncased_quote"
49 | ).repartition(
50 | nb_partition
51 | ).write.parquet(
52 | output_path, "overwrite", compression=compression
53 | )
54 |
55 |
56 | if __name__ == "__main__":
57 | parser = argparse.ArgumentParser()
58 | parser.add_argument(
59 | "-q",
60 | "--quootstrap",
61 | type=str,
62 | help="Path to Quoostrap output (.json)",
63 | required=True,
64 | )
65 | parser.add_argument(
66 | "-c",
67 | "--context",
68 | type=str,
69 | help="Path to folder with all quotes with context (.json.gz)",
70 | required=True,
71 | )
72 | parser.add_argument(
73 | "-o", "--output", type=str, help="Path to output folder", required=True
74 | )
75 | parser.add_argument(
76 | "-l",
77 | "--local",
78 | help="Add if you want to execute locally.",
79 | action="store_true",
80 | )
81 | parser.add_argument(
82 | "-n",
83 | "--nb_partition",
84 | type=int,
85 | help="Number of partition for the output (useful if used with unsplittable compression algorithm). Default=200",
86 | default=200,
87 | )
88 | parser.add_argument(
89 | "--compression",
90 | type=str,
91 | help="Compression algorithm. Can be any compatible alogrithm with Spark Parquet. Default=gzip",
92 | default="gzip",
93 | )
94 | args = parser.parse_args()
95 |
96 | print("Starting the Spark Session")
97 | if args.local:
98 | # import findspark
99 |
100 | # findspark.init()
101 |
102 | spark = (
103 | SparkSession.builder.master("local[24]")
104 | .appName("BootstrapLocal")
105 | .config("spark.driver.memory", "32g")
106 | .config("spark.executor.memory", "32g")
107 | .getOrCreate()
108 | )
109 | else:
110 | spark = SparkSession.builder.appName("Bootstrap_EM").getOrCreate()
111 |
112 | find_exact_match(
113 | spark,
114 | quootstrap_path=args.quootstrap,
115 | quotes_context_path=args.context,
116 | output_path=args.output,
117 | nb_partition=args.nb_partition,
118 | compression=args.compression,
119 | )
120 |
--------------------------------------------------------------------------------
/quobert/dataprocessing/preprocessing/features.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import re
4 | import signal
5 | import string
6 | from typing import List, Optional, Union
7 |
8 | from pyspark.broadcast import Broadcast
9 | from pyspark.sql import Row, SparkSession
10 | from transformers import AutoTokenizer
11 |
12 | QUOTE_TOKEN = "[QUOTE]"
13 | QUOTE_TARGET_TOKEN = "[TARGET_QUOTE]"
14 | MASK_IDX = 103 # Index in BERT wordpiece
15 |
16 | PUNCT = re.escape("".join(x for x in string.punctuation if x not in "[]"))
17 |
18 |
19 | class TimeoutError(Exception):
20 | pass
21 |
22 |
23 | def handler(signum, frame):
24 | raise TimeoutError()
25 |
26 |
27 | def timeout(func, args=(), kwargs={}, timeout_duration=5, default=None):
28 | # set the timeout handler
29 | signal.signal(signal.SIGALRM, handler)
30 | signal.alarm(timeout_duration)
31 | try:
32 | result = func(*args, **kwargs)
33 | except TimeoutError:
34 | print(f"Timeout Error running {func} with {args} and {kwargs}")
35 | result = default
36 | finally:
37 | signal.alarm(0)
38 |
39 | return result
40 |
41 |
42 | def clean_text(masked_text):
43 | masked_text = re.sub(
44 | r"""(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:'".,<>?«»“”‘’]))""",
45 | " ",
46 | masked_text,
47 | )
48 | masked_text = re.sub("[" + PUNCT + "]{4,}", " ", masked_text)
49 | return masked_text
50 |
51 |
52 | def example_to_features(
53 | articleUID: str,
54 | articleOffset: int,
55 | masked_text: str,
56 | speaker: str,
57 | targets: List[int],
58 | entities,
59 | tokenizer: Union[AutoTokenizer, Broadcast],
60 | max_seq_len: int = 320,
61 | pad_to_max_length: bool = True,
62 | ) -> Optional[Row]:
63 | """Transform examples to QuoteFeatures row. Given the context and the speaker,
64 | extract the start/end offset that match the best the speaker.
65 | Those offsets will be used as targets for the models.
66 |
67 | Args:
68 | articleUID (str): The unique identifier of the associated article
69 | articleOffset (int): The offset (in number of quotes) in the associated article
70 | left_context (str): The text left of the quote
71 | quotation (str): The quote
72 | right_context (str): The text right of the quote
73 | speaker (str): The identified speaker
74 | tokenizer (BertTokenizer): The tokenizer to use to compute the features
75 | max_len_quotation (int, optional): Maximum length for the quotation in tokens. Defaults to 100.
76 | max_seq_len (int, optional): Maximum sequence length in tokens, extra tokens will be dropped. Defaults to 320.
77 | pad_to_max_length (bool, optional): Wheter to pad the tokens to `max_seq_len`. Pad on the right. Defaults to True.
78 |
79 | Returns:
80 | Optional[Row]: The features
81 | """
82 | tokenizer = tokenizer.value
83 |
84 | masked_text = timeout(clean_text, args=(masked_text,), default="")
85 | tokenized = timeout(tokenizer.tokenize, args=(masked_text,))
86 | if not tokenized or len(tokenized) > max_seq_len:
87 | # print(len(tokenized), masked_text)
88 | return None
89 | encoded = timeout(tokenizer.encode, args=(tokenized,))
90 |
91 | mask_idx = [0] + [
92 | i for i, idx in enumerate(encoded) if idx == MASK_IDX
93 | ] # indexes of [CLS] and [MASK] token
94 |
95 | # if len(mask_idx) < 2: # This should *NOT* happen
96 | # # print("No mask token in", masked_text)
97 | # return None
98 |
99 | return Row(
100 | uid=articleUID + " " + str(articleOffset),
101 | input_ids=encoded,
102 | mask_idx=mask_idx,
103 | target=targets[0],
104 | entities=json.dumps(entities),
105 | speaker=speaker,
106 | )
107 |
108 |
109 | def transform_to_features(
110 | spark: SparkSession,
111 | *,
112 | transformed_path: str,
113 | tokenizer_model: str,
114 | output_path: str,
115 | nb_partition: int,
116 | compression: str,
117 | kind: str,
118 | ):
119 | """Entire transformation pipeline. Entry point to the process.
120 | Create a tokenizer from the model. Read the merged data and do the transformations.
121 | Finally write the resulting Dataset/Dataframe to the disk
122 |
123 | Args:
124 | merged_path (str): Path to the folder containing the merged data
125 | tokenizer_model (str): Model of the tokenizer. Must be supported by `transformers`
126 | output_path (str): Path to the output folder
127 | nb_partition (int): Number of partition for the output
128 | compression (str, optional): A parquet compatible compression algorithm. Defaults to 'gzip'.
129 | """
130 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)
131 | tokenizer.add_tokens([QUOTE_TOKEN, QUOTE_TARGET_TOKEN])
132 | tokenizer_bc = spark.sparkContext.broadcast(tokenizer)
133 |
134 | def __example_to_features(row: Row) -> Optional[Row]:
135 | return example_to_features(
136 | row.articleUID,
137 | row.articleOffset,
138 | row.masked_text,
139 | row.speaker if kind == "train" else "",
140 | row.targets if kind == "train" else [-1],
141 | row.entities,
142 | tokenizer_bc,
143 | pad_to_max_length=False,
144 | )
145 |
146 | transformed_df = spark.read.parquet(transformed_path)
147 |
148 | output_df = (
149 | transformed_df.rdd.repartition(2 ** 7)
150 | .map(__example_to_features)
151 | .filter(lambda x: x is not None)
152 | .toDF()
153 | )
154 |
155 | if kind == "test":
156 | output_df.drop("speaker", "target") # .coalesce(nb_partition)
157 | output_df.write.parquet(output_path, "overwrite", compression=compression)
158 |
159 |
160 | if __name__ == "__main__":
161 | parser = argparse.ArgumentParser()
162 | parser.add_argument(
163 | "-t",
164 | "--transformed",
165 | type=str,
166 | help="Path to the transformed (sampled) output folder (.parquet)",
167 | required=True,
168 | )
169 | parser.add_argument(
170 | "--tokenizer",
171 | type=str,
172 | help="Name of the pretrained model, default: bert-base-cased",
173 | default="bert-base-cased",
174 | )
175 | parser.add_argument(
176 | "-o", "--output", type=str, help="Path to output folder", required=True
177 | )
178 | parser.add_argument(
179 | "-l",
180 | "--local",
181 | help="Add if you want to execute locally. The code is expected to be run on a cluster if you run on big files",
182 | action="store_true",
183 | )
184 | parser.add_argument(
185 | "-n",
186 | "--nb_partition",
187 | type=int,
188 | help="Number of partition for the output (useful if used with unsplittable compression algorithm). Default=50",
189 | default=50,
190 | )
191 | parser.add_argument(
192 | "--compression",
193 | type=str,
194 | help="Compression algorithm. Can be any compatible alogrithm with Spark Parquet. Default=gzip",
195 | default="gzip",
196 | )
197 | parser.add_argument(
198 | "--kind",
199 | type=str,
200 | help="Which kind of data it is to transform (train = with labels, test = without labels)",
201 | required=True,
202 | choices=["train", "test"],
203 | )
204 | args = parser.parse_args()
205 |
206 | print("Starting the Spark Session")
207 | if args.local:
208 | # import findspark
209 |
210 | # findspark.init()
211 |
212 | spark = (
213 | SparkSession.builder.master("local[24]")
214 | .appName("FeaturesExtractorLocal")
215 | .config("spark.driver.memory", "16g")
216 | .config("spark.executor.memory", "32g")
217 | .getOrCreate()
218 | )
219 | else:
220 | spark = SparkSession.builder.appName("FeaturesExtractor").getOrCreate()
221 |
222 | print("Starting the transformation to features")
223 | transform_to_features(
224 | spark,
225 | transformed_path=args.transformed,
226 | tokenizer_model=args.tokenizer,
227 | output_path=args.output,
228 | nb_partition=args.nb_partition,
229 | compression=args.compression,
230 | kind=args.kind,
231 | )
232 |
--------------------------------------------------------------------------------
/quobert/dataprocessing/preprocessing/merge.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from urllib.parse import urlparse
3 |
4 | from pyspark.sql import SparkSession
5 | from pyspark.sql.functions import explode, to_date, udf
6 |
7 |
8 | @udf
9 | def get_domain(x):
10 | return urlparse(x).netloc
11 |
12 |
13 | def merge(
14 | spark: SparkSession,
15 | *,
16 | quootstrap_path: str,
17 | quotes_context_path: str,
18 | output_path: str,
19 | nb_partition: int,
20 | compression: str = "gzip",
21 | ):
22 | """ Merge The output from Quootstrap together in order to create an input training set
23 |
24 | Args:
25 | spark (SparkSession): Current spark session
26 | quootstrap_path (str): HDFS path to the Quootstrap output (Q, S) pairs
27 | quotes_context_path (str): HDFS path to the quotes+context
28 | output_path (str): HDFS path to store the output of the merge
29 | nb_partition (int): Number of partition for the output
30 | compression (str, optional): A parquet compatible compression algorithm. Defaults to 'gzip'.
31 | """
32 | quootstrap_df = spark.read.json(quootstrap_path)
33 | quotes_context_df = spark.read.json(quotes_context_path)
34 |
35 | # Extract all quotes and speakers from Quootstrap data
36 | q2 = quootstrap_df.select("quotation", explode("occurrences").alias("occurrence"))
37 |
38 | fields_to_keep = [
39 | q2.occurrence.articleUID.alias("articleUID"),
40 | q2.occurrence.articleOffset.alias("articleOffset"),
41 | q2.occurrence.matchedSpeakerTokens.alias("speaker"),
42 | q2.occurrence.extractedBy.alias("pattern"),
43 | get_domain(q2.occurrence.website).alias("domain"),
44 | to_date(q2.occurrence.date).alias("date"),
45 | ]
46 |
47 | attributed_quotes_df = q2.select(*fields_to_keep)
48 |
49 | # Merge df and write parquet files
50 | attributed_quotes_context_df = attributed_quotes_df.join(
51 | quotes_context_df, on=["articleUID", "articleOffset"], how="inner"
52 | )
53 |
54 | attributed_quotes_context_df.repartition(nb_partition).write.parquet(
55 | output_path, "overwrite", compression=compression
56 | )
57 |
58 |
59 | if __name__ == "__main__":
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument(
62 | "-q",
63 | "--quootstrap",
64 | type=str,
65 | help="Path to Quoostrap output (.json)",
66 | required=True,
67 | )
68 | parser.add_argument(
69 | "-c",
70 | "--context",
71 | type=str,
72 | help="Path to folder with all quotes with context (.json.gz)",
73 | required=True,
74 | )
75 | parser.add_argument(
76 | "-o", "--output", type=str, help="Path to output folder", required=True
77 | )
78 | parser.add_argument(
79 | "-n",
80 | "--nb_partition",
81 | type=int,
82 | help="Number of partition for the output (useful if used with unsplittable compression algorithm). Default=50",
83 | default=200,
84 | )
85 | parser.add_argument(
86 | "--compression",
87 | type=str,
88 | help="Compression algorithm. Can be any compatible alogrithm with Spark Parquet. Default=gzip",
89 | default="gzip",
90 | )
91 | parser.add_argument(
92 | "-l",
93 | "--local",
94 | help="Add if you want to execute locally. The code is expected to be run on a cluster if you run on big files",
95 | action="store_true",
96 | )
97 | args = parser.parse_args()
98 |
99 | print("Starting the Spark Session")
100 | if args.local:
101 | import findspark
102 |
103 | findspark.init()
104 |
105 | spark = (
106 | SparkSession.builder.master("local[24]")
107 | .appName("QuoteMergerLocal")
108 | .config("spark.driver.memory", "16g")
109 | .config("spark.executor.memory", "32g")
110 | .getOrCreate()
111 | )
112 | else:
113 | spark = SparkSession.builder.appName("QuoteMerger").getOrCreate()
114 |
115 | print("Starting the merging process")
116 | merge(
117 | spark,
118 | quootstrap_path=args.quootstrap,
119 | quotes_context_path=args.context,
120 | output_path=args.output,
121 | nb_partition=args.nb_partition,
122 | compression=args.compression,
123 | )
124 |
--------------------------------------------------------------------------------
/quobert/dataprocessing/run.sh:
--------------------------------------------------------------------------------
1 | spark-submit \
2 | --master yarn \
3 | --num-executors 16 \
4 | --executor-cores 32 \
5 | --driver-memory 16g \
6 | --executor-memory 64g \
7 | --py-files pygtrie.py \
8 | --conf spark.pyspark.python=python3 \
9 | --conf spark.driver.maxResultSize=0 \
10 | --conf spark.sql.shuffle.partitions=2048 \
11 | --conf spark.executor.memoryOverhead=16g \
12 | --conf spark.blacklist.enabled=true \
13 | --conf spark.reducer.maxReqsInFlight=10 \
14 | --conf spark.shuffle.io.retryWait=10s \
15 | --conf spark.shuffle.io.maxRetries=10 \
16 | --conf spark.shuffle.io.backLog=2048 \
17 | "$@"
18 |
--------------------------------------------------------------------------------
/quobert/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .fit import fit
2 | from .eval import evaluate
3 | from .model import BertForQuotationAttribution
4 |
--------------------------------------------------------------------------------
/quobert/model/eval.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import logging
3 | import os
4 | from operator import itemgetter
5 |
6 | import torch
7 | from torch.utils.data import DataLoader, RandomSampler
8 | from tqdm.auto import tqdm
9 |
10 | from quobert.utils.data import collate_batch_eval
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | def get_most_probable_entity(proba, entities):
16 | most_probable_entity_idx = proba.argmax().item()
17 | for entity, val in entities.items():
18 | if most_probable_entity_idx in val[0]:
19 | return entity
20 | return "None"
21 |
22 |
23 | def evaluate(args, model, dataset, no_save=False, has_target=True, output_proba=True):
24 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
25 | sampler = RandomSampler(dataset)
26 | loader = DataLoader(
27 | dataset,
28 | sampler=sampler,
29 | collate_fn=collate_batch_eval,
30 | batch_size=args.eval_batch_size,
31 | )
32 |
33 | # multi-gpu evaluate
34 | if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
35 | model = torch.nn.DataParallel(model)
36 |
37 | logger.info("*** Start the evaluation porcess ***")
38 | logger.info(f" Number of samples: {len(dataset)}")
39 | logger.info(f" Batch size: {args.eval_batch_size} using {args.n_gpu} GPU(s)")
40 |
41 | correct_pos, correct_neg = 0, 0
42 | total_pos, total_neg = 0, 0
43 | model.eval()
44 | out = []
45 |
46 | for batch in tqdm(loader, desc="Evaluating"):
47 | batch = {
48 | k: v.to(args.device) if hasattr(v, "to") else v for k, v in batch.items()
49 | }
50 | with torch.no_grad():
51 | scores = model(
52 | input_ids=batch["input_ids"],
53 | mask_ids=batch["mask_ids"],
54 | attention_mask=batch["attention_mask"],
55 | )[0].cpu()
56 |
57 | for proba, entities, speaker, uid in zip(
58 | scores, batch["entities"], batch["speakers"], batch["uid"]
59 | ):
60 | speakers_proba = {
61 | entity: proba[val[0]].sum().item() for entity, val in entities.items()
62 | }
63 | speakers_proba["None"] = proba[0].item()
64 | speakers_proba_sorted = sorted(
65 | speakers_proba.items(), key=itemgetter(1), reverse=True
66 | )
67 |
68 | most_probable_speaker = speakers_proba_sorted[0][0]
69 | most_probable_entity = get_most_probable_entity(proba, entities)
70 |
71 | if has_target:
72 | target_speaker = speaker if speaker in entities else "None"
73 | is_correct = most_probable_speaker == target_speaker
74 | if target_speaker == "None":
75 | total_neg += 1
76 | if is_correct:
77 | correct_neg += 1
78 | else:
79 | total_pos += 1
80 | if is_correct:
81 | correct_pos += 1
82 |
83 | out.append(
84 | (
85 | uid,
86 | speakers_proba_sorted,
87 | most_probable_speaker,
88 | most_probable_entity,
89 | )
90 | )
91 |
92 | if has_target:
93 | EM_neg = correct_neg / total_neg
94 | EM_pos = correct_pos / total_pos
95 | total = total_neg + total_pos
96 | EM = (correct_neg + correct_pos) / (total_neg + total_pos)
97 | logger.info(f"EM value: {EM:.2%}%, total: {total}")
98 | logger.info(
99 | f"EM pos: {EM_pos:.2%}%, total: {total_pos} ({total_pos / total:.2%})"
100 | )
101 | logger.info(
102 | f"EM neg: {EM_neg:.2%}%, total: {total_neg} ({total_neg / total:.2%})"
103 | )
104 |
105 | if not no_save:
106 | with open(
107 | os.path.join(args.output_file), "w", encoding="utf-8", newline="",
108 | ) as csvfile:
109 | csvwriter = csv.writer(csvfile)
110 | if output_proba:
111 | csvwriter.writerow(
112 | ["articleUID", "articleOffset", "rank", "speaker", "proba"]
113 | )
114 | for uid, speakers_proba, _, _ in out:
115 | articleUID, articleOffset = uid.split()
116 | for i, (speaker, proba) in enumerate(speakers_proba):
117 | csvwriter.writerow(
118 | [
119 | articleUID,
120 | articleOffset,
121 | i,
122 | speaker,
123 | round(proba * 100, 2),
124 | ]
125 | )
126 | else:
127 | csvwriter.writerow(
128 | ["articleUID", "articleOffset", "sum_speaker", "max_speaker"]
129 | )
130 | for uid, _, sum_speaker, max_speaker in out:
131 | articleUID, articleOffset = uid.split()
132 | csvwriter.writerow(
133 | [articleUID, articleOffset, sum_speaker, max_speaker]
134 | )
135 |
--------------------------------------------------------------------------------
/quobert/model/model.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | from torch.nn import CrossEntropyLoss, Linear
5 | from torch.nn.functional import softmax
6 | from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence
7 | from transformers import BertModel, BertPreTrainedModel
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class BertForQuotationAttribution(BertPreTrainedModel):
13 | def __init__(self, config):
14 | super().__init__(config)
15 | self.num_labels = 1
16 |
17 | self.bert = BertModel(config)
18 | self.qa_outputs = Linear(config.hidden_size, self.num_labels)
19 |
20 | self.loss_fn = CrossEntropyLoss()
21 | self.init_weights()
22 |
23 | def forward(
24 | self,
25 | input_ids,
26 | mask_ids,
27 | *,
28 | attention_mask=None,
29 | token_type_ids=None,
30 | position_ids=None,
31 | head_mask=None,
32 | inputs_embeds=None,
33 | targets=None,
34 | ):
35 | outputs = self.bert(
36 | input_ids,
37 | attention_mask=attention_mask,
38 | token_type_ids=token_type_ids,
39 | position_ids=position_ids,
40 | head_mask=head_mask,
41 | inputs_embeds=inputs_embeds,
42 | )
43 |
44 | sequence_output = outputs[0]
45 | logits = [
46 | self.qa_outputs(output[mask[mask >= 0]]).squeeze(-1)
47 | for output, mask in zip(sequence_output, mask_ids)
48 | ]
49 |
50 | proba, _ = pad_packed_sequence(
51 | pack_sequence(
52 | [softmax(logit, dim=0) for logit in logits], enforce_sorted=False
53 | ),
54 | batch_first=True,
55 | total_length=100,
56 | )
57 |
58 | # logger.info(f"logits: {logits},\ntargets: {targets}\nmask_ids: {mask_ids}")
59 | outputs = (proba,) + outputs[2:]
60 | if targets is not None:
61 | loss = torch.stack(
62 | [
63 | self.loss_fn(logit[None, :], target[None])
64 | for logit, target in zip(logits, targets)
65 | ]
66 | ).mean()
67 | correct = torch.tensor(
68 | [
69 | 1 if logit.argmax().item() == target.item() else 0
70 | for logit, target in zip(logits, targets)
71 | ],
72 | device=loss.get_device(),
73 | ).sum()
74 | outputs = (loss, correct,) + outputs
75 | return outputs # (loss, correct), logits, (hidden_states), (attentions)
76 |
--------------------------------------------------------------------------------
/quobert/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from ._utils import set_seed, get_device
2 |
--------------------------------------------------------------------------------
/quobert/utils/_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import List, TypeVar
3 |
4 | import numpy as np
5 | import torch
6 |
7 | T = TypeVar("T")
8 |
9 |
10 | def set_seed(seed: int):
11 | """
12 | Fix all possible seeds for reproducibility
13 |
14 | Args:
15 | seed (int): number used to set the seeds
16 | """
17 | random.seed(seed)
18 | np.random.seed(seed)
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed) # type: ignore
21 |
22 |
23 | def get_device() -> torch.device:
24 | """ Check if CUDA is available """
25 | return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26 |
27 |
28 | def to_list(tensor: torch.Tensor) -> List[T]:
29 | return tensor.detach().cpu().tolist()
30 |
--------------------------------------------------------------------------------
/quobert/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataloader import collate_batch_eval, collate_batch_train
2 | from .dataset import ConcatParquetDataset, ParquetDataset
3 |
--------------------------------------------------------------------------------
/quobert/utils/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import json
2 | from functools import partial
3 | from typing import Dict, List
4 |
5 | import pandas as pd
6 | import torch
7 | from torch.nn.utils.rnn import pad_sequence
8 |
9 |
10 | def __collate_batch(
11 | data: List[pd.Series], *, train: bool, test: bool
12 | ) -> Dict[str, torch.Tensor]:
13 | """
14 | Transform a list of `ConcatParquetDataset` entries into a Dict of tensors that can be fed to a BERT model.
15 |
16 | This private method is intented to be used with a partial to set `train` and then be fed to `torch.utils.data.DataLoader` as `collate_fn`
17 |
18 | Args:
19 | data (List[pd.Series]): a list of `ConcatParquetDataset` or `ParquetDataset` entries
20 | train (bool): set to `True` if `start_offset` and `end_offset` will be returned
21 |
22 | Returns:
23 | Dict[str, torch.Tensor]: a dict containing the dataset / sample index, input_ids, mask and if `train` the start and end offset
24 | """
25 | input_ids = pad_sequence(
26 | [torch.tensor(d.input_ids, dtype=torch.long) for d in data], batch_first=True,
27 | ) # (b_size, max_sentence_len)
28 | attention_mask = input_ids.where(
29 | input_ids == 0, torch.tensor(1)
30 | ) # (b_size, max_sentence_len)
31 |
32 | mask_ids = pad_sequence(
33 | [torch.tensor(d.mask_idx, dtype=torch.long) for d in data],
34 | batch_first=True,
35 | padding_value=-1,
36 | ) # mask for indexes of CLS/MASK tokens
37 |
38 | out = {
39 | "input_ids": input_ids,
40 | "attention_mask": attention_mask,
41 | "mask_ids": mask_ids,
42 | }
43 |
44 | if train:
45 | targets = torch.tensor([d.target for d in data], dtype=torch.long) # (b_size, )
46 | out["targets"] = targets
47 |
48 | if test:
49 | entities = [json.loads(d.entities) for d in data]
50 | speakers = [d.speaker if "speaker" in d else "" for d in data]
51 | uid = [d.uid for d in data]
52 | out.update({"entities": entities, "speakers": speakers, "uid": uid})
53 |
54 | return out
55 |
56 |
57 | collate_batch_train = partial(__collate_batch, train=True, test=False)
58 | collate_batch_eval = partial(__collate_batch, train=False, test=True)
59 |
--------------------------------------------------------------------------------
/quobert/utils/data/dataset.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | from typing import Tuple, List
3 |
4 | import pandas as pd
5 | from torch.utils.data import ConcatDataset, Dataset
6 |
7 |
8 | class ParquetDataset(Dataset):
9 | """
10 | Parquet Dataset is a wrapper around `torch.utils.data.Dataset`. It loads and serve a parquet DataFrame.
11 |
12 | Args:
13 | parquet_path (str): The path to a single parquet file, can be compressed. If using multiple file, check `ConcatParquetDataset`
14 | sample_n (int, optional): Set to the number of items from the data set to sample. If 0, use all items. Defaults to 0.
15 | """
16 |
17 | def __init__(
18 | self, parquet_path: str, sample_n: int = 0
19 | ):
20 | super(ParquetDataset, self).__init__()
21 | self.df = pd.read_parquet(parquet_path)
22 | if sample_n > 0:
23 | self.df = self.df.sample(n=sample_n)
24 |
25 | def __len__(self) -> int:
26 | return len(self.df)
27 |
28 | def __getitem__(self, idx: int) -> pd.Series:
29 | return self.df.iloc[idx]
30 |
31 |
32 | class ConcatParquetDataset(ConcatDataset):
33 | """
34 | Concat Parquet Dataset is a wrapper around `torch.utils.data.ConcatDataset`. It serves multiple `ParquetDataset`
35 |
36 | Args:
37 | datasets (List[ParquetDataset]): a list of `ParquetDataset`
38 | """
39 |
40 | def __init__(self, datasets: List[ParquetDataset]):
41 | super(ConcatParquetDataset, self).__init__(datasets)
42 |
43 | def __getitem__(self, idx: int) -> Tuple[int, int, pd.Series]:
44 | if idx < 0:
45 | if -idx > len(self):
46 | raise ValueError(
47 | "absolute value of index should not exceed dataset length"
48 | )
49 | idx = len(self) + idx
50 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
51 | if dataset_idx == 0:
52 | sample_idx = idx
53 | else:
54 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
55 | return self.datasets[dataset_idx][sample_idx]
56 |
--------------------------------------------------------------------------------
/quootstrap/README.md:
--------------------------------------------------------------------------------
1 | # Quootstrap
2 | This folder contains an adapted version of the reference implementation of Quootstrap, as described in the paper [[PDF]](https://dlab.epfl.ch/people/west/pub/Pavllo-Piccardi-West_ICWSM-18.pdf):
3 | > Dario Pavllo, Tiziano Piccardi, Robert West. *Quootstrap: Scalable Unsupervised Extraction of Quotation-Speaker Pairs from Large News Corpora via Bootstrapping*. In *Proceedings of the 12th International Conference on Web and Social Media (ICWSM),* 2018.
4 |
5 | #### Disclaimer
6 |
7 | *This folder contains an improved version of the code from the paper using Wikidata instead of Freebase for entity linking and with the option of saving all the quotes with their context and speakers, instead of only those extracted by the rules. For the code of the paper, please check [this branch]( https://github.com/epfl-dlab/quootstrap/tree/master )*
8 |
9 | ## How to run
10 |
11 | Go to the **Release** section and download the .zip archive, which contains the executable `quootstrap.jar` as well as all necessary dependencies and configuration files. You can also find a convenient script `extraction_quotations.sh` that can be used to run the application on a Yarn cluster. The script runs this command:
12 | ```bash
13 | spark-submit --jars opennlp-tools-1.9.2.jar,spinn3r-client-3.4.05-edit.jar,stanford-corenlp-3.8.0.jar,jsoup-1.10.3.jar,guava-14.0.1.jar \
14 | --num-executors 25 \
15 | --executor-cores 16 \
16 | --driver-memory 128g \
17 | --executor-memory 128g \
18 | --conf "spark.executor.memoryOverhead=32768" \
19 | --class ch.epfl.dlab.quootstrap.QuotationExtraction \
20 | --master yarn \
21 | quootstrap.jar
22 | ```
23 | After tuning the settings to suit your particular configuration, you can run the command as:
24 | ```bash
25 | ./extraction_quotations.sh
26 | ```
27 |
28 | ### Setup
29 |
30 | To run our code, you need:
31 | - Java 8
32 | - Spark 2.3
33 | - The entire Spinn3r dataset (available on the hadoop cluster) or your own dataset
34 | - Our dataset of people extracted from Wikidata (available on `/data`)
35 |
36 | ### How to build
37 |
38 | Clone the repository and import it as an Eclipse project. All dependencies are downloaded through Maven. To build the application, generate a .jar file with all source files and run it as explained in the previous section. Alternatively, you can use Spark in local mode for experimenting. Additional instructions on how to extend the project with new functionalities (e.g. support for new datasets) are reported later.
39 |
40 | ## Configuration
41 | The first configuration file is `config.properties`. The most important fields in order to get the application running are:
42 | - `NEWS_DATASET_PATH` specifies the HDFS path of the Spinn3r news dataset
43 | - `PEOPLE_DATASET_PATH` specifies the HDFS path of the Wikidata people list.
44 | - `NEWS_DATASET_LOADER` specifies which loader to use for the dataset.
45 | - `EXPORT_PATH` specifies the HDFS path for the (quotation, speaker) pairs output.
46 | - `CONTEXT_PATH` specifies the HDFS path for the (quotation, context, lang) output.
47 | - `NUM_ITERATIONS=1` specifies the number of iterations of the extraction algorithm. Set it to 1 if you want to run the algorithm only on the seed patterns (iteration 0). A number of iterations between 3 and 5 is more than enough. **Note:** Currently this feature is not supported anymore and a fix to support wikidata would be appraciated.
48 |
49 | Additionally you can change the flow of the application with the following parameters:
50 |
51 | - `DO_QUOTATION_ATTRIBUTION` & `EXPORT_RESULTS`: Whether to perform Quote Attribution and export the results. This is the original output from Quootstrap.
52 | - `EXPORT_CONTEXT`: Whether to export all the quotes and their context. Not present in the original version of Quootstrap
53 | - `EXPORT_SPEAKERS`: Whether to export all full matches of candidate in a given article. Partial matches are dealt with in Quobert
54 | - `EXPORT_ARTICLE`: Whether to export all articles in a easier to read format (not recommanded if your input data is already easily readable)
55 |
56 | The second configuration file is `seedPatterns.txt`, which, as the name suggests, contains the seed patterns that are used in the first iteration, one by line.
57 |
58 |
59 | ## Exporting results
60 | ### Quotation-speaker pairs
61 |
62 | You can save the results as a HDFS text file formatted in JSON, with one record per line. For each record, the full quotation is exported, as well as the full name of the speaker (as reported in the article), his/her Wikidata ID, the confidence value of the tuple, and the occurrences in which the quotation was found. As for the latter, we report the article ID, an incremental offset within the article (which is useful for linking together split quotations), the pattern that extracted the tuple along with its confidence, the website, and the date the article appeared.
63 |
64 | ```json
65 | {
66 | "canonicalQuotation":"action is easy it is very easy comedy is difficult",
67 | "confidence":1.0,
68 | "numOccurrences":6,
69 | "numSpeakers":1,
70 | "occurrences":[
71 | {"articleOffset":0,
72 | "articleUID":"2012031008_00073468_W",
73 | "date":"2012-03-10 08:49:40",
74 | "extractedBy":"$Q , $S said",
75 | "matchedSpeakerTokens":"Akshay Kumar",
76 | "patternConfidence":1.0,
77 | "quotation":"action is easy, it is very easy. comedy is difficult,",
78 | "website":"http://deccanchronicle.com/channels/showbiz/bollywood/action-easy-comedy-toughest-akshay-kumar-887"},
79 | ...
80 | {"articleOffset":0,
81 | "articleUID":"2012031012_00038989_W",
82 | "date":"2012-03-10 12:23:40",
83 | "extractedBy":"$Q , $S told",
84 | "matchedSpeakerTokens":"Akshay Kumar",
85 | "patternConfidence":0.7360332115513023,
86 | "quotation":"action is easy, it is very easy. comedy is difficult,",
87 | "website":"http://hindustantimes.com/Entertainment/Bollywood/Comedy-is-difficult-Akshay-Kumar/Article1-823315.aspx"}
88 | ],
89 | "quotation":"action is easy, it is very easy. comedy is difficult,",
90 | "speaker":["Akshay Kumar"],
91 | "speakerID":["Q233748"]
92 | }
93 | ```
94 | Remarks:
95 | - *articleOffset* can have gaps (but quotations are guaranteed to be ordered correctly).
96 | - *canonicalQuotation* is the internal representation of a particular quotation, used for pattern matching purposes. The string is converted to lower case and punctuation marks are removed.
97 | - As in the example above, the full quotation might differ from the one(s) found in *occurrences* due to the quotation merging mechanism. We always report the longest (and most likely useful) quotation when multiple choices are possible.
98 |
99 | ## Adding support for new datasets/formats
100 |
101 | If you want to add support for other datasets/formats, you can provide a concrete implementation for the Java interface `DatasetLoader` and specify its full class name in the `NEWS_DATASET_LOADER` field of the configuration. For each article, you must supply a unique ID (int64/long), the website in which it can be found, and its content in tokenized format, i.e. as a list of strings. We provide an implementation for our JSON Spinn3r dataset in `ch.epfl.dlab.quootstrap.Spinn3rDatasetLoader`. For parquet dataframes in `ch.epfl.dlab.quootstrap.ParquetDatasetLoader`, and for Standford Spinn3r data format `ch.epfl.dlab.quootstrap.Spinn3rTextDatasetLoader`.
102 |
103 | ## Replacing the tokenizer
104 | If, for any reason (e.g. license, language other than English), you do not want to depend on Stanford PTBTokenizer, you can provide your own implementation of the `ch.epfl.dlab.spinn3r.Tokenizer` interface. You only have to implement two methods: `tokenize` and `untokenize`. Tokenization is one of the least critical steps in our pipeline, and does not impact the final result significantly.
105 |
106 | ## License
107 | We release our work under the MIT license. Third-party components, such as Stanford CoreNLP, are subject to their respective licenses.
108 |
109 | If you use our code and/or data in your research, please cite our paper [[PDF]](https://dlab.epfl.ch/people/west/pub/Pavllo-Piccardi-West_ICWSM-18.pdf):
110 | ```
111 | @inproceedings{quootstrap2018,
112 | title={Quootstrap: Scalable Unsupervised Extraction of Quotation-Speaker Pairs from Large News Corpora via Bootstrapping},
113 | author={Pavllo, Dario and Piccardi, Tiziano and West, Robert},
114 | booktitle={Proceedings of the 12th International Conference on Web and Social Media (ICWSM)},
115 | year={2018}
116 | }
117 | ```
118 |
--------------------------------------------------------------------------------
/quootstrap/extract_quotations.sh:
--------------------------------------------------------------------------------
1 | spark-submit --jars spinn3r-client-3.4.05-edit.jar,stanford-corenlp-3.8.0.jar,jsoup-1.10.3.jar,guava-14.0.1.jar \
2 | --num-executors 25 \
3 | --executor-cores 16 \
4 | --driver-memory 128g \
5 | --executor-memory 128g \
6 | --conf "spark.executor.memoryOverhead=32768" \
7 | --class ch.epfl.dlab.quootstrap.QuotationExtraction \
8 | --master yarn \
9 | quootstrap.jar
--------------------------------------------------------------------------------
/quootstrap/lib/spinn3r-client-3.4.05-edit.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-dlab/Quotebank/a203acf5a05e34841671578d14d6a0be6a66a3ef/quootstrap/lib/spinn3r-client-3.4.05-edit.jar
--------------------------------------------------------------------------------
/quootstrap/pom.xml:
--------------------------------------------------------------------------------
1 |
2 | 4.0.0
3 | ch.epfl
4 | ProtoToJson
5 | 0.0.1-SNAPSHOT
6 |
7 |
8 | 1.8
9 | 1.8
10 |
11 |
12 |
13 |
14 |
15 | org.apache.spark
16 | spark-sql_2.11
17 | 2.3.1
18 |
19 |
20 | com.google.code.gson
21 | gson
22 | 2.3
23 |
24 |
25 | com.sun.jersey
26 | jersey-server
27 | 1.2
28 |
29 |
30 |
31 | org.jsoup
32 | jsoup
33 | 1.10.3
34 |
35 |
36 | edu.stanford.nlp
37 | stanford-corenlp
38 | 3.8.0
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/quootstrap/quootstrap.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-dlab/Quotebank/a203acf5a05e34841671578d14d6a0be6a66a3ef/quootstrap/quootstrap.tar.gz
--------------------------------------------------------------------------------
/quootstrap/resources/config.properties:
--------------------------------------------------------------------------------
1 | # Basic settings (news dataset path, people list dataset, etc.)
2 | NEWS_DATASET_PATH=/path/to/articles
3 | PEOPLE_DATASET_PATH=/path/to/wikidata_people_ALIVE_FILTERED-NAMES-CLEAN.tsv
4 |
5 | # Number of tokens to be considered as a quotation
6 | MIN_QUOTATION_SIZE=6
7 | MAX_QUOTATION_SIZE=500
8 |
9 | # Case sensitive true == bad idea (no match on broken cases articles)
10 | CASE_SENSITIVE=false
11 |
12 | # Provide the concrete implementation class name of DatasetLoader
13 | NEWS_DATASET_LOADER=ch.epfl.dlab.quootstrap.Spinn3rTextDatasetLoader
14 |
15 | # Settings for exporting results
16 | EXPORT_RESULTS=true
17 | EXPORT_PATH=/path/to/quootstrap
18 | DO_QUOTE_ATTRIBUTION=true
19 |
20 | # Settings for exporting Article / Speakers
21 | EXPORT_SPEAKERS=true
22 | SPEAKERS_PATH=/path/to/speakers
23 |
24 | # Settings for exporting Articles
25 | EXPORT_ARTICLE=false
26 | ARTICLE_PATH=/path/to/articles
27 |
28 | # Settings for exporting the quotes and context of the quotes
29 | EXPORT_CONTEXT=true
30 | CONTEXT_PATH=/path/to/quotes_context
31 | NUM_PARTITIONS=100
32 |
33 |
34 | ###### UNUSED PARAMS ######
35 |
36 | # Note: Currently, only 1 iteration is supported
37 | NUM_ITERATIONS=1
38 |
39 | # Set to true if you want to use Spark in local mode
40 | LOCAL_MODE=false
41 |
42 | # Hyperparameters
43 | PATTERN_CONFIDENCE_THRESHOLD=0.7
44 | PATTERN_CLUSTERING_THRESHOLDS=0|0.0002|0.001|0.005
45 |
46 | # Quotation merging
47 | ENABLE_QUOTATION_MERGING=false
48 | ENABLE_DEDUPLICATION=true
49 | MERGING_SHINGLE_SIZE=10
50 |
51 | # Cache settings: some frequently used (and immutable) RDDs can be cached on disk
52 | # in order to speed up the execution of the algorithm after the first time.
53 | # Note that the cache must be invalidated manually (by deleting the files)
54 | # if the code or the internal parameters are changed.
55 | ENABLE_CACHE=false
56 | CACHE_PATH=/path/to/cache
57 |
58 | # Note: Currently Evaluation of the results is not supported anymore
59 | # Evaluation settings
60 | GROUND_TRUTH_PATH=ground_truth.json
61 | # Enable the evaluation on the last iteration
62 | ENABLE_FINAL_EVALUATION=false
63 | # Enable the evaluation on intermediate iterations (slower)
64 | ENABLE_INTERMEDIATE_EVALUATION=false
65 |
66 | # Debug settings
67 | # Set to true if you want to dump all new discovered patterns at each iteration
68 | DEBUG_DUMP_PATTERNS=false
69 |
70 | # Set to true if you want to convert the entire input data to lower case (not recommended)
71 | DEBUG_CASE_FOLDING=false
72 |
73 | # Deprecated
74 | LANGUAGE_FILTER=en|uk
75 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/ConfigManager.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.FileInputStream;
4 | import java.io.IOException;
5 | import java.util.Arrays;
6 | import java.util.List;
7 | import java.util.Properties;
8 | import java.util.stream.Collectors;
9 |
10 | public final class ConfigManager {
11 |
12 | private static final String CONFIG_FILE = "resources/config.properties";
13 |
14 | private static final ConfigManager INSTANCE = new ConfigManager();
15 |
16 | private final String datasetPath;
17 | private final String namesPath;
18 | private final int numIterations;
19 | private final int minQuotationSize;
20 | private final int maxQuotationSize;
21 | private final boolean caseSensitive;
22 | private final List langFilter;
23 |
24 | private final String newsDatasetLoader;
25 |
26 | private final boolean exportEnabled;
27 | private final String exportPath;
28 | private final boolean doQuoteAttribution;
29 |
30 | private final boolean contextExportEnabled;
31 | private final String contextOutputPath;
32 |
33 | private final int numPartition;
34 |
35 | private final boolean localModeEnabled;
36 |
37 | private final double confidenceThreshold;
38 | private final List clusteringThresholds;
39 |
40 | private final boolean mergingEnabled;
41 | private final boolean doDedup;
42 | private final int mergingShingleSize;
43 |
44 | private final boolean cacheEnabled;
45 | private final String cachePath;
46 |
47 | private final String groundTruthPath;
48 | private final boolean finalEvaluationEnabled;
49 | private final boolean intermediateEvaluationEnabled;
50 |
51 | private final boolean dumpPatternsEnabled;
52 | private final boolean caseFoldingEnabled;
53 |
54 | private String outputPath;
55 |
56 | private final boolean speakersEnabled;
57 | private final String speakersPath;
58 | private final boolean articleEnabled;
59 | private final String articlePath;
60 |
61 | private ConfigManager() {
62 | final Properties prop = new Properties();
63 | try {
64 | prop.load(new FileInputStream(CONFIG_FILE));
65 | } catch (IOException e) {
66 | throw new IllegalArgumentException("Unable to read config file", e);
67 | }
68 |
69 | datasetPath = prop.getProperty("NEWS_DATASET_PATH");
70 | namesPath = prop.getProperty("PEOPLE_DATASET_PATH");
71 | numIterations = Integer.parseInt(
72 | prop.getProperty("NUM_ITERATIONS"));
73 | minQuotationSize = Integer.parseInt(
74 | prop.getProperty("MIN_QUOTATION_SIZE"));
75 | maxQuotationSize = Integer.parseInt(
76 | prop.getProperty("MAX_QUOTATION_SIZE"));
77 | caseSensitive = prop.getProperty("CASE_SENSITIVE").equals("true");
78 | langFilter = Arrays.asList(prop.getProperty("LANGUAGE_FILTER").split("\\|"));
79 |
80 | newsDatasetLoader = prop.getProperty("NEWS_DATASET_LOADER");
81 |
82 | doQuoteAttribution = prop.getProperty("DO_QUOTE_ATTRIBUTION").equals("true");
83 |
84 | exportEnabled = prop.getProperty("EXPORT_RESULTS").equals("true");
85 | exportPath = prop.getProperty("EXPORT_PATH");
86 |
87 | contextExportEnabled = prop.getProperty("EXPORT_CONTEXT").equals("true");
88 | contextOutputPath = prop.getProperty("CONTEXT_PATH");
89 |
90 | speakersEnabled = prop.getProperty("EXPORT_SPEAKERS").equals("true");
91 | speakersPath = prop.getProperty("SPEAKERS_PATH");
92 |
93 | articleEnabled = prop.getProperty("EXPORT_ARTICLE").equals("true");
94 | articlePath = prop.getProperty("ARTICLE_PATH");
95 |
96 | numPartition = Integer.parseInt(
97 | prop.getProperty("NUM_PARTITIONS"));
98 |
99 | localModeEnabled = prop.getProperty("LOCAL_MODE").equals("true");
100 |
101 | confidenceThreshold = Double.parseDouble(
102 | prop.getProperty("PATTERN_CONFIDENCE_THRESHOLD"));
103 | clusteringThresholds = Arrays.asList(prop.getProperty("PATTERN_CLUSTERING_THRESHOLDS").split("\\|"))
104 | .stream()
105 | .map(Double::parseDouble)
106 | .collect(Collectors.toList());
107 |
108 | mergingEnabled = prop.getProperty("ENABLE_QUOTATION_MERGING").equals("true");
109 | doDedup = prop.getProperty("ENABLE_DEDUPLICATION").equals("true");
110 | mergingShingleSize = Integer.parseInt(
111 | prop.getProperty("MERGING_SHINGLE_SIZE"));
112 |
113 | cacheEnabled = prop.getProperty("ENABLE_CACHE").equals("true");
114 | cachePath = prop.getProperty("CACHE_PATH");
115 |
116 | groundTruthPath = prop.getProperty("GROUND_TRUTH_PATH");
117 | finalEvaluationEnabled = prop.getProperty("ENABLE_FINAL_EVALUATION").equals("true");
118 | intermediateEvaluationEnabled = prop.getProperty("ENABLE_INTERMEDIATE_EVALUATION").equals("true");
119 |
120 | dumpPatternsEnabled = prop.getProperty("DEBUG_DUMP_PATTERNS").equals("true");
121 | caseFoldingEnabled = prop.getProperty("DEBUG_CASE_FOLDING").equals("true");
122 |
123 | outputPath = "";
124 | }
125 |
126 | public static ConfigManager getInstance() {
127 | return INSTANCE;
128 | }
129 |
130 | public String getDatasetPath() {
131 | return datasetPath;
132 | }
133 |
134 | public String getNamesPath() {
135 | return namesPath;
136 | }
137 |
138 | public int getNumIterations() {
139 | return numIterations;
140 | }
141 |
142 | public boolean isCaseSensitive() {
143 | return caseSensitive;
144 | }
145 |
146 | public List getLangFilter() {
147 | return langFilter;
148 | }
149 |
150 | public String getLangSuffix() {
151 | return langFilter.stream().sorted().collect(Collectors.joining("-"));
152 | }
153 |
154 | public boolean isLocalModeEnabled() {
155 | return localModeEnabled;
156 | }
157 |
158 | public double getConfidenceThreshold() {
159 | return confidenceThreshold;
160 | }
161 |
162 | public List getClusteringThresholds() {
163 | return clusteringThresholds;
164 | }
165 |
166 | public boolean isCacheEnabled() {
167 | return cacheEnabled;
168 | }
169 |
170 | public String getCachePath() {
171 | return cachePath;
172 | }
173 |
174 | public String getGroundTruthPath() {
175 | return groundTruthPath;
176 | }
177 |
178 | public boolean isFinalEvaluationEnabled() {
179 | return finalEvaluationEnabled;
180 | }
181 |
182 | public boolean isIntermediateEvaluationEnabled() {
183 | return intermediateEvaluationEnabled;
184 | }
185 |
186 | public boolean isDumpPatternsEnabled() {
187 | return dumpPatternsEnabled;
188 | }
189 |
190 | public boolean isCaseFoldingEnabled() {
191 | return caseFoldingEnabled;
192 | }
193 |
194 | public boolean isMergingEnabled() {
195 | return mergingEnabled;
196 | }
197 |
198 | public int getMergingShingleSize() {
199 | return mergingShingleSize;
200 | }
201 |
202 | public String getOutputPath() {
203 | return outputPath;
204 | }
205 |
206 | public void setOutputPath(String outputPath) {
207 | this.outputPath = outputPath;
208 | }
209 |
210 | public boolean isExportEnabled() {
211 | return exportEnabled;
212 | }
213 |
214 | public String getExportPath() {
215 | return exportPath;
216 | }
217 |
218 | public boolean isSpeakersEnabled() {
219 | return speakersEnabled;
220 | }
221 |
222 | public String getSpeakersPath() {
223 | return speakersPath;
224 | }
225 |
226 | public boolean isArticleEnabled() {
227 | return articleEnabled;
228 | }
229 |
230 | public String getArticlePath() {
231 | return articlePath;
232 | }
233 |
234 | public String getNewsDatasetLoader() {
235 | return newsDatasetLoader;
236 | }
237 |
238 | public boolean isContextExportEnabled() {
239 | return contextExportEnabled;
240 | }
241 |
242 | public String getContextOutputPath() {
243 | return contextOutputPath;
244 | }
245 |
246 | public int getNumPartition() {
247 | return numPartition;
248 | }
249 |
250 | public boolean isDoQuoteAttribution() {
251 | return doQuoteAttribution;
252 | }
253 |
254 | public boolean isDoDedupEnabled() {
255 | return doDedup;
256 | }
257 |
258 | public int getMinQuotationSize() {
259 | return minQuotationSize;
260 | }
261 |
262 | public int getMaxQuotationSize() {
263 | return maxQuotationSize;
264 | }
265 | }
266 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/ContextExtractor.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Collections;
5 | import java.util.List;
6 | import java.util.stream.Collectors;
7 |
8 | import ch.epfl.dlab.spinn3r.Tokenizer;
9 | import ch.epfl.dlab.spinn3r.TokenizerImpl;
10 | import scala.Tuple2;
11 |
12 | public class ContextExtractor {
13 |
14 | public static List extractQuotations(List tokenList, String articleUid) {
15 | String[] openTokens = {"``", "", ""};
17 |
18 | final int minQuotationSize = 6; // ConfigManager.getInstance().getMinQuotationSize();
19 | final int maxQuotationSize = 500; // ConfigManager.getInstance().getMaxQuotationSize();
20 | final int minContextSize = 50; // At least N tokens in surrounding context
21 | final int maxContextSize = 100; // At most N tokens in surrounding context
22 |
23 | // Extract candidate quotations (without surrounding context)
24 | List> candidateQuotations = new ArrayList<>(); // First token, last token
25 | int expectedToken = -1;
26 | int startIndex = -1;
27 | for (int t = 0; t < tokenList.size(); t++) {
28 | for (int i = 0; i < openTokens.length; i++) {
29 | if (tokenList.get(t).startsWith(openTokens[i])) {
30 | expectedToken = i;
31 | startIndex = t;
32 | break;
33 | }
34 | }
35 | if (expectedToken != -1) {
36 | for (int i = 0; i < closeTokens.length; i++) {
37 | if (tokenList.get(t).equals(closeTokens[i])) {
38 | if (i == expectedToken && t - startIndex >= minQuotationSize && t - startIndex <= maxQuotationSize) {
39 | // Consider only well-formed quotations (e.g. must be followed by
, not )
40 | // In addition, when nested quotations are present, only the innermost one is extracted.
41 | candidateQuotations.add(new Tuple2<>(startIndex, t));
42 | }
43 | expectedToken = -1;
44 | startIndex = -1;
45 | break;
46 | }
47 | }
48 | }
49 | }
50 |
51 | Tokenizer tokenizer = new TokenizerImpl();
52 |
53 | // Extract surrounding context from quotation
54 | List quotations = new ArrayList<>();
55 | int articleIdx = 0;
56 | for (Tuple2 t : candidateQuotations) {
57 |
58 | List currentSentence = new ArrayList<>();
59 |
60 | List prevContext = new ArrayList<>();
61 | int leftOffset = scanBackward(tokenList, prevContext, t._1 - 1, minContextSize, maxContextSize);
62 | Collections.reverse(prevContext);
63 | prevContext.forEach(token -> {
64 | currentSentence.add(new Token(token, Token.Type.GENERIC));
65 | });
66 |
67 | List quotation = new ArrayList<>();
68 | for (int i = t._1 + 1; i < t._2; i++) {
69 | quotation.add(tokenList.get(i));
70 | }
71 |
72 | String quotationStr = tokenizer.untokenize(quotation.stream()
73 | .filter(x -> !StaticRules.isHtmlTag(x))
74 | .collect(Collectors.toList()));
75 |
76 | currentSentence.add(new Token(quotationStr, Token.Type.QUOTATION));
77 |
78 | List nextContext = new ArrayList<>();
79 | int rightOffset = scanForward(tokenList, nextContext, t._2 + 1, minContextSize, maxContextSize);
80 | nextContext.forEach(token -> {
81 | currentSentence.add(new Token(token, Token.Type.GENERIC));
82 | });
83 |
84 | // Remove HTML tags
85 | // currentSentence.removeIf(x -> StaticRules.isHtmlTag(x.toString()));
86 |
87 | quotations.add(new Sentence(currentSentence, articleUid, articleIdx, t._1 + 1, leftOffset, rightOffset + 1));
88 | articleIdx++;
89 | }
90 |
91 | return quotations;
92 | }
93 |
94 | private static int scanForward(List tokens, List context, int initialIndex, int minSize, int maxSize) {
95 | int finalIndexLower = Math.min(initialIndex + minSize, tokens.size() - 1);
96 |
97 | boolean anotherQuotationFound = false;
98 | for (int i = initialIndex; i < finalIndexLower; i++) {
99 | String token = tokens.get(i);
100 | if (token.equals("``")) {
101 | anotherQuotationFound = true;
102 | } else if (token.equals("''")) {
103 | anotherQuotationFound = false;
104 | }
105 | context.add(token);
106 | }
107 |
108 | int finalIndexUpper = Math.min(initialIndex + maxSize, tokens.size() - 1);
109 | for (int i = finalIndexLower; i < finalIndexUpper; i++) {
110 | String token = tokens.get(i);
111 | context.add(token);
112 |
113 | if (anotherQuotationFound) {
114 | if (token.equals("''")) {
115 | return i;
116 | }
117 | } else {
118 | if (token.equals(".") || token.equals("
") || token.equals("
")) {
119 | return i;
120 | }
121 | }
122 | }
123 | // No stopper found -> revert to short context
124 | while (context.size() > minSize) {
125 | context.remove(context.size() - 1);
126 | finalIndexUpper--;
127 | }
128 |
129 | return finalIndexUpper;
130 | }
131 |
132 | private static int scanBackward(List tokens, List context, int initialIndex, int minSize, int maxSize) {
133 | int finalIndexLower = Math.max(initialIndex - minSize, 0);
134 |
135 | boolean anotherQuotationFound = false;
136 | for (int i = initialIndex; i > finalIndexLower; i--) {
137 | String token = tokens.get(i);
138 | if (token.equals("''")) {
139 | anotherQuotationFound = true;
140 | } else if (token.equals("``")) {
141 | anotherQuotationFound = false;
142 | }
143 | context.add(token);
144 | }
145 |
146 | int finalIndexUpper = Math.max(initialIndex - maxSize, 0);
147 | for (int i = finalIndexLower; i >= finalIndexUpper; i--) {
148 | String token = tokens.get(i);
149 | context.add(token);
150 |
151 | if (anotherQuotationFound) {
152 | if (token.equals("``")) {
153 | return i;
154 | }
155 | } else {
156 | if (token.equals(".") || token.equals("
") || token.startsWith(" revert to short context
162 | while (context.size() > minSize) {
163 | context.remove(context.size() - 1);
164 | finalIndexUpper++;
165 | }
166 |
167 | return finalIndexUpper;
168 | }
169 |
170 | public static Sentence postProcess(Sentence s) {
171 |
172 | List tokens = new ArrayList<>(s.getTokens());
173 |
174 | // Remove HTML tags
175 | tokens.removeIf(x -> x.getType() != Token.Type.QUOTATION && StaticRules.isHtmlTag(x.toString()));
176 |
177 | // Find other quotations other than the main quotation
178 | int startQuot = 0;
179 | int i = 0;
180 | while (i < tokens.size()) {
181 | String t = tokens.get(i).toString();
182 | if (tokens.get(i).getType() == Token.Type.QUOTATION) {
183 | startQuot = -1;
184 | }
185 | if (t.equals("''") && startQuot != -1) {
186 | List sub = tokens.subList(startQuot, i + 1);
187 | sub.clear();
188 | sub.add(new Token("[QUOTE]", Token.Type.GENERIC));
189 | startQuot = -1;
190 | i = startQuot;
191 | } else if ((t.equals("``"))) {
192 | startQuot = i;
193 | }
194 | i++;
195 | }
196 | if (startQuot != -1) {
197 | List sub = tokens.subList(startQuot, tokens.size());
198 | sub.clear();
199 | sub.add(new Token("[QUOTE]", Token.Type.GENERIC));
200 | }
201 |
202 | return new Sentence(tokens, s.getArticleUid(), s.getIndex(), s.getQuotationOffset(), s.getLeftOffset(), s.getRightOffset());
203 | }
204 |
205 | public static Sentence canonicalizeQuotation(Sentence s) {
206 | List tokens = new ArrayList<>(s.getTokens());
207 |
208 | // Canonicalize quotation, and push out final punctuation marks, if present inside quotation
209 | for (int i = 0; i < tokens.size(); i++) {
210 | if (tokens.get(i).getType() == Token.Type.QUOTATION) {
211 | String tokenStr = tokens.get(i).toString().trim();
212 | final String[] patterns = {",", "."};
213 | for (String p : patterns) {
214 | if (tokenStr.endsWith(p)) {
215 | if (i == tokens.size() - 1 || !StaticRules.isPunctuation(tokens.get(i + 1).toString())) {
216 | // Push out
217 | tokens.add(i + 1, new Token(p, Token.Type.GENERIC));
218 | }
219 | }
220 | }
221 | tokens.set(i, new Token(StaticRules.canonicalizeQuotation(tokenStr), Token.Type.QUOTATION));
222 | }
223 | }
224 |
225 | return new Sentence(tokens, s.getArticleUid(), s.getIndex(), s.getQuotationOffset(), s.getLeftOffset(), s.getRightOffset());
226 | }
227 |
228 | }
229 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/DatasetLoader.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.util.List;
4 | import java.util.Set;
5 |
6 | import org.apache.spark.api.java.JavaRDD;
7 | import org.apache.spark.api.java.JavaSparkContext;
8 |
9 | /**
10 | * This interface is used for loading a news dataset.
11 | * The handling of the format is up to the concrete implementation.
12 | */
13 | public interface DatasetLoader {
14 |
15 | /**
16 | * Load all articles as a Java RDD.
17 | * @param sc the JavaSparkContext used for loading the RDD
18 | * @param datasetPath the path of the dataset.
19 | * @param languageFilter a set that contains the requested language codes.
20 | * @return a JavaRDD of Articles
21 | */
22 | JavaRDD loadArticles(JavaSparkContext sc, String datasetPath, Set languageFilter);
23 |
24 | /**
25 | * A news article, identified by a unique long identifier.
26 | */
27 | public interface Article {
28 |
29 | /** Get the unique identifier of this article. */
30 | String getArticleUID();
31 |
32 | /** Get the content of this article in tokenized format. */
33 | List getArticleContent();
34 |
35 | /** Get the domain name in which this article was found. */
36 | String getWebsite();
37 |
38 | /** Get the date of this article. */
39 | String getDate();
40 |
41 | /** Get the version of the article */
42 | String getVersion();
43 |
44 | /** Get the title of this article */
45 | String getTitle();
46 |
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Dawg.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Collection;
5 | import java.util.Collections;
6 | import java.util.HashMap;
7 | import java.util.List;
8 | import java.util.Map;
9 | import java.util.Optional;
10 |
11 | import com.google.common.collect.HashMultimap;
12 | import com.google.common.collect.Multimap;
13 |
14 | import scala.Tuple2;
15 |
16 | /**
17 | * This class implements a Directed Acyclic Word Graph (DAWG),
18 | * also known as deterministic acyclic finite state automaton (DAFSA).
19 | */
20 | public class Dawg {
21 |
22 | final Node root;
23 | final Multimap allNodes;
24 |
25 | public Dawg() {
26 | root = new Node("");
27 | allNodes = HashMultimap.create();
28 | }
29 |
30 | public Node getRoot() {
31 | return root;
32 | }
33 |
34 | public static Optional convert(List newPattern) {
35 | // Remove leading/trailing asterisks
36 | int startPos = 0;
37 | while (newPattern.size() > startPos && newPattern.get(startPos).getType() == Token.Type.ANY) {
38 | startPos++;
39 | }
40 |
41 | int endPos = newPattern.size();
42 | while (endPos > 0 && newPattern.get(endPos - 1).getType() == Token.Type.ANY) {
43 | endPos--;
44 | }
45 |
46 | if (startPos >= endPos) {
47 | return Optional.empty();
48 | }
49 |
50 | newPattern = newPattern.subList(startPos, endPos);
51 |
52 | newPattern.replaceAll(x -> {
53 | if (x.toString().equals(Pattern.QUOTATION_PLACEHOLDER)) {
54 | return new Token(null, Token.Type.QUOTATION);
55 | } else if (x.toString().equals(Pattern.SPEAKER_PLACEHOLDER)) {
56 | return new Token(null, Token.Type.SPEAKER);
57 | }
58 | return x;
59 | });
60 |
61 | try {
62 | Pattern p = new Pattern(newPattern);
63 |
64 | // Add a constraint on the number of consecutive $* tokens
65 | int maxCount = 0;
66 | final int threshold = 5;
67 | for (Token t : p.getTokens()) {
68 | if (t.getType() == Token.Type.ANY) {
69 | maxCount++;
70 | } else {
71 | maxCount = 0;
72 | }
73 | if (maxCount > threshold) {
74 | return Optional.empty();
75 | }
76 | }
77 |
78 | return Optional.of(p);
79 | } catch (Exception e) {
80 | // Discard invalid patterns
81 | return Optional.empty();
82 | }
83 | }
84 |
85 | public void addAll(List> patterns) {
86 | patterns.sort((x, y) -> {
87 | for (int i = 0; i < Math.min(x.size(), y.size()); i++) {
88 | int comp = x.get(i).compareTo(y.get(i));
89 | if (comp != 0) {
90 | return comp;
91 | }
92 | }
93 | return Integer.compare(x.size(), y.size());
94 | });
95 | patterns.forEach(this::add);
96 | if (root.children.size() > 0) {
97 | addOrReplace(root);
98 | }
99 | }
100 |
101 | private void add(List tokens) {
102 | Tuple2> tail = getInsertionPoint(tokens);
103 | if (tail._1 == null) {
104 | return;
105 | }
106 | if (tail._1.children.size() > 0) {
107 | addOrReplace(tail._1);
108 | }
109 | addSuffix(tail._1, tail._2);
110 | }
111 |
112 | private Tuple2> getInsertionPoint(List tokens) {
113 | Node node = root;
114 | for (int i = 0; i < tokens.size(); i++) {
115 | node.count++;
116 |
117 | String token = tokens.get(i);
118 | Node next = node.children.get(token);
119 | if (next == null) {
120 | return new Tuple2<>(node, tokens.subList(i, tokens.size()));
121 | }
122 | node = next;
123 | }
124 | return new Tuple2<>(null, Collections.emptyList());
125 | }
126 |
127 | /**
128 | * Export this graph into a GraphViz description (for rendering).
129 | * @return a string in GraphViz format
130 | */
131 | public String toDot() {
132 | StringBuilder builder = new StringBuilder();
133 | builder.append("digraph g{\n\trankdir=LR\n");
134 |
135 | allNodes.put("", root);
136 | allNodes.entries().forEach(x -> {
137 | builder.append("\t" + System.identityHashCode(x.getValue())
138 | + " [label=\"" + x.getKey() + " ["+ x.getValue().count + "]\"]\n");
139 | });
140 | allNodes.values().forEach(x -> {
141 | x.children.values().forEach(y -> {
142 | builder.append("\t" + System.identityHashCode(x) + " -> " + System.identityHashCode(y) + "\n");
143 | });
144 | });
145 | allNodes.remove("", root);
146 | builder.append("}\n");
147 | return builder.toString();
148 | }
149 |
150 | private void addSuffix(Node node, List tokens) {
151 | for (String token : tokens) {
152 | Node next = new Node(token);
153 | next.count++;
154 | node.children.put(token, next);
155 | node.lastAdded = next;
156 | node = next;
157 | }
158 | }
159 |
160 | private void addOrReplace(Node node) {
161 | Node child = node.lastAdded;
162 | if (child.children.size() > 0) {
163 | addOrReplace(child);
164 | }
165 |
166 | Collection candidateChildren = allNodes.get(child.word);
167 | for (Node existingChild : candidateChildren) {
168 | if (existingChild.equals(child)) {
169 | existingChild.count += child.count;
170 | node.lastAdded = existingChild;
171 | node.children.put(child.word, existingChild);
172 | return;
173 | }
174 | }
175 |
176 | allNodes.put(child.word, child);
177 | }
178 |
179 | public static class Node {
180 | private final String word;
181 | private final Map children;
182 | private int count;
183 | private Node lastAdded;
184 |
185 | public Node(String word) {
186 | this.word = word;
187 | this.children = new HashMap<>();
188 | this.count = 0;
189 | }
190 |
191 | public String getWord() {
192 | return word;
193 | }
194 |
195 | public int getCount() {
196 | return count;
197 | }
198 |
199 | public Collection getNodes() {
200 | return children.values();
201 | }
202 |
203 | @Override
204 | public int hashCode() {
205 | int hash = word.hashCode();
206 | // Commutative hash function
207 | for (Map.Entry entry : children.entrySet()) {
208 | hash ^= entry.getKey().hashCode() + 31 * System.identityHashCode(entry.getValue());
209 | }
210 | return hash;
211 | }
212 |
213 | @Override
214 | public boolean equals(final Object obj) {
215 | if (this == obj)
216 | return true;
217 |
218 | if (obj instanceof Node) {
219 | final Node node = (Node) obj;
220 | return hashCode() == node.hashCode() && word.equals(node.word) && equalsChildren(node);
221 | }
222 |
223 | return false;
224 | }
225 |
226 | @Override
227 | public String toString() {
228 | return Integer.toHexString(System.identityHashCode(this)) + "(" + word + ", " + count + ")";
229 | }
230 |
231 | private boolean equalsChildren(final Node other) {
232 | if (children.size() != other.children.size()) {
233 | return false;
234 | }
235 |
236 | for (Map.Entry entry : children.entrySet()) {
237 | Node match = other.children.get(entry.getKey());
238 | // The pointer comparison is intentional!
239 | if (match == null || match != entry.getValue()) {
240 | return false;
241 | }
242 | }
243 |
244 | return true;
245 | }
246 |
247 | /**
248 | * Extract all patterns stored in this graph.
249 | * @param all The output list.
250 | * @param current Temporary storage (must be empty).
251 | */
252 | public void dump(List> all, List current) {
253 | current.add(this);
254 | if (children.isEmpty()) {
255 | all.add(new ArrayList<>(current));
256 | } else {
257 | children.values().forEach(x -> x.dump(all, current));
258 | }
259 | current.remove(current.size() - 1);
260 | }
261 |
262 | }
263 | }
264 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Exporter.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.util.ArrayList;
4 | import java.util.HashMap;
5 | import java.util.HashSet;
6 | import java.util.Iterator;
7 | import java.util.List;
8 | import java.util.Map;
9 | import java.util.Set;
10 |
11 | import org.apache.hadoop.io.compress.GzipCodec;
12 | import org.apache.spark.api.java.JavaPairRDD;
13 | import org.apache.spark.api.java.JavaRDD;
14 | import org.apache.spark.api.java.JavaSparkContext;
15 |
16 | import com.google.gson.GsonBuilder;
17 | import com.google.gson.JsonArray;
18 | import com.google.gson.JsonObject;
19 | import com.google.gson.JsonPrimitive;
20 |
21 | import scala.Tuple2;
22 | import scala.Tuple3;
23 | import scala.Tuple4;
24 |
25 | public class Exporter {
26 |
27 | private final JavaPairRDD quotationMap;
28 |
29 | /** Map article UID to a tuple (full quotation, website, date) */
30 | private final JavaPairRDD, Tuple3> articles;
31 |
32 | private final int numPartition;
33 |
34 | public Exporter(JavaSparkContext sc, JavaRDD sentences, NameDatabaseWikiData people) {
35 |
36 | final Set langSet = new HashSet<>(ConfigManager.getInstance().getLangFilter());
37 | this.articles = QuotationExtraction.getConcreteDatasetLoader().loadArticles(sc,
38 | ConfigManager.getInstance().getDatasetPath(), langSet)
39 | .mapToPair(x -> new Tuple2<>(new Tuple2<>(x.getWebsite(), x.getDate()), new Tuple2<>(x.getArticleContent(), x.getArticleUID())))
40 | .flatMapValues(x -> ContextExtractor.extractQuotations(x._1(), x._2()))
41 | .mapToPair(x -> new Tuple2<>(x._2.getKey(), new Tuple3<>(x._2.getQuotation(), x._1._1, x._1._2)));
42 |
43 | this.quotationMap = computeQuotationMap(sc);
44 | this.numPartition = ConfigManager.getInstance().getNumPartition();
45 | }
46 |
47 | private JavaPairRDD computeQuotationMap(JavaSparkContext sc) {
48 | Set langSet = new HashSet<>(ConfigManager.getInstance().getLangFilter());
49 |
50 | // Reconstruct quotations (from the lower-case canonical form to the full form)
51 | return QuotationExtraction.getConcreteDatasetLoader().loadArticles(sc,
52 | ConfigManager.getInstance().getDatasetPath(), langSet)
53 | .flatMap(x -> ContextExtractor.extractQuotations(x.getArticleContent(), x.getArticleUID()).iterator())
54 | .mapToPair(x -> new Tuple2<>(StaticRules.canonicalizeQuotation(x.getQuotation()), x.getQuotation()))
55 | .reduceByKey((x, y) -> {
56 | // Out of multiple possibilities, get the longest quotation
57 | if (x.length() > y.length()) {
58 | return x;
59 | } else if (x.length() < y.length()) {
60 | return y;
61 | } else {
62 | // Lexicographical comparison to ensure determinism
63 | return x.compareTo(y) == -1 ? x : y;
64 | }
65 | });
66 | }
67 |
68 | public void exportResults(JavaPairRDD>, LineageInfo>> pairs) {
69 | String exportPath = ConfigManager.getInstance().getExportPath();
70 |
71 | JavaPairRDD, String, String, String>> articleMap = pairs.mapToPair(x -> new Tuple2<>(x._1, x._2._2)) // (canonical quotation, lineage info)
72 | .flatMapValues(x -> {
73 | // (key)
74 | List> values = new ArrayList<>();
75 | for (int i = 0; i < x.getPatterns().size(); i++) {
76 | values.add(x.getSentences().get(i).getKey());
77 | }
78 | return values;
79 | }) // (canonical quotation, key)
80 | .mapToPair(Tuple2::swap) // (key, canonical quotation)
81 | .join(this.articles) // (key, (canonical quotation, (website, date)))
82 | .mapToPair(x -> new Tuple2<>(x._2._1, new Tuple4<>(x._1, x._2._2._1(), x._2._2._2(), x._2._2._3()))); // (canonical quotation, (key, full quotation, website, date))
83 |
84 | pairs // (canonical quotation, (speaker, lineage info))
85 | .join(quotationMap)
86 | .mapValues(x -> new Tuple3<>(x._1._1, x._1._2, x._2)) // (canonical quotation, (speakers, lineage info, full quotation))
87 | .cogroup(articleMap)
88 | .map(t -> {
89 |
90 | String canonicalQuotation = t._1;
91 | Map, Tuple3> articles = new HashMap<>();
92 | t._2._2.forEach(x -> {
93 | articles.put(x._1(), new Tuple3<>(x._2(), x._3(), x._4())); // (key, (full quotation, website, date))
94 | });
95 |
96 | Iterator>, LineageInfo, String>> it = t._2._1.iterator();
97 | if (!it.hasNext()) {
98 | return null;
99 | }
100 | Tuple3>, LineageInfo, String> data = it.next();
101 |
102 | if (data._2().getPatterns().size() != articles.size()) {
103 | return null;
104 | }
105 |
106 | JsonArray ids = new JsonArray();
107 | data._1().stream().map(y -> new JsonPrimitive(y._1)).forEach(y -> ids.add(y));
108 |
109 | JsonArray names = new JsonArray();
110 | data._1().stream().map(y -> new JsonPrimitive(y._2)).forEach(y -> names.add(y));
111 |
112 |
113 | JsonObject o = new JsonObject();
114 | o.addProperty("quotation", data._3());
115 | o.addProperty("canonicalQuotation", canonicalQuotation);
116 | o.addProperty("numSpeakers", data._1().size());
117 | o.add("speaker", names);
118 | o.add("speakerID", ids);
119 | o.addProperty("confidence", data._2().getConfidence()); // Tuple confidence
120 | o.addProperty("numOccurrences", data._2().getPatterns().size());
121 |
122 | JsonArray occurrences = new JsonArray();
123 | for (int i = 0; i < data._2().getPatterns().size(); i++) {
124 | JsonObject occ = new JsonObject();
125 | Tuple2 key = data._2().getSentences().get(i).getKey();
126 | occ.addProperty("articleUID", key._1);
127 | occ.addProperty("articleOffset", data._2().getSentences().get(i).getIndex());
128 | occ.addProperty("extractedBy", data._2().getPatterns().get(i).toString(false));
129 | occ.addProperty("patternConfidence", data._2().getPatterns().get(i).getConfidenceMetric());
130 | occ.addProperty("quotation", articles.get(key)._1());
131 |
132 |
133 | String matchedTokens = String.join(" ", data._2().getAliases().get(i));
134 | occ.addProperty("matchedSpeakerTokens", matchedTokens);
135 | occ.addProperty("website", articles.get(key)._2());
136 | String date = articles.get(key)._3();
137 | if (!date.isEmpty()) {
138 | occ.addProperty("date", date);
139 | }
140 | occurrences.add(occ);
141 | }
142 | o.add("occurrences", occurrences);
143 |
144 | return new GsonBuilder().disableHtmlEscaping().create().toJson(o);
145 | })
146 | .filter(x -> x != null)
147 | .repartition(numPartition)
148 | .saveAsTextFile(exportPath, GzipCodec.class);
149 | }
150 |
151 | }
152 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/ExporterArticle.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.IOException;
4 |
5 | import org.apache.hadoop.io.compress.GzipCodec;
6 | import org.apache.spark.api.java.JavaRDD;
7 | import org.apache.spark.api.java.JavaSparkContext;
8 |
9 | import com.google.gson.GsonBuilder;
10 | import com.google.gson.JsonObject;
11 |
12 | import ch.epfl.dlab.quootstrap.DatasetLoader.Article;
13 |
14 | public class ExporterArticle {
15 |
16 | private final JavaRDD allArticles;
17 | // private final int numPartition;
18 |
19 | public ExporterArticle(JavaRDD allArticles) {
20 | this.allArticles = allArticles;
21 | // this.numPartition = ConfigManager.getInstance().getNumPartition();
22 | }
23 |
24 | public void exportResults(String exportPath, JavaSparkContext sc) throws IOException {
25 | allArticles.map(x -> { //repartition(numPartition).
26 | JsonObject o = new JsonObject();
27 | o.addProperty("articleUID", x.getArticleUID());
28 | o.addProperty("content", String.join(" ", x.getArticleContent()));
29 | o.addProperty("date", x.getDate());
30 | o.addProperty("website", x.getWebsite());
31 | o.addProperty("phase", x.getVersion());
32 | o.addProperty("title", x.getTitle());
33 |
34 | return new GsonBuilder().disableHtmlEscaping().create().toJson(o);
35 | }).filter(x -> x != null).saveAsTextFile(exportPath, GzipCodec.class);
36 | }
37 |
38 | }
39 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/ExporterContext.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.IOException;
4 | import java.util.List;
5 |
6 | import org.apache.hadoop.io.compress.GzipCodec;
7 | import org.apache.spark.api.java.JavaRDD;
8 | import org.apache.spark.api.java.JavaSparkContext;
9 |
10 | import com.google.gson.GsonBuilder;
11 | import com.google.gson.JsonObject;
12 |
13 | import ch.epfl.dlab.spinn3r.Tokenizer;
14 | import ch.epfl.dlab.spinn3r.TokenizerImpl;
15 |
16 | public class ExporterContext {
17 |
18 | private final JavaRDD sentences;
19 | private final int numPartition;
20 |
21 | public ExporterContext(JavaRDD sentences) {
22 | this.sentences = sentences;
23 | this.numPartition = ConfigManager.getInstance().getNumPartition();
24 | }
25 |
26 | public void exportResults(String exportPath, JavaSparkContext sc) throws IOException {
27 | sentences.repartition(numPartition).map(x -> {
28 | JsonObject o = new JsonObject();
29 | o.addProperty("articleUID", x.getArticleUid());
30 | o.addProperty("articleOffset", x.getIndex());
31 |
32 | String leftContext = "";
33 | String rightContext = "";
34 | String quotation = "";
35 | List tokens = x.getTokens();
36 | Tokenizer tokenizer = new TokenizerImpl();
37 |
38 | for (int i = 0; i < x.getTokenCount(); i++) {
39 | Token t = tokens.get(i);
40 | if (t.getType() == Token.Type.QUOTATION) {
41 | leftContext = tokenizer.untokenize(Token.getStrings(tokens.subList(0, i)));
42 | quotation = t.toString();
43 | rightContext = tokenizer.untokenize(Token.getStrings(tokens.subList(i + 1, x.getTokenCount())));
44 | break;
45 | }
46 | }
47 |
48 | o.addProperty("leftContext", leftContext);
49 | o.addProperty("rightContext", rightContext);
50 | o.addProperty("quotation", quotation);
51 | o.addProperty("quotationOffset", x.getQuotationOffset());
52 | o.addProperty("leftOffset", x.getLeftOffset());
53 | o.addProperty("rightOffset", x.getRightOffset());
54 |
55 | return new GsonBuilder().disableHtmlEscaping().create().toJson(o);
56 | }).filter(x -> x != null).saveAsTextFile(exportPath, GzipCodec.class);
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/ExporterSpeakers.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.IOException;
4 | import java.util.stream.Collectors;
5 |
6 | import org.apache.hadoop.io.compress.GzipCodec;
7 | import org.apache.spark.api.java.JavaPairRDD;
8 | import org.apache.spark.api.java.JavaSparkContext;
9 |
10 | import com.google.gson.GsonBuilder;
11 | import com.google.gson.JsonArray;
12 | import com.google.gson.JsonObject;
13 | import com.google.gson.JsonPrimitive;
14 |
15 | import scala.Tuple2;
16 |
17 | public class ExporterSpeakers {
18 |
19 | private final JavaPairRDD> allSpeakers;
20 | private final int numPartition;
21 |
22 | public ExporterSpeakers(JavaPairRDD> allSpeakers) {
23 | this.allSpeakers = allSpeakers;
24 | this.numPartition = ConfigManager.getInstance().getNumPartition();
25 | }
26 |
27 | public void exportResults(String exportPath, JavaSparkContext sc) throws IOException {
28 | allSpeakers.repartition(numPartition).map(x -> {
29 | JsonObject o = new JsonObject();
30 | o.addProperty("articleUID", x._1);
31 | JsonArray names = new JsonArray();
32 | for (SpeakerAlias alias: x._2) {
33 | JsonObject current = new JsonObject();
34 | JsonArray ids = new JsonArray();
35 | JsonArray offsets = new JsonArray();
36 | String current_alias = alias.getAlias().stream().collect(Collectors.joining(" "));
37 | alias.getIds().stream().map(y -> new Tuple2<>(new JsonPrimitive(y._1), new JsonPrimitive(y._2))).forEach(y -> {
38 | JsonArray current_ids = new JsonArray();
39 | current_ids.add(y._1);
40 | current_ids.add(y._2);
41 | ids.add(current_ids);
42 | });
43 | alias.getOffsets().stream().map(y -> new Tuple2<>(new JsonPrimitive(y._1), new JsonPrimitive(y._2))).forEach(y -> {
44 | JsonArray current_offset = new JsonArray();
45 | current_offset.add(y._1);
46 | current_offset.add(y._2);
47 | offsets.add(current_offset);
48 | });
49 | current.addProperty("name", current_alias);
50 | current.add("ids", ids);
51 | current.add("offsets", offsets);
52 | names.add(current);
53 | }
54 | o.add("names", names);
55 |
56 | return new GsonBuilder().disableHtmlEscaping().create().toJson(o);
57 | }).filter(x -> x != null).saveAsTextFile(exportPath, GzipCodec.class);
58 | }
59 |
60 | }
61 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/HashTrie.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.Serializable;
4 | import java.util.ArrayList;
5 | import java.util.Collection;
6 | import java.util.Collections;
7 | import java.util.HashMap;
8 | import java.util.Iterator;
9 | import java.util.List;
10 | import java.util.Locale;
11 | import java.util.Map;
12 | import java.util.Map.Entry;
13 |
14 | import scala.Tuple2;
15 |
16 | public class HashTrie implements Serializable {
17 |
18 | private static final long serialVersionUID = -9164247688163451454L;
19 |
20 | private final Node rootNode;
21 | private final boolean caseSensitive;
22 |
23 | public HashTrie(Iterable, List>>> substrings, boolean caseSensitive) {
24 | this.rootNode = new Node(null, caseSensitive, Collections.emptyList());
25 | this.caseSensitive = caseSensitive;
26 | substrings.forEach(this::insertSubstring);
27 | }
28 |
29 | private void insertSubstring(Tuple2, List>> substring) {
30 | Node current = rootNode;
31 |
32 | for (int i = 0; i < substring._1.size(); i++) {
33 | String token = substring._1.get(i);
34 | String key = caseSensitive ? token : token.toLowerCase(Locale.ROOT);
35 |
36 | Node next = current.findChild(key);
37 | if (next == null) {
38 | next = new Node(token, caseSensitive, Collections.emptyList());
39 | current.children.put(key, next);
40 | }
41 | current = next;
42 | }
43 |
44 | current.terminal = true;
45 | current.addIds(substring._2);
46 | }
47 |
48 | public List> getAllSubstrings() {
49 | List> allSubstrings = new ArrayList<>();
50 |
51 | List currentSubstring = new ArrayList<>();
52 | DFS(rootNode, currentSubstring, allSubstrings);
53 | return allSubstrings;
54 | }
55 |
56 | public Node getRootNode() {
57 | return rootNode;
58 | }
59 |
60 | private void DFS(Node current, List currentSubstring, List> allSubstrings) {
61 | if (current.isTerminal()) {
62 | allSubstrings.add(new ArrayList<>(currentSubstring));
63 | }
64 |
65 | for (Map.Entry next : current) {
66 | currentSubstring.add(next.getKey());
67 | DFS(next.getValue(), currentSubstring, allSubstrings);
68 | currentSubstring.remove(currentSubstring.size() - 1);
69 | }
70 | }
71 |
72 | public static class Node implements Iterable>, Serializable {
73 |
74 | private static final long serialVersionUID = -4344489198225825075L;
75 |
76 | private final Map children;
77 | private final boolean caseSensitive;
78 | private final String value;
79 | private boolean terminal;
80 | private List> ids; // (id, standard name)
81 |
82 | public Node(String value, boolean caseSensitive, Collection> ids) {
83 | this.children = new HashMap<>();
84 | this.caseSensitive = caseSensitive;
85 | this.value = value;
86 | this.terminal = false;
87 | this.ids = new ArrayList<>(ids);
88 | }
89 |
90 | public void addIds(Collection> ids) {
91 | this.ids.addAll(ids);
92 | }
93 |
94 | public boolean hasChildren() {
95 | return !children.isEmpty();
96 | }
97 |
98 | public Node findChild(String token) {
99 | return children.get(caseSensitive ? token : token.toLowerCase(Locale.ROOT));
100 | }
101 |
102 | public boolean isTerminal() {
103 | return terminal;
104 | }
105 |
106 | public String getValue() {
107 | return value;
108 | }
109 |
110 | public List> getIds() {
111 | return ids;
112 | }
113 |
114 | @Override
115 | public Iterator> iterator() {
116 | return Collections.unmodifiableMap(children).entrySet().iterator();
117 | }
118 | }
119 | }
120 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/HashTriePatternMatcher.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Collections;
5 | import java.util.List;
6 |
7 | import scala.Tuple2;
8 | import scala.Tuple3;
9 |
10 | public class HashTriePatternMatcher {
11 |
12 | private final HashTrie trie;
13 | private final List currentMatch;
14 | private List> currentIds;
15 |
16 | public HashTriePatternMatcher(HashTrie trie) {
17 | this.trie = trie;
18 | currentMatch = new ArrayList<>();
19 | }
20 |
21 | public boolean match(List tokens) {
22 | List, List>>> longestMatches = new ArrayList<>();
23 | boolean result = false;
24 | for (int i = 0; i < tokens.size(); i++) {
25 | currentMatch.clear();
26 | if (matchImpl(tokens, trie.getRootNode(), i)) {
27 | result = true;
28 | longestMatches.add(new Tuple3<>(i, new ArrayList<>(currentMatch), new ArrayList<>(currentIds)));
29 | }
30 | }
31 |
32 | if (result) {
33 | // Get longest match out of multiple matches
34 | int maxLen = Collections.max(longestMatches, (x, y) -> Integer.compare(x._2().size(), y._2().size()))._2()
35 | .size();
36 | longestMatches.removeIf(x -> x._2().size() != maxLen);
37 |
38 | currentMatch.clear();
39 | currentMatch.addAll(longestMatches.get(0)._2());
40 | currentIds = longestMatches.get(0)._3();
41 | }
42 |
43 | return result;
44 | }
45 |
46 | public List multiMatch(List tokens) {
47 | ArrayList matches = new ArrayList<>();
48 | for (int i = 0; i < tokens.size(); i++) {
49 | currentMatch.clear();
50 | if (matchImpl(tokens, trie.getRootNode(), i)) {
51 | Tuple2 offset = new Tuple2<>(i, i + currentMatch.size());
52 | SpeakerAlias speaker = new SpeakerAlias(new ArrayList<>(currentMatch), new ArrayList<>(currentIds),
53 | offset);
54 | int idx = matches.indexOf(speaker);
55 | if (idx >= 0) {
56 | matches.get(idx).addOffset(offset);
57 | } else {
58 | matches.add(speaker);
59 | }
60 | i += currentMatch.size() - 1; // Avoid matching subsequences
61 | }
62 | }
63 |
64 | return new ArrayList<>(matches);
65 | }
66 |
67 | public boolean match(Sentence s) {
68 | List tokens = s.getTokens();
69 | List, List>>> longestMatches = new ArrayList<>();
70 | boolean result = false;
71 | for (int i = 0; i < tokens.size(); i++) {
72 | currentMatch.clear();
73 | if (matchImpl(tokens, trie.getRootNode(), i)) {
74 | result = true;
75 | longestMatches.add(new Tuple3<>(i, new ArrayList<>(currentMatch), new ArrayList<>(currentIds)));
76 | }
77 | }
78 |
79 | if (result) {
80 | int maxLen = Collections.max(longestMatches, (x, y) -> Integer.compare(x._2().size(), y._2().size()))._2()
81 | .size();
82 | longestMatches.removeIf(x -> x._2().size() != maxLen);
83 |
84 | // If there are multiple speakers with max length, select the one that is
85 | // nearest to the quotation
86 | currentMatch.clear();
87 | if (longestMatches.size() > 1) {
88 | int quotationIdx = -1;
89 | for (int i = 0; i < tokens.size(); i++) {
90 | if (tokens.get(i).getType() == Token.Type.QUOTATION) {
91 | quotationIdx = i;
92 | break;
93 | }
94 | }
95 | final int qi = quotationIdx;
96 | Tuple3, List>> nearest = Collections.min(longestMatches,
97 | (x, y) -> {
98 | int delta1 = Math.abs(x._1() - qi);
99 | int delta2 = Math.abs(y._1() - qi);
100 | return Integer.compare(delta1, delta2);
101 | });
102 | currentMatch.addAll(nearest._2());
103 | currentIds = nearest._3();
104 | } else {
105 | currentMatch.addAll(longestMatches.get(0)._2());
106 | currentIds = longestMatches.get(0)._3();
107 | }
108 | }
109 |
110 | return result;
111 | }
112 |
113 | public SpeakerAlias getLongestMatch() {
114 | return new SpeakerAlias(new ArrayList<>(currentMatch), new ArrayList<>(currentIds), new Tuple2<>(0, 0));
115 | }
116 |
117 | private boolean matchImpl(List tokens, HashTrie.Node current, int i) {
118 |
119 | if (i == tokens.size()) {
120 | return false;
121 | }
122 |
123 | if (tokens.get(i).getType() != Token.Type.GENERIC) {
124 | return false;
125 | }
126 |
127 | String tokenStr = tokens.get(i).toString();
128 | HashTrie.Node next = current.findChild(tokenStr);
129 | if (next != null) {
130 | currentMatch.add(tokenStr); // next.getValue());
131 | boolean result = false;
132 | if (next.isTerminal()) {
133 | // Match found
134 | currentIds = next.getIds();
135 | result = true;
136 | }
137 |
138 | // Even if a match is found, try to match a longer sequence
139 | if (matchImpl(tokens, next, i + 1) || result) {
140 | return true;
141 | }
142 |
143 | currentMatch.remove(currentMatch.size() - 1);
144 | }
145 |
146 | return false;
147 | }
148 | }
149 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Hashed.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.Serializable;
4 |
5 | /**
6 | * Implements a lightweight hashed object that contains
7 | * a 128-bit hash of a string (using MurmurHash3 128bit x64),
8 | * as well as its original length prior to hashing it.
9 | */
10 | public final class Hashed implements Serializable, Comparable {
11 |
12 | private static final long serialVersionUID = -5391976195212991788L;
13 |
14 | private final long h1;
15 | private final long h2;
16 | private final int len;
17 |
18 | public Hashed(String s) {
19 | long h1 = 0;
20 | long h2 = 0;
21 |
22 | final byte[] key = s.getBytes();
23 | final int len = key.length;
24 | final int offset = 0;
25 |
26 | final long c1 = 0x87c37b91114253d5L;
27 | final long c2 = 0x4cf5ad432745937fL;
28 |
29 | int roundedEnd = offset + (len & 0xFFFFFFF0);
30 | for (int i = offset; i < roundedEnd; i += 16) {
31 | long k1 = arrayToLong(key, i);
32 | long k2 = arrayToLong(key, i+8);
33 | k1 *= c1;
34 | k1 = Long.rotateLeft(k1,31);
35 | k1 *= c2;
36 | h1 ^= k1;
37 | h1 = Long.rotateLeft(h1,27);
38 | h1 += h2;
39 | h1 = h1*5 + 0x52dce729;
40 | k2 *= c2;
41 | k2 = Long.rotateLeft(k2,33);
42 | k2 *= c1;
43 | h2 ^= k2;
44 | h2 = Long.rotateLeft(h2,31);
45 | h2 += h1;
46 | h2 = h2*5 + 0x38495ab5;
47 | }
48 |
49 | long k1 = 0;
50 | long k2 = 0;
51 |
52 | switch (len & 15) {
53 | case 15: k2 = (key[roundedEnd+14] & 0xffL) << 48;
54 | case 14: k2 |= (key[roundedEnd+13] & 0xffL) << 40;
55 | case 13: k2 |= (key[roundedEnd+12] & 0xffL) << 32;
56 | case 12: k2 |= (key[roundedEnd+11] & 0xffL) << 24;
57 | case 11: k2 |= (key[roundedEnd+10] & 0xffL) << 16;
58 | case 10: k2 |= (key[roundedEnd+ 9] & 0xffL) << 8;
59 | case 9: k2 |= (key[roundedEnd+ 8] & 0xffL);
60 | k2 *= c2;
61 | k2 = Long.rotateLeft(k2, 33);
62 | k2 *= c1;
63 | h2 ^= k2;
64 | case 8: k1 = ((long)key[roundedEnd+7]) << 56;
65 | case 7: k1 |= (key[roundedEnd+6] & 0xffL) << 48;
66 | case 6: k1 |= (key[roundedEnd+5] & 0xffL) << 40;
67 | case 5: k1 |= (key[roundedEnd+4] & 0xffL) << 32;
68 | case 4: k1 |= (key[roundedEnd+3] & 0xffL) << 24;
69 | case 3: k1 |= (key[roundedEnd+2] & 0xffL) << 16;
70 | case 2: k1 |= (key[roundedEnd+1] & 0xffL) << 8;
71 | case 1: k1 |= (key[roundedEnd ] & 0xffL);
72 | k1 *= c1; k1 = Long.rotateLeft(k1,31); k1 *= c2; h1 ^= k1;
73 | }
74 |
75 | h1 ^= len; h2 ^= len;
76 |
77 | h1 += h2;
78 | h2 += h1;
79 |
80 | h1 = mix(h1);
81 | h2 = mix(h2);
82 |
83 | h1 += h2;
84 | h2 += h1;
85 |
86 | this.h1 = h1;
87 | this.h2 = h2;
88 | this.len = s.length();
89 | }
90 |
91 | private static long mix(long k) {
92 | k ^= k >>> 33;
93 | k *= 0xff51afd7ed558ccdL;
94 | k ^= k >>> 33;
95 | k *= 0xc4ceb9fe1a85ec53L;
96 | k ^= k >>> 33;
97 | return k;
98 | }
99 |
100 | private static long arrayToLong(byte[] buf, int offset) {
101 | return ((long)buf[offset+7] << 56)
102 | | ((buf[offset+6] & 0xffL) << 48)
103 | | ((buf[offset+5] & 0xffL) << 40)
104 | | ((buf[offset+4] & 0xffL) << 32)
105 | | ((buf[offset+3] & 0xffL) << 24)
106 | | ((buf[offset+2] & 0xffL) << 16)
107 | | ((buf[offset+1] & 0xffL) << 8)
108 | | ((buf[offset ] & 0xffL));
109 | }
110 |
111 | @Override
112 | public String toString() {
113 | return String.format("%016X", h1) + String.format("%016X", h2) + ":" + len;
114 | }
115 |
116 | @Override
117 | public boolean equals(Object o) {
118 | if (o instanceof Hashed) {
119 | Hashed h = (Hashed) o;
120 | return h1 == h.h1 && h2 == h.h2 && len == h.len;
121 | }
122 | return false;
123 | }
124 |
125 | @Override
126 | public int hashCode() {
127 | // Return the 32 least significant bits of the hash
128 | return (int) (h1 & 0x00000000ffffffff);
129 | }
130 |
131 | public long hashCode64() {
132 | return h1 ^ h2;
133 | }
134 |
135 | public int getLength() {
136 | return len;
137 | }
138 |
139 | @Override
140 | public int compareTo(Hashed o) {
141 | // First, compare by the original string length...
142 | if (len != o.len) {
143 | return Integer.compare(len, o.len);
144 | }
145 |
146 | // In case of a tie, define a lexicographical order based on the hash
147 | if (h1 != o.h1) {
148 | return Long.compare(h1, o.h1);
149 | }
150 | return Long.compare(h2, o.h2);
151 | }
152 |
153 | }
154 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/LineageInfo.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.Serializable;
4 | import java.util.List;
5 | import java.util.stream.Collectors;
6 |
7 | /**
8 | * Instances of this class contain the lineage information associated with a quotation-speaker pair.
9 | *
10 | */
11 | public final class LineageInfo implements Serializable {
12 |
13 | private static final long serialVersionUID = -6380483748689471512L;
14 |
15 | private final List patterns;
16 | private final List sentences;
17 | private final List> aliases;
18 | private final double confidence;
19 |
20 | public LineageInfo(List patterns, List sentences, List> aliases, double confidence) {
21 | this.patterns = patterns;
22 | this.sentences = sentences;
23 | this.confidence = confidence;
24 | this.aliases = aliases;
25 | }
26 |
27 | public List getPatterns() {
28 | return patterns;
29 | }
30 |
31 | public List getSentences() {
32 | return sentences;
33 | }
34 |
35 | public List> getAliases() {
36 | return aliases;
37 | }
38 |
39 | public double getConfidence() {
40 | return confidence;
41 | }
42 |
43 | @Override
44 | public String toString() {
45 | return "{Confidence: " + confidence + ", Patterns: " + patterns + ", Sentences: "
46 | + sentences.stream().map(x -> "<" + x.getKey().toString() + "> " + x.toString()).collect(Collectors.toList()) + "}";
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/MultiCounter.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.Serializable;
4 | import java.util.HashMap;
5 | import java.util.Map;
6 |
7 | import org.apache.spark.util.LongAccumulator;
8 | import org.apache.spark.api.java.JavaSparkContext;
9 |
10 | public class MultiCounter implements Serializable {
11 |
12 | private static final long serialVersionUID = -5606791394577221239L;
13 |
14 | private final Map accumulators;
15 |
16 | public MultiCounter(JavaSparkContext sc, String... counters) {
17 | accumulators = new HashMap<>();
18 | for (String counter : counters) {
19 | accumulators.put(counter, sc.sc().longAccumulator());
20 | }
21 | }
22 |
23 | public void increment(String accumulator) {
24 | accumulators.get(accumulator).add(1);
25 | }
26 |
27 | public long getValue(String accumulator) {
28 | return accumulators.get(accumulator).value();
29 | }
30 |
31 | public void dump() {
32 | accumulators.forEach((k, v) -> System.out.println(k + ": " + v.value()));
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/NameDatabaseWikiData.java:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-dlab/Quotebank/a203acf5a05e34841671578d14d6a0be6a66a3ef/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/NameDatabaseWikiData.java
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/ParquetDatasetLoader.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.util.Set;
4 |
5 | import org.apache.spark.api.java.JavaRDD;
6 | import org.apache.spark.api.java.JavaSparkContext;
7 | import org.apache.spark.sql.Dataset;
8 | import org.apache.spark.sql.Encoders;
9 | import org.apache.spark.sql.Row;
10 | import org.apache.spark.sql.SparkSession;
11 |
12 | import scala.Tuple4;
13 |
14 | public class ParquetDatasetLoader implements DatasetLoader {
15 |
16 | private static SparkSession session;
17 |
18 | private static void initialize(JavaSparkContext sc) {
19 | if (session == null) {
20 | session = new SparkSession(sc.sc());
21 | }
22 | }
23 |
24 | @Override
25 | public JavaRDD loadArticles(JavaSparkContext sc, String datasetPath, Set languageFilter) {
26 | ParquetDatasetLoader.initialize(sc);
27 | Dataset df = ParquetDatasetLoader.session.read().parquet(datasetPath);
28 |
29 | /* Expected schema:
30 | * articleUID: String
31 | * website: String
32 | * content: String
33 | * date: String
34 | */
35 |
36 | return df.select(df.col("articleUID"), df.col("website"), df.col("content"), df.col("date"))
37 | .map(x -> {
38 | String uid = x.getString(0);
39 | String url = x.getString(1);
40 | String content = x.getString(2);
41 | String date = x.getString(3);
42 | return new Tuple4<>(uid, content, url, date);
43 | }, Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.STRING(), Encoders.STRING()))
44 | .javaRDD()
45 | .map(x -> new Spinn3rTextDatasetLoader.Article(x._1(), x._2(), x._3(), x._4()));
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Pattern.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.Serializable;
4 | import java.util.ArrayList;
5 | import java.util.Collections;
6 | import java.util.Iterator;
7 | import java.util.List;
8 | import java.util.regex.Matcher;
9 |
10 | /**
11 | * This class defines a pattern used for representing regular expressions.
12 | * In this context, a valid regular expression is defined by the following set of rules:
13 | * - The matching is done on a per-token basis (and not on a per-character basis).
14 | * - A token usually corresponds to a full word or to a punctuation mark.
15 | *
16 | * The pattern should/could contain:
17 | * - Exactly one quotation token $Q
18 | * - Exactly one speaker token $S (which can match a variable number of tokens)
19 | * - Any number of text tokens (i.e. generic tokens), which are matched exactly
20 | * - Any number of "ANY" tokens $* (equivalent to "." in Perl), which are matched 1 time
21 | *
22 | * Additionally, a pattern is subject to the following constraints:
23 | * - It must not start or end with an ANY token or a speaker token
24 | */
25 | public final class Pattern implements Serializable, Iterable, Comparable {
26 |
27 | private static final long serialVersionUID = -3877164912177196115L;
28 |
29 | public static final String QUOTATION_PLACEHOLDER = "$Q";
30 | public static final String SPEAKER_PLACEHOLDER = "$S";
31 | public static final String ANY_PLACEHOLDER = "$*";
32 |
33 | private final List tokens;
34 | private final double confidenceMetric;
35 |
36 | public Pattern(String expression, double confidence) {
37 | tokens = new ArrayList<>();
38 | String[] strTokens = expression.split(" ");
39 | for (String token : strTokens) {
40 | Token next;
41 | switch (token) {
42 | case QUOTATION_PLACEHOLDER:
43 | next = new Token(null, Token.Type.QUOTATION);
44 | break;
45 | case SPEAKER_PLACEHOLDER:
46 | next = new Token(null, Token.Type.SPEAKER);
47 | break;
48 | case ANY_PLACEHOLDER:
49 | next = new Token(null, Token.Type.ANY);
50 | break;
51 | default:
52 | next = new Token(token, Token.Type.GENERIC);
53 | break;
54 | }
55 | this.tokens.add(next);
56 | }
57 | confidenceMetric = confidence;
58 | sanityCheck();
59 | }
60 |
61 | public Pattern(List tokens, double confidence) {
62 | this.tokens = new ArrayList<>(tokens); // Defensive copy
63 | this.confidenceMetric = confidence;
64 | sanityCheck();
65 | }
66 |
67 | public Pattern(List tokens) {
68 | this(tokens, Double.NaN);
69 | }
70 |
71 | public Pattern(String expression) {
72 | this(expression, Double.NaN);
73 | }
74 |
75 | public List getTokens() {
76 | return Collections.unmodifiableList(tokens);
77 | }
78 |
79 | public int getTokenCount() {
80 | return tokens.size();
81 | }
82 |
83 | /**
84 | * @return the number of text tokens
85 | */
86 | public int getCardinality() {
87 | int count = 0;
88 | for (Token t : tokens) {
89 | if (t.getType() == Token.Type.GENERIC) {
90 | count++;
91 | }
92 | }
93 | return count;
94 | }
95 |
96 | public double getConfidenceMetric() {
97 | return confidenceMetric;
98 | }
99 |
100 | public boolean isSpeakerSurroundedByAny() {
101 | for (int i = 0; i < tokens.size(); i++) {
102 | if (tokens.get(i).getType() == Token.Type.SPEAKER) {
103 | return tokens.get(i - 1).getType() == Token.Type.ANY
104 | || tokens.get(i + 1).getType() == Token.Type.ANY;
105 | }
106 | }
107 | throw new IllegalStateException("Speaker not found in pattern");
108 | }
109 |
110 | private void sanityCheck() {
111 | if (tokens.isEmpty()) {
112 | throw new IllegalArgumentException("Invalid pattern: the pattern is empty");
113 | }
114 | boolean quotationFound = false;
115 | boolean speakerFound = false;
116 | for (Token t : tokens) {
117 | if (t.getType() == Token.Type.QUOTATION) {
118 | if (quotationFound) {
119 | throw new IllegalArgumentException("Invalid pattern: more than one quotation placeholder found");
120 | }
121 | quotationFound = true;
122 | } else if (t.getType() == Token.Type.SPEAKER) {
123 | if (speakerFound) {
124 | throw new IllegalArgumentException("Invalid pattern: more than one speaker placeholder found");
125 | }
126 | speakerFound = true;
127 | }
128 | }
129 | if (!quotationFound) {
130 | throw new IllegalArgumentException("Invalid pattern: no quotation placeholder found");
131 | }
132 | if (!speakerFound) {
133 | throw new IllegalArgumentException("Invalid pattern: no speaker placeholder found");
134 | }
135 | if (tokens.get(0).getType() == Token.Type.SPEAKER) {
136 | throw new IllegalArgumentException("Invalid pattern: the pattern must not start with a speaker placeholder");
137 | }
138 | if (tokens.get(tokens.size() - 1).getType() == Token.Type.SPEAKER) {
139 | throw new IllegalArgumentException("Invalid pattern: the pattern must not end with a speaker placeholder");
140 | }
141 | if (tokens.get(0).getType() == Token.Type.ANY) {
142 | throw new IllegalArgumentException("Invalid pattern: the pattern must not start with an 'any' placeholder");
143 | }
144 | if (tokens.get(tokens.size() - 1).getType() == Token.Type.ANY) {
145 | throw new IllegalArgumentException("Invalid pattern: the pattern must not end with an 'any' placeholder");
146 | }
147 | }
148 |
149 | @Override
150 | public String toString() {
151 | return toString(true);
152 | }
153 |
154 | public String toString(boolean addConfidence) {
155 | StringBuilder str = new StringBuilder();
156 | Iterator it = tokens.iterator();
157 | while (it.hasNext()) {
158 | Token t = it.next();
159 | switch (t.getType()) {
160 | case QUOTATION:
161 | str.append(QUOTATION_PLACEHOLDER);
162 | break;
163 | case SPEAKER:
164 | str.append(SPEAKER_PLACEHOLDER);
165 | break;
166 | case ANY:
167 | str.append(ANY_PLACEHOLDER);
168 | break;
169 | default:
170 | str.append(t.toString());
171 | break;
172 | }
173 | if (it.hasNext()) {
174 | str.append(' ');
175 | }
176 | }
177 |
178 | String output = str.toString();
179 | if (addConfidence && confidenceMetric >= 0) {
180 | output = output.replace("\"", "\\\""); // Escape character
181 | output = "[\"" + output + "\": " + confidenceMetric + "]";
182 | }
183 | return output;
184 | }
185 |
186 | public static Pattern parse(String input) {
187 | java.util.regex.Pattern pattern = java.util.regex.Pattern.compile("^\\[\\\"(.*)\": ([0-9.eE]+)\\]$");
188 | Matcher m = pattern.matcher(input);
189 | if (m.matches() && m.groupCount() == 2) {
190 | String p = m.group(1).replace("\\\"", "\"");
191 | double confidence = Double.parseDouble(m.group(2));
192 | return new Pattern(p, confidence);
193 | }
194 | throw new IllegalArgumentException("Invalid pattern format: " + input);
195 | }
196 |
197 | @Override
198 | public int hashCode() {
199 | return tokens.hashCode();
200 | }
201 |
202 | @Override
203 | public boolean equals(Object obj) {
204 | if (obj instanceof Pattern) {
205 | Pattern p = (Pattern) obj;
206 | // The comparison is done only on the content
207 | return p.tokens.equals(tokens);
208 | }
209 | return false;
210 | }
211 |
212 | @Override
213 | public Iterator iterator() {
214 | return Collections.unmodifiableList(tokens).iterator();
215 | }
216 |
217 | @Override
218 | public int compareTo(Pattern other) {
219 | // Define a lexicographical order for patterns
220 | int n = Math.min(tokens.size(), other.tokens.size());
221 | for (int i = 0; i < n; i++) {
222 | int comp = tokens.get(i).compareTo(other.tokens.get(i));
223 | if (comp != 0) {
224 | return comp;
225 | }
226 | }
227 | return Integer.compare(tokens.size(), other.tokens.size());
228 | }
229 | }
230 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/PatternExtractor.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 |
6 | public class PatternExtractor {
7 |
8 | public static Pattern extractPattern(Sentence s, String quotation,
9 | List speaker, boolean caseSensitive) {
10 |
11 | int quotationPtr = 0;
12 |
13 | int indexOfSpeaker = getIndexOfSpeaker(s, speaker, caseSensitive);
14 | if (indexOfSpeaker == -1 || indexOfSpeaker == 0) {
15 | // If match not found, or, if match found at position 0...
16 | return null;
17 | }
18 |
19 | boolean patternStarted = false;
20 | boolean hasQuotation = false;
21 | boolean hasSpeaker = false;
22 | List tokens = s.getTokens();
23 | List patternTokens = new ArrayList<>();
24 | for (int i = 0; i < tokens.size(); i++) {
25 |
26 | if (i == indexOfSpeaker) {
27 | if (patternTokens.isEmpty()) {
28 | // The pattern cannot start with the speaker
29 | patternTokens.add(tokens.get(i - 1));
30 | patternStarted = true;
31 | }
32 | patternTokens.add(new Token(null, Token.Type.SPEAKER));
33 | i += speaker.size() - 1;
34 | hasSpeaker = true;
35 | } else {
36 | switch (tokens.get(i).getType()) {
37 | case QUOTATION:
38 | if (quotationPtr < 1 && tokens.get(i).toString().equals(quotation)) {
39 | patternStarted = true;
40 | quotationPtr++;
41 | patternTokens.add(new Token(null, Token.Type.QUOTATION));
42 | } else {
43 | return null;
44 | }
45 | if (quotationPtr == 1) {
46 | hasQuotation = true;
47 | }
48 | break;
49 | default:
50 | if (hasQuotation && hasSpeaker && patternTokens.get(patternTokens.size() - 1).getType() != Token.Type.SPEAKER) {
51 | patternStarted = false;
52 | }
53 |
54 | if (patternStarted) {
55 | patternTokens.add(tokens.get(i));
56 | }
57 | if (hasQuotation && hasSpeaker) {
58 | patternStarted = false;
59 | }
60 | break;
61 | }
62 | }
63 | }
64 |
65 | if (quotationPtr < 1) {
66 | // The quotation has not been matched entirely
67 | return null;
68 | }
69 |
70 | // Ensure that the speaker token is not last token in the pattern
71 | if (patternTokens.get(patternTokens.size() - 1).getType() == Token.Type.SPEAKER) {
72 | return null;
73 | }
74 |
75 | return new Pattern(patternTokens);
76 | }
77 |
78 | private static int getIndexOfSpeaker(Sentence s, List speaker, boolean caseSensitive) {
79 | List tokens = s.getTokens();
80 |
81 | if (!caseSensitive) {
82 | tokens = Token.caseFold(tokens);
83 | speaker = Token.caseFold(speaker);
84 | }
85 |
86 | for (Token t : speaker) {
87 | // Each token must be perfectly matched without duplicates (even if they are partial)
88 | int firstIndex = tokens.indexOf(t);
89 | int lastIndex = tokens.lastIndexOf(t);
90 | if (firstIndex == -1 || lastIndex != firstIndex) {
91 | return -1;
92 | }
93 | }
94 |
95 | for (int i = 0; i < tokens.size() - speaker.size() + 1; i++) {
96 | for (int j = 0; j < speaker.size(); j++) {
97 | if (!tokens.get(i + j).equals(speaker.get(j))) {
98 | break;
99 | }
100 | if (j == speaker.size() - 1) {
101 | // Match found!
102 | return i;
103 | }
104 | }
105 | }
106 |
107 | return -1;
108 | }
109 |
110 | }
111 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/PatternMatcher.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Collections;
5 | import java.util.List;
6 | import java.util.stream.Collectors;
7 |
8 | public abstract class PatternMatcher {
9 |
10 | protected List sentenceTokens;
11 | protected boolean speakerTokenFoundFlag;
12 | protected boolean quotationTokenFoundFlag;
13 |
14 | protected final List matchedQuotation;
15 | protected final List matchedSpeaker;
16 | protected final List matches;
17 |
18 | protected final int minSpeakerLength;
19 | protected final int maxSpeakerLength;
20 | protected final boolean caseSensitive;
21 |
22 | protected PatternMatcher(int speakerLengthMin, int speakerLengthMax, boolean caseSensitive) {
23 | matchedQuotation = new ArrayList<>();
24 | matchedSpeaker = new ArrayList<>();
25 | minSpeakerLength = speakerLengthMin;
26 | maxSpeakerLength = speakerLengthMax;
27 | this.caseSensitive = caseSensitive;
28 | matches = new ArrayList<>();
29 | }
30 |
31 | public abstract boolean match(Sentence s);
32 |
33 | public List getMatches(boolean longest) {
34 | if (longest && !matches.isEmpty()) {
35 | // We want only the pattern with the highest cardinality (i.e. number of text tokens)
36 | final int longestLength = Collections.max(matches,
37 | (x, y) -> Integer.compare(x.getPattern().getCardinality(), y.getPattern().getCardinality()))
38 | .getPattern().getCardinality();
39 |
40 | return matches.stream()
41 | .filter(x -> x.getPattern().getCardinality() == longestLength)
42 | .collect(Collectors.toList());
43 | } else {
44 | return Collections.unmodifiableList(matches);
45 | }
46 | }
47 |
48 | protected final boolean matchTokens(Token patternToken, Token token, int speakerTokensLeft) {
49 | switch (patternToken.getType()) {
50 | case GENERIC:
51 | if (token.getType() == Token.Type.GENERIC) {
52 | String tokenStr = token.toString();
53 | String patternTokenStr = patternToken.toString();
54 | if (caseSensitive) {
55 | return tokenStr.equals(patternTokenStr);
56 | } else {
57 | return tokenStr.equalsIgnoreCase(patternTokenStr);
58 | }
59 | }
60 | return false;
61 | case SPEAKER:
62 | if (speakerTokensLeft > 0 && token.getType() == Token.Type.GENERIC) {
63 | speakerTokenFoundFlag = true;
64 | matchedSpeaker.add(token);
65 | return true;
66 | }
67 | return false;
68 | case QUOTATION:
69 | if (token.getType() == Token.Type.QUOTATION) {
70 | matchedQuotation.add(token);
71 | quotationTokenFoundFlag = true;
72 | return true;
73 | }
74 | return false;
75 | case ANY:
76 | return true;
77 | default:
78 | throw new IllegalStateException();
79 | }
80 | }
81 |
82 | public final static class Match {
83 | private final String matchedQuotation;
84 | private final List matchedSpeaker;
85 | private final Pattern matchedPattern;
86 |
87 | public Match(List quotation, List speaker, Pattern pattern) {
88 | if (quotation.size() != 1) {
89 | throw new IllegalArgumentException("Invalid quotation token");
90 | }
91 | matchedQuotation = quotation.get(0).toString();
92 | matchedSpeaker = new ArrayList<>(speaker);
93 | matchedPattern = pattern;
94 | }
95 |
96 | public final String getQuotation() {
97 | return matchedQuotation;
98 | }
99 |
100 | public final List getSpeaker() {
101 | return matchedSpeaker;
102 | }
103 |
104 | public final Pattern getPattern() {
105 | return matchedPattern;
106 | }
107 | }
108 |
109 | }
110 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Sentence.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.Serializable;
4 | import java.util.ArrayList;
5 | import java.util.Collections;
6 | import java.util.Iterator;
7 | import java.util.List;
8 | import java.util.function.Predicate;
9 | import java.util.stream.Collectors;
10 |
11 | import ch.epfl.dlab.spinn3r.Tokenizer;
12 | import ch.epfl.dlab.spinn3r.TokenizerImpl;
13 | import scala.Tuple2;
14 | import scala.Tuple5;
15 |
16 | /**
17 | * A "Sentence" is an extracted context from an article.
18 | * It consists of:
19 | * - A *single* quotation between quotation marks (or equivalent HTML tags)
20 | * - The context surrounding the quotation (before and after)
21 | */
22 | public final class Sentence implements Serializable, Iterable, Comparable {
23 |
24 | private static final long serialVersionUID = -4971549043386733522L;
25 |
26 | private final List tokens;
27 | private final String articleUid;
28 | private final int index;
29 | private final int quotationOffset;
30 | private final int leftOffset;
31 | private final int rightOffset;
32 | private transient int cachedHashCode;
33 |
34 | /**
35 | * Constructs a new Sentence from the given list of tokens.
36 | * @param tokens The list of tokens (must contain one quotation token)
37 | * @param articleUid The UID of the article from which this sentence has been extracted
38 | * @param index The index of this sentence within the article (starting from 0)
39 | * @param skipChecks Skip safety checks (for debug or test purposes)
40 | */
41 | public Sentence(List tokens, String articleUid, int index, int quotationOffset, int leftOffset, int rightOffset, boolean skipChecks) {
42 | this.tokens = tokens;
43 | this.articleUid = articleUid;
44 | this.index = index;
45 | this.quotationOffset = quotationOffset;
46 | this.leftOffset = leftOffset;
47 | this.rightOffset = rightOffset;
48 | if (!skipChecks) {
49 | sanityCheck();
50 | }
51 | }
52 |
53 | /**
54 | * Constructs a new Sentence from the given list of tokens.
55 | * @param tokens The list of tokens (must contain one quotation token)
56 | * @param articleUid The UID of the article from which this sentence has been extracted
57 | * @param index The index of this sentence within the article (starting from 0)
58 | */
59 | public Sentence(List tokens, String articleUid, int index, int quotationOffset, int leftOffset, int rightOffset) {
60 | this(tokens, articleUid, index, quotationOffset, leftOffset, rightOffset, false);
61 | }
62 |
63 | /**
64 | * Constructs a new Sentence from the given list of tokens.
65 | * @param tokens The list of tokens (must contain one quotation token)
66 | */
67 | public Sentence(List tokens) {
68 | this(tokens, "", -1, 0, 0, 0, false);
69 | }
70 |
71 | /**
72 | * Constructs a new Sentence from the given list of tokens.
73 | * @param tokens The list of tokens (must contain one quotation token)
74 | * @param skipChecks Skip safety checks (for debug or test purposes)
75 | */
76 | public Sentence(List tokens, boolean skipChecks) {
77 | this(tokens, "", -1, 0, 0, 0, skipChecks);
78 | }
79 |
80 | public boolean matches(Predicate> predicate) {
81 | return predicate.test(tokens.iterator());
82 | }
83 |
84 | public List getTokens() {
85 | return tokens;
86 | }
87 |
88 | public List getTokensByType(Token.Type type) {
89 | return tokens.stream()
90 | .filter(x -> x.getType() == type)
91 | .collect(Collectors.toCollection(ArrayList::new));
92 | }
93 |
94 | public String getQuotation() {
95 | for (Token t : tokens) {
96 | if (t.getType() == Token.Type.QUOTATION) {
97 | return t.toString();
98 | }
99 | }
100 | throw new IllegalStateException("No quotation found in this sentence");
101 | }
102 |
103 | public String getArticleUid() {
104 | return articleUid;
105 | }
106 |
107 | public int getIndex() {
108 | return index;
109 | }
110 |
111 | public int getQuotationOffset() {
112 | return quotationOffset;
113 | }
114 |
115 | /**
116 | * @return the leftOffset
117 | */
118 | public int getLeftOffset() {
119 | return leftOffset;
120 | }
121 |
122 | /**
123 | * @return the rightOffset
124 | */
125 | public int getRightOffset() {
126 | return rightOffset;
127 | }
128 |
129 | /**
130 | * Gets the key that uniquely identifies this Sentence.
131 | * @return An (Article UID, index within article) pair
132 | */
133 | public Tuple2 getKey() {
134 | return new Tuple2<>(articleUid, index);
135 | }
136 |
137 | /**
138 | * Gets all the attributes of this Sentence.
139 | * @return An (Article UID, index within article, left Offset, right Offset) pair
140 | */
141 | public Tuple5 getInfo() {
142 | return new Tuple5<>(articleUid, index, quotationOffset, leftOffset, rightOffset);
143 | }
144 |
145 | public int getTokenCount() {
146 | return tokens.size();
147 | }
148 |
149 | @Override
150 | public String toString() {
151 | Iterator it = tokens.iterator();
152 | StringBuilder str = new StringBuilder();
153 | while (it.hasNext()) {
154 | Token next = it.next();
155 | if (next.getType() == Token.Type.QUOTATION) {
156 | str.append('[');
157 | }
158 | str.append(next.toString());
159 | if (next.getType() == Token.Type.QUOTATION) {
160 | str.append(']');
161 | }
162 | if (it.hasNext()) {
163 | str.append(' ');
164 | }
165 |
166 | }
167 | return str.toString();
168 | }
169 |
170 | public String toHumanReadableString(boolean htmlQuotation) {
171 | if (tokens.isEmpty()) {
172 | return "";
173 | }
174 |
175 | Tokenizer tokenizer = new TokenizerImpl();
176 | List buffer = new ArrayList<>();
177 | StringBuilder result = new StringBuilder();
178 | for (int i = 0; i < tokens.size(); i++) {
179 | Token t = tokens.get(i);
180 | if (t.getType() == Token.Type.GENERIC && StaticRules.isHtmlTag(t.toString())) {
181 | continue; // Discard HTML tags
182 | }
183 |
184 | if (i == 0 && t.getType() == Token.Type.GENERIC && (t.toString().equals(".") || t.toString().equals(","))) {
185 | continue; // Discard leading punctuation
186 | }
187 |
188 | if (t.getType() != Token.Type.QUOTATION) {
189 | buffer.add(t.toString());
190 | } else {
191 | if (!buffer.isEmpty()) {
192 | result.append(tokenizer.untokenize(buffer));
193 | result.append(" ");
194 | }
195 |
196 | if (htmlQuotation) {
197 | result.append("" + t.toString() + "
");
198 | } else {
199 | result.append("\"" + t.toString() + "\"");
200 | }
201 | buffer.clear();
202 | }
203 | }
204 | if (!buffer.isEmpty()) {
205 | result.append(" ");
206 | result.append(tokenizer.untokenize(buffer));
207 | }
208 |
209 |
210 | return result.toString();
211 | }
212 |
213 | @Override
214 | public int hashCode() {
215 | // Racy single-check idiom
216 | int h = cachedHashCode;
217 | if (h == 0) {
218 | h = tokens.hashCode();
219 | cachedHashCode = h;
220 | }
221 | return h;
222 | }
223 |
224 | @Override
225 | public boolean equals(Object obj) {
226 | if (obj instanceof Sentence) {
227 | Sentence p = (Sentence) obj;
228 | return p.tokens.equals(tokens);
229 | }
230 | return false;
231 | }
232 |
233 | @Override
234 | public Iterator iterator() {
235 | return Collections.unmodifiableList(tokens).iterator();
236 | }
237 |
238 | @Override
239 | public int compareTo(Sentence other) {
240 | // Define a lexicographical order for sentences
241 | int n = Math.min(tokens.size(), other.tokens.size());
242 | for (int i = 0; i < n; i++) {
243 | int comp = tokens.get(i).compareTo(other.tokens.get(i));
244 | if (comp != 0) {
245 | return comp;
246 | }
247 | }
248 | return Integer.compare(tokens.size(), other.tokens.size());
249 | }
250 |
251 | private void sanityCheck() {
252 | if (tokens.isEmpty()) {
253 | throw new IllegalArgumentException("Invalid sentence: the sentence is empty");
254 | }
255 |
256 | boolean quotationFound = false;
257 | for (Token t : tokens) {
258 | if (t.getType() == Token.Type.QUOTATION) {
259 | if (quotationFound) {
260 | throw new IllegalArgumentException("Invalid sentence: more than one quotation placeholder found");
261 | }
262 | quotationFound = true;
263 | }
264 | }
265 | if (!quotationFound) {
266 | throw new IllegalArgumentException("Invalid sentence: no quotation placeholder found");
267 | }
268 | }
269 | }
270 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/SimplePatternMatcher.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 |
6 | public class SimplePatternMatcher extends PatternMatcher {
7 |
8 | private final List patterns;
9 | private List patternTokens;
10 | private Pattern currentPattern;
11 |
12 | public SimplePatternMatcher(Pattern pattern,
13 | int speakerLengthMin, int speakerLengthMax, boolean caseSensitive) {
14 | super(speakerLengthMin, speakerLengthMax, caseSensitive);
15 | this.patterns = new ArrayList<>();
16 | patterns.add(pattern);
17 | }
18 |
19 | public SimplePatternMatcher(List patterns,
20 | int speakerLengthMin, int speakerLengthMax, boolean caseSensitive) {
21 | super(speakerLengthMin, speakerLengthMax, caseSensitive);
22 | this.patterns = new ArrayList<>(patterns);
23 | }
24 |
25 | @Override
26 | public boolean match(Sentence s) {
27 | sentenceTokens = s.getTokens();
28 |
29 | matches.clear();
30 | boolean result = false;
31 | for (int i = 0; i < sentenceTokens.size(); i++) {
32 | for (Pattern currentPattern : patterns) {
33 | this.currentPattern = currentPattern;
34 | patternTokens = currentPattern.getTokens();
35 |
36 | if (matchImpl(0, i, maxSpeakerLength)) {
37 | result = true;
38 | }
39 |
40 | assert matchedQuotation.isEmpty();
41 | assert matchedSpeaker.isEmpty();
42 | assert !speakerTokenFoundFlag;
43 | assert !quotationTokenFoundFlag;
44 | }
45 | }
46 |
47 | return result;
48 | }
49 |
50 | private boolean matchImpl(int i, int j, int speakerTokensLeft) {
51 | if (i == patternTokens.size()) {
52 | // End of pattern reached
53 | matches.add(new Match(matchedQuotation, matchedSpeaker, currentPattern));
54 | return true;
55 | }
56 |
57 | if (j == sentenceTokens.size()) {
58 | // End of sentence reached: no match possible
59 | return false;
60 | }
61 |
62 | if (matchTokens(patternTokens.get(i), sentenceTokens.get(j), speakerTokensLeft)) {
63 | if (speakerTokenFoundFlag) {
64 | speakerTokenFoundFlag = false;
65 | boolean m1 = false;
66 | if (matchedSpeaker.size() >= minSpeakerLength) {
67 | m1 = matchImpl(i + 1, j + 1, 0);
68 | }
69 | boolean m2 = matchImpl(i, j + 1, speakerTokensLeft - 1);
70 | matchedSpeaker.remove(matchedSpeaker.size() - 1);
71 | return m1 || m2;
72 |
73 | } else if (quotationTokenFoundFlag) {
74 | quotationTokenFoundFlag = false;
75 | boolean m = matchImpl(i + 1, j + 1, speakerTokensLeft);
76 | matchedQuotation.remove(matchedQuotation.size() - 1);
77 | return m;
78 |
79 | } else {
80 | return matchImpl(i + 1, j + 1, speakerTokensLeft);
81 | }
82 | }
83 |
84 | return false;
85 | }
86 | }
87 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/SpeakerAlias.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.Serializable;
4 | import java.util.ArrayList;
5 | import java.util.List;
6 |
7 | import scala.Tuple2;
8 |
9 | public class SpeakerAlias implements Serializable {
10 |
11 | private static final long serialVersionUID = -6037497222642880562L;
12 |
13 | private final List aliasTokens;
14 | private final List> ids;
15 | private final List> offsets;
16 |
17 | public SpeakerAlias(List aliasTokens, List> ids, Tuple2 offset) {
18 | this.aliasTokens = aliasTokens;
19 | this.ids = ids;
20 | this.offsets = new ArrayList<>();
21 | this.offsets.add(offset);
22 | }
23 |
24 | public List getAlias() {
25 | return aliasTokens;
26 | }
27 |
28 | public List> getIds() {
29 | return ids;
30 | }
31 |
32 | public List> getOffsets() {
33 | return offsets;
34 | }
35 |
36 | public void addOffset(Tuple2 offset) {
37 | offsets.add(offset);
38 | }
39 |
40 | @Override
41 | public int hashCode() {
42 | return aliasTokens.hashCode() ^ ids.hashCode();
43 | }
44 |
45 | @Override
46 | public boolean equals(Object o) {
47 | if (o instanceof SpeakerAlias) {
48 | SpeakerAlias oa = (SpeakerAlias) o;
49 | return oa.ids.equals(ids); //oa.aliasTokens.equals(aliasTokens) &&
50 | }
51 | return false;
52 | }
53 |
54 | }
55 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Spinn3rDatasetLoader.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.Serializable;
4 | import java.net.MalformedURLException;
5 | import java.net.URL;
6 | import java.util.Arrays;
7 | import java.util.Collections;
8 | import java.util.List;
9 | import java.util.Locale;
10 | import java.util.Set;
11 |
12 | import org.apache.spark.api.java.JavaRDD;
13 | import org.apache.spark.api.java.JavaSparkContext;
14 |
15 | import com.google.gson.Gson;
16 |
17 | import ch.epfl.dlab.spinn3r.EntryWrapper;
18 |
19 | public class Spinn3rDatasetLoader implements DatasetLoader, Serializable {
20 |
21 | private static final long serialVersionUID = 7689367526740335445L;
22 |
23 | private final boolean caseFold;
24 |
25 | public Spinn3rDatasetLoader() {
26 | caseFold = ConfigManager.getInstance().isCaseFoldingEnabled();
27 | }
28 |
29 | @Override
30 | public JavaRDD loadArticles(JavaSparkContext sc, String datasetPath, Set languageFilter) {
31 | return sc.textFile(datasetPath)
32 | .map(x -> new Gson().fromJson(x, EntryWrapper.class).getPermalinkEntry())
33 | //.filter(x -> languageFilter.contains(x.getLanguage()))
34 | .map(x -> {
35 | String domain;
36 | try {
37 | final URL url = new URL(x.getUrl());
38 | domain = url.getHost();
39 | } catch (MalformedURLException e) {
40 | domain = "";
41 | }
42 |
43 | String time;
44 | if (!x.getLastPublished().isEmpty() && !x.getDateFound().isEmpty()) {
45 | time = Collections.min(Arrays.asList(x.getLastPublished(), x.getDateFound()));
46 | } else if (!x.getLastPublished().isEmpty()) {
47 | time = x.getLastPublished();
48 | } else {
49 | time = x.getDateFound();
50 | }
51 |
52 | List tokenizedContent = x.getContent();
53 | if (caseFold) {
54 | tokenizedContent.replaceAll(token -> token.toLowerCase(Locale.ROOT));
55 | }
56 |
57 | return new Article(Long.parseLong(x.getIdentifier()), tokenizedContent, domain, time);
58 | });
59 | }
60 |
61 | public static class Article implements DatasetLoader.Article {
62 |
63 | private final long articleUID;
64 | private final List articleContent;
65 | private final String website;
66 | private final String date;
67 |
68 | public Article(long articleUID, List articleContent, String website, String date) {
69 | this.articleUID = articleUID;
70 | this.articleContent = articleContent;
71 | this.website = website;
72 | this.date = date;
73 | }
74 |
75 | @Override
76 | public String getArticleUID() {
77 | return Long.toString(articleUID);
78 | }
79 |
80 | @Override
81 | public List getArticleContent() {
82 | return articleContent;
83 | }
84 |
85 | @Override
86 | public String getWebsite() {
87 | return website;
88 | }
89 |
90 | @Override
91 | public String getDate() {
92 | return date;
93 | }
94 |
95 | @Override
96 | public String getVersion() {
97 | return null;
98 | }
99 |
100 | @Override
101 | public String getTitle() {
102 | return null;
103 | }
104 |
105 |
106 |
107 | }
108 |
109 | }
110 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Spinn3rTextDatasetLoader.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.Serializable;
4 | import java.text.SimpleDateFormat;
5 | import java.util.ArrayList;
6 | import java.util.Arrays;
7 | import java.util.HashMap;
8 | import java.util.HashSet;
9 | import java.util.List;
10 | import java.util.Locale;
11 | import java.util.Set;
12 |
13 | import org.apache.hadoop.conf.Configuration;
14 | import org.apache.hadoop.io.LongWritable;
15 | import org.apache.hadoop.io.Text;
16 | import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
17 | import org.apache.spark.SparkConf;
18 | import org.apache.spark.api.java.JavaRDD;
19 | import org.apache.spark.api.java.JavaSparkContext;
20 |
21 |
22 | public class Spinn3rTextDatasetLoader implements DatasetLoader, Serializable {
23 |
24 | private static final long serialVersionUID = 1432174520856283513L;
25 |
26 | private final boolean caseFold;
27 | private transient Configuration config = null;
28 |
29 | public Spinn3rTextDatasetLoader() {
30 | caseFold = ConfigManager.getInstance().isCaseFoldingEnabled();
31 | }
32 |
33 | @Override
34 | public JavaRDD loadArticles(JavaSparkContext sc, String datasetPath, Set languageFilter) {
35 | if (config == null) {
36 | // Copy configuration
37 | config = new Configuration(sc.hadoopConfiguration());
38 | config.set("textinputformat.record.delimiter", "\n\n");
39 | }
40 |
41 | JavaRDD records = sc.newAPIHadoopFile(datasetPath, TextInputFormat.class, LongWritable.class, Text.class, config)
42 | .map(x -> new Spinn3rDocument(x._2.toString()))
43 | .filter(x -> !x.isGarbled
44 | && x.content != null && !x.content.isEmpty()
45 | && x.date != null
46 | && x.urlString != null
47 | //&& languageFilter.contains(x.getProbableLanguage())
48 | )
49 | .map(x -> {
50 | final String id = x.docId;
51 | final String domain = x.urlString;
52 | final String time = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(x.date);
53 | final String version = x.version.name();
54 | final String title = x.title;
55 |
56 | String content = x.content;
57 | if (caseFold) {
58 | content = content.toLowerCase(Locale.ROOT);
59 | }
60 |
61 | return new Article(id, content, domain, time, version, title);
62 | });
63 | //.mapToPair(x -> new Tuple2<>(x.getArticleUID(), x))
64 | //.reduceByKey((x, y) -> x.getArticleUID() < y.getArticleUID() ? x : y) // Deterministic distinct
65 | //.map(x -> x._2);
66 |
67 | /*System.out.println(sc.newAPIHadoopFile(datasetPath, TextInputFormat.class, LongWritable.class, Text.class, config)
68 | .map(x -> new Spinn3rDocument(x._2.toString()))
69 | .take(2));
70 |
71 | System.out.println(records.take(1));
72 | System.exit(0);*/
73 |
74 | //String suffix = ConfigManager.getInstance().getLangSuffix();
75 | //records = Utils.loadCache(records, "documents-" + suffix);
76 | return records;
77 | }
78 |
79 | public static void main(String[] args) {
80 | final SparkConf conf = new SparkConf()
81 | .setAppName("QuotationExtraction")
82 | .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
83 | .registerKryoClasses(new Class>[] { ArrayList.class, Token.class, Token.Type.class, Sentence.class, Pattern.class,
84 | Trie.class, Trie.Node.class, String[].class, Object[].class, HashMap.class, Hashed.class });
85 |
86 | if (ConfigManager.getInstance().isLocalModeEnabled()) {
87 | conf.setMaster("local[*]")
88 | .set("spark.executor.memory", "8g")
89 | .set("spark.driver.memory", "8g");
90 | }
91 |
92 | try (JavaSparkContext sc = new JavaSparkContext(conf)) {
93 | Spinn3rTextDatasetLoader loader = new Spinn3rTextDatasetLoader();
94 |
95 | Set langFilter = new HashSet<>(ConfigManager.getInstance().getLangFilter());
96 | loader.loadArticles(sc, "C:\\Users\\Dario\\Documents\\EPFL\\part-r-00000.snappy", langFilter);
97 | }
98 |
99 |
100 | }
101 |
102 | public static class Article implements DatasetLoader.Article, Serializable {
103 |
104 | private static final long serialVersionUID = -5411421564171041258L;
105 |
106 | private final String articleUID;
107 | private final String articleContent;
108 | private final String website;
109 | private final String date;
110 | private final String version;
111 | private final String title;
112 |
113 | public Article(String articleUID, String articleContent, String website, String date, String version, String title) {
114 | this.articleUID = articleUID;
115 | this.articleContent = articleContent;
116 | this.website = website;
117 | this.date = date;
118 | this.version = version;
119 | this.title = title;
120 | }
121 |
122 | public Article(String articleUID, String articleContent, String website, String date) {
123 | this(articleUID, articleContent, website, date, "", "");
124 | }
125 |
126 | @Override
127 | public String getArticleUID() {
128 | return articleUID;
129 | }
130 |
131 | @Override
132 | public List getArticleContent() {
133 | // Construct on the fly
134 | return new ArrayList<>(Arrays.asList(articleContent.split(" ")));
135 | }
136 |
137 | @Override
138 | public String getWebsite() {
139 | return website;
140 | }
141 |
142 | @Override
143 | public String getDate() {
144 | return date;
145 | }
146 |
147 | @Override
148 | public String getVersion() {
149 | return version;
150 | }
151 |
152 | @Override
153 | public String getTitle() {
154 | return title;
155 | }
156 |
157 | @Override
158 | public String toString() {
159 | return articleUID + ": " + articleContent;
160 | }
161 |
162 | }
163 |
164 | }
165 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/StaticRules.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Collections;
5 | import java.util.List;
6 | import java.util.Optional;
7 |
8 | import org.apache.commons.lang.StringEscapeUtils;
9 |
10 | public class StaticRules {
11 |
12 | private static final java.util.regex.Pattern NONASCII_REGEX = java.util.regex.Pattern.compile("[^\u0000-\u007F]+");
13 | private static final java.util.regex.Pattern QUEST_REGEX = java.util.regex.Pattern.compile("\\?+");
14 |
15 | public static boolean isHtmlTag(String token) {
16 | return token.startsWith("<") && token.endsWith(">");
17 | }
18 |
19 | public static boolean isPunctuation(String token) {
20 | return token.equals(",") || token.equals(".");
21 | }
22 |
23 | public static String normalizePre(String str) {
24 | str = str.toLowerCase();
25 | str = NONASCII_REGEX.matcher(str).replaceAll("?");
26 | str = QUEST_REGEX.matcher(str).replaceAll("?");
27 | str = StringEscapeUtils.unescapeHtml(str);
28 | // Need to do this again because HTML-decoding may reintroduce non-ASCII characters.
29 | str = NONASCII_REGEX.matcher(str).replaceAll("?");
30 | str = QUEST_REGEX.matcher(str).replaceAll("?");
31 |
32 | return str;
33 | }
34 |
35 | public static String normalize(String str) {
36 | str = str.toLowerCase();
37 | str = NONASCII_REGEX.matcher(str).replaceAll("?");
38 | str = QUEST_REGEX.matcher(str).replaceAll("?");
39 | str = StringEscapeUtils.unescapeHtml(str);
40 | // Need to do this again because HTML-decoding may reintroduce non-ASCII characters.
41 | str = NONASCII_REGEX.matcher(str).replaceAll("?");
42 | str = QUEST_REGEX.matcher(str).replaceAll("?");
43 |
44 | // Post-process
45 | String[] inTokens = str.split(" ");
46 | List outTokens = new ArrayList<>();
47 | for (int i = 0; i < inTokens.length; i++) {
48 | String token = inTokens[i];
49 |
50 | boolean added = false;
51 | if (token.length() > 1) {
52 | if (token.startsWith("?")) {
53 | outTokens.add("?");
54 | token = token.substring(1);
55 | }
56 | if (token.endsWith("?")) {
57 | token = token.substring(0, token.length() - 1);
58 | if (token.contains("?") && token.contains("-")) {
59 | token = String.join(" - ", token.split("-"));
60 | }
61 | outTokens.add(token);
62 | outTokens.add("?");
63 | added = true;
64 | }
65 | }
66 |
67 | if (!added) {
68 | if (token.contains("?") && token.contains("-")) {
69 | token = String.join(" - ", token.split("-"));
70 | }
71 | outTokens.add(token);
72 | }
73 | }
74 |
75 | return String.join(" ", outTokens);
76 | }
77 |
78 | public static String canonicalizeQuotation(String str) {
79 |
80 | // Normalize
81 | str = normalize(str);
82 |
83 | StringBuilder sb = new StringBuilder();
84 | str.codePoints()
85 | .filter(c -> Character.isWhitespace(c) || Character.isLetterOrDigit(c) || c == '?')
86 | .map(c -> Character.isWhitespace(c) ? ' ' : c)
87 | .mapToObj(c -> Character.isAlphabetic(c) ? Character.toLowerCase(c) : c)
88 | .forEach(sb::appendCodePoint);
89 |
90 | return sb.toString()
91 | .trim()
92 | .replaceAll(" +", " "); // Remove double (or more) spaces
93 | }
94 |
95 | public static boolean matchSpeakerApprox(List first, List second,
96 | boolean caseSensitive) {
97 | if (second == null) {
98 | return false;
99 | }
100 | if (!caseSensitive) {
101 | first = Token.caseFold(first);
102 | second = Token.caseFold(second);
103 | }
104 | // Return true if they have at least one token in common
105 | return !Collections.disjoint(first, second);
106 | }
107 |
108 | public static Optional> matchSpeakerApprox(List first,
109 | Iterable> choices, boolean caseSensitive) {
110 | // Return the match with the highest number of tokens in common
111 | Optional> bestMatch = Optional.empty();
112 | int bestMatchLen = 0;
113 | boolean dirty = false; // Used to track conflicts
114 | for (List choice : choices) {
115 | int matches = 0;
116 | // O(n^2) loop, but it is fine since these lists are very small
117 | for (int i = 0; i < choice.size(); i++) {
118 | for (int j = 0; j < first.size(); j++) {
119 | boolean equals;
120 | if (caseSensitive) {
121 | equals = choice.get(i).equals(first.get(j));
122 | } else {
123 | equals = choice.get(i).equalsIgnoreCase(first.get(j));
124 | }
125 | if (equals) {
126 | matches++;
127 | break;
128 | }
129 | }
130 | }
131 | if (matches > bestMatchLen) {
132 | bestMatchLen = matches;
133 | bestMatch = Optional.of(choice);
134 | dirty = false;
135 | } else if (matches == bestMatchLen) {
136 | dirty = true;
137 | }
138 | }
139 |
140 | if (dirty && bestMatchLen > 1) {
141 | throw new IllegalStateException("Conflicting speakers during ground truth evaluation: "
142 | + first + " " + bestMatch.get());
143 | }
144 |
145 | if (bestMatchLen >= 2) {
146 | return bestMatch;
147 | } else {
148 | return Optional.empty();
149 | }
150 | }
151 |
152 | }
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Token.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.Serializable;
4 | import java.util.ArrayList;
5 | import java.util.Collection;
6 | import java.util.List;
7 | import java.util.Locale;
8 | import java.util.Objects;
9 | import java.util.stream.Collectors;
10 |
11 | import org.apache.commons.lang3.ObjectUtils;
12 |
13 | public final class Token implements Serializable, Comparable {
14 |
15 | private static final long serialVersionUID = 160449932354710762L;
16 |
17 | public enum Type {
18 | GENERIC, QUOTATION, SPEAKER, ANY;
19 | }
20 |
21 | private final String text;
22 | private final Type type;
23 |
24 | public Token(String text, Type type) {
25 | this.text = text;
26 | this.type = type;
27 | }
28 |
29 | public final Type getType() {
30 | return this.type;
31 | }
32 |
33 | public static List getStrings(Collection tokens) {
34 | return tokens.stream()
35 | .map(x -> x.text)
36 | .collect(Collectors.toList());
37 | }
38 |
39 | public static List getTokens(Collection tokens) {
40 | return tokens.stream()
41 | .map(x -> new Token(x, Type.GENERIC))
42 | .collect(Collectors.toList());
43 | }
44 |
45 | @Override
46 | public int hashCode() {
47 | int hashCode = Objects.hashCode(text);
48 |
49 | // Important, since the JVM uses the object hash code for enums
50 | switch (type) {
51 | case GENERIC:
52 | hashCode += 377280272;
53 | break;
54 | case QUOTATION:
55 | hashCode += 116811174;
56 | break;
57 | case SPEAKER:
58 | hashCode += 637353343;
59 | break;
60 | case ANY:
61 | hashCode += 265741433;
62 | break;
63 | }
64 | return hashCode;
65 | }
66 |
67 | @Override
68 | public boolean equals(Object obj) {
69 | if (obj instanceof Token) {
70 | Token t = (Token) obj;
71 | return type == t.type && Objects.equals(text, t.text);
72 | }
73 | return false;
74 | }
75 |
76 | public boolean equalsIgnoreCase(Token t) {
77 | if (t == null) {
78 | return false;
79 | }
80 |
81 | return type == t.type && text.equalsIgnoreCase(t.text);
82 | }
83 |
84 | public static List caseFold(List tokens) {
85 | return tokens.stream()
86 | .map(x -> {
87 | if (x.text != null) {
88 | String s = x.text.toLowerCase(Locale.ROOT);
89 | return new Token(s, x.type);
90 | }
91 | return x;
92 | })
93 | .collect(Collectors.toCollection(ArrayList::new));
94 | }
95 |
96 | @Override
97 | public String toString() {
98 | if (text == null && type == Type.QUOTATION) {
99 | return Pattern.QUOTATION_PLACEHOLDER;
100 | }
101 | if (text == null && type == Type.SPEAKER) {
102 | return Pattern.SPEAKER_PLACEHOLDER;
103 | }
104 | if (text == null && type == Type.ANY) {
105 | return Pattern.ANY_PLACEHOLDER;
106 | }
107 | return text;
108 | }
109 |
110 | @Override
111 | public int compareTo(Token other) {
112 | // Define a lexicographical order of tokens
113 | if (type == other.type)
114 | {
115 | return ObjectUtils.compare(text, other.text); // Handles nulls
116 | }
117 | return type.compareTo(other.type);
118 | }
119 |
120 | }
121 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Trie.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.Serializable;
4 | import java.util.ArrayList;
5 | import java.util.Collection;
6 | import java.util.Collections;
7 | import java.util.HashMap;
8 | import java.util.Iterator;
9 | import java.util.List;
10 | import java.util.Locale;
11 | import java.util.Map;
12 |
13 | public class Trie implements Serializable {
14 |
15 | private static final long serialVersionUID = -9164247688163451454L;
16 |
17 | private final Node rootNode;
18 | private final boolean caseSensitive;
19 |
20 | public Trie(Collection patterns, boolean caseSensitive) {
21 | this.rootNode = new RootNode(caseSensitive);
22 | this.caseSensitive = caseSensitive;
23 | List sortedPatterns = new ArrayList<>(patterns);
24 | sortedPatterns.sort(Collections.reverseOrder());
25 | sortedPatterns.forEach(this::insertPattern);
26 | }
27 |
28 | private void insertPattern(Pattern pattern) {
29 | Node current = rootNode;
30 |
31 | Iterator it = pattern.iterator();
32 | while (it.hasNext()) {
33 | Token token = it.next();
34 | Node next = null;
35 |
36 | if (token.getType() == Token.Type.GENERIC) {
37 | // Text token
38 | String key = caseSensitive
39 | ? token.toString()
40 | : token.toString().toLowerCase(Locale.ROOT);
41 | next = current.getTextChild(key);
42 | if (next == null) {
43 | next = it.hasNext()
44 | ? new InnerNode(token, caseSensitive)
45 | : new TerminalNode(token, pattern.getConfidenceMetric());
46 | ((RootNode) current).textChildren.put(key, next);
47 | }
48 | } else {
49 | // Special token
50 | for (Node candidateNext : current) {
51 | if (candidateNext.getToken().equals(token)) {
52 | next = candidateNext;
53 | break;
54 | }
55 | }
56 | if (next == null) {
57 | next = it.hasNext()
58 | ? new InnerNode(token, caseSensitive)
59 | : new TerminalNode(token, pattern.getConfidenceMetric());
60 | ((RootNode) current).children.add(next);
61 | }
62 | }
63 |
64 | current = next;
65 | }
66 | }
67 |
68 | public boolean isCaseSensitive() {
69 | return caseSensitive;
70 | }
71 |
72 | public List getAllPatterns() {
73 | List allPatterns = new ArrayList<>();
74 |
75 | List currentPattern = new ArrayList<>();
76 | DFS(rootNode, currentPattern, allPatterns);
77 | return allPatterns;
78 | }
79 |
80 | public Node getRootNode() {
81 | return rootNode;
82 | }
83 |
84 | private void DFS(Node current, List currentPattern, List allPatterns) {
85 | if (current.hasChildren()) {
86 | for (Node next : current.getTextChildren()) {
87 | currentPattern.add(next.getToken());
88 | DFS(next, currentPattern, allPatterns);
89 | currentPattern.remove(currentPattern.size() - 1);
90 | }
91 |
92 | for (Node next : current) {
93 | currentPattern.add(next.getToken());
94 | DFS(next, currentPattern, allPatterns);
95 | currentPattern.remove(currentPattern.size() - 1);
96 | }
97 | } else {
98 | // Terminal node
99 | allPatterns.add(new Pattern(new ArrayList<>(currentPattern), current.getConfidenceFactor()));
100 | }
101 | }
102 |
103 | public interface Node extends Iterable, Serializable {
104 | Token getToken();
105 | boolean hasChildren();
106 | Node getTextChild(String key);
107 | Iterable getTextChildren();
108 | double getConfidenceFactor();
109 | }
110 |
111 | private static class RootNode implements Node {
112 |
113 | private static final long serialVersionUID = -3680161357395993244L;
114 |
115 | private final Map textChildren;
116 | private final List children;
117 | private final boolean caseSensitive;
118 |
119 | public RootNode(boolean caseSensitive) {
120 | this.children = new ArrayList<>();
121 | this.textChildren = new HashMap<>();
122 | this.caseSensitive = caseSensitive;
123 | }
124 |
125 | @Override
126 | public Iterator iterator() {
127 | return children.iterator();
128 | }
129 |
130 | @Override
131 | public Token getToken() {
132 | throw new UnsupportedOperationException("No token in the root node");
133 | }
134 |
135 | @Override
136 | public boolean hasChildren() {
137 | return !children.isEmpty() || !textChildren.isEmpty();
138 | }
139 |
140 | @Override
141 | public String toString() {
142 | return "*ROOT*";
143 | }
144 |
145 | @Override
146 | public double getConfidenceFactor() {
147 | throw new UnsupportedOperationException("This is not a terminal node");
148 | }
149 |
150 | @Override
151 | public Node getTextChild(String key) {
152 | if (!caseSensitive) {
153 | key = key.toLowerCase(Locale.ROOT);
154 | }
155 | return textChildren.get(key);
156 | }
157 |
158 | @Override
159 | public Iterable getTextChildren() {
160 | return textChildren.values();
161 | }
162 | }
163 |
164 | private static class InnerNode extends RootNode {
165 |
166 | private static final long serialVersionUID = 8115648454995183879L;
167 |
168 | private final Token token;
169 |
170 | private InnerNode(Token t, boolean caseSensitive) {
171 | super(caseSensitive);
172 | token = t;
173 | }
174 |
175 | public Token getToken() {
176 | return token;
177 | }
178 |
179 | @Override
180 | public String toString() {
181 | return token.toString();
182 | }
183 | }
184 |
185 | private static class TerminalNode implements Node {
186 |
187 | private static final long serialVersionUID = 6535993027251722667L;
188 |
189 | private final Token token;
190 | private final double confidenceFactor;
191 |
192 | public TerminalNode(Token t, double confidence) {
193 | token = t;
194 | confidenceFactor = confidence;
195 | }
196 |
197 | @Override
198 | public Iterator iterator() {
199 | return Collections.emptyIterator();
200 | }
201 |
202 | @Override
203 | public Token getToken() {
204 | return token;
205 | }
206 |
207 | @Override
208 | public boolean hasChildren() {
209 | return false;
210 | }
211 |
212 | @Override
213 | public String toString() {
214 | return token.toString();
215 | }
216 |
217 | @Override
218 | public double getConfidenceFactor() {
219 | return confidenceFactor;
220 | }
221 |
222 | @Override
223 | public Node getTextChild(String key) {
224 | return null;
225 | }
226 |
227 | @Override
228 | public Iterable getTextChildren() {
229 | return Collections.emptyList();
230 | }
231 | }
232 | }
233 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/TriePatternMatcher.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 | import java.util.stream.Collectors;
6 |
7 | public class TriePatternMatcher extends PatternMatcher {
8 |
9 | private final Trie trie;
10 | private final List currentPattern;
11 | private double matchedConfidenceFactor;
12 |
13 | public TriePatternMatcher(Trie trie, int speakerLengthMin, int speakerLengthMax) {
14 | super(speakerLengthMin, speakerLengthMax, trie.isCaseSensitive());
15 | this.trie = trie;
16 | currentPattern = new ArrayList<>();
17 | matchedConfidenceFactor = Double.NaN;
18 | }
19 |
20 | @Override
21 | public boolean match(Sentence s) {
22 | sentenceTokens = s.getTokens();
23 |
24 | boolean result = false;
25 | matches.clear();
26 | for (int i = 0; i < sentenceTokens.size(); i++) {
27 |
28 | matchedConfidenceFactor = Double.NaN;
29 | if (sentenceTokens.get(i).getType() == Token.Type.GENERIC) {
30 | Trie.Node node = trie.getRootNode().getTextChild(sentenceTokens.get(i).toString());
31 | if (node != null) {
32 | currentPattern.add(node);
33 | if (matchImpl(node, i, maxSpeakerLength)) {
34 | result = true;
35 | }
36 | currentPattern.remove(currentPattern.size() - 1);
37 | }
38 | }
39 | for (Trie.Node node : trie.getRootNode()) {
40 | currentPattern.add(node);
41 | if (matchImpl(node, i, maxSpeakerLength)) {
42 | result = true;
43 | }
44 | currentPattern.remove(currentPattern.size() - 1);
45 | }
46 | }
47 |
48 | return result;
49 | }
50 |
51 | private Pattern getMatchedPattern() {
52 | return new Pattern(new ArrayList<>(currentPattern.stream()
53 | .map(x -> x.getToken())
54 | .collect(Collectors.toList())), matchedConfidenceFactor);
55 | }
56 |
57 | private boolean matchImplNext(Trie.Node node, int j, int speakerTokensLeft) {
58 | if (node.hasChildren()) {
59 | boolean result = false;
60 |
61 | if (j + 1 < sentenceTokens.size() && sentenceTokens.get(j + 1).getType() == Token.Type.GENERIC) {
62 | Trie.Node next = node.getTextChild(sentenceTokens.get(j + 1).toString());
63 | if (next != null) {
64 | currentPattern.add(next);
65 | if (matchImpl(next, j + 1, speakerTokensLeft)) {
66 | result = true;
67 | }
68 | currentPattern.remove(currentPattern.size() - 1);
69 | }
70 | }
71 | for (Trie.Node next : node) {
72 | currentPattern.add(next);
73 |
74 | if (matchImpl(next, j + 1, speakerTokensLeft)) {
75 | result = true;
76 | }
77 | currentPattern.remove(currentPattern.size() - 1);
78 | }
79 | return result;
80 | } else {
81 | // End of pattern reached (terminal node)
82 | matchedConfidenceFactor = node.getConfidenceFactor();
83 | matches.add(new Match(matchedQuotation, matchedSpeaker, getMatchedPattern()));
84 | return true;
85 | }
86 | }
87 |
88 | private boolean matchImpl(Trie.Node node, int j, int speakerTokensLeft) {
89 | if (j == sentenceTokens.size()) {
90 | // End of text reached. No matches are possible.
91 | return false;
92 | }
93 |
94 | if (matchTokens(node.getToken(), sentenceTokens.get(j), speakerTokensLeft)) {
95 | if (speakerTokenFoundFlag) {
96 | speakerTokenFoundFlag = false;
97 | boolean m1 = false;
98 | if (matchedSpeaker.size() >= minSpeakerLength) {
99 | m1 = matchImplNext(node, j, 0);
100 | }
101 | boolean m2 = matchImpl(node, j + 1, speakerTokensLeft - 1);
102 | matchedSpeaker.remove(matchedSpeaker.size() - 1);
103 | return m1 || m2;
104 |
105 | } else if (quotationTokenFoundFlag) {
106 | quotationTokenFoundFlag = false;
107 | boolean m = matchImplNext(node, j, speakerTokensLeft);
108 | matchedQuotation.remove(matchedQuotation.size() - 1);
109 | return m;
110 |
111 | } else {
112 | return matchImplNext(node, j, speakerTokensLeft);
113 | }
114 | }
115 |
116 | return false;
117 | }
118 | }
119 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/TupleExtractor.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.util.Iterator;
4 | import java.util.List;
5 |
6 | import ch.epfl.dlab.spinn3r.converter.AbstractDecoder;
7 | import scala.Tuple4;
8 |
9 | public class TupleExtractor
10 | // Tuple: (quotation, speaker, full context, pattern that extracted the quotation)
11 | extends AbstractDecoder, Sentence, Pattern>> {
12 |
13 | private final Iterator it;
14 | private final PatternMatcher patternMatcher;
15 |
16 | private Iterator currentMatch;
17 | private Sentence currentSentence;
18 |
19 | public TupleExtractor(Iterator input, PatternMatcher pm) {
20 | it = input;
21 | patternMatcher = pm;
22 | currentMatch = null;
23 | currentSentence = null;
24 | }
25 |
26 | @Override
27 | protected Tuple4, Sentence, Pattern> getNextImpl() {
28 | while (true) {
29 | if (currentMatch == null) {
30 | if (it.hasNext()) {
31 | currentSentence = it.next();
32 | if (patternMatcher.match(currentSentence)) {
33 | currentMatch = patternMatcher.getMatches(false).iterator();
34 | }
35 | } else {
36 | // No more results
37 | return null;
38 | }
39 | }
40 |
41 | if (currentMatch != null) {
42 | if (currentMatch.hasNext()) {
43 | PatternMatcher.Match m = currentMatch.next();
44 | return new Tuple4<>(m.getQuotation(), m.getSpeaker(), currentSentence, m.getPattern());
45 | } else {
46 | currentMatch = null;
47 | }
48 | }
49 | }
50 | }
51 |
52 | }
53 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Utils.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.quootstrap;
2 |
3 | import java.io.File;
4 | import java.io.IOException;
5 | import java.nio.file.Files;
6 | import java.nio.file.Paths;
7 | import java.util.Collection;
8 | import java.util.HashMap;
9 | import java.util.List;
10 | import java.util.Map;
11 | import java.util.stream.Collectors;
12 |
13 | import org.apache.commons.io.FileUtils;
14 | import org.apache.hadoop.fs.FileSystem;
15 | import org.apache.hadoop.fs.Path;
16 | import org.apache.spark.api.java.JavaPairRDD;
17 | import org.apache.spark.api.java.JavaRDD;
18 | import org.apache.spark.api.java.JavaSparkContext;
19 |
20 | import ch.epfl.dlab.spinn3r.converter.ProtoToJson;
21 | import scala.Tuple2;
22 |
23 | public class Utils {
24 |
25 | /**
26 | * Load or generate a cached RDD.
27 | * If caching is disabled, the RDD is saved and returned as-is.
28 | * If caching is enabled, the RDD is loaded from disk if it is already cached,
29 | * otherwise it is generated, saved, and returned.
30 | * @param rdd the rdd to cache
31 | * @param fileName the file name of the cache file
32 | * @return the output RDD (cached or generated)
33 | */
34 | public static JavaRDD loadCache(JavaRDD rdd, String fileName) {
35 | if (!ConfigManager.getInstance().isCacheEnabled()) {
36 | return rdd;
37 | }
38 | final String cacheDir = ConfigManager.getInstance().getCachePath();
39 | final String path = cacheDir + "/" + fileName;
40 | try {
41 | FileSystem hdfs = org.apache.hadoop.fs.FileSystem.get(rdd.context().hadoopConfiguration());
42 | Path hdfsPath = new Path(path);
43 | if (!hdfs.exists(hdfsPath)) {
44 | rdd.saveAsObjectFile(path);
45 | }
46 | return JavaSparkContext.fromSparkContext(rdd.context()).objectFile(path);
47 | } catch (IOException e) {
48 | throw new IllegalStateException(e);
49 | }
50 |
51 | }
52 |
53 | /**
54 | * Load or generate a cached PairRDD.
55 | * If caching is disabled, the RDD is saved and returned as-is.
56 | * If caching is enabled, the RDD is loaded from disk if it is already cached,
57 | * otherwise it is generated, saved, and returned.
58 | * @param rdd the rdd to cache
59 | * @param fileName the file name of the cache file
60 | * @return the output RDD (cached or generated)
61 | */
62 | public static JavaPairRDD loadCache(JavaPairRDD rdd, String fileName) {
63 | if (!ConfigManager.getInstance().isCacheEnabled()) {
64 | return rdd;
65 | }
66 | final String cacheDir = ConfigManager.getInstance().getCachePath();
67 | final String path = cacheDir + "/" + fileName;
68 | try {
69 | FileSystem hdfs = org.apache.hadoop.fs.FileSystem.get(rdd.context().hadoopConfiguration());
70 | Path hdfsPath = new Path(path);
71 | if (!hdfs.exists(hdfsPath)) {
72 | rdd.saveAsObjectFile(path);
73 | }
74 | return JavaPairRDD.fromJavaRDD(JavaSparkContext.fromSparkContext(rdd.context()).objectFile(path));
75 | } catch (IOException e) {
76 | throw new IllegalStateException(e);
77 | }
78 | }
79 |
80 | public static void dumpRDD(JavaRDD rdd, String fileName) {
81 | if (!ConfigManager.getInstance().isLocalModeEnabled()) {
82 | throw new IllegalArgumentException("The method dumpRDD can be used only in local mode");
83 | }
84 | fileName = ConfigManager.getInstance().getOutputPath() + fileName;
85 | FileUtils.deleteQuietly(new File(fileName));
86 | rdd.saveAsTextFile(fileName);
87 | try {
88 | ProtoToJson.mergeHdfsFile(fileName);
89 | } catch (IOException e) {
90 | throw new IllegalStateException(e);
91 | }
92 | }
93 |
94 | public static void dumpRDDLocal(JavaRDD rdd, String fileName) {
95 | fileName = ConfigManager.getInstance().getOutputPath() + fileName;
96 | FileUtils.deleteQuietly(new File(fileName));
97 | List lines = rdd.collect()
98 | .stream()
99 | .map(x -> x.toString())
100 | .collect(Collectors.toList());
101 | try {
102 | Files.write(Paths.get(fileName), lines);
103 | } catch (IOException e) {
104 | throw new IllegalStateException(e);
105 | }
106 | }
107 |
108 | public static void dumpCollection(Collection data, String fileName) {
109 | fileName = ConfigManager.getInstance().getOutputPath() + fileName;
110 | List lines = data.stream()
111 | .map(x -> x.toString())
112 | .sorted()
113 | .collect(Collectors.toList());
114 |
115 | try {
116 | Files.write(Paths.get(fileName), lines);
117 | } catch (IOException e) {
118 | throw new IllegalStateException(e);
119 | }
120 | }
121 |
122 | public static void dumpRDD(JavaPairRDD rdd, String fileName) {
123 | if (!ConfigManager.getInstance().isLocalModeEnabled()) {
124 | throw new IllegalArgumentException("The method dumpRDD can be used only in local mode");
125 | }
126 | FileUtils.deleteQuietly(new File(fileName));
127 | rdd.saveAsTextFile(fileName);
128 | try {
129 | ProtoToJson.mergeHdfsFile(fileName);
130 | } catch (IOException e) {
131 | throw new IllegalStateException(e);
132 | }
133 | }
134 |
135 | public static List findLongestSuperstring(List needle, Iterable> haystack) {
136 | List bestMatch = null;
137 | boolean dirty = false;
138 | for (List candidate : haystack) {
139 | if (bestMatch == null || candidate.size() >= bestMatch.size()) {
140 | for (int i = 0; i < candidate.size() - needle.size() + 1; i++) {
141 | List subCandidate = candidate.subList(i, i + needle.size());
142 | if (subCandidate.equals(needle)) {
143 | if (bestMatch == null || candidate.size() > bestMatch.size()) {
144 | bestMatch = candidate;
145 | dirty = false;
146 | } else {
147 | dirty = true;
148 | }
149 | }
150 | }
151 | }
152 | }
153 |
154 | // If we have multiple superstrings of the same length, return no match to avoid conflicts
155 | // e.g. "John Doe" could be extended to either "John Doe Jr" or "John Doe Sr"
156 | if (dirty) {
157 | return null;
158 | }
159 | return bestMatch;
160 | }
161 |
162 | public static SpeakerAlias findUniqueSuperstring(List needle,
163 | Iterable haystack, boolean caseSensitive) {
164 | SpeakerAlias bestMatch = null;
165 | for (SpeakerAlias candidate : haystack) {
166 | List candTok = Token.getTokens(candidate.getAlias());
167 | for (int i = 0; i < candTok.size() - needle.size() + 1; i++) {
168 | List subCandidate = candTok.subList(i, i + needle.size());
169 | boolean equals;
170 | if (caseSensitive) {
171 | equals = subCandidate.equals(needle);
172 | } else {
173 | equals = equalsCaseInsensitive(subCandidate, needle);
174 | }
175 | if (equals) {
176 | if (bestMatch != null) {
177 | return null; // Conflict detected
178 | }
179 | bestMatch = candidate;
180 | break; // Break the outer loop
181 | }
182 | }
183 | }
184 |
185 | return bestMatch;
186 | }
187 |
188 | private static boolean equalsCaseInsensitive(List a, List b) {
189 | if (a.size() != b.size()) {
190 | return false;
191 | }
192 |
193 | for (int i = 0; i < a.size(); i++) {
194 | if (!a.get(i).equalsIgnoreCase(b.get(i))) {
195 | return false;
196 | }
197 | }
198 |
199 | return true;
200 | }
201 |
202 | /**
203 | * Returns the element that appears with the highest frequency,
204 | * along with its count.
205 | * If there are ties, this method returns false.
206 | */
207 | public static Tuple2 maxFrequencyItem(Iterable it) {
208 | Map frequencies = new HashMap<>();
209 | for (T elem : it) {
210 | frequencies.put(elem, frequencies.getOrDefault(elem, 0) + 1);
211 | }
212 | Map.Entry best = null;
213 | boolean dirty = true; // We want to avoid ties
214 | for (Map.Entry entry : frequencies.entrySet()) {
215 | if (best == null || entry.getValue() > best.getValue()) {
216 | dirty = false;
217 | best = entry;
218 | } else if (entry.getValue() == best.getValue()) {
219 | dirty = true;
220 | }
221 | }
222 |
223 | if (dirty) {
224 | return null;
225 | } else {
226 | return new Tuple2<>(best.getKey(), best.getValue());
227 | }
228 | }
229 |
230 | public static boolean doubleEquals(double a, double b) {
231 | return Math.abs(a - b) < 1e-6;
232 | }
233 | }
234 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/spinn3r/EntryWrapper.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.spinn3r;
2 | import java.io.Serializable;
3 | import java.util.List;
4 |
5 | /**
6 | * This apparently purpose-less class wraps Spinn3r entries
7 | * and is used as input for a JSON serializer such as Gson.
8 | */
9 | public class EntryWrapper implements Serializable {
10 |
11 | private static final long serialVersionUID = 3319808914288405234L;
12 |
13 | private final SourceWrapper source;
14 | private final FeedWrapper feed;
15 | private final FeedEntryWrapper feedEntry;
16 | private final PermalinkEntryWrapper permalinkEntry;
17 |
18 | public EntryWrapper(SourceWrapper source, FeedWrapper feed, FeedEntryWrapper feedEntry,
19 | PermalinkEntryWrapper permalinkEntry) {
20 | this.source = source;
21 | this.feed = feed;
22 | this.feedEntry = feedEntry;
23 | this.permalinkEntry = permalinkEntry;
24 | }
25 |
26 | public SourceWrapper getSource() {
27 | return source;
28 | }
29 |
30 | public FeedWrapper getFeed() {
31 | return feed;
32 | }
33 |
34 | public FeedEntryWrapper getFeedEntry() {
35 | return feedEntry;
36 | }
37 |
38 | public PermalinkEntryWrapper getPermalinkEntry() {
39 | return permalinkEntry;
40 | }
41 |
42 | public static class SourceWrapper implements Serializable {
43 |
44 | private static final long serialVersionUID = 884787890709132785L;
45 |
46 | private final String url;
47 | private final String title;
48 | private final String language;
49 | private final String description;
50 | private final String lastPosted;
51 | private final String lastPublished;
52 | private final String dateFound;
53 | private final String publisherType;
54 |
55 | public SourceWrapper(String url, String title, String language, String description, String lastPosted,
56 | String lastPublished, String dateFound, String publisherType) {
57 | this.url = url;
58 | this.title = title;
59 | this.language = language;
60 | this.description = description;
61 | this.lastPosted = lastPosted;
62 | this.lastPublished = lastPublished;
63 | this.dateFound = dateFound;
64 | this.publisherType = publisherType;
65 | }
66 |
67 | public String getUrl() {
68 | return url;
69 | }
70 |
71 | public String getTitle() {
72 | return title;
73 | }
74 |
75 | public String getLanguage() {
76 | return language;
77 | }
78 |
79 | public String getDescription() {
80 | return description;
81 | }
82 |
83 | public String getLastPosted() {
84 | return lastPosted;
85 | }
86 |
87 | public String getLastPublished() {
88 | return lastPublished;
89 | }
90 |
91 | public String getDateFound() {
92 | return dateFound;
93 | }
94 |
95 | public String getPublisherType() {
96 | return publisherType;
97 | }
98 |
99 | }
100 |
101 | public static class FeedWrapper implements Serializable {
102 |
103 | private static final long serialVersionUID = 928980040260245901L;
104 |
105 | private final String url;
106 | private final String title;
107 | private final String language;
108 | private final String lastPosted;
109 | private final String lastPublished;
110 | private final String dateFound;
111 | private final String channelUrl;
112 |
113 | public FeedWrapper(String url, String title, String language, String lastPosted, String lastPublished,
114 | String dateFound, String channelUrl) {
115 | this.url = url;
116 | this.title = title;
117 | this.language = language;
118 | this.lastPosted = lastPosted;
119 | this.lastPublished = lastPublished;
120 | this.dateFound = dateFound;
121 | this.channelUrl = channelUrl;
122 | }
123 |
124 | public String getUrl() {
125 | return url;
126 | }
127 |
128 | public String getTitle() {
129 | return title;
130 | }
131 |
132 | public String getLanguage() {
133 | return language;
134 | }
135 |
136 | public String getLastPosted() {
137 | return lastPosted;
138 | }
139 |
140 | public String getLastPublished() {
141 | return lastPublished;
142 | }
143 |
144 | public String getDateFound() {
145 | return dateFound;
146 | }
147 |
148 | public String getChannelUrl() {
149 | return channelUrl;
150 | }
151 | }
152 |
153 | public static class FeedEntryWrapper implements Serializable {
154 |
155 | private static final long serialVersionUID = 5137190259805666093L;
156 |
157 | private final String identifier;
158 | private final String url;
159 | private final String title;
160 | private final String language;
161 | private final String authorName;
162 | private final String authorEmail;
163 | private final String lastPublished;
164 | private final String dateFound;
165 | private final List tokenizedContent;
166 |
167 | public FeedEntryWrapper(String identifier, String url, String title, String language, String authorName,
168 | String authorEmail, String lastPublished, String dateFound, List tokenizedContent) {
169 | super();
170 | this.identifier = identifier;
171 | this.url = url;
172 | this.title = title;
173 | this.language = language;
174 | this.authorName = authorName;
175 | this.authorEmail = authorEmail;
176 | this.lastPublished = lastPublished;
177 | this.dateFound = dateFound;
178 | this.tokenizedContent = tokenizedContent;
179 | }
180 |
181 | public String getIdentifier() {
182 | return identifier;
183 | }
184 |
185 | public String getUrl() {
186 | return url;
187 | }
188 |
189 | public String getTitle() {
190 | return title;
191 | }
192 |
193 | public String getLanguage() {
194 | return language;
195 | }
196 |
197 | public String getAuthorName() {
198 | return authorName;
199 | }
200 |
201 | public String getAuthorEmail() {
202 | return authorEmail;
203 | }
204 |
205 | public String getLastPublished() {
206 | return lastPublished;
207 | }
208 |
209 | public String getDateFound() {
210 | return dateFound;
211 | }
212 |
213 | public List getContent() {
214 | return tokenizedContent;
215 | }
216 | }
217 |
218 | public static class PermalinkEntryWrapper implements Serializable {
219 |
220 | private static final long serialVersionUID = -1886188543863283140L;
221 |
222 | private final String identifier;
223 | private final String url;
224 | private final String title;
225 | private final String language;
226 | private final String authorName;
227 | private final String authorEmail;
228 | private final String lastPublished;
229 | private final String dateFound;
230 | private final List tokenizedContent;
231 |
232 | public PermalinkEntryWrapper(String identifier, String url, String title, String language, String authorName,
233 | String authorEmail, String lastPublished, String dateFound, List tokenizedContent) {
234 | this.identifier = identifier;
235 | this.url = url;
236 | this.title = title;
237 | this.language = language;
238 | this.authorName = authorName;
239 | this.authorEmail = authorEmail;
240 | this.lastPublished = lastPublished;
241 | this.dateFound = dateFound;
242 | this.tokenizedContent = tokenizedContent;
243 | }
244 |
245 | public String getIdentifier() {
246 | return identifier;
247 | }
248 |
249 | public String getUrl() {
250 | return url;
251 | }
252 |
253 | public String getTitle() {
254 | return title;
255 | }
256 |
257 | public String getLanguage() {
258 | return language;
259 | }
260 |
261 | public String getAuthorName() {
262 | return authorName;
263 | }
264 |
265 | public String getAuthorEmail() {
266 | return authorEmail;
267 | }
268 |
269 | public String getLastPublished() {
270 | return lastPublished;
271 | }
272 |
273 | public String getDateFound() {
274 | return dateFound;
275 | }
276 |
277 | public List getContent() {
278 | return tokenizedContent;
279 | }
280 | }
281 | }
282 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/spinn3r/Tokenizer.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.spinn3r;
2 |
3 | import java.util.List;
4 |
5 | public interface Tokenizer {
6 |
7 | List tokenize(String sentence);
8 | String untokenize(List tokens);
9 |
10 | }
11 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/spinn3r/TokenizerImpl.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.spinn3r;
2 |
3 | import java.io.StringReader;
4 | import java.util.ArrayList;
5 | import java.util.List;
6 |
7 | import edu.stanford.nlp.ling.CoreLabel;
8 | import edu.stanford.nlp.process.CoreLabelTokenFactory;
9 | import edu.stanford.nlp.process.PTBTokenizer;
10 |
11 | public class TokenizerImpl implements Tokenizer {
12 |
13 | /**
14 | * Configuration for Stanford PTBTokenizer.
15 | */
16 | private static final String TOKENIZER_SETTINGS = "tokenizeNLs=false, americanize=false, " +
17 | "normalizeCurrency=false, normalizeParentheses=false," +
18 | "normalizeOtherBrackets=false, unicodeQuotes=false, ptb3Ellipsis=true," +
19 | "escapeForwardSlashAsterisk=false, untokenizable=noneKeep, normalizeSpace=false";
20 |
21 |
22 | @Override
23 | public List tokenize(String sentence) {
24 | PTBTokenizer ptbt = new PTBTokenizer<>(new StringReader(sentence),
25 | new CoreLabelTokenFactory(), TOKENIZER_SETTINGS);
26 | List tokens = new ArrayList<>();
27 | while (ptbt.hasNext()) {
28 | CoreLabel label = ptbt.next();
29 | tokens.add(label.toString());
30 | }
31 | return tokens;
32 | }
33 |
34 | @Override
35 | public String untokenize(List tokens) {
36 | return PTBTokenizer.ptb2Text(tokens);
37 | }
38 |
39 | }
40 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/spinn3r/converter/AbstractDecoder.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.spinn3r.converter;
2 | import java.util.Iterator;
3 |
4 | /**
5 | * Provides a convenient interface to create custom iterators
6 | * without materializing intermediate results.
7 | */
8 | public abstract class AbstractDecoder implements Iterator {
9 |
10 | private boolean dirty;
11 | private boolean finished;
12 | private T cachedNext;
13 |
14 | protected AbstractDecoder() {
15 | dirty = true;
16 | finished = false;
17 | }
18 |
19 | /**
20 | * Obtain the next element.
21 | */
22 | protected abstract T getNextImpl();
23 |
24 | private void extractNext()
25 | {
26 | if (!finished && dirty) {
27 | cachedNext = getNextImpl();
28 | dirty = false;
29 |
30 | if (cachedNext == null) {
31 | finished = true;
32 | }
33 | }
34 | }
35 |
36 | @Override
37 | public boolean hasNext() {
38 | extractNext();
39 | return !finished;
40 | }
41 |
42 | @Override
43 | public T next() {
44 | extractNext();
45 | T next = cachedNext;
46 | cachedNext = null;
47 | dirty = true;
48 | return next;
49 | }
50 |
51 | }
52 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/spinn3r/converter/CombinedDecoder.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.spinn3r.converter;
2 | import java.util.Iterator;
3 |
4 | import org.apache.spark.input.PortableDataStream;
5 |
6 | import ch.epfl.dlab.spinn3r.EntryWrapper;
7 | import scala.Tuple2;
8 |
9 | public class CombinedDecoder extends AbstractDecoder {
10 |
11 | private final Iterator> fileInput;
12 | private Decompressor decompressor;
13 | private SpinnerDecoder spinnerDecoder;
14 | private final boolean clean;
15 | private final boolean tokenize;
16 |
17 | /**
18 | * Initialize this decoder.
19 | * @param in The streams to decode and convert.
20 | * @param clean Whether to clean or not useless HTML tags.
21 | * @param tokenize Whether to tokenize the output.
22 | */
23 | public CombinedDecoder(Iterator> in, boolean clean, boolean tokenize) {
24 | fileInput = in;
25 | this.clean = clean;
26 | this.tokenize = tokenize;
27 | }
28 |
29 | @Override
30 | protected EntryWrapper getNextImpl() {
31 | while (true) {
32 |
33 | if (decompressor == null) {
34 | if (fileInput.hasNext()) {
35 | // Decompress next archive
36 | decompressor = new Decompressor(fileInput.next()._2.open());
37 | } else {
38 | return null;
39 | }
40 | }
41 |
42 | if (spinnerDecoder == null) {
43 | if (decompressor.hasNext()) {
44 | // Decode next file from archive
45 | spinnerDecoder = new SpinnerDecoder(decompressor.next());
46 | } else {
47 | decompressor = null;
48 | }
49 | }
50 |
51 | if (spinnerDecoder != null) {
52 | if (spinnerDecoder.hasNext()) {
53 | // Decode next entry from file
54 | return EntryWrapperBuilder.buildFrom(spinnerDecoder.next(), clean, tokenize);
55 | } else {
56 | spinnerDecoder = null;
57 | }
58 | }
59 | }
60 | }
61 |
62 | }
63 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/spinn3r/converter/Decompressor.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.spinn3r.converter;
2 | import java.io.IOException;
3 | import java.io.InputStream;
4 |
5 | import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
6 | import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
7 | import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
8 |
9 | /**
10 | * This class implements an on-the-fly decoder for .tar.gz archives.
11 | * The stream is decompressed in memory, and the result is returned as an iterator,
12 | * without materializing the entire result.
13 | */
14 | public class Decompressor extends AbstractDecoder {
15 |
16 | private final TarArchiveInputStream tar;
17 |
18 | /**
19 | * Initializes this iterator using the given stream.
20 | * @param inStream the stream to decode
21 | */
22 | public Decompressor(InputStream inStream) {
23 | try {
24 | tar = new TarArchiveInputStream(new GzipCompressorInputStream(inStream));
25 | } catch (IOException e) {
26 | throw new IllegalStateException("Unable to open stream", e);
27 | }
28 | }
29 |
30 | @Override
31 | protected byte[] getNextImpl() {
32 | try {
33 | TarArchiveEntry entry;
34 | while ((entry = tar.getNextTarEntry()) != null) {
35 | if (entry.isDirectory()) {
36 | continue;
37 | }
38 |
39 | byte[] data = new byte[tar.available()];
40 | tar.read(data);
41 | return data;
42 | }
43 |
44 | // Entry is null
45 | tar.close();
46 | return null;
47 | } catch (IOException e) {
48 | throw new IllegalStateException("Unable to read stream", e);
49 | }
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/quootstrap/src/main/java/ch/epfl/dlab/spinn3r/converter/ProtoToJson.java:
--------------------------------------------------------------------------------
1 | package ch.epfl.dlab.spinn3r.converter;
2 | import java.io.BufferedOutputStream;
3 | import java.io.File;
4 | import java.io.FileInputStream;
5 | import java.io.FileOutputStream;
6 | import java.io.IOException;
7 | import java.io.InputStream;
8 | import java.nio.file.Files;
9 | import java.nio.file.Paths;
10 | import java.util.ArrayList;
11 | import java.util.Arrays;
12 | import java.util.List;
13 | import java.util.Optional;
14 |
15 | import org.apache.commons.io.FileUtils;
16 | import org.apache.commons.io.IOUtils;
17 | import org.apache.hadoop.io.compress.CompressionCodec;
18 | import org.apache.spark.SparkConf;
19 | import org.apache.spark.api.java.JavaRDD;
20 | import org.apache.spark.api.java.JavaSparkContext;
21 | import org.apache.spark.util.LongAccumulator;
22 |
23 | import com.google.gson.Gson;
24 | import com.google.gson.GsonBuilder;
25 |
26 | import ch.epfl.dlab.spinn3r.EntryWrapper;
27 | import scala.Tuple2;
28 |
29 | public class ProtoToJson {
30 |
31 | private static final Gson JSON_BUILDER = new GsonBuilder().disableHtmlEscaping().create();
32 |
33 | public static void main(final String[] args) throws IOException {
34 |
35 | if (args.length < 2) {
36 | System.err.println("Usage: ProtoToJson [optional arguments]");
37 | System.err.println("[--master=] specifies a master if spark-submit is not used.");
38 | System.err.println("[--sample=] specifies the fraction of data to sample.");
39 | System.err.println("[--merge] merges HDFS output into one single file (if the application is run locally).");
40 | System.err.println("[--partitions=] specifies the number of partitions to save (useful if compression is used)");
41 | System.err.println("[--compress=] compresses the output using the given codec (GzipCodec, Lz4Codec, BZip2Codec, SnappyCodec)");
42 | System.err.println("[--source-type=] specifies the source type(s) to extract (e.g. MAINSTREAM_NEWS). Separate with ;");
43 | System.err.println("[--clean=] clean useless HTML tags (default: true)");
44 | System.err.println("[--tokenize=