├── speechgpt-gen ├── imgs │ ├── readme.md │ ├── coi.png │ └── model.png └── README.md ├── speechgpt ├── imgs │ ├── readme.md │ ├── logo.png │ ├── SpeechGPT-main.png │ ├── speechgpt-intro.png │ ├── cases_cm_inst_follow.png │ └── cases_spoken_dialogue.png ├── prompt │ ├── readme.md │ ├── 0.wav │ └── 1.wav ├── utils │ ├── .DS_Store │ ├── text2unit │ │ ├── spm.model │ │ ├── README.md │ │ ├── text2unit.py │ │ └── binary │ │ │ └── dict.en.txt │ ├── prompter.py │ ├── speech2unit │ │ ├── README.md │ │ └── speech2unit.py │ └── vocoder │ │ ├── vocoder.sh │ │ ├── README.md │ │ └── generate_waveform_from_code.py ├── requirements.txt ├── scripts │ ├── ma_pretrain.sh │ ├── cm_sft.sh │ └── com_sft.sh ├── src │ ├── infer │ │ ├── web_infer.py │ │ └── cli_infer.py │ └── train │ │ ├── cm_sft.py │ │ ├── com_sft.py │ │ └── ma_pretrain.py └── README.md ├── README.md └── LICENSE /speechgpt-gen/imgs/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /speechgpt/imgs/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /speechgpt/prompt/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /speechgpt/imgs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/SpeechGPT/HEAD/speechgpt/imgs/logo.png -------------------------------------------------------------------------------- /speechgpt/prompt/0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/SpeechGPT/HEAD/speechgpt/prompt/0.wav -------------------------------------------------------------------------------- /speechgpt/prompt/1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/SpeechGPT/HEAD/speechgpt/prompt/1.wav -------------------------------------------------------------------------------- /speechgpt-gen/imgs/coi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/SpeechGPT/HEAD/speechgpt-gen/imgs/coi.png -------------------------------------------------------------------------------- /speechgpt/utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/SpeechGPT/HEAD/speechgpt/utils/.DS_Store -------------------------------------------------------------------------------- /speechgpt-gen/imgs/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/SpeechGPT/HEAD/speechgpt-gen/imgs/model.png -------------------------------------------------------------------------------- /speechgpt/imgs/SpeechGPT-main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/SpeechGPT/HEAD/speechgpt/imgs/SpeechGPT-main.png -------------------------------------------------------------------------------- /speechgpt/imgs/speechgpt-intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/SpeechGPT/HEAD/speechgpt/imgs/speechgpt-intro.png -------------------------------------------------------------------------------- /speechgpt/utils/text2unit/spm.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/SpeechGPT/HEAD/speechgpt/utils/text2unit/spm.model -------------------------------------------------------------------------------- /speechgpt/imgs/cases_cm_inst_follow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/SpeechGPT/HEAD/speechgpt/imgs/cases_cm_inst_follow.png -------------------------------------------------------------------------------- /speechgpt/imgs/cases_spoken_dialogue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/SpeechGPT/HEAD/speechgpt/imgs/cases_spoken_dialogue.png -------------------------------------------------------------------------------- /speechgpt/utils/text2unit/README.md: -------------------------------------------------------------------------------- 1 | # Text to unit 2 | The text-to-unit generator adopts a Transformer encoder-decoder architecture. We trained it on LibriSpeech unit-text pairs. 3 | 4 | ## Download 5 | ```bash 6 | t2u_dir="uitls/text2unit" 7 | cd ${t2u_dir} 8 | wget https://huggingface.co/fnlp/text2unit/resolve/main/text2unit.pt 9 | ``` 10 | 11 | # Inference 12 | ```python 13 | python3 utils/text2unit/text2unit.py --text "Today is a good day." 14 | ``` -------------------------------------------------------------------------------- /speechgpt/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.20.3 2 | bitsandbytes==0.37.2 3 | datasets==2.11.0 4 | deepspeed==0.9.0 5 | einops==0.6.0 6 | evaluate==0.4.0 7 | fairseq==0.12.2 8 | fire==0.5.0 9 | gradio==3.30.0 10 | gradio_client==0.2.4 11 | librosa==0.10.0.post2 12 | numpy==1.22.4 13 | pathlib==1.0.1 14 | peft==0.3.0 15 | PyYAML==6.0 16 | pydantic==1.10.7 17 | sentencepiece==0.1.98 18 | soundfile==0.12.1 19 | tensorboard==2.12.2 20 | tokenizers==0.13.3 21 | torch==1.13.1 22 | torchaudio==0.13.1 23 | torchvision==0.14.1 24 | tqdm==4.65.0 25 | transformers==4.33.1 26 | -------------------------------------------------------------------------------- /speechgpt/utils/prompter.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import os.path as osp 4 | from typing import Union, List 5 | 6 | 7 | class Prompter(object): 8 | 9 | def __init__(self, verbose: bool = False): 10 | self._verbose = verbose 11 | 12 | 13 | def generate_prompt( 14 | self, 15 | prefix: str, 16 | text: Union[None, str] = None, 17 | ) -> str: 18 | 19 | res = prefix 20 | if text: 21 | res = f"{res}{text}" 22 | if self._verbose: 23 | print(res) 24 | return res 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /speechgpt/utils/speech2unit/README.md: -------------------------------------------------------------------------------- 1 | # Speech2unit 2 | We employ [mHuBERT](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/textless_s2st_real_data.md) as the speech tokenizer to discretize speech data into discrete units and remove the repetitive units of adjacent frames to get reduced units. 3 | 4 | ## Download 5 | ```bash 6 | s2u_dir="uitls/speech2unit" 7 | cd ${s2u_dir} 8 | wget https://dl.fbaipublicfiles.com/hubert/mhubert_base_vp_en_es_fr_it3.pt 9 | wget https://dl.fbaipublicfiles.com/hubert/mhubert_base_vp_en_es_fr_it3_L11_km1000.bin 10 | ``` 11 | 12 | ## Discretize 13 | ```python 14 | python3 speech2unit.py --wav path/to/wav 15 | ``` -------------------------------------------------------------------------------- /speechgpt/utils/vocoder/vocoder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | UNITS=$1 #units 4 | 5 | VOCODER_DIR="speechgpt/utils/vocoder" 6 | IN_CODE_FILE=${VOCODER_DIR}/in_code_file.txt 7 | VOCODER_CKPT=${VOCODER_DIR}/vocoder.pt 8 | VOCODER_CFG=${VOCODER_DIR}/config.json 9 | RESULTS_PATH="output/wav" 10 | 11 | mkdir -p ${VOCODER_DIR} 12 | mkdir -p ${RESULTS_PATH} 13 | 14 | 15 | if [ ! -f ${VOCODER_CFG} ];then 16 | wget https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/config.json -O ${VOCODER_CFG} 17 | wget https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/g_00500000 -O ${VOCODER_CKPT} 18 | fi 19 | 20 | 21 | echo $UNITS | sed -E 's/[^0-9]+/ /g' > ${IN_CODE_FILE} 22 | 23 | 24 | #genereate file 25 | python3 ${VOCODER_DIR}/generate_waveform_from_code.py \ 26 | --in-code-file ${IN_CODE_FILE} \ 27 | --vocoder ${VOCODER_CKPT} --vocoder-cfg ${VOCODER_CFG} \ 28 | --results-path ${RESULTS_PATH} --dur-prediction 29 | 30 | 31 | -------------------------------------------------------------------------------- /speechgpt/utils/vocoder/README.md: -------------------------------------------------------------------------------- 1 | # Vocoder 2 | We adopt a [unit-based HiFi-GAN vocoder](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/textless_s2st_real_data.md) to convert discrete units back to speech. 3 | 4 | ## Download 5 | You should download the vocoder checkpoint and config files before SpeechGPT inference. 6 | ```bash 7 | vocoder_dir="utils/vocoder/" 8 | cd ${vocoder_dir} 9 | wget https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/config.json -O config.json 10 | wget https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/g_00500000 -O vocoder.pt 11 | ``` 12 | 13 | ## Unit to speech 14 | ```bash 15 | units="<991><741><945><944><579><969><901><202><393><946><734><498><889><172><871><877><822><89><194><620><915><143><38><914><445><469><167><655><764><70><828><347><376><975><955><333><198><711><510><700><362><932><148><45><914><119><593><167><655><837><81><852><12><852><336><503><523><506><29><561><326><531><576><822><89><834><705><417><675><237><584>" 16 | bash vocoder.sh ${units} 17 | ``` -------------------------------------------------------------------------------- /speechgpt/scripts/ma_pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | METAROOT="llama/hf/7B" #stage1 4 | DATAROOT="data/stage1" 5 | OUTROOT="output/stage1" 6 | CACHEROOT="${DATAROOT}/cache/" 7 | 8 | 9 | mkdir -p ${CACHEROOT}/tokenized/train/ 10 | mkdir -p ${CACHEROOT}/tokenized/valid/ 11 | mkdir -p ${CACHEROOT}/group/train/ 12 | mkdir -p ${CACHEROOT}/group/valid/ 13 | 14 | 15 | #ddp realted 16 | NNODE=$1 17 | NODE_RANK=$2 18 | MASTER_ADDR=$3 19 | MASTER_PORT=$4 20 | 21 | 22 | echo "stage1: modality-adaptation pretraining" 23 | 24 | 25 | torchrun \ 26 | --nnode $NNODE \ 27 | --nproc_per_node 8 \ 28 | --node_rank $NODE_RANK \ 29 | --master_addr $MASTER_ADDR \ 30 | --master_port $MASTER_PORT \ 31 | src/train/ma_pretrain.py \ 32 | --bf16 True \ 33 | --block_size 1024 \ 34 | --model_name_or_path "${METAROOT}" \ 35 | --train_file ${DATAROOT}/train.txt \ 36 | --validation_file ${DATAROOT}/dev.txt \ 37 | --do_train \ 38 | --do_eval \ 39 | --output_dir "${OUTROOT}" \ 40 | --preprocessing_num_workers 100 \ 41 | --overwrite_output_dir \ 42 | --per_device_eval_batch_size 3 \ 43 | --per_device_train_batch_size 3 \ 44 | --gradient_accumulation_steps 8 \ 45 | --num_train_epochs 3 \ 46 | --log_level debug \ 47 | --logging_steps 1 \ 48 | --save_steps 300 \ 49 | --cache_dir ${CACHEROOT} \ 50 | --fsdp "full_shard auto_wrap" \ 51 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ 52 | 53 | -------------------------------------------------------------------------------- /speechgpt/src/infer/web_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import gradio as gr 4 | from speechgpt.utils.speech2unit.speech2unit import Speech2Unit 5 | from speechgpt.src.infer.cli_inference import SpeechGPTInference 6 | import soundfile as sf 7 | import argparse 8 | 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--model-name-or-path", type=str, default="") 12 | parser.add_argument("--lora-weights", type=str, default=None) 13 | parser.add_argument("--s2u-dir", type=str, default="speechgpt/utils/speech2unit/") 14 | parser.add_argument("--vocoder-dir", type=str, default="speechgpt/utils/vocoder/") 15 | parser.add_argument("--output-dir", type=str, default="speechgpt/output/") 16 | args = parser.parse_args() 17 | 18 | os.makedirs(args.output_dir, exist_ok=True) 19 | 20 | infer = SpeechGPTInference( 21 | args.model_name_or_path, 22 | args.lora_weights, 23 | args.load_8bit, 24 | args.s2u_dir, 25 | args.vocoder_dir, 26 | args.output_dir 27 | ) 28 | 29 | def speech_dialogue(audio): 30 | sr, data = audio 31 | sf.write( 32 | args.input_path, 33 | data, 34 | sr, 35 | ) 36 | prompts = [args.input_path] 37 | sr, wav = infer(prompts) 38 | return (sr, wav) 39 | 40 | 41 | demo = gr.Interface( 42 | fn=speech_dialogue, 43 | inputs="microphone", 44 | outputs="audio", 45 | title="SpeechGPT", 46 | cache_examples=False 47 | ) 48 | demo.launch(share=True) 49 | 50 | -------------------------------------------------------------------------------- /speechgpt/scripts/cm_sft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | METAROOT="output/stage1" 4 | DATAROOT="data/stage2" 5 | OUTROOT="output/stage2" 6 | CACHEROOT="${DATAROOT}/cache/" 7 | 8 | 9 | mkdir -p ${CACHEROOT}/tokenized/train/ 10 | mkdir -p ${CACHEROOT}/tokenized/valid/ 11 | 12 | 13 | #ddp realted 14 | NNODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l) 15 | MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 16 | NODE_RANK=$(($(scontrol show hostnames "$SLURM_JOB_NODELIST" | grep -Fn $(hostname) | cut --delimiter=":" --fields=1)-1)) 17 | 18 | 19 | echo "stage2: cross-modal instruction fine-tuning" 20 | 21 | 22 | torchrun \ 23 | --nnode $NNODE \ 24 | --nproc_per_node 8 \ 25 | --node_rank $NODE_RANK \ 26 | --master_addr $MASTER_ADDR \ 27 | --master_port 29501 \ 28 | speechgpt/src/train/cm_sft.py \ 29 | --model_name_or_path "${METAROOT}" \ 30 | --data_path "${DATAROOT}/SpeechInstruct_cross_modal.jsonl" \ 31 | --cache_dir ${CACHEROOT} \ 32 | --preprocessing_num_workers 10 \ 33 | --model_max_length 512 \ 34 | --bf16 True \ 35 | --do_train \ 36 | --do_eval \ 37 | --train_on_inputs True \ 38 | --output_dir "${OUTROOT}" \ 39 | --per_device_train_batch_size 6 \ 40 | --per_device_eval_batch_size 4 \ 41 | --gradient_accumulation_steps 12 \ 42 | --num_train_epochs 3 \ 43 | --evaluation_strategy "no" \ 44 | --save_strategy "steps" \ 45 | --save_steps 300 \ 46 | --learning_rate 2e-5 \ 47 | --weight_decay 0. \ 48 | --warmup_ratio 0.03 \ 49 | --lr_scheduler_type "cosine" \ 50 | --log_level debug \ 51 | --logging_steps 1 \ 52 | --overwrite_output_dir \ 53 | --fsdp "full_shard auto_wrap" \ 54 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ 55 | 56 | -------------------------------------------------------------------------------- /speechgpt/scripts/com_sft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | METAROOT="output/stage2" 4 | DATAROOT="data/stage3" 5 | OUTROOT="output/stage3" 6 | CACHEROOT="${DATAROOT}/cache/" 7 | 8 | 9 | mkdir -p ${CACHEROOT}/tokenized/train/ 10 | mkdir -p ${CACHEROOT}/tokenized/valid/ 11 | 12 | 13 | #ddp realted 14 | NNODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l) 15 | MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 16 | NODE_RANK=$(($(scontrol show hostnames "$SLURM_JOB_NODELIST" | grep -Fn $(hostname) | cut --delimiter=":" --fields=1)-1)) 17 | 18 | 19 | echo "stage3: chain-of-modality instruction finetuning" 20 | 21 | 22 | torchrun \ 23 | --nnode $NNODE \ 24 | --nproc_per_node 8 \ 25 | --node_rank $NODE_RANK \ 26 | --master_addr $MASTER_ADDR \ 27 | --master_port 29502 \ 28 | speechgpt/src/train/com_sft.py \ 29 | --model_name_or_path "${METAROOT}" \ 30 | --data_path "${DATAROOT}/SpeechInstruct_chain_of_modality.jsonl" \ 31 | --cache_dir ${CACHEROOT} \ 32 | --lora_r 8 \ 33 | --lora_alpha 16 \ 34 | --lora_dropout 0.05 \ 35 | --lora_target_modules 'q_proj,v_proj' \ 36 | --preprocessing_num_workers 10 \ 37 | --model_max_length 1024 \ 38 | --val_set_size 10 \ 39 | --bf16 True \ 40 | --do_train \ 41 | --train_on_inputs True \ 42 | --output_dir "${OUTROOT}" \ 43 | --per_device_train_batch_size 12 \ 44 | --per_device_eval_batch_size 4 \ 45 | --gradient_accumulation_steps 4 \ 46 | --num_train_epochs 300 \ 47 | --evaluation_strategy "no" \ 48 | --save_strategy "steps" \ 49 | --save_steps 300 \ 50 | --learning_rate 2e-5 \ 51 | --weight_decay 0. \ 52 | --warmup_ratio 0.03 \ 53 | --lr_scheduler_type "cosine" \ 54 | --log_level debug \ 55 | --logging_steps 1 \ 56 | 57 | -------------------------------------------------------------------------------- /speechgpt-gen/README.md: -------------------------------------------------------------------------------- 1 | # SpeechGPT-Gen: Scaling Chain-of-Information Speech Generation 2 | 3 | 4 | 5 |

6 |
7 |

8 | 9 | ## Introduction 10 | Benefiting from effective speech modeling, current Speech Large Language Models (SLLMs) have demonstrated exceptional capabilities in in-context speech generation and efficient generalization to unseen speakers. 11 | However, the prevailing information modeling process is encumbered by certain redundancies, leading to inefficiencies in speech generation. 12 | We propose Chain-of-Information Generation (CoIG), a method for decoupling semantic and perceptual information in large-scale speech generation. Building on this, we develop SpeechGPT-Gen, an 8-billion-parameter SLLM efficient in semantic and perceptual information modeling. It comprises an autoregressive model based on LLM for semantic information modeling and a non-autoregressive model employing flow matching for perceptual information modeling. Additionally, we introduce the novel approach of infusing semantic information into the prior distribution to enhance the efficiency of flow matching. 13 | Extensive experimental results demonstrate that SpeechGPT-Gen markedly excels in zero-shot text-to-speech, zero-shot voice conversion, and speech-to-speech dialogue, underscoring CoIG's remarkable proficiency in capturing and modeling speech's semantic and perceptual dimensions. 14 | 15 |

16 |
17 | Illustration of SpeechGPT-Gen. 18 |

19 | 20 | 21 | ## Code 22 | We will soon open-source our codes and models, stay tuned! 23 | 24 | 25 | 26 | ## Citation 27 | -------------------------------------------------------------------------------- /speechgpt/utils/text2unit/text2unit.py: -------------------------------------------------------------------------------- 1 | import os 2 | from fairseq.models.transformer import TransformerModel 3 | import argparse 4 | from argparse import Namespace 5 | from fairseq.data import encoders 6 | import json 7 | from tqdm import tqdm 8 | from typing import List, Optional 9 | import torch 10 | 11 | class Text2Unit: 12 | def __init__( 13 | self, 14 | checkpoint_dir="speechgpt/utils/text2unit", 15 | checkpoint_file="text2unit.pt", 16 | data_name_or_path="speechgpt/utils/text2unit/binary", 17 | sentencepiece_model="speechgpt/utils/text2unit/spm.model" 18 | ) -> None: 19 | 20 | self.bpe_tokenizer = encoders.build_bpe( 21 | Namespace( 22 | bpe='sentencepiece', 23 | sentencepiece_model=sentencepiece_model, 24 | ) 25 | ) 26 | self.t2u = TransformerModel.from_pretrained( 27 | checkpoint_dir, 28 | checkpoint_file=checkpoint_file, 29 | data_name_or_path=data_name_or_path, 30 | 31 | ) 32 | 33 | @torch.no_grad() 34 | def forward( 35 | self, 36 | text, 37 | **kwargs 38 | ): 39 | encoded_text = [self.bpe_tokenizer.encode(x).strip() for x in text] if isinstance(text, list) else self.bpe_tokenizer.encode(text).strip() 40 | output = self.t2u.translate(encoded_text, **kwargs) 41 | return [self.postprocess(x) for x in output] if isinstance(output, list) else self.postprocess(output) 42 | 43 | 44 | def postprocess( 45 | self, 46 | input 47 | ): 48 | return ''+"".join(input.split())+'' 49 | 50 | 51 | def __call__(self, text, **kwargs): 52 | return self.forward(text, **kwargs) 53 | 54 | 55 | if __name__=='__main__': 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("--text", type=str, default="") 58 | args = parser.parse_args() 59 | 60 | 61 | translator = Text2Unit( 62 | checkpoint_dir="speechgpt/utils/text2unit", 63 | checkpoint_file="text2unit.pt", 64 | data_name_or_path="speechgpt/utils/text2unit/binary", 65 | sentencepiece_model="speechgpt/utils/text2unit/spm.model" 66 | ) 67 | 68 | gen_args = { 69 | "max_len_b":1000, 70 | "beam":5, 71 | } 72 | 73 | units = translator(args.text, **gen_args) 74 | print(units) 75 | 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SpeechGPT: Speech Large Language Models 2 | 3 |

4 |
5 |

