├── utils
├── __init__.py
├── README.md
├── preprocess_transformations.py
├── preprocess_reading_lookup.py
├── preprocess_output_vocab.py
├── helpers.py
├── preprocess_lang8.py
├── preprocess_wiki.py
├── errorify.py
└── edits.py
├── .gitignore
├── data
└── output_vocab
│ ├── detect.txt
│ └── labels.txt
├── requirements.txt
├── app.yaml
├── main.py
├── templates
└── index.html
├── evaluate.py
├── train.py
├── README.md
├── model.py
└── LICENSE
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | data/corpora/*
3 | data/model/*
4 |
--------------------------------------------------------------------------------
/data/output_vocab/detect.txt:
--------------------------------------------------------------------------------
1 | [PAD]
2 | CORRECT
3 | INCORRECT
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==2.5.0
2 | transformers==4.5.1
3 | fugashi==1.1.0
4 | unidic-lite==1.0.8
5 | Flask==2.0.1
--------------------------------------------------------------------------------
/utils/README.md:
--------------------------------------------------------------------------------
1 | # Dataset Preprocessing Modules
2 | - `utils.preprocess_wiki`: Generates a TFRecordDataset from a Wikipedia dump extracted by WikiExtractor.
3 | - `utils.preprocess_lang8`: Generates a TFRecordDataset from the Lang8 corpus.
4 | - `utils.edits`: Module for edit-tagging parallel sentences.
5 | - `utils.errorify`: Module for generating synthetic errors in sentences.
6 | - `utils.helpers`: General common helper functions.
7 |
--------------------------------------------------------------------------------
/app.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | runtime: python39
16 |
--------------------------------------------------------------------------------
/utils/preprocess_transformations.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from itertools import chain
4 |
5 | from fugashi import Tagger
6 |
7 | from .errorify import Errorify
8 |
9 |
10 | def preprocess_transformations(verbs_file, adjs_file, output_file):
11 | """Generate verb/adj stem transformations for all verbs in vocab."""
12 |
13 | with open(verbs_file) as f:
14 | verbs = list(json.load(f).keys())
15 | with open(adjs_file) as f:
16 | adjs = list(json.load(f).keys())
17 | print(f'Loaded {len(verbs)} verbs and {len(adjs)} adjectives.')
18 | errorify = Errorify()
19 | lines = []
20 | for baseform in chain(verbs, adjs):
21 | forms = errorify.get_forms(baseform)
22 | for form1, orth1 in forms.items():
23 | for form2, orth2 in forms.items():
24 | if form1 != form2:
25 | lines.append(f'{orth1}_{orth2}:{form1}_{form2}\n')
26 | with open(output_file, 'w', encoding='utf-8') as f:
27 | f.writelines(lines)
28 | print(f'Wrote {len(lines)} transformations to {output_file}.')
29 |
30 | def main(args):
31 | preprocess_transformations(args.verbs, args.adjs, args.output)
32 |
33 |
34 | if __name__ == '__main__':
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument('-v', '--verbs',
37 | help='Path to verbs frequencies file',
38 | required=True)
39 | parser.add_argument('-a', '--adjs',
40 | help='Path to i-adjectives frequencies file',
41 | required=True)
42 | parser.add_argument('-o', '--output',
43 | help='Path to output file',
44 | required=True)
45 | args = parser.parse_args()
46 | main(args)
47 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # [START gae_python38_app]
16 | # [START gae_python3_app]
17 | import unicodedata
18 | from difflib import ndiff
19 |
20 | from flask import Flask, render_template, request, jsonify
21 | from model import GEC
22 |
23 |
24 | # If `entrypoint` is not defined in app.yaml, App Engine will look for an app
25 | # called `app` in `main.py`.
26 | app = Flask(__name__)
27 | gec = GEC(pretrained_weights_path='data/model/model_checkpoint')
28 |
29 |
30 | @app.route('/', methods=['GET'])
31 | def index():
32 | return render_template('index.html')
33 |
34 |
35 | @app.route('/correct', methods=['POST'])
36 | def correct():
37 | text = unicodedata.normalize('NFKC', request.json['text']).replace(' ', '')
38 | correct_text = gec.correct(text)
39 | diffs = list(ndiff(text, correct_text))
40 | print(f'Correction: {text} -> {correct_text}')
41 | return jsonify({
42 | 'correctedText': correct_text,
43 | 'diffs': diffs
44 | })
45 |
46 |
47 | if __name__ == '__main__':
48 | # This is used when running locally only. When deploying to Google App
49 | # Engine, a webserver process such as Gunicorn will serve the app. This
50 | # can be configured by adding an `entrypoint` to app.yaml.
51 | app.run(host='127.0.0.1', port=8080, threaded=False, use_reloader=False)
52 | # [END gae_python3_app]
53 | # [END gae_python38_app]
54 |
--------------------------------------------------------------------------------
/utils/preprocess_reading_lookup.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import argparse
3 | import json
4 | import xml.etree.ElementTree as ET
5 | import re
6 |
7 | import jaconv
8 |
9 |
10 | kanji_re = re.compile(r'([一-龯])')
11 |
12 |
13 | def preprocess_reading_lookup(kanjidic_path, jmdict_path, output_path):
14 | """Generate reading to kanji lookup dictionary."""
15 | reading_lookup = defaultdict(set)
16 | kd_root = ET.parse(kanjidic_path).getroot()
17 | characters = kd_root.findall('character')
18 | print(f'Loaded {len(characters)} characters from kanjidic')
19 | for c in characters:
20 | if not c.findtext('misc/grade'): # only use joyo kanji
21 | continue
22 | literal = c.findtext('literal')
23 | readings = c.findall('reading_meaning/rmgroup/reading')
24 | for reading in readings:
25 | if reading.attrib['r_type'] in ['ja_on', 'ja_kun']:
26 | r = jaconv.hira2kata(reading.text)
27 | reading_lookup[r].add(literal)
28 | jd_root = ET.parse(jmdict_path).getroot()
29 | entries = jd_root.findall('entry')
30 | print(f'Loaded {len(entries)} entries from JMdict')
31 | for e in entries:
32 | pos = e.findtext('sense/pos')
33 | pri = e.findtext('k_ele/ke_pri')
34 | if pri and 'noun' in pos and kanji_re.search(e.findtext('k_ele/keb')):
35 | reading = {jaconv.hira2kata(r.text) for r in e.findall('r_ele/reb')}
36 | orth = {k.text for k in e.findall('k_ele/keb')}
37 | for r in reading:
38 | reading_lookup[r] |= orth
39 | reading_lookup = {k: list(v) for k, v in reading_lookup.items()}
40 | with open(output_path, 'w') as f:
41 | json.dump(reading_lookup, f)
42 | print(f'Reading lookup output to {output_path}')
43 |
44 |
45 | def main(args):
46 | preprocess_reading_lookup(args.kanjidic, args.jmdict, args.output)
47 |
48 |
49 | if __name__ == '__main__':
50 | parser = argparse.ArgumentParser()
51 | parser.add_argument('-k', '--kanjidic',
52 | help='Path to KANJIDIC file',
53 | required=True)
54 | parser.add_argument('-j', '--jmdict',
55 | help='Path to JMDict file',
56 | required=True)
57 | parser.add_argument('-o', '--output',
58 | help='Path to output file',
59 | required=True)
60 | args = parser.parse_args()
61 | main(args)
62 |
--------------------------------------------------------------------------------
/utils/preprocess_output_vocab.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import json
4 | import math
5 | from collections import Counter
6 |
7 |
8 | def get_class_weights(classes, freqs):
9 | n_samples = sum(freqs[c] for c in classes)
10 | class_weights = [1.] * len(classes)
11 | for i, c in enumerate(classes):
12 | if c in freqs:
13 | w = math.log(n_samples / freqs[c])
14 | class_weights[i] = max(w, 1.0)
15 | return class_weights
16 |
17 |
18 | def preprocess_output_vocab(output_file, weights_file):
19 | """Generate output vocab from all corpora."""
20 | labels = ['[PAD]', '[UNK]']
21 | with open('data/corpora/jawiki/edit_freq.json') as f:
22 | jawiki_edit_freq = json.load(f)
23 | with open('data/corpora/lang8/edit_freq.json') as f:
24 | lang8_edit_freq = json.load(f)
25 | edit_freq = Counter(jawiki_edit_freq)
26 | edit_freq.update(lang8_edit_freq)
27 | ordered = sorted(edit_freq.items(), key=lambda x: x[1], reverse=True)
28 | labels += [edit for edit, freq in ordered if freq >= 500]
29 | n_samples = sum(edit_freq[edit] for edit in labels)
30 | dist = [freq / n_samples for edit, freq in ordered]
31 | print(ordered[:100])
32 | print(dist[:100])
33 | n_classes = len(labels)
34 | class_weights = [1.] * n_classes
35 | labels_class_weights = get_class_weights(labels, edit_freq)
36 | n_correct = edit_freq['$KEEP']
37 | detect_class_weights = get_class_weights(
38 | ['CORRECT', 'INCORRECT'],
39 | {'CORRECT': n_correct, 'INCORRECT': n_samples-n_correct}
40 | )
41 | detect_class_weights = [1.] + detect_class_weights
42 | with open(weights_file, 'w', encoding='utf-8') as f:
43 | json.dump([labels_class_weights, detect_class_weights], f)
44 | with open(output_file, 'w', encoding='utf-8') as f:
45 | f.writelines(f'{label}\n' for label in labels)
46 | print(f'{n_classes} edits output to {output_file}.')
47 | print(f'Class weights output to {weights_file}.')
48 |
49 |
50 | def main(args):
51 | preprocess_output_vocab(args.output, args.weights)
52 |
53 |
54 | if __name__ == '__main__':
55 | parser = argparse.ArgumentParser()
56 | parser.add_argument('-o', '--output',
57 | help='Path to output file',
58 | required=True)
59 | parser.add_argument('-w', '--weights',
60 | help='Path to class weights output file',
61 | required=True)
62 | args = parser.parse_args()
63 | main(args)
64 |
--------------------------------------------------------------------------------
/templates/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | gector-ja
12 |
13 |
16 |
17 |
18 |
19 |
20 |
gector-ja
21 |
Grammatical error correction model for Japanese, based on transformers and keras/tensorflow 2. Github.
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
...
33 |
34 |
35 |
36 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import tensorflow as tf
5 | from fugashi import Tagger
6 | from nltk.translate.gleu_score import corpus_gleu
7 |
8 | from model import GEC
9 |
10 |
11 | tagger = Tagger('-Owakati')
12 |
13 |
14 | def tokenize(sentence):
15 | return [t.surface for t in tagger(sentence)]
16 |
17 |
18 | def main(weights_path, vocab_dir, transforms_file, corpus_dir):
19 | try:
20 | tpu = tf.distribute.cluster_resolver.TPUClusterResolver(
21 | tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
22 | tf.config.experimental_connect_to_cluster(tpu)
23 | tf.tpu.experimental.initialize_tpu_system(tpu)
24 | print('TPUs: ', tf.config.list_logical_devices('TPU'))
25 | except (ValueError, KeyError) as e:
26 | tpu = None
27 | source_path = tf.io.gfile.glob(os.path.join(corpus_dir, '*.src'))[0]
28 | with tf.io.gfile.GFile(source_path, 'r') as f:
29 | source_sents = [line for line in f.readlines() if line]
30 | reference_tokens = []
31 | for reference_path in tf.io.gfile.glob(os.path.join(corpus_dir, '*.ref*')):
32 | with tf.io.gfile.GFile(reference_path, 'r') as f:
33 | tokens = [tokenize(line) for line in f.readlines() if line]
34 | reference_tokens.append(tokens)
35 | reference_tokens = list(zip(*reference_tokens))
36 | print(f'Loaded {len(source_sents)} src, {len(reference_tokens)} ref')
37 |
38 | if tpu:
39 | strategy = tf.distribute.TPUStrategy(tpu)
40 | else:
41 | strategy = tf.distribute.MultiWorkerMirroredStrategy()
42 | with strategy.scope():
43 | gec = GEC(vocab_path=vocab_dir, verb_adj_forms_path=transforms_file,
44 | pretrained_weights_path=weights_path)
45 |
46 | pred_tokens = []
47 | source_batches = [source_sents[i:i + 64]
48 | for i in range(0, len(source_sents), 64)]
49 | for i, source_batch in enumerate(source_batches):
50 | print(f'Predict batch {i+1}/{len(source_batches)}')
51 | pred_batch = gec.correct(source_batch)
52 | pred_batch_tokens = [tokenize(sent) for sent in pred_batch]
53 | pred_tokens.extend(pred_batch_tokens)
54 | print('Corpus GLEU', corpus_gleu(reference_tokens, pred_tokens))
55 |
56 |
57 | if __name__ == '__main__':
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument('-w', '--weights',
60 | help='Path to model weights',
61 | required=True)
62 | parser.add_argument('-v', '--vocab_dir',
63 | help='Path to output vocab folder',
64 | default='./data/output_vocab')
65 | parser.add_argument('-f', '--transforms_file',
66 | help='Path to verb/adj transforms file',
67 | default='./data/transform.txt')
68 | parser.add_argument('-c', '--corpus_dir',
69 | help='Path to directory of TMU evaluation corpus',
70 | required=True)
71 | args = parser.parse_args()
72 | main(args.weights, args.vocab_dir, args.transforms_file, args.corpus_dir)
73 |
--------------------------------------------------------------------------------
/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import tensorflow as tf
4 | from tensorflow import keras
5 | from tensorflow.data import TFRecordDataset, AUTOTUNE
6 | from tensorflow.train import Features, Feature, Example, BytesList, Int64List
7 | from tensorflow.io import (TFRecordWriter, TFRecordOptions, FixedLenFeature,
8 | parse_single_example)
9 |
10 |
11 | class WeightedSCCE(keras.losses.Loss):
12 | def __init__(self, class_weight, from_logits=False, name='weighted_scce'):
13 | if class_weight is None or all(v == 1. for v in class_weight):
14 | self.class_weight = None
15 | else:
16 | self.class_weight = tf.convert_to_tensor(class_weight,
17 | dtype=tf.float32)
18 | self.reduction = keras.losses.Reduction.NONE
19 | self.unreduced_scce = keras.losses.SparseCategoricalCrossentropy(
20 | from_logits=from_logits, name=name,
21 | reduction=self.reduction)
22 |
23 | def __call__(self, y_true, y_pred, sample_weight=None):
24 | loss = self.unreduced_scce(y_true, y_pred, sample_weight)
25 | if self.class_weight is not None:
26 | weight_mask = tf.gather(self.class_weight, y_true)
27 | loss = tf.math.multiply(loss, weight_mask)
28 | return loss
29 |
30 |
31 | class Vocab:
32 | def __init__(self, words):
33 | self.id2word = words
34 | self.word2id = {word: i for i, word in enumerate(words)}
35 | self.unk_id = self.word2id['[UNK]'] if '[UNK]' in self.word2id else -1
36 |
37 | @classmethod
38 | def from_file(cls, file):
39 | if not os.path.exists(file):
40 | raise ValueError(f'Vocab file {file} does not exist')
41 | words = []
42 | with open(file, 'r', encoding='utf-8') as f:
43 | for line in f.readlines():
44 | line = line.strip()
45 | if line:
46 | words.append(line)
47 | return cls(words)
48 |
49 | def __len__(self):
50 | return len(self.id2word)
51 |
52 | def __getitem__(self, key):
53 | if isinstance(key, str):
54 | return self.word2id.get(key, self.unk_id)
55 | else:
56 | return self.id2word[key]
57 |
58 |
59 | def write_dataset(path, examples):
60 | options = TFRecordOptions(compression_type='GZIP')
61 | with TFRecordWriter(path, options=options) as writer:
62 | for example in examples:
63 | writer.write(example.SerializeToString())
64 |
65 |
66 | def read_dataset(paths):
67 | return TFRecordDataset(paths, compression_type='GZIP',
68 | num_parallel_reads=AUTOTUNE).map(parse_example,
69 | num_parallel_calls=AUTOTUNE)
70 |
71 |
72 | def create_example(tokens, edits, tokenizer, labels_vocab, detect_vocab,
73 | max_tokens_len=128):
74 | if len(tokens) > max_tokens_len:
75 | tokens = tokens[:max_tokens_len]
76 | edits = edits[:max_tokens_len]
77 | token_ids = [0] * max_tokens_len
78 | label_ids = [0] * max_tokens_len
79 | detect_ids = [0] * max_tokens_len
80 |
81 | n = min(len(tokens), max_tokens_len)
82 | token_ids[:n] = tokenizer.convert_tokens_to_ids(tokens)
83 | label_ids[:n] = [labels_vocab[e] for e in edits]
84 | corr_idx = detect_vocab['CORRECT']
85 | incorr_idx = detect_vocab['INCORRECT']
86 | detect_ids[:n] = [corr_idx if e == '$KEEP' else incorr_idx for e in edits]
87 |
88 | assert len(token_ids) == max_tokens_len
89 | assert len(label_ids) == max_tokens_len
90 | assert len(detect_ids) == max_tokens_len
91 |
92 | feature = {
93 | 'token_ids': int64_list_feature(token_ids),
94 | 'label_ids': int64_list_feature(label_ids),
95 | 'detect_ids': int64_list_feature(detect_ids)
96 | }
97 | return Example(features=Features(feature=feature))
98 |
99 |
100 | def parse_example(example, max_tokens_len=128):
101 | feature_desc = {
102 | 'token_ids': FixedLenFeature([max_tokens_len], tf.int64),
103 | 'label_ids': FixedLenFeature([max_tokens_len], tf.int64),
104 | 'detect_ids': FixedLenFeature([max_tokens_len], tf.int64)
105 | }
106 | example = parse_single_example(example, feature_desc)
107 | token_ids = tf.cast(example['token_ids'], tf.int32)
108 | att_mask = token_ids != 0
109 | label_ids = example['label_ids']
110 | detect_ids = example['detect_ids']
111 | return token_ids, (label_ids, detect_ids), att_mask
112 |
113 |
114 | def int64_list_feature(value):
115 | """Returns an int64_list from a list of bool / enum / int / uint."""
116 | return Feature(int64_list=Int64List(value=value))
117 |
--------------------------------------------------------------------------------
/utils/preprocess_lang8.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import json
4 | import re
5 | import unicodedata
6 | from collections import Counter
7 | from multiprocessing import Pool
8 | from itertools import chain
9 | from difflib import SequenceMatcher
10 |
11 | from .edits import EditTagger
12 | from .helpers import write_dataset
13 |
14 |
15 | invalid_bytes_re = re.compile(r'[\x00-\x1F]+')
16 | sline_re = re.compile(r'\[sline\].*?\[/sline\]')
17 | color_tags = ['[f-blue]','[/f-blue]',
18 | '[f-red]','[/f-red]',
19 | '[f-bold]','[/f-bold]']
20 | ja_re = re.compile(r'([ぁ-んァ-ン])')
21 | html_re = re.compile(r'<(\/?[a-z]+)>')
22 | edit_tagger = EditTagger()
23 |
24 |
25 | def clean_line(line):
26 | line = unicodedata.normalize('NFKC', line.strip()).replace(' ', '')
27 | if line.endswith('GOOD'):
28 | line = line[:-4]
29 | elif line.endswith('OK'):
30 | line = line[:-2]
31 | for tag in color_tags:
32 | line = line.replace(tag, '')
33 | line = sline_re.sub('', line).replace('[/sline]', '')
34 | return line
35 |
36 |
37 | def preprocess_lang8_part(args,
38 | correct_file='corr_sentences.txt',
39 | incorrect_file='incorr_sentences.txt',
40 | edit_tags_file='edit_tagged_sentences.tfrec.gz'):
41 | edit_tagger.edit_freq = Counter()
42 | rows, part_output_dir = args
43 | pairs = set()
44 | for row in rows:
45 | for learner_sent, corrections in zip(row[4], row[5]):
46 | if not ja_re.search(learner_sent) or html_re.search(learner_sent):
47 | continue
48 | learner_sent = clean_line(learner_sent)
49 | if not corrections:
50 | pairs.add((learner_sent, learner_sent))
51 | else:
52 | for target_sent in corrections:
53 | if not target_sent or not ja_re.search(target_sent) or \
54 | html_re.search(target_sent):
55 | continue
56 | target_sent = clean_line(target_sent)
57 | pairs.add((learner_sent, target_sent))
58 | corr_lines = []
59 | incorr_lines = []
60 | edit_rows = []
61 | for learner_sent, target_sent in pairs:
62 | # remove appended comments
63 | matcher = SequenceMatcher(None, learner_sent, target_sent)
64 | diffs = list(matcher.get_opcodes())
65 | tag, i1, i2, j1, j2 = diffs[-1]
66 | if tag == 'insert' and (learner_sent[-1] in '。.!?' or j2 - j1 >= 10):
67 | target_sent = target_sent[:j1]
68 | elif tag == 'replace' and (j2 - j1) / (i2 - i1) >= 10:
69 | continue
70 | corr_lines.append(f'{target_sent}\n')
71 | incorr_lines.append(f'{learner_sent}\n')
72 | levels = edit_tagger(learner_sent, target_sent, levels=True)
73 | edit_rows.extend(levels)
74 | if not os.path.exists(part_output_dir):
75 | os.makedirs(part_output_dir)
76 | corr_path = os.path.join(part_output_dir, correct_file)
77 | incorr_path = os.path.join(part_output_dir, incorrect_file)
78 | edit_tags_path = os.path.join(part_output_dir, edit_tags_file)
79 | with open(corr_path, 'w', encoding='utf-8') as f:
80 | f.writelines(corr_lines)
81 | with open(incorr_path, 'w', encoding='utf-8') as f:
82 | f.writelines(incorr_lines)
83 | write_dataset(edit_tags_path, edit_rows)
84 | print(f'Processed {len(corr_lines)} sentences, ' \
85 | f'{len(edit_rows)} edit-tagged sentences to {part_output_dir}')
86 | return len(corr_lines), len(edit_rows), edit_tagger.edit_freq
87 |
88 |
89 | def preprocess_lang8(source_file, output_dir, processes,
90 | edit_freq_file='edit_freq.json'):
91 | """Generate edit-tagged sentence corpus from Lang8 corpus."""
92 | lines = []
93 | with open(source_file, encoding='utf-8') as f:
94 | lines = f.readlines()
95 | rows = []
96 | for line in lines:
97 | row = json.loads(invalid_bytes_re.sub('', line))
98 | if row[2] == 'Japanese':
99 | rows.append(row)
100 | r = 512
101 | rows_parts = [(rows[i:i + r], os.path.join(output_dir, str((i//r)+1)))
102 | for i in range(0, len(rows), r)]
103 | print(f'Loaded {len(rows)} Japanese entries into {len(rows_parts)} parts')
104 | pool = Pool(processes)
105 | pool_outputs = pool.imap_unordered(preprocess_lang8_part, rows_parts)
106 | n_sents = 0
107 | n_edit_sents = 0
108 | edit_freq = Counter()
109 | for n in pool_outputs:
110 | n_sents += n[0]
111 | n_edit_sents += n[1]
112 | edit_freq.update(n[2])
113 | with open(os.path.join(output_dir, edit_freq_file), 'w') as f:
114 | json.dump(edit_freq, f)
115 | print(f'Processed {n_sents} sentences and ' \
116 | f'{n_edit_sents} edit-tagged sentences.')
117 |
118 |
119 | def main(args):
120 | preprocess_lang8(args.source, args.output_dir, args.processes)
121 |
122 |
123 | if __name__ == '__main__':
124 | parser = argparse.ArgumentParser()
125 | parser.add_argument('-s', '--source',
126 | help='Path to Lang8 corpus file',
127 | required=True)
128 | parser.add_argument('-o', '--output_dir',
129 | help='Path to output directory',
130 | required=True)
131 | parser.add_argument('-p', '--processes', type=int,
132 | help='Number of processes',
133 | required=False)
134 | args = parser.parse_args()
135 | main(args)
136 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import json
4 |
5 | import tensorflow as tf
6 | from tensorflow import keras
7 | import numpy as np
8 | from transformers import AdamWeightDecay
9 | from sklearn.metrics import classification_report
10 |
11 | from model import GEC
12 | from utils.helpers import read_dataset, WeightedSCCE
13 |
14 |
15 | AUTO = tf.data.AUTOTUNE
16 |
17 |
18 | def train(corpora_dir, output_weights_path, vocab_dir, transforms_file,
19 | pretrained_weights_path, batch_size, n_epochs, dev_ratio, dataset_len,
20 | dataset_ratio, bert_trainable, learning_rate, class_weight_path,
21 | filename='edit_tagged_sentences.tfrec.gz'):
22 | try:
23 | tpu = tf.distribute.cluster_resolver.TPUClusterResolver(
24 | tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
25 | tf.config.experimental_connect_to_cluster(tpu)
26 | tf.tpu.experimental.initialize_tpu_system(tpu)
27 | print('TPUs: ', tf.config.list_logical_devices('TPU'))
28 | except (ValueError, KeyError) as e:
29 | tpu = None
30 | files = [os.path.join(root, filename)
31 | for root, dirs, files in tf.io.gfile.walk(corpora_dir)
32 | if filename in files]
33 | dataset = read_dataset(files).shuffle(buffer_size=1024)
34 | if dataset_len:
35 | dataset_card = tf.data.experimental.assert_cardinality(dataset_len)
36 | dataset = dataset.apply(dataset_card)
37 | if 0 < dataset_ratio < 1:
38 | dataset_len = int(dataset_len * dataset_ratio)
39 | dataset = dataset.take(dataset_len)
40 | print(dataset, dataset.cardinality().numpy())
41 | print('Loaded dataset')
42 |
43 | dev_len = int(dataset_len * dev_ratio)
44 | train_set = dataset.skip(dev_len).prefetch(AUTO)
45 | dev_set = dataset.take(dev_len).prefetch(AUTO)
46 | print(train_set.cardinality().numpy(), dev_set.cardinality().numpy())
47 | print(f'Using {dev_ratio} of dataset for dev set')
48 | train_set = train_set.batch(batch_size, num_parallel_calls=AUTO)
49 | dev_set = dev_set.batch(batch_size, num_parallel_calls=AUTO)
50 |
51 | if tpu:
52 | strategy = tf.distribute.TPUStrategy(tpu)
53 | else:
54 | strategy = tf.distribute.MultiWorkerMirroredStrategy()
55 | with strategy.scope():
56 | gec = GEC(vocab_path=vocab_dir, verb_adj_forms_path=transforms_file,
57 | pretrained_weights_path=pretrained_weights_path,
58 | bert_trainable=bert_trainable, learning_rate=learning_rate)
59 | model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
60 | filepath=output_weights_path + '_checkpoint',
61 | save_weights_only=True,
62 | monitor='val_labels_probs_sparse_categorical_accuracy',
63 | mode='max',
64 | save_best_only=True)
65 | early_stopping_callback = keras.callbacks.EarlyStopping(
66 | monitor='loss', patience=3)
67 | gec.model.fit(train_set, epochs=n_epochs, validation_data=dev_set,
68 | callbacks=[model_checkpoint_callback, early_stopping_callback])
69 | gec.model.save_weights(output_weights_path)
70 |
71 |
72 | def main(args):
73 | train(args.corpora_dir, args.output_weights_path, args.vocab_dir,
74 | args.transforms_file, args.pretrained_weights_path, args.batch_size,
75 | args.n_epochs, args.dev_ratio, args.dataset_len, args.dataset_ratio,
76 | args.bert_trainable, args.learning_rate, args.class_weight_path)
77 |
78 |
79 | if __name__ == '__main__':
80 | parser = argparse.ArgumentParser()
81 | parser.add_argument('-c', '--corpora_dir',
82 | help='Path to dataset folder',
83 | required=True)
84 | parser.add_argument('-o', '--output_weights_path',
85 | help='Path to save model weights to',
86 | required=True)
87 | parser.add_argument('-v', '--vocab_dir',
88 | help='Path to output vocab folder',
89 | default='./data/output_vocab')
90 | parser.add_argument('-t', '--transforms_file',
91 | help='Path to verb/adj transforms file',
92 | default='./data/transform.txt')
93 | parser.add_argument('-p', '--pretrained_weights_path',
94 | help='Path to pretrained model weights')
95 | parser.add_argument('-b', '--batch_size', type=int,
96 | help='Number of samples per batch',
97 | default=32)
98 | parser.add_argument('-e', '--n_epochs', type=int,
99 | help='Number of epochs',
100 | default=10)
101 | parser.add_argument('-d', '--dev_ratio', type=float,
102 | help='Percent of whole dataset to use for dev set',
103 | default=0.01)
104 | parser.add_argument('-l', '--dataset_len', type=int,
105 | help='Cardinality of dataset')
106 | parser.add_argument('-r', '--dataset_ratio', type=float,
107 | help='Percent of whole dataset to use',
108 | default=1.0)
109 | parser.add_argument('-bt', '--bert_trainable',
110 | help='Enable training for BERT encoder layers',
111 | action='store_true')
112 | parser.add_argument('-lr', '--learning_rate', type=float,
113 | help='Learning rate',
114 | default=1e-5)
115 | parser.add_argument('-cw', '--class_weight_path',
116 | help='Path to class weight file')
117 | args = parser.parse_args()
118 | main(args)
119 |
--------------------------------------------------------------------------------
/utils/preprocess_wiki.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import re
4 | import unicodedata
5 | import json
6 | from collections import Counter
7 | from multiprocessing import Pool
8 |
9 | from .errorify import Errorify
10 | from .edits import EditTagger
11 | from .helpers import write_dataset
12 |
13 |
14 | en_sentence_re = re.compile(r'([a-zA-Z]+[\W]*\s){5,}')
15 | errorify = Errorify()
16 | edit_tagger = EditTagger()
17 |
18 |
19 | def preprocess_wiki_part(args,
20 | correct_file='corr_sentences.txt',
21 | incorrect_file='incorr_sentences.txt',
22 | edit_tags_file='edit_tagged_sentences.tfrec.gz'):
23 | edit_tagger.edit_freq = Counter()
24 | root, fn, output_dir, use_existing = args
25 | edit_rows = []
26 | fp = os.path.join(root, fn)
27 | base = os.path.basename(root)
28 | base_path = os.path.join(output_dir, base, fn)
29 | if not os.path.exists(base_path):
30 | os.makedirs(base_path)
31 | corr_path = os.path.join(base_path, correct_file)
32 | incorr_path = os.path.join(base_path, incorrect_file)
33 | if use_existing:
34 | with open(corr_path, 'r', encoding='utf-8') as f:
35 | corr_lines = f.readlines()
36 | with open(incorr_path, 'r', encoding='utf-8') as f:
37 | incorr_lines = f.readlines()
38 | else:
39 | corr_lines = []
40 | incorr_lines = []
41 | with open(fp, encoding='utf-8') as file:
42 | skip = False
43 | for line in file.readlines():
44 | line = line.strip()
45 | if not line or line[0] == '<' or line[-1] == '.' or skip:
46 | skip = False
47 | continue
48 | if line[-1] != '。':
49 | skip = True
50 | continue
51 | if en_sentence_re.search(line):
52 | continue
53 | line = unicodedata.normalize('NFKC', line).replace(' ', '')
54 | quote_lvl = 0
55 | brackets_lvl = 0
56 | start_i = 0
57 | sents = []
58 | for i, c in enumerate(line):
59 | if c == '「':
60 | quote_lvl += 1
61 | elif c == '」':
62 | quote_lvl -= 1
63 | elif c == '(':
64 | brackets_lvl += 1
65 | elif c == ')':
66 | brackets_lvl -= 1
67 | elif c == '。' and quote_lvl == 0 and brackets_lvl == 0:
68 | sents.append(line[start_i:i+1])
69 | start_i = i+1
70 | for sent in sents:
71 | sent = sent.strip().lstrip('。')
72 | if not sent:
73 | continue
74 | error_sent = errorify(sent)
75 | corr_lines.append(f'{sent}\n')
76 | incorr_lines.append(f'{error_sent}\n')
77 | with open(corr_path, 'w', encoding='utf-8') as file:
78 | file.writelines(corr_lines)
79 | with open(incorr_path, 'w', encoding='utf-8') as file:
80 | file.writelines(incorr_lines)
81 | for incorr_line, corr_line in zip(incorr_lines, corr_lines):
82 | incorr_line = incorr_line.strip()
83 | corr_line = corr_line.strip()
84 | if not incorr_line or not corr_line:
85 | continue
86 | levels = edit_tagger(incorr_line, corr_line)
87 | edit_rows.extend(levels)
88 | edit_tags_path = os.path.join(base_path, edit_tags_file)
89 | write_dataset(edit_tags_path, edit_rows)
90 | print(f'Processed {len(corr_lines)} sentences, ' \
91 | f'{len(edit_rows)} edit-tagged sentences in {fp}')
92 | return len(corr_lines), len(edit_rows), edit_tagger.edit_freq
93 |
94 |
95 | def preprocess_wiki(source_dir, output_dir, processes, use_existing,
96 | edit_freq_file='edit_freq.json'):
97 | """Generate synthetic error corpus from Wikipedia dump."""
98 | if not os.path.isdir(source_dir):
99 | raise ValueError(f'WikiExtractor text folder not found at {source_dir}')
100 | n_sents = 0
101 | n_edit_sents = 0
102 | pool_inputs = []
103 | for root, dirs, files in os.walk(source_dir):
104 | if not dirs:
105 | for fn in files:
106 | pool_inputs.append([root, fn, output_dir, use_existing])
107 | print(f'Processing {len(pool_inputs)} parts...')
108 | pool = Pool(processes)
109 | pool_outputs = pool.imap_unordered(preprocess_wiki_part, pool_inputs)
110 | n_sents = 0
111 | n_edit_sents = 0
112 | edit_freq = Counter()
113 | for n in pool_outputs:
114 | n_sents += n[0]
115 | n_edit_sents += n[1]
116 | edit_freq.update(n[2])
117 |
118 | with open(os.path.join(output_dir, edit_freq_file), 'w') as f:
119 | json.dump(edit_freq, f)
120 | print(f'Processed {n_sents} sentences, {n_edit_sents} edit-tagged ' \
121 | 'sentences from Wikipedia dump')
122 | print(f'Synthetic error corpus output to {output_dir}')
123 |
124 |
125 | def main(args):
126 | preprocess_wiki(args.source_dir, args.output_dir, args.processes,
127 | args.use_existing)
128 |
129 |
130 | if __name__ == '__main__':
131 | parser = argparse.ArgumentParser()
132 | parser.add_argument('-s', '--source_dir',
133 | help='Path to text folder extracted by WikiExtractor',
134 | required=True)
135 | parser.add_argument('-o', '--output_dir',
136 | help='Path to output folder',
137 | required=True)
138 | parser.add_argument('-p', '--processes', type=int,
139 | help='Number of processes',
140 | required=False)
141 | parser.add_argument('-e', '--use_existing',
142 | help='Edit tag existing error-generated sentences',
143 | action='store_true')
144 | args = parser.parse_args()
145 | main(args)
146 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # gector-ja
2 |
3 | Grammatical error correction model described in the paper ["GECToR -- Grammatical Error Correction: Tag, Not Rewrite" (Omelianchuk et al. 2020)](https://arxiv.org/abs/2005.12592), implemented for Japanese. This project's code is based on the official implementation (https://github.com/grammarly/gector).
4 |
5 | The [pretrained Japanese BERT model](https://huggingface.co/cl-tohoku/bert-base-japanese-v2) used in this project was provided by Tohoku University NLP Lab.
6 |
7 | ## Datasets
8 |
9 | - [Japanese Wikipedia dump](https://dumps.wikimedia.org/), extracted with [WikiExtractor](https://github.com/attardi/wikiextractor), synthetic errors generated using preprocessing scripts
10 | - 19,841,767 training sentences
11 | - [NAIST Lang8 Learner Corpora](https://sites.google.com/site/naistlang8corpora/)
12 | - 6,066,306 training sentences (generated from 3,084,0376 original sentences)
13 |
14 | ### Synthetically Generated Error Corpus
15 |
16 | The Wikipedia corpus was used to synthetically generate errorful sentences, with a method similar to [Awasthi et al. 2019](https://github.com/awasthiabhijeet/PIE/tree/master/errorify), but with adjustments for Japanese. The details of the implementation can be found in the [preprocessing scripts](https://github.com/jonnyli1125/gector-ja/blob/main/utils/) in this repository.
17 |
18 | Example error-generated sentence:
19 | ```
20 | 西口側には宿泊施設や地元の日本酒や海、山の幸を揃えた飲食店、呑み屋など多くある。 # Correct
21 | 西口側までは宿泊から施設や地元の日本酒や、山の幸を揃えた飲食は店、呑み屋など多くあろう。 # Errorful
22 | ```
23 |
24 | ### Edit Tagging
25 |
26 | Using the preprocessed Wikipedia corpus and Lang8 corpus, the errorful sentences were tokenized using the WordPiece tokenizer from the [pretrained BERT model](https://huggingface.co/cl-tohoku/bert-base-japanese-v2). Each token was then mapped to a minimal sequence of token transformations, such that when the transformations are applied to the errorful sentence, it will lead to the target sentence. The GECToR paper explains this preprocessing step in more detail (section 3), and the code specifics can be found in the [official implementation](https://github.com/grammarly/gector/blob/master/utils/preprocess_data.py).
27 |
28 | Example edit-tagged sentence (using the same pair of sentences above):
29 | ```
30 | [CLS] 西口 側 まで は 宿泊 から 施設 や 地元 の 日本 酒 や 、 山 の 幸 を 揃え た 飲食 は 店 、 呑 ##み ##屋 など 多く あろう 。 [SEP]
31 | $KEEP $KEEP $KEEP $REPLACE_に $KEEP $KEEP $DELETE $KEEP $KEEP $KEEP $KEEP $KEEP $KEEP $APPEND_海 $KEEP $KEEP $KEEP $KEEP $KEEP $KEEP $KEEP $KEEP $DELETE $KEEP $KEEP $KEEP $KEEP $KEEP $KEEP $KEEP $TRANSFORM_VBV_VB $KEEP $KEEP
32 | ```
33 |
34 | Furthermore, on top of the basic 4 token transformations (`$KEEP`, `$DELETE`, `$APPEND`, `$REPLACE`), there are a set of special transformations called "g-transformations" (i.e. `$TRANSFORM_VBV_VB` in the example above). G-transformations are mainly used for common replacements, such as switching verb conjugations, as described in the GECToR paper (section 3). The g-transformations in this model were redefined to accommodate for Japanese verbs and i-adjectives, which both inflect for tense.
35 |
36 | ## Model Architecture
37 |
38 | The model consists of a [pretrained BERT encoder layer](https://huggingface.co/cl-tohoku/bert-base-japanese-v2) and two linear classification heads, one for `labels` and one for `detect`. `labels` predicts a specific edit transformation (`$KEEP`, `$DELETE`, `$APPEND_x`, etc), and `detect` predicts whether the token is `CORRECT` or `INCORRECT`. The results from the two are used to make a prediction. The predicted transformations are then applied to the errorful input sentence to obtain a corrected sentence.
39 |
40 | Furthermore, in some cases, one pass of predicted transformations is not sufficient to transform the errorful sentence to the target sentence. Therefore, we repeat the process again on the result of the previous pass of transformations, until the model predicts that the sentence no longer contains incorrect tokens.
41 |
42 | For more details about the model architecture and __iterative sequence tagging approach__, refer to section 4 and 5 of the GECToR paper or the [official implementation](https://github.com/grammarly/gector/blob/master/gector/seq2labels_model.py).
43 |
44 | ## Training
45 |
46 | The model was trained in Colab with TPUs on each corpus with the following hyperparameters (default is used if unspecified):
47 |
48 | ```
49 | batch_size: 64
50 | learning_rate: 1e-5
51 | bert_trainable: true
52 | ```
53 |
54 | Synthetic error corpus (Wikipedia dump):
55 | ```
56 | length: 19841767
57 | epochs: 3
58 | ```
59 |
60 | Lang8 corpus:
61 | ```
62 | length: 6066306
63 | epochs: 10
64 | ```
65 |
66 | ## Demo App
67 |
68 | Trained weights can be downloaded [here](https://drive.google.com/file/d/1nhWzDZnZKxLvqwYMLlwRNOkMK2aXv4-5/view?usp=sharing).
69 |
70 | Extract `model.zip` to the `data/` directory. You should have the following folder structure:
71 |
72 | ```
73 | gector-ja/
74 | data/
75 | model/
76 | checkpoint
77 | model_checkpoint.data-00000-of-00001
78 | model_checkpoint.index
79 | ...
80 | main.py
81 | ...
82 | ```
83 |
84 | After downloading and extracting the weights, the demo app can be run with the command `python main.py`.
85 |
86 | You may need to `pip install flask` if Flask is not already installed.
87 |
88 | ## Evaluation
89 |
90 | The model can be evaluated with `evaluate.py` on a parallel sentences corpus. The evaluation corpus used was [TMU Evaluation Corpus for Japanese Learners (Koyama et al. 2020)](https://www.aclweb.org/anthology/2020.lrec-1.26/), and the metric is GLEU score.
91 |
92 | Using the model trained with the parameters described above, it achieved a GLEU score of around 0.81, which appears to outperform the CNN-based method by Chollampatt and Ng, 2018 (state of the art on the CoNLL-2014 dataset prior to transformer-based models), that Koyama et al. 2020 chose to use in their paper.
93 |
94 | #### CoNLL-2014 (GEC dataset for English)
95 | | Method | F0.5 |
96 | | ------------------------- | ----- |
97 | | Chollampatt and Ng, 2018 | 56.52 |
98 | | Omelianchuk et al., 2020 | 66.5 |
99 |
100 | #### TMU Evaluation Corpus for Japanese Learners (GEC dataset for Japanese)
101 | | Method | GLEU |
102 | | ------------------------- | ----- |
103 | | Chollampatt and Ng, 2018 | 0.739 |
104 | | __gector-ja (this project)__ | __0.81__ |
105 |
106 | In the GECToR paper, F0.5 score was used, which can also be determined through use of [errant](https://github.com/chrisjbryant/errant) and [m2scorer](https://github.com/nusnlp/m2scorer). However, these tools were designed to be used for evaluation on the CoNLL-2014 dataset, and using them for this project would also require modifying the tools' source code to accommodate for Japanese. In this project GLEU score was used as in Koyama et al. 2020, which works "out of the box" from the NLTK library.
107 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 | from tensorflow import keras
6 | from tensorflow.keras import layers
7 | from transformers import TFAutoModel, AutoTokenizer, AdamWeightDecay
8 |
9 | from utils.helpers import Vocab
10 |
11 |
12 | class GEC:
13 | def __init__(self, max_len=128, confidence=0.0, min_error_prob=0.0,
14 | learning_rate=1e-5,
15 | vocab_path='data/output_vocab/',
16 | verb_adj_forms_path='data/transform.txt',
17 | bert_model='cl-tohoku/bert-base-japanese-v2',
18 | pretrained_weights_path=None,
19 | bert_trainable=True):
20 | self.max_len = max_len
21 | self.confidence = confidence
22 | self.min_error_prob = min_error_prob
23 | self.tokenizer = AutoTokenizer.from_pretrained(bert_model)
24 | vocab_labels_path = os.path.join(vocab_path, 'labels.txt')
25 | vocab_detect_path = os.path.join(vocab_path, 'detect.txt')
26 | self.vocab_labels = Vocab.from_file(vocab_labels_path)
27 | self.vocab_detect = Vocab.from_file(vocab_detect_path)
28 | self.model = self.get_model(bert_model, bert_trainable, learning_rate)
29 | if pretrained_weights_path:
30 | self.model.load_weights(pretrained_weights_path)
31 | self.transform = self.get_transforms(verb_adj_forms_path)
32 |
33 | def get_model(self, bert_model, bert_trainable=True, learning_rate=None):
34 | encoder = TFAutoModel.from_pretrained(bert_model)
35 | encoder.bert.trainable = bert_trainable
36 | input_ids = layers.Input(shape=(self.max_len,), dtype=tf.int32,
37 | name='input_ids')
38 | attention_mask = input_ids != 0
39 | embedding = encoder(input_ids, attention_mask=attention_mask,
40 | training=bert_trainable)[0]
41 | n_labels = len(self.vocab_labels)
42 | n_detect = len(self.vocab_detect)
43 | labels_probs = layers.Dense(n_labels, activation='softmax',
44 | name='labels_probs')(embedding)
45 | detect_probs = layers.Dense(n_detect, activation='softmax',
46 | name='detect_probs')(embedding)
47 | model = keras.Model(
48 | inputs=input_ids,
49 | outputs=[labels_probs, detect_probs]
50 | )
51 | losses = [keras.losses.SparseCategoricalCrossentropy(),
52 | keras.losses.SparseCategoricalCrossentropy()]
53 | optimizer = AdamWeightDecay(learning_rate=learning_rate)
54 | model.compile(optimizer=optimizer, loss=losses,
55 | weighted_metrics=['sparse_categorical_accuracy'])
56 | return model
57 |
58 | def predict(self, input_ids):
59 | labels_probs, detect_probs = self.model(input_ids, training=False)
60 |
61 | # get maximum INCORRECT probability across tokens for each sequence
62 | incorr_index = self.vocab_detect['INCORRECT']
63 | mask = tf.cast(input_ids != 0, tf.float32)
64 | error_probs = detect_probs[:, :, incorr_index] * mask
65 | max_error_probs = tf.math.reduce_max(error_probs, axis=-1)
66 |
67 | # boost $KEEP probability by self.confidence
68 | if self.confidence > 0:
69 | keep_index = self.vocab_labels['$KEEP']
70 | prob_change = np.zeros(labels_probs.shape[2])
71 | prob_change[keep_index] = self.confidence
72 | B = labels_probs.shape[0]
73 | S = labels_probs.shape[1]
74 | prob_change = tf.reshape(tf.tile(prob_change, [B * S]), [B, S, -1])
75 | labels_probs += prob_change
76 |
77 | output_dict = {
78 | 'labels_probs': labels_probs.numpy(), # (B, S, n_labels)
79 | 'detect_probs': detect_probs.numpy(), # (B, S, n_detect)
80 | 'max_error_probs': max_error_probs.numpy(), # (B,)
81 | }
82 |
83 | # get decoded text labels
84 | for namespace in ['labels', 'detect']:
85 | vocab = getattr(self, f'vocab_{namespace}')
86 | probs = output_dict[f'{namespace}_probs']
87 | decoded_batch = []
88 | for seq in probs:
89 | argmax_idx = np.argmax(seq, axis=-1)
90 | tags = [vocab[i] for i in argmax_idx]
91 | decoded_batch.append(tags)
92 | output_dict[namespace] = decoded_batch
93 |
94 | return output_dict
95 |
96 | def correct(self, sentences, max_iter=10):
97 | single = isinstance(sentences, str)
98 | cur_sentences = [sentences] if single else sentences
99 | for i in range(max_iter):
100 | new_sentences = self.correct_once(cur_sentences)
101 | if cur_sentences == new_sentences:
102 | break
103 | cur_sentences = new_sentences
104 | return cur_sentences[0] if single else cur_sentences
105 |
106 | def correct_once(self, sentences):
107 | input_dict = self.tokenizer(sentences, add_special_tokens=True,
108 | padding='max_length', max_length=self.max_len, return_tensors='tf')
109 | output_dict = self.predict(input_dict['input_ids'])
110 | labels = output_dict['labels']
111 | labels_probs = tf.math.reduce_max(
112 | output_dict['labels_probs'], axis=-1).numpy()
113 | new_sentences = []
114 | for i, sentence in enumerate(sentences):
115 | max_error_prob = output_dict['max_error_probs'][i]
116 | if max_error_prob < self.min_error_prob:
117 | new_sentences.append(sentence)
118 | continue
119 | input_ids = input_dict['input_ids'][i].numpy()
120 | tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
121 | mask = input_dict['attention_mask'][i].numpy()
122 | for j in range(len(tokens)):
123 | if not mask[j]:
124 | tokens[j] = ''
125 | elif labels_probs[i][j] < self.min_error_prob:
126 | continue
127 | elif labels[i][j] in ['[PAD]', '[UNK]', '$KEEP']:
128 | continue
129 | elif labels[i][j] == '$DELETE':
130 | tokens[j] = ''
131 | elif labels[i][j].startswith('$APPEND_'):
132 | tokens[j] += ' ' + labels[i][j].replace('$APPEND_', '')
133 | elif labels[i][j].startswith('$REPLACE_'):
134 | tokens[j] = labels[i][j].replace('$REPLACE_', '')
135 | elif labels[i][j].startswith('$TRANSFORM_'):
136 | transform_op = labels[i][j].replace('$TRANSFORM_', '')
137 | key = f'{tokens[j]}_{transform_op}'
138 | if key in self.transform:
139 | tokens[j] = self.transform[key]
140 | tokens = ' '.join(tokens).split()
141 | tokens = [t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']]
142 | new_sentence = self.tokenizer.convert_tokens_to_string(tokens)
143 | new_sentence = new_sentence.replace(' ', '')
144 | new_sentences.append(new_sentence)
145 | return new_sentences
146 |
147 | def get_transforms(self, verb_adj_forms_path):
148 | decode = {}
149 | with open(verb_adj_forms_path, 'r', encoding='utf-8') as f:
150 | for line in f:
151 | words, tags = line.split(':')
152 | tags = tags.strip()
153 | word1, word2 = words.split('_')
154 | tag1, tag2 = tags.split('_')
155 | decode_key = f'{word1}_{tag1}_{tag2}'
156 | if decode_key not in decode:
157 | decode[decode_key] = word2
158 | return decode
159 |
--------------------------------------------------------------------------------
/utils/errorify.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import traceback
4 |
5 | import numpy as np
6 | from numpy.random import choice, randint, uniform, binomial
7 | from fugashi import Tagger
8 |
9 |
10 | class Errorify:
11 | """Generate artificial errors in sentences."""
12 |
13 | def __init__(self, reading_lookup_path='./data/reading_lookup.json'):
14 | self.tagger = Tagger('-Owakati')
15 | self.error_prob = [0.05, 0.07, 0.25, 0.35, 0.28]
16 | self.core_particles = ['が', 'の', 'を', 'に', 'へ', 'と', 'で', 'から',
17 | 'より', 'は', 'も']
18 | self.other_particles = ['か', 'の', 'や', 'に', 'と', 'やら', 'なり',
19 | 'だの', 'ばかり', 'まで', 'だけ', 'ほど', 'くらい',
20 | 'など', 'やら', 'こそ', 'でも', 'しか', 'さえ',
21 | 'だに', 'ば', 'て', 'のに', 'ので', 'から']
22 | with open(reading_lookup_path) as f:
23 | self.reading_lookup = json.load(f)
24 |
25 | def delete_error(self, token, feature):
26 | """Delete a token."""
27 | return ''
28 |
29 | def inflection_error(self, token, feature):
30 | """Misinflect a verb/adj stem."""
31 | baseform = feature.orthBase or feature.lemma
32 | if not baseform:
33 | return token
34 | morphs = list(self.get_forms(baseform).values())
35 | if not morphs:
36 | return token
37 | return choice(morphs)
38 |
39 | def insert_error(self, token, feature):
40 | """Insert a random particle."""
41 | return token + choice(self.other_particles)
42 |
43 | def replace_error(self, token, feature):
44 | """Replace a particle or word with another word of the same reading."""
45 | if feature.pos2 in ['格助詞', '係助詞']:
46 | return choice(self.core_particles)
47 | elif feature.pos1 in ['動詞', '形容詞']:
48 | reading = f'{feature.kanaBase[:-1]}.{feature.kanaBase[-1]}'
49 | if reading not in self.reading_lookup:
50 | return token
51 | ending = token[len(feature.orthBase)-1:]
52 | return choice(self.reading_lookup[reading]) + ending
53 | else:
54 | if feature.kanaBase not in self.reading_lookup:
55 | return token
56 | return choice(self.reading_lookup[feature.kanaBase])
57 |
58 | def __call__(self, sentence):
59 | """Get sentence with artificially generated errors."""
60 | # need to this because fugashi has some weird bug
61 | tokens = [(t.surface, t.feature) for t in self.tagger(sentence)]
62 | tokens_surface = [t[0] for t in tokens]
63 | n_errors = choice(range(len(self.error_prob)), p=self.error_prob)
64 | candidate_tokens = [i for i, (t, f) in enumerate(tokens)
65 | if f.pos2 not in ['数詞', '固有名詞']
66 | and f.pos1 not in ['記号', '補助記号']]
67 | if not candidate_tokens:
68 | return sentence
69 | error_token_ids = choice(candidate_tokens, size=(n_errors,))
70 | for token_id in error_token_ids:
71 | token, feat = tokens[token_id]
72 | if feat.pos2 in ['格助詞', '係助詞']:
73 | error_func = choice([self.delete_error, self.replace_error])
74 | elif feat.pos1 in ['動詞', '形容詞']:
75 | error_func = choice([self.replace_error, self.inflection_error],
76 | p=[0.05, 0.95])
77 | elif feat.pos1 == '名詞':
78 | error_func = choice([self.insert_error, self.replace_error],
79 | p=[0.05, 0.95])
80 | else:
81 | error_func = choice([self.insert_error, self.delete_error],
82 | p=[0.05, 0.95])
83 | tokens_surface[token_id] = error_func(token, feat)
84 | return ''.join(tokens_surface)
85 |
86 | def get_forms(self, baseform):
87 | f = self.tagger(baseform)[0].feature
88 | # irregular verbs
89 | if f.orthBase == 'する' and f.lemma == '為る':
90 | return {
91 | 'VB': 'する', # plain (終止形)
92 | 'VBI': 'し', # imperfect (未然形)
93 | 'VBC': 'し', # conjunctive (連用形)
94 | 'VBCG': 'し', # conjunctive geminate (連用形-促音便)
95 | 'VBP': 'しろ', # imperative (命令系)
96 | 'VBV': 'しよう', # volitional (意志推量形)
97 | 'VBS': 'する' # stem/subword token
98 | }
99 | elif f.kanaBase == 'イク' or (f.orthBase and f.orthBase[-2:] == '行く'):
100 | forms = {
101 | 'VB': 'く', # plain (終止形)
102 | 'VBI': 'か', # imperfect (未然形)
103 | 'VBC': 'き', # conjunctive (連用形)
104 | 'VBCG': 'っ', # conjunctive geminate (連用形-促音便)
105 | 'VBP': 'け', # imperative (命令系)
106 | 'VBV': 'こう', # volitional (意志推量形)
107 | 'VBS': 'く' # stem/subword token
108 | }
109 | elif f.pos1 == '形容詞': # i-adj
110 | forms = {
111 | 'ADJ': 'い', # plain (終止形)
112 | 'ADJC': 'く', # conjunctive (連用形)
113 | 'ADJCG': 'かっ', # conjunctive geminate (連用形-促音便)
114 | 'ADJS': '' # stem/subword token
115 | }
116 | elif '一段' in f.cType: # ru-verbs
117 | forms = {
118 | 'VB': 'る', # plain (終止形)
119 | 'VBI': '', # imperfect (未然形)
120 | 'VBC': '', # conjunctive (連用形)
121 | 'VBCG': '', # conjunctive geminate (連用形-促音便)
122 | 'VBP': 'ろ', # imperative (命令系)
123 | 'VBV': 'よう', # volitional (意志推量形)
124 | 'VBS': '' # stem/subword token
125 | }
126 | elif baseform[-1] == 'る': # u-verbs from here
127 | forms = {
128 | 'VB': 'る', # plain (終止形)
129 | 'VBI': 'ら', # imperfect (未然形)
130 | 'VBC': 'り', # conjunctive (連用形)
131 | 'VBCG': 'っ', # conjunctive geminate (連用形-促音便)
132 | 'VBP': 'れ', # imperative (命令系)
133 | 'VBV': 'ろう', # volitional (意志推量形)
134 | 'VBS': '' # stem/subword token
135 | }
136 | elif baseform[-1] == 'つ':
137 | forms = {
138 | 'VB': 'つ', # plain (終止形)
139 | 'VBI': 'た', # imperfect (未然形)
140 | 'VBC': 'ち', # conjunctive (連用形)
141 | 'VBCG': 'っ', # conjunctive geminate (連用形-促音便)
142 | 'VBP': 'て', # imperative (命令系)
143 | 'VBV': 'とう', # volitional (意志推量形)
144 | 'VBS': '' # stem/subword token
145 | }
146 | elif baseform[-1] == 'う':
147 | forms = {
148 | 'VB': 'う', # plain (終止形)
149 | 'VBI': 'わ', # imperfect (未然形)
150 | 'VBC': 'い', # conjunctive (連用形)
151 | 'VBCG': 'っ', # conjunctive geminate (連用形-促音便)
152 | 'VBP': 'え', # imperative (命令系)
153 | 'VBV': 'おう', # volitional (意志推量形)
154 | 'VBS': '' # stem/subword token
155 | }
156 | elif baseform[-1] == 'く':
157 | forms = {
158 | 'VB': 'く', # plain (終止形)
159 | 'VBI': 'か', # imperfect (未然形)
160 | 'VBC': 'き', # conjunctive (連用形)
161 | 'VBCG': 'い', # conjunctive geminate (連用形-促音便)
162 | 'VBP': 'け', # imperative (命令系)
163 | 'VBV': 'こう', # volitional (意志推量形)
164 | 'VBS': '' # stem/subword token
165 | }
166 | elif baseform[-1] == 'ぐ':
167 | forms = {
168 | 'VB': 'ぐ', # plain (終止形)
169 | 'VBI': 'が', # imperfect (未然形)
170 | 'VBC': 'ぎ', # conjunctive (連用形)
171 | 'VBCG': 'い', # conjunctive geminate (連用形-促音便)
172 | 'VBP': 'げ', # imperative (命令系)
173 | 'VBV': 'ごう', # volitional (意志推量形)
174 | 'VBS': '' # stem/subword token
175 | }
176 | elif baseform[-1] == 'す':
177 | forms = {
178 | 'VB': 'す', # plain (終止形)
179 | 'VBI': 'さ', # imperfect (未然形)
180 | 'VBC': 'し', # conjunctive (連用形)
181 | 'VBCG': 'し', # conjunctive geminate (連用形-促音便)
182 | 'VBP': 'せ', # imperative (命令系)
183 | 'VBV': 'そう', # volitional (意志推量形)
184 | 'VBS': '' # stem/subword token
185 | }
186 | elif baseform[-1] == 'む':
187 | forms = {
188 | 'VB': 'む', # plain (終止形)
189 | 'VBI': 'ま', # imperfect (未然形)
190 | 'VBC': 'み', # conjunctive (連用形)
191 | 'VBCG': 'ん', # conjunctive geminate (連用形-促音便)
192 | 'VBP': 'め', # imperative (命令系)
193 | 'VBV': 'もう', # volitional (意志推量形)
194 | 'VBS': '' # stem/subword token
195 | }
196 | elif baseform[-1] == 'ぬ':
197 | forms = {
198 | 'VB': 'ぬ', # plain (終止形)
199 | 'VBI': 'な', # imperfect (未然形)
200 | 'VBC': 'に', # conjunctive (連用形)
201 | 'VBCG': 'ん', # conjunctive geminate (連用形-促音便)
202 | 'VBP': 'ね', # imperative (命令系)
203 | 'VBV': 'のう', # volitional (意志推量形)
204 | 'VBS': '' # stem/subword token
205 | }
206 | elif baseform[-1] == 'ぶ':
207 | forms = {
208 | 'VB': 'ぶ', # plain (終止形)
209 | 'VBI': 'ば', # imperfect (未然形)
210 | 'VBC': 'び', # conjunctive (連用形)
211 | 'VBCG': 'ん', # conjunctive geminate (連用形-促音便)
212 | 'VBP': 'べ', # imperative (命令系)
213 | 'VBV': 'ぼう', # volitional (意志推量形)
214 | 'VBS': '' # stem/subword token
215 | }
216 | else:
217 | forms = {}
218 | stem = baseform[:-1]
219 | return {form: stem + end for form, end in forms.items()}
220 |
--------------------------------------------------------------------------------
/utils/edits.py:
--------------------------------------------------------------------------------
1 | from difflib import SequenceMatcher
2 | from collections import Counter
3 |
4 | from transformers import AutoTokenizer
5 | import numpy as np
6 | import Levenshtein
7 |
8 | from .helpers import Vocab, create_example
9 |
10 |
11 | class EditTagger:
12 | """
13 | Get edit sequences to transform source sentence to target sentence.
14 |
15 | Original reference code @ https://github.com/grammarly/gector (see README).
16 | """
17 |
18 | def __init__(self,
19 | verb_adj_forms_path='data/transform.txt',
20 | vocab_detect_path='data/output_vocab/detect.txt',
21 | vocab_labels_path='data/output_vocab/labels.txt'):
22 | self.tokenizer = AutoTokenizer.from_pretrained(
23 | 'cl-tohoku/bert-base-japanese-v2')
24 | encode, decode = self.get_verb_adj_form_dicts(verb_adj_forms_path)
25 | self.encode_verb_adj_form = encode
26 | self.decode_verb_adj_form = decode
27 | self.vocab_detect = Vocab.from_file(vocab_detect_path)
28 | self.vocab_labels = Vocab.from_file(vocab_labels_path)
29 | self.edit_freq = Counter()
30 |
31 | def get_verb_adj_form_dicts(self, verb_adj_forms_path):
32 | encode, decode = {}, {}
33 | with open(verb_adj_forms_path, 'r', encoding='utf-8') as f:
34 | for line in f:
35 | words, tags = line.split(':')
36 | tags = tags.strip()
37 | word1, word2 = words.split('_')
38 | tag1, tag2 = tags.split('_')
39 | decode_key = f'{word1}_{tag1}_{tag2}'
40 | if decode_key not in decode:
41 | encode[words] = tags
42 | decode[decode_key] = word2
43 | return encode, decode
44 |
45 | def tokenize(self, sentence, **kwargs):
46 | ids = self.tokenizer(sentence, **kwargs)['input_ids']
47 | return self.tokenizer.convert_ids_to_tokens(ids)
48 |
49 | def join_tokens(self, tokens):
50 | return self.tokenizer.convert_tokens_to_string(tokens).replace(' ', '')
51 |
52 | def __call__(self, source, target, levels=False):
53 | edit_rows = []
54 | if levels:
55 | edit_levels = self.get_edit_levels(source, target)
56 | else:
57 | edit_levels = [self.get_edits(source, target)]
58 | for cur_tokens, cur_edits in edit_levels:
59 | cur_edits = [e[0] for e in cur_edits]
60 | self.edit_freq.update(cur_edits)
61 | row = create_example(cur_tokens, cur_edits, self.tokenizer,
62 | self.vocab_labels, self.vocab_detect)
63 | edit_rows.append(row)
64 | return edit_rows
65 |
66 | def get_edits(self, source, target, add_special_tokens=True, max_len=128):
67 | source_tokens = self.tokenize(source,
68 | add_special_tokens=add_special_tokens)
69 | target_tokens = self.tokenize(target, add_special_tokens=True)
70 | if len(source_tokens) > max_len or len(target_tokens) > max_len:
71 | return [], []
72 | matcher = SequenceMatcher(None, source_tokens, target_tokens)
73 | diffs = list(matcher.get_opcodes())
74 | edits = []
75 | for tag, i1, i2, j1, j2 in diffs:
76 | source_part = source_tokens[i1:i2]
77 | target_part = target_tokens[j1:j2]
78 | if tag == 'equal':
79 | continue
80 | elif tag == 'delete':
81 | for i in range(i1, i2):
82 | edits.append((i, '$DELETE'))
83 | elif tag == 'insert':
84 | for target_token in target_part:
85 | edits.append((i1-1, f'$APPEND_{target_token}'))
86 | else: # tag == 'replace'
87 | _, alignments = self.perfect_align(source_part, target_part)
88 | for alignment in alignments:
89 | new_edits = self.convert_alignment_into_edits(alignment, i1)
90 | edits.extend(new_edits)
91 |
92 | # map edits to source tokens
93 | labels = [['$KEEP'] for i in range(len(source_tokens))]
94 | for i, edit in edits:
95 | if labels[i] == ['$KEEP']:
96 | labels[i] = []
97 | labels[i].append(edit)
98 |
99 | return source_tokens, labels
100 |
101 | def perfect_align(self, t, T, insertions_allowed=0,
102 | cost_function=Levenshtein.distance):
103 | # dp[i, j, k] is a minimal cost of matching first `i` tokens of `t` with
104 | # first `j` tokens of `T`, after making `k` insertions after last match
105 | # of token from `t`. In other words t[:i] aligned with T[:j].
106 |
107 | # Initialize with INFINITY (unknown)
108 | shape = (len(t) + 1, len(T) + 1, insertions_allowed + 1)
109 | dp = np.ones(shape, dtype=int) * int(1e9)
110 | come_from = np.ones(shape, dtype=int) * int(1e9)
111 | come_from_ins = np.ones(shape, dtype=int) * int(1e9)
112 |
113 | dp[0, 0, 0] = 0 # Starting point. Nothing matched to nothing.
114 | for i in range(len(t) + 1): # Go inclusive
115 | for j in range(len(T) + 1): # Go inclusive
116 | for q in range(insertions_allowed + 1): # Go inclusive
117 | if i < len(t):
118 | # Given matched sequence of t[:i] and T[:j], match token
119 | # t[i] with following tokens T[j:k].
120 | for k in range(j, len(T) + 1):
121 | T_jk = ' '.join(T[j:k])
122 | transform = self.get_g_trans(t[i], T_jk)
123 | if transform:
124 | cost = 0
125 | else:
126 | cost = cost_function(t[i], T_jk)
127 | current = dp[i, j, q] + cost
128 | if dp[i + 1, k, 0] > current:
129 | dp[i + 1, k, 0] = current
130 | come_from[i + 1, k, 0] = j
131 | come_from_ins[i + 1, k, 0] = q
132 | if q < insertions_allowed:
133 | # Given matched sequence of t[:i] and T[:j], create
134 | # insertion with following tokens T[j:k].
135 | for k in range(j, len(T) + 1):
136 | cost = len(' '.join(T[j:k]))
137 | current = dp[i, j, q] + cost
138 | if dp[i, k, q + 1] > current:
139 | dp[i, k, q + 1] = current
140 | come_from[i, k, q + 1] = j
141 | come_from_ins[i, k, q + 1] = q
142 |
143 | # Solution is in the dp[len(t), len(T), *]. Backtracking from there.
144 | alignment = []
145 | i = len(t)
146 | j = len(T)
147 | q = dp[i, j, :].argmin()
148 | while i > 0 or q > 0:
149 | is_insert = (come_from_ins[i, j, q] != q) and (q != 0)
150 | j, k, q = come_from[i, j, q], j, come_from_ins[i, j, q]
151 | if not is_insert:
152 | i -= 1
153 |
154 | if is_insert:
155 | alignment.append(['INSERT', T[j:k], i])
156 | else:
157 | alignment.append([f'REPLACE_{t[i]}', T[j:k], i])
158 |
159 | assert j == 0
160 |
161 | return dp[len(t), len(T)].min(), list(reversed(alignment))
162 |
163 | def get_g_trans(self, source_token, target_token):
164 | # check equal
165 | if source_token == target_token:
166 | return '$KEEP'
167 | # check transform verb/adj form possible
168 | key = f'{source_token}_{target_token}'
169 | encoding = self.encode_verb_adj_form.get(key, '')
170 | if source_token and encoding:
171 | return f'$TRANSFORM_{encoding}'
172 | return None
173 |
174 | def convert_alignment_into_edits(self, alignment, i1):
175 | edits = []
176 | action, target_tokens, new_idx = alignment
177 | shift_idx = new_idx + i1
178 | source_token = action.replace('REPLACE_', '')
179 |
180 | # check if delete
181 | if not target_tokens:
182 | return [(shift_idx, '$DELETE')]
183 |
184 | # check splits
185 | for i in range(1, len(target_tokens)):
186 | target_token = ''.join(target_tokens[:i + 1])
187 | transform = self.get_g_trans(source_token, target_token)
188 | if transform:
189 | edits.append((shift_idx, transform))
190 | for target in target_tokens[i + 1:]:
191 | edits.append((shift_idx, f'$APPEND_{target}'))
192 | return edits
193 |
194 | # default case
195 | transform_costs = []
196 | transforms = []
197 | for target_token in target_tokens:
198 | transform = self.get_g_trans(source_token, target_token)
199 | if transform:
200 | cost = 0
201 | else:
202 | cost = Levenshtein.distance(source_token, target_token)
203 | transforms.append(transform)
204 | transform_costs.append(cost)
205 | min_cost_idx = np.argmin(transform_costs)
206 | # append everything before min cost token (target) to the previous word
207 | for i in range(min_cost_idx):
208 | edits.append((shift_idx - 1, f'$APPEND_{target_tokens[i]}'))
209 | # replace/transform target word
210 | transform = transforms[min_cost_idx]
211 | if transform:
212 | target = transform
213 | else:
214 | target = f'$REPLACE_{target_tokens[min_cost_idx]}'
215 | edits.append((shift_idx, target))
216 | # append everything after target to this word
217 | for i in range(min_cost_idx + 1, len(target_tokens)):
218 | edits.append((shift_idx, f'$APPEND_{target_tokens[i]}'))
219 | return edits
220 |
221 | def get_edit_levels(self, source, target, max_iter=10):
222 | levels = []
223 | cur_sent = source
224 | for i in range(max_iter):
225 | cur_tokens, cur_edits = self.get_edits(cur_sent, target,
226 | add_special_tokens=(i==0))
227 | if not cur_tokens:
228 | break
229 | if i > 0 and all(e == ['$KEEP'] for e in cur_edits):
230 | break
231 | levels.append((cur_tokens, cur_edits))
232 | new_tokens = self.apply_edits(cur_tokens, cur_edits)
233 | cur_sent = self.join_tokens(new_tokens)
234 | # tokenizer may produce [UNK] so we can't actually assert this
235 | # assert cur_sent == target
236 | return levels
237 |
238 | def apply_edits(self, source_tokens, edits):
239 | new_tokens = []
240 | for i, (token, edit_list) in enumerate(zip(source_tokens, edits)):
241 | edit = edit_list[0]
242 | if edit == '$KEEP':
243 | new_tokens.append(token)
244 | elif edit == '$DELETE':
245 | continue
246 | elif edit.startswith('$APPEND_'):
247 | new_tokens += [token, edit.replace('$APPEND_', '')]
248 | elif edit.startswith('$REPLACE_'):
249 | new_tokens.append(edit.replace('$REPLACE_', ''))
250 | elif edit.startswith('$TRANSFORM_'):
251 | transform = edit.replace('$TRANSFORM_', '')
252 | decode_key = f'{token}_{transform}'
253 | new_tokens.append(self.decode_verb_adj_form[decode_key])
254 | else:
255 | raise ValueError(f'Invalid edit {edit}')
256 | return new_tokens
257 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/data/output_vocab/labels.txt:
--------------------------------------------------------------------------------
1 | [PAD]
2 | [UNK]
3 | $KEEP
4 | $DELETE
5 | $APPEND_の
6 | $REPLACE_の
7 | $APPEND_た
8 | $APPEND_に
9 | $REPLACE_に
10 | $APPEND_て
11 | $APPEND_は
12 | $REPLACE_は
13 | $APPEND_で
14 | $REPLACE_年
15 | $REPLACE_を
16 | $APPEND_を
17 | $APPEND_が
18 | $REPLACE_が
19 | $APPEND_と
20 | $REPLACE_と
21 | $TRANSFORM_VBS_VB
22 | $REPLACE_で
23 | $TRANSFORM_VBS_VBCG
24 | $REPLACE_さ
25 | $APPEND_れ
26 | $REPLACE_月
27 | $REPLACE_た
28 | $APPEND_日
29 | $APPEND_も
30 | $REPLACE_て
31 | $REPLACE_も
32 | $TRANSFORM_VBP_VBCG
33 | $REPLACE_こと
34 | $APPEND_(
35 | $APPEND_)
36 | $APPEND_な
37 | $TRANSFORM_VBV_VBCG
38 | $TRANSFORM_VBCG_VBS
39 | $APPEND_、
40 | $APPEND_から
41 | $REPLACE_から
42 | $TRANSFORM_VBP_VB
43 | $TRANSFORM_VBCG_VB
44 | $REPLACE_し
45 | $APPEND_その
46 | $TRANSFORM_VBC_VB
47 | $TRANSFORM_VB_VBCG
48 | $APPEND_や
49 | $APPEND_この
50 | $APPEND_れる
51 | $REPLACE_県
52 | $REPLACE_市
53 | $APPEND_など
54 | $TRANSFORM_VB_VBS
55 | $TRANSFORM_VBI_VB
56 | $TRANSFORM_VBC_VBCG
57 | $REPLACE_ため
58 | $APPEND_的
59 | $APPEND_また
60 | $APPEND_第
61 | $APPEND_人
62 | $REPLACE_、
63 | $REPLACE_れ
64 | $APPEND_い
65 | $APPEND_だ
66 | $TRANSFORM_VBI_VBCG
67 | $APPEND_者
68 | $APPEND_よう
69 | $APPEND_られ
70 | $TRANSFORM_VBS_VBC
71 | $APPEND_まで
72 | $TRANSFORM_VBP_VBS
73 | $APPEND_ない
74 | $TRANSFORM_VBS_VBI
75 | $APPEND_か
76 | $REPLACE_後
77 | $REPLACE_もの
78 | $TRANSFORM_VBV_VB
79 | $APPEND_。
80 | $APPEND_し
81 | $APPEND_これ
82 | $REPLACE_。
83 | $REPLACE_部
84 | $TRANSFORM_VBCG_VBC
85 | $APPEND_つ
86 | $REPLACE_い
87 | $APPEND_こと
88 | $TRANSFORM_VBV_VBS
89 | $APPEND_家
90 | $APPEND_後
91 | $TRANSFORM_VB_VBC
92 | $TRANSFORM_ADJS_ADJ
93 | $REPLACE_時
94 | $APPEND_です
95 | $REPLACE_です
96 | $REPLACE_へ
97 | $REPLACE_大学
98 | $REPLACE_州
99 | $REPLACE_な
100 | $REPLACE_回
101 | $APPEND_中
102 | $REPLACE_(
103 | $APPEND_へ
104 | $APPEND_国
105 | $REPLACE_いる
106 | $APPEND_or
107 | $APPEND_だっ
108 | $REPLACE_駅
109 | $APPEND_ます
110 | $APPEND_ら
111 | $REPLACE_日
112 | $REPLACE_ます
113 | $TRANSFORM_VBI_VBC
114 | $APPEND_それ
115 | $TRANSFORM_VBCG_VBI
116 | $REPLACE_町
117 | $TRANSFORM_ADJC_ADJ
118 | $APPEND_まし
119 | $APPEND_ん
120 | $REPLACE_する
121 | $REPLACE_会
122 | $APPEND_位
123 | $REPLACE_代
124 | $APPEND_目
125 | $APPEND_さ
126 | $APPEND_いる
127 | $REPLACE_より
128 | $REPLACE_語
129 | $APPEND_しかし
130 | $REPLACE_名
131 | $REPLACE_いう
132 | $APPEND_所
133 | $APPEND_られる
134 | $REPLACE_ある
135 | $REPLACE_か
136 | $APPEND_戦
137 | $APPEND_大
138 | $REPLACE_き
139 | $APPEND_より
140 | $APPEND_化
141 | $TRANSFORM_ADJ_ADJC
142 | $REPLACE_号
143 | $APPEND_?
144 | $TRANSFORM_VB_VBI
145 | $REPLACE_以下
146 | $REPLACE_地
147 | $REPLACE_区
148 | $REPLACE_元
149 | $TRANSFORM_VBC_VBI
150 | $APPEND_ず
151 | $APPEND_彼
152 | $REPLACE_だ
153 | $APPEND_お
154 | $APPEND_歳
155 | $REPLACE_おり
156 | $REPLACE_学校
157 | $REPLACE_時代
158 | $REPLACE_選手
159 | $TRANSFORM_VBP_VBC
160 | $REPLACE_初
161 | $APPEND_ば
162 | $APPEND_性
163 | $REPLACE_他
164 | $APPEND_型
165 | $REPLACE_放送
166 | $APPEND_学
167 | $REPLACE_機
168 | $TRANSFORM_ADJS_ADJC
169 | $APPEND_なかっ
170 | $APPEND_線
171 | $REPLACE_だっ
172 | $APPEND_「
173 | $TRANSFORM_VBC_VBS
174 | $REPLACE_人
175 | $APPEND_約
176 | $APPEND_内
177 | $REPLACE_郡
178 | $REPLACE_ない
179 | $APPEND_全
180 | $REPLACE_使用
181 | $REPLACE_とも
182 | $APPEND_せ
183 | $REPLACE_る
184 | $REPLACE_1
185 | $APPEND_長
186 | $APPEND_車
187 | $APPEND_なお
188 | $REPLACE_都
189 | $APPEND_私
190 | $REPLACE_##い
191 | $TRANSFORM_ADJCG_ADJ
192 | $REPLACE_中
193 | $REPLACE_軍
194 | $REPLACE_)
195 | $REPLACE_賞
196 | $REPLACE_##っ
197 | $REPLACE_よう
198 | $REPLACE_社
199 | $REPLACE_映画
200 | $APPEND_たち
201 | $TRANSFORM_VBI_VBS
202 | $APPEND_」
203 | $REPLACE_・
204 | $REPLACE_期
205 | $APPEND_院
206 | $APPEND_および
207 | $REPLACE_なっ
208 | $REPLACE_同
209 | $REPLACE_ん
210 | $REPLACE_##し
211 | $APPEND_上
212 | $REPLACE_まし
213 | $APPEND_館
214 | $APPEND_のみ
215 | $APPEND_可能
216 | $APPEND_系
217 | $APPEND_する
218 | $REPLACE_##る
219 | $APPEND_いう
220 | $REPLACE_当時
221 | $APPEND_本
222 | $TRANSFORM_VBP_VBI
223 | $REPLACE_度
224 | $REPLACE_その
225 | $APPEND_ある
226 | $REPLACE_や
227 | $REPLACE_一部
228 | $APPEND_だけ
229 | $REPLACE_版
230 | $APPEND_小
231 | $REPLACE_れる
232 | $REPLACE_体
233 | $APPEND_権
234 | $REPLACE_参加
235 | $APPEND_さらに
236 | $APPEND_る
237 | $REPLACE_あり
238 | $APPEND_/
239 | $APPEND_等
240 | $APPEND_新
241 | $REPLACE_以降
242 | $APPEND_用
243 | $APPEND_##る
244 | $REPLACE_村
245 | $REPLACE_##く
246 | $APPEND_級
247 | $APPEND_そして
248 | $APPEND_氏
249 | $APPEND_世
250 | $REPLACE_法
251 | $REPLACE_役
252 | $TRANSFORM_ADJ_ADJS
253 | $REPLACE_府
254 | $REPLACE_位置
255 | $REPLACE_点
256 | $APPEND_同じ
257 | $REPLACE_旧
258 | $APPEND_党
259 | $REPLACE_とき
260 | $REPLACE_子
261 | $REPLACE_でし
262 | $APPEND_生
263 | $REPLACE_事
264 | $APPEND_現
265 | $REPLACE_中心
266 | $REPLACE_生まれ
267 | $APPEND_高等
268 | $TRANSFORM_ADJCG_ADJC
269 | $REPLACE_川
270 | $REPLACE_際
271 | $TRANSFORM_VBV_VBC
272 | $REPLACE_局
273 | $REPLACE_以上
274 | $REPLACE_曲
275 | $REPLACE_うち
276 | $REPLACE_よっ
277 | $REPLACE_式
278 | $REPLACE_数
279 | $APPEND_次
280 | $REPLACE_上
281 | $APPEND_ながら
282 | $APPEND_科
283 | $APPEND_場
284 | $REPLACE_ませ
285 | $REPLACE_大会
286 | $REPLACE_あっ
287 | $REPLACE_王
288 | $REPLACE_お
289 | $REPLACE_この
290 | $APPEND_たい
291 | $APPEND_たり
292 | $APPEND_あり
293 | $APPEND_ませ
294 | $REPLACE_たら
295 | $REPLACE_女性
296 | $REPLACE_ほか
297 | $APPEND_間
298 | $REPLACE_円
299 | $REPLACE_時間
300 | $APPEND_特に
301 | $REPLACE_世紀
302 | $REPLACE_番
303 | $APPEND_枚
304 | $REPLACE_艦
305 | $APPEND_##っ
306 | $REPLACE_共
307 | $REPLACE_移籍
308 | $REPLACE_にて
309 | $APPEND_官
310 | $REPLACE_国
311 | $APPEND_隊
312 | $APPEND_き
313 | $APPEND_さん
314 | $APPEND_作
315 | $TRANSFORM_VBS_VBP
316 | $APPEND_方
317 | $REPLACE_線
318 | $REPLACE_せ
319 | $REPLACE_それぞれ
320 | $APPEND_かつて
321 | $APPEND_でし
322 | $REPLACE_おい
323 | $REPLACE_種
324 | $REPLACE_都市
325 | $APPEND_もの
326 | $REPLACE_それ
327 | $REPLACE_ば
328 | $REPLACE_公開
329 | $REPLACE_文化
330 | $REPLACE_教授
331 | $REPLACE_試合
332 | $REPLACE_なる
333 | $APPEND_##い
334 | $APPEND_各
335 | $APPEND_せる
336 | $REPLACE_本
337 | $REPLACE_高校
338 | $REPLACE_よる
339 | $APPEND_ほど
340 | $REPLACE_地区
341 | $REPLACE_ところ
342 | $APPEND_初めて
343 | $APPEND_員
344 | $REPLACE_たい
345 | $APPEND_同様
346 | $REPLACE_構成
347 | $REPLACE_まま
348 | $APPEND_不
349 | $REPLACE_年間
350 | $REPLACE_クラブ
351 | $APPEND_主に
352 | $TRANSFORM_ADJ_ADJCG
353 | $APPEND_にて
354 | $REPLACE_方
355 | $APPEND_形
356 | $APPEND_そう
357 | $APPEND_そこ
358 | $REPLACE_目
359 | $APPEND_物
360 | $APPEND_派
361 | $APPEND_ただし
362 | $REPLACE_者
363 | $REPLACE_台
364 | $REPLACE_西
365 | $APPEND_彼女
366 | $APPEND_ここ
367 | $APPEND_最も
368 | $REPLACE_地方
369 | $REPLACE_なり
370 | $REPLACE_登場
371 | $APPEND_及び
372 | $APPEND_島
373 | $REPLACE_藩
374 | $APPEND_寺
375 | $REPLACE_##き
376 | $APPEND_省
377 | $APPEND_日本
378 | $REPLACE_まで
379 | $REPLACE_##ら
380 | $REPLACE_優勝
381 | $REPLACE_機関
382 | $APPEND_なり
383 | $REPLACE_次
384 | $REPLACE_?
385 | $REPLACE_受賞
386 | $REPLACE_ほとんど
387 | $REPLACE_##さ
388 | $APPEND_なっ
389 | $REPLACE_##り
390 | $REPLACE_一
391 | $REPLACE_戦争
392 | $APPEND_なら
393 | $REPLACE_頃
394 | $REPLACE_面
395 | $APPEND_何
396 | $APPEND_思い
397 | $APPEND_立
398 | $APPEND_下
399 | $APPEND_なく
400 | $APPEND_総
401 | $APPEND_こう
402 | $APPEND_よ
403 | $REPLACE_られ
404 | $REPLACE_全て
405 | $APPEND_店
406 | $APPEND_再
407 | $REPLACE_側
408 | $REPLACE_担当
409 | $REPLACE_##す
410 | $REPLACE_勝
411 | $REPLACE_##しい
412 | $REPLACE_身長
413 | $REPLACE_道
414 | $APPEND_城
415 | $APPEND_力
416 | $APPEND_大きな
417 | $REPLACE_国際
418 | $REPLACE_##ん
419 | $REPLACE_私
420 | $APPEND_よく
421 | $REPLACE_話
422 | $APPEND_くれ
423 | $REPLACE_##かっ
424 | $REPLACE_例
425 | $REPLACE_自身
426 | $REPLACE_##れ
427 | $REPLACE_形
428 | $REPLACE_すべて
429 | $TRANSFORM_VBCG_VBP
430 | $APPEND_たら
431 | $REPLACE_制作
432 | $REPLACE_でき
433 | $APPEND_なる
434 | $REPLACE_終了
435 | $REPLACE_率
436 | $REPLACE_最高
437 | $REPLACE_関する
438 | $APPEND_とも
439 | $REPLACE_利用
440 | $TRANSFORM_VB_VBP
441 | $REPLACE_必要
442 | $APPEND_翌
443 | $APPEND_再び
444 | $REPLACE_そう
445 | $APPEND_高
446 | $REPLACE_##ま
447 | $REPLACE_##か
448 | $REPLACE_下
449 | $REPLACE_年度
450 | $REPLACE_人口
451 | $REPLACE_父
452 | $REPLACE_大
453 | $APPEND_士
454 | $APPEND_##く
455 | $REPLACE_生産
456 | $REPLACE_時期
457 | $REPLACE_ず
458 | $REPLACE_見
459 | $REPLACE_エンジン
460 | $REPLACE_量
461 | $REPLACE_カップ
462 | $REPLACE_内
463 | $REPLACE_実施
464 | $REPLACE_両
465 | $REPLACE_字
466 | $REPLACE_自動
467 | $REPLACE_委員
468 | $APPEND_・
469 | $REPLACE_企業
470 | $REPLACE_神
471 | $REPLACE_県立
472 | $APPEND_どう
473 | $REPLACE_間
474 | $APPEND_重要
475 | $REPLACE_今
476 | $REPLACE_状態
477 | $REPLACE_これ
478 | $REPLACE_対する
479 | $APPEND_ため
480 | $REPLACE_設置
481 | $APPEND_団
482 | $REPLACE_期間
483 | $REPLACE_のち
484 | $REPLACE_時点
485 | $REPLACE_出身
486 | $APPEND_もう
487 | $APPEND_いずれ
488 | $REPLACE_##う
489 | $APPEND_特別
490 | $APPEND_いい
491 | $APPEND_ね
492 | $REPLACE_つ
493 | $REPLACE_名称
494 | $TRANSFORM_VBC_VBP
495 | $REPLACE_指定
496 | $REPLACE_」
497 | $REPLACE_家
498 | $REPLACE_協会
499 | $APPEND_でき
500 | $REPLACE_女子
501 | $APPEND_書
502 | $REPLACE_年代
503 | $REPLACE_収録
504 | $REPLACE_管理
505 | $APPEND_語
506 | $REPLACE_長
507 | $REPLACE_成功
508 | $REPLACE_先
509 | $REPLACE_変更
510 | $REPLACE_政治
511 | $REPLACE_ら
512 | $APPEND_しか
513 | $REPLACE_航空
514 | $APPEND_品
515 | $APPEND_副
516 | $REPLACE_マン
517 | $APPEND_時
518 | $REPLACE_施設
519 | $REPLACE_指揮
520 | $REPLACE_事業
521 | $APPEND_完全
522 | $APPEND_でしょう
523 | $REPLACE_製作
524 | $REPLACE_いっ
525 | $REPLACE_中央
526 | $APPEND_非
527 | $APPEND_様々
528 | $APPEND_向け
529 | $APPEND_公
530 | $REPLACE_手
531 | $REPLACE_けど
532 | $REPLACE_分
533 | $APPEND_洋
534 | $APPEND_達
535 | $REPLACE_公園
536 | $REPLACE_巻
537 | $APPEND_個
538 | $REPLACE_別
539 | $REPLACE_どう
540 | $APPEND_あるいは
541 | $APPEND_最
542 | $REPLACE_自分
543 | $REPLACE_前
544 | $APPEND_.
545 | $REPLACE_議員
546 | $REPLACE_気
547 | $REPLACE_##が
548 | $REPLACE_会長
549 | $REPLACE_対
550 | $REPLACE_思い
551 | $APPEND_制
552 | $REPLACE_階
553 | $REPLACE_み
554 | $APPEND_ぶり
555 | $APPEND_新た
556 | $APPEND_室
557 | $APPEND_無
558 | $REPLACE_もと
559 | $APPEND_史
560 | $REPLACE_完成
561 | $APPEND_=
562 | $REPLACE_的
563 | $REPLACE_なく
564 | $REPLACE_でしょう
565 | $REPLACE_最終
566 | $REPLACE_最後
567 | $REPLACE_戦闘
568 | $APPEND_教
569 | $REPLACE_「
570 | $APPEND_業
571 | $REPLACE_男性
572 | $REPLACE_工業
573 | $APPEND_非常
574 | $REPLACE_娘
575 | $APPEND_子
576 | $APPEND_類
577 | $REPLACE_選挙
578 | $REPLACE_##め
579 | $REPLACE_共和
580 | $REPLACE_とても
581 | $APPEND_けど
582 | $REPLACE_以外
583 | $APPEND_##し
584 | $APPEND_ほぼ
585 | $REPLACE_記念
586 | $APPEND_1
587 | $REPLACE_##あ
588 | $REPLACE_車両
589 | $REPLACE_島
590 | $APPEND_圏
591 | $APPEND_まだ
592 | $APPEND_だろう
593 | $REPLACE_つい
594 | $REPLACE_科学
595 | $REPLACE_当初
596 | $REPLACE_石
597 | $REPLACE_連邦
598 | $REPLACE_妻
599 | $REPLACE_専門
600 | $REPLACE_結婚
601 | $APPEND_つい
602 | $APPEND_有名
603 | $APPEND_あまり
604 | $REPLACE_正
605 | $TRANSFORM_ADJC_ADJS
606 | $APPEND_##れ
607 | $REPLACE_ザ
608 | $REPLACE_競技
609 | $APPEND_正式
610 | $APPEND_とき
611 | $REPLACE_なら
612 | $APPEND_金
613 | $APPEND_手
614 | $REPLACE_いい
615 | $APPEND_好き
616 | $REPLACE_基
617 | $REPLACE_帝国
618 | $APPEND_み
619 | $APPEND_打
620 | $APPEND_2
621 | $REPLACE_子供
622 | $REPLACE_対象
623 | $REPLACE_##は
624 | $REPLACE_##しく
625 | $REPLACE_また
626 | $REPLACE_##て
627 | $APPEND_今
628 | $REPLACE_学会
629 | $REPLACE_教会
630 | $APPEND_主要
631 | $REPLACE_花
632 | $REPLACE_以前
633 | $REPLACE_##け
634 | $APPEND_一
635 | $REPLACE_説
636 | $REPLACE_指導
637 | $REPLACE_##と
638 | $REPLACE_よ
639 | $REPLACE_一人
640 | $APPEND_一方
641 | $APPEND_末
642 | $REPLACE_##み
643 | $APPEND_]
644 | $APPEND_すぐ
645 | $REPLACE_初期
646 | $REPLACE_勝利
647 | $REPLACE_金
648 | $REPLACE_はじめ
649 | $REPLACE_など
650 | $APPEND_→
651 | $REPLACE_名前
652 | $REPLACE_演奏
653 | $APPEND_製
654 | $REPLACE_省
655 | $REPLACE_等
656 | $REPLACE_発生
657 | $REPLACE_##え
658 | $REPLACE_さん
659 | $REPLACE_##の
660 | $APPEND_[
661 | $REPLACE_段
662 | $REPLACE_公
663 | $APPEND_勉強
664 | $APPEND_なけれ
665 | $APPEND_海
666 | $REPLACE_生
667 | $REPLACE_工場
668 | $REPLACE_発行
669 | $REPLACE_##じ
670 | $REPLACE_朝
671 | $APPEND_誌
672 | $APPEND_校
673 | $REPLACE_大戦
674 | $REPLACE_愛称
675 | $APPEND_祭
676 | $REPLACE_ま
677 | $REPLACE_所
678 | $REPLACE_兵
679 | $REPLACE_たり
680 | $APPEND_べき
681 | $APPEND_ところ
682 | $REPLACE_結成
683 | $APPEND_事
684 | $REPLACE_半
685 | $REPLACE_対戦
686 | $APPEND_,
687 | $REPLACE_##せ
688 | $APPEND_##あ
689 | $REPLACE_組織
690 | $APPEND_園
691 | $REPLACE_なかっ
692 | $REPLACE_歳
693 | $APPEND_同じく
694 | $REPLACE_行き
695 | $REPLACE_橋
696 | $REPLACE_市長
697 | $APPEND_船
698 | $REPLACE_ね
699 | $REPLACE_学者
700 | $REPLACE_評価
701 | $REPLACE_着
702 | $APPEND_例えば
703 | $REPLACE_死
704 | $APPEND_族
705 | $APPEND_思っ
706 | $REPLACE_電気
707 | $APPEND_界
708 | $REPLACE_形成
709 | $REPLACE_群
710 | $APPEND_状
711 | $REPLACE_しょう
712 | $REPLACE_##しかっ
713 | $APPEND_未
714 | $REPLACE_後半
715 | $APPEND_しまい
716 | $APPEND_とても
717 | $REPLACE_発見
718 | $REPLACE_##ぎ
719 | $REPLACE_母
720 | $REPLACE_言葉
721 | $REPLACE_学
722 | $APPEND_たくさん
723 | $REPLACE_##わ
724 | $REPLACE_##む
725 | $APPEND_準
726 | $APPEND_##り
727 | $REPLACE_通称
728 | $APPEND_師
729 | $REPLACE_試験
730 | $REPLACE_みな
731 | $APPEND_ごと
732 | $REPLACE_国家
733 | $APPEND_主な
734 | $REPLACE_自由
735 | $REPLACE_息子
736 | $REPLACE_サン
737 | $REPLACE_共同
738 | $REPLACE_まだ
739 | $REPLACE_観光
740 | $APPEND_屋
741 | $APPEND_庁
742 | $REPLACE_以後
743 | $APPEND_いつ
744 | $REPLACE_写真
745 | $REPLACE_近く
746 | $APPEND_式
747 | $REPLACE_今日
748 | $REPLACE_.
749 | $REPLACE_同時
750 | $APPEND_ご
751 | $REPLACE_こう
752 | $REPLACE_公演
753 | $REPLACE_##ご
754 | $REPLACE_部隊
755 | $REPLACE_秒
756 | $REPLACE_企画
757 | $REPLACE_機能
758 | $APPEND_隻
759 | $REPLACE_姓
760 | $REPLACE_関し
761 | $REPLACE_効果
762 | $REPLACE_だけ
763 | $REPLACE_小
764 | $APPEND_超
765 | $REPLACE_せい
766 | $REPLACE_イン
767 | $REPLACE_条
768 | $REPLACE_衆議
769 | $REPLACE_感じ
770 | $REPLACE_紀元
771 | $APPEND_道
772 | $REPLACE_けれど
773 | $REPLACE_られる
774 | $REPLACE_決勝
775 | $REPLACE_公式
776 | $REPLACE_人々
777 | $REPLACE_う
778 | $TRANSFORM_ADJS_ADJCG
779 | $TRANSFORM_ADJC_ADJCG
780 | $APPEND_器
781 | $REPLACE_通じ
782 | $REPLACE_紹介
783 | $REPLACE_出版
784 | $REPLACE_##ー
785 | $REPLACE_戦
786 | $APPEND_意味
787 | $REPLACE_もう
788 | $REPLACE_舞台
789 | $REPLACE_男
790 | $REPLACE_何
791 | $REPLACE_すれ
792 | $REPLACE_そして
793 | $REPLACE_工事
794 | $APPEND_領
795 | $REPLACE_運行
796 | $REPLACE_重
797 | $REPLACE_始め
798 | $REPLACE_主
799 | $REPLACE_設計
800 | $APPEND_御
801 | $REPLACE_候補
802 | $REPLACE_特徴
803 | $REPLACE_文
804 | $REPLACE_来
805 | $REPLACE_現役
806 | $REPLACE_英
807 | $REPLACE_す
808 | $REPLACE_戦い
809 | $REPLACE_##よう
810 | $REPLACE_行動
811 | $REPLACE_くれ
812 | $REPLACE_##そ
813 | $REPLACE_しか
814 | $REPLACE_一種
815 | $REPLACE_高速
816 | $REPLACE_海
817 | $REPLACE_##た
818 | $REPLACE_性
819 | $APPEND_僕
820 | $REPLACE_よく
821 | $REPLACE_差
822 | $REPLACE_葉
823 | $REPLACE_木
824 | $REPLACE_地名
825 | $REPLACE_じ
826 | $REPLACE_支持
827 | $REPLACE_総合
828 | $REPLACE_飛行
829 | $REPLACE_週間
830 | $REPLACE_こ
831 | $APPEND_自分
832 | $REPLACE_##丈
833 | $REPLACE_刊行
834 | $APPEND_堂
835 | $APPEND_敗
836 | $REPLACE_がっこう
837 | $REPLACE_事実
838 | $REPLACE_##張
839 | $APPEND_湖
840 | $APPEND_見
841 | $REPLACE_以来
842 | $REPLACE_編成
843 | $APPEND_*
844 | $APPEND_##よう
845 | $APPEND_あっ
846 | $REPLACE_一方
847 | $REPLACE_新しい
848 | $REPLACE_章
849 | $REPLACE_競走
850 | $REPLACE_付近
851 | $REPLACE_提供
852 | $REPLACE_##こ
853 | $APPEND_##かっ
854 | $APPEND_とっ
855 | $REPLACE_形式
856 | $REPLACE_##ン
857 | $REPLACE_日間
858 | $REPLACE_科
859 | $REPLACE_物
860 | $REPLACE_状況
861 | $REPLACE_##しゃ
862 | $APPEND_少し
863 | $REPLACE_感
864 | $REPLACE_院
865 | $REPLACE_搭載
866 | $APPEND_様
867 | $REPLACE_編集
868 | $REPLACE_そこ
869 | $REPLACE_デ
870 | $REPLACE_立
871 | $REPLACE_対し
872 | $REPLACE_ち
873 | $REPLACE_人気
874 | $REPLACE_夏
875 | $REPLACE_2
876 | $REPLACE_場
877 | $APPEND_座
878 | $APPEND_付
879 | $APPEND_集
880 | $APPEND_両
881 | $REPLACE_歌
882 | $REPLACE_できる
883 | $REPLACE_職
884 | $REPLACE_ご
885 | $REPLACE_連合
886 | $REPLACE_党
887 | $APPEND_街
888 | $REPLACE_郵便
889 | $REPLACE_兄弟
890 | $REPLACE_声
891 | $APPEND_行き
892 | $REPLACE_たくさん
893 | $REPLACE_って
894 | $TRANSFORM_VBI_VBP
895 | $APPEND_もしくは
896 | $REPLACE_法人
897 | $REPLACE_km
898 | $REPLACE_自治
899 | $APPEND_"
900 | $REPLACE_兼
901 | $REPLACE_出
902 | $APPEND_入り
903 | $REPLACE_城
904 | $REPLACE_湾
905 | $REPLACE_学生
906 | $REPLACE_種類
907 | $REPLACE_作家
908 | $APPEND_来
909 | $REPLACE_講師
910 | $APPEND_属
911 | $REPLACE_記載
912 | $REPLACE_記事
913 | $APPEND_弾
914 | $APPEND_症
915 | $REPLACE_初めて
916 | $REPLACE_層
917 | $REPLACE_枠
918 | $REPLACE_行政
919 | $REPLACE_継承
920 | $REPLACE_表現
921 | $APPEND_分
922 | $REPLACE_復帰
923 | $REPLACE_なけれ
924 | $REPLACE_会議
925 | $REPLACE_表
926 | $REPLACE_彼
927 | $APPEND_明らか
928 | $APPEND_って
929 | $REPLACE_とっ
930 | $REPLACE_芸術
931 | $REPLACE_市場
932 | $REPLACE_いただ
933 | $REPLACE_工学
934 | $REPLACE_表示
935 | $REPLACE_あまり
936 | $REPLACE_所有
937 | $REPLACE_節
938 | $REPLACE_有する
939 | $REPLACE_成長
940 | $REPLACE_配信
941 | $REPLACE_学科
942 | $REPLACE_順
943 | $APPEND_歌
944 | $REPLACE_あと
945 | $REPLACE_盤
946 | $APPEND_独自
947 | $REPLACE_ほう
948 | $REPLACE_藩主
949 | $REPLACE_せん
950 | $REPLACE_一緒
951 | $REPLACE_敵
952 | $REPLACE_いく
953 | $REPLACE_秋
954 | $REPLACE_従
955 | $REPLACE_核
956 | $REPLACE_皇帝
957 | $REPLACE_みんな
958 | $APPEND_!
959 | $APPEND_##き
960 | $REPLACE_だろう
961 | $APPEND_流
962 | $REPLACE_##だ
963 | $APPEND_相
964 | $REPLACE_##つ
965 | $REPLACE_言語
966 | $APPEND_帯
967 | $REPLACE_労働
968 | $REPLACE_展開
969 | $APPEND_外
970 | $REPLACE_高
971 | $REPLACE_環境
972 | $REPLACE_病院
973 | $APPEND_##す
974 | $APPEND_しばしば
975 | $APPEND_計
976 | $REPLACE_首相
977 | $REPLACE_演じ
978 | $APPEND_気
979 | $REPLACE_後期
980 | $REPLACE_艦隊
981 | $APPEND_小さな
982 | $REPLACE_距離
983 | $REPLACE_取締
984 | $REPLACE_"
985 | $REPLACE_客
986 | $REPLACE_宇宙
987 | $REPLACE_参戦
988 | $REPLACE_杯
989 | $REPLACE_オン
990 | $APPEND_##ん
991 | $APPEND_##ら
992 | $REPLACE_残っ
993 | $REPLACE_絵
994 | $REPLACE_協力
995 | $REPLACE_障害
996 | $APPEND_##か
997 | $REPLACE_行く
998 | $REPLACE_課程
999 | $REPLACE_装置
1000 | $REPLACE_関連
1001 | $REPLACE_新
1002 | $APPEND_##え
1003 | $REPLACE_なか
1004 | $APPEND_山
1005 | $REPLACE_きっかけ
1006 | $APPEND_まず
1007 | $APPEND_う
1008 | $REPLACE_光
1009 | $APPEND_発
1010 | $APPEND_時間
1011 | $APPEND_つつ
1012 | $REPLACE_画
1013 | $APPEND_そんな
1014 | $APPEND_前
1015 | $REPLACE_いつ
1016 | $APPEND_しまっ
1017 | $REPLACE_トン
1018 | $REPLACE_弟
1019 | $REPLACE_ころ
1020 | $REPLACE_法学
1021 | $APPEND_記
1022 | $APPEND_優秀
1023 | $APPEND_宗
1024 | $APPEND_言い
1025 | $APPEND_思う
1026 | $REPLACE_山
1027 | $REPLACE_下記
1028 | $APPEND_##s
1029 | $APPEND_わずか
1030 | $APPEND_港
1031 | $APPEND_当
1032 | $REPLACE_支配
1033 | $REPLACE_##ち
1034 | $REPLACE_夜
1035 | $REPLACE_##える
1036 | $APPEND_聖
1037 | $REPLACE_化学
1038 | $REPLACE_劇場
1039 | $REPLACE_系統
1040 | $REPLACE_背
1041 | $REPLACE_詩
1042 | $APPEND_馬
1043 | $REPLACE_方向
1044 | $REPLACE_権
1045 | $REPLACE_愛
1046 | $APPEND_砲
1047 | $APPEND_風
1048 | $REPLACE_なし
1049 | $REPLACE_政策
1050 | $REPLACE_週
1051 | $REPLACE_##する
1052 | $REPLACE_てる
1053 | $REPLACE_移動
1054 | $APPEND_路
1055 | $REPLACE_家族
1056 | $APPEND_付き
1057 | $APPEND_##け
1058 | $REPLACE_面積
1059 | $APPEND_フル
1060 | $REPLACE_船
1061 | $REPLACE_ページ
1062 | $REPLACE_首都
1063 | $REPLACE_ル
1064 | $APPEND_どちら
1065 | $APPEND_値
1066 | $APPEND_ちゃん
1067 | $APPEND_ほう
1068 | $APPEND_誰
1069 | $APPEND_ま
1070 | $REPLACE_交換
1071 | $REPLACE_しまい
1072 | $REPLACE_地下
1073 | $REPLACE_行っ
1074 | $APPEND_更に
1075 | $REPLACE_##な
1076 | $REPLACE_機構
1077 | $APPEND_ばかり
1078 | $APPEND_##と
1079 | $REPLACE_始まっ
1080 | $REPLACE_専攻
1081 | $REPLACE_行
1082 | $REPLACE_編
1083 | $APPEND_どの
1084 | $REPLACE_わけ
1085 | $REPLACE_声優
1086 | $REPLACE_車
1087 | $REPLACE_ラン
1088 | $REPLACE_論
1089 | $REPLACE_街
1090 | $REPLACE_複数
1091 | $REPLACE_全
1092 | $REPLACE_ワン
1093 | $REPLACE_体制
1094 | $REPLACE_作戦
1095 | $APPEND_できる
1096 | $REPLACE_実
1097 | $APPEND_##う
1098 | $APPEND_##さ
1099 | $REPLACE_方面
1100 | $APPEND_丸
1101 | $REPLACE_##込ん
1102 | $APPEND_いわゆる
1103 | $REPLACE_貢献
1104 | $REPLACE_件
1105 | $REPLACE_倍
1106 | $REPLACE_基地
1107 | $REPLACE_死亡
1108 | $REPLACE_しれ
1109 | $REPLACE_普通
1110 | $APPEND_←
1111 | $REPLACE_勉強
1112 | $REPLACE_く
1113 | $REPLACE_便
1114 | $REPLACE_くる
1115 | $APPEND_ずつ
1116 | $REPLACE_後述
1117 | $REPLACE_仕様
1118 | $APPEND_言っ
1119 | $APPEND_かつ
1120 | $REPLACE_身
1121 | $REPLACE_連盟
1122 | $REPLACE_現代
1123 | $REPLACE_官
1124 | $REPLACE_音
1125 | $APPEND_もし
1126 | $REPLACE_代わり
1127 | $REPLACE_相当
1128 | $REPLACE_門
1129 | $APPEND_話
1130 | $REPLACE_席
1131 | $APPEND_あなた
1132 | $REPLACE_兄
1133 | $REPLACE_選択
1134 | $REPLACE_,
1135 | $REPLACE_日本
1136 | $APPEND_かなり
1137 | $REPLACE_ソロ
1138 | $REPLACE_比較
1139 | $APPEND_多
1140 | $REPLACE_機械
1141 | $REPLACE_親
1142 | $APPEND_##ま
1143 | $REPLACE_跡
1144 | $REPLACE_きょう
1145 | $REPLACE_単行
1146 | $REPLACE_化
1147 | $REPLACE_得点
1148 | $REPLACE_機会
1149 | $REPLACE_良い
1150 | $REPLACE_終わっ
1151 | $REPLACE_一時
1152 | $REPLACE_経
1153 | $REPLACE_商品
1154 | $REPLACE_投手
1155 | $APPEND_ただ
1156 | $REPLACE_家庭
1157 | $REPLACE_制度
1158 | $REPLACE_型
1159 | $APPEND_既に
1160 | $REPLACE_用
1161 | $APPEND_すでに
1162 | $REPLACE_たち
1163 | $REPLACE_相手
1164 | $REPLACE_規定
1165 | $REPLACE_/
1166 | $REPLACE_個人
1167 | $APPEND_す
1168 | $REPLACE_歩兵
1169 | $APPEND_心
1170 | $REPLACE_形態
1171 | $REPLACE_高等
1172 | $APPEND_言葉
1173 | $REPLACE_報告
1174 | $REPLACE_##ス
1175 | $APPEND_年
1176 | $REPLACE_マイル
1177 | $REPLACE_諸島
1178 | $REPLACE_足
1179 | $REPLACE_え
1180 | $REPLACE_せる
1181 | $APPEND_特殊
1182 | $REPLACE_ハイ
1183 | $APPEND_やがて
1184 | $APPEND_直接
1185 | $REPLACE_行為
1186 | $APPEND_てる
1187 | $REPLACE_食
1188 | $APPEND_なん
1189 | $REPLACE_経由
1190 | $APPEND_男
1191 | $REPLACE_改称
1192 | $REPLACE_製品
1193 | $APPEND_やすい
1194 | $REPLACE_ながら
1195 | $REPLACE_受け
1196 | $REPLACE_司令
1197 | $APPEND_すなわち
1198 | $APPEND_o
1199 | $APPEND_いっ
1200 | $REPLACE_すぐ
1201 | $REPLACE_本当
1202 | $REPLACE_更新
1203 | $APPEND_頭
1204 | $REPLACE_命じ
1205 | $REPLACE_作詞
1206 | $REPLACE_精神
1207 | $REPLACE_像
1208 | $REPLACE_ひ
1209 | $REPLACE_多く
1210 | $REPLACE_好き
1211 | $APPEND_単独
1212 | $REPLACE_指名
1213 | $REPLACE_応じ
1214 | $REPLACE_いき
1215 | $REPLACE_考え
1216 | $REPLACE_伝統
1217 | $REPLACE_あ
1218 | $REPLACE_##生
1219 | $REPLACE_ハ
1220 | $APPEND_一番
1221 | $TRANSFORM_VBCG_VBV
1222 | $TRANSFORM_VBS_VBV
1223 | $REPLACE_主張
1224 | $REPLACE_軌道
1225 | $APPEND_鉄
1226 | $REPLACE_課
1227 | $APPEND_全く
1228 | $REPLACE_傘下
1229 | $REPLACE_向上
1230 | $APPEND_あ
1231 | $APPEND_え
1232 | $REPLACE_##ず
1233 | $REPLACE_初め
1234 | $APPEND_酸
1235 | $REPLACE_不
1236 | $APPEND_:
1237 | $REPLACE_記述
1238 | $APPEND_質
1239 | $REPLACE_上記
1240 | $REPLACE_事故
1241 | $REPLACE_最近
1242 | $REPLACE_改正
1243 | $APPEND_助
1244 | $REPLACE_原子
1245 | $REPLACE_##まっ
1246 | $REPLACE_真
1247 | $REPLACE_自体
1248 | $REPLACE_新人
1249 | $REPLACE_書
1250 | $REPLACE_[
1251 | $REPLACE_世
1252 | $REPLACE_みたい
1253 | $REPLACE_ノ
1254 | $REPLACE_思っ
1255 | $APPEND_~
1256 | $REPLACE_維持
1257 | $REPLACE_こんな
1258 | $REPLACE_停止
1259 | $APPEND_つまり
1260 | $REPLACE_ください
1261 | $REPLACE_青
1262 | $REPLACE_天
1263 | $REPLACE_官位
1264 | $REPLACE_##び
1265 | $APPEND_”
1266 | $REPLACE_出し
1267 | $REPLACE_塔
1268 | $APPEND_しれ
1269 | $APPEND_星
1270 | $REPLACE_伯
1271 | $REPLACE_資料
1272 | $REPLACE_##ば
1273 | $APPEND_卒
1274 | $APPEND_考え
1275 | $REPLACE_制御
1276 | $APPEND_反
1277 | $APPEND_友達
1278 | $APPEND_行っ
1279 | $REPLACE_公立
1280 | $REPLACE_キ
1281 | $REPLACE_項
1282 | $REPLACE_##込む
1283 | $REPLACE_主人
1284 | $APPEND_朝
1285 | $APPEND_こ
1286 | $REPLACE_小学
1287 | $REPLACE_占領
1288 | $REPLACE_傾向
1289 | $APPEND_##つ
1290 | $REPLACE_なん
1291 | $APPEND_感じ
1292 | $REPLACE_常
1293 | $APPEND_低
1294 | $APPEND_料
1295 | $REPLACE_有し
1296 | $REPLACE_黒
1297 | $REPLACE_アンド
1298 | $REPLACE_案
1299 | $APPEND_数多く
1300 | $APPEND_同一
1301 | $APPEND_良い
1302 | $REPLACE_昇進
1303 | $REPLACE_勝ち
1304 | $REPLACE_行う
1305 | $REPLACE_しまっ
1306 | $APPEND_言う
1307 | $APPEND_自然
1308 | $REPLACE_比
1309 | $APPEND_##め
1310 | $REPLACE_ほど
1311 | $REPLACE_白
1312 | $REPLACE_わ
1313 | $REPLACE_だい
1314 | $REPLACE_店
1315 | $REPLACE_かん
1316 | $REPLACE_機体
1317 | $REPLACE_or
1318 | $REPLACE_##れる
1319 | $REPLACE_市街
1320 | $REPLACE_進学
1321 | $REPLACE_##お
1322 | $REPLACE_かけ
1323 | $REPLACE_師事
1324 | $REPLACE_視聴
1325 | $REPLACE_そんな
1326 | $REPLACE_言い
1327 | $REPLACE_増加
1328 | $APPEND_今日
1329 | $REPLACE_おお
1330 | $REPLACE_けん
1331 | $REPLACE_少し
1332 | $APPEND_仕事
1333 | $REPLACE_位
1334 | $REPLACE_サイド
1335 | $APPEND_出
1336 | $REPLACE_券
1337 | $APPEND_じゃ
1338 | $REPLACE_銃
1339 | $REPLACE_二人
1340 | $REPLACE_少女
1341 | $REPLACE_生息
1342 | $REPLACE_移行
1343 | $APPEND_くらい
1344 | $APPEND_病
1345 | $REPLACE_達成
1346 | $REPLACE_生じ
1347 | $APPEND_やや
1348 | $REPLACE_開
1349 | $REPLACE_馬
1350 | $APPEND_法
1351 | $REPLACE_貴族
1352 | $APPEND_いけ
1353 | $REPLACE_!
1354 | $APPEND_##n
1355 | $REPLACE_店舗
1356 | $REPLACE_##ぶ
1357 | $REPLACE_やっ
1358 | $APPEND_十分
1359 | $REPLACE_県道
1360 | $APPEND_さまざま
1361 | $APPEND_剤
1362 | $REPLACE_ちょっと
1363 | $REPLACE_いわ
1364 | $REPLACE_為
1365 | $REPLACE_資格
1366 | $REPLACE_交流
1367 | $REPLACE_施行
1368 | $REPLACE_管
1369 | $APPEND_諸
1370 | $REPLACE_##べ
1371 | $REPLACE_交渉
1372 | $REPLACE_夫
1373 | $REPLACE_意
1374 | $REPLACE_見つけ
1375 | $REPLACE_生物
1376 | $APPEND_君
1377 | $APPEND_こちら
1378 | $APPEND_##日
1379 | $APPEND_財
1380 | $REPLACE_大使
1381 | $REPLACE_解散
1382 | $APPEND_およそ
1383 | $REPLACE_再開
1384 | $REPLACE_土曜
1385 | $REPLACE_政
1386 | $REPLACE_えき
1387 | $REPLACE_侵攻
1388 | $REPLACE_とう
1389 | $REPLACE_進行
1390 | $REPLACE_始まり
1391 | $REPLACE_政権
1392 | $REPLACE_機器
1393 | $APPEND_わ
1394 | $APPEND_##せ
1395 | $REPLACE_ごろ
1396 | $REPLACE_前身
1397 | $REPLACE_銅
1398 | $REPLACE_##ぐ
1399 | $APPEND_選
1400 | $REPLACE_作
1401 | $APPEND_もっと
1402 | $REPLACE_創刊
1403 | $REPLACE_図
1404 | $REPLACE_会場
1405 | $REPLACE_三
1406 | $REPLACE_任じ
1407 | $REPLACE_##める
1408 | $APPEND_域
1409 | $REPLACE_大型
1410 | $REPLACE_原因
1411 | $REPLACE_旅
1412 | $REPLACE_火
1413 | $APPEND_区
1414 | $REPLACE_教え
1415 | $REPLACE_##取
1416 | $REPLACE_対抗
1417 | $REPLACE_艇
1418 | $APPEND_費
1419 | $REPLACE_技
1420 | $REPLACE_内閣
1421 | $REPLACE_感染
1422 | $APPEND_なぜ
1423 | $REPLACE_反応
1424 | $REPLACE_始まる
1425 | $REPLACE_インターネット
1426 | $REPLACE_酸化
1427 | $REPLACE_戸
1428 | $REPLACE_言っ
1429 | $REPLACE_転じ
1430 | $REPLACE_3
1431 | $REPLACE_心
1432 | $TRANSFORM_VBV_VBI
1433 | $REPLACE_潜水
1434 | $REPLACE_基準
1435 | $REPLACE_##や
1436 | $APPEND_[UNK]
1437 | $APPEND_有効
1438 | $APPEND_橋
1439 | $REPLACE_医師
1440 | $APPEND_##しい
1441 | $REPLACE_東部
1442 | $REPLACE_河川
1443 | $REPLACE_加盟
1444 | $REPLACE_禁止
1445 | $REPLACE_引き続き
1446 | $REPLACE_m
1447 | $APPEND_##が
1448 | $REPLACE_戦隊
1449 | $REPLACE_球
1450 | $REPLACE_気候
1451 | $REPLACE_養子
1452 | $REPLACE_派遣
1453 | $REPLACE_##まし
1454 | $REPLACE_劇
1455 | $REPLACE_陣
1456 | $APPEND_署
1457 | $REPLACE_登板
1458 | $REPLACE_##ル
1459 | $REPLACE_どんな
1460 | $REPLACE_装備
1461 | $REPLACE_全長
1462 | $REPLACE_いけ
1463 | $APPEND_うち
1464 | $REPLACE_東
1465 | $REPLACE_当主
1466 | $APPEND_##tt
1467 | $REPLACE_指摘
1468 | $REPLACE_連隊
1469 | $REPLACE_氏
1470 | $REPLACE_いた
1471 | $REPLACE_土地
1472 | $APPEND_どこ
1473 | $APPEND_水
1474 | $APPEND_沿い
1475 | $REPLACE_言う
1476 | $APPEND_行く
1477 | $REPLACE_後継
1478 | $REPLACE_##ト
1479 | $REPLACE_上位
1480 | $TRANSFORM_ADJCG_ADJS
1481 | $REPLACE_民族
1482 | $REPLACE_##ね
1483 | $REPLACE_地震
1484 | $APPEND_たかっ
1485 | $REPLACE_ガン
1486 | $REPLACE_##出し
1487 | $REPLACE_自転
1488 | $REPLACE_難しい
1489 | $APPEND_じ
1490 | $REPLACE_難
1491 | $REPLACE_わから
1492 | $REPLACE_しん
1493 | $REPLACE_急
1494 | $APPEND_一緒
1495 | $APPEND_しまう
1496 | $REPLACE_合成
1497 | $REPLACE_どの
1498 | $REPLACE_史上
1499 | $REPLACE_古代
1500 | $REPLACE_大きな
1501 | $REPLACE_発電
1502 | $REPLACE_半島
1503 | $REPLACE_呼称
1504 | $APPEND_わたし
1505 | $REPLACE_大賞
1506 | $REPLACE_なさ
1507 | $REPLACE_中止
1508 | $REPLACE_強化
1509 | $REPLACE_##に
1510 | $REPLACE_成
1511 | $REPLACE_つもり
1512 | $REPLACE_わかり
1513 | $REPLACE_む
1514 | $REPLACE_破壊
1515 | $REPLACE_美
1516 | $REPLACE_期待
1517 | $REPLACE_話し
1518 | $REPLACE_ー
1519 | $REPLACE_館
1520 | $APPEND_巨大
1521 | $REPLACE_遺跡
1522 | $REPLACE_しかし
1523 | $APPEND_紙
1524 | $REPLACE_##ける
1525 | $REPLACE_赤
1526 | $REPLACE_和
1527 | $REPLACE_共演
1528 | $APPEND_##わ
1529 | $REPLACE_ジム
1530 | $REPLACE_##込み
1531 | $REPLACE_果たし
1532 | $REPLACE_授業
1533 | $REPLACE_操作
1534 | $REPLACE_各
1535 | $REPLACE_##も
1536 | $REPLACE_組合
1537 | $REPLACE_国勢
1538 | $REPLACE_不詳
1539 | $REPLACE_持っ
1540 | $REPLACE_承認
1541 | $APPEND_色
1542 | $REPLACE_朝日
1543 | $REPLACE_風
1544 | $REPLACE_総
1545 | $APPEND_##er
1546 | $APPEND_ちょっと
1547 | $REPLACE_分離
1548 | $REPLACE_ここ
1549 | $REPLACE_もっと
1550 | $REPLACE_作成
1551 | $REPLACE_はず
1552 | $REPLACE_本線
1553 | $REPLACE_市内
1554 | $REPLACE_球団
1555 | $REPLACE_日曜
1556 | $APPEND_短
1557 | $REPLACE_漢字
1558 | $REPLACE_家臣
1559 | $REPLACE_訳
1560 | $REPLACE_周
1561 | $REPLACE_反乱
1562 | $REPLACE_カ
1563 | $REPLACE_外
1564 | $APPEND_##t
1565 | $APPEND_多様
1566 | $REPLACE_改名
1567 | $REPLACE_口
1568 | $REPLACE_もっ
1569 | $REPLACE_##こう
1570 | $REPLACE_終わり
1571 | $APPEND_やっ
1572 | $REPLACE_気持ち
1573 | $APPEND_書い
1574 | $REPLACE_交差
1575 | $APPEND_他
1576 | $REPLACE_徒
1577 | $REPLACE_ど
1578 | $REPLACE_修了
1579 | $APPEND_※
1580 | $APPEND_授業
1581 | $REPLACE_基礎
1582 | $REPLACE_服
1583 | $REPLACE_おも
1584 | $REPLACE_空
1585 | $REPLACE_水
1586 | $REPLACE_森
1587 | $REPLACE_停留
1588 | $APPEND_しばらく
1589 | $REPLACE_解放
1590 | $REPLACE_歌詞
1591 | $REPLACE_児童
1592 | $APPEND_ど
1593 | $REPLACE_ごう
1594 | $APPEND_度
1595 | $APPEND_単に
1596 | $APPEND_I
1597 | $APPEND_同士
1598 | $REPLACE_勢力
1599 | $REPLACE_延長
1600 | $REPLACE_医療
1601 | $REPLACE_金属
1602 | $REPLACE_角
1603 | $REPLACE_創業
1604 | $REPLACE_逆
1605 | $APPEND_ください
1606 | $REPLACE_同盟
1607 | $APPEND_もっとも
1608 | $REPLACE_教
1609 | $APPEND_著名
1610 | $REPLACE_##ざ
1611 | $REPLACE_統計
1612 | $REPLACE_回復
1613 | $APPEND_大幅
1614 | $REPLACE_政党
1615 | $REPLACE_新設
1616 | $REPLACE_死後
1617 | $REPLACE_翻訳
1618 | $REPLACE_収容
1619 | $APPEND_古
1620 | $REPLACE_周年
1621 | $APPEND_波
1622 | $APPEND_薬
1623 | $REPLACE_京
1624 | $REPLACE_際し
1625 | $REPLACE_しよう
1626 | $REPLACE_展示
1627 | $APPEND_民
1628 | $APPEND_ずっと
1629 | $REPLACE_主催
1630 | $REPLACE_たく
1631 | $REPLACE_侯
1632 | $REPLACE_わたし
1633 | $REPLACE_銀
1634 | $REPLACE_変
1635 | $REPLACE_太陽
1636 | $APPEND_かん
1637 | $REPLACE_アイ
1638 | $REPLACE_り
1639 | $REPLACE_紙
1640 | $REPLACE_満
1641 | $REPLACE_自己
1642 | $REPLACE_自
1643 | $REPLACE_国境
1644 | $APPEND_必ず
1645 | $APPEND_簡単
1646 | $REPLACE_実験
1647 | $APPEND_じょう
1648 | $REPLACE_危機
1649 | $REPLACE_やま
1650 | $APPEND_盛ん
1651 | $REPLACE_箇所
1652 | $REPLACE_海岸
1653 | $REPLACE_最
1654 | $REPLACE_け
1655 | $REPLACE_防御
1656 | $REPLACE_米
1657 | $REPLACE_##日
1658 | $REPLACE_思う
1659 | $REPLACE_商
1660 | $REPLACE_友人
1661 | $REPLACE_通っ
1662 | $REPLACE_信仰
1663 | $REPLACE_##まる
1664 | $REPLACE_司会
1665 | $APPEND_みんな
1666 | $REPLACE_顔
1667 | $APPEND_宮
1668 | $REPLACE_おら
1669 | $REPLACE_統治
1670 | $APPEND_観
1671 | $REPLACE_卿
1672 | $REPLACE_命名
1673 | $APPEND_けれど
1674 | $REPLACE_裏
1675 | $APPEND_冊
1676 | $REPLACE_歯
1677 | $REPLACE_形状
1678 | $REPLACE_うえ
1679 | $REPLACE_連絡
1680 | $REPLACE_##カ
1681 | $REPLACE_##よ
1682 | $APPEND_多く
1683 | $REPLACE_セント
1684 | $APPEND_##こ
1685 | $REPLACE_むら
1686 | $REPLACE_規制
1687 | $REPLACE_##出す
1688 | $REPLACE_東洋
1689 | $APPEND_##た
1690 | $REPLACE_投資
1691 | $REPLACE_ス
1692 | $REPLACE_列
1693 | $APPEND_場合
1694 | $REPLACE_連
1695 | $REPLACE_定期
1696 | $REPLACE_##せる
1697 | $REPLACE_通過
1698 | $REPLACE_見る
1699 | $APPEND_たとえば
1700 | $REPLACE_額
1701 | $REPLACE_起源
1702 | $REPLACE_負傷
1703 | $APPEND_本当
1704 | $REPLACE_減少
1705 | $APPEND_大学
1706 | $REPLACE_訓練
1707 | $REPLACE_軽
1708 | $REPLACE_##そう
1709 | $REPLACE_公共
1710 | $REPLACE_ぎ
1711 | $REPLACE_回転
1712 | $REPLACE_じゃ
1713 | $REPLACE_##へ
1714 | $REPLACE_段階
1715 | $REPLACE_広告
1716 | $REPLACE_開設
1717 | $REPLACE_っ
1718 | $REPLACE_初頭
1719 | $REPLACE_##ど
1720 | $REPLACE_黄
1721 | $REPLACE_ベイ
1722 | $REPLACE_分から
1723 | $REPLACE_さい
1724 | $REPLACE_過程
1725 | $REPLACE_友達
1726 | $REPLACE_ずっと
1727 | $REPLACE_挑戦
1728 | $REPLACE_票
1729 | $APPEND_日記
1730 | $REPLACE_下部
1731 | $REPLACE_生涯
1732 | $REPLACE_脳
1733 | $REPLACE_起用
1734 | $REPLACE_乱
1735 | $REPLACE_ナ
1736 | $REPLACE_定
1737 | $REPLACE_調整
1738 | $APPEND_皆
1739 | $APPEND_文
1740 | $REPLACE_範囲
1741 | $REPLACE_##わっ
1742 | $REPLACE_##s
1743 | $REPLACE_値
1744 | $REPLACE_先発
1745 | $REPLACE_食べ
1746 | $REPLACE_接する
1747 | $REPLACE_恋
1748 | $REPLACE_知ら
1749 | $REPLACE_##浮
1750 | $REPLACE_弁
1751 | $REPLACE_首
1752 | $REPLACE_##とう
1753 | $REPLACE_##リ
1754 | $APPEND_伝
1755 | $REPLACE_世帯
1756 | $APPEND_又
1757 | $REPLACE_二
1758 | $REPLACE_史
1759 | $REPLACE_集落
1760 | $REPLACE_休止
1761 | $APPEND_'
1762 | $REPLACE_共産
1763 | $APPEND_いく
1764 | $APPEND_高度
1765 | $REPLACE_帝
1766 | $REPLACE_没
1767 | $REPLACE_なんて
1768 | $REPLACE_得意
1769 | $REPLACE_興行
1770 | $APPEND_だい
1771 | $APPEND_郷
1772 | $REPLACE_海上
1773 | $APPEND_もともと
1774 | $REPLACE_再建
1775 | $REPLACE_がん
1776 | $REPLACE_##まり
1777 | $REPLACE_##っこ
1778 | $REPLACE_書い
1779 | $APPEND_代
1780 | $REPLACE_士
1781 | $REPLACE_##にち
1782 | $APPEND_表現
1783 | $REPLACE_ふ
1784 | $REPLACE_ダ
1785 | $REPLACE_第
1786 | $REPLACE_ガ
1787 | $REPLACE_くらい
1788 | $REPLACE_すご
1789 | $REPLACE_期限
1790 | $APPEND_どんな
1791 | $REPLACE_可
1792 | $REPLACE_しり
1793 | $APPEND_容易
1794 | $APPEND_食べ
1795 | $REPLACE_街道
1796 | $APPEND_文章
1797 | $REPLACE_縁
1798 | $REPLACE_走行
1799 | $APPEND_楽
1800 | $REPLACE_級
1801 | $REPLACE_こん
1802 | $REPLACE_コ
1803 | $REPLACE_起こし
1804 | $APPEND_らしい
1805 | $REPLACE_急行
1806 | $REPLACE_##合
1807 | $REPLACE_両親
1808 | $REPLACE_旧姓
1809 | $REPLACE_師
1810 | $APPEND_くれる
1811 | $APPEND_ノ
1812 | $APPEND_たく
1813 | $REPLACE_隊
1814 | $REPLACE_リン
1815 | $REPLACE_記者
1816 | $REPLACE_幹線
1817 | $APPEND_s
1818 | $APPEND_ぬ
1819 | $REPLACE_##げ
1820 | $REPLACE_中世
1821 | $REPLACE_かい
1822 | $REPLACE_したがっ
1823 | $REPLACE_週刊
1824 | $REPLACE_なくなっ
1825 | $REPLACE_報じ
1826 | $REPLACE_ソン
1827 | $REPLACE_聖
1828 | $APPEND_##w
1829 | $REPLACE_注意
1830 | $REPLACE_##会
1831 | $APPEND_i
1832 | $REPLACE_要求
1833 | $REPLACE_証明
1834 | $APPEND_ビッグ
1835 | $APPEND_わけ
1836 | $APPEND_軒
1837 | $REPLACE_##長
1838 | $REPLACE_まち
1839 | $REPLACE_衛星
1840 | $APPEND_せい
1841 | $REPLACE_系
1842 | $REPLACE_帰っ
1843 | $REPLACE_仕事
1844 | $APPEND_極めて
1845 | $APPEND_網
1846 | $REPLACE_##イ
1847 | $REPLACE_国会
1848 | $REPLACE_帰還
1849 | $APPEND_##み
1850 | $APPEND_令
1851 | $APPEND_べく
1852 | $REPLACE_意味
1853 | $REPLACE_最初
1854 | $REPLACE_幅
1855 | $APPEND_##ー
1856 | $APPEND_##じ
1857 | $REPLACE_血
1858 | $REPLACE_検討
1859 | $REPLACE_長官
1860 | $REPLACE_製
1861 | $REPLACE_たかっ
1862 | $APPEND_持っ
1863 | $APPEND_させ
1864 | $REPLACE_聞い
1865 | $REPLACE_彼女
1866 | $APPEND_学校
1867 | $APPEND_is
1868 | $REPLACE_レイ
1869 | $REPLACE_てん
1870 | $REPLACE_力
1871 | $APPEND_気持ち
1872 | $REPLACE_##ろ
1873 | $APPEND_大変
1874 | $REPLACE_生命
1875 | $APPEND_ふ
1876 | $REPLACE_甲
1877 | $REPLACE_変わっ
1878 | $REPLACE_め
1879 | $REPLACE_軸
1880 | $REPLACE_もし
1881 | $REPLACE_投稿
1882 | $REPLACE_理解
1883 | $APPEND_鏡
1884 | $APPEND_くる
1885 | $APPEND_光
1886 | $REPLACE_妹
1887 | $REPLACE_子ども
1888 | $REPLACE_早く
1889 | $APPEND_##そ
1890 | $REPLACE_機動
1891 | $REPLACE_文章
1892 | $REPLACE_次男
1893 | $REPLACE_先行
1894 | $REPLACE_窓
1895 | $REPLACE_多い
1896 | $REPLACE_うまく
1897 | $REPLACE_日記
1898 | $REPLACE_興
1899 | $REPLACE_→
1900 | $REPLACE_南東
1901 | $REPLACE_けい
1902 | $REPLACE_ぶ
1903 | $APPEND_あと
1904 | $REPLACE_##ク
1905 | $REPLACE_教皇
1906 | $APPEND_##む
1907 | $REPLACE_推定
1908 | $APPEND_##な
1909 | $REPLACE_竣工
1910 | $APPEND_ついに
1911 | $REPLACE_##性
1912 | $REPLACE_無料
1913 | $APPEND_始め
1914 | $APPEND_よっ
1915 | $REPLACE_武将
1916 | $REPLACE_小型
1917 | $REPLACE_精
1918 | $REPLACE_副
1919 | $REPLACE_方針
1920 | $REPLACE_##上
1921 | $APPEND_さえ
1922 | $REPLACE_仮
1923 | $REPLACE_オス
1924 | $REPLACE_マイ
1925 | $REPLACE_タンパク
1926 | $APPEND_有力
1927 | $REPLACE_解説
1928 | $REPLACE_##n
1929 | $REPLACE_##ぼ
1930 | $REPLACE_鉄
1931 | $REPLACE_携帯
1932 | $APPEND_書く
1933 | $REPLACE_兵器
1934 | $REPLACE_##ナ
1935 | $REPLACE_たか
1936 | $REPLACE_##かい
1937 | $REPLACE_##ろう
1938 | $REPLACE_雨
1939 | $REPLACE_本項
1940 | $APPEND_##の
1941 | $REPLACE_姉
1942 | $REPLACE_装甲
1943 | $REPLACE_下位
1944 | $REPLACE_無
1945 | $REPLACE_場所
1946 | $REPLACE_頭
1947 | $APPEND_実
1948 | $REPLACE_寺
1949 | $REPLACE_松
1950 | $APPEND_板
1951 | $APPEND_##ち
1952 | $REPLACE_用語
1953 | $REPLACE_剣
1954 | $REPLACE_強制
1955 | $REPLACE_よれ
1956 | $REPLACE_退職
1957 | $APPEND_##d
1958 | $REPLACE_規格
1959 | $REPLACE_憲法
1960 | $APPEND_視
1961 | $REPLACE_祖
1962 | $REPLACE_岩
1963 | $APPEND_書き
1964 | $REPLACE_端
1965 | $REPLACE_残る
1966 | $REPLACE_事態
1967 | $REPLACE_[UNK]
1968 | $APPEND_みたい
1969 | $APPEND_医
1970 | $REPLACE_修正
1971 | $REPLACE_キロ
1972 | $APPEND_##て
1973 | $REPLACE_解
1974 | $APPEND_頃
1975 | $REPLACE_ボタン
1976 | $REPLACE_しまう
1977 | $APPEND_##ご
1978 | $REPLACE_あげ
1979 | $REPLACE_心理
1980 | $REPLACE_大将
1981 | $REPLACE_土
1982 | $REPLACE_区別
1983 | $REPLACE_教師
1984 | $REPLACE_同期
1985 | $REPLACE_使い
1986 | $APPEND_必要
1987 | $APPEND_単純
1988 | $REPLACE_支部
1989 | $REPLACE_ひと
1990 | $APPEND_好
1991 | $REPLACE_退団
1992 | $REPLACE_亡くなっ
1993 | $REPLACE_この頃
1994 | $APPEND_つもり
1995 | $REPLACE_##ラ
1996 | $APPEND_話し
1997 | $REPLACE_暮らし
1998 | $REPLACE_有
1999 | $APPEND_##しく
2000 | $REPLACE_いか
2001 | $REPLACE_##ずか
2002 | $APPEND_突然
2003 | $REPLACE_##ぞ
2004 | $APPEND_紀
2005 | $APPEND_おそらく
2006 | $REPLACE_知っ
2007 | $REPLACE_夫人
2008 | $REPLACE_実行
2009 | $REPLACE_国内
2010 | $APPEND_明確
2011 | $REPLACE_沖
2012 | $REPLACE_受章
2013 | $REPLACE_意見
2014 | $REPLACE_校
2015 | $REPLACE_制定
2016 | $REPLACE_楽し
2017 | $REPLACE_##ぱ
2018 | $REPLACE_対策
2019 | $REPLACE_ひとり
2020 | $REPLACE_##らし
2021 | $REPLACE_リ
2022 | $REPLACE_渡っ
2023 | $REPLACE_工
2024 | $REPLACE_近郊
2025 | $APPEND_##u
2026 | $REPLACE_構想
2027 | $REPLACE_年齢
2028 | $REPLACE_あなた
2029 | $REPLACE_協議
2030 | $APPEND_わかり
2031 | $REPLACE_根
2032 | $REPLACE_地帯
2033 | $APPEND_##e
2034 | $REPLACE_現行
2035 | $REPLACE_しゅう
2036 | $APPEND_e
2037 | $REPLACE_菌
2038 | $REPLACE_要請
2039 | $APPEND_独特
2040 | $REPLACE_いま
2041 | $REPLACE_之
2042 | $REPLACE_交代
2043 | $REPLACE_危険
2044 | $APPEND_展
2045 | $REPLACE_境界
2046 | $REPLACE_原
2047 | $REPLACE_品
2048 | $REPLACE_手法
2049 | $APPEND_先
2050 | $REPLACE_普及
2051 | $REPLACE_信じ
2052 | $REPLACE_##込ま
2053 | $REPLACE_実装
2054 | $APPEND_ちょう
2055 | $REPLACE_相対
2056 | $REPLACE_戦死
2057 | $REPLACE_参謀
2058 | $REPLACE_ア
2059 | $REPLACE_執行
2060 | $REPLACE_筋
2061 | $APPEND_-
2062 | $REPLACE_員
2063 | $REPLACE_##ズ
2064 | $APPEND_##be
2065 | $REPLACE_振興
2066 | $REPLACE_番台
2067 | $REPLACE_好
2068 | $REPLACE_書き
2069 | $REPLACE_申請
2070 | $REPLACE_##たい
2071 | $REPLACE_植民
2072 | $REPLACE_故障
2073 | $REPLACE_特技
2074 | $REPLACE_入っ
2075 | $APPEND_先生
2076 | $REPLACE_同じ
2077 | $REPLACE_毛
2078 | $APPEND_ちゅう
2079 | $APPEND_自治
2080 | $REPLACE_##ょ
2081 | $REPLACE_車体
2082 | $APPEND_ち
2083 | $APPEND_市
2084 | $REPLACE_再
2085 | $REPLACE_ちゅう
2086 | $APPEND_a
2087 | $REPLACE_先生
2088 | $REPLACE_出来
2089 | $REPLACE_価値
2090 | $REPLACE_かかわら
2091 | $REPLACE_使う
2092 | $REPLACE_イ
2093 | $APPEND_源
2094 | $APPEND_府
2095 | $REPLACE_信
2096 | $REPLACE_~
2097 | $REPLACE_今回
2098 | $REPLACE_指示
2099 | $REPLACE_オ
2100 | $REPLACE_座
2101 | $REPLACE_南
2102 | $APPEND_新規
2103 | $APPEND_こんな
2104 | $REPLACE_なぜ
2105 | $REPLACE_卵
2106 | $REPLACE_碑
2107 | $REPLACE_武器
2108 | $APPEND_入っ
2109 | $REPLACE_司法
2110 | $REPLACE_称号
2111 | $REPLACE_良く
2112 | $REPLACE_分かり
2113 | $APPEND_t
2114 | $REPLACE_ましょう
2115 | $REPLACE_##聞
2116 | $APPEND_ちゃ
2117 | $REPLACE_普段
2118 | $REPLACE_格
2119 | $REPLACE_じょう
2120 | $REPLACE_推進
2121 | $REPLACE_僕
2122 | $REPLACE_身体
2123 | $REPLACE_棟
2124 | $APPEND_あたり
2125 | $APPEND_やはり
2126 | $REPLACE_直
2127 | $REPLACE_あん
2128 | $REPLACE_緑
2129 | $REPLACE_言わ
2130 | $APPEND_林
2131 | $REPLACE_通う
2132 | $APPEND_児
2133 | $REPLACE_郊外
2134 | $APPEND_程
2135 | $APPEND_おい
2136 | $REPLACE_かな
2137 | $APPEND_予定
2138 | $APPEND_m
2139 | $REPLACE_吸収
2140 | $REPLACE_症状
2141 | $REPLACE_ヒ
2142 | $REPLACE_平方
2143 | $REPLACE_=
2144 | $REPLACE_べ
2145 | $REPLACE_奏者
2146 | $REPLACE_てい
2147 | $APPEND_名
2148 | $REPLACE_谷
2149 | $REPLACE_望遠
2150 | $APPEND_いき
2151 | $REPLACE_資本
2152 | $REPLACE_故郷
2153 | $APPEND_ようやく
2154 | $APPEND_毎
2155 | $REPLACE_さらに
2156 | $REPLACE_##子
2157 | $APPEND_り
2158 | $APPEND_問題
2159 | $APPEND_め
2160 | $APPEND_輪
2161 | $REPLACE_かわ
2162 | $REPLACE_帰
2163 | $APPEND_英語
2164 | $APPEND_映画
2165 | $REPLACE_後方
2166 | $REPLACE_思い出
2167 | $REPLACE_楽
2168 | $REPLACE_##体
2169 | $APPEND_出来
2170 | $REPLACE_該当
2171 | $REPLACE_気分
2172 | $REPLACE_うる
2173 | $REPLACE_甲子
2174 | $APPEND_ひと
2175 | $REPLACE_##キ
2176 | $REPLACE_功績
2177 | $REPLACE_修士
2178 | $APPEND_いろいろ
2179 | $REPLACE_加工
2180 | $REPLACE_後者
2181 | $REPLACE_公表
2182 | $APPEND_しも
2183 | $REPLACE_メス
2184 | $REPLACE_よし
2185 | $REPLACE_幕
2186 | $REPLACE_仲間
2187 | $REPLACE_動画
2188 | $REPLACE_尾
2189 | $REPLACE_漢
2190 | $REPLACE_決まっ
2191 | $REPLACE_香
2192 | $APPEND_二
2193 | $REPLACE_よい
2194 | $REPLACE_ほん
2195 | $APPEND_いただ
2196 | $REPLACE_気象
2197 | $REPLACE_丘
2198 | $REPLACE_達
2199 | $REPLACE_使っ
2200 | $REPLACE_ただ
2201 | $REPLACE_提携
2202 | $REPLACE_得
2203 | $REPLACE_蒸気
2204 | $APPEND_っ
2205 | $APPEND_急速
2206 | $REPLACE_書記
2207 | $REPLACE_加
2208 | $REPLACE_先端
2209 | $REPLACE_町名
2210 | $REPLACE_大変
2211 | $REPLACE_あたる
2212 | $REPLACE_従事
2213 | $APPEND_料理
2214 | $REPLACE_改修
2215 | $REPLACE_##e
2216 | $REPLACE_ねん
2217 | $REPLACE_ヒル
2218 | $APPEND_使い
2219 | $REPLACE_披露
2220 | $REPLACE_##家
2221 | $REPLACE_拡張
2222 | $REPLACE_予定
2223 | $REPLACE_改
2224 | $APPEND_使っ
2225 | $REPLACE_楽器
2226 | $APPEND_名前
2227 | $REPLACE_末
2228 | $REPLACE_コーヒー
2229 | $REPLACE_ヒト
2230 | $REPLACE_退
2231 | $REPLACE_龍
2232 | $REPLACE_##ひ
2233 | $REPLACE_あたり
2234 | $REPLACE_観
2235 | $APPEND_まま
2236 | $REPLACE_起き
2237 | $APPEND_おも
2238 | $REPLACE_いれ
2239 | $REPLACE_校舎
2240 | $REPLACE_開館
2241 | $REPLACE_ウ
2242 | $REPLACE_様子
2243 | $REPLACE_玉
2244 | $REPLACE_交響
2245 | $REPLACE_相
2246 | $REPLACE_##出さ
2247 | $REPLACE_一番
2248 | $REPLACE_び
2249 | $APPEND_試験
2250 | $APPEND_一人
2251 | $APPEND_生活
2252 | $REPLACE_性格
2253 | $APPEND_ひ
2254 | $REPLACE_おけ
2255 | $REPLACE_##上がっ
2256 | $REPLACE_退任
2257 | $APPEND_よい
2258 | $APPEND_##は
2259 | $REPLACE_寝
2260 | $APPEND_あれ
2261 | $APPEND_次々
2262 | $APPEND_使う
2263 | $REPLACE_ちゃ
2264 | $REPLACE_##ょう
2265 | $APPEND_受け
2266 | $REPLACE_刑事
2267 | $REPLACE_あい
2268 | $REPLACE_開校
2269 | $APPEND_3
2270 | $REPLACE_しゃ
2271 | $REPLACE_菓子
2272 | $REPLACE_機種
2273 | $REPLACE_##過
2274 | $REPLACE_保険
2275 | $REPLACE_末期
2276 | $REPLACE_戦車
2277 | $REPLACE_そ
2278 | $REPLACE_辞任
2279 | $REPLACE_通る
2280 | $REPLACE_##行
2281 | $REPLACE_当
2282 | $REPLACE_ト
2283 | $REPLACE_進化
2284 | $REPLACE_はな
2285 | $REPLACE_一体
2286 | $REPLACE_聞き
2287 | $REPLACE_くれる
2288 | $REPLACE_渡り
2289 | $REPLACE_次第
2290 | $APPEND_之
2291 | $APPEND_冠
2292 | $REPLACE_##下
2293 | $REPLACE_青年
2294 | $REPLACE_べき
2295 | $REPLACE_知り
2296 | $APPEND_素
2297 | $APPEND_あらゆる
2298 | $REPLACE_どこ
2299 | $REPLACE_##位
2300 | $REPLACE_重賞
2301 | $APPEND_##yo
2302 | $REPLACE_太
2303 | $REPLACE_固定
2304 | $REPLACE_発射
2305 | $APPEND_奪
2306 | $REPLACE_現地
2307 | $REPLACE_貿易
2308 | $REPLACE_および
2309 | $REPLACE_鎮
2310 | $REPLACE_戦士
2311 | $REPLACE_崩壊
2312 | $REPLACE_勲
2313 | $REPLACE_大量
2314 | $REPLACE_提出
2315 | $REPLACE_僧
2316 | $REPLACE_皆
2317 | $REPLACE_属
2318 | $REPLACE_公国
2319 | $REPLACE_中間
2320 | $REPLACE_シ
2321 | $APPEND_巡
2322 | $APPEND_##f
2323 | $REPLACE_横
2324 | $REPLACE_里
2325 | $REPLACE_丁
2326 | $REPLACE_希望
2327 | $REPLACE_限定
2328 | $APPEND_産
2329 | $APPEND_回
2330 | $REPLACE_##シ
2331 | $REPLACE_刑
2332 | $APPEND_最近
2333 | $REPLACE_類
2334 | $REPLACE_奥
2335 | $REPLACE_酵素
2336 | $APPEND_たる
2337 | $REPLACE_##がっ
2338 | $APPEND_be
2339 | $REPLACE_マス
2340 | $REPLACE_通り
2341 | $REPLACE_過ごし
2342 | $APPEND_く
2343 | $REPLACE_分子
2344 | $REPLACE_簡単
2345 | $REPLACE_いえ
2346 | $REPLACE_並行
2347 | $REPLACE_ニ
2348 | $REPLACE_姿勢
2349 | $REPLACE_依頼
2350 | $REPLACE_方法
2351 | $APPEND_>
2352 | $APPEND_##しかっ
2353 | $APPEND_やすく
2354 | $REPLACE_午後
2355 | $REPLACE_降格
2356 | $REPLACE_異
2357 | $REPLACE_成果
2358 | $APPEND_見る
2359 | $REPLACE_ラーメン
2360 | $REPLACE_暑
2361 | $APPEND_教え
2362 | $REPLACE_仏教
2363 | $REPLACE_受験
2364 | $REPLACE_入れ
2365 | $REPLACE_##きょう
2366 | $REPLACE_ゆう
2367 | $REPLACE_##合う
2368 | $REPLACE_陽
2369 | $APPEND_ぐ
2370 | $REPLACE_いち
2371 | $REPLACE_##士
2372 | $APPEND_女
2373 | $APPEND_強力
2374 | $REPLACE_捜査
2375 | $APPEND_匹
2376 | $REPLACE_両方
2377 | $REPLACE_山地
2378 | $REPLACE_ばかり
2379 | $APPEND_ごく
2380 | $REPLACE_i
2381 | $REPLACE_厚生
2382 | $REPLACE_用い
2383 | $APPEND_こそ
2384 | $REPLACE_次元
2385 | $REPLACE_商人
2386 | $REPLACE_大隊
2387 | $APPEND_聞い
2388 | $APPEND_建て
2389 | $REPLACE_変え
2390 | $REPLACE_犬
2391 | $REPLACE_象徴
2392 | $REPLACE_のみ
2393 | $REPLACE_士官
2394 | $REPLACE_航海
2395 | $REPLACE_5
2396 | $APPEND_みる
2397 | $APPEND_技
2398 | $APPEND_理解
2399 | $APPEND_しよう
2400 | $REPLACE_騎士
2401 | $APPEND_##m
2402 | $REPLACE_推薦
2403 | $REPLACE_育成
2404 | $REPLACE_用意
2405 | $REPLACE_バン
2406 | $REPLACE_よかっ
2407 | $REPLACE_こそ
2408 | $REPLACE_楽団
2409 | $REPLACE_任ぜ
2410 | $REPLACE_##かり
2411 | $APPEND_ユニバーサル
2412 | $APPEND_##ぎ
2413 | $REPLACE_低下
2414 | $REPLACE_収集
2415 | $REPLACE_氷
2416 | $REPLACE_雪
2417 | $REPLACE_タ
2418 | $REPLACE_##合い
2419 | $APPEND_け
2420 | $REPLACE_決め
2421 | $APPEND_広
2422 | $APPEND_石
2423 | $REPLACE_##で
2424 | $REPLACE_上昇
2425 | $REPLACE_##だち
2426 | $REPLACE_予備
2427 | $APPEND_おり
2428 | $REPLACE_農
2429 | $REPLACE_管弦
2430 | $APPEND_ほとんど
2431 | $REPLACE_4
2432 | $REPLACE_大公
2433 | $REPLACE_鳥
2434 | $APPEND_みな
2435 | $REPLACE_所長
2436 | $REPLACE_艦長
2437 | $REPLACE_つけ
2438 | $REPLACE_がわ
2439 | $APPEND_適切
2440 | $REPLACE_背景
2441 | $REPLACE_降伏
2442 | $REPLACE_実際
2443 | $REPLACE_s
2444 | $REPLACE_いっぱい
2445 | $REPLACE_異常
2446 | $APPEND_頻繁
2447 | $REPLACE_遊
2448 | $REPLACE_落ち
2449 | $REPLACE_学ん
2450 | $REPLACE_役所
2451 | $REPLACE_退社
2452 | $APPEND_##お
2453 | $REPLACE_ケ
2454 | $REPLACE_過ぎ
2455 | $REPLACE_寮
2456 | $REPLACE_転向
2457 | $REPLACE_脚
2458 | $REPLACE_##かし
2459 | $APPEND_⇒
2460 | $REPLACE_##マ
2461 | $REPLACE_カ国
2462 | $REPLACE_特に
2463 | $REPLACE_##合っ
2464 | $REPLACE_##ド
2465 | $REPLACE_現象
2466 | $REPLACE_校長
2467 | $REPLACE_理由
2468 | $REPLACE_詞
2469 | $REPLACE_原題
2470 | $REPLACE_達する
2471 | $REPLACE_弾
2472 | $REPLACE_探し
2473 | $REPLACE_砦
2474 | $REPLACE_移っ
2475 | $APPEND_理由
2476 | $APPEND_対し
2477 | $REPLACE_繁殖
2478 | $REPLACE_見れ
2479 | $REPLACE_##心
2480 | $REPLACE_印
2481 | $REPLACE_アル
2482 | $REPLACE_波
2483 | $REPLACE_最新
2484 | $REPLACE_巡洋
2485 | $REPLACE_高級
2486 | $REPLACE_##ゃ
2487 | $REPLACE_連携
2488 | $REPLACE_帯
2489 | $APPEND_元々
2490 | $REPLACE_丸
2491 | $REPLACE_称する
2492 | $REPLACE_騎手
2493 | $REPLACE_短編
2494 | $REPLACE_守備
2495 | $REPLACE_復興
2496 | $REPLACE_様式
2497 | $REPLACE_10
2498 | $APPEND_あの
2499 | $REPLACE_##ゆ
2500 | $REPLACE_居
2501 | $REPLACE_通
2502 | $REPLACE_退官
2503 | $APPEND_一旦
2504 | $REPLACE_周波
2505 | $REPLACE_とり
2506 | $REPLACE_後任
2507 | $REPLACE_問題
2508 | $APPEND_スマート
2509 | $REPLACE_たった
2510 | $REPLACE_上京
2511 | $APPEND_漢字
2512 | $REPLACE_十字
2513 | $APPEND_樹
2514 | $REPLACE_楽しみ
2515 | $REPLACE_個体
2516 | $REPLACE_勝目
2517 | $APPEND_もらっ
2518 | $REPLACE_事情
2519 | $REPLACE_霊
2520 | $REPLACE_聖堂
2521 | $APPEND_はっきり
2522 | $REPLACE_刺史
2523 | $REPLACE_金曜
2524 | $REPLACE_##わり
2525 | $REPLACE_発着
2526 | $REPLACE_電動
2527 | $REPLACE_十分
2528 | $REPLACE_海戦
2529 | $REPLACE_策
2530 | $REPLACE_高架
2531 | $APPEND_##ぁ
2532 | $REPLACE_神話
2533 | $REPLACE_発し
2534 | $REPLACE_重量
2535 | $REPLACE_##タ
2536 | $REPLACE_塩
2537 | $APPEND_言わ
2538 | $REPLACE_よろ
2539 | $APPEND_勢
2540 | $REPLACE_体系
2541 | $REPLACE_t
2542 | $APPEND_場所
2543 | $REPLACE_終わる
2544 | $REPLACE_防止
2545 | $REPLACE_##ウ
2546 | $REPLACE_前回
2547 | $APPEND_犬
2548 | $REPLACE_高い
2549 | $APPEND_かけ
2550 | $REPLACE_フ
2551 | $REPLACE_##ない
2552 | $APPEND_着
2553 | $APPEND_当たり
2554 | $REPLACE_##らしい
2555 | $REPLACE_拒否
2556 | $APPEND_u
2557 | $REPLACE_ゲーム
2558 | $REPLACE_様
2559 | $APPEND_##ay
2560 | $REPLACE_助
2561 | $REPLACE_御
2562 | $REPLACE_港
2563 | $APPEND_すれ
2564 | $REPLACE_あれ
2565 | $REPLACE_作り
2566 | $APPEND_母
2567 | $APPEND_ないし
2568 | $REPLACE_混
2569 | $REPLACE_セイ
2570 | $REPLACE_発言
2571 | $REPLACE_メモリ
2572 | $REPLACE_妃
2573 | $REPLACE_使わ
2574 | $REPLACE_鑑賞
2575 | $REPLACE_上手
2576 | $REPLACE_恒星
2577 | $REPLACE_余儀
2578 | $APPEND_豊富
2579 | $APPEND_録
2580 | $REPLACE_##気
2581 | $REPLACE_##つい
2582 | $REPLACE_中退
2583 | $REPLACE_諸
2584 | $REPLACE_##t
2585 | $REPLACE_##法
2586 | $APPEND_もちろん
2587 | $REPLACE_先頭
2588 | $APPEND_写真
2589 | $REPLACE_生成
2590 | $REPLACE_取引
2591 | $REPLACE_監視
2592 | $REPLACE_洋
2593 | $REPLACE_##着
2594 | $REPLACE_養成
2595 | $REPLACE_##光
2596 | $REPLACE_解消
2597 | $REPLACE_n
2598 | $REPLACE_色
2599 | $REPLACE_前期
2600 | $REPLACE_清
2601 | $REPLACE_悪化
2602 | $REPLACE_おん
2603 | $REPLACE_集
2604 | $REPLACE_##取り
2605 | $APPEND_野
2606 | $REPLACE_虫
2607 | $APPEND_しょ
2608 | $REPLACE_トラ
2609 | $REPLACE_効率
2610 | $REPLACE_棒
2611 | $APPEND_願い
2612 | $APPEND_すら
2613 | $REPLACE_原則
2614 | $APPEND_##ore
2615 | $APPEND_OK
2616 | $APPEND_使
2617 | $REPLACE_正確
2618 | $REPLACE_墓地
2619 | $REPLACE_##かる
2620 | $REPLACE_バイ
2621 | $REPLACE_格闘
2622 | $REPLACE_方言
2623 | $REPLACE_田
2624 | $REPLACE_弦
2625 | $REPLACE_続編
2626 | $REPLACE_多
2627 | $REPLACE_I
2628 | $REPLACE_違い
2629 | $REPLACE_なさい
2630 | $REPLACE_夜間
2631 | $REPLACE_河
2632 | $REPLACE_遺
2633 | $REPLACE_取り
2634 | $REPLACE_サ
2635 | $APPEND_べ
2636 | $REPLACE_離婚
2637 | $REPLACE_課題
2638 | $REPLACE_主に
2639 | $REPLACE_かた
2640 | $REPLACE_草
2641 | $REPLACE_放棄
2642 | $REPLACE_うた
2643 | $REPLACE_か所
2644 | $REPLACE_楽しい
2645 | $REPLACE_族
2646 | $REPLACE_大き
2647 | $REPLACE_試作
2648 | $REPLACE_-
2649 | $REPLACE_正教
2650 | $REPLACE_知識
2651 | $REPLACE_数字
2652 | $REPLACE_競争
2653 | $REPLACE_a
2654 | $APPEND_にくい
2655 | $REPLACE_付属
2656 | $REPLACE_魚
2657 | $REPLACE_出来事
2658 | $REPLACE_全部
2659 | $REPLACE_伊
2660 | $REPLACE_行わ
2661 | $REPLACE_ちょう
2662 | $REPLACE_##頭
2663 | $REPLACE_想定
2664 | $REPLACE_短期
2665 | $APPEND_豊か
2666 | $APPEND_月
2667 | $APPEND_##h
2668 | $APPEND_舎
2669 | $REPLACE_将棋
2670 | $REPLACE_税
2671 | $REPLACE_確立
2672 | $APPEND_夜
2673 | $APPEND_在
2674 | $REPLACE_来る
2675 | $REPLACE_刑務
2676 | $APPEND_そ
2677 | $APPEND_しかも
2678 | $APPEND_##だ
2679 | $REPLACE_卒業
2680 | $REPLACE_箱
2681 | $REPLACE_相続
2682 | $REPLACE_じんじゃ
2683 | $APPEND_幼稚
2684 | $REPLACE_善
2685 | $REPLACE_コマ
2686 | $REPLACE_敗戦
2687 | $REPLACE_配給
2688 | $REPLACE_移り
2689 | $REPLACE_審議
2690 | $REPLACE_しゅ
2691 | $APPEND_俺
2692 | $APPEND_最初
2693 | $REPLACE_乾燥
2694 | $REPLACE_四
2695 | $APPEND_ぐらい
2696 | $REPLACE_##ム
2697 | $REPLACE_いろいろ
2698 | $APPEND_##na
2699 | $REPLACE_##々
2700 | $REPLACE_資産
2701 | $REPLACE_代理
2702 | $REPLACE_滝
2703 | $REPLACE_都立
2704 | $REPLACE_誰
2705 | $REPLACE_番手
2706 | $APPEND_分かり
2707 | $REPLACE_招待
2708 | $REPLACE_参考
2709 | $REPLACE_区画
2710 | $REPLACE_個
2711 | $REPLACE_連覇
2712 | $APPEND_とう
2713 | $REPLACE_繰り返し
2714 | $REPLACE_詩人
2715 | $APPEND_けん
2716 | $APPEND_きっ
2717 | $REPLACE_モ
2718 | $REPLACE_##中
2719 | $REPLACE_選定
2720 | $REPLACE_場合
2721 | $REPLACE_発
2722 | $APPEND_荘
2723 | $REPLACE_##わる
2724 | $REPLACE_e
2725 | $REPLACE_一環
2726 | $REPLACE_会館
2727 | $APPEND_##is
2728 | $APPEND_酒
2729 | $REPLACE_殿堂
2730 | $REPLACE_##ア
2731 | $REPLACE_##コ
2732 | $REPLACE_大きく
2733 | $REPLACE_##山
2734 | $REPLACE_加わっ
2735 | $REPLACE_完了
2736 | $REPLACE_味
2737 | $REPLACE_やり
2738 | $REPLACE_業
2739 | $REPLACE_助け
2740 | $REPLACE_古
2741 | $REPLACE_湯
2742 | $REPLACE_もらい
2743 | $REPLACE_いん
2744 | $APPEND_良く
2745 | $REPLACE_引っ
2746 | $REPLACE_##がい
2747 | $REPLACE_入
2748 | $REPLACE_移管
2749 | $REPLACE_良
2750 | $REPLACE_功
2751 | $REPLACE_移植
2752 | $REPLACE_守護
2753 | $REPLACE_自衛
2754 | $REPLACE_執政
2755 | $APPEND_上手
2756 | $REPLACE_単語
2757 | $REPLACE_郎
2758 | $REPLACE_ぬ
2759 | $REPLACE_旧制
2760 | $REPLACE_質
2761 | $APPEND_どれ
2762 | $REPLACE_講談
2763 | $REPLACE_考古
2764 | $REPLACE_o
2765 | $REPLACE_すば
2766 | $REPLACE_陵
2767 | $APPEND_続
2768 | $REPLACE_##館
2769 | $REPLACE_選考
2770 | $REPLACE_穴
2771 | $REPLACE_定年
2772 | $REPLACE_感情
2773 | $APPEND_げ
2774 | $REPLACE_友
2775 | $APPEND_殿
2776 | $REPLACE_創建
2777 | $APPEND_住ん
2778 | $APPEND_##削
2779 | $APPEND_内容
2780 | $REPLACE_関心
2781 | $REPLACE_出す
2782 | $REPLACE_旗
2783 | $REPLACE_主宰
2784 | $REPLACE_バラ
2785 | $REPLACE_法政
2786 | $REPLACE_日々
2787 | $REPLACE_冬季
2788 | $APPEND_異
2789 | $REPLACE_ぶん
2790 | $REPLACE_##正
2791 | $REPLACE_球場
2792 | $REPLACE_仏
2793 | $REPLACE_じゅう
2794 | $REPLACE_制
2795 | $REPLACE_昼
2796 | $REPLACE_既存
2797 | $REPLACE_代わっ
2798 | $REPLACE_周り
2799 | $APPEND_まったく
2800 | $REPLACE_表彰
2801 | $REPLACE_約
2802 | $REPLACE_もらっ
2803 | $REPLACE_特急
2804 | $REPLACE_行なわ
2805 | $REPLACE_介
2806 | $REPLACE_境
2807 | $REPLACE_守
2808 | $APPEND_む
2809 | $APPEND_ほん
2810 | $REPLACE_##形
2811 | $REPLACE_経っ
2812 | $REPLACE_ゴ
2813 | $REPLACE_高齢
2814 | $APPEND_##週
2815 | $REPLACE_能
2816 | $REPLACE_領
2817 | $APPEND_暦
2818 | $REPLACE_体長
2819 | $REPLACE_見え
2820 | $REPLACE_双方
2821 | $APPEND_使わ
2822 | $APPEND_炎
2823 | $REPLACE_冒頭
2824 | $REPLACE_適し
2825 | $REPLACE_半ば
2826 | $REPLACE_帰り
2827 | $APPEND_多い
2828 | $REPLACE_古典
2829 | $REPLACE_依存
2830 | $REPLACE_癌
2831 | $REPLACE_くだ
2832 | $REPLACE_投入
2833 | $REPLACE_商店
2834 | $REPLACE_祭
2835 | $REPLACE_転
2836 | $REPLACE_竜
2837 | $REPLACE_火曜
2838 | $APPEND_口
2839 | $REPLACE_アン
2840 | $REPLACE_激しい
2841 | $REPLACE_##入
2842 | $REPLACE_若干
2843 | $APPEND_むしろ
2844 | $REPLACE_戦線
2845 | $APPEND_<
2846 | $TRANSFORM_VBP_VBV
2847 | $REPLACE_塾
2848 | $REPLACE_温度
2849 | $APPEND_大切
2850 | $APPEND_もらい
2851 | $REPLACE_ゆえ
2852 | $REPLACE_意図
2853 | $REPLACE_題
2854 | $REPLACE_##フ
2855 | $REPLACE_減
2856 | $APPEND_あい
2857 | $APPEND_庫
2858 | $REPLACE_##線
2859 | $REPLACE_文献
2860 | $REPLACE_##っと
2861 | $REPLACE_死者
2862 | $REPLACE_北
2863 | $REPLACE_だり
2864 | $REPLACE_興業
2865 | $APPEND_走
2866 | $APPEND_##o
2867 | $REPLACE_変わり
2868 | $REPLACE_##まら
2869 | $REPLACE_アナ
2870 | $REPLACE_会い
2871 | $REPLACE_願い
2872 | $APPEND_くん
2873 | $REPLACE_強
2874 | $REPLACE_つき
2875 | $REPLACE_冠
2876 | $REPLACE_速
2877 | $APPEND_なんて
2878 | $APPEND_練習
2879 | $REPLACE_##食
2880 | $REPLACE_##べる
2881 | $REPLACE_##付け
2882 | $APPEND_すご
2883 | $APPEND_会社
2884 | $REPLACE_改装
2885 | $REPLACE_昨
2886 | $REPLACE_書く
2887 | $APPEND_なかなか
2888 | $REPLACE_げん
2889 | $REPLACE_##ジ
2890 | $REPLACE_タイ
2891 | $REPLACE_殺人
2892 | $APPEND_門
2893 | $REPLACE_変換
2894 | $REPLACE_直径
2895 | $APPEND_活発
2896 | $REPLACE_侵入
2897 | $REPLACE_6
2898 | $REPLACE_括弧
2899 | $REPLACE_工科
2900 | $REPLACE_だん
2901 | $REPLACE_##道
2902 | $REPLACE_シン
2903 | $APPEND_知っ
2904 | $REPLACE_マ
2905 | $REPLACE_みる
2906 | $REPLACE_州都
2907 | $REPLACE_警備
2908 | $REPLACE_##張っ
2909 | $REPLACE_意思
2910 | $REPLACE_検
2911 | $REPLACE_坂
2912 | $REPLACE_上場
2913 | $REPLACE_聖書
2914 | $REPLACE_酒
2915 | $REPLACE_五
2916 | $REPLACE_着い
2917 | $REPLACE_作っ
2918 | $APPEND_あくまで
2919 | $REPLACE_りゅう
2920 | $REPLACE_##d
2921 | $REPLACE_一家
2922 | $REPLACE_振り
2923 | $REPLACE_##進
2924 | $REPLACE_チ
2925 | $REPLACE_印象
2926 | $REPLACE_胸
2927 | $REPLACE_記号
2928 | $REPLACE_礼
2929 | $REPLACE_##起
2930 | $REPLACE_経路
2931 | $REPLACE_##ぜ
2932 | $REPLACE_歌唱
2933 | $REPLACE_史料
2934 | $REPLACE_笑い
2935 | $REPLACE_公社
2936 | $REPLACE_大夫
2937 | $REPLACE_定理
2938 | $REPLACE_ヤ
2939 | $REPLACE_包囲
2940 | $APPEND_単
2941 | $APPEND_子供
2942 | $APPEND_『
2943 | $APPEND_##p
2944 | $APPEND_ざる
2945 | $REPLACE_##ガ
2946 | $APPEND_中国
2947 | $APPEND_数
2948 | $APPEND_女性
2949 | $REPLACE_修
2950 | $APPEND_稀
2951 | $REPLACE_ありがとう
2952 | $REPLACE_ツ
2953 | $REPLACE_牛
2954 | $REPLACE_哨戒
2955 | $REPLACE_幸せ
2956 | $APPEND_わから
2957 | $APPEND_グレート
2958 | $REPLACE_転換
2959 | $REPLACE_ヶ月
2960 | $REPLACE_方程
2961 | $APPEND_亜
2962 | $REPLACE_艦名
2963 | $REPLACE_次い
2964 | $REPLACE_仲
2965 | $REPLACE_昇
2966 | $REPLACE_一貫
2967 | $REPLACE_官僚
2968 | $REPLACE_入り
2969 | $APPEND_停
2970 | $REPLACE_##づ
2971 | $REPLACE_なに
2972 | $REPLACE_関与
2973 | $REPLACE_起点
2974 | $REPLACE_解析
2975 | $REPLACE_##化
2976 | $REPLACE_公認
2977 | $REPLACE_皇
2978 | $REPLACE_つか
2979 | $REPLACE_甥
2980 | $REPLACE_過
2981 | $REPLACE_正規
2982 | $REPLACE_投
2983 | $REPLACE_屋
2984 | $APPEND_つき
2985 | $REPLACE_高原
2986 | $REPLACE_縦
2987 | $REPLACE_広がっ
2988 | $APPEND_##に
2989 | $REPLACE_同名
2990 | $APPEND_くだ
2991 | $REPLACE_人工
2992 | $REPLACE_概ね
2993 | $REPLACE_##張り
2994 | $APPEND_##こう
2995 | $APPEND_別
2996 | $APPEND_よる
2997 | $REPLACE_講演
2998 | $REPLACE_証券
2999 | $APPEND_すぎ
3000 | $REPLACE_星
3001 | $APPEND_被
3002 | $REPLACE_室
3003 | $APPEND_だん
3004 | $REPLACE_ク
3005 | $REPLACE_増
3006 | $APPEND_たった
3007 | $REPLACE_棋士
3008 | $REPLACE_地理
3009 | $REPLACE_小さな
3010 | $REPLACE_##っち
3011 | $APPEND_サン
3012 | $REPLACE_圏
3013 | $REPLACE_協定
3014 | $REPLACE_百貨
3015 | $REPLACE_班
3016 | $APPEND_真
3017 | $REPLACE_明
3018 | $REPLACE_演技
3019 | $REPLACE_抗議
3020 | $REPLACE_城主
3021 | $REPLACE_嬉
3022 | $APPEND_過ぎ
3023 | $REPLACE_桜
3024 | $REPLACE_見つ
3025 | $REPLACE_流
3026 | $APPEND_帳
3027 | $REPLACE_分解
3028 | $REPLACE_起
3029 | $REPLACE_##取っ
3030 | $REPLACE_園
3031 | $REPLACE_まず
3032 | $APPEND_難しい
3033 | $REPLACE_集まっ
3034 | $REPLACE_史跡
3035 | $REPLACE_構内
3036 | $REPLACE_ラップ
3037 | $APPEND_ちょうど
3038 | $APPEND_僅か
3039 | $APPEND_当然
3040 | $REPLACE_積
3041 | $REPLACE_将
3042 | $REPLACE_出店
3043 | $REPLACE_渡る
3044 | $REPLACE_どちら
3045 | $REPLACE_呼ば
3046 | $REPLACE_勃発
3047 | $REPLACE_競輪
3048 | $APPEND_##y
3049 | $REPLACE_行なっ
3050 | $REPLACE_成る
3051 | $APPEND_但し
3052 | $REPLACE_未
3053 | $REPLACE_解雇
3054 | $REPLACE_矢
3055 | $REPLACE_性質
3056 | $REPLACE_少
3057 | $APPEND_済み
3058 | $REPLACE_長い
3059 | $APPEND_##丈
3060 | $REPLACE_事前
3061 | $APPEND_オ
3062 | $REPLACE_株
3063 | $REPLACE_カン
3064 | $REPLACE_12
3065 | $REPLACE_##どう
3066 | $APPEND_アメリカ
3067 | $REPLACE_審判
3068 | $REPLACE_##w
3069 | $REPLACE_覚え
3070 | $APPEND_##っき
3071 | $REPLACE_遅れ
3072 | $APPEND_ぼく
3073 | $REPLACE_姫
3074 | $REPLACE_きん
3075 | $REPLACE_団
3076 | $APPEND_志
3077 | $REPLACE_創
3078 | $REPLACE_正しい
3079 | $REPLACE_余
3080 | $REPLACE_歌劇
3081 | $APPEND_方法
3082 | $REPLACE_##ロ
3083 | $REPLACE_月刊
3084 | $REPLACE_休
3085 | $REPLACE_##人
3086 | $REPLACE_相撲
3087 | $APPEND_草
3088 | $REPLACE_優先
3089 | $APPEND_悪
3090 | $REPLACE_廃車
3091 | $REPLACE_嫌
3092 | $REPLACE_同型
3093 | $APPEND_話す
3094 | $REPLACE_保障
3095 | $REPLACE_感謝
3096 | $REPLACE_死刑
3097 | $REPLACE_竹
3098 | $REPLACE_##部
3099 | $REPLACE_ぜん
3100 | $REPLACE_料理
3101 | $REPLACE_意向
3102 | $REPLACE_具
3103 | $REPLACE_乗
3104 | $APPEND_h
3105 | $REPLACE_最低
3106 | $REPLACE_附属
3107 | $REPLACE_推移
3108 | $APPEND_有
3109 | $APPEND_まるで
3110 | $REPLACE_叔父
3111 | $REPLACE_##名
3112 | $REPLACE_一員
3113 | $REPLACE_公爵
3114 | $APPEND_片
3115 | $REPLACE_##神
3116 | $REPLACE_##おう
3117 | $APPEND_寝
3118 | $REPLACE_幼少
3119 | $REPLACE_火災
3120 | $REPLACE_司教
3121 | $APPEND_峰
3122 | $REPLACE_##知
3123 | $REPLACE_じゅ
3124 | $REPLACE_市中
3125 | $REPLACE_ごと
3126 | $APPEND_##ば
3127 | $REPLACE_治世
3128 | $REPLACE_いろ
3129 | $APPEND_罪
3130 | $APPEND_衆
3131 | $REPLACE_私立
3132 | $REPLACE_##町
3133 | $REPLACE_##元
3134 | $REPLACE_腹
3135 | $REPLACE_きっ
3136 | $APPEND_こん
3137 | $REPLACE_いら
3138 | $APPEND_##th
3139 | $APPEND_旅行
3140 | $REPLACE_属し
3141 | $REPLACE_典型
3142 | $REPLACE_かし
3143 | $REPLACE_週末
3144 | $REPLACE_たび
3145 | $REPLACE_ドン
3146 | $REPLACE_当地
3147 | $REPLACE_大佐
3148 | $REPLACE_クロ
3149 | $REPLACE_東方
3150 | $REPLACE_綱
3151 | $REPLACE_##地
3152 | $REPLACE_売上
3153 | $APPEND_造
3154 | $REPLACE_会っ
3155 | $REPLACE_##つき
3156 | $APPEND_##to
3157 | $REPLACE_f
3158 | $APPEND_確実
3159 | $APPEND_赤
3160 | $APPEND_間違い
3161 | $REPLACE_h
3162 | $APPEND_初
3163 | $REPLACE_つま
3164 | $REPLACE_西洋
3165 | $APPEND_肢
3166 | $APPEND_##me
3167 | $REPLACE_大気
3168 | $APPEND_ドー
3169 | $REPLACE_見つかっ
3170 | $REPLACE_ヶ
3171 | $REPLACE_全線
3172 | $APPEND_続け
3173 | $APPEND_家族
3174 | $REPLACE_平
3175 | $REPLACE_反
3176 | $REPLACE_##管
3177 | $REPLACE_##物
3178 | $REPLACE_長年
3179 | $REPLACE_保守
3180 | $APPEND_簡易
3181 | $REPLACE_延伸
3182 | $REPLACE_*
3183 | $APPEND_##々
3184 | $REPLACE_将来
3185 | $APPEND_##消
3186 | $REPLACE_上部
3187 | $APPEND_准
3188 | $REPLACE_騎乗
3189 | $REPLACE_福祉
3190 | $REPLACE_宮
3191 | $APPEND_うまく
3192 | $APPEND_長らく
3193 | $REPLACE_いし
3194 | $REPLACE_てつ
3195 | $REPLACE_ならび
3196 | $APPEND_寄り
3197 | $APPEND_毎日
3198 | $REPLACE_別れ
3199 | $REPLACE_はた
3200 | $REPLACE_華
3201 | $REPLACE_任
3202 | $REPLACE_肩
3203 | $APPEND_骨
3204 | $REPLACE_傷
3205 | $REPLACE_炉
3206 | $APPEND_はず
3207 | $APPEND_とりわけ
3208 | $REPLACE_知
3209 | $APPEND_過
3210 | $REPLACE_##川
3211 | $APPEND_w
3212 | $APPEND_##/
3213 | $APPEND_原
3214 | $REPLACE_伝わる
3215 | $REPLACE_はじめて
3216 | $APPEND_##ろ
3217 | $REPLACE_アメリカ
3218 | $APPEND_はじめて
3219 | $REPLACE_##チ
3220 | $REPLACE_ふく
3221 | $REPLACE_エ
3222 | $REPLACE_整理
3223 | $REPLACE_治
3224 | $REPLACE_もん
3225 | $REPLACE_##ふ
3226 | $REPLACE_送信
3227 | $REPLACE_賞金
3228 | $REPLACE_食事
3229 | $REPLACE_続け
3230 | $REPLACE_広
3231 | $REPLACE_起工
3232 | $APPEND_大いに
3233 | $REPLACE_##機
3234 | $APPEND_me
3235 | $APPEND_会話
3236 | $REPLACE_習慣
3237 | $REPLACE_広がる
3238 | $REPLACE_きょく
3239 | $APPEND_以前
3240 | $REPLACE_雲
3241 | $REPLACE_##見
3242 | $REPLACE_体調
3243 | $APPEND_改めて
3244 | $REPLACE_りつ
3245 | $APPEND_関係
3246 | $APPEND_読み
3247 | $APPEND_魚
3248 | $REPLACE_言及
3249 | $REPLACE_##表
3250 | $REPLACE_親交
3251 | $REPLACE_記
3252 | $APPEND_##べ
3253 | $REPLACE_保
3254 | $REPLACE_改編
3255 | $APPEND_点
3256 | $REPLACE_過ご
3257 | $REPLACE_ろう
3258 | $REPLACE_頭部
3259 | $REPLACE_練習
3260 | $REPLACE_生活
3261 | $REPLACE_総裁
3262 | $APPEND_##よ
3263 | $REPLACE_組み合わせ
3264 | $REPLACE_きゅう
3265 | $APPEND_法的
3266 | $APPEND_顕著
3267 | $REPLACE_南朝
3268 | $REPLACE_##事
3269 | $REPLACE_同志
3270 | $REPLACE_抗争
3271 | $REPLACE_無視
3272 | $APPEND_準々
3273 | $APPEND_違い
3274 | $APPEND_OR
3275 | $REPLACE_か国
3276 | $REPLACE_打ち
3277 | $REPLACE_##文
3278 | $APPEND_f
3279 | $REPLACE_##水
3280 | $APPEND_次いで
3281 | $REPLACE_少数
3282 | $REPLACE_がく
3283 | $REPLACE_:
3284 | $REPLACE_終え
3285 | $REPLACE_トリ
3286 | $REPLACE_雰囲気
3287 | $REPLACE_]
3288 | $REPLACE_可能
3289 | $REPLACE_させ
3290 | $REPLACE_変身
3291 | $REPLACE_会話
3292 | $APPEND_会
3293 | $REPLACE_どれ
3294 | $REPLACE_違う
3295 | $REPLACE_籍
3296 | $REPLACE_モン
3297 | $REPLACE_みや
3298 | $REPLACE_答え
3299 | $APPEND_{
3300 | $REPLACE_発揮
3301 | $REPLACE_技師
3302 | $APPEND_あげ
3303 | $REPLACE_##ノ
3304 | $REPLACE_派
3305 | $REPLACE_代替
3306 | $REPLACE_##立っ
3307 | $REPLACE_やす
3308 | $REPLACE_##記
3309 | $REPLACE_非難
3310 | $REPLACE_新型
3311 | $REPLACE_厳しい
3312 | $APPEND_田
3313 | $REPLACE_見かけ
3314 | $REPLACE_航路
3315 | $REPLACE_##ツ
3316 | $APPEND_熱心
3317 | $REPLACE_思い出し
3318 | $REPLACE_主将
3319 | $REPLACE_電機
3320 | $APPEND_##"
3321 | $APPEND_不要
3322 | $REPLACE_将校
3323 | $APPEND_聞き
3324 | $REPLACE_影
3325 | $APPEND_##ca
3326 | $APPEND_材
3327 | $REPLACE_##上げ
3328 | $REPLACE_かなり
3329 | $REPLACE_大勢
3330 | $REPLACE_すぎ
3331 | $REPLACE_##上がり
3332 | $REPLACE_##屋
3333 | $REPLACE_質問
3334 | $REPLACE_##リー
3335 | $REPLACE_短
3336 | $REPLACE_屋根
3337 | $APPEND_しゃ
3338 | $REPLACE_添
3339 | $REPLACE_国土
3340 | $REPLACE_面し
3341 | $REPLACE_中断
3342 | $APPEND_役
3343 | $APPEND_##ta
3344 | $APPEND_たびたび
3345 | $REPLACE_道場
3346 | $REPLACE_デイ
3347 | $REPLACE_開放
3348 | $REPLACE_##年
3349 | $REPLACE_住ん
3350 | $REPLACE_かかり
3351 | $REPLACE_##バ
3352 | $REPLACE_時計
3353 | $REPLACE_そば
3354 | $REPLACE_かかる
3355 | $APPEND_とくに
3356 | $APPEND_大好き
3357 | $REPLACE_ニュース
3358 | $REPLACE_選
3359 | $REPLACE_参
3360 | $REPLACE_町長
3361 | $APPEND_色々
3362 | $REPLACE_みつ
3363 | $REPLACE_通行
3364 | $REPLACE_##トン
3365 | $APPEND_##an
3366 | $REPLACE_ソ
3367 | $REPLACE_衣装
3368 | $APPEND_銀
3369 | $REPLACE_ブドウ
3370 | $REPLACE_じん
3371 | $REPLACE_交
3372 | $REPLACE_ジ
3373 | $REPLACE_合唱
3374 | $REPLACE_パーティー
3375 | $REPLACE_宝
3376 | $APPEND_結果
3377 | $APPEND_早く
3378 | $REPLACE_任期
3379 | $REPLACE_罪
3380 | $APPEND_透明
3381 | $REPLACE_医科
3382 | $REPLACE_##しん
3383 | $REPLACE_近世
3384 | $REPLACE_##らい
3385 | $REPLACE_山岳
3386 | $APPEND_昨
3387 | $REPLACE_封じ
3388 | $APPEND_##。
3389 | $APPEND_深刻
3390 | $REPLACE_えん
3391 | $REPLACE_要塞
3392 | $REPLACE_ほ
3393 | $REPLACE_砲
3394 | $REPLACE_頂
3395 | $APPEND_すべて
3396 | $REPLACE_マンガ
3397 | $REPLACE_##切っ
3398 | $REPLACE_入省
3399 | $REPLACE_##戦
3400 | $REPLACE_c
3401 | $REPLACE_##返し
3402 | $REPLACE_宮内
3403 | $REPLACE_晩
3404 | $REPLACE_果たす
3405 | $REPLACE_反映
3406 | $REPLACE_旨
3407 | $REPLACE_伝承
3408 | $REPLACE_ヶ所
3409 | $REPLACE_ダン
3410 | $APPEND_雨
3411 | $REPLACE_枚
3412 | $REPLACE_季節
3413 | $REPLACE_駐車
3414 | $REPLACE_連結
3415 | $APPEND_相手
3416 | $REPLACE_##状
3417 | $REPLACE_読ん
3418 | $REPLACE_##職
3419 | $REPLACE_施工
3420 | $REPLACE_##f
3421 | $REPLACE_強い
3422 | $APPEND_性的
3423 | $REPLACE_接
3424 | $REPLACE_同行
3425 | $REPLACE_扉
3426 | $REPLACE_##ハ
3427 | $REPLACE_皮
3428 | $REPLACE_騎
3429 | $REPLACE_悪い
3430 | $REPLACE_抗
3431 | $REPLACE_おかげ
3432 | $REPLACE_キン
3433 | $REPLACE_はや
3434 | $APPEND_灯
3435 | $APPEND_新しい
3436 | $REPLACE_傍ら
3437 | $REPLACE_県内
3438 | $APPEND_同等
3439 | $APPEND_』
3440 | $APPEND_○
3441 | $APPEND_び
3442 | $REPLACE_運航
3443 | $REPLACE_しょ
3444 | $REPLACE_でん
3445 | $REPLACE_同校
3446 | $REPLACE_8
3447 | $REPLACE_ム
3448 | $REPLACE_読み
3449 | $REPLACE_観察
3450 | $APPEND_経
3451 | $REPLACE_自社
3452 | $REPLACE_成分
3453 | $REPLACE_糸
3454 | $APPEND_普通
3455 | $REPLACE_単
3456 | $APPEND_全部
3457 | $REPLACE_一連
3458 | $APPEND_##も
3459 | $REPLACE_公卿
3460 | $REPLACE_ダイ
3461 | $REPLACE_##分
3462 | $REPLACE_羽
3463 | $REPLACE_緑色
3464 | $APPEND_なに
3465 | $APPEND_おき
3466 | $REPLACE_メ
3467 | $REPLACE_靴
3468 | $REPLACE_重要
3469 | $REPLACE_リーグ
3470 | $REPLACE_術
3471 | $REPLACE_受
3472 | $REPLACE_野
3473 | $APPEND_##び
3474 | $REPLACE_##字
3475 | $REPLACE_らん
3476 | $APPEND_だり
3477 | $APPEND_##ural
3478 | $REPLACE_広報
3479 | $REPLACE_##書
3480 | $REPLACE_髪
3481 | $REPLACE_例外
3482 | $REPLACE_名詞
3483 | $REPLACE_##一
3484 | $APPEND_純粋
3485 | $REPLACE_がっ
3486 | $APPEND_最後
3487 | $APPEND_気分
3488 | $REPLACE_篇
3489 | $APPEND_ざ
3490 | $REPLACE_ロ
3491 | $REPLACE_出さ
3492 | $APPEND_岩
3493 | $REPLACE_トライ
3494 | $APPEND_やっと
3495 | $REPLACE_志
3496 | $REPLACE_##走
3497 | $APPEND_##on
3498 | $REPLACE_##ニ
3499 | $REPLACE_しゃべ
3500 | $REPLACE_##所
3501 | $REPLACE_下さい
3502 | $REPLACE_携わっ
3503 | $REPLACE_避難
3504 | $APPEND_度々
3505 | $REPLACE_典
3506 | $APPEND_悪い
3507 | $REPLACE_かつ
3508 | $REPLACE_あの
3509 | $REPLACE_##がり
3510 | $APPEND_決して
3511 | $REPLACE_##期
3512 | $REPLACE_##ゅう
3513 | $APPEND_興味
3514 | $REPLACE_寒
3515 | $REPLACE_走
3516 | $REPLACE_畑
3517 | $REPLACE_持ち
3518 | $REPLACE_義
3519 | $REPLACE_雄
3520 | $REPLACE_話す
3521 | $REPLACE_気筒
3522 | $REPLACE_ほしい
3523 | $APPEND_読ん
3524 | $REPLACE_刺激
3525 | $REPLACE_習得
3526 | $APPEND_入れ
3527 | $APPEND_終わっ
3528 | $REPLACE_非
3529 | $REPLACE_海洋
3530 | $REPLACE_支
3531 | $REPLACE_終
3532 | $REPLACE_##勝
3533 | $APPEND_##年
3534 | $REPLACE_考案
3535 | $REPLACE_台風
3536 | $REPLACE_武
3537 | $REPLACE_はし
3538 | $REPLACE_ありが
3539 | $REPLACE_##サ
3540 | $REPLACE_起こっ
3541 | $REPLACE_既
3542 | $APPEND_##まっ
3543 | $REPLACE_湖
3544 | $APPEND_順次
3545 | $APPEND_グローバル
3546 | $REPLACE_向かっ
3547 | $REPLACE_外国
3548 | $APPEND_##ぶん
3549 | $APPEND_厳密
3550 | $REPLACE_津
3551 | $APPEND_旗
3552 | $APPEND_##げ
3553 | $REPLACE_夏季
3554 | $REPLACE_四季
3555 | $REPLACE_魂
3556 | $REPLACE_林
3557 | $REPLACE_紀
3558 | $REPLACE_立つ
3559 | $REPLACE_##訳
3560 | $REPLACE_講座
3561 | $REPLACE_かみ
3562 | $REPLACE_特
3563 | $APPEND_##us
3564 | $APPEND_貴重
3565 | $REPLACE_解除
3566 | $REPLACE_帰る
3567 | $APPEND_遠
3568 | $APPEND_蔵
3569 | $REPLACE_出典
3570 | $REPLACE_神聖
3571 | $REPLACE_降
3572 | $REPLACE_奏
3573 | $REPLACE_##朝
3574 | $REPLACE_衛生
3575 | $REPLACE_損傷
3576 | $REPLACE_##かく
3577 | $REPLACE_##たか
3578 | $REPLACE_興味
3579 | $REPLACE_奇
3580 | $REPLACE_液
3581 | $REPLACE_国鉄
3582 | $APPEND_こく
3583 | $REPLACE_ミ
3584 | $REPLACE_ワ
3585 | $REPLACE_創始
3586 | $REPLACE_支線
3587 | $REPLACE_おこ
3588 | $REPLACE_あいだ
3589 | $REPLACE_引用
3590 | $APPEND_亭
3591 | $REPLACE_個々
3592 | $REPLACE_途中
3593 | $REPLACE_##マン
3594 | $APPEND_##ou
3595 | $REPLACE_おかし
3596 | $REPLACE_時刻
3597 | $REPLACE_夢
3598 | $REPLACE_テ
3599 | $REPLACE_同市
3600 | $REPLACE_長編
3601 | $REPLACE_ちゃん
3602 | $REPLACE_総会
3603 | $REPLACE_階段
3604 | $REPLACE_同僚
3605 | $REPLACE_##葉
3606 | $REPLACE_##内
3607 | $APPEND_ドン
3608 | $REPLACE_##ぶん
3609 | $REPLACE_悲
3610 | $REPLACE_歯科
3611 | $APPEND_限
3612 | $REPLACE_かかっ
3613 | $REPLACE_尚書
3614 | $REPLACE_##ワ
3615 | $REPLACE_ノン
3616 | $REPLACE_幹事
3617 | $APPEND_準備
3618 | $REPLACE_変わる
3619 | $REPLACE_対照
3620 | $APPEND_都市
3621 | $REPLACE_えい
3622 | $APPEND_補
3623 | $REPLACE_しま
3624 | $REPLACE_ぼ
3625 | $REPLACE_眼
3626 | $REPLACE_”
3627 | $REPLACE_証拠
3628 | $APPEND_洞
3629 | $REPLACE_委託
3630 | $REPLACE_##ック
3631 | $REPLACE_行事
3632 | $APPEND_ツ
3633 | $REPLACE_変わら
3634 | $APPEND_調
3635 | $REPLACE_'
3636 | $REPLACE_チャ
3637 | $REPLACE_春
3638 | $APPEND_##語
3639 | $APPEND_##える
3640 | $REPLACE_下院
3641 | $APPEND_駐
3642 | $REPLACE_夫妻
3643 | $REPLACE_盗塁
3644 | $REPLACE_げ
3645 | $APPEND_近
3646 | $APPEND_ぜ
3647 | $APPEND_おお
3648 | $APPEND_声
3649 | $REPLACE_##モ
3650 | $APPEND_しゅ
3651 | $REPLACE_素晴らしい
3652 | $APPEND_人々
3653 | $APPEND_安
3654 | $REPLACE_リン酸
3655 | $REPLACE_あき
3656 | $REPLACE_前線
3657 | $REPLACE_関係
3658 | $REPLACE_結果
3659 | $APPEND_広大
3660 | $REPLACE_周期
3661 | $REPLACE_要する
3662 | $REPLACE_界
3663 | $APPEND_がっ
3664 | $APPEND_じゅう
3665 | $REPLACE_##直
3666 | $APPEND_}
3667 | $REPLACE_利
3668 | $REPLACE_肉
3669 | $APPEND_##ず
3670 | $REPLACE_いっしょ
3671 | $REPLACE_ソナタ
3672 | $REPLACE_各社
3673 | $REPLACE_師匠
3674 | $APPEND_有利
3675 | $REPLACE_もつ
3676 | $REPLACE_一生
3677 | $REPLACE_寄付
3678 | $REPLACE_稼働
3679 | $REPLACE_回答
3680 | $REPLACE_警視
3681 | $APPEND_尚
3682 | $REPLACE_雇用
3683 | $REPLACE_延
3684 | $REPLACE_天王
3685 | $REPLACE_要因
3686 | $REPLACE_立て
3687 | $REPLACE_ジャ
3688 | $REPLACE_どおり
3689 | $REPLACE_徳
3690 | $APPEND_ス
3691 | $APPEND_知ら
3692 | $REPLACE_ともない
3693 | $REPLACE_岳
3694 | $REPLACE_ぎょう
3695 | $APPEND_##ぱ
3696 | $APPEND_ほ
3697 | $REPLACE_##ダ
3698 | $REPLACE_おと
3699 | $APPEND_よかっ
3700 | $REPLACE_へい
3701 | $REPLACE_幹部
3702 | $APPEND_夢
3703 | $APPEND_以上
3704 | $REPLACE_科目
3705 | $REPLACE_南海
3706 | $REPLACE_##段
3707 | $REPLACE_路
3708 | $REPLACE_結局
3709 | $REPLACE_容量
3710 | $REPLACE_付
3711 | $REPLACE_ロン
3712 | $REPLACE_自信
3713 | $REPLACE_##色
3714 | $REPLACE_応
3715 | $REPLACE_独
3716 | $REPLACE_ぐ
3717 | $APPEND_弟
3718 | $APPEND_4
3719 | $REPLACE_経緯
3720 | $REPLACE_自主
3721 | $REPLACE_7
3722 | $REPLACE_出会い
3723 | $APPEND_いえ
3724 | $REPLACE_たっ
3725 | $REPLACE_修行
3726 | $APPEND_##?
3727 | $REPLACE_対称
3728 | $REPLACE_女神
3729 | $REPLACE_要
3730 | $REPLACE_皇子
3731 | $APPEND_学生
3732 | $REPLACE_味方
3733 | $REPLACE_称
3734 | $APPEND_##しゃ
3735 | $REPLACE_みず
3736 | $REPLACE_臣
3737 | $REPLACE_ひがし
3738 | $REPLACE_はい
3739 | $APPEND_経験
3740 | $REPLACE_平行
3741 | $REPLACE_樹
3742 | $APPEND_ほしい
3743 | $APPEND_##ふ
3744 | $APPEND_現在
3745 | $REPLACE_u
3746 | $REPLACE_環状
3747 | $APPEND_歴
3748 | $REPLACE_勧め
3749 | $APPEND_c
3750 | $APPEND_犯
3751 | $REPLACE_##なか
3752 | $REPLACE_教養
3753 | $REPLACE_旗艦
3754 | $REPLACE_鉱山
3755 | $REPLACE_講義
3756 | $REPLACE_##ずる
3757 | $REPLACE_##代
3758 | $APPEND_##や
3759 | $REPLACE_行か
3760 | $REPLACE_教科
3761 | $REPLACE_最小
3762 | $REPLACE_繊維
3763 | $REPLACE_理
3764 | $APPEND_b
3765 | $REPLACE_せき
3766 | $REPLACE_ストーリー
3767 | $APPEND_分から
3768 | $REPLACE_綬章
3769 | $REPLACE_せよ
3770 | $REPLACE_商工
3771 | $REPLACE_戻っ
3772 | $REPLACE_略
3773 | $REPLACE_極
3774 | $REPLACE_##量
3775 | $REPLACE_果
3776 | $REPLACE_公会
3777 | $REPLACE_##ぁ
3778 | $REPLACE_判
3779 | $REPLACE_焼失
3780 | $REPLACE_協同
3781 | $REPLACE_分かっ
3782 | $APPEND_ましょう
3783 | $REPLACE_付き
3784 | $REPLACE_##ター
3785 | $REPLACE_欲しい
3786 | $APPEND_世界
3787 | $REPLACE_##工
3788 | $REPLACE_尽力
3789 | $REPLACE_ゲイ
3790 | $REPLACE_割合
3791 | $APPEND_刊
3792 | $REPLACE_工作
3793 | $REPLACE_紛争
3794 | $REPLACE_##用
3795 | $REPLACE_##エ
3796 | $APPEND_さい
3797 | $APPEND_##然
3798 | $REPLACE_与え
3799 | $REPLACE_底
3800 | $REPLACE_疾患
3801 | $REPLACE_チャン
3802 | $APPEND_三
3803 | $APPEND_##せる
3804 | $APPEND_微
3805 | $APPEND_公的
3806 | $REPLACE_完全
3807 | $REPLACE_##もう
3808 | $REPLACE_後年
3809 | $REPLACE_後退
3810 | $REPLACE_らしい
3811 | $REPLACE_ホ
3812 | $REPLACE_降下
3813 | $REPLACE_支店
3814 | $APPEND_“
3815 | $APPEND_実際
3816 | $REPLACE_##つく
3817 | $REPLACE_注
3818 | $REPLACE_団地
3819 | $REPLACE_##ブ
3820 | $APPEND_いっぱい
3821 | $REPLACE_わかっ
3822 | $REPLACE_上方
3823 | $REPLACE_##オ
3824 | $APPEND_書か
3825 | $REPLACE_欄
3826 | $REPLACE_##方
3827 | $REPLACE_水系
3828 | $REPLACE_棋
3829 | $REPLACE_支社
3830 | $APPEND_がち
3831 | $APPEND_質問
3832 | $REPLACE_書か
3833 | $REPLACE_##官
3834 | $APPEND_ヤ
3835 | $REPLACE_地中
3836 | $REPLACE_家系
3837 | $REPLACE_グループ
3838 | $REPLACE_不振
3839 | $APPEND_はるか
3840 | $REPLACE_分かる
3841 | $REPLACE_常務
3842 | $REPLACE_完
3843 | $APPEND_早
3844 | $REPLACE_交際
3845 | $REPLACE_磁気
3846 | $REPLACE_本体
3847 | $REPLACE_ござい
3848 | $REPLACE_禁じ
3849 | $APPEND_純
3850 | $REPLACE_運
3851 | $REPLACE_##ネ
3852 | $REPLACE_##添
3853 | $REPLACE_宗
3854 | $REPLACE_食べ物
3855 | $APPEND_##文
3856 | $APPEND_曲
3857 | $REPLACE_やっと
3858 | $APPEND_天気
3859 | $REPLACE_付け
3860 | $REPLACE_コン
3861 | $REPLACE_##こし
3862 | $REPLACE_素
3863 | $REPLACE_悪
3864 | $REPLACE_全身
3865 | $REPLACE_秘書
3866 | $REPLACE_胃
3867 | $APPEND_##ぶ
3868 | $APPEND_やり
3869 | $REPLACE_後世
3870 | $REPLACE_文芸
3871 | $REPLACE_供用
3872 | $REPLACE_##間
3873 | $REPLACE_腺
3874 | $REPLACE_##たく
3875 | $REPLACE_なくなり
3876 | $APPEND_よろ
3877 | $APPEND_けい
3878 | $APPEND_貴
3879 | $REPLACE_もたらし
3880 | $APPEND_ヶ
3881 | $REPLACE_粉
3882 | $REPLACE_##ット
3883 | $REPLACE_申し
3884 | $REPLACE_需要
3885 | $APPEND_まれ
3886 | $REPLACE_銘
3887 | $REPLACE_##テ
3888 | $REPLACE_れん
3889 | $REPLACE_##レ
3890 | $REPLACE_景
3891 | $REPLACE_医
3892 | $APPEND_単一
3893 | $REPLACE_電化
3894 | $REPLACE_##かけ
3895 | $REPLACE_進水
3896 | $REPLACE_やすい
3897 | $REPLACE_ロウ
3898 | $APPEND_漢
3899 | $REPLACE_商学
3900 | $APPEND_ー
3901 | $REPLACE_##点
3902 | $REPLACE_句
3903 | $APPEND_さま
3904 | $REPLACE_しも
3905 | $REPLACE_がた
3906 | $APPEND_地
3907 | $APPEND_クラス
3908 | $REPLACE_まる
3909 | $REPLACE_のりば
3910 | $REPLACE_少佐
3911 | $APPEND_##ね
3912 | $REPLACE_ちゃんと
3913 | $APPEND_##ス
3914 | $REPLACE_毎日
3915 | $REPLACE_イー
3916 | $APPEND_させる
3917 | $REPLACE_数々
3918 | $REPLACE_措置
3919 | $REPLACE_降り
3920 | $REPLACE_特性
3921 | $REPLACE_模し
3922 | $REPLACE_##けれ
3923 | $APPEND_n
3924 | $REPLACE_天体
3925 | $APPEND_ぞ
3926 | $APPEND_ゆ
3927 | $APPEND_未だ
3928 | $REPLACE_事例
3929 | $REPLACE_人称
3930 | $REPLACE_有機
3931 | $REPLACE_なが
3932 | $REPLACE_異なる
3933 | $REPLACE_シャツ
3934 | $REPLACE_##使
3935 | $APPEND_つけ
3936 | $REPLACE_電
3937 | $REPLACE_奨励
3938 | $REPLACE_##庁
3939 | $REPLACE_最中
3940 | $REPLACE_状
3941 | $APPEND_ドラマ
3942 | $APPEND_##たい
3943 | $REPLACE_全員
3944 | $REPLACE_振動
3945 | $REPLACE_恩
3946 | $APPEND_説明
3947 | $REPLACE_##調
3948 | $REPLACE_ひこ
3949 | $REPLACE_回収
3950 | $REPLACE_乗降
3951 | $REPLACE_器
3952 | $REPLACE_師範
3953 | $REPLACE_存続
3954 | $REPLACE_##国
3955 | $REPLACE_工廠
3956 | $APPEND_##ろう
3957 | $APPEND_はじめ
3958 | $REPLACE_動
3959 | $APPEND_天
3960 | $REPLACE_郷
3961 | $REPLACE_径
3962 | $REPLACE_航行
3963 | $APPEND_それぞれ
3964 | $REPLACE_##天
3965 | $REPLACE_創作
3966 | $APPEND_ワイルド
3967 | $APPEND_ゲーム
3968 | $REPLACE_ろ
3969 | $REPLACE_民
3970 | $REPLACE_引っ越し
3971 | $REPLACE_不正
3972 | $REPLACE_阻害
3973 | $REPLACE_ざん
3974 | $APPEND_##re
3975 | $APPEND_##at
3976 | $REPLACE_通貨
3977 | $REPLACE_辞
3978 | $APPEND_ほか
3979 | $REPLACE_すい
3980 | $REPLACE_なれ
3981 | $REPLACE_##がる
3982 | $REPLACE_試
3983 | $REPLACE_9
3984 | $REPLACE_統制
3985 | $REPLACE_##なく
3986 | $REPLACE_リー
3987 | $REPLACE_ヘ
3988 | $REPLACE_思わ
3989 | $APPEND_なか
3990 | $REPLACE_枝
3991 | $APPEND_##-
3992 | $REPLACE_最も
3993 | $REPLACE_##っき
3994 | $REPLACE_請求
3995 | $REPLACE_盛
3996 | $REPLACE_使者
3997 | $REPLACE_人生
3998 | $REPLACE_動詞
3999 | $REPLACE_先輩
4000 | $APPEND_##おう
4001 | $APPEND_以下
4002 | $REPLACE_##将
4003 | $REPLACE_全面
4004 | $APPEND_起き
4005 | $REPLACE_##ミ
4006 | $REPLACE_禅
4007 | $REPLACE_売り上げ
4008 | $REPLACE_殻
4009 | $REPLACE_##ュ
4010 | $REPLACE_成人
4011 | $REPLACE_似
4012 | $REPLACE_維新
4013 | $APPEND_##浮
4014 | $REPLACE_11
4015 | $REPLACE_梁
4016 | $REPLACE_立教
4017 | $APPEND_いか
4018 | $REPLACE_確率
4019 | $REPLACE_テン
4020 | $REPLACE_涙
4021 | $REPLACE_示す
4022 | $REPLACE_屋敷
4023 | $REPLACE_後楽
4024 | $REPLACE_眠
4025 | $REPLACE_##らさ
4026 | $REPLACE_みなみ
4027 | $REPLACE_生まれる
4028 | $REPLACE_##取る
4029 | $APPEND_確か
4030 | $APPEND_浴
4031 | $APPEND_ゆっくり
4032 | $REPLACE_囲碁
4033 | $REPLACE_大きい
4034 | $REPLACE_##出
4035 | $APPEND_池
4036 | $REPLACE_##艦
4037 | $APPEND_我々
4038 | $REPLACE_再会
4039 | $REPLACE_校区
4040 | $APPEND_添
4041 | $REPLACE_丘陵
4042 | $REPLACE_##もっ
4043 | $APPEND_重
4044 | $REPLACE_専
4045 | $REPLACE_公使
4046 | $REPLACE_##星
4047 | $APPEND_或
4048 | $REPLACE_##m
4049 | $REPLACE_種子
4050 | $REPLACE_隊長
4051 | $APPEND_邸
4052 | $APPEND_##物
4053 | $APPEND_##ing
4054 | $REPLACE_全車
4055 | $REPLACE_命
4056 | $REPLACE_参事
4057 | $REPLACE_のり
4058 | $REPLACE_基盤
4059 | $REPLACE_おおむね
4060 | $REPLACE_##領
4061 | $REPLACE_失
4062 | $REPLACE_##つか
4063 | $APPEND_江
4064 | $REPLACE_坊
4065 | $REPLACE_珍しい
4066 | $APPEND_p
4067 | $APPEND_かく
4068 | $REPLACE_床
4069 | $REPLACE_玄
4070 | $REPLACE_ぐらい
4071 | $REPLACE_劉
4072 | $REPLACE_折
4073 | $REPLACE_学士
4074 | $REPLACE_##質
4075 | $REPLACE_赴任
4076 | $REPLACE_出港
4077 | $REPLACE_教鞭
4078 | $APPEND_部分
4079 | $REPLACE_##上がる
4080 | $REPLACE_西端
4081 | $REPLACE_##政
4082 | $REPLACE_##足
4083 | $REPLACE_出る
4084 | $APPEND_抗
4085 | $APPEND_##そう
4086 | $REPLACE_感覚
4087 | $REPLACE_生かし
4088 | $REPLACE_夕方
4089 | $REPLACE_砲塔
4090 | $REPLACE_まえ
4091 | $REPLACE_いや
4092 | $APPEND_体
4093 | $REPLACE_磁
4094 | $REPLACE_##なり
4095 | $APPEND_尋常
4096 | $APPEND_頑
4097 | $REPLACE_無い
4098 | $REPLACE_亜科
4099 | $APPEND_建
4100 | $REPLACE_首位
4101 | $REPLACE_三角
4102 | $REPLACE_りょう
4103 | $REPLACE_産
4104 | $REPLACE_照明
4105 | $APPEND_5
4106 | $REPLACE_再現
4107 | $APPEND_特有
4108 | $REPLACE_##ゴ
4109 | $REPLACE_後ろ
4110 | $REPLACE_内戦
4111 | $REPLACE_養
4112 | $REPLACE_膜
4113 | $APPEND_ころ
4114 | $REPLACE_生態
4115 | $REPLACE_ともさ
4116 | $REPLACE_鼻
4117 | $REPLACE_旧名
4118 | $REPLACE_協奏
4119 | $REPLACE_防
4120 | $APPEND_##ど
4121 | $REPLACE_##めん
4122 | $APPEND_8
4123 | $APPEND_生徒
4124 | $REPLACE_照
4125 | $REPLACE_円形
4126 | $REPLACE_元帥
4127 | $REPLACE_人間
4128 | $APPEND_自身
4129 | $REPLACE_##せい
4130 | $REPLACE_灰
4131 | $REPLACE_##つけ
4132 | $REPLACE_全英
4133 | $REPLACE_吉
4134 | $REPLACE_阻止
4135 | $REPLACE_苦手
4136 | $APPEND_順調
4137 | $APPEND_はい
4138 | $REPLACE_きた
4139 | $REPLACE_巣
4140 | $REPLACE_達し
4141 | $REPLACE_符号
4142 | $REPLACE_基金
4143 | $REPLACE_たま
4144 | $APPEND_見え
4145 | $APPEND_言え
4146 | $APPEND_最適
4147 | $TRANSFORM_VB_VBV
4148 | $REPLACE_標
4149 | $REPLACE_会見
4150 | $REPLACE_大好き
4151 | $REPLACE_##わら
4152 | $REPLACE_孝
4153 | $REPLACE_上院
4154 | $APPEND_過去
4155 | $APPEND_来る
4156 | $REPLACE_終結
4157 | $REPLACE_##かす
4158 | $REPLACE_同誌
4159 | $REPLACE_域
4160 | $REPLACE_起こす
4161 | $REPLACE_唐
4162 | $APPEND_てん
4163 | $REPLACE_帰投
4164 | $REPLACE_甲板
4165 | $APPEND_なさい
4166 | $REPLACE_頑
4167 | $APPEND_単なる
4168 | $REPLACE_乗り
4169 | $REPLACE_渓谷
4170 | $REPLACE_やる
4171 | $REPLACE_気温
4172 | $REPLACE_やめ
4173 | $REPLACE_装飾
4174 | $REPLACE_高層
4175 | $REPLACE_##身
4176 | $REPLACE_口径
4177 | $APPEND_##分
4178 | $APPEND_電話
4179 | $REPLACE_サイ
4180 | $REPLACE_買い
4181 | $REPLACE_##p
4182 | $APPEND_番
4183 | $APPEND_楽しみ
4184 | $APPEND_思わ
4185 | $REPLACE_殆ど
4186 | $REPLACE_在任
4187 | $REPLACE_主導
4188 | $REPLACE_##がら
4189 | $REPLACE_色々
4190 | $REPLACE_たん
4191 | $REPLACE_女
4192 | $APPEND_作っ
4193 | $REPLACE_魔
4194 | $REPLACE_正体
4195 | $REPLACE_先住
4196 | $REPLACE_##付
4197 | $APPEND_きょう
4198 | $REPLACE_党首
4199 | $REPLACE_復
4200 | $REPLACE_車内
4201 | $APPEND_##ひ
4202 | $REPLACE_現
4203 | $REPLACE_雌
4204 | $APPEND_##te
4205 | $REPLACE_学習
4206 | $REPLACE_東南
4207 | $REPLACE_##r
4208 | $REPLACE_受ける
4209 | $REPLACE_##校
4210 | $REPLACE_##かん
4211 | $APPEND_対する
4212 | $APPEND_一般
4213 | $REPLACE_つながっ
4214 | $REPLACE_近
4215 | $REPLACE_献
4216 | $REPLACE_及び
4217 | $REPLACE_はく
4218 | $REPLACE_社団
4219 | $REPLACE_上り
4220 | $REPLACE_改訂
4221 | $REPLACE_累計
4222 | $REPLACE_王子
4223 | $REPLACE_鉱
4224 | $REPLACE_タン
4225 | $REPLACE_全く
4226 | $REPLACE_忘れ
4227 | $APPEND_知的
4228 | $REPLACE_なき
4229 | $REPLACE_缶
4230 | $REPLACE_語学
4231 | $REPLACE_降板
4232 | $REPLACE_##目
4233 | $REPLACE_鍵
4234 | $REPLACE_さき
4235 | $TRANSFORM_VBI_VBV
4236 | $REPLACE_宣教
4237 | $REPLACE_落
4238 | $APPEND_紹介
4239 | $APPEND_強
4240 | $REPLACE_栄養
4241 | $REPLACE_##門
4242 | $REPLACE_横断
4243 | $REPLACE_##手
4244 | $REPLACE_息
4245 | $APPEND_さらなる
4246 | $REPLACE_向け
4247 | $REPLACE_聖職
4248 | $REPLACE_専修
4249 | $REPLACE_党員
4250 | $REPLACE_バス
4251 | $REPLACE_なれる
4252 | $REPLACE_面する
4253 | $REPLACE_彗星
4254 | $REPLACE_まい
4255 | $REPLACE_探検
4256 | $REPLACE_製薬
4257 | $REPLACE_##週
4258 | $APPEND_##にち
4259 | $REPLACE_まさ
4260 | $APPEND_有数
4261 | $REPLACE_新潮
4262 | $REPLACE_個別
4263 | $REPLACE_ニン
4264 | $REPLACE_堂
4265 | $REPLACE_##口
4266 | $REPLACE_忙
4267 | $APPEND_##^
4268 | $REPLACE_行使
4269 | $REPLACE_残り
4270 | $APPEND_たろう
4271 | $REPLACE_肺
4272 | $REPLACE_会談
4273 | $APPEND_K
4274 | $APPEND_いま
4275 | $APPEND_##)
4276 | $REPLACE_万
4277 | $REPLACE_バカ
4278 | $REPLACE_知り合い
4279 | $REPLACE_詳しく
4280 | $REPLACE_征服
4281 | $REPLACE_同意
4282 | $REPLACE_都内
4283 | $REPLACE_対空
4284 | $REPLACE_少ない
4285 | $APPEND_ひら
4286 | $REPLACE_経験
4287 | $APPEND_##だん
4288 | $REPLACE_##切り
4289 | $REPLACE_蜂起
4290 | $REPLACE_輪
4291 | $REPLACE_##師
4292 | $REPLACE_反発
4293 | $REPLACE_現職
4294 | $REPLACE_高山
4295 | $REPLACE_退役
4296 | $REPLACE_薨去
4297 | $APPEND_違う
4298 | $REPLACE_逆転
4299 | $TRANSFORM_VBC_VBV
4300 | $REPLACE_宅
4301 | $REPLACE_現在
4302 | $REPLACE_続ける
4303 | $REPLACE_合
4304 | $REPLACE_駒
4305 | $REPLACE_就航
4306 | $APPEND_出し
4307 | $REPLACE_長調
4308 | $REPLACE_署名
4309 | $REPLACE_##らか
4310 | $REPLACE_##)
4311 | $REPLACE_固
4312 | $APPEND_廃
4313 | $REPLACE_喜
4314 | $REPLACE_勝負
4315 | $REPLACE_取っ
4316 | $APPEND_しん
4317 | $REPLACE_教団
4318 | $REPLACE_うつ
4319 | $APPEND_##訳
4320 | $REPLACE_##上げる
4321 | $REPLACE_嵐
4322 | $REPLACE_特典
4323 | $REPLACE_双
4324 | $REPLACE_##がし
4325 | $REPLACE_はら
4326 | $APPEND_司
4327 | $REPLACE_事項
4328 | $REPLACE_庁舎
4329 | $REPLACE_療法
4330 | $REPLACE_闘争
4331 | $REPLACE_令
4332 | $APPEND_斎
4333 | $REPLACE_住
4334 | $REPLACE_強く
4335 | $REPLACE_移る
4336 | $APPEND_残念
4337 | $REPLACE_内容
4338 | $REPLACE_##パ
4339 | $REPLACE_##跡
4340 | $REPLACE_おく
4341 | $REPLACE_激しく
4342 | $REPLACE_最寄り
4343 | $REPLACE_おき
4344 | $REPLACE_統
4345 | $REPLACE_サル
4346 | $REPLACE_総長
4347 | $REPLACE_##越
4348 | $REPLACE_##デ
4349 | $APPEND_長い
4350 | $APPEND_過ごし
4351 | $APPEND_今回
4352 | $REPLACE_作る
4353 | $REPLACE_忠
4354 | $REPLACE_起こる
4355 | $REPLACE_受け取っ
4356 | $APPEND_良好
4357 | $REPLACE_補給
4358 | $APPEND_丁寧
4359 | $APPEND_痛
4360 | $REPLACE_佐
4361 | $APPEND_とり
4362 | $REPLACE_一行
4363 | $APPEND_音
4364 | $REPLACE_学び
4365 | $APPEND_二人
4366 | $REPLACE_##んな
4367 | $REPLACE_針
4368 | $APPEND_庵
4369 | $REPLACE_センチ
4370 | $REPLACE_##種
4371 | $REPLACE_##グ
4372 | $APPEND_クリエイティブ
4373 | $REPLACE_乗っ
4374 | $APPEND_音楽
4375 | $APPEND_参加
4376 | $APPEND_単語
4377 | $REPLACE_油
4378 | $REPLACE_辺
4379 | $REPLACE_保証
4380 | $REPLACE_くろ
4381 | $REPLACE_##h
4382 | $REPLACE_禁
4383 | $REPLACE_初戦
4384 | $APPEND_タ
4385 | $REPLACE_継
4386 | $REPLACE_代々
4387 | $APPEND_学習
4388 | $REPLACE_附
4389 | $REPLACE_上がっ
4390 | $APPEND_かた
4391 | $REPLACE_看護
4392 | $REPLACE_西岸
4393 | $REPLACE_##城
4394 | $REPLACE_法科
4395 | $REPLACE_艦艇
4396 | $REPLACE_登記
4397 | $REPLACE_見つから
4398 | $REPLACE_こうえん
4399 | $REPLACE_もらえ
4400 | $APPEND_卒業
4401 | $REPLACE_考える
4402 | $REPLACE_##げる
4403 | $REPLACE_井
4404 | $REPLACE_戒
4405 | $REPLACE_まつり
4406 | $REPLACE_留まっ
4407 | $REPLACE_英語
4408 | $REPLACE_役者
4409 | $REPLACE_甲斐
4410 | $REPLACE_報
4411 | $REPLACE_虎
4412 | $REPLACE_新興
4413 | $REPLACE_##・
4414 | $REPLACE_こく
4415 | $REPLACE_旅行
4416 | $APPEND_考える
4417 | $REPLACE_天気
4418 | $APPEND_モ
4419 | $REPLACE_当たる
4420 | $REPLACE_きれい
4421 | $REPLACE_休み
4422 | $APPEND_##]
4423 | $APPEND_読む
4424 | $APPEND_急激
4425 | $REPLACE_特別
4426 | $REPLACE_武士
4427 | $REPLACE_辞退
4428 | $REPLACE_資
4429 | $REPLACE_炎
4430 | $REPLACE_巻き込ま
4431 | $REPLACE_やさ
4432 | $REPLACE_スーパー
4433 | $REPLACE_卒
4434 | $APPEND_懸命
4435 | $REPLACE_心臓
4436 | $REPLACE_さく
4437 | $REPLACE_すみ
4438 | $REPLACE_推理
4439 | $REPLACE_##成
4440 | $REPLACE_賛成
4441 | $REPLACE_症候
4442 | $REPLACE_冊
4443 | $REPLACE_持つ
4444 | $REPLACE_一般
4445 | $REPLACE_##かさ
4446 | $REPLACE_たけ
4447 | $REPLACE_##バー
4448 | $APPEND_沢山
4449 | $APPEND_我
4450 | $APPEND_重大
4451 | $REPLACE_総理
4452 | $REPLACE_しつ
4453 | $REPLACE_ユ
4454 | $REPLACE_会社
4455 | $REPLACE_補佐
4456 | $REPLACE_おいしい
4457 | $APPEND_##se
4458 | $APPEND_部屋
4459 | $APPEND_今年
4460 | $APPEND_婦
4461 | $REPLACE_ふたり
4462 | $REPLACE_遺構
4463 | $REPLACE_聞こえ
4464 | $REPLACE_##せん
4465 | $APPEND_近く
4466 | $REPLACE_有料
4467 | $REPLACE_次官
4468 | $REPLACE_##o
4469 | $REPLACE_昔
4470 | $APPEND_覚え
4471 | $APPEND_いれ
4472 | $APPEND_##かけ
4473 | $REPLACE_関わっ
4474 | $REPLACE_比例
4475 | $REPLACE_再興
4476 | $REPLACE_中隊
4477 | $REPLACE_関わる
4478 | $APPEND_つか
4479 | $REPLACE_ぶり
4480 | $REPLACE_家老
4481 | $REPLACE_間違い
4482 | $APPEND_温暖
4483 | $REPLACE_視覚
4484 | $REPLACE_伸ばし
4485 | $REPLACE_拘束
4486 | $REPLACE_庁
4487 | $REPLACE_消化
4488 | $APPEND_通り
4489 | $REPLACE_黒人
4490 | $REPLACE_後輩
4491 | $REPLACE_工程
4492 | $REPLACE_##制
4493 | $REPLACE_あさ
4494 | $REPLACE_ラ
4495 | $REPLACE_##つける
4496 | $REPLACE_史学
4497 | $REPLACE_シュ
4498 | $REPLACE_引き
4499 | $REPLACE_自然
4500 | $APPEND_鬼
4501 | $APPEND_安価
4502 | $REPLACE_無し
4503 | $REPLACE_薬
4504 | $REPLACE_修復
4505 | $REPLACE_投下
4506 | $REPLACE_配
4507 | $REPLACE_タバコ
4508 | $REPLACE_なお
4509 | $REPLACE_伝え
4510 | $REPLACE_編纂
4511 | $REPLACE_脂肪
4512 | $APPEND_It
4513 | $APPEND_もらえ
4514 | $APPEND_答え
4515 | $REPLACE_管制
4516 | $APPEND_鋼
4517 | $APPEND_絵
4518 | $REPLACE_カッコ
4519 | $REPLACE_糖
4520 | $REPLACE_立ち
4521 | $REPLACE_首席
4522 | $APPEND_帰っ
4523 | $REPLACE_広まっ
4524 | $REPLACE_炭
4525 | $REPLACE_捕手
4526 | $REPLACE_海底
4527 | $APPEND_外国
4528 | $APPEND_見つけ
4529 | $REPLACE_##ッ
4530 | $REPLACE_県庁
4531 | $REPLACE_小さい
4532 | $APPEND_##ン
4533 | $REPLACE_早
4534 | $REPLACE_伝わっ
4535 | $APPEND_共
4536 | $REPLACE_超
4537 | $REPLACE_##王
4538 | $REPLACE_庄
4539 | $REPLACE_交戦
4540 | $APPEND_##ざ
4541 | $REPLACE_絶対
4542 | $REPLACE_まんが
4543 | $APPEND_従
4544 | $APPEND_##ぐ
4545 | $REPLACE_##じょう
4546 | $APPEND_全て
4547 | $APPEND_##r
4548 | $REPLACE_警告
4549 | $APPEND_サ
4550 | $REPLACE_偏
4551 | $REPLACE_回想
4552 | $REPLACE_原型
4553 | $APPEND_証
4554 | $REPLACE_開戦
4555 | $REPLACE_定員
4556 | $REPLACE_表紙
4557 | $REPLACE_風邪
4558 | $REPLACE_投じ
4559 | $REPLACE_##科
4560 | $REPLACE_選帝
4561 | $REPLACE_##っくり
4562 | $REPLACE_互換
4563 | $REPLACE_食料
4564 | $REPLACE_想像
4565 | $REPLACE_間隔
4566 | $REPLACE_組み込ま
4567 | $REPLACE_隣
4568 | $REPLACE_始
4569 | $REPLACE_20
4570 | $REPLACE_果実
4571 | $REPLACE_##返
4572 | $REPLACE_##だい
4573 | $REPLACE_##ばっ
4574 | $REPLACE_犬種
4575 | $REPLACE_陰
4576 | $REPLACE_称し
4577 | $REPLACE_ヶ国
4578 | $REPLACE_回路
4579 | $REPLACE_##ビ
4580 | $REPLACE_サイト
4581 | $REPLACE_スポーツ
4582 | $REPLACE_楽しく
4583 | $APPEND_時代
4584 | $REPLACE_残念
4585 | $REPLACE_俗
4586 | $REPLACE_近縁
4587 | $REPLACE_尺
4588 | $APPEND_じょ
4589 | $REPLACE_芸
4590 | $REPLACE_招集
4591 | $REPLACE_##y
4592 | $REPLACE_干渉
4593 | $REPLACE_寄生
4594 | $REPLACE_##付ける
4595 | $APPEND_静か
4596 | $APPEND_双
4597 | $APPEND_ぴ
4598 | $REPLACE_##書き
4599 | $APPEND_シンプル
4600 | $REPLACE_下流
4601 | $APPEND_余
4602 | $REPLACE_チョウ
4603 | $APPEND_休み
4604 | $APPEND_高い
4605 | $REPLACE_電波
4606 | $REPLACE_めい
4607 | $APPEND_直ちに
4608 | $REPLACE_監査
4609 | $REPLACE_進
4610 | $APPEND_父
4611 | $REPLACE_実践
4612 | $REPLACE_結
4613 | $REPLACE_同性
4614 | $REPLACE_空手
4615 | $REPLACE_搭乗
4616 | $REPLACE_硬貨
4617 | $REPLACE_司
4618 | $REPLACE_号し
4619 | $REPLACE_見せ
4620 | $REPLACE_今年
4621 | $REPLACE_回避
4622 | $REPLACE_##高
4623 | $APPEND_##んな
4624 | $APPEND_人間
4625 | $REPLACE_r
4626 | $REPLACE_##賞
4627 | $REPLACE_繁栄
4628 | $APPEND_番組
4629 | $REPLACE_##メ
4630 | $REPLACE_集まり
4631 | $APPEND_無事
4632 | $REPLACE_調
4633 | $REPLACE_前日
4634 | $REPLACE_提示
4635 | $REPLACE_『
4636 | $REPLACE_妖怪
4637 | $REPLACE_in
4638 | $APPEND_らしく
4639 | $REPLACE_##わし
4640 | $REPLACE_みち
4641 | $REPLACE_浮上
4642 | $APPEND_##子
4643 | $APPEND_##食
4644 | $APPEND_ふう
4645 | $REPLACE_ドラマ
4646 | $APPEND_##っち
4647 | $REPLACE_結晶
4648 | $REPLACE_奉
4649 | $REPLACE_メイク
4650 | $REPLACE_招聘
4651 | $REPLACE_所管
4652 | $REPLACE_如来
4653 | $REPLACE_攻勢
4654 | $APPEND_すけ
4655 | $APPEND_壁
4656 | $APPEND_オフィシャル
4657 | $REPLACE_ノット
4658 | $APPEND_l
4659 | $REPLACE_大切
4660 | $REPLACE_説明
4661 | $REPLACE_加わり
4662 | $REPLACE_付加
4663 | $REPLACE_率いる
4664 | $APPEND_おう
4665 | $REPLACE_アマ
4666 | $REPLACE_つく
4667 | $REPLACE_バー
4668 | $REPLACE_はじまり
4669 | $REPLACE_30
4670 | $APPEND_嬉
4671 | $REPLACE_今度
4672 | $REPLACE_妨害
4673 | $REPLACE_持
4674 | $REPLACE_##プ
4675 | $REPLACE_ついに
4676 | $APPEND_しょう
4677 | $REPLACE_##らす
4678 | $REPLACE_習
4679 | $REPLACE_##球
4680 | $REPLACE_d
4681 | $REPLACE_##"
4682 | $REPLACE_良かっ
4683 | $REPLACE_認知
4684 | $APPEND_もっ
4685 | $REPLACE_ハット
4686 | $REPLACE_##声
4687 | $REPLACE_怒り
4688 | $REPLACE_挿入
4689 | $APPEND_アー
4690 | $APPEND_##ょ
4691 | $APPEND_状態
4692 | $REPLACE_##から
4693 | $APPEND_きれい
4694 | $REPLACE_描い
4695 | $REPLACE_気動
4696 | $REPLACE_わかる
4697 | $REPLACE_チーム
4698 | $REPLACE_さえ
4699 | $REPLACE_』
4700 | $REPLACE_ドラ
4701 | $REPLACE_理工
4702 | $REPLACE_衰退
4703 | $APPEND_審
4704 | $REPLACE_外観
4705 | $APPEND_##ut
4706 | $APPEND_正しい
4707 | $APPEND_心配
4708 | $APPEND_依然
4709 | $REPLACE_暗
4710 | $REPLACE_##ケ
4711 | $REPLACE_##臣
4712 | $REPLACE_広がり
4713 | $REPLACE_##張る
4714 | $REPLACE_岬
4715 | $APPEND_進
4716 | $APPEND_油
4717 | $REPLACE_聞く
4718 | $APPEND_##ト
4719 | $REPLACE_しばらく
4720 | $APPEND_##or
4721 | $REPLACE_付い
4722 | $REPLACE_専念
4723 | $REPLACE_砂
4724 | $REPLACE_打
4725 | $REPLACE_残し
4726 | $REPLACE_必ず
4727 | $APPEND_種類
4728 | $APPEND_食べる
4729 | $REPLACE_##もの
4730 | $REPLACE_乗用
4731 | $REPLACE_おか
4732 | $REPLACE_擁する
4733 | $REPLACE_##ごと
4734 | $REPLACE_離れ
4735 | $APPEND_##だち
4736 | $APPEND_##i
4737 | $REPLACE_フラン
4738 | $REPLACE_視点
4739 | $APPEND_嫌
4740 | $REPLACE_買っ
4741 | $APPEND_元
4742 | $REPLACE_伝道
4743 | $REPLACE_歓
4744 | $REPLACE_日系
4745 | $REPLACE_交わし
4746 | $APPEND_互い
4747 | $REPLACE_簡
4748 | $APPEND_多大
4749 | $APPEND_もはや
4750 | $REPLACE_被
4751 | $REPLACE_##交
4752 | $REPLACE_挿絵
4753 | $REPLACE_鏡
4754 | $REPLACE_酸
4755 | $APPEND_突如
4756 | $APPEND_けんど
4757 | $REPLACE_##学
4758 | $REPLACE_下し
4759 | $APPEND_あら
4760 | $APPEND_ぽ
4761 | $REPLACE_摂政
4762 | $REPLACE_貝
4763 | $REPLACE_脇
4764 | $REPLACE_15
4765 | $REPLACE_ずつ
4766 | $REPLACE_##開
4767 | $REPLACE_さわ
4768 | $REPLACE_原料
4769 | $REPLACE_銭
4770 | $REPLACE_l
4771 | $REPLACE_大同
4772 | $REPLACE_女の子
4773 | $REPLACE_そもそも
4774 | $REPLACE_掲示
4775 | $APPEND_偉大
4776 | $APPEND_しっかり
4777 | $REPLACE_神学
4778 | $REPLACE_まっ
4779 | $REPLACE_クラス
4780 | $APPEND_楽しい
4781 | $REPLACE_議
4782 | $REPLACE_懐
4783 | $APPEND_見事
4784 | $REPLACE_開店
4785 | $REPLACE_居士
4786 | $REPLACE_原産
4787 | $REPLACE_メッセージ
4788 | $APPEND_訳
4789 | $REPLACE_制約
4790 | $REPLACE_くら
4791 | $REPLACE_楕円
4792 | $REPLACE_ぜ
4793 | $REPLACE_扇
4794 | $REPLACE_前進
4795 | $APPEND_血
4796 | $REPLACE_勅
4797 | $REPLACE_婦人
4798 | $REPLACE_ガイ
4799 | $REPLACE_グラム
4800 | $REPLACE_住所
4801 | $REPLACE_##達
4802 | $REPLACE_入口
4803 | $REPLACE_##さい
4804 | $APPEND_##中
4805 | $REPLACE_韓
4806 | $REPLACE_繋がっ
4807 | $APPEND_忠実
4808 | $REPLACE_楽しめる
4809 | $REPLACE_導体
4810 | $REPLACE_学ぶ
4811 | $REPLACE_だら
4812 | $APPEND_##ch
4813 | $APPEND_##ans
4814 | $REPLACE_車種
4815 | $REPLACE_携わる
4816 | $REPLACE_##語
4817 | $REPLACE_限界
4818 | $REPLACE_行方
4819 | $APPEND_はる
4820 | $REPLACE_異動
4821 |
--------------------------------------------------------------------------------