6 | 7 | 8 | - [**SpeechGPT**](speechgpt) (2023/05) - Empowering Large Language Models with Intrinsic Cross-Modal Conversational Abilities 9 | 10 | - [**SpeechGPT-Gen**](speechgpt-gen) (2024/01) - Scaling Chain-of-Information Speech Generation 11 | 12 | 13 | ## News 14 | - **[2024/2/20]** We proposed **AnyGPT: Unified Multimodal LLM with Discrete Sequence Modeling**. Checkout the [paper](https://arxiv.org/abs/2402.12226) and [github](https://github.com/OpenMOSS/AnyGPT). 15 | - **[2024/1/25]** We released **SpeechGPT-Gen: Scaling Chain-of-Information Speech Generation**. Checkout the [paper](https://arxiv.org/abs/2401.13527) and [github](https://github.com/0nutation/SpeechGPT/tree/main/speechgpt-gen). 16 | - **[2024/1/9]** We proposed **SpeechAgents: Human-Communication Simulation with Multi-Modal Multi-Agent Systems**. Checkout the [paper](https://arxiv.org/abs/2401.03945) and [github](https://github.com/0nutation/SpeechAgents). 17 | - **[2023/9/15]** We released SpeechGPT code and checkpoints and SpeechInstruct dataset. 18 | - **[2023/9/1]** We proposed **SpeechTokenizer: Unified Speech Tokenizer for Speech Language Models**. We released the code and checkpoints of SpeechTokenizer. Checkout the [paper](https://arxiv.org/abs/2308.16692), [demo](https://0nutation.github.io/SpeechTokenizer.github.io/) and [github](https://github.com/ZhangXInFD/SpeechTokenizer). 19 | - **[2023/5/18]** We released **SpeechGPT: Empowering Large Language Models with Intrinsic Cross-Modal Conversational Abilities**. We propose SpeechGPT, the first multi-modal LLM capable of perceiving and generating multi-modal contents following multi-modal human instructions. Checkout the [paper](https://arxiv.org/abs/2305.11000) and [demo](https://0nutation.github.io/SpeechGPT.github.io/). 20 | 21 | 22 | 23 | ## Acknowledgements 24 | - We express our appreciation to Fuliang Weng and Rong Ye for their valuable suggestions and guidance. 25 | 26 | 27 | 28 | ## Citation 29 | If you find our work useful for your research and applications, please cite using the BibTex: 30 | 31 | ``` 32 | @misc{zhang2023speechgpt, 33 | title={SpeechGPT: Empowering Large Language Models with Intrinsic Cross-Modal Conversational Abilities}, 34 | author={Dong Zhang and Shimin Li and Xin Zhang and Jun Zhan and Pengyu Wang and Yaqian Zhou and Xipeng Qiu}, 35 | year={2023}, 36 | eprint={2305.11000}, 37 | archivePrefix={arXiv}, 38 | primaryClass={cs.CL} 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /speechgpt/utils/vocoder/generate_waveform_from_code.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import json 8 | import logging 9 | from pathlib import Path 10 | import random 11 | import soundfile as sf 12 | import torch 13 | 14 | from tqdm import tqdm 15 | 16 | from fairseq import utils 17 | from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder 18 | 19 | 20 | logging.basicConfig() 21 | logging.root.setLevel(logging.INFO) 22 | logging.basicConfig(level=logging.INFO) 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def dump_result(args, sample_id, pred_wav, suffix=""): 27 | sf.write( 28 | f"{args.results_path}/{sample_id}{suffix}_pred.wav", 29 | pred_wav.detach().cpu().numpy(), 30 | 16000, 31 | ) 32 | 33 | 34 | def load_code(in_file): 35 | with open(in_file) as f: 36 | out = [list(map(int, line.strip().split())) for line in f] 37 | return out 38 | 39 | 40 | def main(args): 41 | logger.info(args) 42 | 43 | use_cuda = torch.cuda.is_available() and not args.cpu 44 | 45 | with open(args.vocoder_cfg) as f: 46 | vocoder_cfg = json.load(f) 47 | vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg) 48 | if use_cuda: 49 | vocoder = vocoder.cuda() 50 | 51 | multispkr = vocoder.model.multispkr 52 | if multispkr: 53 | logger.info("multi-speaker vocoder") 54 | num_speakers = vocoder_cfg.get( 55 | "num_speakers", 200 56 | ) # following the default in codehifigan to set to 200 57 | assert ( 58 | args.speaker_id < num_speakers 59 | ), f"invalid --speaker-id ({args.speaker_id}) with total #speakers = {num_speakers}" 60 | 61 | data = load_code(args.in_code_file) 62 | Path(args.results_path).mkdir(exist_ok=True, parents=True) 63 | for i, d in tqdm(enumerate(data), total=len(data)): 64 | x = { 65 | "code": torch.LongTensor(d).view(1, -1), 66 | } 67 | suffix = "" 68 | if multispkr: 69 | spk = ( 70 | random.randint(0, num_speakers - 1) 71 | if args.speaker_id == -1 72 | else args.speaker_id 73 | ) 74 | suffix = f"_spk{spk}" 75 | x["spkr"] = torch.LongTensor([spk]).view(1, 1) 76 | 77 | x = utils.move_to_cuda(x) if use_cuda else x 78 | wav = vocoder(x, args.dur_prediction) 79 | dump_result(args, i, wav, suffix=suffix) 80 | 81 | 82 | def cli_main(): 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument( 85 | "--in-code-file", type=str, required=True, help="one unit sequence per line" 86 | ) 87 | parser.add_argument( 88 | "--vocoder", type=str, required=True, help="path to the CodeHiFiGAN vocoder" 89 | ) 90 | parser.add_argument( 91 | "--vocoder-cfg", 92 | type=str, 93 | required=True, 94 | help="path to the CodeHiFiGAN vocoder config", 95 | ) 96 | parser.add_argument("--results-path", type=str, required=True) 97 | parser.add_argument( 98 | "--dur-prediction", 99 | action="store_true", 100 | help="enable duration prediction (for reduced/unique code sequences)", 101 | ) 102 | parser.add_argument( 103 | "--speaker-id", 104 | type=int, 105 | default=-1, 106 | help="Speaker id (for vocoder that supports multispeaker). Set to -1 to randomly sample speakers.", 107 | ) 108 | parser.add_argument("--cpu", action="store_true", help="run on CPU") 109 | 110 | args = parser.parse_args() 111 | 112 | main(args) 113 | 114 | 115 | if __name__ == "__main__": 116 | cli_main() 117 | -------------------------------------------------------------------------------- /speechgpt/utils/speech2unit/speech2unit.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | import logging 3 | import os 4 | import sys 5 | import joblib 6 | import fire 7 | import fairseq 8 | import soundfile as sf 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from tqdm.auto import tqdm 13 | from einops import rearrange 14 | import re 15 | import numpy as np 16 | from functools import partial 17 | import torch.multiprocessing as mp 18 | import torchaudio 19 | import glob 20 | import tqdm 21 | import argparse 22 | from torchaudio.functional import resample 23 | 24 | logging.basicConfig( 25 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 26 | datefmt="%Y-%m-%d %H:%M:%S", 27 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 28 | stream=sys.stdout, 29 | ) 30 | logger = logging.getLogger('generate_pseudo_language') 31 | 32 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 33 | 34 | class FeatureReader(object): 35 | def __init__(self, ckpt_path, layer, max_chunk=1600000, fp16=False, sampling_rate=16000): 36 | ( 37 | model, 38 | cfg, 39 | task, 40 | ) = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) 41 | self.model = model[0].eval().to(DEVICE) 42 | self.task = task 43 | self.layer = layer 44 | self.max_chunk = max_chunk 45 | self.fp16 = fp16 46 | if fp16: 47 | self.model.half() 48 | 49 | self.layer_shift = 0 50 | self.target_sample_hz = sampling_rate 51 | 52 | logger.info(f"TASK CONFIG:\n{self.task.cfg}") 53 | 54 | def read_audio(self, path): 55 | wav, sr = torchaudio.load(path) 56 | if sr != self.target_sample_hz: 57 | wav = resample(wav, sr, self.target_sample_hz) 58 | return wav 59 | 60 | @torch.no_grad() 61 | def get_feats(self, waveform): 62 | x = waveform 63 | with torch.no_grad(): 64 | if self.fp16: 65 | x = x.half().cuda() 66 | else: 67 | x = x.float().cuda() 68 | if self.task.cfg.normalize: 69 | x = F.layer_norm(x, x.shape) 70 | x = x.view(1, -1) 71 | 72 | feat = [] 73 | for start in range(0, x.size(1), self.max_chunk): 74 | x_chunk = x[:, start: start + self.max_chunk] 75 | feat_chunk, _ = self.model.extract_features( 76 | source=x_chunk, 77 | padding_mask=None, 78 | mask=False, 79 | output_layer=self.layer + self.layer_shift, 80 | ) 81 | 82 | feat.append(feat_chunk) 83 | if len(feat) == 0: 84 | return torch.zeros(0, 0) 85 | return torch.cat(feat, 1).squeeze(0) 86 | 87 | 88 | 89 | 90 | class ApplyKmeans(object): 91 | def __init__(self, km_path): 92 | self.km_model = joblib.load(km_path) 93 | self.C_np = self.km_model.cluster_centers_.transpose() 94 | self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True) 95 | 96 | self.C = torch.from_numpy(self.C_np) 97 | self.Cnorm = torch.from_numpy(self.Cnorm_np) 98 | if torch.cuda.is_available(): 99 | self.C = self.C.cuda() 100 | self.Cnorm = self.Cnorm.cuda() 101 | 102 | def __call__(self, x): 103 | if isinstance(x, torch.Tensor): 104 | self.C = self.C.to(x) 105 | self.Cnorm = self.Cnorm.to(x) 106 | dist = ( 107 | x.pow(2).sum(1, keepdim=True) 108 | - 2 * torch.matmul(x, self.C) 109 | + self.Cnorm 110 | ) 111 | return dist.argmin(dim=1).cpu().numpy() 112 | else: 113 | dist = ( 114 | (x ** 2).sum(1, keepdims=True) 115 | - 2 * np.matmul(x, self.C_np) 116 | + self.Cnorm_np 117 | ) 118 | return np.argmin(dist, axis=1) 119 | 120 | 121 | class Speech2Unit(torch.nn.Module): 122 | def __init__( 123 | self, 124 | ckpt_dir, 125 | layer=11, 126 | max_chunk=1600000, 127 | fp16=False, 128 | sampling_rate=16000, 129 | ): 130 | 131 | """ 132 | Args: 133 | ckpt_dir(str): path to hubert model dir(e.g. hubert_base_ls960.pt) 134 | layer(int): feat from which layer of hubert models defauly by 9 135 | max_chunk(int): default by 1600000 136 | fp16(bool): default by False 137 | sampling_rate(int): sampling_rate default by 16000 138 | """ 139 | super().__init__() 140 | 141 | ckpt_path = os.path.join(ckpt_dir, "mhubert_base_vp_en_es_fr_it3.pt") 142 | km_path = os.path.join(ckpt_dir, "mhubert_base_vp_en_es_fr_it3_L11_km1000.bin") 143 | 144 | self.feature_reader = FeatureReader(ckpt_path, layer, max_chunk, fp16, sampling_rate) 145 | self.apply_kmeans = ApplyKmeans(km_path) 146 | 147 | @staticmethod 148 | def merge_duplicates(cluster_ids): 149 | dup_cluster_list = [] 150 | duration_list = [] 151 | count = 1 152 | for i in range(0, len(cluster_ids)): 153 | if i + 1 < len(cluster_ids) and cluster_ids[i] == cluster_ids[i+1]: 154 | count += 1 155 | else: 156 | dup_cluster_list.append(cluster_ids[i]) 157 | duration_list.append(count) 158 | count = 1 159 | return dup_cluster_list, duration_list 160 | 161 | 162 | def __call__(self, path, merged=True): 163 | waveform = self.feature_reader.read_audio(path).to(DEVICE) 164 | 165 | feat = self.feature_reader.get_feats(waveform) 166 | cluster_ids = self.apply_kmeans(feat).tolist() 167 | dup_cluster_list, duration_list = self.merge_duplicates(cluster_ids) 168 | 169 | merged_units = "" + "".join([f"<{str(x)}>" for x in dup_cluster_list]) + "" 170 | unmerged_units = "" + "".join([f"<{str(x)}>" for x in cluster_ids]) + "" 171 | 172 | if merged: 173 | return merged_units 174 | else: 175 | return unmerged_units 176 | # return {"continuous":feat, "units":dup_cluster_list, "duration":duration_list, "unmerged_units":cluster_ids} 177 | 178 | 179 | if __name__ == '__main__': 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument("--wav", type=str) 182 | args = parser.parse_args() 183 | 184 | ckpt_dir = "speechgpt/utils/speech2unit/" 185 | 186 | s2u = Speech2Unit( 187 | ckpt_dir=ckpt_dir 188 | ) 189 | 190 | units = s2u(args.wav) 191 | print(units) 192 | 193 | 194 | -------------------------------------------------------------------------------- /speechgpt/src/infer/cli_infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder 4 | import soundfile as sf 5 | from typing import List 6 | import argparse 7 | import logging 8 | import json 9 | from tqdm import tqdm 10 | import os 11 | import re 12 | import traceback 13 | from peft import PeftModel 14 | from speechgpt.utils.speech2unit.speech2unit import Speech2Unit 15 | import transformers 16 | from transformers import AutoConfig, LlamaForCausalLM, LlamaTokenizer, GenerationConfig 17 | 18 | 19 | logging.basicConfig() 20 | logging.root.setLevel(logging.INFO) 21 | logging.basicConfig(level=logging.INFO) 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | 26 | NAME="SpeechGPT" 27 | META_INSTRUCTION="You are an AI assistant whose name is SpeechGPT.\n- SpeechGPT is a intrinsic cross-modal conversational language model that is developed by Fudan University. SpeechGPT can understand and communicate fluently with human through speech or text chosen by the user.\n- It can perceive cross-modal inputs and generate cross-modal outputs.\n" 28 | DEFAULT_GEN_PARAMS = { 29 | "max_new_tokens": 1024, 30 | "min_new_tokens": 10, 31 | "temperature": 0.8, 32 | "do_sample": True, 33 | "top_k": 60, 34 | "top_p": 0.8, 35 | } 36 | device = torch.device('cuda') 37 | 38 | 39 | def extract_text_between_tags(text, tag1='[SpeechGPT] :', tag2=''): 40 | pattern = f'{re.escape(tag1)}(.*?){re.escape(tag2)}' 41 | match = re.search(pattern, text, re.DOTALL) 42 | if match: 43 | response = match.group(1) 44 | else: 45 | response = "" 46 | return response 47 | 48 | 49 | 50 | class SpeechGPTInference: 51 | def __init__( 52 | self, 53 | model_name_or_path: str, 54 | lora_weights: str=None, 55 | s2u_dir: str="speechgpt/utils/speech2unit/", 56 | vocoder_dir: str="speechgpt/utils/vocoder/", 57 | output_dir="speechgpt/output/" 58 | ): 59 | 60 | self.meta_instruction = META_INSTRUCTION 61 | self.template= "[Human]: {question} . [SpeechGPT]: " 62 | 63 | 64 | #speech2unit 65 | self.s2u = Speech2Unit(ckpt_dir=s2u_dir) 66 | 67 | #model 68 | self.model = LlamaForCausalLM.from_pretrained( 69 | model_name_or_path, 70 | load_in_8bit=False, 71 | torch_dtype=torch.float16, 72 | device_map="auto", 73 | ) 74 | 75 | if lora_weights is not None: 76 | self.model = PeftModel.from_pretrained( 77 | self.model, 78 | lora_weights, 79 | torch_dtype=torch.float16, 80 | device_map="auto", 81 | ) 82 | 83 | self.model.half() 84 | 85 | self.model.eval() 86 | if torch.__version__ >= "2" and sys.platform != "win32": 87 | self.model = torch.compile(self.model) 88 | 89 | #tokenizer 90 | self.tokenizer = LlamaTokenizer.from_pretrained( 91 | model_name_or_path) 92 | self.tokenizer.pad_token_id = (0) 93 | self.tokenizer.padding_side = "left" 94 | 95 | 96 | #generation 97 | self.generate_kwargs = DEFAULT_GEN_PARAMS 98 | 99 | 100 | #vocoder 101 | vocoder = os.path.join(vocoder_dir, "vocoder.pt") 102 | vocoder_cfg = os.path.join(vocoder_dir, "config.json") 103 | with open(vocoder_cfg) as f: 104 | vocoder_cfg = json.load(f) 105 | self.vocoder = CodeHiFiGANVocoder(vocoder, vocoder_cfg).to(device) 106 | 107 | self.output_dir = output_dir 108 | 109 | 110 | def preprocess( 111 | self, 112 | raw_text: str, 113 | ): 114 | processed_parts = [] 115 | for part in raw_text.split("is input:"): 116 | if os.path.isfile(part.strip()) and os.path.splitext(part.strip())[-1] in [".wav", ".flac", ".mp4"]: 117 | processed_parts.append(self.s2u(part.strip(), merged=True)) 118 | else: 119 | processed_parts.append(part) 120 | processed_text = "is input:".join(processed_parts) 121 | 122 | prompt_seq = self.meta_instruction + self.template.format(question=processed_text) 123 | return prompt_seq 124 | 125 | 126 | def postprocess( 127 | self, 128 | response: str, 129 | ): 130 | 131 | question = extract_text_between_tags(response, tag1="[Human]", tag2="") 132 | answer = extract_text_between_tags(response + '', tag1=f"[SpeechGPT] :", tag2="") 133 | tq = extract_text_between_tags(response, tag1="[SpeechGPT] :", tag2="; [ta]") if "[ta]" in response else '' 134 | ta = extract_text_between_tags(response, tag1="[ta]", tag2="; [ua]") if "[ta]" in response else '' 135 | ua = extract_text_between_tags(response + '', tag1="[ua]", tag2="") if "[ua]" in response else '' 136 | 137 | return {"question":question, "answer":answer, "textQuestion":tq, "textAnswer":ta, "unitAnswer":ua} 138 | 139 | 140 | def forward( 141 | self, 142 | prompts: List[str] 143 | ): 144 | with torch.no_grad(): 145 | #preprocess 146 | preprocessed_prompts = [] 147 | for prompt in prompts: 148 | preprocessed_prompts.append(self.preprocess(prompt)) 149 | 150 | input_ids = self.tokenizer(preprocessed_prompts, return_tensors="pt", padding=True).input_ids 151 | for input_id in input_ids: 152 | if input_id[-1] == 2: 153 | input_id = input_id[:, :-1] 154 | 155 | input_ids = input_ids.to(device) 156 | 157 | #generate 158 | generation_config = GenerationConfig( 159 | temperature=0.7, 160 | top_p=0.8, 161 | top_k=50, 162 | do_sample=True, 163 | max_new_tokens=2048, 164 | min_new_tokens=10, 165 | ) 166 | 167 | generated_ids = self.model.generate( 168 | input_ids=input_ids, 169 | generation_config=generation_config, 170 | return_dict_in_generate=True, 171 | output_scores=True, 172 | # max_new_tokens=1024, 173 | ) 174 | generated_ids = generated_ids.sequences 175 | responses = self.tokenizer.batch_decode(generated_ids.cpu(), skip_special_tokens=True) 176 | 177 | #postprocess 178 | responses = [self.postprocess(x) for x in responses] 179 | 180 | #save repsonses 181 | init_num = sum(1 for line in open(f"{self.output_dir}/responses.json", 'r')) if os.path.exists(f"{self.output_dir}/responses.json") else 0 182 | with open(f"{self.output_dir}/responses.json", 'a') as f: 183 | for r in responses: 184 | if r["textAnswer"] != "": 185 | print("Transcript:", r["textQuestion"]) 186 | print("Text response:", r["textAnswer"]) 187 | else: 188 | print("Response:\n", r["answer"]) 189 | json_line = json.dumps(r) 190 | f.write(json_line+'\n') 191 | 192 | #dump wav 193 | wav = torch.tensor(0) 194 | os.makedirs(f"{self.output_dir}/wav/", exist_ok=True) 195 | for i, response in enumerate(responses): 196 | if response["answer"] != '' and '' in response["answer"]: 197 | unit = [int(num) for num in re.findall(r'<(\d+)>', response["answer"])] 198 | x = { 199 | "code": torch.LongTensor(unit).view(1, -1).to(device), 200 | } 201 | wav = self.vocoder(x, True) 202 | self.dump_wav(init_num+i, wav, prefix="answer") 203 | print(f"Speech repsonse is saved in {self.output_dir}/wav/answer_{init_num+i}.wav") 204 | print(f"Response json is saved in {self.output_dir}/responses.json") 205 | 206 | 207 | return 16000, wav.detach().cpu().numpy() 208 | 209 | def dump_wav(self, sample_id, pred_wav, prefix): 210 | sf.write( 211 | f"{self.output_dir}/wav/{prefix}_{sample_id}.wav", 212 | pred_wav.detach().cpu().numpy(), 213 | 16000, 214 | ) 215 | 216 | def __call__(self, input): 217 | return self.forward(input) 218 | 219 | 220 | def interact(self): 221 | prompt = str(input(f"Please talk with {NAME}:\n")) 222 | while prompt != "quit": 223 | try: 224 | self.forward([prompt]) 225 | except Exception as e: 226 | traceback.print_exc() 227 | print(e) 228 | 229 | prompt = str(input(f"Please input prompts for {NAME}:\n")) 230 | 231 | 232 | 233 | if __name__=='__main__': 234 | parser = argparse.ArgumentParser() 235 | parser.add_argument("--model-name-or-path", type=str, default="") 236 | parser.add_argument("--lora-weights", type=str, default=None) 237 | parser.add_argument("--s2u-dir", type=str, default="speechgpt/utils/speech2unit/") 238 | parser.add_argument("--vocoder-dir", type=str, default="speechgpt/utils/vocoder/") 239 | parser.add_argument("--output-dir", type=str, default="speechgpt/output/") 240 | args = parser.parse_args() 241 | 242 | os.makedirs(args.output_dir, exist_ok=True) 243 | 244 | 245 | infer = SpeechGPTInference( 246 | args.model_name_or_path, 247 | args.lora_weights, 248 | args.s2u_dir, 249 | args.vocoder_dir, 250 | args.output_dir 251 | ) 252 | 253 | infer.interact() 254 | 255 | 256 | 257 | 258 | 259 | 260 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /speechgpt/src/train/cm_sft.py: -------------------------------------------------------------------------------- 1 | # stage2: cross-modal instruct finetuning 2 | import copy 3 | import logging 4 | from dataclasses import dataclass, field 5 | from typing import Optional, Dict, Sequence 6 | import torch 7 | import transformers 8 | from torch.utils.data import Dataset 9 | from transformers import Trainer 10 | from datasets import load_dataset 11 | from transformers import LlamaForCausalLM, LlamaTokenizer, HfArgumentParser, TrainingArguments, DataCollatorForSeq2Seq 12 | from transformers.trainer_utils import get_last_checkpoint 13 | from speechgpt.utils.prompter import Prompter 14 | import os 15 | import logging 16 | import sys 17 | 18 | logger = logging.getLogger(__name__) 19 | logger.setLevel(logging.DEBUG) 20 | 21 | @dataclass 22 | class ModelArguments: 23 | model_name_or_path: Optional[str] = field( 24 | default=None, 25 | metadata={ 26 | "help": ( 27 | "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." 28 | ) 29 | }, 30 | ) 31 | 32 | 33 | @dataclass 34 | class DataArguments: 35 | data_path: str = field( 36 | default="", 37 | metadata={"help": "Path to the training data."}) 38 | prompt_template_name: str = field( 39 | default="alpaca", 40 | metadata={"help": "prompt_template_name"}, 41 | ) 42 | max_train_samples: Optional[int] = field( 43 | default=None, 44 | metadata={ 45 | "help": ( 46 | "For debugging purposes or quicker training, truncate the number of training examples to this " 47 | "value if set." 48 | ) 49 | }, 50 | ) 51 | max_eval_samples: Optional[int] = field( 52 | default=None, 53 | metadata={ 54 | "help": ( 55 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 56 | "value if set." 57 | ) 58 | }, 59 | ) 60 | 61 | @dataclass 62 | class TrainingArguments(transformers.TrainingArguments): 63 | cache_dir: Optional[str] = field( 64 | default=None, 65 | metadata={"help": "Where do you want to store the tokenized data"}, 66 | ) 67 | optim: str = field(default="adamw_torch") 68 | model_max_length: int = field( 69 | default=512, 70 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 71 | ) 72 | val_set_size: int = field( 73 | default=2000, 74 | metadata={"help": "val_set_size"}, 75 | ) 76 | preprocessing_num_workers: int = field( 77 | default=100, 78 | metadata={"help": "preprocessing_num_workers for tokenizing"}, 79 | ) 80 | num_train_epochs: int = field( 81 | default=3, 82 | metadata={"help": "num_epochs"}, 83 | ) 84 | learning_rate: float = field( 85 | default=2e-5, 86 | metadata={"help": "learning_rate"}, 87 | ) 88 | output_dir: str = field( 89 | default="", 90 | metadata={"help": "output_dir"}, 91 | ) 92 | train_on_inputs: bool = field( 93 | default=True, 94 | metadata={"help": "if False, masks out inputs in loss"}, 95 | ) 96 | initial_global_step: int = field( 97 | default=0, 98 | metadata={"help": "initial_global_step"} 99 | ) 100 | 101 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 102 | """Collects the state dict and dump to disk.""" 103 | state_dict = trainer.model.state_dict() 104 | if trainer.args.should_save: 105 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 106 | del state_dict 107 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 108 | 109 | 110 | 111 | def train(): 112 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 113 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 114 | 115 | # Setup logging 116 | logging.basicConfig( 117 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 118 | datefmt="%m/%d/%Y %H:%M:%S", 119 | handlers=[logging.StreamHandler(sys.stdout)], 120 | ) 121 | 122 | 123 | # Log on each process the small summary: 124 | logger.warning( 125 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 126 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 127 | ) 128 | logger.info(f"Training/evaluation parameters {training_args}") 129 | 130 | 131 | # Detecting last checkpoint. 132 | last_checkpoint = None 133 | if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: 134 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 135 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 136 | raise ValueError( 137 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 138 | "Use --overwrite_output_dir to overcome." 139 | ) 140 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 141 | logger.info( 142 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 143 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 144 | ) 145 | 146 | 147 | prompter = Prompter() 148 | 149 | model = LlamaForCausalLM.from_pretrained( 150 | model_args.model_name_or_path, 151 | ) 152 | 153 | tokenizer = LlamaTokenizer.from_pretrained( 154 | model_args.model_name_or_path, 155 | model_max_length=training_args.model_max_length, 156 | padding_side="right", 157 | use_fast=False, 158 | ) 159 | tokenizer.pad_token_id = ( 160 | 0 # unk. we want this to be different from the eos token 161 | ) 162 | tokenizer.padding_side = "left" # Allow batched inference 163 | #Extend vocab for speech units 164 | if '' not in tokenizer.get_vocab(): 165 | units_size=1000 166 | logger.info(f"Add special unit tokens <0>-<{units_size-1} to tokenizer.vocab") 167 | new_tokens = [f"<{x}>" for x in range(units_size)] + ['','','[Human]','[SpeechGPT]','',''] 168 | tokenizer.add_tokens(new_tokens) 169 | for token in ['','','[Human]','[SpeechGPT]','','']: 170 | if token not in tokenizer.get_vocab(): 171 | logger.info(f"Add special unit tokens {token} to tokenizer.vocab") 172 | tokenizer.add_tokens([token]) 173 | 174 | #resize embedding 175 | embedding_size = model.get_input_embeddings().weight.shape[0] 176 | if len(tokenizer) > embedding_size: 177 | model.resize_token_embeddings(len(tokenizer)) 178 | 179 | def tokenize(prompt, add_eos_token=True): 180 | # there's probably a way to do this with the tokenizer settings 181 | # but again, gotta move fast 182 | result = tokenizer( 183 | prompt, 184 | truncation=True, 185 | max_length=tokenizer.model_max_length, 186 | padding=False, 187 | return_tensors=None, 188 | ) 189 | if ( 190 | result["input_ids"][-1] != tokenizer.eos_token_id 191 | and len(result["input_ids"]) < tokenizer.model_max_length 192 | and add_eos_token 193 | ): 194 | result["input_ids"].append(tokenizer.eos_token_id) 195 | result["attention_mask"].append(1) 196 | 197 | result["labels"] = result["input_ids"].copy() 198 | 199 | return result 200 | 201 | 202 | def generate_and_tokenize_prompt(data_point): 203 | ''' 204 | moss-style instructions 205 | ''' 206 | full_prompt = prompter.generate_prompt( 207 | data_point["prefix"], 208 | data_point["plain_text"], 209 | ) 210 | tokenized_full_prompt = tokenize(full_prompt) 211 | 212 | user_prompt = prompter.generate_prompt( 213 | data_point["prefix"] 214 | ) 215 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) 216 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 217 | 218 | tokenized_full_prompt["labels"] = [ 219 | -100 220 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 221 | user_prompt_len: 222 | ] # could be sped up, probably 223 | return tokenized_full_prompt 224 | 225 | 226 | if data_args.data_path.endswith(".json") or data_args.data_path.endswith(".jsonl"): 227 | data = load_dataset("json", data_files=data_args.data_path) 228 | else: 229 | data = load_dataset(data_args.data_path) 230 | 231 | tokenized_cache_file_names = { 232 | "train":os.path.join(training_args.cache_dir, 'tokenized', 'train', 'processed_train.arrow'), 233 | "test":os.path.join(training_args.cache_dir, 'tokenized', 'valid', 'processed_valid.arrow'), 234 | } 235 | 236 | if training_args.val_set_size > 0: 237 | train_val = data["train"].train_test_split( 238 | test_size=training_args.val_set_size, shuffle=True, seed=42 239 | ) 240 | train_val_data = ( 241 | train_val.map( 242 | generate_and_tokenize_prompt, 243 | batched=False, 244 | num_proc=training_args.preprocessing_num_workers, 245 | load_from_cache_file=True, 246 | cache_file_names=tokenized_cache_file_names, 247 | desc=f"generate_and_tokenize_prompt", 248 | ) 249 | ) 250 | train_data = train_val_data["train"] 251 | val_data = train_val_data["test"] 252 | 253 | else: 254 | train_data = data["train"].map( 255 | generate_and_tokenize_prompt, 256 | batched=False, 257 | num_proc=training_args.preprocessing_num_workers, 258 | load_from_cache_file=True, 259 | cache_file_names=tokenized_cache_file_names, 260 | desc=f"generate_and_tokenize_prompt", 261 | ) 262 | val_data = None 263 | 264 | 265 | data_collator = DataCollatorForSeq2Seq( 266 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 267 | ) 268 | 269 | trainer = Trainer( 270 | model=model, 271 | tokenizer=tokenizer, 272 | args=training_args, 273 | train_dataset=train_data if training_args.do_train else None, 274 | eval_dataset=val_data if training_args.do_eval else None, 275 | data_collator=data_collator 276 | ) 277 | 278 | if training_args.initial_global_step != 0: 279 | logger.info(f"Set initial global step={training_args.initial_global_step}") 280 | trainer.state.global_step = training_args.initial_global_step 281 | 282 | # Training 283 | if training_args.do_train: 284 | checkpoint = None 285 | if training_args.resume_from_checkpoint is not None: 286 | checkpoint = training_args.resume_from_checkpoint 287 | elif last_checkpoint is not None: 288 | checkpoint = last_checkpoint 289 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 290 | metrics = train_result.metrics 291 | 292 | max_train_samples = ( 293 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_data) 294 | ) 295 | metrics["train_samples"] = min(max_train_samples, len(train_data)) 296 | 297 | trainer.log_metrics("train", metrics) 298 | trainer.save_metrics("train", metrics) 299 | trainer.save_state() 300 | safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) 301 | 302 | # Evaluation 303 | if training_args.do_eval: 304 | logger.info("*** Evaluate ***") 305 | 306 | metrics = trainer.evaluate() 307 | 308 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_data) 309 | metrics["eval_samples"] = min(max_eval_samples, len(eval_data)) 310 | try: 311 | perplexity = math.exp(metrics["eval_loss"]) 312 | except OverflowError: 313 | perplexity = float("inf") 314 | metrics["perplexity"] = perplexity 315 | 316 | trainer.log_metrics("eval", metrics) 317 | trainer.save_metrics("eval", metrics) 318 | 319 | 320 | if __name__ == "__main__": 321 | train() -------------------------------------------------------------------------------- /speechgpt/src/train/com_sft.py: -------------------------------------------------------------------------------- 1 | #stage 3: chain-of-modality instruct finetuning 2 | import copy 3 | import logging 4 | from dataclasses import dataclass, field 5 | from typing import Optional, Dict, Sequence 6 | import torch 7 | import transformers 8 | from torch.utils.data import Dataset 9 | from datasets import load_dataset 10 | from peft import ( 11 | LoraConfig, 12 | get_peft_model, 13 | get_peft_model_state_dict, 14 | prepare_model_for_int8_training, 15 | set_peft_model_state_dict, 16 | ) 17 | from transformers import Trainer, LlamaForCausalLM, LlamaTokenizer, HfArgumentParser, TrainingArguments, DataCollatorForSeq2Seq 18 | from transformers.trainer_utils import get_last_checkpoint 19 | from speechgpt.utils.prompter import Prompter 20 | import os 21 | import logging 22 | import sys 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | logger.setLevel(logging.DEBUG) 27 | 28 | @dataclass 29 | class ModelArguments: 30 | model_name_or_path: Optional[str] = field( 31 | default=None, 32 | metadata={ 33 | "help": ( 34 | "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." 35 | ) 36 | }, 37 | ) 38 | lora_r: int = field( 39 | default=8, 40 | metadata={ 41 | "help": ( 42 | "loar rank" 43 | ) 44 | }, 45 | ) 46 | lora_alpha: int = field( 47 | default=16, 48 | metadata={ 49 | "help": ( 50 | "loar alpha" 51 | ) 52 | }, 53 | ) 54 | lora_dropout: float = field( 55 | default=0.05, 56 | metadata={ 57 | "help": ( 58 | "loar dropout" 59 | ) 60 | }, 61 | ) 62 | lora_target_modules: str = field( 63 | default="q_proj,v_proj", 64 | metadata={ 65 | "help": ( 66 | "lora target modules" 67 | ) 68 | } 69 | ) 70 | 71 | 72 | @dataclass 73 | class DataArguments: 74 | data_path: str = field( 75 | default="", 76 | metadata={"help": "Path to the training data."}) 77 | prompt_template_name: str = field( 78 | default="alpaca", 79 | metadata={"help": "prompt_template_name"}, 80 | ) 81 | max_train_samples: Optional[int] = field( 82 | default=None, 83 | metadata={ 84 | "help": ( 85 | "For debugging purposes or quicker training, truncate the number of training examples to this " 86 | "value if set." 87 | ) 88 | }, 89 | ) 90 | max_eval_samples: Optional[int] = field( 91 | default=None, 92 | metadata={ 93 | "help": ( 94 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 95 | "value if set." 96 | ) 97 | }, 98 | ) 99 | 100 | @dataclass 101 | class TrainingArguments(transformers.TrainingArguments): 102 | cache_dir: Optional[str] = field( 103 | default=None, 104 | metadata={"help": "Where do you want to store the tokenized data"}, 105 | ) 106 | optim: str = field(default="adamw_torch") 107 | model_max_length: int = field( 108 | default=512, 109 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 110 | ) 111 | val_set_size: int = field( 112 | default=2000, 113 | metadata={"help": "val_set_size"}, 114 | ) 115 | preprocessing_num_workers: int = field( 116 | default=100, 117 | metadata={"help": "preprocessing_num_workers for tokenizing"}, 118 | ) 119 | num_train_epochs: int = field( 120 | default=3, 121 | metadata={"help": "num_epochs"}, 122 | ) 123 | learning_rate: float = field( 124 | default=2e-5, 125 | metadata={"help": "learning_rate"}, 126 | ) 127 | output_dir: str = field( 128 | default="", 129 | metadata={"help": "output_dir"}, 130 | ) 131 | train_on_inputs: bool = field( 132 | default=True, 133 | metadata={"help": "if False, masks out inputs in loss"}, 134 | ) 135 | initial_global_step: int = field( 136 | default=0, 137 | metadata={"help": "initial_global_step"} 138 | ) 139 | 140 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 141 | """Collects the state dict and dump to disk.""" 142 | state_dict = trainer.model.state_dict() 143 | if trainer.args.should_save: 144 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 145 | del state_dict 146 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 147 | 148 | 149 | 150 | def train(): 151 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 152 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 153 | 154 | # Setup logging 155 | logging.basicConfig( 156 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 157 | datefmt="%m/%d/%Y %H:%M:%S", 158 | handlers=[logging.StreamHandler(sys.stdout)], 159 | ) 160 | 161 | 162 | # Log on each process the small summary: 163 | logger.warning( 164 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 165 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 166 | ) 167 | logger.info(f"Training/evaluation parameters {training_args}") 168 | 169 | 170 | # Detecting last checkpoint. 171 | last_checkpoint = None 172 | if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: 173 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 174 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 175 | raise ValueError( 176 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 177 | "Use --overwrite_output_dir to overcome." 178 | ) 179 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 180 | logger.info( 181 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 182 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 183 | ) 184 | 185 | 186 | prompter = Prompter() 187 | 188 | device_map = "auto" 189 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 190 | ddp = world_size != 1 191 | if ddp: 192 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 193 | 194 | 195 | model = LlamaForCausalLM.from_pretrained( 196 | model_args.model_name_or_path, 197 | load_in_8bit=True, 198 | torch_dtype=torch.float16, 199 | device_map=device_map, 200 | ) 201 | 202 | model = prepare_model_for_int8_training(model) 203 | 204 | config = LoraConfig( 205 | r=model_args.lora_r, 206 | lora_alpha=model_args.lora_alpha, 207 | target_modules=model_args.lora_target_modules.split(','), 208 | lora_dropout=model_args.lora_dropout, 209 | bias="none", 210 | task_type="CAUSAL_LM", 211 | ) 212 | model = get_peft_model(model, config) 213 | 214 | tokenizer = LlamaTokenizer.from_pretrained( 215 | model_args.model_name_or_path, 216 | model_max_length=training_args.model_max_length, 217 | padding_side="right", 218 | use_fast=False, 219 | ) 220 | tokenizer.pad_token_id = ( 221 | 0 # unk. we want this to be different from the eos token 222 | ) 223 | tokenizer.padding_side = "left" # Allow batched inference 224 | 225 | #Extend vocab for speech units 226 | if '' not in tokenizer.get_vocab(): 227 | units_size=1000 228 | logger.info(f"Add special unit tokens <0>-<{units_size-1} to tokenizer.vocab") 229 | new_tokens = [f"<{x}>" for x in range(units_size)] 230 | tokenizer.add_tokens(new_tokens) 231 | for token in ['','','[Human]','[SpeechGPT]','','']: 232 | if token not in tokenizer.get_vocab(): 233 | logger.info(f"Add special unit tokens {token} to tokenizer.vocab") 234 | tokenizer.add_tokens([token]) 235 | 236 | #resize embedding 237 | embedding_size = model.get_input_embeddings().weight.shape[0] 238 | if len(tokenizer) > embedding_size: 239 | model.resize_token_embeddings(len(tokenizer)) 240 | 241 | def tokenize(prompt, add_eos_token=True): 242 | # there's probably a way to do this with the tokenizer settings 243 | # but again, gotta move fast 244 | result = tokenizer( 245 | prompt, 246 | truncation=True, 247 | max_length=tokenizer.model_max_length, 248 | padding=False, 249 | return_tensors=None, 250 | ) 251 | if ( 252 | result["input_ids"][-1] != tokenizer.eos_token_id 253 | and len(result["input_ids"]) < tokenizer.model_max_length 254 | and add_eos_token 255 | ): 256 | result["input_ids"].append(tokenizer.eos_token_id) 257 | result["attention_mask"].append(1) 258 | 259 | result["labels"] = result["input_ids"].copy() 260 | 261 | return result 262 | 263 | 264 | def generate_and_tokenize_prompt(data_point): 265 | ''' 266 | COM-style instructions 267 | ''' 268 | full_prompt = prompter.generate_prompt( 269 | data_point["prefix"], 270 | data_point["plain_text"], 271 | ) 272 | tokenized_full_prompt = tokenize(full_prompt) 273 | 274 | user_prompt = prompter.generate_prompt( 275 | data_point["prefix"], 276 | ) 277 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) 278 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 279 | 280 | tokenized_full_prompt["labels"] = [ 281 | -100 282 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 283 | user_prompt_len: 284 | ] 285 | return tokenized_full_prompt 286 | 287 | 288 | 289 | if data_args.data_path.endswith(".json") or data_args.data_path.endswith(".jsonl"): 290 | data = load_dataset("json", data_files=data_args.data_path) 291 | else: 292 | data = load_dataset(data_args.data_path) 293 | 294 | tokenized_cache_file_names = { 295 | "train":os.path.join(training_args.cache_dir, 'tokenized', 'train', 'processed_train.arrow'), 296 | "test":os.path.join(training_args.cache_dir, 'tokenized', 'valid', 'processed_valid.arrow'), 297 | } 298 | 299 | if training_args.val_set_size > 0: 300 | train_val = data["train"].train_test_split( 301 | test_size=training_args.val_set_size, shuffle=True, seed=42 302 | ) 303 | train_val_data = ( 304 | train_val.map( 305 | generate_and_tokenize_prompt, 306 | batched=False, 307 | num_proc=training_args.preprocessing_num_workers, 308 | load_from_cache_file=True, 309 | cache_file_names=tokenized_cache_file_names, 310 | desc=f"generate_and_tokenize_prompt", 311 | ) 312 | ) 313 | train_data = train_val_data["train"] 314 | val_data = train_val_data["test"] 315 | 316 | else: 317 | train_data = data["train"].map( 318 | generate_and_tokenize_prompt, 319 | batched=False, 320 | num_proc=training_args.preprocessing_num_workers, 321 | load_from_cache_file=True, 322 | cache_file_names=tokenized_cache_file_names, 323 | desc=f"generate_and_tokenize_prompt", 324 | ) 325 | val_data = None 326 | 327 | 328 | data_collator = DataCollatorForSeq2Seq( 329 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 330 | ) 331 | 332 | if not ddp and torch.cuda.device_count() > 1: 333 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available 334 | model.is_parallelizable = True 335 | model.model_parallel = True 336 | 337 | trainer = Trainer( 338 | model=model, 339 | tokenizer=tokenizer, 340 | args=training_args, 341 | train_dataset=train_data if training_args.do_train else None, 342 | eval_dataset=val_data if training_args.do_eval else None, 343 | data_collator=data_collator 344 | ) 345 | 346 | # Training 347 | if training_args.do_train: 348 | checkpoint = None 349 | 350 | if training_args.resume_from_checkpoint is not None: 351 | # Check the available weights and load them 352 | checkpoint_name = os.path.join( 353 | training_args.resume_from_checkpoint, "pytorch_model.bin" 354 | ) # Full checkpoint 355 | if not os.path.exists(checkpoint_name): 356 | checkpoint_name = os.path.join( 357 | training_args.resume_from_checkpoint, "adapter_model.bin" 358 | ) # only LoRA model - LoRA config above has to fit 359 | 360 | # The two files above have a different name depending on how they were saved, but are actually the same. 361 | if os.path.exists(checkpoint_name): 362 | print(f"Restarting from {checkpoint_name}") 363 | adapters_weights = torch.load(checkpoint_name) 364 | model = set_peft_model_state_dict(model, adapters_weights) 365 | else: 366 | print(f"Checkpoint {checkpoint_name} not found") 367 | 368 | model.print_trainable_parameters() 369 | model.config.use_cache = False 370 | 371 | old_state_dict = model.state_dict 372 | model.state_dict = ( 373 | lambda self, *_, **__: get_peft_model_state_dict( 374 | self, old_state_dict() 375 | ) 376 | ).__get__(model, type(model)) 377 | 378 | if torch.__version__ >= "2" and sys.platform != "win32": 379 | model = torch.compile(model) 380 | 381 | 382 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 383 | metrics = train_result.metrics 384 | 385 | max_train_samples = ( 386 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_data) 387 | ) 388 | metrics["train_samples"] = min(max_train_samples, len(train_data)) 389 | 390 | trainer.log_metrics("train", metrics) 391 | trainer.save_metrics("train", metrics) 392 | trainer.save_state() 393 | # safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) 394 | model.save_pretrained(training_args.output_dir) 395 | 396 | # Evaluation 397 | if training_args.do_eval: 398 | logger.info("*** Evaluate ***") 399 | 400 | metrics = trainer.evaluate() 401 | 402 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_data) 403 | metrics["eval_samples"] = min(max_eval_samples, len(eval_data)) 404 | try: 405 | perplexity = math.exp(metrics["eval_loss"]) 406 | except OverflowError: 407 | perplexity = float("inf") 408 | metrics["perplexity"] = perplexity 409 | 410 | trainer.log_metrics("eval", metrics) 411 | trainer.save_metrics("eval", metrics) 412 | 413 | 414 | if __name__ == "__main__": 415 | train() -------------------------------------------------------------------------------- /speechgpt/src/train/ma_pretrain.py: -------------------------------------------------------------------------------- 1 | #stage1: modality-adaptation pretraining 2 | import random 3 | import copy 4 | import logging 5 | from dataclasses import dataclass, field 6 | from typing import Optional, Dict, Sequence, List 7 | import torch 8 | from datasets import load_dataset, Dataset, concatenate_datasets 9 | import evaluate 10 | import math 11 | import tqdm 12 | import glob 13 | import transformers 14 | from transformers import Trainer, LlamaForCausalLM, LlamaTokenizer, HfArgumentParser, TrainingArguments, DataCollatorForSeq2Seq 15 | from transformers.trainer_utils import get_last_checkpoint 16 | from speechgpt.utils.prompter import Prompter 17 | import os 18 | from itertools import chain 19 | import logging 20 | import sys 21 | 22 | logger = logging.getLogger(__name__) 23 | logger.setLevel(logging.DEBUG) 24 | 25 | @dataclass 26 | class ModelArguments: 27 | model_name_or_path: Optional[str] = field( 28 | default=None, 29 | metadata={ 30 | "help": ( 31 | "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." 32 | ) 33 | }, 34 | ) 35 | cache_dir: Optional[str] = field( 36 | default=None, 37 | metadata={"help": "Where do you want to store the tokenized data"}, 38 | ) 39 | 40 | #tune embedding 41 | train_embeddings: bool = field( 42 | default=False, 43 | metadata={ 44 | "help": ( 45 | "only train embeddings while training" 46 | ) 47 | }, 48 | ) 49 | 50 | @dataclass 51 | class DataArguments: 52 | data_path: str = field( 53 | default="", 54 | metadata={"help": "Path to the training data."}) 55 | train_file: Optional[str] = field( 56 | default=None, 57 | metadata={"help": "The input training data file (a text file)."} 58 | ) 59 | validation_file: Optional[str] = field( 60 | default=None, 61 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 62 | ) 63 | max_train_samples: Optional[int] = field( 64 | default=None, 65 | metadata={ 66 | "help": ( 67 | "For debugging purposes or quicker training, truncate the number of training examples to this " 68 | "value if set." 69 | ) 70 | }, 71 | ) 72 | max_eval_samples: Optional[int] = field( 73 | default=None, 74 | metadata={ 75 | "help": ( 76 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 77 | "value if set." 78 | ) 79 | }, 80 | ) 81 | streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) 82 | block_size: Optional[int] = field( 83 | default=None, 84 | metadata={ 85 | "help": ( 86 | "Optional input sequence length after tokenization. " 87 | "The training dataset will be truncated in block of this size for training. " 88 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 89 | ) 90 | }, 91 | ) 92 | overwrite_cache: bool = field( 93 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 94 | ) 95 | validation_split_percentage: Optional[int] = field( 96 | default=5, 97 | metadata={ 98 | "help": "The percentage of the train set used as validation set in case there's no validation split" 99 | }, 100 | ) 101 | preprocessing_num_workers: Optional[int] = field( 102 | default=None, 103 | metadata={"help": "The number of processes to use for the preprocessing."}, 104 | ) 105 | use_text: bool = field( 106 | default=False, metadata={"help": "Use text data for pretraining"} 107 | ) 108 | 109 | 110 | @dataclass 111 | class TrainingArguments(transformers.TrainingArguments): 112 | optim: str = field(default="adamw_torch") 113 | 114 | 115 | 116 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 117 | """Collects the state dict and dump to disk.""" 118 | state_dict = trainer.model.state_dict() 119 | if trainer.args.should_save: 120 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 121 | del state_dict 122 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 123 | 124 | 125 | def train(): 126 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 127 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 128 | 129 | # Setup logging 130 | logging.basicConfig( 131 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 132 | datefmt="%m/%d/%Y %H:%M:%S", 133 | handlers=[logging.StreamHandler(sys.stdout)], 134 | ) 135 | 136 | 137 | # Log on each process the small summary: 138 | logger.warning( 139 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 140 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 141 | ) 142 | logger.info(f"Training/evaluation parameters {training_args}") 143 | 144 | 145 | # Detecting last checkpoint. 146 | last_checkpoint = None 147 | if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: 148 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 149 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 150 | raise ValueError( 151 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 152 | "Use --overwrite_output_dir to overcome." 153 | ) 154 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 155 | logger.info( 156 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 157 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 158 | ) 159 | 160 | 161 | model = LlamaForCausalLM.from_pretrained( 162 | model_args.model_name_or_path, 163 | ).to(torch.device(training_args.device)) 164 | 165 | tokenizer = LlamaTokenizer.from_pretrained( 166 | model_args.model_name_or_path, 167 | ) 168 | tokenizer.pad_token_id = ( 169 | 0 # unk. we want this to be different from the eos token 170 | ) 171 | tokenizer.padding_side = "left" # Allow batched inference 172 | #Extend vocab for speech units 173 | if '' not in tokenizer.get_vocab(): 174 | units_size=1000 175 | logger.info(f"Add special unit tokens <0>-<{units_size-1} to tokenizer.vocab") 176 | new_tokens = [f"<{x}>" for x in range(units_size)] + ['', ''] 177 | tokenizer.add_tokens(new_tokens) 178 | 179 | #resize embedding 180 | embedding_size = model.get_input_embeddings().weight.shape[0] 181 | if len(tokenizer) > embedding_size: 182 | model.resize_token_embeddings(len(tokenizer)) 183 | 184 | if model_args.train_embeddings: 185 | logger.info("only update embedding parameters") 186 | for name, param in model.named_parameters(): 187 | if "embed" not in name: 188 | param.requires_grad = False 189 | 190 | 191 | #data 192 | data_files = {} 193 | dataset_args = {} 194 | if data_args.train_file is not None: 195 | data_files["train"] = data_args.train_file 196 | if data_args.validation_file is not None: 197 | data_files["validation"] = data_args.validation_file 198 | extension = ( 199 | data_args.train_file.split(".")[-1] 200 | if data_args.train_file is not None 201 | else data_args.validation_file.split(".")[-1] 202 | ) 203 | if extension == "txt": 204 | extension = "text" 205 | raw_datasets = load_dataset( 206 | extension, 207 | data_files=data_files, 208 | **dataset_args, 209 | ) 210 | # If no validation data is there, validation_split_percentage will be used to divide the dataset. 211 | if "validation" not in raw_datasets.keys(): 212 | raw_datasets["train"] = load_dataset( 213 | extension, 214 | data_files=data_files, 215 | split=f"train[{data_args.validation_split_percentage}%:]", 216 | **dataset_args, 217 | ) 218 | raw_datasets["validation"] = load_dataset( 219 | extension, 220 | data_files=data_files, 221 | split=f"train[:{data_args.validation_split_percentage}%]", 222 | **dataset_args, 223 | ) 224 | 225 | 226 | if training_args.do_train: 227 | column_names = list(raw_datasets["train"].features) 228 | else: 229 | column_names = list(raw_datasets["validation"].features) 230 | text_column_name = "text" if "text" in column_names else column_names[0] 231 | 232 | def tokenize_function(examples): 233 | output = tokenizer(examples[text_column_name]) 234 | return output 235 | 236 | 237 | tokenized_cache_file_names = { 238 | "train":os.path.join(model_args.cache_dir, 'tokenized', 'train', 'processed_train.arrow'), 239 | "validation":os.path.join(model_args.cache_dir, 'tokenized', 'valid', 'processed_valid.arrow'), 240 | } 241 | with training_args.main_process_first(desc="dataset map tokenization"): 242 | if not data_args.streaming: 243 | tokenized_datasets = raw_datasets.map( 244 | tokenize_function, 245 | batched=True, 246 | num_proc=data_args.preprocessing_num_workers, 247 | remove_columns=column_names, 248 | load_from_cache_file=not data_args.overwrite_cache, 249 | desc="Running tokenizer on dataset", 250 | cache_file_names=tokenized_cache_file_names 251 | ) 252 | else: 253 | tokenized_datasets = raw_datasets.map( 254 | tokenize_function, 255 | batched=True, 256 | remove_columns=column_names, 257 | ) 258 | 259 | if data_args.block_size is None: 260 | block_size = tokenizer.model_max_length 261 | if block_size > 1024: 262 | logger.warning( 263 | "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value" 264 | " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can" 265 | " override this default with `--block_size xxx`." 266 | ) 267 | block_size = 1024 268 | else: 269 | if data_args.block_size > tokenizer.model_max_length: 270 | logger.warning( 271 | f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" 272 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 273 | ) 274 | block_size = min(data_args.block_size, tokenizer.model_max_length) 275 | 276 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 277 | def group_texts(examples): 278 | # Concatenate all texts. 279 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 280 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 281 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 282 | # customize this part to your needs. 283 | if total_length >= block_size: 284 | total_length = (total_length // block_size) * block_size 285 | # Split by chunks of max_len. 286 | result = { 287 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 288 | for k, t in concatenated_examples.items() 289 | } 290 | result["labels"] = result["input_ids"].copy() 291 | return result 292 | 293 | 294 | group_cache_file_names = { 295 | "train":os.path.join(model_args.cache_dir, 'group', 'train', 'processed_train.arrow'), 296 | "validation":os.path.join(model_args.cache_dir, 'group', 'valid', 'processed_valid.arrow'), 297 | } 298 | 299 | with training_args.main_process_first(desc="grouping texts together"): 300 | if not data_args.streaming: 301 | lm_datasets = tokenized_datasets.map( 302 | group_texts, 303 | batched=True, 304 | num_proc=data_args.preprocessing_num_workers, 305 | load_from_cache_file=not data_args.overwrite_cache, 306 | desc=f"Grouping texts in chunks of {block_size}", 307 | cache_file_names=group_cache_file_names 308 | ) 309 | else: 310 | lm_datasets = tokenized_datasets.map( 311 | group_texts, 312 | batched=True, 313 | ) 314 | 315 | if training_args.do_train: 316 | if "train" not in tokenized_datasets: 317 | raise ValueError("--do_train requires a train dataset") 318 | train_dataset = lm_datasets["train"] 319 | if data_args.max_train_samples is not None: 320 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 321 | train_dataset = train_dataset.select(range(max_train_samples)) 322 | 323 | 324 | if training_args.do_eval: 325 | if "validation" not in tokenized_datasets: 326 | raise ValueError("--do_eval requires a validation dataset") 327 | eval_dataset = lm_datasets["validation"] 328 | if data_args.max_eval_samples is not None: 329 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 330 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 331 | 332 | 333 | data_collator = DataCollatorForSeq2Seq( 334 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 335 | ) 336 | 337 | trainer = Trainer( 338 | model=model, 339 | tokenizer=tokenizer, 340 | args=training_args, 341 | train_dataset=train_dataset if training_args.do_train else None, 342 | eval_dataset=eval_dataset if training_args.do_eval else None, 343 | data_collator=data_collator 344 | ) 345 | 346 | 347 | # Training 348 | if training_args.do_train: 349 | checkpoint = None 350 | if training_args.resume_from_checkpoint is not None: 351 | checkpoint = training_args.resume_from_checkpoint 352 | elif last_checkpoint is not None: 353 | checkpoint = last_checkpoint 354 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 355 | 356 | metrics = train_result.metrics 357 | 358 | max_train_samples = ( 359 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 360 | ) 361 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 362 | 363 | trainer.log_metrics("train", metrics) 364 | trainer.save_metrics("train", metrics) 365 | trainer.save_state() 366 | safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) 367 | 368 | 369 | 370 | # Evaluation 371 | if training_args.do_eval: 372 | logger.info("*** Evaluate ***") 373 | 374 | metrics = trainer.evaluate() 375 | 376 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 377 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 378 | try: 379 | perplexity = math.exp(metrics["eval_loss"]) 380 | except OverflowError: 381 | perplexity = float("inf") 382 | metrics["perplexity"] = perplexity 383 | 384 | trainer.log_metrics("eval", metrics) 385 | trainer.save_metrics("eval", metrics) 386 | 387 | 388 | 389 | if __name__ == "__main__": 390 | train() -------------------------------------------------------------------------------- /speechgpt/README.md: -------------------------------------------------------------------------------- 1 | # SpeechGPT: Empowering Large Language Models with Intrinsic Cross-Modal Conversational Abilities 2 | 3 | [![](https://img.shields.io/badge/Datasets-SpeechInstruct-yellow)](https://huggingface.co/datasets/fnlp/SpeechInstruct) 4 | 5 |

6 |
7 |

8 | 9 | ## Introduction 10 | SpeechGPT is a large language model with **intrinsic cross-modal conversational abilities**, capable of perceiving and generating multi-model content following human instructions. With discrete speech representations, we first construct **SpeechInstruct**, a large-scale cross-modal speech instruction dataset. Additionally, we employ a three-stage training strategy that includes **modality-adaptation pre-training**, **cross-modal instruction fine-tuning**, and **chain-of-modality instruction fine-tuning**. The experimental results demonstrate that SpeechGPT has an impressive capacity to follow multi-modal human instructions and highlight the potential of handling multiple modalities with one model.
11 | SpeechGPT demos are shown in our [project page](https://0nutation.github.io/SpeechGPT.github.io/). As shown in the demos, SpeechGPT has strong cross-modal instruction-following ability and spoken dialogue ability. SpeechGPT can be **a talking encyclopedia, your personal assistant, your chat partner, a poet, a psychologist and your educational assistant**... 12 | 13 |
14 |
15 |

16 |
17 | SpeechGPT’s capabilities to tackle multiple cross-modal tasks 18 |

19 |
20 |
21 |

22 |
23 | Left: SpeechInstruct construction process. Right: SpeechGPT model structure 24 |

25 | 26 | 27 | 28 | ## Table of Contents 29 | - [Open-source list](#open-source-list) 30 | - [Talk with SpeechGPT](#talk-with-speechgpt) 31 | - [Train SpeechGPT](#train-speechgpt) 32 | - [Finetune SpeechGPT](#finetune-speechgpt) 33 | 34 | 35 | ## Open-source list 36 | ### Models 37 | 38 | - [**SpeechGPT-7B-ma**](https://huggingface.co/fnlp/SpeechGPT-7B-ma): The model obtained after the first-stage modality-adaptation pre-training, which was initialized with LLaMA-7B and further pre-trained on LibriLight speech units. 39 | - [**SpeechGPT-7B-cm**](https://huggingface.co/fnlp/SpeechGPT-7B-cm): The model obtained after the second-stage cross-modal instruction finetuning, which was initialized with SpeechGPT-7B-ma and further finetuned on SpeechInstruct Cross-Modal Instruction set. This is a powerful foundational model that aligns speech and text. 40 | - [**SpeechGPT-7B-com**](https://huggingface.co/fnlp/SpeechGPT-7B-com): The model obtained after the third-stage chain-of-modality instruction lora-finetuning, which was initialized with SpeechGPT-7B-cm and further lora-finetuned on SpeechInstruct Chain-of-Modality Instruction set. This is an adapter-model of SpeechGPT-7B-cm for spoken dialogue. 41 | 42 | ### Datasets 43 | 44 | - [**SpeechInstruct-cross-modal**](https://huggingface.co/datasets/fnlp/SpeechInstruct): The cross-modal instruction set, about 9 million unit-text data pairs tokenized by mHuBERT from large-scale English ASR datasets. data format: 45 | - [**SpeechInstruct-chain-of-modality**](https://huggingface.co/datasets/fnlp/SpeechInstruct): The chain-of-thought style instructions for four input-output formats, namely Speech Instruction-Speech Response, Speech Instruction-Text Response, Text Instruction-Speech Response, and Text Instruction-Text Response. 46 | 47 | SpeechInstruct-cross-modal data format: 48 | ``` 49 | [ 50 | { 51 | "prefix": "You are an AI assistant whose name is SpeechGPT.\n- SpeechGPT is a intrinsic cross-modal conversational language model that is developed by Fudan University. SpeechGPT can understand and communicate fluently with human through speech or text chosen by the user.\n- It can perceive cross-modal inputs and generate cross-modal outputs.\n", 52 | "plain_text": "[Human]: Try to speak out this sentence, please. This is input: The alchemist rode in front, with the falcon on his shoulder. [SpeechGPT]: <661><588><604><157><596><499><596><106><596><189><63><189><665><991><162><202><393><946><327><905><907><597><660><351><557><794><788><59><754><12><977><877><333><873><835><67><940><118><686><613><169><72><644><553><535><935><101><741><384><173><894><787><380><787><196><555><721><944><250><56><812><222><915><143><390><479><330><435><647><246><650><816><325><506><686><208><613><417><755><193><411><452><111><735><6><735><63><665><644><991><535><271><333><196><918><29><202><393><946><734><390><479><330><776><167><761><907><597><660><351><557><794><75><788><15><366><896><627><168><654><659><177><183><609><710><187><493><361><470><821><59><56><198><912><742><840><431><531><76><668><576><803><791><380><660><325><801><549><366><377><164><309><584><605><193><71><39> " 53 | }, 54 | ] 55 | ``` 56 | 57 | SpeechInstruct-chain-of-modality data format: 58 | ``` 59 | [ 60 | { 61 | "prefix": "You are an AI assistant whose name is SpeechGPT.\n- SpeechGPT is a intrinsic cross-modal conversational language model that is developed by Fudan University. SpeechGPT can understand and communicate fluently with human through speech or text chosen by the user.\n- It can perceive cross-modal inputs and generate cross-modal outputs.\n", 62 | "plain_text": "[Human]: <661><987><511><732><951><997><111><982><189><63><665><991><535><101><741><173><945><944><503><641><124><565><734><870><290><978><833><238><761><907><430><901><185><403><557><244><583><788><663><969><896><627><143><515><663><969><660><691><251><412><260><41><740><677><253><380><382><268><506><876><417><755><16><819><80><651><80><651><80><987><588>. [SpeechGPT]: What is a bad term for poop?; [ta] A bad term for poop is excrement. It is usually used as a polite way to refer to fecal waste.; [ua] <497><63><264><644><710><823><565><577><154><331><384><173><945><29><244><326><583><728><576><663><969><896><627><143><38><515><663><24><382><251><676><412><260><41><740><677><253><382><268><876><233><878><609><389><771><865><641><124><878><609><423><384><879><487><219><522><589><337><126><119><663><748><12><671><877><377><385><902><819><619><842><419><997><829><111><666><42><277><63><665><644><389><771><685><437><641><124><258><436><139><340><11><59><518><56><948><86><258><436><139><340><347><376><940><118><944><878><173><641><124><362><734><179><961><931><878><609><423><384><879><219><522><866><337><243><935><101><741><822><89><194><630><86><555><105><79><868><220><156><824><998><870><390><422><330><776><663><969><523><105><79><799><220><357><390><479><422><330><776><485><165><86><501><119><716><205><521><787><935><101><741><89><194><664><835><67><940><118><613><417><755><902><415><772><497>." 63 | }, 64 | ] 65 | ``` 66 | 67 | ## Talk with SpeechGPT 68 | **Due to limited training data and resources, the performance of the open-source SpeechGPT is currently not optimal. Problems such as task recognition errors and inaccuracies in speech recognition may occur. As this project is primarily an exploration in research, we have not increased the amount of pretraining and sft data or training steps to enhance performance. Our hope is that SpeechGPT can serve as a foundational model to encourage research and exploration in the field of speech language models.** 69 | 70 | ### Installation 71 | 72 | ```bash 73 | git clone https://github.com/0nutation/SpeechGPT 74 | cd SpeechGPT/speechgpt 75 | conda create --name SpeechGPT python=3.8 76 | conda activate SpeechGPT 77 | pip install -r requirements.txt 78 | ``` 79 | 80 | 81 | ### Download 82 | To talk with SpeechGPT, you should download [SpeechGPT-7B-cm](https://huggingface.co/fnlp/SpeechGPT-7B-cm) and [SpeechGPT-7B-com](https://huggingface.co/fnlp/SpeechGPT-7B-com) locally. 83 | 84 | You should download mHuBERT model to ```utils/speech2unit/```. Please see [Speech2unit](https://github.com/0nutation/SpeechGPT/blob/main/speechgpt/utils/speech2unit/README.md) for details. 85 | ```bash 86 | s2u_dir="uitls/speech2unit" 87 | cd ${s2u_dir} 88 | wget https://dl.fbaipublicfiles.com/hubert/mhubert_base_vp_en_es_fr_it3.pt 89 | wget https://dl.fbaipublicfiles.com/hubert/mhubert_base_vp_en_es_fr_it3_L11_km1000.bin 90 | ``` 91 | 92 | You should download the unit-vocoder to ```utils/vocoder/```. Please see [vocoder](https://github.com/0nutation/SpeechGPT/blob/main/speechgpt/utils/vocoder/README.md) for details. 93 | ```bash 94 | vocoder_dir="utils/vocoder/" 95 | cd ${vocoder_dir} 96 | wget https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/config.json -O config.json 97 | wget https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/g_00500000 -O vocoder.pt 98 | ``` 99 | 100 | ### CLI Inference 101 | ```bash 102 | python3 speechgpt/src/infer/cli_infer.py \ 103 | --model-name-or-path "path/to/SpeechGPT-7B-cm" \ 104 | --lora-weights "path/to/SpeechGPT-7B-com" \ 105 | --s2u-dir "${s2u_dir}" \ 106 | --vocoder-dir "${vocoder_dir} \ 107 | --output-dir "output" 108 | ``` 109 | **Notes:** 110 | For speech input, you can provide the path to the audio file. For ASR or TTS tasks, you must prefix the speech or text with ```this is input: ```, otherwise, it may be recognized incorrectly. 111 | The speech response will be saved to a ```.wav``` file, and detailed responses will be saved in a JSON file. The paths to these files will be indicated in the response. 112 | 113 | Here are some examples of talking with SpeechGPT: 114 | 115 | **Textual dialogue example** 116 | ``` 117 | Please talk with SpeechGPT: 118 | Who is Lebron James? 119 | Response: 120 | Lebron James is an American professional basketball player for the Los Angeles Lakers of the National Basketball Association (NBA). He is considered one of the greatest basketball players of all time and is known for his athleticism, scoring ability, and leadership skills. He is a four-time NBA MVP, a 14-time NBA All-Star, a 13-time All-NBA selection, and a two-time Olympic gold medalist. 121 | Response json is saved in output/responses.json 122 | ``` 123 | 124 | **Spoken dialogue example** 125 | ``` 126 | Please talk with SpeechGPT: 127 | prompts/0.wav 128 | Transcript: What are the main causes of climate change? 129 | Text response: The main causes of climate change are human activities such as burning fossil fuels, deforestation, and agricultural practices. These activities release greenhouse gases, like carbon dioxide and Methane, into the atmosphere which trap heat and cause the Earth's temperature to rise. 130 | Speech repsonse is saved in output/wav/answer_0.wav 131 | Response json is saved in output/responses.json 132 | ``` 133 | 134 | **ASR example** 135 | ``` 136 | Please talk with SpeechGPT: 137 | Recognize this speech, this is input: prompts/1.wav 138 | Response: 139 | today is a sunny day. 140 | Response json is saved in output/responses.json 141 | ``` 142 | 143 | **TTS example** 144 | ``` 145 | Please talk with SpeechGPT: 146 | Read this sentence aloud, this is input: Today is a sunny day. 147 | Response: 148 | <661> <987> <520> <982> <681> <982> <681> <982> <681> <982> <681> <982> <189> <63> <662> <79> <868> <220> <196> <166> <549> <822> <89> <194> <633> <14> <855> <183> <609> <389> <771> <865> <641> <124> <362> <734> <742> <98> <519> <26> <204> <280> <668> <167> <104> <650> <179> <961> <428> <950> <82> <165> <196> <166> <549> <822> <89> <194> <458> <726> <603> <819> <651> <133> <651> <133> <186> <133> <186> <133> <186> <511> <186> <511> 149 | Speech repsonse is saved in output/wav/answer_1.wav 150 | Response json is saved in output/responses.json 151 | ``` 152 | 153 | 154 | ### Gradio Web UI 155 | ```bash 156 | python3 speechgpt/src/infer/web_infer.py \ 157 | --model-name-or-path "path/to/SpeechGPT-7B-cm" \ 158 | --lora-weights "path/to/SpeechGPT-7B-com" \ 159 | --s2u-dir "${s2u_dir}" \ 160 | --vocoder-dir "${vocoder_dir}" \ 161 | --output-dir "output/" 162 | ``` 163 | 164 | 165 | ## Train SpeechGPT 166 | ### Stage1: Modality-adaptation Pre-training 167 | First, utilize mHuBERT for discretizing the LibriLight dataset to obtain discrete unit sequences for stage1 training. You can refer to the data processing methods in [Speech2unit](https://github.com/0nutation/SpeechGPT/blob/main/speechgpt/utils/speech2unit/README.md). 168 | 169 | Second, divide the discrete units into a training set and a development set, and save them in the following format in the files ```data/stage1/train.txt``` and ```data/stage1/dev.txt```: 170 | ``` 171 | <189><247><922><991><821><258><485><974><284><466><969><523><196><202><881><331><822><853><432><32><742><98><519><26><204><280><576><384><879><901><555><944><366><641><124><362><734><156><824><462><761><907><430><81><597><716><205><521><470><821><677><355><483><641><124><243><290><978><82><620><915><470><821><576><384><466><398><212><455><931><579><969><778><45><914><445><469><576><803><6><803><791><377><506><835><67><940><613><417><755><237><224><452><121><736> 172 | <300><189><63><6><665><991><881><331><6><384><879><945><29><244><583><874><655><837><81><627><545><124><337><850><412><213><260><41><740><797><211><488><961><428><6><196><555><944><873><32><683><700><955><812><328><915><166><250><56><903><86><233><479><330><776><167><104><764><259><921><366><663><432><431><531><976><314><822><89><664><377><611><479><417> 173 | <189><735><991><39><565><734><32><742><98><519><26><204><280><668><576><803><791><660><555><233><787><101><741><466><969><219><107><459><491><556><384><733><219><501><445><137><910><523><793><50><981><230><534><321><948><86><116><281><62><462><104><70><918><743><15><212><455><143><836><173><944><958><390><422><66><776><258><436><139><663><432><742><98><519><589><243><126><260><41><444><6><655><764><969><219><727><85><297><700><362><493><6><493><361><393><946><6><470><821><246><655><837><81><969><916><584><819><544><452><158><452><736> 174 | ``` 175 | Third, you should download LLaMA 7B(HuggingFace) to ```llama/hf/7B```. 176 | 177 | Now you can start stage1 training: 178 | To perform distributed training, you must specify the correct values for ```NNODE```, ```NODE_RANK```, ```MASTER_ADDR```, and ```MASTER_PORT```. 179 | ```bash 180 | bash scripts/ma_pretrain.sh ${NNODE} ${NODE_RANK} ${MASTER_ADDR} ${MASTER_PORT} 181 | ``` 182 | 183 | ### Stage 2: Cross-modal Instruction Finetuning 184 | You should download [SpeechInstruct Cross-modal Instruction set](https://huggingface.co/datasets/fnlp/SpeechInstruct/resolve/main/cross_modal_instruction.jsonl) to ```data/stage2/```. 185 | 186 | If you want to skip stage1 training, you can download ```SpeechGPT-7B-ma``` to ```output/stage1/```. 187 | 188 | Now you can start stage2 training: 189 | To perform distributed training, you must specify the correct values for ```NNODE```, ```NODE_RANK```, ```MASTER_ADDR```, and ```MASTER_PORT```. 190 | ```bash 191 | bash scripts/cm_sft.sh ${NNODE} ${NODE_RANK} ${MASTER_ADDR} ${MASTER_PORT} 192 | ``` 193 | 194 | ### Stage 3: Chain-of-modality Instruction Finetuning 195 | You should download [SpeechInstruct Chain-of-modality Instruction set](https://huggingface.co/datasets/fnlp/SpeechInstruct/resolve/main/chain_of_modality_instruction.jsonl) to ```data/stage3/```. 196 | 197 | If you want to skip stage1 and stage2, you can download ```SpeechGPT-7B-cm``` to ```output/stage2/```. 198 | 199 | Now you can start stage3 training: 200 | To perform distributed training, you must specify the correct values for ```NNODE```, ```NODE_RANK```, ```MASTER_ADDR```, and ```MASTER_PORT```. 201 | ```bash 202 | bash scripts/com_sft.sh ${NNODE} ${NODE_RANK} ${MASTER_ADDR} ${MASTER_PORT} 203 | ``` 204 | 205 | ## Finetune SpeechGPT 206 | ```Speech-7B-cm``` is a foundational model with strong alignment between speech and text. We encourage fine-tuning SpeechGPT based on this model. 207 | 208 | Step1: prepare your data following the format in [SpeechInstruct Cross-modal Instruction set](https://huggingface.co/datasets/fnlp/SpeechInstruct/resolve/main/cross_modal_instruction.jsonl). 209 | 210 | Step2: download [SpeechGPT-7B-cm](https://huggingface.co/fnlp/SpeechGPT-7B-cm) locally. 211 | 212 | Step3: Modify the ```METAROOT```, ```DATAROOT```, and ```OUTROOT``` parameters in the ```scripts/cm_sft.sh``` script to yours and then run it. For LoRA fine-tuning, update the ```METAROOT```, ```DATAROOT```, and ```OUTROOT``` parameters in the ```scripts/com_sft.sh``` script and run it. 213 | 214 | 215 | ## Acknowledgements 216 | - We express our appreciation to Fuliang Weng and Rong Ye for their valuable suggestions and guidance. 217 | - [MOSS](https://github.com/OpenLMLab/MOSS): We use moss-sft-002-data. 218 | - [stanford_alpaca](https://github.com/tatsu-lab/stanford_alpaca):The codebase we built upon. 219 | 220 | ## Citation 221 | If you find SpeechGPT useful for your research and applications, please cite using the BibTex: 222 | 223 | ``` 224 | @misc{zhang2023speechgpt, 225 | title={SpeechGPT: Empowering Large Language Models with Intrinsic Cross-Modal Conversational Abilities}, 226 | author={Dong Zhang and Shimin Li and Xin Zhang and Jun Zhan and Pengyu Wang and Yaqian Zhou and Xipeng Qiu}, 227 | year={2023}, 228 | eprint={2305.11000}, 229 | archivePrefix={arXiv}, 230 | primaryClass={cs.CL} 231 | } 232 | ``` 233 | -------------------------------------------------------------------------------- /speechgpt/utils/text2unit/binary/dict.en.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 1 3 | <0> 1 4 | <1> 1 5 | <2> 1 6 | <3> 1 7 | <4> 1 8 | <5> 1 9 | <6> 1 10 | <7> 1 11 | <8> 1 12 | <9> 1 13 | <10> 1 14 | <11> 1 15 | <12> 1 16 | <13> 1 17 | <14> 1 18 | <15> 1 19 | <16> 1 20 | <17> 1 21 | <18> 1 22 | <19> 1 23 | <20> 1 24 | <21> 1 25 | <22> 1 26 | <23> 1 27 | <24> 1 28 | <25> 1 29 | <26> 1 30 | <27> 1 31 | <28> 1 32 | <29> 1 33 | <30> 1 34 | <31> 1 35 | <32> 1 36 | <33> 1 37 | <34> 1 38 | <35> 1 39 | <36> 1 40 | <37> 1 41 | <38> 1 42 | <39> 1 43 | <40> 1 44 | <41> 1 45 | <42> 1 46 | <43> 1 47 | <44> 1 48 | <45> 1 49 | <46> 1 50 | <47> 1 51 | <48> 1 52 | <49> 1 53 | <50> 1 54 | <51> 1 55 | <52> 1 56 | <53> 1 57 | <54> 1 58 | <55> 1 59 | <56> 1 60 | <57> 1 61 | <58> 1 62 | <59> 1 63 | <60> 1 64 | <61> 1 65 | <62> 1 66 | <63> 1 67 | <64> 1 68 | <65> 1 69 | <66> 1 70 | <67> 1 71 | <68> 1 72 | <69> 1 73 | <70> 1 74 | <71> 1 75 | <72> 1 76 | <73> 1 77 | <74> 1 78 | <75> 1 79 | <76> 1 80 | <77> 1 81 | <78> 1 82 | <79> 1 83 | <80> 1 84 | <81> 1 85 | <82> 1 86 | <83> 1 87 | <84> 1 88 | <85> 1 89 | <86> 1 90 | <87> 1 91 | <88> 1 92 | <89> 1 93 | <90> 1 94 | <91> 1 95 | <92> 1 96 | <93> 1 97 | <94> 1 98 | <95> 1 99 | <96> 1 100 | <97> 1 101 | <98> 1 102 | <99> 1 103 | <100> 1 104 | <101> 1 105 | <102> 1 106 | <103> 1 107 | <104> 1 108 | <105> 1 109 | <106> 1 110 | <107> 1 111 | <108> 1 112 | <109> 1 113 | <110> 1 114 | <111> 1 115 | <112> 1 116 | <113> 1 117 | <114> 1 118 | <115> 1 119 | <116> 1 120 | <117> 1 121 | <118> 1 122 | <119> 1 123 | <120> 1 124 | <121> 1 125 | <122> 1 126 | <123> 1 127 | <124> 1 128 | <125> 1 129 | <126> 1 130 | <127> 1 131 | <128> 1 132 | <129> 1 133 | <130> 1 134 | <131> 1 135 | <132> 1 136 | <133> 1 137 | <134> 1 138 | <135> 1 139 | <136> 1 140 | <137> 1 141 | <138> 1 142 | <139> 1 143 | <140> 1 144 | <141> 1 145 | <142> 1 146 | <143> 1 147 | <144> 1 148 | <145> 1 149 | <146> 1 150 | <147> 1 151 | <148> 1 152 | <149> 1 153 | <150> 1 154 | <151> 1 155 | <152> 1 156 | <153> 1 157 | <154> 1 158 | <155> 1 159 | <156> 1 160 | <157> 1 161 | <158> 1 162 | <159> 1 163 | <160> 1 164 | <161> 1 165 | <162> 1 166 | <163> 1 167 | <164> 1 168 | <165> 1 169 | <166> 1 170 | <167> 1 171 | <168> 1 172 | <169> 1 173 | <170> 1 174 | <171> 1 175 | <172> 1 176 | <173> 1 177 | <174> 1 178 | <175> 1 179 | <176> 1 180 | <177> 1 181 | <178> 1 182 | <179> 1 183 | <180> 1 184 | <181> 1 185 | <182> 1 186 | <183> 1 187 | <184> 1 188 | <185> 1 189 | <186> 1 190 | <187> 1 191 | <188> 1 192 | <189> 1 193 | <190> 1 194 | <191> 1 195 | <192> 1 196 | <193> 1 197 | <194> 1 198 | <195> 1 199 | <196> 1 200 | <197> 1 201 | <198> 1 202 | <199> 1 203 | <200> 1 204 | <201> 1 205 | <202> 1 206 | <203> 1 207 | <204> 1 208 | <205> 1 209 | <206> 1 210 | <207> 1 211 | <208> 1 212 | <209> 1 213 | <210> 1 214 | <211> 1 215 | <212> 1 216 | <213> 1 217 | <214> 1 218 | <215> 1 219 | <216> 1 220 | <217> 1 221 | <218> 1 222 | <219> 1 223 | <220> 1 224 | <221> 1 225 | <222> 1 226 | <223> 1 227 | <224> 1 228 | <225> 1 229 | <226> 1 230 | <227> 1 231 | <228> 1 232 | <229> 1 233 | <230> 1 234 | <231> 1 235 | <232> 1 236 | <233> 1 237 | <234> 1 238 | <235> 1 239 | <236> 1 240 | <237> 1 241 | <238> 1 242 | <239> 1 243 | <240> 1 244 | <241> 1 245 | <242> 1 246 | <243> 1 247 | <244> 1 248 | <245> 1 249 | <246> 1 250 | <247> 1 251 | <248> 1 252 | <249> 1 253 | <250> 1 254 | <251> 1 255 | <252> 1 256 | <253> 1 257 | <254> 1 258 | <255> 1 259 | <256> 1 260 | <257> 1 261 | <258> 1 262 | <259> 1 263 | <260> 1 264 | <261> 1 265 | <262> 1 266 | <263> 1 267 | <264> 1 268 | <265> 1 269 | <266> 1 270 | <267> 1 271 | <268> 1 272 | <269> 1 273 | <270> 1 274 | <271> 1 275 | <272> 1 276 | <273> 1 277 | <274> 1 278 | <275> 1 279 | <276> 1 280 | <277> 1 281 | <278> 1 282 | <279> 1 283 | <280> 1 284 | <281> 1 285 | <282> 1 286 | <283> 1 287 | <284> 1 288 | <285> 1 289 | <286> 1 290 | <287> 1 291 | <288> 1 292 | <289> 1 293 | <290> 1 294 | <291> 1 295 | <292> 1 296 | <293> 1 297 | <294> 1 298 | <295> 1 299 | <296> 1 300 | <297> 1 301 | <298> 1 302 | <299> 1 303 | <300> 1 304 | <301> 1 305 | <302> 1 306 | <303> 1 307 | <304> 1 308 | <305> 1 309 | <306> 1 310 | <307> 1 311 | <308> 1 312 | <309> 1 313 | <310> 1 314 | <311> 1 315 | <312> 1 316 | <313> 1 317 | <314> 1 318 | <315> 1 319 | <316> 1 320 | <317> 1 321 | <318> 1 322 | <319> 1 323 | <320> 1 324 | <321> 1 325 | <322> 1 326 | <323> 1 327 | <324> 1 328 | <325> 1 329 | <326> 1 330 | <327> 1 331 | <328> 1 332 | <329> 1 333 | <330> 1 334 | <331> 1 335 | <332> 1 336 | <333> 1 337 | <334> 1 338 | <335> 1 339 | <336> 1 340 | <337> 1 341 | <338> 1 342 | <339> 1 343 | <340> 1 344 | <341> 1 345 | <342> 1 346 | <343> 1 347 | <344> 1 348 | <345> 1 349 | <346> 1 350 | <347> 1 351 | <348> 1 352 | <349> 1 353 | <350> 1 354 | <351> 1 355 | <352> 1 356 | <353> 1 357 | <354> 1 358 | <355> 1 359 | <356> 1 360 | <357> 1 361 | <358> 1 362 | <359> 1 363 | <360> 1 364 | <361> 1 365 | <362> 1 366 | <363> 1 367 | <364> 1 368 | <365> 1 369 | <366> 1 370 | <367> 1 371 | <368> 1 372 | <369> 1 373 | <370> 1 374 | <371> 1 375 | <372> 1 376 | <373> 1 377 | <374> 1 378 | <375> 1 379 | <376> 1 380 | <377> 1 381 | <378> 1 382 | <379> 1 383 | <380> 1 384 | <381> 1 385 | <382> 1 386 | <383> 1 387 | <384> 1 388 | <385> 1 389 | <386> 1 390 | <387> 1 391 | <388> 1 392 | <389> 1 393 | <390> 1 394 | <391> 1 395 | <392> 1 396 | <393> 1 397 | <394> 1 398 | <395> 1 399 | <396> 1 400 | <397> 1 401 | <398> 1 402 | <399> 1 403 | <400> 1 404 | <401> 1 405 | <402> 1 406 | <403> 1 407 | <404> 1 408 | <405> 1 409 | <406> 1 410 | <407> 1 411 | <408> 1 412 | <409> 1 413 | <410> 1 414 | <411> 1 415 | <412> 1 416 | <413> 1 417 | <414> 1 418 | <415> 1 419 | <416> 1 420 | <417> 1 421 | <418> 1 422 | <419> 1 423 | <420> 1 424 | <421> 1 425 | <422> 1 426 | <423> 1 427 | <424> 1 428 | <425> 1 429 | <426> 1 430 | <427> 1 431 | <428> 1 432 | <429> 1 433 | <430> 1 434 | <431> 1 435 | <432> 1 436 | <433> 1 437 | <434> 1 438 | <435> 1 439 | <436> 1 440 | <437> 1 441 | <438> 1 442 | <439> 1 443 | <440> 1 444 | <441> 1 445 | <442> 1 446 | <443> 1 447 | <444> 1 448 | <445> 1 449 | <446> 1 450 | <447> 1 451 | <448> 1 452 | <449> 1 453 | <450> 1 454 | <451> 1 455 | <452> 1 456 | <453> 1 457 | <454> 1 458 | <455> 1 459 | <456> 1 460 | <457> 1 461 | <458> 1 462 | <459> 1 463 | <460> 1 464 | <461> 1 465 | <462> 1 466 | <463> 1 467 | <464> 1 468 | <465> 1 469 | <466> 1 470 | <467> 1 471 | <468> 1 472 | <469> 1 473 | <470> 1 474 | <471> 1 475 | <472> 1 476 | <473> 1 477 | <474> 1 478 | <475> 1 479 | <476> 1 480 | <477> 1 481 | <478> 1 482 | <479> 1 483 | <480> 1 484 | <481> 1 485 | <482> 1 486 | <483> 1 487 | <484> 1 488 | <485> 1 489 | <486> 1 490 | <487> 1 491 | <488> 1 492 | <489> 1 493 | <490> 1 494 | <491> 1 495 | <492> 1 496 | <493> 1 497 | <494> 1 498 | <495> 1 499 | <496> 1 500 | <497> 1 501 | <498> 1 502 | <499> 1 503 | <500> 1 504 | <501> 1 505 | <502> 1 506 | <503> 1 507 | <504> 1 508 | <505> 1 509 | <506> 1 510 | <507> 1 511 | <508> 1 512 | <509> 1 513 | <510> 1 514 | <511> 1 515 | <512> 1 516 | <513> 1 517 | <514> 1 518 | <515> 1 519 | <516> 1 520 | <517> 1 521 | <518> 1 522 | <519> 1 523 | <520> 1 524 | <521> 1 525 | <522> 1 526 | <523> 1 527 | <524> 1 528 | <525> 1 529 | <526> 1 530 | <527> 1 531 | <528> 1 532 | <529> 1 533 | <530> 1 534 | <531> 1 535 | <532> 1 536 | <533> 1 537 | <534> 1 538 | <535> 1 539 | <536> 1 540 | <537> 1 541 | <538> 1 542 | <539> 1 543 | <540> 1 544 | <541> 1 545 | <542> 1 546 | <543> 1 547 | <544> 1 548 | <545> 1 549 | <546> 1 550 | <547> 1 551 | <548> 1 552 | <549> 1 553 | <550> 1 554 | <551> 1 555 | <552> 1 556 | <553> 1 557 | <554> 1 558 | <555> 1 559 | <556> 1 560 | <557> 1 561 | <558> 1 562 | <559> 1 563 | <560> 1 564 | <561> 1 565 | <562> 1 566 | <563> 1 567 | <564> 1 568 | <565> 1 569 | <566> 1 570 | <567> 1 571 | <568> 1 572 | <569> 1 573 | <570> 1 574 | <571> 1 575 | <572> 1 576 | <573> 1 577 | <574> 1 578 | <575> 1 579 | <576> 1 580 | <577> 1 581 | <578> 1 582 | <579> 1 583 | <580> 1 584 | <581> 1 585 | <582> 1 586 | <583> 1 587 | <584> 1 588 | <585> 1 589 | <586> 1 590 | <587> 1 591 | <588> 1 592 | <589> 1 593 | <590> 1 594 | <591> 1 595 | <592> 1 596 | <593> 1 597 | <594> 1 598 | <595> 1 599 | <596> 1 600 | <597> 1 601 | <598> 1 602 | <599> 1 603 | <600> 1 604 | <601> 1 605 | <602> 1 606 | <603> 1 607 | <604> 1 608 | <605> 1 609 | <606> 1 610 | <607> 1 611 | <608> 1 612 | <609> 1 613 | <610> 1 614 | <611> 1 615 | <612> 1 616 | <613> 1 617 | <614> 1 618 | <615> 1 619 | <616> 1 620 | <617> 1 621 | <618> 1 622 | <619> 1 623 | <620> 1 624 | <621> 1 625 | <622> 1 626 | <623> 1 627 | <624> 1 628 | <625> 1 629 | <626> 1 630 | <627> 1 631 | <628> 1 632 | <629> 1 633 | <630> 1 634 | <631> 1 635 | <632> 1 636 | <633> 1 637 | <634> 1 638 | <635> 1 639 | <636> 1 640 | <637> 1 641 | <638> 1 642 | <639> 1 643 | <640> 1 644 | <641> 1 645 | <642> 1 646 | <643> 1 647 | <644> 1 648 | <645> 1 649 | <646> 1 650 | <647> 1 651 | <648> 1 652 | <649> 1 653 | <650> 1 654 | <651> 1 655 | <652> 1 656 | <653> 1 657 | <654> 1 658 | <655> 1 659 | <656> 1 660 | <657> 1 661 | <658> 1 662 | <659> 1 663 | <660> 1 664 | <661> 1 665 | <662> 1 666 | <663> 1 667 | <664> 1 668 | <665> 1 669 | <666> 1 670 | <667> 1 671 | <668> 1 672 | <669> 1 673 | <670> 1 674 | <671> 1 675 | <672> 1 676 | <673> 1 677 | <674> 1 678 | <675> 1 679 | <676> 1 680 | <677> 1 681 | <678> 1 682 | <679> 1 683 | <680> 1 684 | <681> 1 685 | <682> 1 686 | <683> 1 687 | <684> 1 688 | <685> 1 689 | <686> 1 690 | <687> 1 691 | <688> 1 692 | <689> 1 693 | <690> 1 694 | <691> 1 695 | <692> 1 696 | <693> 1 697 | <694> 1 698 | <695> 1 699 | <696> 1 700 | <697> 1 701 | <698> 1 702 | <699> 1 703 | <700> 1 704 | <701> 1 705 | <702> 1 706 | <703> 1 707 | <704> 1 708 | <705> 1 709 | <706> 1 710 | <707> 1 711 | <708> 1 712 | <709> 1 713 | <710> 1 714 | <711> 1 715 | <712> 1 716 | <713> 1 717 | <714> 1 718 | <715> 1 719 | <716> 1 720 | <717> 1 721 | <718> 1 722 | <719> 1 723 | <720> 1 724 | <721> 1 725 | <722> 1 726 | <723> 1 727 | <724> 1 728 | <725> 1 729 | <726> 1 730 | <727> 1 731 | <728> 1 732 | <729> 1 733 | <730> 1 734 | <731> 1 735 | <732> 1 736 | <733> 1 737 | <734> 1 738 | <735> 1 739 | <736> 1 740 | <737> 1 741 | <738> 1 742 | <739> 1 743 | <740> 1 744 | <741> 1 745 | <742> 1 746 | <743> 1 747 | <744> 1 748 | <745> 1 749 | <746> 1 750 | <747> 1 751 | <748> 1 752 | <749> 1 753 | <750> 1 754 | <751> 1 755 | <752> 1 756 | <753> 1 757 | <754> 1 758 | <755> 1 759 | <756> 1 760 | <757> 1 761 | <758> 1 762 | <759> 1 763 | <760> 1 764 | <761> 1 765 | <762> 1 766 | <763> 1 767 | <764> 1 768 | <765> 1 769 | <766> 1 770 | <767> 1 771 | <768> 1 772 | <769> 1 773 | <770> 1 774 | <771> 1 775 | <772> 1 776 | <773> 1 777 | <774> 1 778 | <775> 1 779 | <776> 1 780 | <777> 1 781 | <778> 1 782 | <779> 1 783 | <780> 1 784 | <781> 1 785 | <782> 1 786 | <783> 1 787 | <784> 1 788 | <785> 1 789 | <786> 1 790 | <787> 1 791 | <788> 1 792 | <789> 1 793 | <790> 1 794 | <791> 1 795 | <792> 1 796 | <793> 1 797 | <794> 1 798 | <795> 1 799 | <796> 1 800 | <797> 1 801 | <798> 1 802 | <799> 1 803 | <800> 1 804 | <801> 1 805 | <802> 1 806 | <803> 1 807 | <804> 1 808 | <805> 1 809 | <806> 1 810 | <807> 1 811 | <808> 1 812 | <809> 1 813 | <810> 1 814 | <811> 1 815 | <812> 1 816 | <813> 1 817 | <814> 1 818 | <815> 1 819 | <816> 1 820 | <817> 1 821 | <818> 1 822 | <819> 1 823 | <820> 1 824 | <821> 1 825 | <822> 1 826 | <823> 1 827 | <824> 1 828 | <825> 1 829 | <826> 1 830 | <827> 1 831 | <828> 1 832 | <829> 1 833 | <830> 1 834 | <831> 1 835 | <832> 1 836 | <833> 1 837 | <834> 1 838 | <835> 1 839 | <836> 1 840 | <837> 1 841 | <838> 1 842 | <839> 1 843 | <840> 1 844 | <841> 1 845 | <842> 1 846 | <843> 1 847 | <844> 1 848 | <845> 1 849 | <846> 1 850 | <847> 1 851 | <848> 1 852 | <849> 1 853 | <850> 1 854 | <851> 1 855 | <852> 1 856 | <853> 1 857 | <854> 1 858 | <855> 1 859 | <856> 1 860 | <857> 1 861 | <858> 1 862 | <859> 1 863 | <860> 1 864 | <861> 1 865 | <862> 1 866 | <863> 1 867 | <864> 1 868 | <865> 1 869 | <866> 1 870 | <867> 1 871 | <868> 1 872 | <869> 1 873 | <870> 1 874 | <871> 1 875 | <872> 1 876 | <873> 1 877 | <874> 1 878 | <875> 1 879 | <876> 1 880 | <877> 1 881 | <878> 1 882 | <879> 1 883 | <880> 1 884 | <881> 1 885 | <882> 1 886 | <883> 1 887 | <884> 1 888 | <885> 1 889 | <886> 1 890 | <887> 1 891 | <888> 1 892 | <889> 1 893 | <890> 1 894 | <891> 1 895 | <892> 1 896 | <893> 1 897 | <894> 1 898 | <895> 1 899 | <896> 1 900 | <897> 1 901 | <898> 1 902 | <899> 1 903 | <900> 1 904 | <901> 1 905 | <902> 1 906 | <903> 1 907 | <904> 1 908 | <905> 1 909 | <906> 1 910 | <907> 1 911 | <908> 1 912 | <909> 1 913 | <910> 1 914 | <911> 1 915 | <912> 1 916 | <913> 1 917 | <914> 1 918 | <915> 1 919 | <916> 1 920 | <917> 1 921 | <918> 1 922 | <919> 1 923 | <920> 1 924 | <921> 1 925 | <922> 1 926 | <923> 1 927 | <924> 1 928 | <925> 1 929 | <926> 1 930 | <927> 1 931 | <928> 1 932 | <929> 1 933 | <930> 1 934 | <931> 1 935 | <932> 1 936 | <933> 1 937 | <934> 1 938 | <935> 1 939 | <936> 1 940 | <937> 1 941 | <938> 1 942 | <939> 1 943 | <940> 1 944 | <941> 1 945 | <942> 1 946 | <943> 1 947 | <944> 1 948 | <945> 1 949 | <946> 1 950 | <947> 1 951 | <948> 1 952 | <949> 1 953 | <950> 1 954 | <951> 1 955 | <952> 1 956 | <953> 1 957 | <954> 1 958 | <955> 1 959 | <956> 1 960 | <957> 1 961 | <958> 1 962 | <959> 1 963 | <960> 1 964 | <961> 1 965 | <962> 1 966 | <963> 1 967 | <964> 1 968 | <965> 1 969 | <966> 1 970 | <967> 1 971 | <968> 1 972 | <969> 1 973 | <970> 1 974 | <971> 1 975 | <972> 1 976 | <973> 1 977 | <974> 1 978 | <975> 1 979 | <976> 1 980 | <977> 1 981 | <978> 1 982 | <979> 1 983 | <980> 1 984 | <981> 1 985 | <982> 1 986 | <983> 1 987 | <984> 1 988 | <985> 1 989 | <986> 1 990 | <987> 1 991 | <988> 1 992 | <989> 1 993 | <990> 1 994 | <991> 1 995 | <992> 1 996 | <993> 1 997 | <994> 1 998 | <995> 1 999 | <996> 1 1000 | <997> 1 1001 | <998> 1 1002 | <999> 1 1003 | ▁t 1 1004 | he 1 1005 | ▁a 1 1006 | in 1 1007 | ▁the 1 1008 | re 1 1009 | ▁s 1 1010 | ▁w 1 1011 | ▁o 1 1012 | er 1 1013 | is 1 1014 | nd 1 1015 | ed 1 1016 | on 1 1017 | at 1 1018 | ▁b 1 1019 | ▁c 1 1020 | it 1 1021 | en 1 1022 | ▁f 1 1023 | ou 1 1024 | ▁h 1 1025 | or 1 1026 | ▁m 1 1027 | es 1 1028 | an 1 1029 | ar 1 1030 | ▁p 1 1031 | ▁of 1 1032 | as 1 1033 | ing 1 1034 | ▁in 1 1035 | al 1 1036 | ▁to 1 1037 | ▁d 1 1038 | ▁and 1 1039 | ▁l 1 1040 | ▁th 1 1041 | ic 1 1042 | om 1 1043 | ▁n 1 1044 | ll 1 1045 | le 1 1046 | ▁he 1 1047 | ion 1 1048 | ▁g 1 1049 | The 1 1050 | ly 1 1051 | ▁e 1 1052 | ▁is 1 1053 | ▁was 1 1054 | ent 1 1055 | ad 1 1056 | ▁be 1 1057 | st 1 1058 | ▁re 1 1059 | ow 1 1060 | ot 1 1061 | ve 1 1062 | id 1 1063 | im 1 1064 | ut 1 1065 | ac 1 1066 | am 1 1067 | et 1 1068 | ver 1 1069 | ro 1 1070 | se 1 1071 | ▁on 1 1072 | ld 1 1073 | gh 1 1074 | ur 1 1075 | ay 1 1076 | ir 1 1077 | ▁y 1 1078 | ▁for 1 1079 | ▁st 1 1080 | ▁u 1 1081 | ▁it 1 1082 | ith 1 1083 | ter 1 1084 | ▁as 1 1085 | ▁that 1 1086 | ri 1 1087 | ▁we 1 1088 | ▁with 1 1089 | ce 1 1090 | ct 1 1091 | ▁wh 1 1092 | ch 1 1093 | ▁his 1 1094 | ▁i 1 1095 | ▁you 1 1096 | her 1 1097 | il 1 1098 | ▁an 1 1099 | ol 1 1100 | ain 1 1101 | ation 1 1102 | and 1 1103 | ▁C 1 1104 | ▁at 1 1105 | ▁S 1 1106 | ▁se 1 1107 | ▁al 1 1108 | ght 1 1109 | ▁are 1 1110 | ill 1 1111 | red 1 1112 | od 1 1113 | ▁r 1 1114 | us 1 1115 | ge 1 1116 | ul 1 1117 | ▁con 1 1118 | ▁not 1 1119 | ▁by 1 1120 | ke 1 1121 | pe 1 1122 | ave 1 1123 | ist 1 1124 | ▁su 1 1125 | ▁had 1 1126 | ▁de 1 1127 | He 1 1128 | rom 1 1129 | ▁com 1 1130 | ▁sh 1 1131 | th 1 1132 | ers 1 1133 | ▁M 1 1134 | ore 1 1135 | ▁me 1 1136 | ess 1 1137 | est 1 1138 | nt 1 1139 | ▁her 1 1140 | ab 1 1141 | ap 1 1142 | ▁A 1 1143 | um 1 1144 | ould 1 1145 | It 1 1146 | ▁ne 1 1147 | if 1 1148 | all 1 1149 | ▁v 1 1150 | ▁k 1 1151 | qu 1 1152 | ant 1 1153 | ard 1 1154 | ▁ch 1 1155 | ▁from 1 1156 | ▁B 1 1157 | ate 1 1158 | op 1 1159 | res 1 1160 | os 1 1161 | ▁were 1 1162 | ▁ex 1 1163 | art 1 1164 | ▁have 1 1165 | ▁" 1 1166 | ame 1 1167 | un 1 1168 | ▁pro 1 1169 | ity 1 1170 | em 1 1171 | oc 1 1172 | ▁him 1 1173 | ive 1 1174 | ▁P 1 1175 | ▁this 1 1176 | ig 1 1177 | ie 1 1178 | ▁she 1 1179 | ight 1 1180 | ▁sa 1 1181 | our 1 1182 | el 1 1183 | ies 1 1184 | ort 1 1185 | ▁or 1 1186 | iv 1 1187 | ▁but 1 1188 | ome 1 1189 | so 1 1190 | rou 1 1191 | ok 1 1192 | ▁pl 1 1193 | ong 1 1194 | ▁all 1 1195 | Th 1 1196 | ast 1 1197 | ous 1 1198 | ▁do 1 1199 | ust 1 1200 | ▁so 1 1201 | ag 1 1202 | ▁man 1 1203 | ▁go 1 1204 | ated 1 1205 | ich 1 1206 | ▁also 1 1207 | ▁tr 1 1208 | ud 1 1209 | ▁j 1 1210 | ▁I 1 1211 | ▁one 1 1212 | ▁R 1 1213 | ally 1 1214 | out 1 1215 | ak 1 1216 | ▁le 1 1217 | own 1 1218 | ▁has 1 1219 | ▁wor 1 1220 | ost 1 1221 | ▁L 1 1222 | ▁they 1 1223 | ▁my 1 1224 | ▁whe 1 1225 | pp 1 1226 | ack 1 1227 | ▁int 1 1228 | ment 1 1229 | ▁up 1 1230 | ▁G 1 1231 | ▁kn 1 1232 | ind 1 1233 | ide 1 1234 | ▁ab 1 1235 | ▁D 1 1236 | ▁H 1 1237 | ine 1 1238 | very 1 1239 | av 1 1240 | ▁no 1 1241 | ther 1 1242 | ure 1 1243 | ▁said 1 1244 | ▁tw 1 1245 | ect 1 1246 | ▁out 1 1247 | ▁W 1 1248 | ip 1 1249 | one 1 1250 | ▁co 1 1251 | ▁fe 1 1252 | ish 1 1253 | ▁been 1 1254 | ▁lo 1 1255 | ▁their 1 1256 | ▁can 1 1257 | ▁E 1 1258 | ▁us 1 1259 | ▁which 1 1260 | ▁bo 1 1261 | ▁there 1 1262 | ▁F 1 1263 | ▁cl 1 1264 | ▁T 1 1265 | ry 1 1266 | ▁un 1 1267 | ian 1 1268 | pt 1 1269 | ound 1 1270 | ▁ar 1 1271 | ▁N 1 1272 | ▁sp 1 1273 | itt 1 1274 | ff 1 1275 | are 1 1276 | ace 1 1277 | ▁ro 1 1278 | ▁them 1 1279 | hed 1 1280 | age 1 1281 | ▁who 1 1282 | ass 1 1283 | cc 1 1284 | ▁te 1 1285 | au 1 1286 | And 1 1287 | ree 1 1288 | ia 1 1289 | ice 1 1290 | fter 1 1291 | cl 1 1292 | ount 1 1293 | per 1 1294 | ime 1 1295 | ater 1 1296 | ▁would 1 1297 | sel 1 1298 | ars 1 1299 | ber 1 1300 | ary 1 1301 | ions 1 1302 | ▁comp 1 1303 | wn 1 1304 | ite 1 1305 | ire 1 1306 | ▁pr 1 1307 | able 1 1308 | ▁J 1 1309 | ▁ad 1 1310 | ord 1 1311 | ▁cont 1 1312 | ▁two 1 1313 | ▁en 1 1314 | ▁will 1 1315 | ake 1 1316 | This 1 1317 | ▁other 1 1318 | te 1 1319 | ial 1 1320 | ue 1 1321 | ▁ag 1 1322 | ▁im 1 1323 | pl 1 1324 | ▁gr 1 1325 | ang 1 1326 | ▁into 1 1327 | ▁part 1 1328 | ub 1 1329 | way 1 1330 | ▁when 1 1331 | ance 1 1332 | ▁li 1 1333 | ▁qu 1 1334 | ark 1 1335 | nder 1 1336 | ach 1 1337 | In 1 1338 | act 1 1339 | orn 1 1340 | ▁some 1 1341 | ra 1 1342 | ▁dis 1 1343 | uc 1 1344 | ▁its 1 1345 | og 1 1346 | ved 1 1347 | ose 1 1348 | ▁what 1 1349 | ick 1 1350 | ood 1 1351 | ▁more 1 1352 | irst 1 1353 | iz 1 1354 | ▁then 1 1355 | ade 1 1356 | ans 1 1357 | ild 1 1358 | ence 1 1359 | reat 1 1360 | ". 1 1361 | other 1 1362 | ▁time 1 1363 | ▁fl 1 1364 | ▁per 1 1365 | urn 1 1366 | ▁br 1 1367 | ib 1 1368 | ▁bec 1 1369 | ied 1 1370 | ings 1 1371 | ical 1 1372 | ▁res 1 1373 | ▁ye 1 1374 | ▁am 1 1375 | ▁off 1 1376 | ▁sc 1 1377 | ▁bl 1 1378 | ven 1 1379 | ▁ser 1 1380 | ▁first 1 1381 | ep 1 1382 | ign 1 1383 | ▁over 1 1384 | ult 1 1385 | ▁after 1 1386 | ren 1 1387 | self 1 1388 | ▁K 1 1389 | ▁if 1 1390 | ▁about 1 1391 | ▁any 1 1392 | oll 1 1393 | we 1 1394 | ▁could 1 1395 | ond 1 1396 | ▁pe 1 1397 | ile 1 1398 | ress 1 1399 | ▁comm 1 1400 | ▁des 1 1401 | ▁did 1 1402 | ▁your 1 1403 | ▁now 1 1404 | ▁spe 1 1405 | ail 1 1406 | ious 1 1407 | ▁only 1 1408 | iss 1 1409 | ool 1 1410 | orm 1 1411 | ▁work 1 1412 | ink 1 1413 | ▁than 1 1414 | ations 1 1415 | They 1 1416 | vers 1 1417 | ▁O 1 1418 | ▁very 1 1419 | ▁thou 1 1420 | land 1 1421 | ition 1 1422 | ▁like 1 1423 | ng 1 1424 | fore 1 1425 | ft 1 1426 | ne 1 1427 | Wh 1 1428 | ict 1 1429 | She 1 1430 | ens 1 1431 | ru 1 1432 | hing 1 1433 | ather 1 1434 | ▁pre 1 1435 | ▁play 1 1436 | ▁rec 1 1437 | ▁know 1 1438 | ever 1 1439 | ▁po 1 1440 | ▁day 1 1441 | ▁rem 1 1442 | ew 1 1443 | man 1 1444 | int 1 1445 | ory 1 1446 | fe 1 1447 | les 1 1448 | ors 1 1449 | ▁how 1 1450 | igh 1 1451 | ward 1 1452 | ▁sm 1 1453 | mer 1 1454 | ▁look 1 1455 | ife 1 1456 | rough 1 1457 | There 1 1458 | anc 1 1459 | ily 1 1460 | ▁Ch 1 1461 | ull 1 1462 | ▁again 1 1463 | ath 1 1464 | ▁under 1 1465 | ase 1 1466 | ittle 1 1467 | ▁made 1 1468 | oth 1 1469 | ons 1 1470 | ▁cons 1 1471 | ▁ind 1 1472 | air 1 1473 | ▁ri 1 1474 | chool 1 1475 | ction 1 1476 | ▁see 1 1477 | ▁through 1 1478 | ▁loc 1 1479 | ▁most 1 1480 | ▁well 1 1481 | ert 1 1482 | ▁may 1 1483 | clud 1 1484 | ▁bet 1 1485 | But 1 1486 | ▁these 1 1487 | ▁new 1 1488 | ▁call 1 1489 | ents 1 1490 | ▁little 1 1491 | ▁wr 1 1492 | ▁cr 1 1493 | ▁imp 1 1494 | rent 1 1495 | ▁St 1 1496 | ▁mo 1 1497 | ▁U 1 1498 | round 1 1499 | ered 1 1500 | ▁long 1 1501 | ▁old 1 1502 | ▁hand 1 1503 | ise 1 1504 | aw 1 1505 | gin 1 1506 | ked 1 1507 | uch 1 1508 | ople 1 1509 | ▁includ 1 1510 | amp 1 1511 | hes 1 1512 | ▁many 1 1513 | ▁used 1 1514 | vel 1 1515 | ▁ap 1 1516 | end 1 1517 | ▁down 1 1518 | ates 1 1519 | ▁V 1 1520 | ▁back 1 1521 | led 1 1522 | ell 1 1523 | ister 1 1524 | ty 1 1525 | pen 1 1526 | ▁good 1 1527 | ▁three 1 1528 | ks 1 1529 | ▁boy 1 1530 | tern 1 1531 | ▁never 1 1532 | ▁app 1 1533 | ful 1 1534 | ried 1 1535 | ▁own 1 1536 | ix 1 1537 | ▁before 1 1538 | ▁every 1 1539 | ▁found 1 1540 | ▁should 1 1541 | ▁str 1 1542 | ▁such 1 1543 | ▁ret 1 1544 | oy 1 1545 | ular 1 1546 | gan 1 1547 | ited 1 1548 | ▁sub 1 1549 | ouse 1 1550 | ▁say 1 1551 | ▁acc 1 1552 | ▁great 1 1553 | ▁people 1 1554 | form 1 1555 | lic 1 1556 | ▁car 1 1557 | ▁act 1 1558 | aint 1 1559 | ough 1 1560 | ▁bu 1 1561 | rib 1 1562 | ▁came 1 1563 | ▁mar 1 1564 | ▁att 1 1565 | ▁ob 1 1566 | ▁ac 1 1567 | ▁pres 1 1568 | omet 1 1569 | hen 1 1570 | ▁ear 1 1571 | ▁name 1 1572 | ▁just 1 1573 | oss 1 1574 | ▁rep 1 1575 | uring 1 1576 | ▁where 1 1577 | ▁count 1 1578 | His 1 1579 | ▁dist 1 1580 | ▁way 1 1581 | ▁ent 1 1582 | ▁form 1 1583 | ▁reg 1 1584 | ener 1 1585 | mb 1 1586 | ▁get 1 1587 | ▁much 1 1588 | ▁our 1 1589 | vent 1 1590 | ▁sec 1 1591 | ▁still 1 1592 | ▁sl 1 1593 | ▁sur 1 1594 | ▁must 1 1595 | ss 1 1596 | ock 1 1597 | ec 1 1598 | ves 1 1599 | als 1 1600 | ▁thought 1 1601 | ▁think 1 1602 | ollow 1 1603 | ▁child 1 1604 | old 1 1605 | ▁sy 1 1606 | ▁upon 1 1607 | ared 1 1608 | ▁lar 1 1609 | ness 1 1610 | ph 1 1611 | ▁rel 1 1612 | ased 1 1613 | ished 1 1614 | any 1 1615 | ▁dif 1 1616 | ▁set 1 1617 | amed 1 1618 | ained 1 1619 | ause 1 1620 | ont 1 1621 | ▁town 1 1622 | ject 1 1623 | You 1 1624 | ▁years 1 1625 | ▁appe 1 1626 | ▁went 1 1627 | ▁known 1 1628 | ▁want 1 1629 | How 1 1630 | ▁here 1 1631 | ible 1 1632 | ince 1 1633 | ▁gl 1 1634 | ▁cur 1 1635 | ▁bel 1 1636 | ▁end 1 1637 | ▁even 1 1638 | ▁char 1 1639 | ▁head 1 1640 | ▁later 1 1641 | ady 1 1642 | ▁called 1 1643 | ▁wom 1 1644 | ange 1 1645 | ived 1 1646 | ▁don 1 1647 | enc 1 1648 | ual 1 1649 | ▁Un 1 1650 | ently 1 1651 | ▁come 1 1652 | ▁ke 1 1653 | ton 1 1654 | ▁fam 1 1655 | ▁men 1 1656 | ▁ph 1 1657 | ains 1 1658 | ters 1 1659 | ities 1 1660 | outh 1 1661 | ▁pass 1 1662 | imes 1 1663 | ▁prod 1 1664 | ▁supp 1 1665 | ▁ass 1 1666 | ▁inst 1 1667 | ob 1 1668 | ▁life 1 1669 | ▁ele 1 1670 | und 1 1671 | ▁same 1 1672 | ually 1 1673 | ▁four 1 1674 | ten 1 1675 | orth 1 1676 | sh 1 1677 | ational 1 1678 | ▁follow 1 1679 | iet 1 1680 | ower 1 1681 | ▁somet 1 1682 | erm 1 1683 | ative 1 1684 | ric 1 1685 | ▁let 1 1686 | ▁became 1 1687 | ▁miss 1 1688 | ures 1 1689 | ows 1 1690 | ▁fri 1 1691 | ▁himself 1 1692 | ▁ey 1 1693 | ▁gu 1 1694 | inc 1 1695 | ale 1 1696 | ating 1 1697 | ▁near 1 1698 | ▁last 1 1699 | rew 1 1700 | ier 1 1701 | ▁school 1 1702 | urned 1 1703 | ▁met 1 1704 | ▁pub 1 1705 | its 1 1706 | app 1 1707 | ident 1 1708 | xt 1 1709 | ▁inc 1 1710 | alk 1 1711 | ▁mon 1 1712 | ▁being 1 1713 | oh 1 1714 | ram 1 1715 | ific 1 1716 | As 1 1717 | ▁make 1 1718 | ank 1 1719 | ▁inter 1 1720 | ning 1 1721 | ▁both 1 1722 | az 1 1723 | ▁stud 1 1724 | ▁ed 1 1725 | ▁seem 1 1726 | ▁young 1 1727 | ▁min 1 1728 | ▁mus 1 1729 | That 1 1730 | cept 1 1731 | ▁too 1 1732 | ▁high 1 1733 | ▁dec 1 1734 | ble 1 1735 | aking 1 1736 | ▁gener 1 1737 | ▁em 1 1738 | ▁num 1 1739 | cted 1 1740 | gg 1 1741 | ," 1 1742 | ross 1 1743 | io 1 1744 | ▁spec 1 1745 | ▁cour 1 1746 | ▁house 1 1747 | ▁home 1 1748 | ▁cap 1 1749 | ▁open 1 1750 | ivers 1 1751 | ▁born 1 1752 | ween 1 1753 | get 1 1754 | ▁took 1 1755 | arm 1 1756 | ics 1 1757 | ▁word 1 1758 | ▁fr 1 1759 | ▁right 1 1760 | ield 1 1761 | ▁ref 1 1762 | ired 1 1763 | ▁pers 1 1764 | ▁small 1 1765 | ▁art 1 1766 | ret 1 1767 | ▁those 1 1768 | ▁away 1 1769 | ▁fin 1 1770 | ah 1 1771 | ere 1 1772 | ▁left 1 1773 | ove 1 1774 | ▁add 1 1775 | ▁pol 1 1776 | What 1 1777 | ody 1 1778 | ▁while 1 1779 | ▁place 1 1780 | ▁between 1 1781 | ▁contin 1 1782 | ▁sever 1 1783 | olog 1 1784 | ised 1 1785 | ▁father 1 1786 | ways 1 1787 | ▁mat 1 1788 | ner 1 1789 | ▁might 1 1790 | ▁world 1 1791 | ▁night 1 1792 | ▁book 1 1793 | ▁cent 1 1794 | up 1 1795 | ling 1 1796 | son 1 1797 | oun 1 1798 | ined 1 1799 | ives 1 1800 | less 1 1801 | ▁Ar 1 1802 | ▁several 1 1803 | ▁without 1 1804 | ists 1 1805 | ism 1 1806 | ley 1 1807 | ▁sim 1 1808 | ▁city 1 1809 | vern 1 1810 | ▁another 1 1811 | ▁occ 1 1812 | ▁use 1 1813 | oot 1 1814 | ▁year 1 1815 | ▁ma 1 1816 | ▁friend 1 1817 | ▁inv 1 1818 | ▁find 1 1819 | We 1 1820 | ▁water 1 1821 | ▁going 1 1822 | ▁Br 1 1823 | ▁trans 1 1824 | ▁exp 1 1825 | ▁saw 1 1826 | att 1 1827 | ments 1 1828 | irl 1 1829 | ▁prov 1 1830 | ▁take 1 1831 | ▁each 1 1832 | ▁named 1 1833 | ason 1 1834 | ided 1 1835 | hem 1 1836 | ▁point 1 1837 | ▁ev 1 1838 | thing 1 1839 | li 1 1840 | ird 1 1841 | ▁mod 1 1842 | ▁often 1 1843 | ature 1 1844 | ▁serv 1 1845 | ctor 1 1846 | ▁disc 1 1847 | ized 1 1848 | These 1 1849 | ▁hel 1 1850 | ▁col 1 1851 | ▁rest 1 1852 | An 1 1853 | ▁care 1 1854 | rigin 1 1855 | ▁var 1 1856 | osed 1 1857 | ted 1 1858 | ▁light 1 1859 | ern 1 1860 | However 1 1861 | ral 1 1862 | augh 1 1863 | ▁run 1 1864 | ery 1 1865 | roup 1 1866 | ▁arm 1 1867 | ins 1 1868 | ▁large 1 1869 | ▁fil 1 1870 | ▁sign 1 1871 | ▁coll 1 1872 | ▁diffe 1 1873 | ▁dire 1 1874 | ee 1 1875 | ▁began 1 1876 | ▁located 1 1877 | ▁sing 1 1878 | ys 1 1879 | ▁dr 1 1880 | ▁inf 1 1881 | ▁bus 1 1882 | ames 1 1883 | ▁Y 1 1884 | ient 1 1885 | ▁ins 1 1886 | ▁second 1 1887 | ▁once 1 1888 | ants 1 1889 | For 1 1890 | ▁tell 1 1891 | ▁air 1 1892 | lish 1 1893 | aving 1 1894 | akes 1 1895 | ck 1 1896 | ▁New 1 1897 | ple 1 1898 | ▁does 1 1899 | ines 1 1900 | uth 1 1901 | ▁show 1 1902 | der 1 1903 | iness 1 1904 | ▁during 1 1905 | ▁sw 1 1906 | ▁pop 1 1907 | ▁got 1 1908 | ▁stand 1 1909 | ute 1 1910 | ▁always 1 1911 | ▁nothing 1 1912 | ork 1 1913 | ting 1 1914 | uck 1 1915 | ured 1 1916 | den 1 1917 | ▁need 1 1918 | After 1 1919 | row 1 1920 | ▁tre 1 1921 | ▁wind 1 1922 | ▁mister 1 1923 | tend 1 1924 | ▁origin 1 1925 | ▁eyes 1 1926 | ▁red 1 1927 | ▁asked 1 1928 | ▁area 1 1929 | ", 1 1930 | At 1 1931 | ▁room 1 1932 | ▁put 1 1933 | When 1 1934 | ued 1 1935 | ▁conf 1 1936 | cess 1 1937 | ohn 1 1938 | ▁heart 1 1939 | ov 1 1940 | ▁eff 1 1941 | ages 1 1942 | ▁woman 1 1943 | ▁sk 1 1944 | ▁el 1 1945 | ▁few 1 1946 | ▁things 1 1947 | iew 1 1948 | ▁gre 1 1949 | erv 1 1950 | ote 1 1951 | ▁children 1 1952 | ▁land 1 1953 | ▁seen 1 1954 | ides 1 1955 | eg 1 1956 | ash 1 1957 | velop 1 1958 | urch 1 1959 | ts 1 1960 | ▁war 1 1961 | stem 1 1962 | ▁eng 1 1963 | ▁because 1 1964 | ▁Eng 1 1965 | ically 1 1966 | ▁mother 1 1967 | ▁par 1 1968 | by 1 1969 | here 1 1970 | ▁current 1 1971 | ▁ever 1 1972 | ook 1 1973 | aring 1 1974 | ream 1 1975 | ane 1 1976 | ately 1 1977 | ched 1 1978 | ▁poss 1 1979 | enn 1 1980 | ▁girl 1 1981 | ▁music 1 1982 | ertain 1 1983 | ▁group 1 1984 | ▁against 1 1985 | ution 1 1986 | ▁oper 1 1987 | ▁love 1 1988 | ▁ext 1 1989 | ▁start 1 1990 | ▁lead 1 1991 | ▁School 1 1992 | All 1 1993 | iversity 1 1994 | ann 1 1995 | ▁help 1 1996 | ▁family 1 1997 | ▁bre 1 1998 | br 1 1999 | ▁however 1 2000 | ither 1 2001 | ▁film 1 2002 | ▁song 1 2003 | ▁const 1 2004 | ets 1 2005 | ▁rece 1 2006 | ▁shall 1 2007 | ▁cle 1 2008 | pect 1 2009 | ▁read 1 2010 | ▁main 1 2011 | ▁govern 1 2012 | ▁vis 1 2013 | ▁soon 1 2014 | On 1 2015 | ▁fore 1 2016 | ▁six 1 2017 | ▁looked 1 2018 | wards 1 2019 | ene 1 2020 | ▁son 1 2021 | ccess 1 2022 | ▁alb 1 2023 | ug 1 2024 | ices 1 2025 | ural 1 2026 | rop 1 2027 | ▁dep 1 2028 | most 1 2029 | alf 1 2030 | ces 1 2031 | day 1 2032 | arch 1 2033 | ▁commun 1 2034 | ▁king 1 2035 | ▁build 1 2036 | ▁che 1 2037 | ▁around 1 2038 | iver 1 2039 | ▁det 1 2040 | ▁told 1 2041 | ▁ans 1 2042 | ially 1 2043 | ▁bro 1 2044 | omm 1 2045 | ▁hum 1 2046 | Some 1 2047 | ▁side 1 2048 | ▁five 1 2049 | ▁different 1 2050 | port 1 2051 | ▁comple 1 2052 | ▁develop 1 2053 | ▁belie 1 2054 | ▁person 1 2055 | ger 1 2056 | aster 1 2057 | ▁album 1 2058 | ott 1 2059 | aced 1 2060 | ▁ple 1 2061 | ▁The 1 2062 | ▁system 1 2063 | ▁something 1 2064 | ▁public 1 2065 | ▁state 1 2066 | ▁fell 1 2067 | ▁prom 1 2068 | ues 1 2069 | ▁among 1 2070 | ience 1 2071 | ▁list 1 2072 | ▁face 1 2073 | ▁ho 1 2074 | ▁yet 1 2075 | rict 1 2076 | ▁better 1 2077 | ility 1 2078 | rap 1 2079 | iting 1 2080 | ▁next 1 2081 | yp 1 2082 | meric 1 2083 | ▁repl 1 2084 | ▁along 1 2085 | ▁hard 1 2086 | ▁far 1 2087 | aken 1 2088 | ▁mind 1 2089 | hip 1 2090 | To 1 2091 | ▁pri 1 2092 | ▁white 1 2093 | ▁thing 1 2094 | ▁number 1 2095 | ying 1 2096 | ▁mark 1 2097 | ▁full 1 2098 | ▁hist 1 2099 | ▁unt 1 2100 | ▁present 1 2101 | ▁best 1 2102 | ▁short 1 2103 | ▁feel 1 2104 | ▁common 1 2105 | ron 1 2106 | ▁hor 1 2107 | pend 1 2108 | ▁black 1 2109 | ▁sil 1 2110 | ▁heard 1 2111 | used 1 2112 | ▁es 1 2113 | ▁beh 1 2114 | ▁within 1 2115 | ▁design 1 2116 | ▁mom 1 2117 | ▁give 1 2118 | ▁Al 1 2119 | jor 1 2120 | ▁term 1 2121 | ▁pur 1 2122 | ▁aff 1 2123 | ▁mem 1 2124 | ▁law 1 2125 | ▁perform 1 2126 | ▁vill 1 2127 | Al 1 2128 | ▁pat 1 2129 | ision 1 2130 | ▁success 1 2131 | err 1 2132 | ▁though 1 2133 | ▁av 1 2134 | ▁produ 1 2135 | ▁He 1 2136 | ▁record 1 2137 | iving 1 2138 | ▁seemed 1 2139 | ape 1 2140 | ▁sat 1 2141 | ison 1 2142 | ense 1 2143 | ▁include 1 2144 | ▁played 1 2145 | ▁charac 1 2146 | Its 1 2147 | elt 1 2148 | ▁since 1 2149 | ▁cond 1 2150 | aps 1 2151 | aut 1 2152 | ▁mil 1 2153 | ▁mean 1 2154 | rist 1 2155 | ital 1 2156 | oice 1 2157 | uced 1 2158 | ats 1 2159 | ax 1 2160 | arn 1 2161 | ▁team 1 2162 | ▁prog 1 2163 | ▁bur 1 2164 | ▁cor 1 2165 | ▁death 1 2166 | ▁power 1 2167 | unt 1 2168 | ▁In 1 2169 | ▁half 1 2170 | ▁expl 1 2171 | ony 1 2172 | ▁partic 1 2173 | ▁Count 1 2174 | ▁won 1 2175 | iam 1 2176 | ▁done 1 2177 | ▁adv 1 2178 | ▁med 1 2179 | ced 1 2180 | oke 1 2181 | ▁inde 1 2182 | ▁days 1 2183 | ilt 1 2184 | ▁sent 1 2185 | ians 1 2186 | ▁rele 1 2187 | ury 1 2188 | ▁brother 1 2189 | ▁station 1 2190 | ▁adm 1 2191 | ract 1 2192 | ▁gave 1 2193 | ior 1 2194 | angu 1 2195 | ▁hour 1 2196 | ▁knew 1 2197 | ▁result 1 2198 | erest 1 2199 | ▁door 1 2200 | ▁real 1 2201 | ▁kind 1 2202 | uss 1 2203 | ung 1 2204 | ▁ste 1 2205 | ▁conc 1 2206 | atch 1 2207 | ▁sound 1 2208 | oney 1 2209 | ▁import 1 2210 | ering 1 2211 | ▁ide 1 2212 | ▁val 1 2213 | ▁est 1 2214 | pr 1 2215 | aces 1 2216 | aid 1 2217 | med 1 2218 | ▁pos 1 2219 | ▁def 1 2220 | ief 1 2221 | ▁bri 1 2222 | ▁wife 1 2223 | ▁tri 1 2224 | ▁Le 1 2225 | ▁why 1 2226 | ▁Americ 1 2227 | ▁Sh 1 2228 | ▁band 1 2229 | ▁fact 1 2230 | ▁dri 1 2231 | If 1 2232 | ▁less 1 2233 | ▁University 1 2234 | ▁interest 1 2235 | use 1 2236 | ▁hold 1 2237 | ▁foot 1 2238 | ably 1 2239 | ▁sun 1 2240 | ances 1 2241 | aim 1 2242 | ▁class 1 2243 | yl 1 2244 | ▁major 1 2245 | ▁fall 1 2246 | ▁almost 1 2247 | ▁soc 1 2248 | undred 1 2249 | ▁moment 1 2250 | ▁Ind 1 2251 | ium 1 2252 | ▁certain 1 2253 | St 1 2254 | hern 1 2255 | ▁sit 1 2256 | ged 1 2257 | iter 1 2258 | ▁wa 1 2259 | ▁underst 1 2260 | irc 1 2261 | oint 1 2262 | ▁eight 1 2263 | rought 1 2264 | ▁hands 1 2265 | ▁sol 1 2266 | ▁desc 1 2267 | ional 1 2268 | So 1 2269 | ball 1 2270 | be 1 2271 | reet 1 2272 | ege 1 2273 | ster 1 2274 | ▁feat 1 2275 | ▁line 1 2276 | ▁sn 1 2277 | ▁cre 1 2278 | iven 1 2279 | ▁fire 1 2280 | ▁eas 1 2281 | ▁County 1 2282 | ▁country 1 2283 | ▁opp 1 2284 | omp 1 2285 | ▁until 1 2286 | ▁daugh 1 2287 | ▁times 1 2288 | ▁early 1 2289 | ened 1 2290 | ▁equ 1 2291 | ina 1 2292 | por 1 2293 | ▁served 1 2294 | side 1 2295 | ▁heav 1 2296 | ▁round 1 2297 | erman 1 2298 | ▁requ 1 2299 | ▁held 1 2300 | One 1 2301 | illed 1 2302 | ▁unc 1 2303 | ▁prof 1 2304 | ▁become 1 2305 | ▁quite 1 2306 | ▁Fr 1 2307 | ▁game 1 2308 | ving 1 2309 | ards 1 2310 | Ch 1 2311 | ▁bar 1 2312 | iety 1 2313 | ▁turn 1 2314 | ▁stop 1 2315 | orning 1 2316 | ik 1 2317 | ▁langu 1 2318 | ▁live 1 2319 | ▁exper 1 2320 | With 1 2321 | els 1 2322 | ator 1 2323 | ached 1 2324 | ▁turned 1 2325 | ones 1 2326 | ▁company 1 2327 | ▁top 1 2328 | ention 1 2329 | ▁didn 1 2330 | ▁Cl 1 2331 | ex 1 2332 | ▁whole 1 2333 | ▁toget 1 2334 | ▁together 1 2335 | ▁business 1 2336 | bert 1 2337 | idge 1 2338 | ▁ca 1 2339 | ▁course 1 2340 | ▁mer 1 2341 | ▁hundred 1 2342 | een 1 2343 | ▁god 1 2344 | ship 1 2345 | ▁Sc 1 2346 | ▁having 1 2347 | ▁sir 1 2348 | Many 1 2349 | ▁suff 1 2350 | ▁obs 1 2351 | ▁offic 1 2352 | ▁strong 1 2353 | Then 1 2354 | itions 1 2355 | ipp 1 2356 | ressed 1 2357 | rod 1 2358 | enty 1 2359 | ▁north 1 2360 | ▁chang 1 2361 | ▁div 1 2362 | ully 1 2363 | ▁camp 1 2364 | ▁dou 1 2365 | ▁fre 1 2366 | ▁brought 1 2367 | ▁mag 1 2368 | ▁attend 1 2369 | ▁local 1 2370 | ▁Th 1 2371 | selves 1 2372 | ▁leg 1 2373 | ries 1 2374 | ▁hon 1 2375 | ▁words 1 2376 | ead 1 2377 | raw 1 2378 | ection 1 2379 | ▁enough 1 2380 | ▁road 1 2381 | ▁morning 1 2382 | itten 1 2383 | ▁received 1 2384 | ▁government 1 2385 | ▁tem 1 2386 | ider 1 2387 | ▁keep 1 2388 | atic 1 2389 | umb 1 2390 | ▁others 1 2391 | ▁ve 1 2392 | wo 1 2393 | ▁close 1 2394 | ▁really 1 2395 | ology 1 2396 | ▁wee 1 2397 | oul 1 2398 | eng 1 2399 | lished 1 2400 | ▁appear 1 2401 | ▁alc 1 2402 | ▁polit 1 2403 | ▁English 1 2404 | ▁Is 1 2405 | ▁organ 1 2406 | co 1 2407 | eor 1 2408 | ille 1 2409 | Her 1 2410 | not 1 2411 | ▁popular 1 2412 | ▁mount 1 2413 | ▁appro 1 2414 | ▁dark 1 2415 | ▁port 1 2416 | me 1 2417 | ▁felt 1 2418 | ▁effect 1 2419 | ize 1 2420 | idd 1 2421 | ▁princ 1 2422 | ▁move 1 2423 | ▁village 1 2424 | ▁character 1 2425 | con 1 2426 | ▁south 1 2427 | rence 1 2428 | aining 1 2429 | ▁happ 1 2430 | que 1 2431 | ▁yes 1 2432 | ▁dest 1 2433 | ler 1 2434 | ▁former 1 2435 | ▁ground 1 2436 | ots 1 2437 | ham 1 2438 | ped 1 2439 | ▁win 1 2440 | aged 1 2441 | ▁following 1 2442 | uro 1 2443 | ▁return 1 2444 | ▁continued 1 2445 | ▁money 1 2446 | ▁married 1 2447 | ▁support 1 2448 | ▁An 1 2449 | ▁returned 1 2450 | ▁voice 1 2451 | ▁missus 1 2452 | gy 1 2453 | ▁sure 1 2454 | sequ 1 2455 | ▁month 1 2456 | ▁talk 1 2457 | ▁circ 1 2458 | hy 1 2459 | ▁bas 1 2460 | ▁sometimes 1 2461 | par 1 2462 | ▁fur 1 2463 | ▁dem 1 2464 | fic 1 2465 | ▁wood 1 2466 | ▁given 1 2467 | ▁quest 1 2468 | ▁Mar 1 2469 | let 1 2470 | ▁body 1 2471 | ▁1 1 2472 | ▁John 1 2473 | No 1 2474 | ified 1 2475 | ▁lay 1 2476 | ▁died 1 2477 | ctions 1 2478 | aves 1 2479 | ▁low 1 2480 | ▁appeared 1 2481 | ▁looking 1 2482 | ▁women 1 2483 | ▁ann 1 2484 | ▁anything 1 2485 | lu 1 2486 | asure 1 2487 | ▁ill 1 2488 | ▁gold 1 2489 | ▁rev 1 2490 | arent 1 2491 | ▁taken 1 2492 | oon 1 2493 | ▁poor 1 2494 | ▁ra 1 2495 | arr 1 2496 | ches 1 2497 | ograp 1 2498 | ▁field 1 2499 | ▁program 1 2500 | ▁ir 1 2501 | ▁jo 1 2502 | ▁included 1 2503 | ▁fun 1 2504 | eth 1 2505 | ▁sle 1 2506 | ox 1 2507 | ev 1 2508 | ality 1 2509 | read 1 2510 | ▁quick 1 2511 | ▁sett 1 2512 | ases 1 2513 | ▁speak 1 2514 | fully 1 2515 | ceed 1 2516 | ald 1 2517 | ▁capt 1 2518 | ▁beaut 1 2519 | ▁sum 1 2520 | alth 1 2521 | iously 1 2522 | ination 1 2523 | iod 1 2524 | ▁prop 1 2525 | owed 1 2526 | ▁educ 1 2527 | ▁across 1 2528 | ▁series 1 2529 | ▁pap 1 2530 | ▁contro 1 2531 | ilar 1 2532 | ▁written 1 2533 | ▁hap 1 2534 | ▁general 1 2535 | ▁' 1 2536 | ford 1 2537 | ▁begin 1 2538 | ored 1 2539 | ▁trad 1 2540 | ney 1 2541 | ▁ten 1 2542 | ▁prot 1 2543 | ▁free 1 2544 | attle 1 2545 | ▁important 1 2546 | ey 1 2547 | ▁built 1 2548 | ▁means 1 2549 | ▁ship 1 2550 | ▁swe 1 2551 | chn 1 2552 | ▁news 1 2553 | arl 1 2554 | ▁cross 1 2555 | ▁daughter 1 2556 | ▁aw 1 2557 | ▁behind 1 2558 | men 1 2559 | ▁repres 1 2560 | ▁nat 1 2561 | wered 1 2562 | ▁event 1 2563 | ▁meet 1 2564 | ▁single 1 2565 | ese 1 2566 | ▁happen 1 2567 | ▁sold 1 2568 | ▁party 1 2569 | idered 1 2570 | ▁stood 1 2571 | aught 1 2572 | ected 1 2573 | ▁bed 1 2574 | ▁memb 1 2575 | ▁rather 1 2576 | ▁sal 1 2577 | ▁case 1 2578 | Two 1 2579 | eter 1 2580 | ike 1 2581 | oor 1 2582 | ▁wall 1 2583 | ▁usually 1 2584 | lled 1 2585 | ▁story 1 2586 | bs 1 2587 | ready 1 2588 | ▁order 1 2589 | oci 1 2590 | ▁Coll 1 2591 | ▁answered 1 2592 | ef 1 2593 | ▁vers 1 2594 | ▁passed 1 2595 | of 1 2596 | rench 1 2597 | ▁auth 1 2598 | ▁church 1 2599 | ▁alchem 1 2600 | ▁sud 1 2601 | irt 1 2602 | During 1 2603 | ▁pract 1 2604 | ▁United 1 2605 | uff 1 2606 | ▁tit 1 2607 | ▁sudden 1 2608 | ▁third 1 2609 | ▁lady 1 2610 | ▁river 1 2611 | ▁community 1 2612 | ▁considered 1 2613 | ament 1 2614 | ▁dev 1 2615 | ▁lost 1 2616 | ▁walk 1 2617 | ▁River 1 2618 | My 1 2619 | ▁already 1 2620 | ▁matter 1 2621 | ▁herself 1 2622 | ▁Ro 1 2623 | ▁anim 1 2624 | ▁building 1 2625 | Do 1 2626 | ai 1 2627 | ms 1 2628 | ▁Saint 1 2629 | ▁grow 1 2630 | empt 1 2631 | ▁incre 1 2632 | Other 1 2633 | ▁moved 1 2634 | ▁table 1 2635 | orts 1 2636 | ▁rail 1 2637 | ▁dog 1 2638 | ▁members 1 2639 | ▁post 1 2640 | ▁themselves 1 2641 | ▁horse 1 2642 | ▁island 1 2643 | ency 1 2644 | ▁Can 1 2645 | ▁period 1 2646 | ▁cannot 1 2647 | uted 1 2648 | wer 1 2649 | ▁myself 1 2650 | anger 1 2651 | ▁reason 1 2652 | ills 1 2653 | ▁ran 1 2654 | ▁twenty 1 2655 | ▁plan 1 2656 | ▁exist 1 2657 | oud 1 2658 | ude 1 2659 | ▁nor 1 2660 | ▁German 1 2661 | ories 1 2662 | ▁occur 1 2663 | ls 1 2664 | ▁view 1 2665 | ▁2 1 2666 | urope 1 2667 | erc 1 2668 | pec 1 2669 | ▁past 1 2670 | ▁del 1 2671 | itted 1 2672 | ▁remem 1 2673 | ▁big 1 2674 | ▁true 1 2675 | ▁prim 1 2676 | ▁whom 1 2677 | haps 1 2678 | ising 1 2679 | ▁City 1 2680 | ▁service 1 2681 | ▁member 1 2682 | ▁proper 1 2683 | ▁itself 1 2684 | ▁led 1 2685 | ▁describ 1 2686 | ▁lab 1 2687 | ▁Park 1 2688 | ▁earth 1 2689 | itting 1 2690 | sc 1 2691 | ▁particular 1 2692 | ought 1 2693 | ▁entire 1 2694 | ▁blue 1 2695 | ▁stru 1 2696 | ▁tour 1 2697 | fort 1 2698 | gest 1 2699 | ▁Sp 1 2700 | ousand 1 2701 | ounc 1 2702 | airs 1 2703 | ▁everything 1 2704 | ▁understand 1 2705 | ▁trave 1 2706 | ▁above 1 2707 | bur 1 2708 | body 1 2709 | ming 1 2710 | ele 1 2711 | ▁track 1 2712 | ▁North 1 2713 | ▁mot 1 2714 | ▁clear 1 2715 | ▁county 1 2716 | ▁currently 1 2717 | osp 1 2718 | aff 1 2719 | ▁green 1 2720 | ▁human 1 2721 | ▁seat 1 2722 | bor 1 2723 | ▁lim 1 2724 | ▁conn 1 2725 | ▁grand 1 2726 | ulation 1 2727 | ays 1 2728 | ument 1 2729 | onder 1 2730 | gr 1 2731 | ush 1 2732 | ▁crit 1 2733 | ▁based 1 2734 | ▁Pr 1 2735 | ▁exc 1 2736 | ▁park 1 2737 | atter 1 2738 | ▁dear 1 2739 | ▁either 1 2740 | ▁hear 1 2741 | ump 1 2742 | ▁fif 1 2743 | inn 1 2744 | omin 1 2745 | ague 1 2746 | ▁hur 1 2747 | ze 1 2748 | ▁wide 1 2749 | ▁ant 1 2750 | ▁believe 1 2751 | ▁ord 1 2752 | ges 1 2753 | urb 1 2754 | ▁Aust 1 2755 | val 1 2756 | ficult 1 2757 | ▁sea 1 2758 | ▁history 1 2759 | ▁cal 1 2760 | ▁difficult 1 2761 | ▁resp 1 2762 | ▁process 1 2763 | ▁represent 1 2764 | ▁original 1 2765 | lf 1 2766 | ▁cut 1 2767 | ▁various 1 2768 | ▁leave 1 2769 | ▁thousand 1 2770 | ▁cast 1 2771 | ▁doub 1 2772 | ▁fail 1 2773 | ▁feet 1 2774 | ville 1 2775 | ▁idea 1 2776 | ▁American 1 2777 | ids 1 2778 | ▁avail 1 2779 | ▁object 1 2780 | ▁dream 1 2781 | ▁followed 1 2782 | Most 1 2783 | ▁else 1 2784 | ▁Bl 1 2785 | ▁Ph 1 2786 | ▁sister 1 2787 | ▁cry 1 2788 | ▁gone 1 2789 | ▁priv 1 2790 | ▁typ 1 2791 | ission 1 2792 | ▁language 1 2793 | ▁hot 1 2794 | ▁War 1 2795 | ▁wal 1 2796 | ▁stri 1 2797 | ▁similar 1 2798 | ively 1 2799 | ▁season 1 2800 | yle 1 2801 | ▁influ 1 2802 | Be 1 2803 | ▁project 1 2804 | ▁wanted 1 2805 | ▁species 1 2806 | ▁sand 1 2807 | ▁High 1 2808 | to 1 2809 | ▁spr 1 2810 | ▁inform 1 2811 | ▁Brit 1 2812 | ording 1 2813 | wh 1 2814 | ▁enc 1 2815 | ▁arri 1 2816 | medi 1 2817 | ▁Cal 1 2818 | ▁worked 1 2819 | ▁invol 1 2820 | the 1 2821 | une 1 2822 | ▁club 1 2823 | though 1 2824 | ▁wrote 1 2825 | ear 1 2826 | ▁able 1 2827 | ▁profess 1 2828 | ius 1 2829 | ▁least 1 2830 | ▁author 1 2831 | ▁esc 1 2832 | ▁career 1 2833 | ▁started 1 2834 | ▁rese 1 2835 | ▁fac 1 2836 | ▁alone 1 2837 | ▁desert 1 2838 | ▁Ed 1 2839 | ▁possible 1 2840 | ▁alchemist 1 2841 | utes 1 2842 | ▁prev 1 2843 | mp 1 2844 | ▁front 1 2845 | ▁pie 1 2846 | Sh 1 2847 | ▁dead 1 2848 | ▁nature 1 2849 | raft 1 2850 | idence 1 2851 | ▁learn 1 2852 | ▁works 1 2853 | ▁seven 1 2854 | ▁including 1 2855 | ▁techn 1 2856 | ode 1 2857 | ences 1 2858 | ▁indust 1 2859 | che 1 2860 | reen 1 2861 | ▁abs 1 2862 | ▁gent 1 2863 | aur 1 2864 | ▁South 1 2865 | ▁command 1 2866 | tered 1 2867 | Not 1 2868 | ▁fight 1 2869 | ▁ep 1 2870 | ploy 1 2871 | ▁Tr 1 2872 | ▁mor 1 2873 | ▁street 1 2874 | ▁Con 1 2875 | ▁squ 1 2876 | ▁host 1 2877 | ▁Comm 1 2878 | ▁hope 1 2879 | ▁proble 1 2880 | yn 1 2881 | ▁bad 1 2882 | ▁longer 1 2883 | ▁tele 1 2884 | ▁subject 1 2885 | oman 1 2886 | ▁fav 1 2887 | ▁replied 1 2888 | Both 1 2889 | ▁soft 1 2890 | ▁deep 1 2891 | ▁question 1 2892 | ▁further 1 2893 | ▁York 1 2894 | oe 1 2895 | antly 1 2896 | ript 1 2897 | ▁except 1 2898 | ▁oh 1 2899 | ▁office 1 2900 | fect 1 2901 | ▁produced 1 2902 | ▁gen 1 2903 | ▁addition 1 2904 | ▁step 1 2905 | ▁super 1 2906 | ▁College 1 2907 | kn 1 2908 | ization 1 2909 | ▁hous 1 2910 | ▁ru 1 2911 | ▁grad 1 2912 | ▁lived 1 2913 | lin 1 2914 | ▁district 1 2915 | ample 1 2916 | ▁books 1 2917 | ▁bit 1 2918 | over 1 2919 | ▁available 1 2920 | ▁hig 1 2921 | ▁attack 1 2922 | ises 1 2923 | ▁cried 1 2924 | ▁consider 1 2925 | ▁arch 1 2926 | ▁Christ 1 2927 | erred 1 2928 | ▁friends 1 2929 | ▁remained 1 2930 | ▁till 1 2931 | ▁iss 1 2932 | bo 1 2933 | ▁non 1 2934 | ▁fort 1 2935 | conom 1 2936 | ▁compet 1 2937 | ▁trou 1 2938 | ara 1 2939 | ▁ask 1 2940 | ▁occas 1 2941 | ▁perhaps 1 2942 | ▁special 1 2943 | iful 1 2944 | ▁making 1 2945 | ▁direct 1 2946 | ▁released 1 2947 | estern 1 2948 | itary 1 2949 | ▁fa 1 2950 | inct 1 2951 | na 1 2952 | iddle 1 2953 | ▁cult 1 2954 | ▁level 1 2955 | ▁decl 1 2956 | ayed 1 2957 | Their 1 2958 | ▁mis 1 2959 | ouncil 1 2960 | ▁complete 1 2961 | ▁creat 1 2962 | ▁Europe 1 2963 | itch 1 2964 | wood 1 2965 | ▁kept 1 2966 | ▁shop 1 2967 | ests 1 2968 | field 1 2969 | ▁change 1 2970 | aul 1 2971 | ▁dro 1 2972 | ▁towards 1 2973 | ▁fair 1 2974 | my 1 2975 | ager 1 2976 | illiam 1 2977 | ▁Be 1 2978 | ▁z 1 2979 | ▁playing 1 2980 | rest 1 2981 | band 1 2982 | fact 1 2983 | ▁due 1 2984 | ▁thus 1 2985 | lling 1 2986 | ▁slow 1 2987 | ▁wonder 1 2988 | ▁generally 1 2989 | ▁rock 1 2990 | acy 1 2991 | ▁separ 1 2992 | ▁Cent 1 2993 | ▁attempt 1 2994 | ▁doctor 1 2995 | ogn 1 2996 | ▁rad 1 2997 | ▁estab 1 2998 | arily 1 2999 | ▁buried 1 3000 | ▁cold 1 3001 | ▁sus 1 3002 | ▁neigh 1 3003 | itor 1 3004 | hod 1 3005 | ▁fish 1 3006 | ▁States 1 3007 | ▁situ 1 3008 | teen 1 3009 | ▁bird 1 3010 | ement 1 3011 | From 1 3012 | ▁attended 1 3013 | ▁intern 1 3014 | ▁gentle 1 3015 | inist 1 3016 | ▁defe 1 3017 | ▁position 1 3018 | ression 1 3019 | ▁wish 1 3020 | ▁fig 1 3021 | ▁makes 1 3022 | iding 1 3023 | ▁east 1 3024 | ▁ty 1 3025 | ▁pal 1 3026 | ▁elect 1 3027 | ival 1 3028 | ▁necess 1 3029 | ada 1 3030 | ▁coming 1 3031 | ▁features 1 3032 | ▁grew 1 3033 | owers 1 3034 | ▁example 1 3035 | ett 1 3036 | ▁Wor 1 3037 | ivid 1 3038 | ▁says 1 3039 | ▁bring 1 3040 | af 1 3041 | ▁activ 1 3042 | ▁break 1 3043 | ▁corn 1 3044 | ales 1 3045 | ▁Par 1 3046 | ▁aut 1 3047 | ▁indeed 1 3048 | ▁frequ 1 3049 | iol 1 3050 | ▁emb 1 3051 | ▁West 1 3052 | ▁gard 1 3053 | ites 1 3054 | ▁fear 1 3055 | ▁spir 1 3056 | ▁master 1 3057 | ▁west 1 3058 | Is 1 3059 | ▁jud 1 3060 | line 1 3061 | ▁quickly 1 3062 | ▁food 1 3063 | pped 1 3064 | hood 1 3065 | Ad 1 3066 | ▁added 1 3067 | ▁lives 1 3068 | ▁vol 1 3069 | ▁cat 1 3070 | go 1 3071 | lor 1 3072 | ▁Se 1 3073 | apan 1 3074 | iled 1 3075 | ▁sour 1 3076 | isted 1 3077 | ondon 1 3078 | ▁outside 1 3079 | ▁doing 1 3080 | ken 1 3081 | ▁star 1 3082 | ▁princip 1 3083 | ▁National 1 3084 | ▁throughout 1 3085 | icip 1 3086 | ▁prob 1 3087 | ▁opened 1 3088 | Who 1 3089 | ▁bes 1 3090 | itive 1 3091 | bour 1 3092 | ▁mater 1 3093 | ▁econom 1 3094 | ▁immedi 1 3095 | ▁includes 1 3096 | isc 1 3097 | ▁wat 1 3098 | ires 1 3099 | ma 1 3100 | reek 1 3101 | ▁Gr 1 3102 | ▁months 1 3103 | add 1 3104 | ▁living 1 3105 | ixed 1 3106 | ▁ball 1 3107 | iment 1 3108 | ▁therefore 1 3109 | riage 1 3110 | ask 1 3111 | ▁accept 1 3112 | ▁constru 1 3113 | ▁national 1 3114 | ublic 1 3115 | ▁mult 1 3116 | ▁tou 1 3117 | ration 1 3118 | ▁age 1 3119 | ▁secret 1 3120 | ▁Col 1 3121 | ▁mur 1 3122 | head 1 3123 | ▁court 1 3124 | road 1 3125 | erson 1 3126 | ▁compan 1 3127 | anch 1 3128 | erved 1 3129 | ▁recogn 1 3130 | ▁novel 1 3131 | ▁taking 1 3132 | ining 1 3133 | ▁bright 1 3134 | ▁using 1 3135 | ▁surpr 1 3136 | ▁changed 1 3137 | ▁hours 1 3138 | omes 1 3139 | ▁Pro 1 3140 | ▁anc 1 3141 | ▁cause 1 3142 | alt 1 3143 | ▁purp 1 3144 | vered 1 3145 | ▁depart 1 3146 | sy 1 3147 | ▁Car 1 3148 | ▁phys 1 3149 | ▁suddenly 1 3150 | ▁De 1 3151 | ▁sleep 1 3152 | ▁London 1 3153 | ▁associ 1 3154 | ▁thr 1 3155 | ▁watch 1 3156 | ston 1 3157 | ▁decided 1 3158 | ▁jour 1 3159 | ▁pay 1 3160 | ▁self 1 3161 | Where 1 3162 | ▁regard 1 3163 | emy 1 3164 | pected 1 3165 | ▁train 1 3166 | ata 1 3167 | where 1 3168 | ▁reached 1 3169 | ▁whose 1 3170 | ▁comb 1 3171 | ▁method 1 3172 | orthern 1 3173 | ▁paint 1 3174 | ole 1 3175 | ▁histor 1 3176 | ▁Ad 1 3177 | ▁lord 1 3178 | work 1 3179 | ▁Ste 1 3180 | aimed 1 3181 | ▁dam 1 3182 | ▁week 1 3183 | ▁soul 1 3184 | ▁surround 1 3185 | reg 1 3186 | ▁evening 1 3187 | ▁hus 1 3188 | ▁largest 1 3189 | ▁account 1 3190 | owd 1 3191 | ▁acqu 1 3192 | ▁trees 1 3193 | ▁employ 1 3194 | ▁farm 1 3195 | ▁mill 1 3196 | ested 1 3197 | ▁site 1 3198 | bers 1 3199 | ades 1 3200 | overed 1 3201 | ▁tree 1 3202 | ▁blood 1 3203 | ▁clim 1 3204 | Don 1 3205 | Now 1 3206 | dom 1 3207 | ▁working 1 3208 | ▁ver 1 3209 | By 1 3210 | ▁tradition 1 3211 | ▁surv 1 3212 | ▁British 1 3213 | iff 1 3214 | ▁wild 1 3215 | ▁trib 1 3216 | ▁Or 1 3217 | ▁developed 1 3218 | reme 1 3219 | aries 1 3220 | ustom 1 3221 | ▁strange 1 3222 | ▁beautiful 1 3223 | ivil 1 3224 | ▁region 1 3225 | ▁Austral 1 3226 | ▁cou 1 3227 | lo 1 3228 | ▁late 1 3229 | ▁comes 1 3230 | ▁op 1 3231 | ▁songs 1 3232 | ili 1 3233 | sequently 1 3234 | ▁terr 1 3235 | istic 1 3236 | ▁Dist 1 3237 | ▁whether 1 3238 | erve 1 3239 | ▁space 1 3240 | ▁disp 1 3241 | ▁It 1 3242 | ▁chief 1 3243 | ▁carried 1 3244 | ▁ut 1 3245 | ▁natural 1 3246 | ▁spoke 1 3247 | ▁draw 1 3248 | ▁surf 1 3249 | ▁imm 1 3250 | ▁tried 1 3251 | cul 1 3252 | ▁deg 1 3253 | cts 1 3254 | Every 1 3255 | bed 1 3256 | ▁eye 1 3257 | ▁political 1 3258 | ino 1 3259 | ▁eat 1 3260 | ▁perfect 1 3261 | ▁mist 1 3262 | ▁doubt 1 3263 | ▁deb 1 3264 | apt 1 3265 | iforn 1 3266 | ▁arr 1 3267 | well 1 3268 | joy 1 3269 | ▁sugg 1 3270 | ▁engine 1 3271 | ▁Geor 1 3272 | pendent 1 3273 | ▁forg 1 3274 | outhern 1 3275 | lex 1 3276 | ▁letter 1 3277 | ▁stars 1 3278 | ▁style 1 3279 | iced 1 3280 | ▁carry 1 3281 | ▁arms 1 3282 | ▁remains 1 3283 | ▁broad 1 3284 | ▁sort 1 3285 | ▁report 1 3286 | ▁published 1 3287 | While 1 3288 | ▁hair 1 3289 | ▁impro 1 3290 | ▁today 1 3291 | ▁box 1 3292 | yd 1 3293 | ▁formed 1 3294 | ▁sight 1 3295 | ▁standing 1 3296 | ▁Mc 1 3297 | ifornia 1 3298 | ▁covered 1 3299 | ana 1 3300 | ▁below 1 3301 | ▁forest 1 3302 | ▁Town 1 3303 | ▁modern 1 3304 | inal 1 3305 | ▁famous 1 3306 | erous 1 3307 | ▁husband 1 3308 | ▁relig 1 3309 | ▁market 1 3310 | olic 1 3311 | ▁har 1 3312 | lement 1 3313 | cast 1 3314 | ▁quar 1 3315 | ▁cost 1 3316 | ▁ident 1 3317 | ▁Man 1 3318 | osition 1 3319 | ▁World 1 3320 | ▁King 1 3321 | vision 1 3322 | ▁pict 1 3323 | ▁cover 1 3324 | ▁raised 1 3325 | ▁den 1 3326 | ▁Wh 1 3327 | orpor 1 3328 | ▁thir 1 3329 | ▁title 1 3330 | ▁signific 1 3331 | abit 1 3332 | ▁material 1 3333 | ▁glass 1 3334 | ▁lies 1 3335 | ▁sitting 1 3336 | ▁French 1 3337 | ▁af 1 3338 | ▁deal 1 3339 | ▁sens 1 3340 | ▁mad 1 3341 | ingly 1 3342 | inet 1 3343 | ▁rich 1 3344 | ▁provided 1 3345 | ean 1 3346 | ▁Af 1 3347 | ift 1 3348 | uses 1 3349 | iers 1 3350 | ength 1 3351 | ▁points 1 3352 | pecially 1 3353 | ▁Sw 1 3354 | ▁bound 1 3355 | ▁students 1 3356 | ▁quiet 1 3357 | ▁District 1 3358 | ips 1 3359 | Why 1 3360 | ▁schools 1 3361 | ▁described 1 3362 | ▁areas 1 3363 | ▁repe 1 3364 | ▁floor 1 3365 | ▁running 1 3366 | ▁ready 1 3367 | ▁role 1 3368 | ▁entered 1 3369 | ators 1 3370 | Un 1 3371 | aly 1 3372 | ▁regular 1 3373 | rupt 1 3374 | ▁rout 1 3375 | ▁bene 1 3376 | ▁Lake 1 3377 | ▁prep 1 3378 | ▁pick 1 3379 | ously 1 3380 | rs 1 3381 | ▁lad 1 3382 | ▁viol 1 3383 | ▁mouth 1 3384 | ▁remain 1 3385 | llow 1 3386 | ▁ter 1 3387 | oph 1 3388 | olution 1 3389 | ▁determ 1 3390 | uge 1 3391 | ▁phil 1 3392 | ply 1 3393 | ▁Re 1 3394 | eder 1 3395 | ▁answer 1 3396 | unicip 1 3397 | ▁prison 1 3398 | Which 1 3399 | ▁imag 1 3400 | ▁happy 1 3401 | acher 1 3402 | ▁respect 1 3403 | ▁pret 1 3404 | ference 1 3405 | inary 1 3406 | ▁England 1 3407 | ▁pain 1 3408 | ops 1 3409 | ▁la 1 3410 | rate 1 3411 | ▁enjoy 1 3412 | ▁window 1 3413 | ▁ren 1 3414 | ▁houses 1 3415 | ▁Japan 1 3416 | enced 1 3417 | ocr 1 3418 | ▁mass 1 3419 | ▁exam 1 3420 | gn 1 3421 | ▁fast 1 3422 | ▁rose 1 3423 | ▁services 1 3424 | ▁excl 1 3425 | iction 1 3426 | ▁California 1 3427 | ▁killed 1 3428 | ica 1 3429 | itude 1 3430 | ▁football 1 3431 | anced 1 3432 | ▁social 1 3433 | ▁bord 1 3434 | ▁Pl 1 3435 | ▁liter 1 3436 | key 1 3437 | known 1 3438 | ▁Har 1 3439 | ▁fut 1 3440 | ▁William 1 3441 | ▁try 1 3442 | ps 1 3443 | ▁type 1 3444 | oming 1 3445 | ▁individ 1 3446 | ▁manner 1 3447 | ▁da 1 3448 | ails 1 3449 | aded 1 3450 | cend 1 3451 | ▁paper 1 3452 | asc 1 3453 | ▁stat 1 3454 | resh 1 3455 | ha 1 3456 | omb 1 3457 | ▁Island 1 3458 | ▁movement 1 3459 | ▁battle 1 3460 | ▁development 1 3461 | ▁capital 1 3462 | ▁created 1 3463 | ▁crowd 1 3464 | ▁placed 1 3465 | su 1 3466 | ▁treat 1 3467 | ledge 1 3468 | ▁base 1 3469 | Of 1 3470 | ▁maint 1 3471 | irm 1 3472 | Each 1 3473 | ▁drink 1 3474 | roy 1 3475 | aped 1 3476 | ▁gun 1 3477 | ▁fine 1 3478 | ensive 1 3479 | ▁parent 1 3480 | like 1 3481 | useum 1 3482 | ▁bank 1 3483 | ▁information 1 3484 | ▁captain 1 3485 | ▁struct 1 3486 | hib 1 3487 | hel 1 3488 | ▁introd 1 3489 | gu 1 3490 | irth 1 3491 | ▁especially 1 3492 | ▁mov 1 3493 | ▁tom 1 3494 | Well 1 3495 | ▁heavy 1 3496 | ▁parts 1 3497 | ▁fellow 1 3498 | uct 1 3499 | ▁center 1 3500 | ▁prince 1 3501 | cy 1 3502 | cher 1 3503 | ▁prec 1 3504 | ▁text 1 3505 | ▁laugh 1 3506 | het 1 3507 | ▁conv 1 3508 | ▁hall 1 3509 | ▁toward 1 3510 | ights 1 3511 | ▁names 1 3512 | itar 1 3513 | ume 1 3514 | iling 1 3515 | ▁store 1 3516 | ▁vict 1 3517 | ▁player 1 3518 | ▁remember 1 3519 | ilities 1 3520 | hic 1 3521 | ▁search 1 3522 | ▁version 1 3523 | ▁private 1 3524 | ▁final 1 3525 | unk 1 3526 | ▁mid 1 3527 | ▁truth 1 3528 | ▁possess 1 3529 | ▁Gu 1 3530 | ▁forward 1 3531 | ▁All 1 3532 | ▁der 1 3533 | ▁walked 1 3534 | ▁minutes 1 3535 | urs 1 3536 | off 1 3537 | ▁danger 1 3538 | aching 1 3539 | ological 1 3540 | ▁instr 1 3541 | oes 1 3542 | amb 1 3543 | ▁immediately 1 3544 | ned 1 3545 | ceeded 1 3546 | emet 1 3547 | ▁As 1 3548 | ointed 1 3549 | ▁suc 1 3550 | ▁health 1 3551 | ▁instead 1 3552 | ▁sense 1 3553 | ening 1 3554 | ?" 1 3555 | ▁appearance 1 3556 | ▁chair 1 3557 | ▁control 1 3558 | ▁force 1 3559 | ▁board 1 3560 | ▁Air 1 3561 | ▁visit 1 3562 | Sever 1 3563 | ▁intell 1 3564 | ▁State 1 3565 | ▁exec 1 3566 | ▁worth 1 3567 | pper 1 3568 | ▁happened 1 3569 | ▁events 1 3570 | ▁length 1 3571 | ▁clot 1 3572 | ▁stra 1 3573 | ▁contains 1 3574 | ▁rul 1 3575 | rel 1 3576 | ▁vide 1 3577 | gar 1 3578 | uel 1 3579 | ▁yellow 1 3580 | ▁brown 1 3581 | ▁closed 1 3582 | ▁nine 1 3583 | oses 1 3584 | ▁plat 1 3585 | ▁test 1 3586 | ▁exact 1 3587 | ▁er 1 3588 | ▁feeling 1 3589 | Or 1 3590 | ▁ge 1 3591 | burg 1 3592 | ▁lot 1 3593 | ▁fle 1 3594 | ▁dress 1 3595 | ▁di 1 3596 | Give 1 3597 | aker 1 3598 | usion 1 3599 | ▁takes 1 3600 | ▁distance 1 3601 | pap 1 3602 | ▁bab 1 3603 | orrow 1 3604 | ▁init 1 3605 | emetery 1 3606 | ique 1 3607 | ▁remark 1 3608 | ended 1 3609 | ▁mountain 1 3610 | ▁eventually 1 3611 | ▁saying 1 3612 | de 1 3613 | Can 1 3614 | ech 1 3615 | ▁dut 1 3616 | ▁merch 1 3617 | ▁artic 1 3618 | ▁slight 1 3619 | uk 1 3620 | ▁plant 1 3621 | ▁occasion 1 3622 | ▁relations 1 3623 | ▁press 1 3624 | ils 1 3625 | ▁uncle 1 3626 | Several 1 3627 | mon 1 3628 | ▁phot 1 3629 | ▁die 1 3630 | anish 1 3631 | ▁forms 1 3632 | ▁attract 1 3633 | ▁games 1 3634 | ▁joined 1 3635 | ▁stay 1 3636 | ▁stopped 1 3637 | ▁future 1 3638 | ▁administ 1 3639 | arty 1 3640 | ▁Bar 1 3641 | ▁occup 1 3642 | oid 1 3643 | Later 1 3644 | ▁stren 1 3645 | ground 1 3646 | ▁production 1 3647 | ability 1 3648 | ▁probably 1 3649 | ▁education 1 3650 | ▁din 1 3651 | ▁president 1 3652 | ▁nearly 1 3653 | ym 1 3654 | ▁meaning 1 3655 | ▁fem 1 3656 | ▁cell 1 3657 | ▁successful 1 3658 | ancy 1 3659 | ▁slowly 1 3660 | ▁military 1 3661 | udd 1 3662 | ▁pull 1 3663 | ▁sheep 1 3664 | ▁getting 1 3665 | ▁displ 1 3666 | oted 1 3667 | ▁transl 1 3668 | aled 1 3669 | ▁invest 1 3670 | ▁official 1 3671 | iant 1 3672 | vere 1 3673 | ▁seems 1 3674 | ▁snow 1 3675 | ▁pred 1 3676 | ▁cand 1 3677 | ▁pot 1 3678 | ▁mention 1 3679 | ▁miles 1 3680 | ▁studied 1 3681 | zz 1 3682 | ▁numerous 1 3683 | ▁nearby 1 3684 | ▁inj 1 3685 | ky 1 3686 | aven 1 3687 | ▁At 1 3688 | lect 1 3689 | ▁spirit 1 3690 | avid 1 3691 | ▁summer 1 3692 | nal 1 3693 | agon 1 3694 | ny 1 3695 | ▁contrib 1 3696 | west 1 3697 | ▁writing 1 3698 | ▁rare 1 3699 | ▁Indian 1 3700 | ▁stone 1 3701 | ▁simple 1 3702 | ▁neg 1 3703 | icy 1 3704 | ▁desp 1 3705 | ▁doll 1 3706 | rated 1 3707 | Another 1 3708 | ▁recorded 1 3709 | stit 1 3710 | ushed 1 3711 | ▁spent 1 3712 | ae 1 3713 | ▁mach 1 3714 | ▁streng 1 3715 | ▁Go 1 3716 | ▁ang 1 3717 | ▁inn 1 3718 | ▁sweet 1 3719 | ▁constit 1 3720 | king 1 3721 | ▁ago 1 3722 | ▁action 1 3723 | rog 1 3724 | ▁Intern 1 3725 | ▁allowed 1 3726 | ▁bott 1 3727 | room 1 3728 | ▁belong 1 3729 | ▁temper 1 3730 | ▁contr 1 3731 | ▁pen 1 3732 | ▁Afric 1 3733 | ▁tal 1 3734 | ▁performed 1 3735 | unch 1 3736 | ▁grave 1 3737 | ords 1 3738 | iation 1 3739 | ▁produc 1 3740 | rit 1 3741 | ▁garden 1 3742 | atory 1 3743 | ▁hop 1 3744 | ▁key 1 3745 | ▁House 1 3746 | ▁states 1 3747 | iny 1 3748 | ▁research 1 3749 | ▁active 1 3750 | olf 1 3751 | ▁agre 1 3752 | ▁couldn 1 3753 | ▁comput 1 3754 | ▁peace 1 3755 | ▁originally 1 3756 | ▁designed 1 3757 | ▁mere 1 3758 | ▁please 1 3759 | Con 1 3760 | ▁path 1 3761 | eds 1 3762 | ▁civil 1 3763 | ▁wearing 1 3764 | ▁Co 1 3765 | ▁direction 1 3766 | ples 1 3767 | ▁Church 1 3768 | ▁fly 1 3769 | ▁municip 1 3770 | ▁cred 1 3771 | anded 1 3772 | ▁groups 1 3773 | ▁caused 1 3774 | ▁breat 1 3775 | ▁established 1 3776 | enth 1 3777 | ▁attention 1 3778 | owing 1 3779 | ▁warm 1 3780 | ▁inside 1 3781 | olly 1 3782 | osing 1 3783 | ▁rob 1 3784 | ospital 1 3785 | ▁thinking 1 3786 | ▁pretty 1 3787 | ▁referred 1 3788 | ▁scar 1 3789 | icated 1 3790 | raid 1 3791 | ander 1 3792 | ▁rap 1 3793 | ▁ess 1 3794 | ▁race 1 3795 | ▁newspap 1 3796 | ▁fund 1 3797 | ▁weather 1 3798 | ▁lif 1 3799 | ▁stage 1 3800 | ▁middle 1 3801 | ▁satis 1 3802 | ▁army 1 3803 | ▁leaves 1 3804 | ▁hat 1 3805 | ▁television 1 3806 | cer 1 3807 | iles 1 3808 | ▁shows 1 3809 | ▁cop 1 3810 | ▁suppose 1 3811 | ▁El 1 3812 | Ar 1 3813 | ware 1 3814 | ▁send 1 3815 | cell 1 3816 | sey 1 3817 | ▁glad 1 3818 | aret 1 3819 | ▁none 1 3820 | ▁und 1 3821 | ▁although 1 3822 | ▁color 1 3823 | ▁showed 1 3824 | ▁cab 1 3825 | ▁purpose 1 3826 | ▁trust 1 3827 | ▁aud 1 3828 | ▁john 1 3829 | ▁finished 1 3830 | ▁mostly 1 3831 | ▁featured 1 3832 | asons 1 3833 | igned 1 3834 | ▁lower 1 3835 | ▁straight 1 3836 | ▁mess 1 3837 | ▁kill 1 3838 | right 1 3839 | ▁net 1 3840 | ▁certainly 1 3841 | ential 1 3842 | ▁pack 1 3843 | ▁bill 1 3844 | ▁problem 1 3845 | ▁college 1 3846 | cent 1 3847 | ▁specific 1 3848 | ▁experience 1 3849 | aug 1 3850 | inger 1 3851 | ▁East 1 3852 | ▁honor 1 3853 | ▁lie 1 3854 | mith 1 3855 | ios 1 3856 | ▁cy 1 3857 | ▁poet 1 3858 | ▁noted 1 3859 | olk 1 3860 | Only 1 3861 | ▁Pres 1 3862 | ▁thy 1 3863 | mit 1 3864 | aches 1 3865 | ▁wrong 1 3866 | ris 1 3867 | ervation 1 3868 | ael 1 3869 | idents 1 3870 | ▁extend 1 3871 | uments 1 3872 | ▁access 1 3873 | ▁review 1 3874 | wise 1 3875 | itect 1 3876 | ▁bear 1 3877 | ▁simply 1 3878 | ▁ancient 1 3879 | ▁convers 1 3880 | ▁shot 1 3881 | ▁Roman 1 3882 | ▁piece 1 3883 | ▁filled 1 3884 | ▁railway 1 3885 | ught 1 3886 | oo 1 3887 | uce 1 3888 | ▁And 1 3889 | fer 1 3890 | ▁Phil 1 3891 | ▁trying 1 3892 | ▁extreme 1 3893 | Ex 1 3894 | ▁meant 1 3895 | ▁write 1 3896 | ▁assist 1 3897 | light 1 3898 | ▁population 1 3899 | ▁birth 1 3900 | ▁shar 1 3901 | ▁according 1 3902 | erry 1 3903 | ▁youth 1 3904 | craft 1 3905 | Re 1 3906 | ▁joy 1 3907 | ▁talking 1 3908 | ▁replaced 1 3909 | ▁gradu 1 3910 | ▁chance 1 3911 | Bec 1 3912 | ▁knowledge 1 3913 | hest 1 3914 | ▁latter 1 3915 | ope 1 3916 | ▁mine 1 3917 | Cl 1 3918 | asant 1 3919 | ▁surface 1 3920 | Pl 1 3921 | ▁afraid 1 3922 | Tr 1 3923 | ▁entirely 1 3924 | ▁offered 1 3925 | ▁industry 1 3926 | amm 1 3927 | ipped 1 3928 | itation 1 3929 | ▁award 1 3930 | ▁hen 1 3931 | ux 1 3932 | Am 1 3933 | ▁commerc 1 3934 | ▁girls 1 3935 | ▁norm 1 3936 | ▁distinct 1 3937 | ▁police 1 3938 | ging 1 3939 | ▁colon 1 3940 | ▁range 1 3941 | Let 1 3942 | oned 1 3943 | ergy 1 3944 | ▁tro 1 3945 | ▁neither 1 3946 | ▁rain 1 3947 | ▁Gl 1 3948 | ▁custom 1 3949 | ▁centre 1 3950 | ▁thirty 1 3951 | vert 1 3952 | ama 1 3953 | ▁opin 1 3954 | ees 1 3955 | what 1 3956 | ether 1 3957 | ▁wasn 1 3958 | ingu 1 3959 | itz 1 3960 | ▁Val 1 3961 | ▁Canada 1 3962 | omas 1 3963 | edd 1 3964 | ▁claim 1 3965 | cle 1 3966 | ▁fru 1 3967 | ▁walking 1 3968 | ▁parents 1 3969 | yond 1 3970 | olit 1 3971 | ▁sevent 1 3972 | ufact 1 3973 | ▁beyond 1 3974 | ▁hast 1 3975 | ▁tax 1 3976 | Three 1 3977 | ▁personal 1 3978 | ▁expect 1 3979 | ▁owned 1 3980 | chie 1 3981 | ▁theory 1 3982 | iber 1 3983 | eb 1 3984 | ▁stream 1 3985 | ago 1 3986 | ▁destroy 1 3987 | ▁marriage 1 3988 | ification 1 3989 | ables 1 3990 | ounds 1 3991 | orses 1 3992 | ▁sch 1 3993 | ▁places 1 3994 | ▁Ber 1 3995 | ▁video 1 3996 | ▁runs 1 3997 | ▁feature 1 3998 | ▁San 1 3999 | ▁bell 1 4000 | ▁christ 1 4001 | ▁appears 1 4002 | ▁society 1 4003 | ▁animals 1 4004 | ▁Mount 1 4005 | van 1 4006 | ▁ce 1 4007 | ▁save 1 4008 | ▁required 1 4009 | ify 1 4010 | ▁env 1 4011 | ▁concern 1 4012 | ▁gri 1 4013 | ▁necessary 1 4014 | orough 1 4015 | Ne 1 4016 | ▁study 1 4017 | ▁Am 1 4018 | ricult 1 4019 | ▁leaving 1 4020 | Tod 1 4021 | ▁winter 1 4022 | ▁whis 1 4023 | ▁model 1 4024 | ishes 1 4025 | ▁plays 1 4026 | ▁yourself 1 4027 | ▁coun 1 4028 | ▁uses 1 4029 | ounded 1 4030 | ▁ring 1 4031 | ▁caught 1 4032 | bered 1 4033 | ▁epis 1 4034 | xim 1 4035 | ▁knows 1 4036 | ▁wait 1 4037 | ▁believed 1 4038 | ▁serves 1 4039 | apped 1 4040 | ▁boat 1 4041 | ▁laid 1 4042 | ▁lake 1 4043 | Oh 1 4044 | yal 1 4045 | ▁possib 1 4046 | ems 1 4047 | ▁circum 1 4048 | ▁higher 1 4049 | time 1 4050 | ▁release 1 4051 | ▁countries 1 4052 | ▁couple 1 4053 | ▁complex 1 4054 | ▁involved 1 4055 | ▁pray 1 4056 | ▁particularly 1 4057 | ▁northern 1 4058 | ▁safe 1 4059 | ▁trouble 1 4060 | ping 1 4061 | irgin 1 4062 | ▁Bo 1 4063 | ▁bow 1 4064 | ▁Mich 1 4065 | ▁easy 1 4066 | ▁consc 1 4067 | asing 1 4068 | ▁influence 1 4069 | ▁lines 1 4070 | ▁altern 1 4071 | ▁consists 1 4072 | rey 1 4073 | uke 1 4074 | elebr 1 4075 | ▁moon 1 4076 | ▁twel 1 4077 | ▁terms 1 4078 | ico 1 4079 | oved 1 4080 | ▁opt 1 4081 | ▁forgot 1 4082 | ▁dropped 1 4083 | ▁bridge 1 4084 | ▁finally 1 4085 | ▁Ex 1 4086 | ▁spread 1 4087 | ▁buildings 1 4088 | Acc 1 4089 | ners 1 4090 | ▁boys 1 4091 | esp 1 4092 | ▁ought 1 4093 | ▁spot 1 4094 | ula 1 4095 | ▁arg 1 4096 | anks 1 4097 | ▁Cemetery 1 4098 | ▁tast 1 4099 | board 1 4100 | ▁Ir 1 4101 | semb 1 4102 | ▁Min 1 4103 | ▁suit 1 4104 | utions 1 4105 | ours 1 4106 | ▁fat 1 4107 | anged 1 4108 | ▁strength 1 4109 | ▁cloud 1 4110 | ros 1 4111 | eneral 1 4112 | ▁architect 1 4113 | Today 1 4114 | ▁inh 1 4115 | ▁pet 1 4116 | Because 1 4117 | inder 1 4118 | ▁movie 1 4119 | ▁practice 1 4120 | ashing 1 4121 | ▁learned 1 4122 | ▁stre 1 4123 | ▁foc 1 4124 | ▁faith 1 4125 | ▁waiting 1 4126 | ▁beginning 1 4127 | ▁drew 1 4128 | ▁easily 1 4129 | ▁southern 1 4130 | Even 1 4131 | ounced 1 4132 | ▁degree 1 4133 | inent 1 4134 | ▁hill 1 4135 | atholic 1 4136 | ▁minute 1 4137 | ▁property 1 4138 | ▁manufact 1 4139 | for 1 4140 | lt 1 4141 | ▁grass 1 4142 | ▁bag 1 4143 | ▁corner 1 4144 | ▁Nor 1 4145 | oly 1 4146 | ▁function 1 4147 | agn 1 4148 | ▁doesn 1 4149 | chan 1 4150 | ols 1 4151 | ▁applic 1 4152 | ▁jack 1 4153 | ▁broken 1 4154 | ▁marked 1 4155 | ▁highly 1 4156 | ▁failed 1 4157 | ▁Ab 1 4158 | ▁Road 1 4159 | ▁dig 1 4160 | ▁figure 1 4161 | ▁tea 1 4162 | ▁related 1 4163 | ▁Russ 1 4164 | ▁cases 1 4165 | ▁silence 1 4166 | ▁fourth 1 4167 | ching 1 4168 | str 1 4169 | ouis 1 4170 | ▁thick 1 4171 | ▁horses 1 4172 | ▁thee 1 4173 | ▁understood 1 4174 | ▁condition 1 4175 | ▁merchant 1 4176 | appe 1 4177 | olute 1 4178 | ▁critic 1 4179 | ▁results 1 4180 | ule 1 4181 | ▁arrived 1 4182 | ena 1 4183 | ▁dim 1 4184 | ▁source 1 4185 | ▁apart 1 4186 | ▁coast 1 4187 | Like 1 4188 | ▁discuss 1 4189 | ▁frequently 1 4190 | ▁central 1 4191 | ▁systems 1 4192 | ▁Char 1 4193 | ▁loss 1 4194 | ▁mainly 1 4195 | ▁stories 1 4196 | ▁magaz 1 4197 | eer 1 4198 | atin 1 4199 | umed 1 4200 | ▁Soc 1 4201 | ▁Arab 1 4202 | ▁hunt 1 4203 | acing 1 4204 | ▁forget 1 4205 | ▁gas 1 4206 | umn 1 4207 | ette 1 4208 | ▁adj 1 4209 | oured 1 4210 | ▁pit 1 4211 | ▁serve 1 4212 | edy 1 4213 | Also 1 4214 | ▁Rep 1 4215 | raz 1 4216 | ▁elected 1 4217 | unity 1 4218 | ▁flowers 1 4219 | abor 1 4220 | ▁sac 1 4221 | iring 1 4222 | ▁celebr 1 4223 | ▁subsequently 1 4224 | ompany 1 4225 | ▁plain 1 4226 | ▁spring 1 4227 | uous 1 4228 | ▁Mon 1 4229 | ▁significant 1 4230 | ▁primary 1 4231 | ▁Australia 1 4232 | ▁Ass 1 4233 | ▁dry 1 4234 | ▁respons 1 4235 | lig 1 4236 | rec 1 4237 | lected 1 4238 | ▁larger 1 4239 | agement 1 4240 | Did 1 4241 | house 1 4242 | ▁fool 1 4243 | ▁twice 1 4244 | ▁job 1 4245 | ▁clean 1 4246 | ▁ended 1 4247 | ▁comfort 1 4248 | ▁doc 1 4249 | ▁loved 1 4250 | ▁Sec 1 4251 | ▁depend 1 4252 | ▁completely 1 4253 | liament 1 4254 | ▁size 1 4255 | quar 1 4256 | ▁century 1 4257 | ▁Pol 1 4258 | ▁iron 1 4259 | ▁holding 1 4260 | ▁guitar 1 4261 | antic 1 4262 | ▁speaking 1 4263 | ▁platform 1 4264 | ▁amb 1 4265 | ▁Rich 1 4266 | heast 1 4267 | ▁huge 1 4268 | ography 1 4269 | ipe 1 4270 | bury 1 4271 | ▁obt 1 4272 | ▁Mary 1 4273 | ▁aftern 1 4274 | ▁international 1 4275 | ▁Hall 1 4276 | ▁helped 1 4277 | ▁changes 1 4278 | ▁compl 1 4279 | ▁variety 1 4280 | ▁favor 1 4281 | ▁shad 1 4282 | uction 1 4283 | mark 1 4284 | ▁sky 1 4285 | atively 1 4286 | ▁pan 1 4287 | ▁eld 1 4288 | ▁concept 1 4289 | ▁rapid 1 4290 | Add 1 4291 | ▁bi 1 4292 | ▁particip 1 4293 | amer 1 4294 | rible 1 4295 | ▁dinner 1 4296 | ▁restaur 1 4297 | ▁primarily 1 4298 | rook 1 4299 | ▁reb 1 4300 | ▁spok 1 4301 | ▁shut 1 4302 | ▁earl 1 4303 | ▁educated 1 4304 | ▁ice 1 4305 | ▁Port 1 4306 | ▁council 1 4307 | ▁goes 1 4308 | ▁weeks 1 4309 | ores 1 4310 | ▁collect 1 4311 | ▁reve 1 4312 | ▁guard 1 4313 | ▁greater 1 4314 | isk 1 4315 | ▁commercial 1 4316 | ▁allow 1 4317 | Since 1 4318 | ▁birds 1 4319 | ▁older 1 4320 | rian 1 4321 | olved 1 4322 | ▁appl 1 4323 | ▁leading 1 4324 | ▁distrib 1 4325 | '. 1 4326 | ▁Street 1 4327 | ▁flat 1 4328 | ▁contract 1 4329 | ams 1 4330 | ▁content 1 4331 | Br 1 4332 | alian 1 4333 | Thus 1 4334 | ▁Roy 1 4335 | ▁individual 1 4336 | ▁taught 1 4337 | ▁offer 1 4338 | ▁Ac 1 4339 | ▁tall 1 4340 | ▁edge 1 4341 | ▁James 1 4342 | ▁shown 1 4343 | ▁succeeded 1 4344 | ▁languages 1 4345 | ora 1 4346 | ▁amount 1 4347 | ▁silver 1 4348 | ▁widely 1 4349 | ▁Cr 1 4350 | ▁delight 1 4351 | ▁princess 1 4352 | ▁firm 1 4353 | ▁presence 1 4354 | icient 1 4355 | ▁enter 1 4356 | ▁location 1 4357 | ▁imposs 1 4358 | ▁somewhat 1 4359 | abb 1 4360 | ▁Sch 1 4361 | ▁walls 1 4362 | ▁fresh 1 4363 | win 1 4364 | icks 1 4365 | ▁needed 1 4366 | ▁America 1 4367 | ▁twelve 1 4368 | ▁performance 1 4369 | ▁fal 1 4370 | ▁achie 1 4371 | ▁commonly 1 4372 | ▁mel 1 4373 | ▁sad 1 4374 | ▁van 1 4375 | arrow 1 4376 | berg 1 4377 | ▁multip 1 4378 | ▁fifty 1 4379 | ▁invent 1 4380 | fl 1 4381 | May 1 4382 | hered 1 4383 | ▁prevent 1 4384 | iron 1 4385 | ▁sem 1 4386 | ▁emer 1 4387 | ▁interp 1 4388 | iqu 1 4389 | ▁struck 1 4390 | ▁om 1 4391 | ceived 1 4392 | ▁treasure 1 4393 | Pe 1 4394 | eal 1 4395 | Somet 1 4396 | oking 1 4397 | ▁trade 1 4398 | view 1 4399 | ▁sons 1 4400 | Before 1 4401 | ▁speed 1 4402 | ▁continue 1 4403 | ▁hardly 1 4404 | ▁separate 1 4405 | ▁instit 1 4406 | aign 1 4407 | come 1 4408 | illa 1 4409 | ▁hom 1 4410 | ▁egg 1 4411 | Was 1 4412 | ▁ves 1 4413 | ▁harm 1 4414 | ▁leader 1 4415 | ican 1 4416 | ▁anx 1 4417 | ▁repeated 1 4418 | ▁Germany 1 4419 | ▁previous 1 4420 | ▁gentleman 1 4421 | ograph 1 4422 | ▁forced 1 4423 | arth 1 4424 | Although 1 4425 | ▁unknown 1 4426 | ▁neighbor 1 4427 | ▁hung 1 4428 | ▁May 1 4429 | ▁threat 1 4430 | back 1 4431 | ▁Party 1 4432 | ▁gives 1 4433 | ▁reach 1 4434 | ▁reading 1 4435 | ▁adop 1 4436 | ington 1 4437 | ▁Council 1 4438 | ▁agricult 1 4439 | ▁Med 1 4440 | shire 1 4441 | idden 1 4442 | ▁dism 1 4443 | ▁mechan 1 4444 | ▁Me 1 4445 | ▁clos 1 4446 | ▁wine 1 4447 | ▁divided 1 4448 | met 1 4449 | ▁afternoon 1 4450 | aph 1 4451 | rawn 1 4452 | ▁Mor 1 4453 | ▁queen 1 4454 | ▁Pe 1 4455 | ▁cryst 1 4456 | ▁science 1 4457 | ▁Met 1 4458 | ▁pock 1 4459 | ▁construction 1 4460 | ▁lack 1 4461 | ▁flight 1 4462 | According 1 4463 | ▁evidence 1 4464 | ▁training 1 4465 | ▁Mus 1 4466 | ▁gall 1 4467 | ▁rule 1 4468 | ▁directly 1 4469 | ▁discovered 1 4470 | ▁pleasure 1 4471 | ▁hit 1 4472 | ▁tot 1 4473 | ▁neare 1 4474 | ▁scene 1 4475 | ▁strugg 1 4476 | orporated 1 4477 | ▁Christian 1 4478 | ▁giving 1 4479 | aine 1 4480 | ▁alt 1 4481 | ▁route 1 4482 | ▁George 1 4483 | ▁- 1 4484 | lands 1 4485 | ▁data 1 4486 | ▁subst 1 4487 | ▁parish 1 4488 | ▁tone 1 4489 | ▁hol 1 4490 | ▁philos 1 4491 | verse 1 4492 | ▁France 1 4493 | ulf 1 4494 | ▁speech 1 4495 | ▁characters 1 4496 | ▁isn 1 4497 | ▁descend 1 4498 | ▁charge 1 4499 | ▁Ne 1 4500 | Will 1 4501 | ▁smile 1 4502 | ▁directed 1 4503 | ▁radio 1 4504 | ▁independent 1 4505 | ▁hospital 1 4506 | ▁sex 1 4507 | ▁cit 1 4508 | ▁exhib 1 4509 | ▁ninet 1 4510 | Sc 1 4511 | ension 1 4512 | ▁provide 1 4513 | ▁settled 1 4514 | ▁majority 1 4515 | ▁anal 1 4516 | ▁everyone 1 4517 | ▁associated 1 4518 | ari 1 4519 | nds 1 4520 | more 1 4521 | ▁Bay 1 4522 | ▁thin 1 4523 | ▁provides 1 4524 | ket 1 4525 | ▁beat 1 4526 | pic 1 4527 | ▁artist 1 4528 | ▁fright 1 4529 | ▁wouldn 1 4530 | ▁teacher 1 4531 | ▁broke 1 4532 | ▁smaller 1 4533 | ▁standard 1 4534 | ▁operated 1 4535 | ▁ways 1 4536 | aughed 1 4537 | ▁films 1 4538 | ida 1 4539 | ▁champ 1 4540 | ▁english 1 4541 | ▁opinion 1 4542 | ▁expected 1 4543 | ▁Jack 1 4544 | ▁cho 1 4545 | ▁forth 1 4546 | ▁teams 1 4547 | ▁listed 1 4548 | ▁impossible 1 4549 | ▁hole 1 4550 | isions 1 4551 | Just 1 4552 | ▁laughed 1 4553 | lers 1 4554 | ourt 1 4555 | ▁forces 1 4556 | ugust 1 4557 | ▁sell 1 4558 | atever 1 4559 | ▁brief 1 4560 | ▁shoot 1 4561 | ▁musical 1 4562 | ▁Rec 1 4563 | ▁bal 1 4564 | ▁tur 1 4565 | ▁western 1 4566 | ▁represented 1 4567 | cks 1 4568 | ocks 1 4569 | rect 1 4570 | ▁cup 1 4571 | fast 1 4572 | alled 1 4573 | ▁escape 1 4574 | ▁eighteen 1 4575 | ▁neck 1 4576 | ▁shape 1 4577 | ashion 1 4578 | ▁storm 1 4579 | ▁univers 1 4580 | isl 1 4581 | utive 1 4582 | ▁link 1 4583 | ▁circumst 1 4584 | ▁sen 1 4585 | ▁card 1 4586 | ▁Jew 1 4587 | ▁sor 1 4588 | ▁valley 1 4589 | ▁problems 1 4590 | ▁wise 1 4591 | ▁shirt 1 4592 | ▁seeing 1 4593 | ▁actually 1 4594 | atives 1 4595 | ▁dressed 1 4596 | ▁meeting 1 4597 | ▁Em 1 4598 | ▁aband 1 4599 | ancing 1 4600 | ▁metal 1 4601 | ▁likely 1 4602 | ▁native 1 4603 | ▁perm 1 4604 | Yes 1 4605 | ▁silent 1 4606 | ▁greatest 1 4607 | ▁appointed 1 4608 | ▁structure 1 4609 | ▁professional 1 4610 | sec 1 4611 | ▁insp 1 4612 | ▁scient 1 4613 | ▁Hill 1 4614 | ▁date 1 4615 | icket 1 4616 | ▁shout 1 4617 | garet 1 4618 | ▁energy 1 4619 | Bl 1 4620 | ▁inqu 1 4621 | ▁note 1 4622 | ▁magn 1 4623 | ▁reported 1 4624 | ▁beauty 1 4625 | ▁highest 1 4626 | oms 1 4627 | ▁looks 1 4628 | astic 1 4629 | ▁European 1 4630 | ▁ign 1 4631 | ocked 1 4632 | och 1 4633 | ▁vessel 1 4634 | ▁university 1 4635 | Mar 1 4636 | ▁hero 1 4637 | ▁libr 1 4638 | ▁liber 1 4639 | ▁continues 1 4640 | ▁En 1 4641 | hens 1 4642 | ▁You 1 4643 | ▁voc 1 4644 | ▁opening 1 4645 | ▁value 1 4646 | ▁League 1 4647 | ▁suffered 1 4648 | ▁observed 1 4649 | ▁traditional 1 4650 | ek 1 4651 | ▁wit 1 4652 | orter 1 4653 | ulated 1 4654 | ▁recent 1 4655 | ▁picture 1 4656 | ▁situated 1 4657 | ▁burn 1 4658 | avy 1 4659 | ▁towns 1 4660 | ▁afterwards 1 4661 | atre 1 4662 | rences 1 4663 | ▁amer 1 4664 | ▁cere 1 4665 | ▁print 1 4666 | New 1 4667 | osite 1 4668 | worth 1 4669 | ▁hang 1 4670 | tr 1 4671 | kin 1 4672 | ▁types 1 4673 | ▁react 1 4674 | istor 1 4675 | arsh 1 4676 | ▁experi 1 4677 | ▁express 1 4678 | ▁Catholic 1 4679 | ▁composed 1 4680 | imate 1 4681 | ibr 1 4682 | ademy 1 4683 | ▁stones 1 4684 | adian 1 4685 | ▁watched 1 4686 | ▁castle 1 4687 | ▁brothers 1 4688 | rea 1 4689 | ▁Ange 1 4690 | ▁Club 1 4691 | rial 1 4692 | ▁baby 1 4693 | ▁expression 1 4694 | ▁limited 1 4695 | ▁David 1 4696 | ▁Su 1 4697 | ruction 1 4698 | ▁transport 1 4699 | yr 1 4700 | ▁pun 1 4701 | ication 1 4702 | ▁Mart 1 4703 | ▁director 1 4704 | ka 1 4705 | aval 1 4706 | ▁weak 1 4707 | ▁investig 1 4708 | ishop 1 4709 | ▁aunt 1 4710 | ▁clothes 1 4711 | sych 1 4712 | ▁Bel 1 4713 | ▁Paul 1 4714 | ▁International 1 4715 | De 1 4716 | ▁Te 1 4717 | ▁gar 1 4718 | ▁disappe 1 4719 | ▁Township 1 4720 | ▁mentioned 1 4721 | anted 1 4722 | isher 1 4723 | ▁catch 1 4724 | ▁habit 1 4725 | uly 1 4726 | ker 1 4727 | ▁spoken 1 4728 | ▁typically 1 4729 | ateg 1 4730 | lease 1 4731 | fied 1 4732 | gypt 1 4733 | ▁serious 1 4734 | ▁President 1 4735 | ▁activities 1 4736 | ember 1 4737 | eff 1 4738 | ▁brand 1 4739 | ▁Ant 1 4740 | ▁virt 1 4741 | ▁beside 1 4742 | ▁proved 1 4743 | ▁mut 1 4744 | ▁saint 1 4745 | ero 1 4746 | ▁campaign 1 4747 | ▁Reg 1 4748 | ictor 1 4749 | ▁aircraft 1 4750 | can 1 4751 | ▁relationship 1 4752 | ▁foreign 1 4753 | ray 1 4754 | ▁chem 1 4755 | ▁opport 1 4756 | ▁sword 1 4757 | ▁dance 1 4758 | ▁cities 1 4759 | ▁heaven 1 4760 | ▁connect 1 4761 | ▁windows 1 4762 | See 1 4763 | John 1 4764 | ▁singing 1 4765 | ▁apparent 1 4766 | Are 1 4767 | ▁cool 1 4768 | ▁nice 1 4769 | ▁cous 1 4770 | Se 1 4771 | ▁dom 1 4772 | ▁Green 1 4773 | ▁defin 1 4774 | ▁growing 1 4775 | ▁increased 1 4776 | ▁conversation 1 4777 | ▁newspaper 1 4778 | ▁becomes 1 4779 | lem 1 4780 | ▁Miss 1 4781 | ▁remembered 1 4782 | ▁distingu 1 4783 | ▁eastern 1 4784 | ▁paid 1 4785 | bridge 1 4786 | ▁faint 1 4787 | aked 1 4788 | ▁tears 1 4789 | ▁studies 1 4790 | ▁introduced 1 4791 | ▁tar 1 4792 | ▁swim 1 4793 | ronic 1 4794 | ▁animal 1 4795 | gs 1 4796 | ici 1 4797 | ▁jump 1 4798 | ▁situation 1 4799 | ittee 1 4800 | ▁coach 1 4801 | ▁border 1 4802 | ▁fashion 1 4803 | urity 1 4804 | attered 1 4805 | ▁tracks 1 4806 | stand 1 4807 | Here 1 4808 | mond 1 4809 | father 1 4810 | ▁letters 1 4811 | ▁ven 1 4812 | ▁stands 1 4813 | ▁Do 1 4814 | ▁surprise 1 4815 | unn 1 4816 | ▁fields 1 4817 | ▁disapp 1 4818 | ▁spect 1 4819 | ▁tim 1 4820 | ▁excell 1 4821 | ▁cru 1 4822 | ▁sports 1 4823 | ▁thoughts 1 4824 | ▁palace 1 4825 | quarters 1 4826 | ▁Pal 1 4827 | ▁soldiers 1 4828 | ▁exclaimed 1 4829 | ▁listen 1 4830 | ▁autom 1 4831 | ▁prepared 1 4832 | vey 1 4833 | ▁heat 1 4834 | ▁tail 1 4835 | ▁wear 1 4836 | olitan 1 4837 | ▁enemy 1 4838 | ▁throw 1 4839 | ▁Robert 1 4840 | ▁younger 1 4841 | ician 1 4842 | ▁Pat 1 4843 | ▁fixed 1 4844 | ▁produce 1 4845 | Have 1 4846 | ▁managed 1 4847 | inese 1 4848 | ria 1 4849 | ▁magazine 1 4850 | ▁wealth 1 4851 | ▁od 1 4852 | ▁conditions 1 4853 | ▁vent 1 4854 | mar 1 4855 | ▁minister 1 4856 | ▁existence 1 4857 | ▁cart 1 4858 | ▁advant 1 4859 | rim 1 4860 | ▁planet 1 4861 | ▁ships 1 4862 | ▁connected 1 4863 | ndon 1 4864 | ▁religious 1 4865 | ▁vehic 1 4866 | ▁determined 1 4867 | resp 1 4868 | ▁memory 1 4869 | alls 1 4870 | ▁minor 1 4871 | ▁sick 1 4872 | ▁sharp 1 4873 | ▁intellig 1 4874 | ▁bull 1 4875 | ▁legs 1 4876 | ▁section 1 4877 | ▁Dep 1 4878 | Desp 1 4879 | ▁scarce 1 4880 | ▁stret 1 4881 | omy 1 4882 | aling 1 4883 | ▁double 1 4884 | omen 1 4885 | ▁empt 1 4886 | ▁usual 1 4887 | odes 1 4888 | ▁err 1 4889 | ses 1 4890 | ▁smoke 1 4891 | ▁Fl 1 4892 | aging 1 4893 | ▁lear 1 4894 | ▁products 1 4895 | cean 1 4896 | ▁Mont 1 4897 | ▁marry 1 4898 | estival 1 4899 | ▁neighbour 1 4900 | ▁deli 1 4901 | ▁grace 1 4902 | ▁schol 1 4903 | ▁kit 1 4904 | ▁Govern 1 4905 | Such 1 4906 | night 1 4907 | icted 1 4908 | ▁accepted 1 4909 | ▁female 1 4910 | ▁butter 1 4911 | ▁She 1 4912 | bon 1 4913 | ▁oil 1 4914 | amber 1 4915 | nding 1 4916 | ▁luck 1 4917 | ▁journey 1 4918 | ▁desire 1 4919 | ▁contem 1 4920 | ▁numbers 1 4921 | ▁Republic 1 4922 | ▁Englishman 1 4923 | ▁founded 1 4924 | ▁require 1 4925 | ▁Sen 1 4926 | illy 1 4927 | ▁narrow 1 4928 | ani 1 4929 | ▁Bur 1 4930 | ailed 1 4931 | ▁beg 1 4932 | ▁dise 1 4933 | ▁qual 1 4934 | ▁sides 1 4935 | ▁fallen 1 4936 | rem 1 4937 | ▁evil 1 4938 | ▁approach 1 4939 | list 1 4940 | ▁kiss 1 4941 | ▁companies 1 4942 | ▁blind 1 4943 | Despite 1 4944 | ▁england 1 4945 | reland 1 4946 | ▁dread 1 4947 | ▁ly 1 4948 | gged 1 4949 | ▁previously 1 4950 | Ab 1 4951 | ▁upper 1 4952 | ▁mixed 1 4953 | ▁Valley 1 4954 | ▁contain 1 4955 | ▁Ma 1 4956 | ▁Spanish 1 4957 | ▁loud 1 4958 | ▁cyl 1 4959 | ▁flag 1 4960 | ▁supported 1 4961 | ompl 1 4962 | ▁rate 1 4963 | ▁issue 1 4964 | ▁instrument 1 4965 | ▁One 1 4966 | ▁exactly 1 4967 | wr 1 4968 | hire 1 4969 | ▁drum 1 4970 | ▁breath 1 4971 | ▁plants 1 4972 | ▁Japanese 1 4973 | otion 1 4974 | ▁gate 1 4975 | ▁guess 1 4976 | ring 1 4977 | ▁extra 1 4978 | ima 1 4979 | ▁royal 1 4980 | ▁quarter 1 4981 | ▁God 1 4982 | while 1 4983 | ▁shore 1 4984 | ancis 1 4985 | stone 1 4986 | ogether 1 4987 | ▁Africa 1 4988 | ▁measure 1 4989 | izes 1 4990 | izing 1 4991 | ▁dollars 1 4992 | ▁buy 1 4993 | rid 1 4994 | ▁Union 1 4995 | ▁players 1 4996 | gl 1 4997 | ▁vo 1 4998 | ▁tor 1 4999 | ▁occurred 1 5000 | ▁pure 1 5001 | ▁heads 1 5002 | ▁Royal 1 5003 | ▁honour 1 5004 | ▁offers 1 5005 | ▁pieces 1 5006 | hold 1 5007 | ▁diam 1 5008 | ror 1 5009 | rance 1 5010 | ▁vari 1 5011 | ault 1 5012 | Never 1 5013 | irginia 1 5014 | ▁noticed 1 5015 | ▁Ham 1 5016 | ▁earlier 1 5017 | ▁Dem 1 5018 | ▁Hen 1 5019 | anned 1 5020 | ▁pros 1 5021 | ▁oblig 1 5022 | ▁slightly 1 5023 | hew 1 5024 | under 1 5025 | ▁crew 1 5026 | ▁product 1 5027 | mas 1 5028 | stead 1 5029 | ▁notable 1 5030 | Ind 1 5031 | ented 1 5032 | ▁rough 1 5033 | ▁whatever 1 5034 | ▁presented 1 5035 | astern 1 5036 | ▁judge 1 5037 | ▁basket 1 5038 | ▁ast 1 5039 | ▁progress 1 5040 | ki 1 5041 | eness 1 5042 | ▁powerful 1 5043 | inated 1 5044 | ▁moving 1 5045 | ▁surrounding 1 5046 | ▁Tex 1 5047 | clock 1 5048 | ▁bread 1 5049 | ▁indic 1 5050 | iance 1 5051 | ▁Scott 1 5052 | ▁delic 1 5053 | ▁collection 1 5054 | ▁rub 1 5055 | ashed 1 5056 | ▁declared 1 5057 | ▁match 1 5058 | iar 1 5059 | ▁environ 1 5060 | ▁mary 1 5061 | ▁journal 1 5062 | ota 1 5063 | asis 1 5064 | ▁computer 1 5065 | Hey 1 5066 | porary 1 5067 | ares 1 5068 | ental 1 5069 | ▁maid 1 5070 | ▁consist 1 5071 | ▁mix 1 5072 | kes 1 5073 | ▁student 1 5074 | ▁dust 1 5075 | fortun 1 5076 | ▁notice 1 5077 | ini 1 5078 | ▁fit 1 5079 | ▁webs 1 5080 | ▁network 1 5081 | Le 1 5082 | ario 1 5083 | ▁motion 1 5084 | ▁opposite 1 5085 | ▁beach 1 5086 | oseph 1 5087 | More 1 5088 | ▁coat 1 5089 | ▁subs 1 5090 | ▁India 1 5091 | ▁fully 1 5092 | ▁trail 1 5093 | ims 1 5094 | ona 1 5095 | ▁holds 1 5096 | ▁concert 1 5097 | ▁refused 1 5098 | ▁turning 1 5099 | orial 1 5100 | ▁furn 1 5101 | ibility 1 5102 | ▁stick 1 5103 | ▁worse 1 5104 | ▁extremely 1 5105 | orks 1 5106 | ▁Mac 1 5107 | ▁unless 1 5108 | ▁owner 1 5109 | eper 1 5110 | book 1 5111 | roud 1 5112 | ▁eth 1 5113 | ▁driver 1 5114 | ▁favour 1 5115 | ▁ordered 1 5116 | oper 1 5117 | uries 1 5118 | ▁symb 1 5119 | ▁Louis 1 5120 | ▁portion 1 5121 | ▁mountains 1 5122 | iana 1 5123 | itory 1 5124 | ▁rank 1 5125 | zen 1 5126 | ache 1 5127 | ▁ble 1 5128 | ▁ideas 1 5129 | ▁purch 1 5130 | ▁weight 1 5131 | aunt 1 5132 | etic 1 5133 | ropolitan 1 5134 | ▁accident 1 5135 | ety 1 5136 | apter 1 5137 | ▁anyone 1 5138 | ▁pointed 1 5139 | ▁hyp 1 5140 | ▁removed 1 5141 | la 1 5142 | ▁Black 1 5143 | ▁weap 1 5144 | ▁authority 1 5145 | ▁ash 1 5146 | ▁threw 1 5147 | ▁airport 1 5148 | ▁largely 1 5149 | ▁Creek 1 5150 | ▁Peter 1 5151 | ▁citiz 1 5152 | jo 1 5153 | owned 1 5154 | ▁hy 1 5155 | ▁Jer 1 5156 | ▁Airport 1 5157 | ▁headquarters 1 5158 | Pro 1 5159 | ▁gaz 1 5160 | ▁sav 1 5161 | ▁fruit 1 5162 | ▁steps 1 5163 | ▁Mal 1 5164 | ▁officer 1 5165 | ▁proposed 1 5166 | ▁grown 1 5167 | ▁falling 1 5168 | ▁mir 1 5169 | eding 1 5170 | Follow 1 5171 | ▁lying 1 5172 | ▁surprised 1 5173 | osen 1 5174 | ishing 1 5175 | ▁Her 1 5176 | ▁diss 1 5177 | ▁duty 1 5178 | ▁via 1 5179 | ▁mort 1 5180 | ▁share 1 5181 | ▁philosop 1 5182 | ▁physical 1 5183 | unior 1 5184 | ▁noise 1 5185 | place 1 5186 | ▁issues 1 5187 | ▁golden 1 5188 | ▁receive 1 5189 | bing 1 5190 | ▁Sal 1 5191 | ▁rid 1 5192 | ▁rot 1 5193 | ▁skin 1 5194 | ▁tower 1 5195 | ▁museum 1 5196 | ▁becoming 1 5197 | ▁fighting 1 5198 | ▁pale 1 5199 | ▁greatly 1 5200 | ▁drawn 1 5201 | ▁rom 1 5202 | ▁elements 1 5203 | ▁staff 1 5204 | ▁contained 1 5205 | ompan 1 5206 | ▁square 1 5207 | ▁ham 1 5208 | ▁domin 1 5209 | ▁carrying 1 5210 | cles 1 5211 | ▁spl 1 5212 | ▁bought 1 5213 | ▁creature 1 5214 | aved 1 5215 | thur 1 5216 | ▁ded 1 5217 | ▁municipality 1 5218 | Gr 1 5219 | ception 1 5220 | ▁retired 1 5221 | ▁Rail 1 5222 | ▁poly 1 5223 | ashington 1 5224 | ▁election 1 5225 | ▁officers 1 5226 | ▁log 1 5227 | ▁mic 1 5228 | ▁pra 1 5229 | pered 1 5230 | ▁recomm 1 5231 | rol 1 5232 | ▁lips 1 5233 | ▁writer 1 5234 | peror 1 5235 | ▁stated 1 5236 | ▁pil 1 5237 | ▁dram 1 5238 | ▁roof 1 5239 | ▁french 1 5240 | ▁address 1 5241 | ▁programs 1 5242 | ▁Per 1 5243 | ▁earn 1 5244 | ▁Paris 1 5245 | ▁labor 1 5246 | ▁arrest 1 5247 | ▁interview 1 5248 | ▁ones 1 5249 | com 1 5250 | ▁effects 1 5251 | ▁bless 1 5252 | ▁stock 1 5253 | ▁division 1 5254 | ▁interpret 1 5255 | ▁Mil 1 5256 | ▁flo 1 5257 | ▁merely 1 5258 | ▁persons 1 5259 | ▁bod 1 5260 | ▁plans 1 5261 | ▁travel 1 5262 | ▁element 1 5263 | ias 1 5264 | icago 1 5265 | ▁grey 1 5266 | ption 1 5267 | ▁Great 1 5268 | ionship 1 5269 | ription 1 5270 | ▁accompl 1 5271 | ▁terrible 1 5272 | ▁psych 1 5273 | ▁episode 1 5274 | ▁politics 1 5275 | onst 1 5276 | ▁gro 1 5277 | ▁culture 1 5278 | Please 1 5279 | ▁editor 1 5280 | ▁passing 1 5281 | ▁restaurant 1 5282 | Per 1 5283 | bled 1 5284 | iel 1 5285 | ▁Associ 1 5286 | ▁Museum 1 5287 | ▁resist 1 5288 | ▁unf 1 5289 | ▁agreed 1 5290 | ▁sail 1 5291 | ▁nation 1 5292 | ▁severe 1 5293 | ▁ben 1 5294 | ▁Academy 1 5295 | ▁positive 1 5296 | ▁Mass 1 5297 | ▁thank 1 5298 | respond 1 5299 | ▁wished 1 5300 | ▁Company 1 5301 | ▁streets 1 5302 | ▁visible 1 5303 | ▁image 1 5304 | ▁intended 1 5305 | using 1 5306 | ▁cousin 1 5307 | ▁effort 1 5308 | ▁General 1 5309 | ▁Cor 1 5310 | hester 1 5311 | ▁completed 1 5312 | ▁difference 1 5313 | bl 1 5314 | ▁convent 1 5315 | ▁leaders 1 5316 | ▁interested 1 5317 | aud 1 5318 | ilies 1 5319 | lls 1 5320 | uls 1 5321 | ▁anne 1 5322 | Had 1 5323 | ait 1 5324 | ▁id 1 5325 | ▁coff 1 5326 | ▁unus 1 5327 | oma 1 5328 | wick 1 5329 | ▁signed 1 5330 | chen 1 5331 | ▁Sam 1 5332 | ▁economy 1 5333 | ▁sym 1 5334 | ▁candid 1 5335 | ▁machine 1 5336 | ola 1 5337 | ▁withd 1 5338 | theless 1 5339 | ▁follows 1 5340 | ▁explained 1 5341 | change 1 5342 | ▁circumstances 1 5343 | yer 1 5344 | ▁dogs 1 5345 | ▁advent 1 5346 | ▁faces 1 5347 | lla 1 5348 | ▁bottom 1 5349 | ▁Alex 1 5350 | ▁bare 1 5351 | ▁Italian 1 5352 | ▁controvers 1 5353 | Said 1 5354 | lymp 1 5355 | ▁Penn 1 5356 | ▁transform 1 5357 | ▁rise 1 5358 | ▁electric 1 5359 | ▁daughters 1 5360 | ▁nick 1 5361 | ▁drive 1 5362 | ▁sched 1 5363 | ▁library 1 5364 | ▁operations 1 5365 | kins 1 5366 | ▁perfectly 1 5367 | .. 1 5368 | ▁pa 1 5369 | iano 1 5370 | Among 1 5371 | set 1 5372 | venue 1 5373 | ▁survived 1 5374 | Up 1 5375 | ▁renamed 1 5376 | ▁recently 1 5377 | ▁Victor 1 5378 | ▁passion 1 5379 | ▁destroyed 1 5380 | ▁forty 1 5381 | ▁encour 1 5382 | ▁illust 1 5383 | ▁instant 1 5384 | ▁rocks 1 5385 | Following 1 5386 | ients 1 5387 | ▁tired 1 5388 | ▁papers 1 5389 | ▁signal 1 5390 | ▁concent 1 5391 | ▁fishing 1 5392 | ▁province 1 5393 | ste 1 5394 | ▁For 1 5395 | ▁laws 1 5396 | anta 1 5397 | ▁male 1 5398 | ▁freed 1 5399 | ▁needs 1 5400 | rees 1 5401 | ▁Central 1 5402 | ▁decision 1 5403 | Sim 1 5404 | Look 1 5405 | ▁avoid 1 5406 | ▁correspond 1 5407 | Ph 1 5408 | ▁Frank 1 5409 | ▁rights 1 5410 | ▁sought 1 5411 | ▁sources 1 5412 | ▁happiness 1 5413 | ▁principal 1 5414 | ▁Red 1 5415 | ▁choice 1 5416 | ▁economic 1 5417 | mosp 1 5418 | aming 1 5419 | ▁agree 1 5420 | ▁observ 1 5421 | ▁reviews 1 5422 | ▁drop 1 5423 | ▁Charles 1 5424 | ▁tend 1 5425 | ▁Texas 1 5426 | ▁crystal 1 5427 | apes 1 5428 | ▁pair 1 5429 | ▁Greek 1 5430 | ▁questions 1 5431 | ▁stro 1 5432 | ▁counter 1 5433 | iliar 1 5434 | uts 1 5435 | herd 1 5436 | tt 1 5437 | ederal 1 5438 | ▁pleasant 1 5439 | ▁Mark 1 5440 | ▁join 1 5441 | ▁regarded 1 5442 | Over 1 5443 | rie 1 5444 | abeth 1 5445 | adium 1 5446 | terday 1 5447 | ▁activity 1 5448 | eral 1 5449 | ▁blow 1 5450 | ▁meas 1 5451 | enger 1 5452 | ▁manager 1 5453 | ▁screen 1 5454 | ▁supposed 1 5455 | Sp 1 5456 | wing 1 5457 | ▁height 1 5458 | pet 1 5459 | ▁pocket 1 5460 | azz 1 5461 | adem 1 5462 | ties 1 5463 | ▁shadow 1 5464 | ▁operates 1 5465 | ▁ox 1 5466 | ▁literature 1 5467 | nament 1 5468 | ads 1 5469 | enes 1 5470 | ▁conce 1 5471 | ▁organization 1 5472 | ▁cow 1 5473 | ▁software 1 5474 | bles 1 5475 | ▁wid 1 5476 | ▁begins 1 5477 | ▁establish 1 5478 | ▁wel 1 5479 | ▁spite 1 5480 | ▁Mex 1 5481 | ▁Wood 1 5482 | ▁condu 1 5483 | Those 1 5484 | ▁compos 1 5485 | ▁waters 1 5486 | ▁ur 1 5487 | ading 1 5488 | ▁atmosp 1 5489 | ▁demand 1 5490 | host 1 5491 | ▁london 1 5492 | ▁We 1 5493 | ▁Rock 1 5494 | ▁orders 1 5495 | utch 1 5496 | rous 1 5497 | ▁exerc 1 5498 | ▁historical 1 5499 | ▁stations 1 5500 | lict 1 5501 | ▁employed 1 5502 | ▁breakfast 1 5503 | Mc 1 5504 | ▁mission 1 5505 | ooth 1 5506 | otland 1 5507 | ▁Canadian 1 5508 | ▁recognized 1 5509 | People 1 5510 | ▁fro 1 5511 | achus 1 5512 | ▁Area 1 5513 | ▁Latin 1 5514 | ▁eager 1 5515 | ▁beneath 1 5516 | ano 1 5517 | iments 1 5518 | ▁quality 1 5519 | ado 1 5520 | ▁pattern 1 5521 | ▁Western 1 5522 | ▁cylinder 1 5523 | ▁conscious 1 5524 | ▁Thomas 1 5525 | ▁absolute 1 5526 | ▁environment 1 5527 | uals 1 5528 | ▁Old 1 5529 | ▁crim 1 5530 | ▁ends 1 5531 | ▁vary 1 5532 | Though 1 5533 | ▁suggested 1 5534 | ▁Society 1 5535 | ▁advert 1 5536 | ▁hyd 1 5537 | ▁competition 1 5538 | ▁Jewish 1 5539 | onic 1 5540 | ▁policy 1 5541 | ▁Richard 1 5542 | ▁arrange 1 5543 | ▁methods 1 5544 | arian 1 5545 | ▁useful 1 5546 | ▁darkness 1 5547 | odd 1 5548 | ▁Fore 1 5549 | ▁impl 1 5550 | ▁occurs 1 5551 | ▁pulled 1 5552 | ▁closely 1 5553 | ▁explain 1 5554 | ▁improve 1 5555 | ▁township 1 5556 | fit 1 5557 | igan 1 5558 | ▁batt 1 5559 | ntario 1 5560 | ▁haven 1 5561 | ▁interrupt 1 5562 | ysis 1 5563 | ▁empty 1 5564 | ▁lit 1 5565 | ▁shook 1 5566 | ▁circle 1 5567 | year 1 5568 | ▁ult 1 5569 | stitute 1 5570 | ushing 1 5571 | ▁quant 1 5572 | ▁document 1 5573 | ▁jew 1 5574 | ▁reasons 1 5575 | ▁operation 1 5576 | ▁Washington 1 5577 | ▁commission 1 5578 | ▁Oh 1 5579 | ▁legisl 1 5580 | ▁lights 1 5581 | iform 1 5582 | ▁check 1 5583 | ▁learning 1 5584 | za 1 5585 | mber 1 5586 | annel 1 5587 | pherd 1 5588 | urance 1 5589 | ye 1 5590 | overy 1 5591 | ▁reply 1 5592 | ▁surrounded 1 5593 | ▁label 1 5594 | ▁teach 1 5595 | ▁besides 1 5596 | ▁wonderful 1 5597 | nes 1 5598 | ▁ah 1 5599 | eorge 1 5600 | Cur 1 5601 | usted 1 5602 | ▁chall 1 5603 | ▁nobody 1 5604 | ▁Ap 1 5605 | ▁vac 1 5606 | ▁multiple 1 5607 | rics 1 5608 | ▁Tor 1 5609 | based 1 5610 | ▁daily 1 5611 | ▁mayor 1 5612 | ▁increase 1 5613 | osh 1 5614 | uit 1 5615 | ruct 1 5616 | ▁lean 1 5617 | ▁purs 1 5618 | ▁parties 1 5619 | ▁principle 1 5620 | ▁lose 1 5621 | ▁growth 1 5622 | ▁heavily 1 5623 | ▁Carol 1 5624 | ▁suggest 1 5625 | ▁stranger 1 5626 | ▁Mo 1 5627 | etts 1 5628 | ▁Center 1 5629 | ▁versions 1 5630 | ▁legal 1 5631 | ▁gained 1 5632 | ▁Court 1 5633 | ▁financ 1 5634 | ▁First 1 5635 | ▁Ontario 1 5636 | ▁proceed 1 5637 | ▁background 1 5638 | icked 1 5639 | ▁someone 1 5640 | ▁prominent 1 5641 | ▁additional 1 5642 | ▁importance 1 5643 | pro 1 5644 | rig 1 5645 | ▁Chinese 1 5646 | ▁excited 1 5647 | war 1 5648 | ▁decor 1 5649 | ▁extended 1 5650 | ▁Wil 1 5651 | ▁calm 1 5652 | ferred 1 5653 | urally 1 5654 | ▁communities 1 5655 | ▁wants 1 5656 | uary 1 5657 | uten 1 5658 | lection 1 5659 | iday 1 5660 | ▁eggs 1 5661 | ▁alive 1 5662 | alle 1 5663 | ▁anti 1 5664 | izabeth 1 5665 | ▁Minister 1 5666 | ▁Law 1 5667 | ▁articles 1 5668 | ▁oldest 1 5669 | ▁suspic 1 5670 | ▁percent 1 5671 | ingers 1 5672 | ▁models 1 5673 | ... 1 5674 | Man 1 5675 | ▁ath 1 5676 | ocratic 1 5677 | ▁protect 1 5678 | ▁shoulder 1 5679 | esh 1 5680 | ller 1 5681 | ▁conver 1 5682 | ▁forgotten 1 5683 | born 1 5684 | ▁lic 1 5685 | ▁saf 1 5686 | ▁drawing 1 5687 | ternal 1 5688 | gent 1 5689 | gel 1 5690 | onent 1 5691 | burgh 1 5692 | ▁crater 1 5693 | ▁picked 1 5694 | ▁familiar 1 5695 | ▁instruments 1 5696 | ▁tong 1 5697 | Ste 1 5698 | ▁sequ 1 5699 | ▁adopted 1 5700 | ▁inhabit 1 5701 | ▁religion 1 5702 | uzz 1 5703 | Yet 1 5704 | ▁integ 1 5705 | ▁troops 1 5706 | leep 1 5707 | ▁cas 1 5708 | ▁playlist 1 5709 | ▁considerable 1 5710 | arts 1 5711 | ▁diver 1 5712 | ▁touch 1 5713 | ishment 1 5714 | ▁tender 1 5715 | ▁motor 1 5716 | ▁trains 1 5717 | ▁angry 1 5718 | ▁unique 1 5719 | ▁setting 1 5720 | ▁responsible 1 5721 | ▁knight 1 5722 | ▁ladies 1 5723 | gia 1 5724 | ▁riding 1 5725 | ▁article 1 5726 | ▁conduct 1 5727 | ▁million 1 5728 | ylvan 1 5729 | sequent 1 5730 | enny 1 5731 | oston 1 5732 | ▁acquaint 1 5733 | ▁facilities 1 5734 | ▁actor 1 5735 | ▁awarded 1 5736 | ▁engaged 1 5737 | ▁figures 1 5738 | aks 1 5739 | ▁fill 1 5740 | ▁Olymp 1 5741 | ati 1 5742 | ▁suburb 1 5743 | ▁abandoned 1 5744 | ▁goal 1 5745 | encies 1 5746 | ▁brain 1 5747 | ▁carefully 1 5748 | ▁steel 1 5749 | ▁colour 1 5750 | ▁presently 1 5751 | ▁proud 1 5752 | ▁Swed 1 5753 | ▁writ 1 5754 | ▁medical 1 5755 | ril 1 5756 | ugg 1 5757 | ▁gray 1 5758 | ▁nickn 1 5759 | ▁pecul 1 5760 | ▁veget 1 5761 | water 1 5762 | ▁finding 1 5763 | ▁advanced 1 5764 | ▁Cam 1 5765 | ▁cook 1 5766 | ▁media 1 5767 | ▁sorry 1 5768 | ▁classes 1 5769 | ▁display 1 5770 | ▁treatment 1 5771 | ▁Association 1 5772 | anes 1 5773 | asts 1 5774 | ▁dial 1 5775 | ▁peculiar 1 5776 | ▁administr 1 5777 | ▁milk 1 5778 | ▁China 1 5779 | ita 1 5780 | ▁resemb 1 5781 | ▁occasionally 1 5782 | ▁fought 1 5783 | ▁Don 1 5784 | ▁behav 1 5785 | ▁website 1 5786 | ▁remaining 1 5787 | ▁Ireland 1 5788 | ▁Michael 1 5789 | ▁assemb 1 5790 | ▁causes 1 5791 | Play 1 5792 | inity 1 5793 | ▁hotel 1 5794 | ▁details 1 5795 | ▁difficulty 1 5796 | ▁imper 1 5797 | cope 1 5798 | ▁emph 1 5799 | ▁touched 1 5800 | ▁scr 1 5801 | after 1 5802 | ▁artists 1 5803 | ▁derived 1 5804 | isation 1 5805 | oring 1 5806 | orted 1 5807 | ▁estate 1 5808 | ▁margaret 1 5809 | ▁Sy 1 5810 | achusetts 1 5811 | ▁excellent 1 5812 | ▁fifth 1 5813 | ▁Columb 1 5814 | ▁correct 1 5815 | ▁fifteen 1 5816 | ▁resident 1 5817 | otes 1 5818 | ▁ing 1 5819 | ▁salt 1 5820 | ▁guest 1 5821 | ▁campus 1 5822 | ▁disease 1 5823 | ▁families 1 5824 | ▁settlement 1 5825 | ▁ears 1 5826 | ▁block 1 5827 | ▁score 1 5828 | ▁carriage 1 5829 | gress 1 5830 | ▁elder 1 5831 | ▁acting 1 5832 | ▁otherwise 1 5833 | Te 1 5834 | ▁Res 1 5835 | ▁accompan 1 5836 | ▁normal 1 5837 | ▁fortune 1 5838 | ▁Tur 1 5839 | ▁Flor 1 5840 | ▁flows 1 5841 | ▁tremb 1 5842 | ▁alternative 1 5843 | ▁Chicago 1 5844 | ▁Mel 1 5845 | ▁kingdom 1 5846 | ▁watching 1 5847 | ▁une 1 5848 | Through 1 5849 | ▁powers 1 5850 | ▁smiled 1 5851 | esse 1 5852 | ▁glo 1 5853 | ▁aston 1 5854 | ylvania 1 5855 | ▁singer 1 5856 | acks 1 5857 | ▁ax 1 5858 | ▁notes 1 5859 | ▁pictures 1 5860 | ▁somewhere 1 5861 | utation 1 5862 | ▁promise 1 5863 | ▁Australian 1 5864 | ▁victory 1 5865 | ▁ban 1 5866 | ▁quietly 1 5867 | ▁advantage 1 5868 | ▁confess 1 5869 | ▁railroad 1 5870 | ▁Cast 1 5871 | ▁earned 1 5872 | ▁equal 1 5873 | ▁Russian 1 5874 | ▁positions 1 5875 | ▁aver 1 5876 | iration 1 5877 | hedral 1 5878 | ▁unit 1 5879 | ala 1 5880 | rant 1 5881 | ▁narr 1 5882 | ▁worn 1 5883 | ▁comment 1 5884 | ▁dru 1 5885 | ▁hun 1 5886 | tering 1 5887 | ▁applied 1 5888 | fall 1 5889 | onse 1 5890 | ▁shortly 1 5891 | ▁objects 1 5892 | ▁companion 1 5893 | ▁rode 1 5894 | ▁taste 1 5895 | athered 1 5896 | ▁occupied 1 5897 | ▁cars 1 5898 | ▁hidden 1 5899 | ▁morrow 1 5900 | ga 1 5901 | iot 1 5902 | ▁cheer 1 5903 | ▁schem 1 5904 | ▁efforts 1 5905 | ▁Cons 1 5906 | ▁Egypt 1 5907 | ▁theatre 1 5908 | ▁unusual 1 5909 | ▁stuff 1 5910 | ▁Ill 1 5911 | ▁Los 1 5912 | ▁trained 1 5913 | ▁promised 1 5914 | Four 1 5915 | ▁eleven 1 5916 | ▁atmosphere 1 5917 | acle 1 5918 | riend 1 5919 | ▁create 1 5920 | ▁passes 1 5921 | ▁scarcely 1 5922 | Once 1 5923 | orous 1 5924 | ▁intellect 1 5925 | ▁approached 1 5926 | erg 1 5927 | ▁announced 1 5928 | ▁department 1 5929 | rier 1 5930 | ▁flower 1 5931 | ▁Records 1 5932 | ▁annual 1 5933 | ▁Virginia 1 5934 | ▁defeated 1 5935 | asp 1 5936 | ▁doors 1 5937 | ▁governor 1 5938 | ▁professor 1 5939 | iate 1 5940 | umber 1 5941 | ▁code 1 5942 | ▁pron 1 5943 | Americ 1 5944 | ▁constructed 1 5945 | ▁total 1 5946 | ▁syn 1 5947 | ▁teeth 1 5948 | ▁gets 1 5949 | ▁noble 1 5950 | itable 1 5951 | ▁wings 1 5952 | lam 1 5953 | uan 1 5954 | lend 1 5955 | ▁visited 1 5956 | ▁relatively 1 5957 | umes 1 5958 | acity 1 5959 | ▁solid 1 5960 | ▁matters 1 5961 | iated 1 5962 | anches 1 5963 | ▁waited 1 5964 | bre 1 5965 | ▁bitter 1 5966 | ▁league 1 5967 | ▁flood 1 5968 | ▁Army 1 5969 | ▁drag 1 5970 | ▁albums 1 5971 | ▁shel 1 5972 | ▁target 1 5973 | rose 1 5974 | ules 1 5975 | ▁fault 1 5976 | ▁branch 1 5977 | ▁shepherd 1 5978 | ▁proof 1 5979 | ▁script 1 5980 | ront 1 5981 | ▁Ang 1 5982 | ▁bass 1 5983 | ▁conflict 1 5984 | ▁perceived 1 5985 | ▁wet 1 5986 | ▁categ 1 5987 | ▁chamber 1 5988 | ▁Mem 1 5989 | estic 1 5990 | ▁burst 1 5991 | ▁Northern 1 5992 | ▁hurry 1 5993 | ▁prior 1 5994 | ▁rapidly 1 5995 | ▁uns 1 5996 | ▁false 1 5997 | sylvania 1 5998 | Under 1 5999 | ▁Ohio 1 6000 | ▁rules 1 6001 | ▁chosen 1 6002 | ▁clearly 1 6003 | lim 1 6004 | asy 1 6005 | esty 1 6006 | ▁page 1 6007 | ▁aware 1 6008 | Pr 1 6009 | ▁rooms 1 6010 | ▁Dan 1 6011 | ▁emp 1 6012 | ▁contact 1 6013 | ▁Art 1 6014 | ▁Com 1 6015 | ▁guy 1 6016 | ▁calling 1 6017 | ▁Fort 1 6018 | ▁affect 1 6019 | ▁interesting 1 6020 | enance 1 6021 | ▁yesterday 1 6022 | ▁opportunity 1 6023 | ▁sam 1 6024 | ▁wounded 1 6025 | ▁approxim 1 6026 | ▁incl 1 6027 | ▁broadcast 1 6028 | how 1 6029 | ▁wave 1 6030 | ▁medic 1 6031 | ▁consequ 1 6032 | ▁shr 1 6033 | ▁symp 1 6034 | ▁vast 1 6035 | ▁rising 1 6036 | ▁bul 1 6037 | chief 1 6038 | ▁roads 1 6039 | ▁undert 1 6040 | ▁consult 1 6041 | Ed 1 6042 | ▁hills 1 6043 | ▁nomin 1 6044 | ▁flying 1 6045 | ▁studio 1 6046 | ▁respond 1 6047 | isp 1 6048 | itated 1 6049 | ▁damage 1 6050 | ▁murder 1 6051 | ▁crossed 1 6052 | ▁islands 1 6053 | ▁inse 1 6054 | ▁adult 1 6055 | ▁tribes 1 6056 | ▁Del 1 6057 | ▁Smith 1 6058 | ▁White 1 6059 | utenant 1 6060 | ▁industrial 1 6061 | arly 1 6062 | fess 1 6063 | ▁union 1 6064 | ▁Doctor 1 6065 | ▁aspect 1 6066 | ▁connection 1 6067 | Fur 1 6068 | icing 1 6069 | ▁hurt 1 6070 | ▁compar 1 6071 | ▁residents 1 6072 | ▁neighborhood 1 6073 | upp 1 6074 | ▁Mr 1 6075 | ▁san 1 6076 | ▁wore 1 6077 | ▁records 1 6078 | ▁wooden 1 6079 | ▁wheel 1 6080 | ▁perman 1 6081 | ▁formerly 1 6082 | ▁rat 1 6083 | iest 1 6084 | aj 1 6085 | eles 1 6086 | ▁Cat 1 6087 | ▁elev 1 6088 | ▁Henry 1 6089 | ▁Francis 1 6090 | ▁uniform 1 6091 | Therefore 1 6092 | yram 1 6093 | ▁claimed 1 6094 | rif 1 6095 | ▁fold 1 6096 | ▁fingers 1 6097 | ipl 1 6098 | ▁Pet 1 6099 | ulous 1 6100 | ▁gain 1 6101 | ▁stir 1 6102 | ▁omens 1 6103 | ▁peter 1 6104 | Instead 1 6105 | ▁nort 1 6106 | ivered 1 6107 | ▁critics 1 6108 | ▁everybody 1 6109 | ▁grandfather 1 6110 | ▁waves 1 6111 | ▁constant 1 6112 | va 1 6113 | ▁Angeles 1 6114 | ivity 1 6115 | ▁Gold 1 6116 | ▁pink 1 6117 | ▁column 1 6118 | hus 1 6119 | min 1 6120 | onym 1 6121 | ▁mal 1 6122 | ▁onto 1 6123 | ▁trial 1 6124 | ▁possibly 1 6125 | oln 1 6126 | urd 1 6127 | acific 1 6128 | ▁levels 1 6129 | ▁distant 1 6130 | ▁dreams 1 6131 | ▁bold 1 6132 | ▁liqu 1 6133 | ▁bodies 1 6134 | ▁volunt 1 6135 | ▁pleased 1 6136 | rael 1 6137 | ▁mile 1 6138 | ▁murd 1 6139 | ▁ceased 1 6140 | ▁witness 1 6141 | ansas 1 6142 | pping 1 6143 | ▁Earth 1 6144 | ▁genus 1 6145 | ▁feelings 1 6146 | tle 1 6147 | ▁Liber 1 6148 | ▁sport 1 6149 | ▁detect 1 6150 | ▁entrance 1 6151 | ▁Mag 1 6152 | ▁My 1 6153 | ▁experiment 1 6154 | ▁maintained 1 6155 | ▁disappeared 1 6156 | ▁freedom 1 6157 | ▁smiling 1 6158 | ▁feed 1 6159 | ▁sounds 1 6160 | Sometimes 1 6161 | ▁mas 1 6162 | ▁chest 1 6163 | ▁crown 1 6164 | ▁affairs 1 6165 | ▁hunting 1 6166 | lly 1 6167 | Upon 1 6168 | ▁ride 1 6169 | ▁recording 1 6170 | wich 1 6171 | ▁Ser 1 6172 | ▁Grand 1 6173 | ▁kitchen 1 6174 | ▁interior 1 6175 | ▁Haw 1 6176 | ▁justice 1 6177 | ▁satisfact 1 6178 | ▁Spr 1 6179 | ▁ward 1 6180 | ▁alarm 1 6181 | erk 1 6182 | ▁Irish 1 6183 | ien 1 6184 | ▁unh 1 6185 | ▁folk 1 6186 | ▁nose 1 6187 | ▁thro 1 6188 | ▁suppl 1 6189 | idently 1 6190 | ▁exists 1 6191 | usc 1 6192 | Soon 1 6193 | ueen 1 6194 | ▁Atl 1 6195 | play 1 6196 | gment 1 6197 | ▁persu 1 6198 | ▁gathered 1 6199 | ▁financial 1 6200 | ▁initially 1 6201 | ato 1 6202 | ▁kid 1 6203 | ▁meat 1 6204 | ▁defend 1 6205 | ▁mathem 1 6206 | ▁Florida 1 6207 | ▁actions 1 6208 | adel 1 6209 | olt 1 6210 | ▁La 1 6211 | ▁flu 1 6212 | ▁knowing 1 6213 | ▁dangerous 1 6214 | ▁functions 1 6215 | athy 1 6216 | inite 1 6217 | ▁affected 1 6218 | nam 1 6219 | ▁Che 1 6220 | ▁task 1 6221 | ▁safety 1 6222 | ulations 1 6223 | Go 1 6224 | iffe 1 6225 | ▁seek 1 6226 | ▁impact 1 6227 | ▁parliament 1 6228 | nic 1 6229 | lying 1 6230 | ▁risk 1 6231 | ▁basis 1 6232 | ▁application 1 6233 | ▁plate 1 6234 | ▁strang 1 6235 | ▁instance 1 6236 | avis 1 6237 | ▁Day 1 6238 | ▁coal 1 6239 | ▁horiz 1 6240 | ucks 1 6241 | ▁hes 1 6242 | ▁fier 1 6243 | ▁crime 1 6244 | ▁breast 1 6245 | ▁gather 1 6246 | Addition 1 6247 | ▁passage 1 6248 | ▁pressure 1 6249 | ▁management 1 6250 | ▁scientific 1 6251 | ▁technology 1 6252 | ▁ens 1 6253 | ▁temple 1 6254 | ▁Massachusetts 1 6255 | ▁Rev 1 6256 | ennis 1 6257 | ▁reject 1 6258 | ▁imagine 1 6259 | ▁woods 1 6260 | ▁poetry 1 6261 | Further 1 6262 | ▁security 1 6263 | ▁concerned 1 6264 | artment 1 6265 | ▁reform 1 6266 | ▁inspired 1 6267 | ua 1 6268 | oles 1 6269 | ▁Rob 1 6270 | ▁critical 1 6271 | vy 1 6272 | ▁Stud 1 6273 | ▁units 1 6274 | ▁constitu 1 6275 | ▁independ 1 6276 | unes 1 6277 | osis 1 6278 | ▁Mic 1 6279 | ▁ocean 1 6280 | ▁attrib 1 6281 | onia 1 6282 | awa 1 6283 | ouch 1 6284 | emies 1 6285 | ▁busy 1 6286 | ▁Kingdom 1 6287 | ▁escaped 1 6288 | ▁advice 1 6289 | ▁glance 1 6290 | ▁painted 1 6291 | Our 1 6292 | Would 1 6293 | ilton 1 6294 | ▁August 1 6295 | ▁Middle 1 6296 | ava 1 6297 | erves 1 6298 | ▁asleep 1 6299 | ▁courage 1 6300 | ▁influenced 1 6301 | ▁response 1 6302 | ▁arts 1 6303 | ▁dawn 1 6304 | ▁pier 1 6305 | ▁showing 1 6306 | ▁analysis 1 6307 | ▁frame 1 6308 | ▁obtained 1 6309 | ▁painting 1 6310 | rehens 1 6311 | iser 1 6312 | umph 1 6313 | ▁lamp 1 6314 | ▁reduced 1 6315 | ▁resc 1 6316 | ▁Board 1 6317 | Everyone 1 6318 | ▁requires 1 6319 | uild 1 6320 | ▁Sim 1 6321 | pton 1 6322 | ▁colonel 1 6323 | ▁gentlemen 1 6324 | ▁pool 1 6325 | ▁innoc 1 6326 | ▁Pennsylvania 1 6327 | ▁wond 1 6328 | ▁utter 1 6329 | ▁academ 1 6330 | ▁retire 1 6331 | ▁ancest 1 6332 | ▁hide 1 6333 | ▁aside 1 6334 | ▁flour 1 6335 | ▁terror 1 6336 | ▁nearest 1 6337 | Your 1 6338 | ▁Camp 1 6339 | ▁price 1 6340 | ▁allows 1 6341 | ▁regret 1 6342 | ▁prove 1 6343 | ▁corpor 1 6344 | ▁ordinary 1 6345 | Sub 1 6346 | gian 1 6347 | hetic 1 6348 | ▁Mill 1 6349 | ▁fate 1 6350 | ▁wing 1 6351 | istical 1 6352 | ▁offices 1 6353 | ▁interests 1 6354 | ▁municipal 1 6355 | ▁Im 1 6356 | ▁Ben 1 6357 | ▁examples 1 6358 | ▁inhabitants 1 6359 | ▁phr 1 6360 | acent 1 6361 | ▁tent 1 6362 | ▁Award 1 6363 | unction 1 6364 | ▁sisters 1 6365 | ▁ourselves 1 6366 | mes 1 6367 | ▁reference 1 6368 | ▁servant 1 6369 | ya 1 6370 | ▁yard 1 6371 | ▁status 1 6372 | ervative 1 6373 | iable 1 6374 | ▁wand 1 6375 | ▁smell 1 6376 | ▁traff 1 6377 | ▁returning 1 6378 | ▁explan 1 6379 | ▁discover 1 6380 | ▁liked 1 6381 | ▁theme 1 6382 | nderson 1 6383 | ▁hurried 1 6384 | ▁materials 1 6385 | ▁chain 1 6386 | ▁equip 1 6387 | ▁message 1 6388 | wed 1 6389 | ▁dyn 1 6390 | urt 1 6391 | ▁disturb 1 6392 | ▁destiny 1 6393 | ▁citizens 1 6394 | ▁apparently 1 6395 | istics 1 6396 | ▁Italy 1 6397 | ▁spend 1 6398 | ▁mounted 1 6399 | ▁listened 1 6400 | Rec 1 6401 | rum 1 6402 | wide 1 6403 | ▁Serv 1 6404 | ▁failure 1 6405 | ▁selected 1 6406 | ▁sixty 1 6407 | ferences 1 6408 | gen 1 6409 | ▁calcul 1 6410 | ▁expressed 1 6411 | oir 1 6412 | ▁Off 1 6413 | ▁Brown 1 6414 | ▁faced 1 6415 | ▁flash 1 6416 | ▁entertain 1 6417 | anda 1 6418 | inate 1 6419 | imately 1 6420 | ▁voices 1 6421 | ▁continu 1 6422 | ▁frightened 1 6423 | lets 1 6424 | ctors 1 6425 | rical 1 6426 | ▁fond 1 6427 | ▁conqu 1 6428 | ▁thrown 1 6429 | ▁shoulders 1 6430 | ▁stairs 1 6431 | ▁purposes 1 6432 | ▁belief 1 6433 | ▁wedd 1 6434 | apping 1 6435 | ▁incident 1 6436 | ▁regularly 1 6437 | onn 1 6438 | razil 1 6439 | ▁fled 1 6440 | ▁obsc 1 6441 | ▁eating 1 6442 | ▁stayed 1 6443 | ▁shouted 1 6444 | ▁separated 1 6445 | ▁guns 1 6446 | rently 1 6447 | ▁projects 1 6448 | ▁consisted 1 6449 | itors 1 6450 | ▁collabor 1 6451 | urse 1 6452 | ▁camer 1 6453 | ▁servants 1 6454 | ▁Ox 1 6455 | ▁Hel 1 6456 | ▁calls 1 6457 | ▁conclud 1 6458 | ▁degrees 1 6459 | appy 1 6460 | ▁Kent 1 6461 | ▁piano 1 6462 | ▁gently 1 6463 | ▁Southern 1 6464 | ▁tub 1 6465 | ▁Cong 1 6466 | ▁debut 1 6467 | ▁sugar 1 6468 | ▁affection 1 6469 | ▁dedicated 1 6470 | oped 1 6471 | ▁coffee 1 6472 | ▁temperature 1 6473 | ▁simpl 1 6474 | ▁audience 1 6475 | ▁pill 1 6476 | incoln 1 6477 | ▁talked 1 6478 | outheast 1 6479 | ▁villages 1 6480 | ▁impr 1 6481 | ▁branches 1 6482 | ▁desk 1 6483 | ▁outd 1 6484 | ▁teaching 1 6485 | ▁distinguished 1 6486 | ▁goods 1 6487 | ▁tongue 1 6488 | ▁scholars 1 6489 | ▁swimming 1 6490 | Maybe 1 6491 | inois 1 6492 | oviet 1 6493 | ▁tight 1 6494 | ▁maj 1 6495 | ▁Scotland 1 6496 | ▁yield 1 6497 | ▁regions 1 6498 | ▁Arch 1 6499 | ▁teles 1 6500 | ▁cultiv 1 6501 | ▁views 1 6502 | ▁climate 1 6503 | ▁moments 1 6504 | ▁thousands 1 6505 | ulate 1 6506 | ▁propos 1 6507 | ▁obliged 1 6508 | ▁historic 1 6509 | ij 1 6510 | ▁pursu 1 6511 | ▁laun 1 6512 | ▁armed 1 6513 | ▁finds 1 6514 | ▁unable 1 6515 | ▁enemies 1 6516 | ▁soldier 1 6517 | ta 1 6518 | ▁frag 1 6519 | ▁smooth 1 6520 | Any 1 6521 | Mean 1 6522 | ▁era 1 6523 | ▁vain 1 6524 | bar 1 6525 | ▁Sir 1 6526 | ▁surely 1 6527 | inter 1 6528 | omer 1 6529 | ▁aid 1 6530 | ealand 1 6531 | ▁overl 1 6532 | ▁capable 1 6533 | Im 1 6534 | ▁Tim 1 6535 | andom 1 6536 | ▁dick 1 6537 | ▁phone 1 6538 | ▁claims 1 6539 | ▁enthus 1 6540 | ▁splend 1 6541 | ▁traffic 1 6542 | ▁understanding 1 6543 | En 1 6544 | ▁strike 1 6545 | umm 1 6546 | aver 1 6547 | ▁Jersey 1 6548 | ▁settle 1 6549 | ▁chemical 1 6550 | aws 1 6551 | ▁Martin 1 6552 | ▁lyrics 1 6553 | ▁Carolina 1 6554 | Ag 1 6555 | rill 1 6556 | ▁hosts 1 6557 | vis 1 6558 | ▁mob 1 6559 | ▁homes 1 6560 | ▁symbol 1 6561 | ucle 1 6562 | ▁Inde 1 6563 | ▁myst 1 6564 | ▁saved 1 6565 | ▁preced 1 6566 | ▁Sl 1 6567 | ▁Mad 1 6568 | ▁cycl 1 6569 | ▁craft 1 6570 | ▁lifted 1 6571 | ▁reality 1 6572 | ▁actual 1 6573 | ▁provinc 1 6574 | ▁seasons 1 6575 | ▁thorough 1 6576 | sp 1 6577 | ▁falls 1 6578 | standing 1 6579 | ildren 1 6580 | ▁producer 1 6581 | ▁recovered 1 6582 | Car 1 6583 | vin 1 6584 | hess 1 6585 | ounce 1 6586 | erable 1 6587 | ▁relief 1 6588 | ▁captured 1 6589 | Ro 1 6590 | ban 1 6591 | ▁mart 1 6592 | ▁ghost 1 6593 | ▁online 1 6594 | ▁agricultural 1 6595 | ▁Libr 1 6596 | ▁mistake 1 6597 | ▁solution 1 6598 | ▁reputation 1 6599 | iger 1 6600 | ulpt 1 6601 | umped 1 6602 | ▁sake 1 6603 | ▁chose 1 6604 | ▁strip 1 6605 | ▁resour 1 6606 | ▁meteor 1 6607 | ▁equipment 1 6608 | ras 1 6609 | iner 1 6610 | ▁shoes 1 6611 | ▁Public 1 6612 | ▁supply 1 6613 | ▁camel 1 6614 | ▁scale 1 6615 | ▁rivers 1 6616 | ▁Carl 1 6617 | ▁exped 1 6618 | ▁phen 1 6619 | ▁concl 1 6620 | ▁permitted 1 6621 | onto 1 6622 | ▁dish 1 6623 | ▁soil 1 6624 | ▁cattle 1 6625 | ▁naturally 1 6626 | ▁keeping 1 6627 | ▁territory 1 6628 | ▁murm 1 6629 | ▁pride 1 6630 | ▁rang 1 6631 | ▁strongly 1 6632 | osc 1 6633 | ▁hal 1 6634 | ▁lect 1 6635 | ▁prem 1 6636 | ▁drift 1 6637 | ▁trick 1 6638 | incorporated 1 6639 | Gu 1 6640 | arge 1 6641 | ▁sin 1 6642 | ▁Second 1 6643 | ▁madame 1 6644 | ▁Institute 1 6645 | ▁dare 1 6646 | ▁prompt 1 6647 | izations 1 6648 | ▁ahead 1 6649 | ▁sites 1 6650 | ▁waste 1 6651 | bec 1 6652 | ▁Den 1 6653 | ▁rac 1 6654 | ockey 1 6655 | ▁cliff 1 6656 | ▁pounds 1 6657 | cel 1 6658 | ▁lat 1 6659 | ▁Prov 1 6660 | ▁covers 1 6661 | ▁duties 1 6662 | ▁finger 1 6663 | ▁patient 1 6664 | ▁extensive 1 6665 | ▁suffering 1 6666 | iques 1 6667 | ▁mode 1 6668 | ▁choose 1 6669 | ▁bay 1 6670 | ▁acts 1 6671 | ointment 1 6672 | ▁sufficient 1 6673 | ▁accompanied 1 6674 | ▁cos 1 6675 | ▁contrast 1 6676 | ▁fer 1 6677 | ▁lod 1 6678 | ▁Ag 1 6679 | ▁Sur 1 6680 | stein 1 6681 | ▁flew 1 6682 | ▁rarely 1 6683 | ▁emperor 1 6684 | ▁facility 1 6685 | ih 1 6686 | urches 1 6687 | Fl 1 6688 | ▁Lord 1 6689 | ▁drove 1 6690 | ownt 1 6691 | phone 1 6692 | ▁strict 1 6693 | American 1 6694 | ▁hearing 1 6695 | isms 1 6696 | ▁obv 1 6697 | ▁sets 1 6698 | ▁cruel 1 6699 | ▁festival 1 6700 | ucky 1 6701 | ▁aim 1 6702 | ▁orange 1 6703 | ▁treated 1 6704 | ▁gradually 1 6705 | iece 1 6706 | ▁Group 1 6707 | ▁sculpt 1 6708 | Meanwhile 1 6709 | ▁intention 1 6710 | oys 1 6711 | ▁myster 1 6712 | ▁priest 1 6713 | ▁seized 1 6714 | ▁administration 1 6715 | ▁Sund 1 6716 | ▁evident 1 6717 | ▁immense 1 6718 | ▁universal 1 6719 | ran 1 6720 | ogue 1 6721 | ▁unincorporated 1 6722 | phy 1 6723 | ▁Cross 1 6724 | ▁refers 1 6725 | ▁African 1 6726 | ▁increasing 1 6727 | lle 1 6728 | inem 1 6729 | ▁tra 1 6730 | mother 1 6731 | ▁minim 1 6732 | ▁suffer 1 6733 | ▁opposition 1 6734 | ▁unp 1 6735 | ▁united 1 6736 | ▁descent 1 6737 | ▁administrative 1 6738 | ▁cells 1 6739 | ▁Mexico 1 6740 | ▁fairly 1 6741 | ▁seated 1 6742 | ▁arranged 1 6743 | Us 1 6744 | ▁glob 1 6745 | ▁granted 1 6746 | ▁sleeping 1 6747 | izz 1 6748 | ▁lege 1 6749 | ▁nurse 1 6750 | ▁protection 1 6751 | rete 1 6752 | atherine 1 6753 | ▁starting 1 6754 | ▁fox 1 6755 | ▁compr 1 6756 | ▁leads 1 6757 | ▁handsome 1 6758 | ▁hundreds 1 6759 | ▁assistant 1 6760 | awn 1 6761 | seless 1 6762 | ▁crack 1 6763 | ▁fancy 1 6764 | ▁lawyer 1 6765 | ▁theore 1 6766 | ▁altogether 1 6767 | usse 1 6768 | raham 1 6769 | ▁split 1 6770 | rain 1 6771 | ▁telling 1 6772 | ▁luc 1 6773 | ▁anxious 1 6774 | ▁regional 1 6775 | ▁executive 1 6776 | ▁subsequent 1 6777 | ▁Bell 1 6778 | ▁entering 1 6779 | yard 1 6780 | ▁pup 1 6781 | chest 1 6782 | ropri 1 6783 | ▁Columbia 1 6784 | ▁Pri 1 6785 | ▁Library 1 6786 | ▁graduated 1 6787 | know 1 6788 | ▁Vill 1 6789 | ▁Station 1 6790 | ▁du 1 6791 | ▁bent 1 6792 | ensity 1 6793 | ▁workers 1 6794 | akers 1 6795 | icken 1 6796 | urban 1 6797 | ▁Bill 1 6798 | ▁Wind 1 6799 | ▁Trans 1 6800 | Without 1 6801 | ▁philosophy 1 6802 | cca 1 6803 | olas 1 6804 | otic 1 6805 | onents 1 6806 | ▁diamond 1 6807 | ▁equival 1 6808 | ▁Fin 1 6809 | istan 1 6810 | itage 1 6811 | ▁Star 1 6812 | ▁fart 1 6813 | engers 1 6814 | nce 1 6815 | ▁Ke 1 6816 | anny 1 6817 | ▁dur 1 6818 | ▁sole 1 6819 | ▁intelligence 1 6820 | inth 1 6821 | ▁sed 1 6822 | istry 1 6823 | ▁remind 1 6824 | ▁improved 1 6825 | ▁dot 1 6826 | ▁odd 1 6827 | ▁amaz 1 6828 | ▁flow 1 6829 | bra 1 6830 | gor 1 6831 | ilty 1 6832 | ▁bat 1 6833 | ▁Foot 1 6834 | ▁dull 1 6835 | ▁trip 1 6836 | ▁Super 1 6837 | ▁Soviet 1 6838 | ▁instru 1 6839 | ▁lovely 1 6840 | ▁acquired 1 6841 | riff 1 6842 | ▁turns 1 6843 | ▁american 1 6844 | ▁satisfied 1 6845 | Miss 1 6846 | ▁pion 1 6847 | ▁Brook 1 6848 | ▁images 1 6849 | ▁titles 1 6850 | ▁planned 1 6851 | ▁Boston 1 6852 | ▁initial 1 6853 | ▁sixteen 1 6854 | ▁achieved 1 6855 | ▁cheese 1 6856 | ▁literary 1 6857 | ▁stores 1 6858 | ▁rhy 1 6859 | ▁subjects 1 6860 | ▁possession 1 6861 | mony 1 6862 | ▁moral 1 6863 | ▁reign 1 6864 | William 1 6865 | ▁curious 1 6866 | ▁absor 1 6867 | ▁signs 1 6868 | ▁Forest 1 6869 | ▁Joseph 1 6870 | ▁proport 1 6871 | ▁admitted 1 6872 | ▁earnest 1 6873 | ▁architecture 1 6874 | uns 1 6875 | esus 1 6876 | ture 1 6877 | ▁poll 1 6878 | ▁alleg 1 6879 | ▁advance 1 6880 | ▁Bob 1 6881 | ▁depth 1 6882 | ▁swift 1 6883 | ▁dismiss 1 6884 | ▁Rad 1 6885 | ▁horn 1 6886 | atically 1 6887 | ▁anybody 1 6888 | ▁laughing 1 6889 | ▁Department 1 6890 | ▁serving 1 6891 | ▁Sat 1 6892 | ▁exha 1 6893 | ▁yours 1 6894 | ▁evidently 1 6895 | ▁fet 1 6896 | ▁acid 1 6897 | ▁grat 1 6898 | ▁loose 1 6899 | lyn 1 6900 | ura 1 6901 | ▁Er 1 6902 | ▁prefer 1 6903 | ▁prisoner 1 6904 | ▁attempted 1 6905 | rior 1 6906 | ▁basic 1 6907 | ▁hoped 1 6908 | ▁federal 1 6909 | ▁putting 1 6910 | Book 1 6911 | ▁honest 1 6912 | ▁owners 1 6913 | ▁stable 1 6914 | ▁achieve 1 6915 | ▁tells 1 6916 | ▁colors 1 6917 | ▁tow 1 6918 | ▁issued 1 6919 | ▁frequent 1 6920 | ▁conducted 1 6921 | ▁agriculture 1 6922 | pent 1 6923 | ▁gay 1 6924 | ▁nerv 1 6925 | oration 1 6926 | ▁Junior 1 6927 | ▁Illinois 1 6928 | ▁discovery 1 6929 | irk 1 6930 | onna 1 6931 | ▁jul 1 6932 | aired 1 6933 | ▁astron 1 6934 | ▁sorrow 1 6935 | ibly 1 6936 | ▁fix 1 6937 | ▁inher 1 6938 | ▁stamp 1 6939 | owntown 1 6940 | ▁athlet 1 6941 | ▁ability 1 6942 | ▁adjacent 1 6943 | ▁sentence 1 6944 | uman 1 6945 | ▁sess 1 6946 | ▁roles 1 6947 | ▁baseball 1 6948 | ▁secondary 1 6949 | ▁Parliament 1 6950 | anna 1 6951 | ▁focus 1 6952 | ▁march 1 6953 | ▁Brazil 1 6954 | umin 1 6955 | ▁map 1 6956 | ▁Will 1 6957 | ▁advoc 1 6958 | ▁brave 1 6959 | ▁childhood 1 6960 | ▁Line 1 6961 | ▁begun 1 6962 | ▁devil 1 6963 | ▁outer 1 6964 | Much 1 6965 | ▁Kore 1 6966 | ▁Oxford 1 6967 | ▁complic 1 6968 | ▁tomorrow 1 6969 | ▁revolution 1 6970 | ▁contributed 1 6971 | ▁manufacture 1 6972 | ilm 1 6973 | uate 1 6974 | ▁gods 1 6975 | ▁nest 1 6976 | ▁clock 1 6977 | ▁interrupted 1 6978 | Born 1 6979 | ▁Tom 1 6980 | ▁landsc 1 6981 | ▁translation 1 6982 | aska 1 6983 | Event 1 6984 | ▁hate 1 6985 | ▁mood 1 6986 | ▁sang 1 6987 | ▁rural 1 6988 | ▁amongst 1 6989 | ▁enjoyed 1 6990 | ▁welcome 1 6991 | ▁arrested 1 6992 | agg 1 6993 | han 1 6994 | ▁Dire 1 6995 | ▁request 1 6996 | ▁Corn 1 6997 | lessly 1 6998 | ▁opposed 1 6999 | ▁Conn 1 7000 | ▁banks 1 7001 | ▁Britain 1 7002 | ▁defined 1 7003 | ▁struggle 1 7004 | ▁Sand 1 7005 | ▁clouds 1 7006 | eenth 1 7007 | yramids 1 7008 | ▁actress 1 7009 | ▁partner 1 7010 | ▁restrict 1 7011 | ▁whispered 1 7012 | ▁pin 1 7013 | ▁moder 1 7014 | ▁Albert 1 7015 | bell 1 7016 | raph 1 7017 | ▁pow 1 7018 | ▁web 1 7019 | ▁bore 1 7020 | ▁seats 1 7021 | ▁Stat 1 7022 | ▁sust 1 7023 | ▁occasions 1 7024 | ▁technical 1 7025 | First 1 7026 | onsie 1 7027 | ▁crop 1 7028 | borough 1 7029 | ▁friendly 1 7030 | ▁implement 1 7031 | Bes 1 7032 | ▁buck 1 7033 | ▁meal 1 7034 | ▁faster 1 7035 | ▁dancing 1 7036 | ▁ec 1 7037 | Good 1 7038 | ▁Charl 1 7039 | ▁magic 1 7040 | ▁glasses 1 7041 | abel 1 7042 | iral 1 7043 | ▁cott 1 7044 | ▁Champ 1 7045 | ▁Victoria 1 7046 | ▁identified 1 7047 | ▁convinc 1 7048 | ▁hanging 1 7049 | ▁attached 1 7050 | ▁churches 1 7051 | ▁visitors 1 7052 | onc 1 7053 | ▁ton 1 7054 | ▁engin 1 7055 | roc 1 7056 | ▁row 1 7057 | ▁myth 1 7058 | ▁thor 1 7059 | ▁fleet 1 7060 | ▁sales 1 7061 | ▁superior 1 7062 | ▁practical 1 7063 | ▁translated 1 7064 | ns 1 7065 | ▁pard 1 7066 | ▁worst 1 7067 | ▁affair 1 7068 | rep 1 7069 | ocal 1 7070 | ▁bands 1 7071 | ▁profession 1 7072 | ▁Kn 1 7073 | ▁steam 1 7074 | ▁german 1 7075 | ▁individuals 1 7076 | burn 1 7077 | ▁Land 1 7078 | ▁winds 1 7079 | ▁shared 1 7080 | ▁arrival 1 7081 | ▁ceremony 1 7082 | ▁alongside 1 7083 | Out 1 7084 | ▁dign 1 7085 | ▁straw 1 7086 | ▁wedding 1 7087 | ▁basketball 1 7088 | ▁Div 1 7089 | aylor 1 7090 | eping 1 7091 | ▁pity 1 7092 | ▁vote 1 7093 | ▁Centre 1 7094 | ▁Israel 1 7095 | ▁Pacific 1 7096 | ▁Cap 1 7097 | ▁Wal 1 7098 | ▁awards 1 7099 | ▁driving 1 7100 | Eventually 1 7101 | iders 1 7102 | ▁Toronto 1 7103 | ▁grounds 1 7104 | ▁Williams 1 7105 | ▁preserved 1 7106 | ▁specifically 1 7107 | Gl 1 7108 | ▁pig 1 7109 | ▁vic 1 7110 | ▁cord 1 7111 | ▁ende 1 7112 | ieties 1 7113 | inking 1 7114 | ▁contemporary 1 7115 | ▁bath 1 7116 | ▁obst 1 7117 | ▁Georgia 1 7118 | order 1 7119 | works 1 7120 | ▁thanks 1 7121 | ▁vision 1 7122 | ador 1 7123 | ▁Gre 1 7124 | ▁Jim 1 7125 | ▁holy 1 7126 | ▁credited 1 7127 | ▁Dr 1 7128 | apel 1 7129 | dale 1 7130 | mouth 1 7131 | ▁scattered 1 7132 | inson 1 7133 | ▁unex 1 7134 | ifying 1 7135 | ▁guilty 1 7136 | ▁weapons 1 7137 | ▁organized 1 7138 | Rate 1 7139 | irmed 1 7140 | ▁cards 1 7141 | ▁knees 1 7142 | ▁Jackson 1 7143 | ao 1 7144 | oons 1 7145 | ▁sty 1 7146 | App 1 7147 | roll 1 7148 | rome 1 7149 | acles 1 7150 | bourne 1 7151 | ▁flame 1 7152 | ▁locked 1 7153 | ▁engines 1 7154 | uable 1 7155 | ▁brill 1 7156 | ▁hence 1 7157 | ▁finish 1 7158 | ▁jacket 1 7159 | ▁substant 1 7160 | ▁guests 1 7161 | ▁difficulties 1 7162 | oting 1 7163 | ourse 1 7164 | ▁formal 1 7165 | ▁borders 1 7166 | ▁decades 1 7167 | ▁semi 1 7168 | ▁reflect 1 7169 | ▁Str 1 7170 | ▁Gard 1 7171 | ▁jane 1 7172 | ▁acknow 1 7173 | ▁hollow 1 7174 | ▁cultural 1 7175 | ▁Commission 1 7176 | ois 1 7177 | ighed 1 7178 | ommon 1 7179 | ▁overt 1 7180 | ▁combat 1 7181 | ▁staring 1 7182 | inton 1 7183 | ▁Navy 1 7184 | ▁pipe 1 7185 | ▁cloth 1 7186 | ▁constitution 1 7187 | elly 1 7188 | ober 1 7189 | onsc 1 7190 | Chapter 1 7191 | ▁northwest 1 7192 | Fin 1 7193 | olis 1 7194 | vard 1 7195 | liest 1 7196 | ▁dece 1 7197 | ▁poem 1 7198 | ▁Music 1 7199 | ▁acted 1 7200 | ▁classical 1 7201 | Near 1 7202 | usive 1 7203 | uther 1 7204 | ▁duke 1 7205 | ▁consum 1 7206 | phia 1 7207 | ▁wis 1 7208 | iture 1 7209 | ▁cert 1 7210 | unning 1 7211 | ▁flesh 1 7212 | ▁henry 1 7213 | ▁joint 1 7214 | ▁knife 1 7215 | ▁hearts 1 7216 | ▁prepar 1 7217 | ▁describe 1 7218 | ▁authorities 1 7219 | aus 1 7220 | ▁Sil 1 7221 | ▁combined 1 7222 | igr 1 7223 | atics 1 7224 | ▁copy 1 7225 | ▁solo 1 7226 | Mister 1 7227 | ptember 1 7228 | ▁Castle 1 7229 | ▁senior 1 7230 | ▁Mer 1 7231 | ▁burned 1 7232 | ▁matches 1 7233 | ▁retained 1 7234 | ▁departure 1 7235 | Sm 1 7236 | wan 1 7237 | aters 1 7238 | otton 1 7239 | ▁cave 1 7240 | ▁rear 1 7241 | ▁wire 1 7242 | ▁depos 1 7243 | ▁remarked 1 7244 | ▁experienced 1 7245 | vant 1 7246 | ▁sits 1 7247 | ▁george 1 7248 | ▁Governor 1 7249 | ▁unsu 1 7250 | teenth 1 7251 | ▁frank 1 7252 | ▁biggest 1 7253 | ▁listening 1 7254 | oked 1 7255 | ▁Sar 1 7256 | ▁desper 1 7257 | ▁nearer 1 7258 | ▁supper 1 7259 | ▁charged 1 7260 | ▁triumph 1 7261 | ▁Michigan 1 7262 | ▁attempts 1 7263 | ▁interred 1 7264 | ▁Des 1 7265 | ▁ere 1 7266 | ▁por 1 7267 | ▁alcoh 1 7268 | ▁refer 1 7269 | ropical 1 7270 | ▁forests 1 7271 | ▁Football 1 7272 | ▁internal 1 7273 | ▁committee 1 7274 | ▁disappointed 1 7275 | ▁dozen 1 7276 | ▁transm 1 7277 | ▁operate 1 7278 | ▁Ann 1 7279 | asion 1 7280 | icide 1 7281 | ▁Rome 1 7282 | ▁bicy 1 7283 | ▁agent 1 7284 | ▁items 1 7285 | ▁novels 1 7286 | ▁plenty 1 7287 | ▁violence 1 7288 | Bet 1 7289 | opes 1 7290 | Nor 1 7291 | ▁landing 1 7292 | ▁attacked 1 7293 | ▁vehicles 1 7294 | ▁attracted 1 7295 | put 1 7296 | ndered 1 7297 | ▁graph 1 7298 | ▁crossing 1 7299 | force 1 7300 | ▁pushed 1 7301 | ▁immediate 1 7302 | acon 1 7303 | real 1 7304 | icity 1 7305 | pelled 1 7306 | ▁equally 1 7307 | ▁fol 1 7308 | ▁rag 1 7309 | ▁relie 1 7310 | adelphia 1 7311 | ▁irrit 1 7312 | ▁handle 1 7313 | ▁careful 1 7314 | ▁Scottish 1 7315 | ▁dreadful 1 7316 | ▁resolved 1 7317 | fortunately 1 7318 | ▁confidence 1 7319 | anz 1 7320 | ▁pine 1 7321 | ▁costs 1 7322 | ▁shooting 1 7323 | ▁permanent 1 7324 | ▁How 1 7325 | rayed 1 7326 | ▁spark 1 7327 | ▁sacrif 1 7328 | ▁expanded 1 7329 | with 1 7330 | ▁rent 1 7331 | ▁shell 1 7332 | ▁credit 1 7333 | ▁farther 1 7334 | ▁paralle 1 7335 | ▁centuries 1 7336 | amin 1 7337 | ▁channel 1 7338 | ▁applications 1 7339 | crib 1 7340 | ipping 1 7341 | long 1 7342 | ▁Mur 1 7343 | ▁france 1 7344 | ▁instantly 1 7345 | ▁brid 1 7346 | ▁publ 1 7347 | ustomed 1 7348 | ▁paused 1 7349 | ▁despite 1 7350 | ▁fiction 1 7351 | ▁impress 1 7352 | greg 1 7353 | ▁fuel 1 7354 | ▁horror 1 7355 | iro 1 7356 | ▁Jan 1 7357 | ▁curt 1 7358 | inging 1 7359 | ▁extraord 1 7360 | Sch 1 7361 | arry 1 7362 | ications 1 7363 | ▁desired 1 7364 | ▁spirits 1 7365 | ▁Committee 1 7366 | bit 1 7367 | ▁Ge 1 7368 | ▁Cle 1 7369 | ▁assert 1 7370 | ▁commander 1 7371 | ▁converted 1 7372 | ▁oasis 1 7373 | ilst 1 7374 | ▁rust 1 7375 | ▁challen 1 7376 | Cont 1 7377 | ▁sup 1 7378 | ▁giant 1 7379 | ▁nucle 1 7380 | ▁briefly 1 7381 | ▁intense 1 7382 | ▁northeast 1 7383 | ▁Metropolitan 1 7384 | ▁cig 1 7385 | ▁ped 1 7386 | ▁vocal 1 7387 | ▁designs 1 7388 | ▁william 1 7389 | ▁mistress 1 7390 | ▁represents 1 7391 | orph 1 7392 | ▁Dev 1 7393 | ▁golf 1 7394 | ▁despair 1 7395 | ▁highway 1 7396 | ▁medicine 1 7397 | ▁reservation 1 7398 | fle 1 7399 | iary 1 7400 | ▁vig 1 7401 | ▁guide 1 7402 | ▁lands 1 7403 | ▁user 1 7404 | iences 1 7405 | ▁driven 1 7406 | ▁founder 1 7407 | ▁clothing 1 7408 | ▁creation 1 7409 | ▁passenger 1 7410 | ▁techniques 1 7411 | Co 1 7412 | Come 1 7413 | irty 1 7414 | onel 1 7415 | ▁ate 1 7416 | tical 1 7417 | action 1 7418 | ▁dying 1 7419 | ▁sixth 1 7420 | ▁denied 1 7421 | ▁worship 1 7422 | ▁compared 1 7423 | ▁committed 1 7424 | ova 1 7425 | ▁decre 1 7426 | ▁afford 1 7427 | ▁absolutely 1 7428 | ▁rejo 1 7429 | ▁burning 1 7430 | ▁climbed 1 7431 | ▁apartment 1 7432 | ▁formation 1 7433 | Additionally 1 7434 | enses 1 7435 | erior 1 7436 | leton 1 7437 | ▁sounded 1 7438 | ▁explanation 1 7439 | face 1 7440 | ▁arose 1 7441 | ▁manif 1 7442 | ▁appeal 1 7443 | ▁endeav 1 7444 | ▁partners 1 7445 | ▁reaction 1 7446 | acc 1 7447 | elia 1 7448 | ▁apply 1 7449 | ▁slaves 1 7450 | entially 1 7451 | ▁newspapers 1 7452 | ▁string 1 7453 | ▁academic 1 7454 | ▁contrary 1 7455 | ▁performing 1 7456 | ▁appearances 1 7457 | Saint 1 7458 | ▁drunk 1 7459 | ▁skill 1 7460 | ▁feather 1 7461 | ▁retreat 1 7462 | ▁Memorial 1 7463 | life 1 7464 | ▁Dutch 1 7465 | ▁cream 1 7466 | ▁resulted 1 7467 | ▁composition 1 7468 | ▁Hol 1 7469 | ▁Asia 1 7470 | ▁tiny 1 7471 | ▁excit 1 7472 | ▁absence 1 7473 | ▁Internet 1 7474 | Part 1 7475 | Besides 1 7476 | Similar 1 7477 | ▁liquid 1 7478 | ▁partly 1 7479 | ▁racing 1 7480 | ▁revealed 1 7481 | eck 1 7482 | ▁Mars 1 7483 | ▁subd 1 7484 | ▁slave 1 7485 | ▁Jo 1 7486 | Short 1 7487 | arris 1 7488 | ▁chap 1 7489 | ▁facts 1 7490 | ▁resumed 1 7491 | ▁Cambridge 1 7492 | ▁everywhere 1 7493 | chers 1 7494 | ▁regist 1 7495 | ▁bedroom 1 7496 | ▁domestic 1 7497 | leg 1 7498 | urg 1 7499 | ▁reck 1 7500 | ▁Wales 1 7501 | ▁shame 1 7502 | ▁distribution 1 7503 | aka 1 7504 | ▁Luc 1 7505 | assed 1 7506 | ▁deck 1 7507 | onsieur 1 7508 | ▁scheme 1 7509 | ▁assumed 1 7510 | ▁effective 1 7511 | ▁operating 1 7512 | via 1 7513 | ▁Emp 1 7514 | ▁demanded 1 7515 | ▁sund 1 7516 | ▁Edward 1 7517 | Fr 1 7518 | ▁Sea 1 7519 | ▁grant 1 7520 | ▁Eastern 1 7521 | ▁exercise 1 7522 | ▁restored 1 7523 | ▁photograph 1 7524 | ▁Independent 1 7525 | ouri 1 7526 | ▁income 1 7527 | avig 1 7528 | oust 1 7529 | urday 1 7530 | ▁scenes 1 7531 | ▁distributed 1 7532 | ella 1 7533 | rive 1 7534 | ▁fan 1 7535 | aping 1 7536 | ▁Soul 1 7537 | ▁ideal 1 7538 | ▁stead 1 7539 | ▁portra 1 7540 | ▁nations 1 7541 | ▁adventure 1 7542 | ▁equivalent 1 7543 | ▁properties 1 7544 | enic 1 7545 | ansion 1 7546 | ▁ignor 1 7547 | ▁friendship 1 7548 | avan 1 7549 | ▁jim 1 7550 | ▁Fred 1 7551 | ▁maxim 1 7552 | ▁happens 1 7553 | ▁vehicle 1 7554 | ▁movements 1 7555 | iac 1 7556 | uty 1 7557 | ▁plot 1 7558 | ▁explos 1 7559 | ▁reports 1 7560 | ▁behavior 1 7561 | orse 1 7562 | ▁Post 1 7563 | ▁volume 1 7564 | ▁average 1 7565 | ▁Congress 1 7566 | ▁necessity 1 7567 | ▁accur 1 7568 | ▁willing 1 7569 | Currently 1 7570 | ▁somebody 1 7571 | ▁statement 1 7572 | ▁Way 1 7573 | ▁kings 1 7574 | ▁mental 1 7575 | ▁remarkable 1 7576 | ais 1 7577 | Stud 1 7578 | ▁Squ 1 7579 | ▁divor 1 7580 | rail 1 7581 | ▁Mid 1 7582 | ▁Budd 1 7583 | ▁pitch 1 7584 | ▁maintain 1 7585 | ▁sections 1 7586 | udge 1 7587 | ▁rein 1 7588 | ▁Spain 1 7589 | ▁Russia 1 7590 | ographer 1 7591 | ▁imprison 1 7592 | orporation 1 7593 | ▁approximately 1 7594 | mann 1 7595 | aptain 1 7596 | stairs 1 7597 | ▁fence 1 7598 | Nothing 1 7599 | ▁apprec 1 7600 | ▁Tri 1 7601 | itled 1 7602 | ▁hath 1 7603 | orious 1 7604 | ▁Avenue 1 7605 | ▁format 1 7606 | ▁nervous 1 7607 | ▁declined 1 7608 | ▁possessed 1 7609 | ▁uncertain 1 7610 | rick 1 7611 | ▁papa 1 7612 | ▁micro 1 7613 | ▁facult 1 7614 | ▁injury 1 7615 | ▁episodes 1 7616 | ▁synt 1 7617 | ▁recre 1 7618 | western 1 7619 | ▁bishop 1 7620 | ▁liberty 1 7621 | agle 1 7622 | ctober 1 7623 | ▁brick 1 7624 | ologist 1 7625 | inations 1 7626 | ▁shopping 1 7627 | ▁displayed 1 7628 | ▁association 1 7629 | ▁Es 1 7630 | ▁torn 1 7631 | ▁yards 1 7632 | ▁Berlin 1 7633 | ▁borough 1 7634 | ▁predomin 1 7635 | athan 1 7636 | ▁fans 1 7637 | ▁cabin 1 7638 | ▁tools 1 7639 | ▁lasted 1 7640 | ▁purple 1 7641 | ▁assigned 1 7642 | ▁telescope 1 7643 | ▁constantly 1 7644 | col 1 7645 | lemn 1 7646 | ▁tank 1 7647 | ▁Portug 1 7648 | ▁realized 1 7649 | wart 1 7650 | ▁Cup 1 7651 | unnel 1 7652 | ▁sale 1 7653 | ▁tort 1 7654 | ▁inner 1 7655 | ▁bottle 1 7656 | ▁observe 1 7657 | ext 1 7658 | law 1 7659 | ▁Os 1 7660 | class 1 7661 | icate 1 7662 | ▁leaf 1 7663 | ▁deeply 1 7664 | ▁throat 1 7665 | ▁wishes 1 7666 | ▁factory 1 7667 | ▁potential 1 7668 | tic 1 7669 | down 1 7670 | ▁attacks 1 7671 | ▁protest 1 7672 | ▁household 1 7673 | ▁residence 1 7674 | bey 1 7675 | atima 1 7676 | ▁manage 1 7677 | ▁eighty 1 7678 | ▁devoted 1 7679 | ▁stadium 1 7680 | ▁programming 1 7681 | ▁tied 1 7682 | ▁Chief 1 7683 | ▁cancer 1 7684 | ▁fierce 1 7685 | ▁solemn 1 7686 | ▁collected 1 7687 | ▁platforms 1 7688 | Tell 1 7689 | ewis 1 7690 | ▁gift 1 7691 | acious 1 7692 | ▁paris 1 7693 | ▁starred 1 7694 | ▁tournament 1 7695 | ba 1 7696 | ▁Dar 1 7697 | ▁elim 1 7698 | ▁gran 1 7699 | ▁mild 1 7700 | ▁cheap 1 7701 | ▁enorm 1 7702 | ▁slept 1 7703 | ensions 1 7704 | ▁camera 1 7705 | ▁pretend 1 7706 | ▁succeed 1 7707 | ▁nurs 1 7708 | ▁Secret 1 7709 | ▁consol 1 7710 | ▁creatures 1 7711 | ▁performances 1 7712 | ogy 1 7713 | omed 1 7714 | ▁Sing 1 7715 | ▁rushed 1 7716 | ▁monaster 1 7717 | oil 1 7718 | agan 1 7719 | asty 1 7720 | uled 1 7721 | ▁Dou 1 7722 | ester 1 7723 | ▁rabb 1 7724 | ▁chart 1 7725 | ▁extent 1 7726 | ▁Mountain 1 7727 | ▁capacity 1 7728 | Bel 1 7729 | Very 1 7730 | eless 1 7731 | ▁Jean 1 7732 | ▁Jones 1 7733 | ▁conven 1 7734 | ▁structures 1 7735 | ▁forb 1 7736 | Origin 1 7737 | pecial 1 7738 | ▁clubs 1 7739 | ▁farmer 1 7740 | ▁instinct 1 7741 | ums 1 7742 | abet 1 7743 | ▁nep 1 7744 | ▁diet 1 7745 | ▁generation 1 7746 | osa 1 7747 | ▁et 1 7748 | ▁Lee 1 7749 | ▁dar 1 7750 | ▁mole 1 7751 | ▁rival 1 7752 | ▁resides 1 7753 | ▁stepped 1 7754 | ▁argument 1 7755 | ▁essential 1 7756 | ▁satisfaction 1 7757 | ▁Sun 1 7758 | atural 1 7759 | ▁fired 1 7760 | ▁knock 1 7761 | ographic 1 7762 | ▁dialect 1 7763 | ▁destination 1 7764 | ▁Act 1 7765 | uting 1 7766 | ▁kids 1 7767 | ▁locations 1 7768 | ▁thereafter 1 7769 | ▁residential 1 7770 | uz 1 7771 | ands 1 7772 | ▁McC 1 7773 | ▁keen 1 7774 | ▁Beach 1 7775 | ▁proceeded 1 7776 | ▁recommend 1 7777 | ▁excitement 1 7778 | ▁championship 1 7779 | iosity 1 7780 | ▁worry 1 7781 | ▁accustomed 1 7782 | ▁engineering 1 7783 | ▁12 1 7784 | utor 1 7785 | ▁Just 1 7786 | ▁flav 1 7787 | ▁distin 1 7788 | ▁legend 1 7789 | ▁random 1 7790 | ▁splendid 1 7791 | ▁searching 1 7792 | rev 1 7793 | rich 1 7794 | rove 1 7795 | ▁remote 1 7796 | ▁secure 1 7797 | ▁involve 1 7798 | ▁technique 1 7799 | ▁principles 1 7800 | vest 1 7801 | abled 1 7802 | ▁intr 1 7803 | ▁condem 1 7804 | ▁throne 1 7805 | iat 1 7806 | ▁Hills 1 7807 | ▁bench 1 7808 | ▁poison 1 7809 | ▁governed 1 7810 | ▁southwest 1 7811 | Take 1 7812 | icit 1 7813 | ▁cod 1 7814 | venge 1 7815 | ▁rene 1 7816 | ▁theater 1 7817 | erd 1 7818 | acob 1 7819 | ▁Minn 1 7820 | ▁tomb 1 7821 | ▁emerg 1 7822 | ▁steep 1 7823 | ▁thinks 1 7824 | ▁edition 1 7825 | Char 1 7826 | ▁fare 1 7827 | Little 1 7828 | ▁Civil 1 7829 | ▁Railway 1 7830 | ▁farming 1 7831 | ila 1 7832 | oop 1 7833 | una 1 7834 | istol 1 7835 | ▁rope 1 7836 | ▁spin 1 7837 | Unlike 1 7838 | ▁belonged 1 7839 | ldom 1 7840 | ▁hadn 1 7841 | ▁reward 1 7842 | enna 1 7843 | grad 1 7844 | ▁seventy 1 7845 | ▁expensive 1 7846 | ▁consequence 1 7847 | ▁accomplished 1 7848 | cho 1 7849 | ▁Bas 1 7850 | olate 1 7851 | ▁york 1 7852 | ▁Holly 1 7853 | ▁anger 1 7854 | ▁mining 1 7855 | ▁christian 1 7856 | ante 1 7857 | ▁Ray 1 7858 | ▁Matt 1 7859 | ▁cock 1 7860 | ▁truly 1 7861 | ▁Prince 1 7862 | ▁tennis 1 7863 | ▁illness 1 7864 | ▁boundary 1 7865 | ▁exchange 1 7866 | ▁No 1 7867 | Pres 1 7868 | ints 1 7869 | ▁haw 1 7870 | ▁diff 1 7871 | ▁hook 1 7872 | fessor 1 7873 | ▁shops 1 7874 | hesis 1 7875 | ▁Wars 1 7876 | illing 1 7877 | ▁limit 1 7878 | ▁polic 1 7879 | ▁hungry 1 7880 | ▁institutions 1 7881 | El 1 7882 | outs 1 7883 | agger 1 7884 | ▁adap 1 7885 | ▁roll 1 7886 | Doctor 1 7887 | Histor 1 7888 | ▁prize 1 7889 | ▁Bridge 1 7890 | ▁contest 1 7891 | ▁graduate 1 7892 | ▁valuable 1 7893 | Col 1 7894 | held 1 7895 | ▁Hon 1 7896 | ▁rim 1 7897 | ▁Book 1 7898 | ▁monk 1 7899 | ▁lover 1 7900 | ▁injured 1 7901 | ▁vessels 1 7902 | ▁seriously 1 7903 | lar 1 7904 | ippi 1 7905 | ▁ris 1 7906 | ▁asking 1 7907 | ▁nephew 1 7908 | bow 1 7909 | ▁ain 1 7910 | ▁boats 1 7911 | ▁prayer 1 7912 | ▁render 1 7913 | ▁changing 1 7914 | ▁patience 1 7915 | ▁bree 1 7916 | ▁dipl 1 7917 | ▁Broad 1 7918 | ▁batter 1 7919 | ▁rating 1 7920 | ▁Pyramids 1 7921 | ▁reception 1 7922 | ▁seventeen 1 7923 | egal 1 7924 | ▁avo 1 7925 | isible 1 7926 | ▁Princ 1 7927 | ▁violent 1 7928 | ▁prisoners 1 7929 | pool 1 7930 | ▁Wel 1 7931 | ▁summ 1 7932 | ▁plane 1 7933 | ▁worthy 1 7934 | ▁typical 1 7935 | ▁definition 1 7936 | ▁discussion 1 7937 | ▁ 1 7938 | e 1 7939 | t 1 7940 | a 1 7941 | o 1 7942 | i 1 7943 | n 1 7944 | s 1 7945 | r 1 7946 | h 1 7947 | l 1 7948 | d 1 7949 | c 1 7950 | u 1 7951 | m 1 7952 | f 1 7953 | w 1 7954 | g 1 7955 | y 1 7956 | p 1 7957 | b 1 7958 | . 1 7959 | v 1 7960 | k 1 7961 | , 1 7962 | T 1 7963 | A 1 7964 | I 1 7965 | ' 1 7966 | S 1 7967 | H 1 7968 | " 1 7969 | C 1 7970 | x 1 7971 | M 1 7972 | B 1 7973 | W 1 7974 | P 1 7975 | - 1 7976 | j 1 7977 | q 1 7978 | D 1 7979 | L 1 7980 | z 1 7981 | R 1 7982 | F 1 7983 | G 1 7984 | E 1 7985 | N 1 7986 | O 1 7987 | J 1 7988 | ? 1 7989 | K 1 7990 | Y 1 7991 | U 1 7992 | V 1 7993 | ! 1 7994 | 1 1 7995 | 2 1 7996 | : 1 7997 | --------------------------------------------------------------------------------