├── bert ├── __init__.py ├── train │ ├── __init__.py │ ├── model │ │ ├── __init__.py │ │ ├── gelu.py │ │ ├── test_model.py │ │ ├── embeddings.py │ │ ├── bert.py │ │ └── transformer.py │ ├── utils │ │ ├── __init__.py │ │ ├── pad.py │ │ ├── convert.py │ │ ├── log.py │ │ └── collate.py │ ├── datasets │ │ ├── __init__.py │ │ ├── classification.py │ │ └── pretraining.py │ ├── optimizers.py │ ├── metrics.py │ ├── loss_models.py │ ├── trainer.py │ └── train.py └── preprocess │ ├── utils.py │ ├── __init__.py │ ├── dictionary.py │ └── preprocess.py ├── requirements.txt ├── run.sh ├── data ├── example │ ├── val.txt │ └── train.txt ├── wiki-example │ └── wiki.txt └── SST-2 │ └── dev.tsv ├── configs └── mini-bert-pretraining.json ├── main.py ├── LICENSE ├── .gitignore └── README.md /bert/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bert/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bert/train/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bert/train/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bert/train/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn 2 | sentencepiece 3 | nltk 4 | gensim 5 | numpy 6 | tqdm 7 | torch>=0.4.1 8 | -------------------------------------------------------------------------------- /bert/preprocess/utils.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | 4 | def prepend_data_dir(path, data_dir): 5 | return path if data_dir is None else join(data_dir, path) -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2 --directory-prefix data/wiki/ 2 | python main.py preprocess-all --data_dir data/wiki -------------------------------------------------------------------------------- /bert/train/utils/pad.py: -------------------------------------------------------------------------------- 1 | from bert.preprocess import PAD_INDEX 2 | 3 | 4 | def pad_masking(x): 5 | # x: (batch_size, seq_len) 6 | padded_positions = x == PAD_INDEX 7 | return padded_positions.unsqueeze(1) -------------------------------------------------------------------------------- /bert/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | PAD_TOKEN, PAD_INDEX = '[PAD]', 0 2 | UNK_TOKEN, UNK_INDEX = '[UNK]', 1 3 | MASK_TOKEN, MASK_INDEX = '[MASK]', 2 4 | CLS_TOKEN, CLS_INDEX = '[CLS]', 3 5 | SEP_TOKEN, SEP_INDEX = '[SEP]', 4 6 | -------------------------------------------------------------------------------- /bert/train/model/gelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import math 5 | 6 | 7 | class GELU(nn.Module): 8 | 9 | def forward(self, x): 10 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 11 | -------------------------------------------------------------------------------- /bert/train/utils/convert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def convert_to_tensor(data, device): 5 | if type(data) == tuple: 6 | return tuple(torch.tensor(d, device=device) for d in data) 7 | else: 8 | return torch.tensor(data, device=device) 9 | 10 | 11 | def convert_to_array(data): 12 | if type(data) == tuple: 13 | return tuple(d.detach().cpu().numpy() for d in data) 14 | else: 15 | return data.detach().cpu().numpy() 16 | -------------------------------------------------------------------------------- /data/example/val.txt: -------------------------------------------------------------------------------- 1 | There were ten in a bed|And the little one said|Roll over, roll over|So they all rolled over|And one fell out 2 | The wheels on the bus go round and round|Round and round, Round and round|The wheels on the bus go round and round|All through the town 3 | This old man, he played one|He played knick-knack on my thumb|Knick-knack paddy whack|Give the dog the bone|This old man came rolling home 4 | Twinkle, twinkle, little star|How I wonder what you are|Up above the world so high|Like a diamond in the sky -------------------------------------------------------------------------------- /configs/mini-bert-pretraining.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir": "data/wiki", 3 | "train_path": "train.txt", 4 | "val_path": "val.txt", 5 | "dictionary_path": "dictionary.txt", 6 | "dataset_limit": null, 7 | "epochs": 8, 8 | "batch_size": 78, 9 | "print_every": 1, 10 | "save_every": 1, 11 | "vocabulary_size": 30000, 12 | "max_len": 512, 13 | "lr": 0.0002, 14 | "clip_grads": false, 15 | "layers_count": 6, 16 | "hidden_size": 384, 17 | "heads_count": 6, 18 | "d_ff": 1536, 19 | "dropout_prob": 0.1 20 | } -------------------------------------------------------------------------------- /bert/train/model/test_model.py: -------------------------------------------------------------------------------- 1 | from .bert import build_model 2 | 3 | import torch 4 | 5 | 6 | def test_encoder(): 7 | model = build_model(hidden_size=512, layers_count=6, heads_count=8, d_ff=1024, dropout_prob=0.1, max_len=512, 8 | vocabulary_size=100) 9 | 10 | example_sequence = torch.tensor([[1, 2, 3, 4, 5], [2, 1, 3, 0, 0]]) 11 | example_segment = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 0, 1, 1]]) 12 | 13 | token_predictions, classification_output = model((example_sequence, example_segment)) 14 | 15 | batch_size, seq_len, target_vocabulary_size = 2, 5, 100 16 | assert token_predictions.size() == (batch_size, seq_len, target_vocabulary_size) -------------------------------------------------------------------------------- /bert/train/datasets/classification.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class SST2IndexedDataset: 4 | 5 | def __init__(self, data_path, dictionary): 6 | 7 | self.data = [] 8 | with open(data_path) as file: 9 | assert file.readline() == 'sentence\tlabel\n' 10 | 11 | for line in file: 12 | tokenized_sentence, sentiment = line.strip().split('\t') 13 | indexed_sentence = [dictionary.token_to_index(token) for token in tokenized_sentence.split()] 14 | self.data.append((indexed_sentence, int(sentiment))) 15 | 16 | def __getitem__(self, item): 17 | indexed_text, sentiment = self.data[item] 18 | segment = [0] * len(indexed_text) 19 | return (indexed_text, segment), sentiment 20 | 21 | def __len__(self): 22 | return len(self.data) 23 | -------------------------------------------------------------------------------- /bert/train/optimizers.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Adam 2 | 3 | 4 | class NoamOptimizer(Adam): 5 | 6 | def __init__(self, params, d_model, factor=2, warmup_steps=4000, betas=(0.9, 0.98), weight_decay=0, eps=1e-9): 7 | self.d_model = d_model 8 | self.warmup_steps = warmup_steps 9 | self.lr = 0 10 | self.step_num = 0 11 | self.factor = factor 12 | 13 | super(NoamOptimizer, self).__init__(params, betas=betas, weight_decay=weight_decay, eps=eps) 14 | 15 | def step(self, closure=None): 16 | self.step_num += 1 17 | self.lr = self._learning_rate() 18 | for group in self.param_groups: 19 | group['lr'] = self.lr 20 | super(NoamOptimizer, self).step() 21 | 22 | def _learning_rate(self): 23 | return self.factor * self.d_model ** (-0.5) * min(self.step_num ** (-0.5), self.step_num * self.warmup_steps ** (-1.5)) -------------------------------------------------------------------------------- /bert/train/metrics.py: -------------------------------------------------------------------------------- 1 | from bert.preprocess import PAD_INDEX 2 | 3 | import numpy as np 4 | 5 | 6 | def mlm_accuracy(predictions, targets): 7 | mlm_predictions, nsp_predictions = predictions 8 | mlm_targets, is_nexts = targets 9 | 10 | relevent_indexes = np.where(mlm_targets != PAD_INDEX) 11 | relevent_predictions = mlm_predictions[relevent_indexes] 12 | relevent_targets = mlm_targets[relevent_indexes] 13 | 14 | corrects = np.equal(relevent_predictions, relevent_targets) 15 | return corrects.mean() 16 | 17 | 18 | def nsp_accuracy(predictions, targets): 19 | mlm_predictions, nsp_predictions = predictions 20 | mlm_targets, is_nexts = targets 21 | 22 | corrects = np.equal(nsp_predictions, is_nexts) 23 | return corrects.mean() 24 | 25 | 26 | def classification_accuracy(predictions, targets): 27 | corrects = np.equal(predictions, targets) 28 | return corrects.mean() 29 | -------------------------------------------------------------------------------- /bert/train/model/embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class PositionalEmbedding(nn.Module): 6 | 7 | def __init__(self, max_len, hidden_size, ): 8 | super(PositionalEmbedding, self).__init__() 9 | self.positional_embedding = nn.Embedding(max_len, hidden_size) 10 | positions = torch.arange(0, max_len) 11 | self.register_buffer('positions', positions) 12 | 13 | def forward(self, sequence): 14 | batch_size, seq_len = sequence.size() 15 | positions = self.positions[:seq_len].unsqueeze(0).repeat(batch_size, 1) 16 | return self.positional_embedding(positions) 17 | 18 | 19 | class SegmentEmbedding(nn.Module): 20 | 21 | def __init__(self, hidden_size): 22 | super(SegmentEmbedding, self).__init__() 23 | self.segment_embedding = nn.Embedding(2, hidden_size) 24 | 25 | def forward(self, segments): 26 | """segments: (batch_size, seq_len)""" 27 | return self.segment_embedding(segments) # (batch_size, seq_len, hidden_size) 28 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from bert.preprocess.preprocess import add_preprocess_parser 2 | from bert.train.train import add_pretrain_parser, add_finetune_parser 3 | 4 | import json 5 | from argparse import ArgumentParser 6 | 7 | 8 | def main(): 9 | parser = ArgumentParser('BERT') 10 | parser.add_argument('-c', '--config_path', type=str, default=None) 11 | subparsers = parser.add_subparsers() 12 | 13 | add_preprocess_parser(subparsers) 14 | add_pretrain_parser(subparsers) 15 | add_finetune_parser(subparsers) 16 | 17 | args = parser.parse_args() 18 | 19 | if args.config_path is not None: 20 | with open(args.config_path) as f: 21 | config = json.load(f) 22 | 23 | default_config = vars(args) 24 | for key, default_value in default_config.items(): 25 | if key not in config: 26 | config[key] = default_value 27 | else: 28 | config = vars(args) # convert to dictionary 29 | 30 | args.function(**config, config=config) 31 | 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /data/example/train.txt: -------------------------------------------------------------------------------- 1 | One, two, three, four, five,|Once I caught a fish alive,|Six, seven, eight, nine, ten,|Then I let go again. 2 | I’m a little teapot|Short and stout|Here is my handle|Here is my spout. 3 | Jack and Jill went up the hill|To fetch a pail of water.|Jack fell down and broke his crown,|And Jill came tumbling after. 4 | Five little ducks|Went out one day|Over the hill and far away|Mother duck said|Quack, quack, quack, quack.|But only four little ducks came back. 5 | Five little monkeys jumping on the bed,|One fell off and bumped his head.|Mama called the Doctor and the Doctor said,|No more monkeys jumping on the bed! 6 | Mary had a little lamb,|Little lamb, little lamb.|Mary had a little lamb,|Its fleece was white as snow. 7 | Old MacDonald had a farm|And on his farm he had a cow|With a moo moo here and a moo moo there|Here a moo, there a moo, everywhere a moo moo|Old MacDonald had a farm 8 | Ring around the rosie|A pocket full of posies|Atishoo, Atishoo|We all fall down 9 | Rock a bye baby, on the tree top|When the wind blows the cradle will rock|When the bough breaks the cradle will fall|And down will come baby, cradle and all 10 | Row, row, row your boat|Gently down the stream|Merrily, merrily, merrily, merrily|Life is but a dream -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /bert/train/loss_models.py: -------------------------------------------------------------------------------- 1 | from bert.preprocess import PAD_INDEX 2 | 3 | from torch import nn 4 | 5 | 6 | class MLMNSPLossModel(nn.Module): 7 | 8 | def __init__(self, model): 9 | super(MLMNSPLossModel, self).__init__() 10 | 11 | self.model = model 12 | self.mlm_loss_function = nn.CrossEntropyLoss(ignore_index=PAD_INDEX) 13 | self.nsp_loss_function = nn.CrossEntropyLoss() 14 | 15 | def forward(self, inputs, targets): 16 | 17 | outputs = self.model(inputs) 18 | 19 | mlm_outputs, nsp_outputs = outputs 20 | mlm_targets, is_nexts = targets 21 | 22 | mlm_predictions, nsp_predictions = mlm_outputs.argmax(dim=2), nsp_outputs.argmax(dim=1) 23 | predictions = (mlm_predictions, nsp_predictions) 24 | 25 | batch_size, seq_len, vocabulary_size = mlm_outputs.size() 26 | 27 | mlm_outputs_flat = mlm_outputs.view(batch_size * seq_len, vocabulary_size) 28 | mlm_targets_flat = mlm_targets.view(batch_size * seq_len) 29 | 30 | mlm_loss = self.mlm_loss_function(mlm_outputs_flat, mlm_targets_flat) 31 | nsp_loss = self.nsp_loss_function(nsp_outputs, is_nexts) 32 | 33 | loss = mlm_loss + nsp_loss 34 | 35 | return predictions, loss.unsqueeze(dim=0) 36 | 37 | 38 | class ClassificationLossModel(nn.Module): 39 | 40 | def __init__(self, model): 41 | super(ClassificationLossModel, self).__init__() 42 | 43 | self.model = model 44 | self.loss_function = nn.CrossEntropyLoss() 45 | 46 | def forward(self, inputs, targets): 47 | 48 | outputs = self.model(inputs) 49 | predictions = outputs.argmax(dim=1) 50 | loss = self.loss_function(outputs, targets) 51 | 52 | return predictions, loss.unsqueeze(dim=0) 53 | -------------------------------------------------------------------------------- /bert/train/utils/log.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os.path import dirname, abspath, exists, join 3 | import os 4 | import logging 5 | from datetime import datetime 6 | 7 | 8 | def make_run_name(format, phase, config): 9 | return format.format( 10 | **config, 11 | phase=format, 12 | timestamp=datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 13 | ) 14 | 15 | 16 | def make_logger(run_name, log_output): 17 | logger = logging.getLogger(run_name) 18 | log_filepath = log_output if log_output is not None else join('logs', f'{run_name}.log') 19 | 20 | log_dir = dirname(abspath(log_filepath)) 21 | if not exists(log_dir): 22 | os.makedirs(log_dir) 23 | 24 | if not logger.handlers: # execute only if logger doesn't already exist 25 | file_handler = logging.FileHandler(log_filepath, 'a', 'utf-8') 26 | stream_handler = logging.StreamHandler(os.sys.stdout) 27 | 28 | formatter = logging.Formatter('[%(levelname)s] %(asctime)s > %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 29 | 30 | file_handler.setFormatter(formatter) 31 | stream_handler.setFormatter(formatter) 32 | 33 | logger.addHandler(file_handler) 34 | logger.addHandler(stream_handler) 35 | logger.setLevel(logging.INFO) 36 | 37 | return logger 38 | 39 | 40 | def make_checkpoint_dir(checkpoint_dir, run_name, config): 41 | checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else join('checkpoints', run_name) 42 | if not exists(abspath(checkpoint_dir)): 43 | os.makedirs(checkpoint_dir) 44 | 45 | config_output = join(checkpoint_dir, 'config.json') 46 | with open(config_output, 'w') as config_file: 47 | del config['function'] 48 | json.dump(config, config_file) 49 | 50 | return checkpoint_dir 51 | -------------------------------------------------------------------------------- /bert/train/utils/collate.py: -------------------------------------------------------------------------------- 1 | from bert.preprocess import PAD_INDEX 2 | 3 | 4 | def pretraining_collate_function(batch): 5 | 6 | targets = [target for _, (target, is_next) in batch] 7 | longest_target = max(targets, key=lambda target: len(target)) 8 | max_length = len(longest_target) 9 | 10 | padded_sequences = [] 11 | padded_segments = [] 12 | padded_targets = [] 13 | is_nexts = [] 14 | 15 | for (sequence, segment), (target, is_next) in batch: 16 | length = len(sequence) 17 | padding = [PAD_INDEX] * (max_length - length) 18 | padded_sequence = sequence + padding 19 | padded_segment = segment + padding 20 | padded_target = target + padding 21 | 22 | padded_sequences.append(padded_sequence) 23 | padded_segments.append(padded_segment) 24 | padded_targets.append(padded_target) 25 | is_nexts.append(is_next) 26 | 27 | count = 0 28 | for target in targets: 29 | for token in target: 30 | if token != PAD_INDEX: 31 | count += 1 32 | 33 | return (padded_sequences, padded_segments), (padded_targets, is_nexts), count 34 | 35 | 36 | def classification_collate_function(batch): 37 | 38 | lengths = [len(sequence) for (sequence, _), _ in batch] 39 | max_length = max(lengths) 40 | 41 | padded_sequences = [] 42 | padded_segments = [] 43 | labels = [] 44 | 45 | for (sequence, segment), label in batch: 46 | length = len(sequence) 47 | padding = [PAD_INDEX] * (max_length - length) 48 | padded_sequence = sequence + padding 49 | padded_segment = segment + padding 50 | 51 | padded_sequences.append(padded_sequence) 52 | padded_segments.append(padded_segment) 53 | labels.append(label) 54 | 55 | count = len(labels) 56 | 57 | return (padded_sequences, padded_segments), labels, count 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | logs/ 3 | data/* 4 | !data/SST-2/*.tsv 5 | !data/example/*.txt 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | 112 | \.idea/ 113 | 114 | credentials\.json 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BERT-pytorch 2 | PyTorch implementation of BERT in "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" (https://arxiv.org/abs/1810.04805) 3 | 4 | ## Requirements 5 | - Python 3.6+ 6 | - [PyTorch 4.1+](http://pytorch.org/) 7 | - [tqdm](https://github.com/tqdm/tqdm) 8 | 9 | All dependencies can be installed via: 10 | 11 | ``` 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Quickstart 16 | 17 | ### Prepare data 18 | First things first, you need to prepare your data in an appropriate format. 19 | Your corpus is assumed to follow the below constraints. 20 | 21 | - Each line is a *document*. 22 | - A *document* consists of *sentences*, seperated by vertical bar (|). 23 | - A *sentence* is assumed to be already tokenized. Tokens are seperated by space. 24 | - A *sentence* has no more than 256 tokens. 25 | - A *document* has at least 2 sentences. 26 | - You have two distinct data files, one for train data and the other for val data. 27 | 28 | This repo comes with example data for pretraining in data/example directory. 29 | Here is the content of data/example/train.txt file. 30 | 31 | ``` 32 | One, two, three, four, five,|Once I caught a fish alive,|Six, seven, eight, nine, ten,|Then I let go again. 33 | I’m a little teapot|Short and stout|Here is my handle|Here is my spout. 34 | Jack and Jill went up the hill|To fetch a pail of water.|Jack fell down and broke his crown,|And Jill came tumbling after. 35 | ``` 36 | 37 | Also, this repo includes SST-2 data in data/SST-2 directory for sentiment classification. 38 | 39 | ### Build dictionary 40 | ``` 41 | python bert.py preprocess-index data/example/train.txt --dictionary=dictionary.txt 42 | ``` 43 | Running the above command produces dictionary.txt file in your current directory. 44 | 45 | ### Pre-train the model 46 | ``` 47 | python bert.py pretrain --train_data data/example/train.txt --val_data data/example/val.txt --checkpoint_output model.pth 48 | ``` 49 | This step trains BERT model with unsupervised objective. Also this step does: 50 | - logs the training procedure for every epoch 51 | - outputs model checkpoint periodically 52 | - reports the best checkpoint based on validation metric 53 | 54 | ### Fine-tune the model 55 | You can fine-tune pretrained BERT model with downstream task. 56 | For example, you can fine-tune your model with SST-2 sentiment classification task. 57 | ``` 58 | python bert.py finetune --pretrained_checkpoint model.pth --train_data data/SST-2/train.tsv --val_data data/SST-2/dev.tsv 59 | ``` 60 | This command also logs the procedure, outputs checkpoint, and reports the best checkpoint. 61 | 62 | ## See also 63 | - [Transformer-pytorch](https://github.com/dreamgonfly/Transformer-pytorch) : My own implementation of Transformer. This BERT implementation is based on this repo. 64 | 65 | ## Author 66 | [@dreamgonfly](https://github.com/dreamgonfly) -------------------------------------------------------------------------------- /bert/train/model/bert.py: -------------------------------------------------------------------------------- 1 | from .embeddings import PositionalEmbedding, SegmentEmbedding 2 | from .transformer import TransformerEncoder 3 | from ..utils.pad import pad_masking 4 | 5 | from torch import nn 6 | 7 | 8 | def build_model(layers_count, hidden_size, heads_count, d_ff, dropout_prob, max_len, vocabulary_size): 9 | token_embedding = nn.Embedding(num_embeddings=vocabulary_size, embedding_dim=hidden_size) 10 | positional_embedding = PositionalEmbedding(max_len=max_len, hidden_size=hidden_size) 11 | segment_embedding = SegmentEmbedding(hidden_size=hidden_size) 12 | 13 | encoder = TransformerEncoder( 14 | layers_count=layers_count, 15 | d_model=hidden_size, 16 | heads_count=heads_count, 17 | d_ff=d_ff, 18 | dropout_prob=dropout_prob) 19 | 20 | bert = BERT( 21 | encoder=encoder, 22 | token_embedding=token_embedding, 23 | positional_embedding=positional_embedding, 24 | segment_embedding=segment_embedding, 25 | hidden_size=hidden_size, 26 | vocabulary_size=vocabulary_size) 27 | 28 | return bert 29 | 30 | 31 | class FineTuneModel(nn.Module): 32 | 33 | def __init__(self, pretrained_model, hidden_size, num_classes): 34 | super(FineTuneModel, self).__init__() 35 | 36 | self.pretrained_model = pretrained_model 37 | 38 | new_classification_layer = nn.Linear(hidden_size, num_classes) 39 | self.pretrained_model.classification_layer = new_classification_layer 40 | 41 | def forward(self, inputs): 42 | sequence, segment = inputs 43 | token_predictions, classification_outputs = self.pretrained_model((sequence, segment)) 44 | return classification_outputs 45 | 46 | 47 | class BERT(nn.Module): 48 | 49 | def __init__(self, encoder, token_embedding, positional_embedding, segment_embedding, hidden_size, vocabulary_size): 50 | super(BERT, self).__init__() 51 | 52 | self.encoder = encoder 53 | self.token_embedding = token_embedding 54 | self.positional_embedding = positional_embedding 55 | self.segment_embedding = segment_embedding 56 | self.token_prediction_layer = nn.Linear(hidden_size, vocabulary_size) 57 | self.classification_layer = nn.Linear(hidden_size, 2) 58 | 59 | def forward(self, inputs): 60 | sequence, segment = inputs 61 | token_embedded = self.token_embedding(sequence) 62 | positional_embedded = self.positional_embedding(sequence) 63 | segment_embedded = self.segment_embedding(segment) 64 | embedded_sources = token_embedded + positional_embedded + segment_embedded 65 | 66 | mask = pad_masking(sequence) 67 | encoded_sources = self.encoder(embedded_sources, mask) 68 | token_predictions = self.token_prediction_layer(encoded_sources) 69 | classification_embedding = encoded_sources[:, 0, :] 70 | classification_output = self.classification_layer(classification_embedding) 71 | return token_predictions, classification_output 72 | -------------------------------------------------------------------------------- /bert/preprocess/dictionary.py: -------------------------------------------------------------------------------- 1 | from . import PAD_TOKEN, UNK_TOKEN, MASK_TOKEN, CLS_TOKEN, SEP_TOKEN 2 | 3 | from collections import Counter 4 | 5 | 6 | class IndexDictionary: 7 | 8 | def __init__(self, vocabulary_size=None): 9 | 10 | self.special_tokens = [PAD_TOKEN, UNK_TOKEN, MASK_TOKEN, CLS_TOKEN, SEP_TOKEN] 11 | self.vocabulary_size = vocabulary_size 12 | self.vocab_tokens, self.token_counts = None, None 13 | self.token_index_dict = None 14 | 15 | def build_vocabulary(self, iterable): 16 | 17 | counter = Counter(iterable) 18 | 19 | n = self.vocabulary_size - len(self.special_tokens) if self.vocabulary_size is not None else None 20 | most_commons = counter.most_common(n) 21 | frequent_tokens = [token for token, count in most_commons] 22 | self.vocab_tokens = self.special_tokens + frequent_tokens 23 | self.token_counts = [0] * len(self.special_tokens) + [count for token, count in most_commons] 24 | 25 | self.vocabulary_size = len(self.vocab_tokens) 26 | self.token_index_dict = {token: index for index, token in enumerate(self.vocab_tokens)} 27 | 28 | def __len__(self): 29 | return len(self.vocab_tokens) 30 | 31 | def token_to_index(self, token): 32 | try: 33 | return self.token_index_dict[token] 34 | except KeyError: 35 | return self.token_index_dict[UNK_TOKEN] 36 | 37 | def index_to_token(self, index): 38 | if index >= self.vocabulary_size: 39 | return UNK_TOKEN 40 | else: 41 | return self.vocab_tokens[index] 42 | 43 | def index_sentence(self, sentence): 44 | return [self.token_to_index(token) for token in sentence] 45 | 46 | def tokenify_indexes(self, token_indexes): 47 | return [self.index_to_token(token_index) for token_index in token_indexes] 48 | 49 | def save(self, dictionary_path): 50 | with open(dictionary_path, 'w') as file: 51 | for vocab_index, (vocab_token, count) in enumerate(zip(self.vocab_tokens, self.token_counts)): 52 | file.write(str(vocab_index) + '\t' + vocab_token + '\t' + str(count) + '\n') 53 | 54 | @classmethod 55 | def load(cls, dictionary_path, vocabulary_size=None): 56 | vocab_tokens = {} 57 | token_counts = [] 58 | 59 | with open(dictionary_path) as file: 60 | for line in file: 61 | vocab_index, vocab_token, count = line.strip().split('\t') 62 | vocab_index = int(vocab_index) 63 | vocab_tokens[vocab_index] = vocab_token 64 | token_counts.append(int(count)) 65 | 66 | if vocabulary_size is not None: 67 | vocab_tokens = {k: v for k, v in vocab_tokens.items() if k < vocabulary_size} 68 | token_counts = token_counts[:vocabulary_size] 69 | 70 | instance = cls() 71 | instance.vocab_tokens = vocab_tokens 72 | instance.token_counts = token_counts 73 | instance.token_index_dict = {token: index for index, token in vocab_tokens.items()} 74 | instance.vocabulary_size = len(vocab_tokens) 75 | 76 | return instance 77 | -------------------------------------------------------------------------------- /bert/train/datasets/pretraining.py: -------------------------------------------------------------------------------- 1 | from bert.preprocess import PAD_INDEX, MASK_INDEX, CLS_INDEX, SEP_INDEX 2 | 3 | from tqdm import tqdm 4 | 5 | from random import random, randint 6 | 7 | 8 | class IndexedCorpus: 9 | def __init__(self, data_path, dictionary, dataset_limit=None): 10 | self.indexed_documents = [] 11 | with open(data_path) as file: 12 | for document in tqdm(file): 13 | indexed_document = [] 14 | for sentence in document.split('|'): 15 | indexed_sentence = [] 16 | for token in sentence.strip().split(): 17 | indexed_token = dictionary.token_to_index(token) 18 | indexed_sentence.append(indexed_token) 19 | if len(indexed_sentence) < 1: 20 | continue 21 | indexed_document.append(indexed_sentence) 22 | if len(indexed_document) < 2: 23 | continue 24 | self.indexed_documents.append(indexed_document) 25 | 26 | if dataset_limit is not None and len(self.indexed_documents) >= dataset_limit: 27 | break 28 | 29 | def __getitem__(self, item): 30 | return self.indexed_documents[item] 31 | 32 | def __len__(self): 33 | return len(self.indexed_documents) 34 | 35 | 36 | class MaskedDocument: 37 | def __init__(self, sentences, vocabulary_size): 38 | self.sentences = sentences 39 | self.vocabulary_size = vocabulary_size 40 | self.THRESHOLD = 0.15 41 | 42 | def __getitem__(self, item): 43 | """Get a masked sentence and the corresponding target. 44 | 45 | For wiki-example, [5,6,MASK_INDEX,8,9], [0,0,7,0,0] 46 | """ 47 | sentence = self.sentences[item] 48 | 49 | masked_sentence = [] 50 | target_sentence = [] 51 | 52 | for token_index in sentence: 53 | r = random() 54 | if r < self.THRESHOLD: # we mask 15% of all tokens in each sequence at random. 55 | if r < self.THRESHOLD * 0.8: # 80% of the time: Replace the word with the [MASK] token 56 | masked_sentence.append(MASK_INDEX) 57 | target_sentence.append(token_index) 58 | elif r < self.THRESHOLD * 0.9: # 10% of the time: Replace the word with a random word 59 | random_token_index = randint(5, self.vocabulary_size-1) 60 | masked_sentence.append(random_token_index) 61 | target_sentence.append(token_index) 62 | else: # 10% of the time: Keep the word unchanged 63 | masked_sentence.append(token_index) 64 | target_sentence.append(token_index) 65 | else: 66 | masked_sentence.append(token_index) 67 | target_sentence.append(PAD_INDEX) 68 | 69 | return masked_sentence, target_sentence 70 | 71 | def __len__(self): 72 | return len(self.sentences) 73 | 74 | 75 | class MaskedCorpus: 76 | 77 | def __init__(self, data_path, dictionary, dataset_limit=None): 78 | source_corpus = IndexedCorpus(data_path, dictionary, dataset_limit=dataset_limit) 79 | 80 | self.sentences_count = 0 81 | self.masked_documents = [] 82 | for indexed_document in source_corpus: 83 | masked_document = MaskedDocument(indexed_document, vocabulary_size=len(dictionary)) 84 | self.masked_documents.append(masked_document) 85 | 86 | self.sentences_count += len(masked_document) 87 | 88 | def __getitem__(self, item): 89 | return self.masked_documents[item] 90 | 91 | def __len__(self): 92 | return len(self.masked_documents) 93 | 94 | 95 | class PairedDataset: 96 | 97 | def __init__(self, data_path, dictionary, dataset_limit=None): 98 | self.source_corpus = MaskedCorpus(data_path, dictionary, dataset_limit=dataset_limit) 99 | self.dataset_size = self.source_corpus.sentences_count 100 | self.corpus_size = len(self.source_corpus) 101 | 102 | def __getitem__(self, item): 103 | 104 | document_index = randint(0, self.corpus_size-1) 105 | document = self.source_corpus[document_index] 106 | sentence_index = randint(0, len(document) - 2) 107 | A_masked_sentence, A_target_sentence = document[sentence_index] 108 | 109 | if random() < 0.5: # 50% of the time B is the actual next sentence that follows A 110 | B_masked_sentence, B_target_sentence = document[sentence_index + 1] 111 | is_next = 1 112 | else: # 50% of the time it is a random sentence from the corpus 113 | random_document_index = randint(0, self.corpus_size-1) 114 | random_document = self.source_corpus[random_document_index] 115 | random_sentence_index = randint(0, len(random_document)-1) 116 | B_masked_sentence, B_target_sentence = random_document[random_sentence_index] 117 | is_next = 0 118 | 119 | sequence = [CLS_INDEX] + A_masked_sentence + [SEP_INDEX] + B_masked_sentence + [SEP_INDEX] 120 | 121 | # segment : something like [0,0,0,0,0,1,1,1,1,1,1,1]) 122 | segment = [0] + [0] * len(A_masked_sentence) + [0] + [1] * len(B_masked_sentence) + [1] 123 | 124 | target = [PAD_INDEX] + A_target_sentence + [PAD_INDEX] + B_target_sentence + [PAD_INDEX] 125 | 126 | return (sequence, segment), (target, is_next) 127 | 128 | def __len__(self): 129 | return self.dataset_size 130 | -------------------------------------------------------------------------------- /bert/train/trainer.py: -------------------------------------------------------------------------------- 1 | from .utils.convert import convert_to_tensor, convert_to_array 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from os.path import join 7 | from datetime import datetime 8 | 9 | SAVE_FORMAT = 'epoch={epoch:0>3}-val_loss={val_loss:<.3}-val_metrics={val_metrics}.pth' 10 | 11 | LOG_FORMAT = ( 12 | "Epoch: {epoch:>3} " 13 | "Progress: {progress:<.1%} " 14 | "Elapsed: {elapsed} " 15 | "Examples/second: {per_second:<.1} " 16 | "Train Loss: {train_loss:<.6} " 17 | "Val Loss: {val_loss:<.6} " 18 | "Train Metrics: {train_metrics} " 19 | "Val Metrics: {val_metrics} " 20 | "Learning rate: {current_lr:<.4} " 21 | ) 22 | 23 | 24 | class Trainer: 25 | 26 | def __init__(self, loss_model, train_dataloader, val_dataloader, 27 | metric_functions, device, optimizer, clip_grads, 28 | logger, checkpoint_dir, print_every, save_every): 29 | 30 | self.device = device 31 | 32 | self.loss_model = loss_model.to(self.device) 33 | self.train_dataloader = train_dataloader 34 | self.val_dataloader = val_dataloader 35 | 36 | self.metric_functions = metric_functions 37 | self.optimizer = optimizer 38 | self.clip_grads = clip_grads 39 | 40 | self.logger = logger 41 | self.checkpoint_dir = checkpoint_dir 42 | 43 | self.print_every = print_every 44 | self.save_every = save_every 45 | 46 | self.epoch = 0 47 | self.history = [] 48 | 49 | self.start_time = datetime.now() 50 | 51 | self.best_val_metric = None 52 | self.best_checkpoint_output_path = None 53 | 54 | def run_epoch(self, dataloader, mode='train'): 55 | 56 | epoch_loss = 0 57 | epoch_count = 0 58 | epoch_metrics = [0 for _ in range(len(self.metric_functions))] 59 | 60 | for inputs, targets, batch_count in tqdm(dataloader): 61 | inputs = convert_to_tensor(inputs, self.device) 62 | targets = convert_to_tensor(targets, self.device) 63 | 64 | predictions, batch_losses = self.loss_model(inputs, targets) 65 | predictions = convert_to_array(predictions) 66 | targets = convert_to_array(targets) 67 | 68 | batch_loss = batch_losses.mean() 69 | 70 | if mode == 'train': 71 | self.optimizer.zero_grad() 72 | batch_loss.backward() 73 | if self.clip_grads: 74 | torch.nn.utils.clip_grad_norm_(self.loss_model.parameters(), 1) 75 | self.optimizer.step() 76 | 77 | epoch_loss = (epoch_loss * epoch_count + batch_loss.item() * batch_count) / (epoch_count + batch_count) 78 | 79 | batch_metrics = [metric_function(predictions, targets) for metric_function in self.metric_functions] 80 | epoch_metrics = [(epoch_metric * epoch_count + batch_metric * batch_count) / (epoch_count + batch_count) 81 | for epoch_metric, batch_metric in zip(epoch_metrics, batch_metrics)] 82 | 83 | epoch_count += batch_count 84 | 85 | if self.epoch == 0: # for testing 86 | return float('inf'), [float('inf')] 87 | 88 | return epoch_loss, epoch_metrics 89 | 90 | def run(self, epochs=10): 91 | 92 | for epoch in range(self.epoch, epochs + 1): 93 | self.epoch = epoch 94 | 95 | self.loss_model.train() 96 | 97 | epoch_start_time = datetime.now() 98 | train_epoch_loss, train_epoch_metrics = self.run_epoch(self.train_dataloader, mode='train') 99 | epoch_end_time = datetime.now() 100 | 101 | self.loss_model.eval() 102 | 103 | val_epoch_loss, val_epoch_metrics = self.run_epoch(self.val_dataloader, mode='val') 104 | 105 | if epoch % self.print_every == 0 and self.logger: 106 | per_second = len(self.train_dataloader.dataset) / ((epoch_end_time - epoch_start_time).seconds + 1) 107 | current_lr = self.optimizer.param_groups[0]['lr'] 108 | log_message = LOG_FORMAT.format(epoch=epoch, 109 | progress=epoch / epochs, 110 | per_second=per_second, 111 | train_loss=train_epoch_loss, 112 | val_loss=val_epoch_loss, 113 | train_metrics=[round(metric, 4) for metric in train_epoch_metrics], 114 | val_metrics=[round(metric, 4) for metric in val_epoch_metrics], 115 | current_lr=current_lr, 116 | elapsed=self._elapsed_time() 117 | ) 118 | 119 | self.logger.info(log_message) 120 | 121 | if epoch % self.save_every == 0: 122 | self._save_model(epoch, train_epoch_loss, val_epoch_loss, train_epoch_metrics, val_epoch_metrics) 123 | 124 | def _save_model(self, epoch, train_epoch_loss, val_epoch_loss, train_epoch_metrics, val_epoch_metrics): 125 | 126 | checkpoint_name = SAVE_FORMAT.format( 127 | epoch=epoch, 128 | val_loss=val_epoch_loss, 129 | val_metrics='-'.join(['{:<.3}'.format(v) for v in val_epoch_metrics]) 130 | ) 131 | 132 | checkpoint_output_path = join(self.checkpoint_dir, checkpoint_name) 133 | 134 | save_state = { 135 | 'epoch': epoch, 136 | 'train_loss': train_epoch_loss, 137 | 'train_metrics': train_epoch_metrics, 138 | 'val_loss': val_epoch_loss, 139 | 'val_metrics': val_epoch_metrics, 140 | 'checkpoint': checkpoint_output_path, 141 | } 142 | if epoch > 0: 143 | self.history.append(save_state) 144 | 145 | if hasattr(self.loss_model, 'module'): # DataParallel 146 | save_state['state_dict'] = self.loss_model.module.state_dict() 147 | else: 148 | save_state['state_dict'] = self.loss_model.state_dict() 149 | 150 | torch.save(save_state, checkpoint_output_path) 151 | 152 | representative_val_metric = val_epoch_metrics[0] 153 | if self.best_val_metric is None or self.best_val_metric > representative_val_metric: 154 | self.best_val_metric = representative_val_metric 155 | self.val_metrics_at_best = val_epoch_metrics 156 | self.val_loss_at_best = val_epoch_loss 157 | self.train_metrics_at_best = train_epoch_metrics 158 | self.train_loss_at_best = train_epoch_loss 159 | self.best_checkpoint_output_path = checkpoint_output_path 160 | self.best_epoch = epoch 161 | 162 | if self.logger: 163 | self.logger.info("Saved model to {}".format(checkpoint_output_path)) 164 | self.logger.info("Current best model is {}".format(self.best_checkpoint_output_path)) 165 | 166 | def _elapsed_time(self): 167 | now = datetime.now() 168 | elapsed = now - self.start_time 169 | return str(elapsed).split('.')[0] # remove milliseconds 170 | -------------------------------------------------------------------------------- /bert/train/model/transformer.py: -------------------------------------------------------------------------------- 1 | from .gelu import GELU 2 | 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | 7 | 8 | class TransformerEncoder(nn.Module): 9 | 10 | def __init__(self, layers_count, d_model, heads_count, d_ff, dropout_prob): 11 | super(TransformerEncoder, self).__init__() 12 | 13 | self.d_model = d_model 14 | self.encoder_layers = nn.ModuleList( 15 | [TransformerEncoderLayer(d_model, heads_count, d_ff, dropout_prob) for _ in range(layers_count)] 16 | ) 17 | 18 | def forward(self, sources, mask): 19 | """Transformer bidirectional encoder 20 | 21 | args: 22 | sources: embedded_sequence, (batch_size, seq_len, embed_size) 23 | """ 24 | for encoder_layer in self.encoder_layers: 25 | sources = encoder_layer(sources, mask) 26 | 27 | return sources 28 | 29 | 30 | class TransformerEncoderLayer(nn.Module): 31 | 32 | def __init__(self, d_model, heads_count, d_ff, dropout_prob): 33 | super(TransformerEncoderLayer, self).__init__() 34 | 35 | self.self_attention_layer = Sublayer(MultiHeadAttention(heads_count, d_model, dropout_prob), d_model) 36 | self.pointwise_feedforward_layer = Sublayer(PointwiseFeedForwardNetwork(d_ff, d_model, dropout_prob), d_model) 37 | self.dropout = nn.Dropout(dropout_prob) 38 | 39 | def forward(self, sources, sources_mask): 40 | # x: (batch_size, seq_len, d_model) 41 | 42 | sources = self.self_attention_layer(sources, sources, sources, sources_mask) 43 | sources = self.dropout(sources) 44 | sources = self.pointwise_feedforward_layer(sources) 45 | 46 | return sources 47 | 48 | 49 | class Sublayer(nn.Module): 50 | 51 | def __init__(self, sublayer, d_model): 52 | super(Sublayer, self).__init__() 53 | 54 | self.sublayer = sublayer 55 | self.layer_normalization = LayerNormalization(d_model) 56 | 57 | def forward(self, *args): 58 | x = args[0] 59 | x = self.sublayer(*args) + x 60 | return self.layer_normalization(x) 61 | 62 | 63 | class LayerNormalization(nn.Module): 64 | 65 | def __init__(self, features_count, epsilon=1e-6): 66 | super(LayerNormalization, self).__init__() 67 | 68 | self.gain = nn.Parameter(torch.ones(features_count)) 69 | self.bias = nn.Parameter(torch.zeros(features_count)) 70 | self.epsilon = epsilon 71 | 72 | def forward(self, x): 73 | 74 | mean = x.mean(dim=-1, keepdim=True) 75 | std = x.std(dim=-1, keepdim=True) 76 | 77 | return self.gain * (x - mean) / (std + self.epsilon) + self.bias 78 | 79 | 80 | class MultiHeadAttention(nn.Module): 81 | 82 | def __init__(self, heads_count, d_model, dropout_prob, mode='self-attention'): 83 | super(MultiHeadAttention, self).__init__() 84 | 85 | assert d_model % heads_count == 0 86 | assert mode in ('self-attention', 'memory-attention') 87 | 88 | self.d_head = d_model // heads_count 89 | self.heads_count = heads_count 90 | self.mode = mode 91 | self.query_projection = nn.Linear(d_model, heads_count * self.d_head) 92 | self.key_projection = nn.Linear(d_model, heads_count * self.d_head) 93 | self.value_projection = nn.Linear(d_model, heads_count * self.d_head) 94 | self.final_projection = nn.Linear(d_model, heads_count * self.d_head) 95 | self.dropout = nn.Dropout(dropout_prob) 96 | self.softmax = nn.Softmax(dim=3) 97 | 98 | self.attention = None 99 | # For cache 100 | self.key_projected = None 101 | self.value_projected = None 102 | 103 | def forward(self, query, key, value, mask=None, layer_cache=None): 104 | """ 105 | 106 | Args: 107 | query: (batch_size, query_len, model_dim) 108 | key: (batch_size, key_len, model_dim) 109 | value: (batch_size, value_len, model_dim) 110 | mask: (batch_size, query_len, key_len) 111 | """ 112 | # print('attention mask', mask) 113 | batch_size, query_len, d_model = query.size() 114 | 115 | d_head = d_model // self.heads_count 116 | 117 | query_projected = self.query_projection(query) 118 | # print('query_projected', query_projected.shape) 119 | if layer_cache is None or layer_cache[self.mode] is None: # Don't use cache 120 | key_projected = self.key_projection(key) 121 | value_projected = self.value_projection(value) 122 | else: # Use cache 123 | if self.mode == 'self-attention': 124 | key_projected = self.key_projection(key) 125 | value_projected = self.value_projection(value) 126 | 127 | key_projected = torch.cat([key_projected, layer_cache[self.mode]['key_projected']], dim=1) 128 | value_projected = torch.cat([value_projected, layer_cache[self.mode]['value_projected']], dim=1) 129 | elif self.mode == 'memory-attention': 130 | key_projected = layer_cache[self.mode]['key_projected'] 131 | value_projected = layer_cache[self.mode]['value_projected'] 132 | 133 | # For cache 134 | self.key_projected = key_projected 135 | self.value_projected = value_projected 136 | 137 | batch_size, key_len, d_model = key_projected.size() 138 | batch_size, value_len, d_model = value_projected.size() 139 | 140 | query_heads = query_projected.view(batch_size, query_len, self.heads_count, d_head).transpose(1, 2) # (batch_size, heads_count, query_len, d_head) 141 | key_heads = key_projected.view(batch_size, key_len, self.heads_count, d_head).transpose(1, 2) # (batch_size, heads_count, key_len, d_head) 142 | value_heads = value_projected.view(batch_size, value_len, self.heads_count, d_head).transpose(1, 2) # (batch_size, heads_count, value_len, d_head) 143 | 144 | attention_weights = self.scaled_dot_product(query_heads, key_heads) # (batch_size, heads_count, query_len, key_len) 145 | 146 | if mask is not None: 147 | mask_expanded = mask.unsqueeze(1).expand_as(attention_weights) 148 | attention_weights = attention_weights.masked_fill(mask_expanded, -1e18) 149 | 150 | self.attention = self.softmax(attention_weights) # Save attention to the object 151 | attention_dropped = self.dropout(self.attention) 152 | context_heads = torch.matmul(attention_dropped, value_heads) # (batch_size, heads_count, query_len, d_head) 153 | context_sequence = context_heads.transpose(1, 2).contiguous() # (batch_size, query_len, heads_count, d_head) 154 | context = context_sequence.view(batch_size, query_len, d_model) # (batch_size, query_len, d_model) 155 | final_output = self.final_projection(context) 156 | 157 | return final_output 158 | 159 | def scaled_dot_product(self, query_heads, key_heads): 160 | """ 161 | 162 | Args: 163 | query_heads: (batch_size, heads_count, query_len, d_head) 164 | key_heads: (batch_size, heads_count, key_len, d_head) 165 | """ 166 | key_heads_transposed = key_heads.transpose(2, 3) 167 | dot_product = torch.matmul(query_heads, key_heads_transposed) # (batch_size, heads_count, query_len, key_len) 168 | attention_weights = dot_product / np.sqrt(self.d_head) 169 | return attention_weights 170 | 171 | 172 | class PointwiseFeedForwardNetwork(nn.Module): 173 | 174 | def __init__(self, d_ff, d_model, dropout_prob): 175 | super(PointwiseFeedForwardNetwork, self).__init__() 176 | 177 | self.feed_forward = nn.Sequential( 178 | nn.Linear(d_model, d_ff), 179 | nn.Dropout(dropout_prob), 180 | GELU(), 181 | nn.Linear(d_ff, d_model), 182 | nn.Dropout(dropout_prob), 183 | ) 184 | 185 | def forward(self, x): 186 | """ 187 | 188 | Args: 189 | x: (batch_size, seq_len, d_model) 190 | """ 191 | return self.feed_forward(x) 192 | -------------------------------------------------------------------------------- /bert/preprocess/preprocess.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import train_test_split 2 | 3 | from .dictionary import IndexDictionary 4 | from .utils import prepend_data_dir 5 | 6 | from tqdm import tqdm 7 | from gensim.corpora import WikiCorpus 8 | from nltk.tokenize import sent_tokenize 9 | import sentencepiece as spm 10 | 11 | import re 12 | 13 | NUMBERS = re.compile(r'\d+') 14 | TOKENIZATION = re.compile(r'(\w+)') 15 | 16 | 17 | def preprocess_all(data_dir, wiki_raw_path, raw_documents_path, sentences_detected_path, spm_input_path, 18 | spm_model_prefix, word_piece_vocab_size, prepared_documents_path, train_path, val_path, 19 | dictionary_path, **_): 20 | 21 | wiki_raw_path = prepend_data_dir(wiki_raw_path, data_dir) 22 | raw_documents_path = prepend_data_dir(raw_documents_path, data_dir) 23 | sentences_detected_path = prepend_data_dir(sentences_detected_path, data_dir) 24 | spm_input_path = prepend_data_dir(spm_input_path, data_dir) 25 | spm_model_prefix = prepend_data_dir(spm_model_prefix, data_dir) 26 | prepared_documents_path = prepend_data_dir(prepared_documents_path, data_dir) 27 | train_path = prepend_data_dir(train_path, data_dir) 28 | val_path = prepend_data_dir(val_path, data_dir) 29 | dictionary_path = prepend_data_dir(dictionary_path, data_dir) 30 | 31 | print('Extracting articles...') 32 | extract_articles_wiki(wiki_raw_path, raw_documents_path) 33 | print('Detecting sentences...') 34 | detect_sentences(raw_documents_path, sentences_detected_path) 35 | print('Splitting sentences...') 36 | split_sentences(sentences_detected_path, spm_input_path) 37 | print('Training tokenizer...') 38 | train_tokenizer(spm_input_path, spm_model_prefix, word_piece_vocab_size) 39 | print('Preparing documents...') 40 | prepare_documents(spm_model_prefix, sentences_detected_path, prepared_documents_path) 41 | print('Splitting train val data...') 42 | split_train_val(prepared_documents_path, train_path, val_path) 43 | print('Building dictionary...') 44 | build_dictionary(train_path, dictionary_path) 45 | 46 | 47 | def tokenize(text: str, lower: bool, **_): # token_min_len: int, token_max_len: int, 48 | if lower: 49 | text = text.lower() 50 | return text.split() 51 | 52 | 53 | def extract_articles_wiki(wiki_raw_path, raw_documents_path, **_): 54 | wiki_corpus = WikiCorpus(wiki_raw_path, lemmatize=False, dictionary={}, tokenizer_func=tokenize, lower=False) 55 | 56 | with open(raw_documents_path, 'w') as raw_documents_file: 57 | for text in tqdm(wiki_corpus.get_texts()): 58 | document = ' '.join(text) 59 | raw_documents_file.write(document + '\n') 60 | 61 | 62 | def detect_sentences(raw_documents_path, sentences_detected_path, **_): 63 | with open(raw_documents_path) as raw_documents_file, open(sentences_detected_path, 'w') as sentences_detected_file: 64 | for line in tqdm(raw_documents_file): 65 | sentences = sent_tokenize(line.strip()) 66 | tokenized_sentences = [] 67 | for sentence in sentences: 68 | sentence = sentence.lower() 69 | sentence = NUMBERS.sub('N', sentence) 70 | tokens = [match.group() for match in TOKENIZATION.finditer(sentence)] 71 | if not tokens: 72 | continue 73 | tokenized_sentences.append(' '.join(tokens)) 74 | 75 | output_line = '|'.join(tokenized_sentences) + '\n' 76 | sentences_detected_file.write(output_line) 77 | 78 | 79 | def split_sentences(sentences_detected_path, spm_input_path, **_): 80 | with open(sentences_detected_path) as sentences_detected_file, open(spm_input_path, 'w') as spm_input_file: 81 | for line in tqdm(sentences_detected_file): 82 | for sentence in line.strip().split('|'): 83 | words = sentence.split() 84 | for i in range(0, len(words), 254): 85 | sentence_segment = words[i:i+254] 86 | spm_input_file.write(' '.join(sentence_segment) + '\n') 87 | 88 | 89 | def train_tokenizer(spm_input_path, spm_model_prefix, word_piece_vocab_size, **_): 90 | spm.SentencePieceTrainer.Train(f'--input={spm_input_path} --model_prefix={spm_model_prefix} ' 91 | f'--vocab_size={word_piece_vocab_size} --hard_vocab_limit=false') 92 | 93 | 94 | def prepare_documents(spm_model_prefix, sentences_detected_path, prepared_documents_path, **_): 95 | spm_model = spm_model_prefix + '.model' 96 | sp_preprocessor = spm.SentencePieceProcessor() 97 | sp_preprocessor.Load(spm_model) 98 | 99 | with open(sentences_detected_path) as sentences_detected_file, \ 100 | open(prepared_documents_path, 'w') as prepared_documents_file: 101 | for document in tqdm(sentences_detected_file): 102 | prepared_sentences = [] 103 | pieces = [] 104 | for sentence in document.strip().split('|'): 105 | sentence_pieces = sp_preprocessor.EncodeAsPieces(sentence) 106 | 107 | if len(sentence_pieces) <= 254: 108 | 109 | if len(pieces) + len(sentence_pieces) >= 254: 110 | prepared_sentences.append(' '.join(pieces)) 111 | pieces = sentence_pieces 112 | else: 113 | pieces.extend(sentence_pieces) 114 | else: 115 | if len(pieces) > 0: 116 | prepared_sentences.append(' '.join(pieces)) 117 | for i in range(0, len(sentence_pieces), 254): 118 | sentence_pieces_segment = sentence_pieces[i:i+254] 119 | prepared_sentences.append(' '.join(sentence_pieces_segment)) 120 | pieces = [] 121 | if len(prepared_sentences) < 2: 122 | continue 123 | output_line = '|'.join(prepared_sentences) + '\n' 124 | prepared_documents_file.write(output_line) 125 | 126 | 127 | def split_train_val(prepared_documents_path, train_path, val_path, **_): 128 | with open(prepared_documents_path) as prepared_documents_file: 129 | documents = prepared_documents_file.readlines() 130 | 131 | train_data, val_data = train_test_split(documents, test_size=10000) 132 | with open(train_path, 'w') as train_file: 133 | for line in train_data: 134 | train_file.write(line) 135 | with open(val_path, 'w') as val_file: 136 | for line in val_data: 137 | val_file.write(line) 138 | 139 | 140 | def build_dictionary(train_path, dictionary_path, **_): 141 | 142 | def token_generator(data_path): 143 | with open(data_path) as file: 144 | for document in file: 145 | for sentence in document.strip().split('|'): 146 | for token in sentence.split(): 147 | yield token 148 | 149 | dictionary = IndexDictionary() 150 | dictionary.build_vocabulary(token_generator(train_path)) 151 | dictionary.save(dictionary_path) 152 | return dictionary 153 | 154 | 155 | def add_preprocess_parser(subparsers): 156 | preprocess_all_parser = subparsers.add_parser('preprocess-all') 157 | preprocess_all_parser.set_defaults(function=preprocess_all) 158 | preprocess_all_parser.add_argument('--data_dir', type=str, default=None) 159 | preprocess_all_parser.add_argument('--wiki_raw_path', type=str, default='enwiki-latest-pages-articles.xml.bz2') 160 | preprocess_all_parser.add_argument('--raw_documents_path', type=str, default='raw_documents.txt') 161 | preprocess_all_parser.add_argument('--sentences_detected_path', type=str, default='sentences_detected.txt') 162 | preprocess_all_parser.add_argument('--spm_input_path', type=str, default='spm_input.txt') 163 | preprocess_all_parser.add_argument('--spm_model_prefix', type=str, default='spm') 164 | preprocess_all_parser.add_argument('--word_piece_vocab_size', type=int, default=30000) 165 | preprocess_all_parser.add_argument('--prepared_documents_path', type=str, default='prepared_documents.txt') 166 | preprocess_all_parser.add_argument('--dictionary_path', type=str, default='dictionary.txt') 167 | 168 | extract_wiki_parser = subparsers.add_parser('extract-wiki') 169 | extract_wiki_parser.set_defaults(function=extract_articles_wiki) 170 | extract_wiki_parser.add_argument('wiki_raw_path', type=str) 171 | extract_wiki_parser.add_argument('raw_documents_path', nargs='?', type=str, default='raw_documents.txt') 172 | 173 | detect_sentences_parser = subparsers.add_parser('detect-sentences') 174 | detect_sentences_parser.set_defaults(function=detect_sentences) 175 | detect_sentences_parser.add_argument('raw_documents_path', type=str) 176 | detect_sentences_parser.add_argument('sentences_detected_path', nargs='?', type=str, 177 | default='sentences_detected.txt') 178 | 179 | split_sentences_parser = subparsers.add_parser('split-sentences') 180 | split_sentences_parser.set_defaults(function=split_sentences) 181 | split_sentences_parser.add_argument('sentences_detected_path', type=str) 182 | split_sentences_parser.add_argument('spm_input_path', nargs='?', type=str, default='spm_input.txt') 183 | 184 | train_tokenizer_parser = subparsers.add_parser('train-tokenizer') 185 | train_tokenizer_parser.set_defaults(function=train_tokenizer) 186 | train_tokenizer_parser.add_argument('spm_input_path', type=str) 187 | train_tokenizer_parser.add_argument('spm_model_prefix', nargs='?', type=str, default='spm') 188 | train_tokenizer_parser.add_argument('--word_piece_vocab_size', type=int, default=30000) 189 | 190 | prepare_documents_parser = subparsers.add_parser('prepare-documents') 191 | prepare_documents_parser.set_defaults(function=prepare_documents) 192 | prepare_documents_parser.add_argument('sentences_detected_path', type=str) 193 | prepare_documents_parser.add_argument('prepared_documents_path', nargs='?', type=str, 194 | default='prepared_documents.txt') 195 | prepare_documents_parser.add_argument('--spm_model_prefix', type=str, default='spm') 196 | 197 | split_train_test_parser = subparsers.add_parser('split-train-val') 198 | split_train_test_parser.set_defaults(function=split_train_val) 199 | split_train_test_parser.add_argument('prepared_documents_path', type=str) 200 | split_train_test_parser.add_argument('train_path', nargs='?', type=str, default='train.txt') 201 | split_train_test_parser.add_argument('val_path', nargs='?', type=str, default='val.txt') 202 | 203 | build_dictionary_parser = subparsers.add_parser('build-dictionary') 204 | build_dictionary_parser.set_defaults(function=build_dictionary) 205 | build_dictionary_parser.add_argument('train_path', type=str, default='train.txt') 206 | build_dictionary_parser.add_argument('dictionary_path', nargs='?', type=str, default='dictionary.txt') 207 | -------------------------------------------------------------------------------- /bert/train/train.py: -------------------------------------------------------------------------------- 1 | from bert.preprocess.dictionary import IndexDictionary 2 | from .model.bert import build_model, FineTuneModel 3 | from .loss_models import MLMNSPLossModel, ClassificationLossModel 4 | from .metrics import mlm_accuracy, nsp_accuracy, classification_accuracy 5 | from .datasets.pretraining import PairedDataset 6 | from .datasets.classification import SST2IndexedDataset 7 | from .trainer import Trainer 8 | from .utils.log import make_run_name, make_logger, make_checkpoint_dir 9 | from .utils.collate import pretraining_collate_function, classification_collate_function 10 | from .optimizers import NoamOptimizer 11 | 12 | import torch 13 | from torch.nn import DataParallel 14 | from torch.optim import Adam 15 | from torch.utils.data import DataLoader 16 | 17 | import random 18 | import numpy as np 19 | from os.path import join 20 | 21 | 22 | RUN_NAME_FORMAT = ( 23 | "BERT-" 24 | "{phase}-" 25 | "layers_count={layers_count}-" 26 | "hidden_size={hidden_size}-" 27 | "heads_count={heads_count}-" 28 | "{timestamp}" 29 | ) 30 | 31 | 32 | def pretrain(data_dir, train_path, val_path, dictionary_path, 33 | dataset_limit, vocabulary_size, batch_size, max_len, epochs, clip_grads, device, 34 | layers_count, hidden_size, heads_count, d_ff, dropout_prob, 35 | log_output, checkpoint_dir, print_every, save_every, config, run_name=None, **_): 36 | 37 | random.seed(0) 38 | np.random.seed(0) 39 | torch.manual_seed(0) 40 | 41 | train_path = train_path if data_dir is None else join(data_dir, train_path) 42 | val_path = val_path if data_dir is None else join(data_dir, val_path) 43 | dictionary_path = dictionary_path if data_dir is None else join(data_dir, dictionary_path) 44 | 45 | run_name = run_name if run_name is not None else make_run_name(RUN_NAME_FORMAT, phase='pretrain', config=config) 46 | logger = make_logger(run_name, log_output) 47 | logger.info('Run name : {run_name}'.format(run_name=run_name)) 48 | logger.info(config) 49 | 50 | logger.info('Constructing dictionaries...') 51 | dictionary = IndexDictionary.load(dictionary_path=dictionary_path, 52 | vocabulary_size=vocabulary_size) 53 | vocabulary_size = len(dictionary) 54 | logger.info(f'dictionary vocabulary : {vocabulary_size} tokens') 55 | 56 | logger.info('Loading datasets...') 57 | train_dataset = PairedDataset(data_path=train_path, dictionary=dictionary, dataset_limit=dataset_limit) 58 | val_dataset = PairedDataset(data_path=val_path, dictionary=dictionary, dataset_limit=dataset_limit) 59 | logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset))) 60 | 61 | logger.info('Building model...') 62 | model = build_model(layers_count, hidden_size, heads_count, d_ff, dropout_prob, max_len, vocabulary_size) 63 | 64 | logger.info(model) 65 | logger.info('{parameters_count} parameters'.format( 66 | parameters_count=sum([p.nelement() for p in model.parameters()]))) 67 | 68 | loss_model = MLMNSPLossModel(model) 69 | if torch.cuda.device_count() > 1: 70 | loss_model = DataParallel(loss_model, output_device=1) 71 | 72 | metric_functions = [mlm_accuracy, nsp_accuracy] 73 | 74 | train_dataloader = DataLoader( 75 | train_dataset, 76 | batch_size=batch_size, 77 | collate_fn=pretraining_collate_function) 78 | 79 | val_dataloader = DataLoader( 80 | val_dataset, 81 | batch_size=batch_size, 82 | collate_fn=pretraining_collate_function) 83 | 84 | optimizer = NoamOptimizer(model.parameters(), 85 | d_model=hidden_size, factor=2, warmup_steps=10000, betas=(0.9, 0.999), weight_decay=0.01) 86 | 87 | checkpoint_dir = make_checkpoint_dir(checkpoint_dir, run_name, config) 88 | 89 | logger.info('Start training...') 90 | trainer = Trainer( 91 | loss_model=loss_model, 92 | train_dataloader=train_dataloader, 93 | val_dataloader=val_dataloader, 94 | metric_functions=metric_functions, 95 | optimizer=optimizer, 96 | clip_grads=clip_grads, 97 | logger=logger, 98 | checkpoint_dir=checkpoint_dir, 99 | print_every=print_every, 100 | save_every=save_every, 101 | device=device 102 | ) 103 | 104 | trainer.run(epochs=epochs) 105 | return trainer 106 | 107 | 108 | def finetune(pretrained_checkpoint, 109 | data_dir, train_path, val_path, dictionary_path, 110 | vocabulary_size, batch_size, max_len, epochs, lr, clip_grads, device, 111 | layers_count, hidden_size, heads_count, d_ff, dropout_prob, 112 | log_output, checkpoint_dir, print_every, save_every, config, run_name=None, **_): 113 | 114 | random.seed(0) 115 | np.random.seed(0) 116 | torch.manual_seed(0) 117 | 118 | train_path = train_path if data_dir is None else join(data_dir, train_path) 119 | val_path = val_path if data_dir is None else join(data_dir, val_path) 120 | dictionary_path = dictionary_path if data_dir is None else join(data_dir, dictionary_path) 121 | 122 | run_name = run_name if run_name is not None else make_run_name(RUN_NAME_FORMAT, phase='finetune', config=config) 123 | logger = make_logger(run_name, log_output) 124 | logger.info('Run name : {run_name}'.format(run_name=run_name)) 125 | logger.info(config) 126 | 127 | logger.info('Constructing dictionaries...') 128 | dictionary = IndexDictionary.load(dictionary_path=dictionary_path, 129 | vocabulary_size=vocabulary_size) 130 | vocabulary_size = len(dictionary) 131 | logger.info(f'dictionary vocabulary : {vocabulary_size} tokens') 132 | 133 | logger.info('Loading datasets...') 134 | train_dataset = SST2IndexedDataset(data_path=train_path, dictionary=dictionary) 135 | val_dataset = SST2IndexedDataset(data_path=val_path, dictionary=dictionary) 136 | logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset))) 137 | 138 | logger.info('Building model...') 139 | pretrained_model = build_model(layers_count, hidden_size, heads_count, d_ff, dropout_prob, max_len, vocabulary_size) 140 | pretrained_model.load_state_dict(torch.load(pretrained_checkpoint, map_location='cpu')['state_dict']) 141 | 142 | model = FineTuneModel(pretrained_model, hidden_size, num_classes=2) 143 | 144 | logger.info(model) 145 | logger.info('{parameters_count} parameters'.format( 146 | parameters_count=sum([p.nelement() for p in model.parameters()]))) 147 | 148 | loss_model = ClassificationLossModel(model) 149 | metric_functions = [classification_accuracy] 150 | 151 | train_dataloader = DataLoader( 152 | train_dataset, 153 | batch_size=batch_size, 154 | collate_fn=classification_collate_function) 155 | 156 | val_dataloader = DataLoader( 157 | val_dataset, 158 | batch_size=batch_size, 159 | collate_fn=classification_collate_function) 160 | 161 | optimizer = Adam(model.parameters(), lr=lr) 162 | 163 | checkpoint_dir = make_checkpoint_dir(checkpoint_dir, run_name, config) 164 | 165 | logger.info('Start training...') 166 | trainer = Trainer( 167 | loss_model=loss_model, 168 | train_dataloader=train_dataloader, 169 | val_dataloader=val_dataloader, 170 | metric_functions=metric_functions, 171 | optimizer=optimizer, 172 | clip_grads=clip_grads, 173 | logger=logger, 174 | checkpoint_dir=checkpoint_dir, 175 | print_every=print_every, 176 | save_every=save_every, 177 | device=device 178 | ) 179 | 180 | trainer.run(epochs=epochs) 181 | return trainer 182 | 183 | 184 | def add_pretrain_parser(subparsers): 185 | pretrain_parser = subparsers.add_parser('pretrain') 186 | pretrain_parser.set_defaults(function=pretrain) 187 | 188 | pretrain_parser.add_argument('--data_dir', type=str, default=None) 189 | pretrain_parser.add_argument('--train_path', type=str, default='train.txt') 190 | pretrain_parser.add_argument('--val_path', type=str, default='val.txt') 191 | pretrain_parser.add_argument('--dictionary_path', type=str, default='dictionary.txt') 192 | 193 | pretrain_parser.add_argument('--checkpoint_dir', type=str, default=None) 194 | pretrain_parser.add_argument('--log_output', type=str, default=None) 195 | 196 | pretrain_parser.add_argument('--dataset_limit', type=int, default=None) 197 | pretrain_parser.add_argument('--epochs', type=int, default=100) 198 | pretrain_parser.add_argument('--batch_size', type=int, default=16) 199 | 200 | pretrain_parser.add_argument('--print_every', type=int, default=1) 201 | pretrain_parser.add_argument('--save_every', type=int, default=10) 202 | 203 | pretrain_parser.add_argument('--vocabulary_size', type=int, default=30000) 204 | pretrain_parser.add_argument('--max_len', type=int, default=512) 205 | 206 | pretrain_parser.add_argument('--lr', type=float, default=0.001) 207 | pretrain_parser.add_argument('--clip_grads', action='store_true') 208 | 209 | pretrain_parser.add_argument('--layers_count', type=int, default=1) 210 | pretrain_parser.add_argument('--hidden_size', type=int, default=128) 211 | pretrain_parser.add_argument('--heads_count', type=int, default=2) 212 | pretrain_parser.add_argument('--d_ff', type=int, default=128) 213 | pretrain_parser.add_argument('--dropout_prob', type=float, default=0.1) 214 | 215 | pretrain_parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 216 | 217 | 218 | def add_finetune_parser(subparsers): 219 | finetune_parser = subparsers.add_parser('finetune') 220 | finetune_parser.set_defaults(function=finetune) 221 | 222 | finetune_parser.add_argument('--pretrained_checkpoint', type=str, required=True) 223 | 224 | finetune_parser.add_argument('--data_dir', type=str, default=None) 225 | finetune_parser.add_argument('--train_data', type=str, default='train.tsv') 226 | finetune_parser.add_argument('--val_data', type=str, default='dev.tsv') 227 | finetune_parser.add_argument('--dictionary', type=str, default='dictionary.txt') 228 | 229 | finetune_parser.add_argument('--checkpoint_dir', type=str, default=None) 230 | finetune_parser.add_argument('--log_output', type=str, default=None) 231 | 232 | finetune_parser.add_argument('--dataset_limit', type=int, default=None) 233 | finetune_parser.add_argument('--epochs', type=int, default=100) 234 | finetune_parser.add_argument('--batch_size', type=int, default=16) 235 | 236 | finetune_parser.add_argument('--print_every', type=int, default=1) 237 | finetune_parser.add_argument('--save_every', type=int, default=10) 238 | 239 | finetune_parser.add_argument('--vocabulary_size', type=int, default=30000) 240 | finetune_parser.add_argument('--max_len', type=int, default=512) 241 | 242 | finetune_parser.add_argument('--lr', type=float, default=0.001) 243 | finetune_parser.add_argument('--clip_grads', action='store_true') 244 | 245 | finetune_parser.add_argument('--layers_count', type=int, default=1) 246 | finetune_parser.add_argument('--hidden_size', type=int, default=128) 247 | finetune_parser.add_argument('--heads_count', type=int, default=2) 248 | finetune_parser.add_argument('--d_ff', type=int, default=128) 249 | finetune_parser.add_argument('--dropout_prob', type=float, default=0.1) 250 | 251 | finetune_parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 252 | -------------------------------------------------------------------------------- /data/wiki-example/wiki.txt: -------------------------------------------------------------------------------- 1 | anarchism is political philosophy that advocates self governed societies based on voluntary institutions these are often described as stateless societies although several authors have defined them more specifically as institutions based on non hierarchical or free associations anarchism holds the state to be undesirable unnecessary and harmful while opposition to the state is central anarchism specifically entails opposing authority or hierarchical organisation in the conduct of all human relations anarchism is usually considered far left ideology and much of anarchist economics and anarchist legal philosophy reflects anti authoritarian interpretations of communism collectivism syndicalism mutualism or participatory economics anarchism does not offer fixed body of doctrine from single particular 2 | autism is developmental disorder characterized by troubles with social interaction and communication and by restricted and repetitive behavior parents usually notice signs in the first two or three years of their child life these signs often develop gradually though some children with autism reach their developmental milestones at normal pace and then worsen autism is associated with combination of genetic and environmental factors risk factors include certain infections during pregnancy such as rubella as well as valproic acid alcohol or cocaine use during pregnancy controversies surround other proposed environmental causes for example the vaccine hypotheses which have been disproven autism affects information processing in the brain by altering 3 | diffusely reflected sunlight relative to various surface conditions albedo meaning whiteness is the measure of the diffuse reflection of solar radiation out of the total solar radiation received by an astronomical body planet like earth it is dimensionless and measured on scale from corresponding to black body that absorbs all incident radiation to corresponding to body that reflects all incident radiation surface albedo is defined as the ratio of irradiance reflected to the irradiance received by surface the proportion reflected is not only determined by properties of the surface itself but also by the spectral and angular distribution of solar radiation reaching the earth surface these factors 4 | named plural as or aes is the first letter and the first vowel of the iso basic latin alphabet it is similar to the ancient greek letter alpha from which it derives the uppercase version consists of the two slanting sides of triangle crossed in the middle by horizontal bar the lowercase version can be written in two forms the double storey and single storey the latter is commonly used in handwriting and fonts based on it especially fonts intended to be read by children and is also found in italic type history egyptian cretan phoenician aleph semitic greek alpha etruscan roman cyrillic boeotian bc greek uncial latin ad 5 | alabama is state in the southeastern region of the united states it is bordered by tennessee to the north georgia to the east florida and the gulf of mexico to the south and mississippi to the west alabama is the th largest by area and the th most populous of the states with total of of inland waterways alabama has among the most of any state alabama is nicknamed the yellowhammer state after the state bird alabama is also known as the heart of dixie and the cotton state the state tree is the longleaf pine and the state flower is the camellia alabama capital is montgomery the largest 6 | achilles and agamemnon by gottlieb schick in greek mythology achilles or achilleus achilleus was greek hero of the trojan war and the central character and greatest warrior of homer iliad his mother was the immortal nereid thetis and his father the mortal peleus was the king of the myrmidons achilles most notable feat during the trojan war was the slaying of the trojan hero hector outside the gates of troy although the death of achilles is not presented in the iliad other sources concur that he was killed near the end of the trojan war by paris who shot him in the heel with an arrow later legends 7 | abraham lincoln february april was an american statesman and lawyer who served as the th president of the united states from march until his assassination in april lincoln led the united states through the american civil war its bloodiest war and perhaps its greatest moral constitutional and political crisis in doing so he preserved the union abolished slavery strengthened the federal government and modernized the economy born in hodgenville kentucky lincoln grew up on the western frontier in kentucky and indiana largely self educated he became lawyer in illinois whig party leader and was elected to the illinois house of representatives in which he served for eight years 8 | aristotle aristotélēs bc was an ancient greek philosopher and scientist born in the city of stagira chalkidiki in the north of classical greece along with plato aristotle is considered the father of western philosophy which inherited almost its entire lexicon from his teachings including problems and methods of inquiry so influencing almost all forms of knowledge little is known for certain about his life his father nicomachus died when aristotle was child and he was brought up by guardian at seventeen or eighteen years of age he joined plato academy in athens and remained there until the age of thirty seven bc his writings cover many subjects 9 | time key major relative acciaccatura mp acciaccatura acciaccatura acciaccatura acciaccatura c_ time clef bass partial cis markup italic scherzando dis eis fis gis fis eis dis fis eis acciaccatura eis dis cis cis acciaccatura eis dis cis themes from an american in paris an american in paris is jazz influenced orchestral piece by the american composer george gershwin written in inspired by the time gershwin had spent in paris it evokes the sights and energy of the french capital in the and is one of his best known compositions gershwin composed an american in paris on commission from the conductor walter damrosch he scored the piece for the standard 10 | the academy award for best production design recognizes achievement for art direction in film the category original name was best art direction but was changed to its current name in for the th academy awards this change resulted from the art director branch of the academy of motion picture arts and sciences ampas being renamed the designer branch since the award is shared with the set decorator it is awarded to the best interior design in film the films below are listed with their production year for example the academy award for best art direction is given to film from in the lists below the winner of the award 11 | the academy awards also known as the oscars are set of awards for artistic and technical merit in the film industry given annually by the academy of motion picture arts and sciences ampas to recognize excellence in cinematic achievements as assessed by the academy voting membership the various category winners are awarded copy of golden statuette officially called the academy award of merit although more commonly referred to by its nickname oscar the award was originally sculpted by george stanley from design sketch by cedric gibbons ampas first presented it in at private dinner hosted by douglas fairbanks in the hollywood roosevelt hotel the academy awards ceremony was 12 | actresses catalan actrius is catalan language spanish drama film produced and directed by ventura pons and based on the award winning stage play by josep maria benet jornet the film has no male actors with all roles played by females the film was produced in synopsis in order to prepare herself to play role commemorating the life of legendary actress empar ribera young actress mercè pons interviews three established actresses who had been the ribera pupils the international diva glòria marc núria espert the television star assumpta roca rosa maria sardà and dubbing director maria caminal anna lizaran cast núria espert as glòria marc rosa maria sardà as 13 | animalia is an illustrated children book by graeme base it was originally published in followed by tenth anniversary edition in and th anniversary edition in over three million copies have been sold special numbered and signed anniversary edition was also published in with an embossed gold jacket synopsis animalia is an alliterative alphabet book and contains twenty six illustrations one for each letter of the alphabet each illustration features an animal from the animal kingdom is for alligator is for butterfly etc along with short poem utilizing the letter of the page for many of the words the illustrations contain many other objects beginning with that letter that the 14 | international atomic time tai from the french name is high precision atomic coordinate time standard based on the notional passage of proper time on earth geoid it is the principal realisation of terrestrial time except for fixed offset of epoch it is also the basis for coordinated universal time utc which is used for civil timekeeping all over the earth surface when another leap second was added tai is exactly seconds ahead of utc the seconds results from the initial difference of seconds at the start of plus leap seconds in utc since tai may be reported using traditional means of specifying days carried over from non 15 | giving alms to the poor is often considered an altruistic action altruism is the principle and moral practice of concern for happiness of other human beings resulting in quality of life both material and spiritual it is traditional virtue in many cultures and core aspect of various religious traditions and secular worldviews though the concept of others toward whom concern should be directed can vary among cultures and religions in an extreme case altruism may become synonym of selflessness which is the opposite of selfishness in common way of living it doesn deny the singular nature of the subject but realizes the traits of the individual personality in relation 16 | ayn rand born alisa zinovyevna rosenbaum march was russian american novelist playwright screenwriter and philosopher she is known for her two best selling novels the fountainhead and atlas shrugged and for developing philosophical system she named objectivism educated in russia she moved to the united states in she had play produced on broadway in and after two early novels that were initially unsuccessful she achieved fame with her novel the fountainhead in rand published her best known work the novel atlas shrugged afterward she turned to non fiction to promote her philosophy publishing her own periodicals and releasing several collections of essays until her death in rand advocated reason 17 | alain connes born april is french mathematician currently professor at the collège de france ihés ohio state university and vanderbilt university he was an invited professor at the conservatoire national des arts et métiers work alain connes studies operator algebras in his early work on von neumann algebras in the he succeeded in obtaining the almost complete classification of injective factors he also formulated the connes embedding problem following this he made contributions in operator theory and index theory which culminated in the baum connes conjecture he also introduced cyclic cohomology in the early as first step in the study of noncommutative differential geometry he was member of bourbaki 18 | allan dwan april december was pioneering canadian born american motion picture director producer and screenwriter early life born joseph aloysius dwan in toronto ontario canada dwan was the younger son of commercial traveller of woolen clothing joseph michael dwan and his wife mary jane dwan née hunt the family moved to the united states when he was seven years old on december by ferry from windsor to detroit according to his naturalization petition of august his elder brother leo garnet dwan became physician at the university of notre dame allan dwan studied engineering and began working for lighting company in chicago however he had strong interest in the fledgling 19 | algeria familary algerian arabic officially the people democratic republic of algeria is country in north africa on the mediterranean coast the capital and most populous city is algiers located in the far north of the country with an area of algeria is the tenth largest country in the world and the largest in africa since south sudan became independent from sudan in algeria is bordered to the northeast by tunisia to the east by libya to the west by morocco to the southwest by the western saharan territory mauritania and mali to the southeast by niger and to the north by the mediterranean sea the country is semi 20 | this is list of characters in ayn rand novel atlas shrugged major characters the following are major characters from the novel protagonists dagny taggart dagny taggart is the protagonist of the novel she is vice president in charge of operations for taggart under her brother james taggart given james incompetence dagny is responsible for all the workings of the railroad francisco anconia francisco anconia is one of the central characters in atlas shrugged an owner by inheritance of the world largest copper mining operation he is childhood friend and the first love of dagny taggart child prodigy of exceptional talents francisco was dubbed the climax of the anconia line an 21 | anthropology is the study of humans and human behavior and societies in the past and present social anthropology and cultural anthropology study the norms and values of societies linguistic anthropology studies how language affects social life biological or physical anthropology studies the biological development of humans archaeology which studies past human cultures through investigation of physical evidence is thought of as branch of anthropology in the united states while in europe it is viewed as discipline in its own right or grouped under other related disciplines such as history origin and development of the term bernardino de sahagún is considered to be the founder of modern anthropology the 22 | agricultural science is broad field of biology that encompasses the parts of exact natural economic and social sciences that are used in the practice and understanding of agriculture veterinary science but not animal science is often excluded from the definition agriculture agricultural science and agronomy the three terms are often confused however they cover different concepts agriculture is the set of activities that transform the environment for the production of animals and plants for human use agriculture concerns techniques including the application of agronomic research agronomy is research and development related to studying and improving plant based crops agricultural sciences include research and development on plant breeding 23 | kimiya yi sa ādat the alchemy of happiness text on persian islamic philosophy and spiritual alchemy by al ghazālī alchemy from arabic al kīmiyā is philosophical and protoscientific tradition practiced throughout europe africa and asia it aims to purify mature and perfect certain objects common aims were chrysopoeia the transmutation of base metals lead into noble metals particularly gold the creation of an elixir of immortality the creation of panaceas able to cure any disease and the development of an alkahest universal solvent the perfection of the human body and soul was thought to permit or result from the alchemical magnum opus and in the hellenistic and western tradition the 24 | alien primarily refers to life life which does not originate from earth specifically intelligent beings see list of alleged beings alien law person in country who is not national of that country alien or the alien may also refer to science and technology introduced species species not native to its environment alien file converter linux program alien alice environment grid framework alien technology manufacturer of rfid technology aliens newsletter of the iucn invasive species specialist group arts and entertainment alien franchise media franchise alien creature in alien franchise films alien film film by ridley scott aliens film the sequel by james cameron alien third film in the series 25 | the astronomer by johannes vermeer an astronomer is scientist in the field of astronomy who focuses their studies on specific question or field outside the scope of earth they observe astronomical objects such as stars planets moons comets and galaxies in either observational by analyzing the data or theoretical astronomy examples of topics or fields astronomers study include planetary science solar astronomy the origin or evolution of stars or the formation of galaxies related but distinct subjects like physical cosmology which studies the universe as whole astronomers usually fall under either of two main types observational and theoretical observational astronomers make direct observations of celestial objects and analyze the -------------------------------------------------------------------------------- /data/SST-2/dev.tsv: -------------------------------------------------------------------------------- 1 | sentence label 2 | it 's a charming and often affecting journey . 1 3 | unflinchingly bleak and desperate 0 4 | allows us to hope that nolan is poised to embark a major career as a commercial yet inventive filmmaker . 1 5 | the acting , costumes , music , cinematography and sound are all astounding given the production 's austere locales . 1 6 | it 's slow -- very , very slow . 0 7 | although laced with humor and a few fanciful touches , the film is a refreshingly serious look at young women . 1 8 | a sometimes tedious film . 0 9 | or doing last year 's taxes with your ex-wife . 0 10 | you do n't have to know about music to appreciate the film 's easygoing blend of comedy and romance . 1 11 | in exactly 89 minutes , most of which passed as slowly as if i 'd been sitting naked on an igloo , formula 51 sank from quirky to jerky to utter turkey . 0 12 | the mesmerizing performances of the leads keep the film grounded and keep the audience riveted . 1 13 | it takes a strange kind of laziness to waste the talents of robert forster , anne meara , eugene levy , and reginald veljohnson all in the same movie . 0 14 | ... the film suffers from a lack of humor ( something needed to balance out the violence ) ... 0 15 | we root for ( clara and paul ) , even like them , though perhaps it 's an emotion closer to pity . 1 16 | even horror fans will most likely not find what they 're seeking with trouble every day ; the movie lacks both thrills and humor . 0 17 | a gorgeous , high-spirited musical from india that exquisitely blends music , dance , song , and high drama . 1 18 | the emotions are raw and will strike a nerve with anyone who 's ever had family trauma . 1 19 | audrey tatou has a knack for picking roles that magnify her outrageous charm , and in this literate french comedy , she 's as morning-glory exuberant as she was in amélie . 1 20 | ... the movie is just a plain old monster . 0 21 | in its best moments , resembles a bad high school production of grease , without benefit of song . 0 22 | pumpkin takes an admirable look at the hypocrisy of political correctness , but it does so with such an uneven tone that you never know when humor ends and tragedy begins . 0 23 | the iditarod lasts for days - this just felt like it did . 0 24 | holden caulfield did it better . 0 25 | a delectable and intriguing thriller filled with surprises , read my lips is an original . 1 26 | seldom has a movie so closely matched the spirit of a man and his work . 1 27 | nicks , seemingly uncertain what 's going to make people laugh , runs the gamut from stale parody to raunchy sex gags to formula romantic comedy . 0 28 | the action switches between past and present , but the material link is too tenuous to anchor the emotional connections that purport to span a 125-year divide . 0 29 | it 's an offbeat treat that pokes fun at the democratic exercise while also examining its significance for those who take part . 1 30 | it 's a cookie-cutter movie , a cut-and-paste job . 0 31 | i had to look away - this was god awful . 0 32 | thanks to scott 's charismatic roger and eisenberg 's sweet nephew , roger dodger is one of the most compelling variations on in the company of men . 1 33 | ... designed to provide a mix of smiles and tears , `` crossroads '' instead provokes a handful of unintentional howlers and numerous yawns . 0 34 | a gorgeous , witty , seductive movie . 1 35 | if the movie succeeds in instilling a wary sense of ` there but for the grace of god , ' it is far too self-conscious to draw you deeply into its world . 0 36 | it does n't believe in itself , it has no sense of humor ... it 's just plain bored . 0 37 | a sequence of ridiculous shoot - 'em - up scenes . 0 38 | the weight of the piece , the unerring professionalism of the chilly production , and the fascination embedded in the lurid topic prove recommendation enough . 1 39 | ( w ) hile long on amiable monkeys and worthy environmentalism , jane goodall 's wild chimpanzees is short on the thrills the oversize medium demands . 0 40 | as surreal as a dream and as detailed as a photograph , as visually dexterous as it is at times imaginatively overwhelming . 1 41 | escaping the studio , piccoli is warmly affecting and so is this adroitly minimalist movie . 1 42 | there 's ... tremendous energy from the cast , a sense of playfulness and excitement that seems appropriate . 1 43 | this illuminating documentary transcends our preconceived vision of the holy land and its inhabitants , revealing the human complexities beneath . 1 44 | the subtle strength of `` elling '' is that it never loses touch with the reality of the grim situation . 1 45 | holm ... embodies the character with an effortlessly regal charisma . 1 46 | the title not only describes its main characters , but the lazy people behind the camera as well . 0 47 | it offers little beyond the momentary joys of pretty and weightless intellectual entertainment . 0 48 | a synthesis of cliches and absurdities that seems positively decadent in its cinematic flash and emptiness . 0 49 | a subtle and well-crafted ( for the most part ) chiller . 1 50 | has a lot of the virtues of eastwood at his best . 1 51 | it 's hampered by a lifetime-channel kind of plot and a lead actress who is out of her depth . 0 52 | it feels like an after-school special gussied up with some fancy special effects , and watching its rote plot points connect is about as exciting as gazing at an egg timer for 93 minutes . 0 53 | for the most part , director anne-sophie birot 's first feature is a sensitive , extraordinarily well-acted drama . 1 54 | mr. tsai is a very original artist in his medium , and what time is it there ? 1 55 | sade is an engaging look at the controversial eponymous and fiercely atheistic hero . 1 56 | so devoid of any kind of intelligible story that it makes films like xxx and collateral damage seem like thoughtful treatises 0 57 | a tender , heartfelt family drama . 1 58 | ... a hollow joke told by a cinematic gymnast having too much fun embellishing the misanthropic tale to actually engage it . 0 59 | the cold turkey would 've been a far better title . 0 60 | manages to be both repulsively sadistic and mundane . 0 61 | it 's just disappointingly superficial -- a movie that has all the elements necessary to be a fascinating , involving character study , but never does more than scratch the surface . 0 62 | this is a story of two misfits who do n't stand a chance alone , but together they are magnificent . 1 63 | schaeffer has to find some hook on which to hang his persistently useless movies , and it might as well be the resuscitation of the middle-aged character . 0 64 | the primitive force of this film seems to bubble up from the vast collective memory of the combatants . 1 65 | on this tricky topic , tadpole is very much a step in the right direction , with its blend of frankness , civility and compassion . 1 66 | the script kicks in , and mr. hartley 's distended pace and foot-dragging rhythms follow . 0 67 | you wonder why enough was n't just a music video rather than a full-length movie . 0 68 | if you 're hard up for raunchy college humor , this is your ticket right here . 1 69 | a fast , funny , highly enjoyable movie . 1 70 | good old-fashioned slash-and-hack is back ! 1 71 | this one is definitely one to skip , even for horror movie fanatics . 0 72 | for all its impressive craftsmanship , and despite an overbearing series of third-act crescendos , lily chou-chou never really builds up a head of emotional steam . 0 73 | exquisitely nuanced in mood tics and dialogue , this chamber drama is superbly acted by the deeply appealing veteran bouquet and the chilling but quite human berling . 1 74 | uses high comedy to evoke surprising poignance . 1 75 | one of creepiest , scariest movies to come along in a long , long time , easily rivaling blair witch or the others . 1 76 | a string of rehashed sight gags based in insipid vulgarity . 0 77 | among the year 's most intriguing explorations of alientation . 1 78 | the movie fails to live up to the sum of its parts . 0 79 | the son 's room is a triumph of gentility that earns its moments of pathos . 1 80 | there is nothing outstanding about this film , but it is good enough and will likely be appreciated most by sailors and folks who know their way around a submarine . 1 81 | this is a train wreck of an action film -- a stupefying attempt by the filmmakers to force-feed james bond into the mindless xxx mold and throw 40 years of cinematic history down the toilet in favor of bright flashes and loud bangs . 0 82 | the draw ( for `` big bad love '' ) is a solid performance by arliss howard . 1 83 | green might want to hang onto that ski mask , as robbery may be the only way to pay for his next project . 0 84 | it 's one pussy-ass world when even killer-thrillers revolve around group therapy sessions . 0 85 | though it 's become almost redundant to say so , major kudos go to leigh for actually casting people who look working-class . 1 86 | the band 's courage in the face of official repression is inspiring , especially for aging hippies ( this one included ) . 1 87 | the movie achieves as great an impact by keeping these thoughts hidden as ... ( quills ) did by showing them . 1 88 | the film flat lines when it should peak and is more missed opportunity and trifle than dark , decadent truffle . 0 89 | jaglom ... put ( s ) the audience in the privileged position of eavesdropping on his characters 1 90 | fresnadillo 's dark and jolting images have a way of plying into your subconscious like the nightmare you had a week ago that wo n't go away . 1 91 | we know the plot 's a little crazy , but it held my interest from start to finish . 1 92 | it 's a scattershot affair , but when it hits its mark it 's brilliant . 1 93 | hardly a masterpiece , but it introduces viewers to a good charitable enterprise and some interesting real people . 1 94 | you wo n't like roger , but you will quickly recognize him . 0 95 | if steven soderbergh 's ` solaris ' is a failure it is a glorious failure . 1 96 | byler reveals his characters in a way that intrigues and even fascinates us , and he never reduces the situation to simple melodrama . 1 97 | this riveting world war ii moral suspense story deals with the shadow side of american culture : racial prejudice in its ugly and diverse forms . 0 98 | it 's difficult to imagine the process that produced such a script , but here 's guessing that spray cheese and underarm noises played a crucial role . 0 99 | no sophomore slump for director sam mendes , who segues from oscar winner to oscar-winning potential with a smooth sleight of hand . 1 100 | on the whole , the movie lacks wit , feeling and believability to compensate for its incessant coarseness and banality . 0 101 | why make a documentary about these marginal historical figures ? 0 102 | neither parker nor donovan is a typical romantic lead , but they bring a fresh , quirky charm to the formula . 1 103 | his last movie was poetically romantic and full of indelible images , but his latest has nothing going for it . 0 104 | does paint some memorable images ... , but makhmalbaf keeps her distance from the characters 1 105 | a gripping movie , played with performances that are all understated and touching . 1 106 | it 's one of those baseball pictures where the hero is stoic , the wife is patient , the kids are as cute as all get-out and the odds against success are long enough to intimidate , but short enough to make a dream seem possible . 1 107 | combining quick-cut editing and a blaring heavy metal much of the time , beck seems to be under the illusion that he 's shooting the latest system of a down video . 0 108 | the movie 's relatively simple plot and uncomplicated morality play well with the affable cast . 1 109 | what the director ca n't do is make either of val kilmer 's two personas interesting or worth caring about . 0 110 | too often , the viewer is n't reacting to humor so much as they are wincing back in repugnance . 0 111 | it 's great escapist fun that recreates a place and time that will never happen again . 1 112 | scores no points for originality , wit , or intelligence . 0 113 | there is n't nearly enough fun here , despite the presence of some appealing ingredients . 0 114 | hilariously inept and ridiculous . 1 115 | this movie is maddening . 0 116 | it haunts you , you ca n't forget it , you admire its conception and are able to resolve some of the confusions you had while watching it . 1 117 | sam mendes has become valedictorian at the school for soft landings and easy ways out . 0 118 | one of the smartest takes on singles culture i 've seen in a long time . 1 119 | moody , heartbreaking , and filmed in a natural , unforced style that makes its characters seem entirely convincing even when its script is not . 1 120 | every nanosecond of the the new guy reminds you that you could be doing something else far more pleasurable . 0 121 | comes ... uncomfortably close to coasting in the treads of the bicycle thief . 0 122 | warm water under a red bridge is a quirky and poignant japanese film that explores the fascinating connections between women , water , nature , and sexuality . 1 123 | it seems to me the film is about the art of ripping people off without ever letting them consciously know you have done so 0 124 | old-form moviemaking at its best . 1 125 | turns potentially forgettable formula into something strangely diverting . 1 126 | ( lawrence bounces ) all over the stage , dancing , running , sweating , mopping his face and generally displaying the wacky talent that brought him fame in the first place . 1 127 | a movie that reminds us of just how exciting and satisfying the fantasy cinema can be when it 's approached with imagination and flair . 1 128 | confirms the nagging suspicion that ethan hawke would be even worse behind the camera than he is in front of it . 0 129 | in the end , we are left with something like two ships passing in the night rather than any insights into gay love , chinese society or the price one pays for being dishonest . 0 130 | montias ... pumps a lot of energy into his nicely nuanced narrative and surrounds himself with a cast of quirky -- but not stereotyped -- street characters . 1 131 | it provides the grand , intelligent entertainment of a superior cast playing smart people amid a compelling plot . 1 132 | suffers from the lack of a compelling or comprehensible narrative . 0 133 | in execution , this clever idea is far less funny than the original , killers from space . 0 134 | scooby dooby doo / and shaggy too / you both look and sound great . 1 135 | the tale of tok ( andy lau ) , a sleek sociopath on the trail of o ( takashi sorimachi ) , the most legendary of asian hitmen , is too scattershot to take hold . 0 136 | it all drags on so interminably it 's like watching a miserable relationship unfold in real time . 0 137 | pumpkin means to be an outrageous dark satire on fraternity life , but its ambitions far exceed the abilities of writer adam larson broder and his co-director , tony r. abrams , in their feature debut . 0 138 | looks and feels like a project better suited for the small screen . 0 139 | forced , familiar and thoroughly condescending . 0 140 | that is a compliment to kuras and miller . 1 141 | it 's not the ultimate depression-era gangster movie . 0 142 | sacrifices the value of its wealth of archival foot-age with its less-than-objective stance . 0 143 | the character of zigzag is not sufficiently developed to support a film constructed around him . 0 144 | what better message than ` love thyself ' could young women of any size receive ? 1 145 | a solid film ... but more conscientious than it is truly stirring . 1 146 | while ( hill ) has learned new tricks , the tricks alone are not enough to salvage this lifeless boxing film . 0 147 | the best that can be said about the work here of scottish director ritchie ... is that he obviously does n't have his heart in it . 0 148 | about a manga-like heroine who fights back at her abusers , it 's energetic and satisfying if not deep and psychological . 1 149 | the talented and clever robert rodriguez perhaps put a little too much heart into his first film and did n't reserve enough for his second . 0 150 | feels too formulaic and too familiar to produce the transgressive thrills of early underground work . 0 151 | the volatile dynamics of female friendship is the subject of this unhurried , low-key film that is so off-hollywood that it seems positively french in its rhythms and resonance . 1 152 | overall very good for what it 's trying to do . 1 153 | a big , gorgeous , sprawling swashbuckler that delivers its diversions in grand , uncomplicated fashion . 1 154 | a difficult , absorbing film that manages to convey more substance despite its repetitions and inconsistencies than do most films than are far more pointed and clear . 1 155 | the heavy-handed film is almost laughable as a consequence . 0 156 | a solid examination of the male midlife crisis . 1 157 | a nightmare date with a half-formed wit done a great disservice by a lack of critical distance and a sad trust in liberal arts college bumper sticker platitudes . 0 158 | manages to transcend the sex , drugs and show-tunes plot into something far richer . 1 159 | it takes talent to make a lifeless movie about the most heinous man who ever lived . 0 160 | by getting myself wrapped up in the visuals and eccentricities of many of the characters , i found myself confused when it came time to get to the heart of the movie . 0 161 | like leon , it 's frustrating and still oddly likable . 1 162 | uncommonly stylish but equally silly ... the picture fails to generate much suspense , nor does it ask searching enough questions to justify its pretensions . 0 163 | not exactly the bees knees 0 164 | there seems to be no clear path as to where the story 's going , or how long it 's going to take to get there . 0 165 | slapstick buffoonery can tickle many a preschooler 's fancy , but when it costs a family of four about $ 40 to see a film in theaters , why spend money on a dog like this when you can rent a pedigree instead ? 0 166 | a woman 's pic directed with resonance by ilya chaiken . 1 167 | may reawaken discussion of the kennedy assassination but this fictional film looks made for cable rather than for the big screen . 0 168 | characters still need to function according to some set of believable and comprehensible impulses , no matter how many drugs they do or how much artistic license avary employs . 0 169 | the end result is a film that 's neither . 0 170 | manages to be sweet and wickedly satisfying at the same time . 1 171 | leigh 's film is full of memorable performances from top to bottom . 1 172 | it 's also , clearly , great fun . 1 173 | rarely has leukemia looked so shimmering and benign . 0 174 | it seems like i have been waiting my whole life for this movie and now i ca n't wait for the sequel . 1 175 | determined to be fun , and bouncy , with energetic musicals , the humor did n't quite engage this adult . 0 176 | if you dig on david mamet 's mind tricks ... rent this movie and enjoy ! 1 177 | bleakly funny , its characters all the more touching for refusing to pity or memorialize themselves . 1 178 | delivers the same old same old , tarted up with latin flava and turned out by hollywood playas . 0 179 | does n't offer much besides glib soullessness , raunchy language and a series of brutal set pieces ... that raise the bar on stylized screen violence . 0 180 | it made me want to wrench my eyes out of my head and toss them at the screen . 0 181 | the film 's performances are thrilling . 1 182 | unfortunately , it 's not silly fun unless you enjoy really bad movies . 0 183 | it 's a bad thing when a movie has about as much substance as its end credits blooper reel . 0 184 | i sympathize with the plight of these families , but the movie does n't do a very good job conveying the issue at hand . 0 185 | the lower your expectations , the more you 'll enjoy it . 0 186 | though perry and hurley make inspiring efforts to breathe life into the disjointed , haphazard script by jay scherick and david ronn , neither the actors nor director reginald hudlin can make it more than fitfully entertaining . 0 187 | a must-see for the david mamet enthusiast and for anyone who appreciates intelligent , stylish moviemaking . 1 188 | pacino is brilliant as the sleep-deprived dormer , his increasing weariness as much existential as it is physical . 1 189 | ` de niro ... is a veritable source of sincere passion that this hollywood contrivance orbits around . ' 1 190 | a misogynistic piece of filth that attempts to pass itself off as hip , young adult entertainment . 0 191 | its story may be a thousand years old , but why did it have to seem like it took another thousand to tell it to us ? 0 192 | try as i may , i ca n't think of a single good reason to see this movie , even though everyone in my group extemporaneously shouted , ` thank you ! ' 0 193 | the movie is beautiful to behold and engages one in a sense of epic struggle -- inner and outer -- that 's all too rare in hollywood 's hastier productions . 1 194 | a celebration of quirkiness , eccentricity , and certain individuals ' tendency to let it all hang out , and damn the consequences . 1 195 | morton uses her face and her body language to bring us morvern 's soul , even though the character is almost completely deadpan . 1 196 | instead of a hyperbolic beat-charged urban western , it 's an unpretentious , sociologically pointed slice of life . 1 197 | my thoughts were focused on the characters . 1 198 | so , too , is this comedy about mild culture clashing in today 's new delhi . 1 199 | for starters , the story is just too slim . 0 200 | this is a winning ensemble comedy that shows canadians can put gentle laughs and equally gentle sentiments on the button , just as easily as their counterparts anywhere else in the world . 1 201 | at the very least , if you do n't know anything about derrida when you walk into the theater , you wo n't know much more when you leave . 0 202 | the format gets used best ... to capture the dizzying heights achieved by motocross and bmx riders , whose balletic hotdogging occasionally ends in bone-crushing screwups . 1 203 | inside the film 's conflict-powered plot there is a decent moral trying to get out , but it 's not that , it 's the tension that keeps you in your seat . 1 204 | there ought to be a directing license , so that ed burns can have his revoked . 0 205 | bad . 0 206 | that dogged good will of the parents and ` vain ' jia 's defoliation of ego , make the film touching despite some doldrums . 1 207 | falls neatly into the category of good stupid fun . 1 208 | an artful , intelligent film that stays within the confines of a well-established genre . 1 209 | smart , provocative and blisteringly funny . 1 210 | and the lesson , in the end , is nothing new . 0 211 | this is not the undisputed worst boxing movie ever , but it 's certainly not a champion - the big loser is the audience . 0 212 | not only is undercover brother as funny , if not more so , than both austin powers films , but it 's also one of the smarter , savvier spoofs to come along in some time . 1 213 | to say this was done better in wilder 's some like it hot is like saying the sun rises in the east . 0 214 | the entire movie is about a boring , sad man being boring and sad . 0 215 | this time mr. burns is trying something in the martin scorsese street-realist mode , but his self-regarding sentimentality trips him up again . 0 216 | perceptive in its vision of nascent industrialized world politics as a new art form , but far too clunky , didactic and saddled with scenes that seem simply an ill fit for this movie . 0 217 | the best revenge may just be living well because this film , unlike other dumas adaptations , is far more likened to a treasure than a lengthy jail sentence . 1 218 | the movie understands like few others how the depth and breadth of emotional intimacy give the physical act all of its meaning and most of its pleasure . 1 219 | once ( kim ) begins to overplay the shock tactics and bait-and-tackle metaphors , you may decide it 's too high a price to pay for a shimmering picture postcard . 0 220 | all that 's missing is the spontaneity , originality and delight . 0 221 | what the film lacks in general focus it makes up for in compassion , as corcuera manages to find the seeds of hope in the form of collective action . 1 222 | the socio-histo-political treatise is told in earnest strides ... ( and ) personal illusion is deconstructed with poignancy . 1 223 | my reaction in a word : disappointment . 0 224 | a psychological thriller with a genuinely spooky premise and an above-average cast , actor bill paxton 's directing debut is a creepy slice of gothic rural americana . 1 225 | corny , schmaltzy and predictable , but still manages to be kind of heartwarming , nonetheless . 1 226 | nothing 's at stake , just a twisty double-cross you can smell a mile away -- still , the derivative nine queens is lots of fun . 1 227 | far more imaginative and ambitious than the trivial , cash-in features nickelodeon has made from its other animated tv series . 1 228 | of course , by more objective measurements it 's still quite bad . 0 229 | as the two leads , lathan and diggs are charming and have chemistry both as friends and lovers . 1 230 | it provides an honest look at a community striving to anchor itself in new grounds . 1 231 | this movie seems to have been written using mad-libs . 0 232 | reign of fire looks as if it was made without much thought -- and is best watched that way . 1 233 | martin and barbara are complex characters -- sometimes tender , sometimes angry -- and the delicate performances by sven wollter and viveka seldahl make their hopes and frustrations vivid . 1 234 | it 's not that kung pow is n't funny some of the time -- it just is n't any funnier than bad martial arts movies are all by themselves , without all oedekerk 's impish augmentation . 0 235 | i 'd have to say the star and director are the big problems here . 0 236 | affleck and jackson are good sparring partners . 1 237 | whether you like rap music or loathe it , you ca n't deny either the tragic loss of two young men in the prime of their talent or the power of this movie . 1 238 | not since japanese filmmaker akira kurosawa 's ran have the savagery of combat and the specter of death been visualized with such operatic grandeur . 1 239 | a by-the-numbers effort that wo n't do much to enhance the franchise . 0 240 | an occasionally funny , but overall limp , fish-out-of-water story . 0 241 | brilliantly explores the conflict between following one 's heart and following the demands of tradition . 1 242 | despite the 2-d animation , the wild thornberrys movie makes for a surprisingly cinematic experience . 1 243 | it appears that something has been lost in the translation to the screen . 0 244 | it all feels like a monty python sketch gone horribly wrong . 0 245 | the film tunes into a grief that could lead a man across centuries . 1 246 | dazzles with its fully-written characters , its determined stylishness ( which always relates to characters and story ) and johnny dankworth 's best soundtrack in years . 1 247 | it 's a work by an artist so in control of both his medium and his message that he can improvise like a jazzman . 1 248 | it 's the chemistry between the women and the droll scene-stealing wit and wolfish pessimism of anna chancellor that makes this `` two weddings and a funeral '' fun . 1 249 | stealing harvard is evidence that the farrelly bros. -- peter and bobby -- and their brand of screen comedy are wheezing to an end , along with green 's half-hearted movie career . 0 250 | a full world has been presented onscreen , not some series of carefully structured plot points building to a pat resolution . 1 251 | huston nails both the glad-handing and the choking sense of hollow despair . 1 252 | one of the more intelligent children 's movies to hit theaters this year . 1 253 | the film tries too hard to be funny and tries too hard to be hip . 0 254 | blanchett 's performance confirms her power once again . 1 255 | if you believe any of this , i can make you a real deal on leftover enron stock that will double in value a week from friday . 0 256 | attempts by this ensemble film to impart a message are so heavy-handed that they instead pummel the audience . 0 257 | no one but a convict guilty of some truly heinous crime should have to sit through the master of disguise . 0 258 | rarely has so much money delivered so little entertainment . 0 259 | taylor appears to have blown his entire budget on soundtrack rights and had nothing left over for jokes . 0 260 | `` the time machine '' is a movie that has no interest in itself . 0 261 | a rarity among recent iranian films : it 's a comedy full of gentle humor that chides the absurdity of its protagonist 's plight . 1 262 | / but daphne , you 're too buff / fred thinks he 's tough / and velma - wow , you 've lost weight ! 0 263 | the very definition of the ` small ' movie , but it is a good stepping stone for director sprecher . 1 264 | it 's like every bad idea that 's ever gone into an after-school special compiled in one place , minus those daytime programs ' slickness and sophistication ( and who knew they even had any ? ) . 0 265 | chilling , well-acted , and finely directed : david jacobson 's dahmer . 1 266 | it ca n't decide if it wants to be a mystery/thriller , a romance or a comedy . 0 267 | paid in full is so stale , in fact , that its most vibrant scene is one that uses clips from brian de palma 's scarface . 0 268 | a coda in every sense , the pinochet case splits time between a minute-by-minute account of the british court 's extradition chess game and the regime 's talking-head survivors . 1 269 | it 's played in the most straight-faced fashion , with little humor to lighten things up . 0 270 | a dumb movie with dumb characters doing dumb things and you have to be really dumb not to see where this is going . 0 271 | with virtually no interesting elements for an audience to focus on , chelsea walls is a triple-espresso endurance challenge . 0 272 | dense with characters and contains some thrilling moments . 1 273 | as unseemly as its title suggests . 1 274 | it 's like watching a nightmare made flesh . 0 275 | minority report is exactly what the title indicates , a report . 1 276 | it 's hard to like a film about a guy who is utterly unlikeable , and shiner , starring michael caine as an aging british boxing promoter desperate for a taste of fame and fortune , is certainly that . 0 277 | an entertaining , colorful , action-filled crime story with an intimate heart . 1 278 | for this reason and this reason only -- the power of its own steadfast , hoity-toity convictions -- chelsea walls deserves a medal . 1 279 | it just may inspire a few younger moviegoers to read stevenson 's book , which is a treasure in and of itself . 1 280 | basically a static series of semi-improvised ( and semi-coherent ) raps between the stars . 0 281 | ... with `` the bourne identity '' we return to the more traditional action genre . 1 282 | it 's so good that its relentless , polished wit can withstand not only inept school productions , but even oliver parker 's movie adaptation . 1 283 | chokes on its own depiction of upper-crust decorum . 0 284 | while there 's something intrinsically funny about sir anthony hopkins saying ` get in the car , bitch , ' this jerry bruckheimer production has little else to offer 1 285 | a rewarding work of art for only the most patient and challenge-hungry moviegoers . 1 286 | directed in a paint-by-numbers manner . 0 287 | k-19 exploits our substantial collective fear of nuclear holocaust to generate cheap hollywood tension . 0 288 | at its best , queen is campy fun like the vincent price horror classics of the '60s . 1 289 | it 's a much more emotional journey than what shyamalan has given us in his past two movies , and gibson , stepping in for bruce willis , is the perfect actor to take us on the trip . 1 290 | the quality of the art combined with the humor and intelligence of the script allow the filmmakers to present the biblical message of forgiveness without it ever becoming preachy or syrupy . 1 291 | cool ? 1 292 | deliriously funny , fast and loose , accessible to the uninitiated , and full of surprises 1 293 | even with a green mohawk and a sheet of fire-red flame tattoos covering his shoulder , however , kilmer seems to be posing , rather than acting . 0 294 | the story and the friendship proceeds in such a way that you 're watching a soap opera rather than a chronicle of the ups and downs that accompany lifelong friendships . 0 295 | at a time when half the so-called real movies are little more than live-action cartoons , it 's refreshing to see a cartoon that knows what it is , and knows the form 's history . 1 296 | the old-world - meets-new mesh is incarnated in the movie 's soundtrack , a joyful effusion of disco bollywood that , by the end of monsoon wedding , sent my spirit soaring out of the theater . 1 297 | jones ... does offer a brutal form of charisma . 1 298 | its well of thorn and vinegar ( and simple humanity ) has long been plundered by similar works featuring the insight and punch this picture so conspicuously lacks . 0 299 | travels a fascinating arc from hope and euphoria to reality and disillusionment . 1 300 | serving sara does n't serve up a whole lot of laughs . 0 301 | the sort of film that makes me miss hitchcock , but also feel optimistic that there 's hope for popular cinema yet . 1 302 | fun , flip and terribly hip bit of cinematic entertainment . 1 303 | the x potion gives the quickly named blossom , bubbles and buttercup supernatural powers that include extraordinary strength and laser-beam eyes , which unfortunately do n't enable them to discern flimsy screenplays . 0 304 | the wild thornberrys movie is a jolly surprise . 1 305 | entertains by providing good , lively company . 1 306 | a densely constructed , highly referential film , and an audacious return to form that can comfortably sit among jean-luc godard 's finest work . 1 307 | what was once original has been co-opted so frequently that it now seems pedestrian . 0 308 | the story and structure are well-honed . 1 309 | macdowell , whose wifty southern charm has anchored lighter affairs ... brings an absolutely riveting conviction to her role . 1 310 | an intriguing cinematic omnibus and round-robin that occasionally is more interesting in concept than in execution . 1 311 | the second coming of harry potter is a film far superior to its predecessor . 1 312 | if you can stomach the rough content , it 's worth checking out for the performances alone . 1 313 | a warm , funny , engaging film . 1 314 | i 'll bet the video game is a lot more fun than the film . 0 315 | the best film about baseball to hit theaters since field of dreams . 1 316 | it is great summer fun to watch arnold and his buddy gerald bounce off a quirky cast of characters . 1 317 | complete lack of originality , cleverness or even visible effort 0 318 | awesome creatures , breathtaking scenery , and epic battle scenes add up to another ` spectacular spectacle . ' 1 319 | all-in-all , the film is an enjoyable and frankly told tale of a people who live among us , but not necessarily with us . 1 320 | hit and miss as far as the comedy goes and a big ole ' miss in the way of story . 0 321 | too much of it feels unfocused and underdeveloped . 0 322 | a deep and meaningful film . 1 323 | but it could have been worse . 0 324 | that 's pure pr hype . 0 325 | a painfully funny ode to bad behavior . 1 326 | you 'll gasp appalled and laugh outraged and possibly , watching the spectacle of a promising young lad treading desperately in a nasty sea , shed an errant tear . 1 327 | liotta put on 30 pounds for the role , and has completely transformed himself from his smooth , goodfellas image . 1 328 | a beguiling splash of pastel colors and prankish comedy from disney . 1 329 | it proves quite compelling as an intense , brooding character study . 1 330 | an unwise amalgam of broadcast news and vibes . 0 331 | utterly lacking in charm , wit and invention , roberto benigni 's pinocchio is an astonishingly bad film . 0 332 | and that leaves a hole in the center of the salton sea . 0 333 | the chateau cleverly probes the cross-cultural differences between gauls and yanks . 1 334 | broomfield turns his distinctive ` blundering ' style into something that could really help clear up the case . 1 335 | a pleasant enough romance with intellectual underpinnings , the kind of movie that entertains even as it turns maddeningly predictable . 1 336 | what really makes it special is that it pulls us into its world , gives us a hero whose suffering and triumphs we can share , surrounds him with interesting characters and sends us out of the theater feeling we 've shared a great adventure . 1 337 | with the exception of some fleetingly amusing improvisations by cedric the entertainer as perry 's boss , there is n't a redeeming moment here . 0 338 | having had the good sense to cast actors who are , generally speaking , adored by the movie-going public , khouri then gets terrific performances from them all . 1 339 | ... a boring parade of talking heads and technical gibberish that will do little to advance the linux cause . 0 340 | it 's of the quality of a lesser harrison ford movie - six days , seven nights , maybe , or that dreadful sabrina remake . 0 341 | if you enjoy more thoughtful comedies with interesting conflicted characters ; this one is for you . 1 342 | the most hopelessly monotonous film of the year , noteworthy only for the gimmick of being filmed as a single unbroken 87-minute take . 0 343 | it deserves to be seen by anyone with even a passing interest in the events shaping the world beyond their own horizons . 1 344 | in an effort , i suspect , not to offend by appearing either too serious or too lighthearted , it offends by just being wishy-washy . 0 345 | no way i can believe this load of junk . 0 346 | there is a fabric of complex ideas here , and feelings that profoundly deepen them . 1 347 | this tenth feature is a big deal , indeed -- at least the third-best , and maybe even a notch above the previous runner-up , nicholas meyer 's star trek vi : the undiscovered country . 1 348 | not only unfunny , but downright repellent . 0 349 | works hard to establish rounded characters , but then has nothing fresh or particularly interesting to say about them . 0 350 | just one bad idea after another . 0 351 | ... turns so unforgivably trite in its last 10 minutes that anyone without a fortified sweet tooth will likely go into sugar shock . 0 352 | his comedy premises are often hackneyed or just plain crude , calculated to provoke shocked laughter , without following up on a deeper level . 0 353 | ( næs ) directed the stage version of elling , and gets fine performances from his two leads who originated the characters on stage . 1 354 | a swashbuckling tale of love , betrayal , revenge and above all , faith . 1 355 | for movie lovers as well as opera lovers , tosca is a real treat . 1 356 | the film is quiet , threatening and unforgettable . 1 357 | there is no pleasure in watching a child suffer . 0 358 | jason x is positively anti-darwinian : nine sequels and 400 years later , the teens are none the wiser and jason still kills on auto-pilot . 0 359 | stealing harvard aspires to comedic grand larceny but stands convicted of nothing more than petty theft of your time . 0 360 | ( d ) oes n't bother being as cloying or preachy as equivalent evangelical christian movies -- maybe the filmmakers know that the likely audience will already be among the faithful . 1 361 | displaying about equal amounts of naiveté , passion and talent , beneath clouds establishes sen as a filmmaker of considerable potential . 1 362 | ` easily my choice for one of the year 's best films . ' 1 363 | a very long movie , dull in stretches , with entirely too much focus on meal preparation and igloo construction . 0 364 | as a first-time director , paxton has tapped something in himself as an actor that provides frailty with its dark soul . 1 365 | it 's a grab bag of genres that do n't add up to a whole lot of sense . 0 366 | instead of hiding pinocchio from critics , miramax should have hidden it from everyone . 0 367 | portentous and pretentious , the weight of water is appropriately titled , given the heavy-handedness of it drama . 0 368 | altogether , this is successful as a film , while at the same time being a most touching reconsideration of the familiar masterpiece . 1 369 | there has always been something likable about the marquis de sade . 1 370 | the humor is forced and heavy-handed , and occasionally simply unpleasant . 0 371 | without ever becoming didactic , director carlos carrera expertly weaves this novelistic story of entangled interrelationships and complex morality . 1 372 | partway through watching this saccharine , easter-egg-colored concoction , you realize that it is made up of three episodes of a rejected tv show . 0 373 | for the most part , it 's a work of incendiary genius , steering clear of knee-jerk reactions and quick solutions . 1 374 | the special effects and many scenes of weightlessness look as good or better than in the original , while the oscar-winning sound and james horner 's rousing score make good use of the hefty audio system . 1 375 | not since freddy got fingered has a major release been so painful to sit through . 0 376 | the movie is what happens when you blow up small potatoes to 10 times their natural size , and it ai n't pretty . 0 377 | we have n't seen such hilarity since say it is n't so ! 1 378 | to call the other side of heaven `` appalling '' would be to underestimate just how dangerous entertainments like it can be . 0 379 | nothing is sacred in this gut-buster . 0 380 | feels haphazard , as if the writers mistakenly thought they could achieve an air of frantic spontaneity by simply tossing in lots of characters doing silly stuff and stirring the pot . 0 381 | tries to add some spice to its quirky sentiments but the taste is all too familiar . 0 382 | at its worst , it implodes in a series of very bad special effects . 0 383 | with tightly organized efficiency , numerous flashbacks and a constant edge of tension , miller 's film is one of 2002 's involvingly adult surprises . 1 384 | a great ensemble cast ca n't lift this heartfelt enterprise out of the familiar . 0 385 | a warm but realistic meditation on friendship , family and affection . 1 386 | at times , the suspense is palpable , but by the end there 's a sense that the crux of the mystery hinges on a technicality that strains credulity and leaves the viewer haunted by the waste of potential . 0 387 | while the resident evil games may have set new standards for thrills , suspense , and gore for video games , the movie really only succeeds in the third of these . 0 388 | it 's a remarkably solid and subtly satirical tour de force . 1 389 | director andrew niccol ... demonstrates a wry understanding of the quirks of fame . 1 390 | when leguizamo finally plugged an irritating character late in the movie . 0 391 | thekids will probably stay amused at the kaleidoscope of big , colorful characters . 1 392 | mattei is tiresomely grave and long-winded , as if circularity itself indicated profundity . 0 393 | ... plays like somebody spliced random moments of a chris rock routine into what is otherwise a cliche-riddled but self-serious spy thriller . 0 394 | an overemphatic , would-be wacky , ultimately tedious sex farce . 0 395 | it all adds up to good fun . 1 396 | whether writer-director anne fontaine 's film is a ghost story , an account of a nervous breakdown , a trip down memory lane , all three or none of the above , it is as seductive as it is haunting . 1 397 | another in-your-face wallow in the lower depths made by people who have never sung those blues . 0 398 | a very well-made , funny and entertaining picture . 1 399 | it 's worth seeing just on the basis of the wisdom , and at times , the startling optimism , of the children . 1 400 | despite its title , punch-drunk love is never heavy-handed . 1 401 | if director michael dowse only superficially understands his characters , he does n't hold them in contempt . 0 402 | it 's refreshing to see a girl-power movie that does n't feel it has to prove anything . 1 403 | the film may appear naked in its narrative form ... but it goes deeper than that , to fundamental choices that include the complexity of the catholic doctrine 1 404 | however it may please those who love movies that blare with pop songs , young science fiction fans will stomp away in disgust . 0 405 | as vulgar as it is banal . 0 406 | zhang ... has done an amazing job of getting realistic performances from his mainly nonprofessional cast . 1 407 | outer-space buffs might love this film , but others will find its pleasures intermittent . 0 408 | maud and roland 's search for an unknowable past makes for a haunting literary detective story , but labute pulls off a neater trick in possession : he makes language sexy . 1 409 | more whiny downer than corruscating commentary . 0 410 | there are simply too many ideas floating around -- part farce , part sliding doors , part pop video -- and yet failing to exploit them . 0 411 | it 's another stale , kill-by-numbers flick , complete with blade-thin characters and terrible , pun-laden dialogue . 0 412 | what distinguishes time of favor from countless other thrillers is its underlying concern with the consequences of words and with the complicated emotions fueling terrorist acts . 1 413 | i do n't mind having my heartstrings pulled , but do n't treat me like a fool . 0 414 | the movie 's accumulated force still feels like an ugly knot tightening in your stomach . 0 415 | at least one scene is so disgusting that viewers may be hard pressed to retain their lunch . 0 416 | it has charm to spare , and unlike many romantic comedies , it does not alienate either gender in the audience . 1 417 | an operatic , sprawling picture that 's entertainingly acted , magnificently shot and gripping enough to sustain most of its 170-minute length . 1 418 | a giggle a minute . 1 419 | uses sharp humor and insight into human nature to examine class conflict , adolescent yearning , the roots of friendship and sexual identity . 1 420 | the continued good chemistry between carmen and juni is what keeps this slightly disappointing sequel going , with enough amusing banter -- blessedly curse-free -- to keep both kids and parents entertained . 1 421 | i 'm just too bored to care . 0 422 | one of the more irritating cartoons you will see this , or any , year . 0 423 | it 's one heck of a character study -- not of hearst or davies but of the unique relationship between them . 1 424 | it moves quickly , adroitly , and without fuss ; it does n't give you time to reflect on the inanity -- and the cold war datedness -- of its premise . 1 425 | i am sorry that i was unable to get the full brunt of the comedy . 0 426 | a good piece of work more often than not . 1 427 | while the ideas about techno-saturation are far from novel , they 're presented with a wry dark humor . 1 428 | charles ' entertaining film chronicles seinfeld 's return to stand-up comedy after the wrap of his legendary sitcom , alongside wannabe comic adams ' attempts to get his shot at the big time . 1 429 | an exhilarating futuristic thriller-noir , minority report twists the best of technology around a gripping story , delivering a riveting , pulse intensifying escapist adventure of the first order 1 430 | beautifully observed , miraculously unsentimental comedy-drama . 1 431 | the film 's hackneyed message is not helped by the thin characterizations , nonexistent plot and pretentious visual style . 0 432 | a breezy romantic comedy that has the punch of a good sitcom , while offering exceptionally well-detailed characters . 1 433 | should have been someone else - 0 434 | coughs and sputters on its own postmodern conceit . 0 435 | the lion king was a roaring success when it was released eight years ago , but on imax it seems better , not just bigger . 1 436 | almost gags on its own gore . 0 437 | a marvel like none you 've seen . 1 438 | trite , banal , cliched , mostly inoffensive . 0 439 | immersing us in the endlessly inventive , fiercely competitive world of hip-hop djs , the project is sensational and revelatory , even if scratching makes you itch . 1 440 | the movie has an infectious exuberance that will engage anyone with a passing interest in the skate/surf culture , the l.a. beach scene and the imaginative ( and sometimes illegal ) ways kids can make a playground out of the refuse of adults . 1 441 | yakusho and shimizu ... create engaging characterizations in imamura 's lively and enjoyable cultural mix . 1 442 | this is wild surreal stuff , but brilliant and the camera just kind of sits there and lets you look at this and its like you 're going from one room to the next and none of them have any relation to the other . 1 443 | there is very little dread or apprehension , and though i like the creepy ideas , they are not executed with anything more than perfunctory skill . 0 444 | the notion that bombing buildings is the funniest thing in the world goes entirely unexamined in this startlingly unfunny comedy . 0 445 | good car chases , great fight scenes , and a distinctive blend of european , american and asian influences . 1 446 | the last 20 minutes are somewhat redeeming , but most of the movie is the same teenage american road-trip drek we 've seen before - only this time you have to read the fart jokes 0 447 | even in its most tedious scenes , russian ark is mesmerizing . 1 448 | with its dogged hollywood naturalism and the inexorable passage of its characters toward sainthood , windtalkers is nothing but a sticky-sweet soap . 0 449 | generally , clockstoppers will fulfill your wildest fantasies about being a different kind of time traveler , while happily killing 94 minutes . 1 450 | something akin to a japanese alice through the looking glass , except that it seems to take itself far more seriously . 1 451 | oh come on . 0 452 | a moody , multi-dimensional love story and sci-fi mystery , solaris is a thought-provoking , haunting film that allows the seeds of the imagination to germinate . 1 453 | not only are the special effects and narrative flow much improved , and daniel radcliffe more emotionally assertive this time around as harry , but the film conjures the magic of author j.k. rowling 's books . 1 454 | it 's clear the filmmakers were n't sure where they wanted their story to go , and even more clear that they lack the skills to get us to this undetermined destination . 0 455 | ( t ) his beguiling belgian fable , very much its own droll and delicate little film , has some touching things to say about what is important in life and why . 1 456 | even on those rare occasions when the narrator stops yammering , miller 's hand often feels unsure . 0 457 | ( chaiken 's ) talent lies in an evocative , accurate observation of a distinctive milieu and in the lively , convincing dialogue she creates for her characters . 1 458 | sticky sweet sentimentality , clumsy plotting and a rosily myopic view of life in the wwii-era mississippi delta undermine this adaptation . 0 459 | it inspires a continuing and deeply satisfying awareness of the best movies as monumental ` picture shows . ' 1 460 | featuring a dangerously seductive performance from the great daniel auteuil , `` sade '' covers the same period as kaufmann 's `` quills '' with more unsettlingly realistic results . 1 461 | gives you the steady pulse of life in a beautiful city viewed through the eyes of a character who , in spite of tragic loss and increasing decrepitude , knows in his bones that he is one of the luckiest men alive . 1 462 | if you are an actor who can relate to the search for inner peace by dramatically depicting the lives of others onstage , then esther 's story is a compelling quest for truth . 1 463 | it 's too bad that the helping hand he uses to stir his ingredients is also a heavy one . 0 464 | yes , dull . 0 465 | some of their jokes work , but most fail miserably and in the end , pumpkin is far more offensive than it is funny . 0 466 | intriguing documentary which is emotionally diluted by focusing on the story 's least interesting subject . 1 467 | shaky close-ups of turkey-on-rolls , stubbly chins , liver spots , red noses and the filmmakers new bobbed do draw easy chuckles but lead nowhere . 0 468 | the inspirational screenplay by mike rich covers a lot of ground , perhaps too much , but ties things together , neatly , by the end . 1 469 | ramsay , as in ratcatcher , remains a filmmaker with an acid viewpoint and a real gift for teasing chilly poetry out of lives and settings that might otherwise seem drab and sordid . 1 470 | the characters are interesting and often very creatively constructed from figure to backstory . 1 471 | so unremittingly awful that labeling it a dog probably constitutes cruelty to canines . 0 472 | reggio 's continual visual barrage is absorbing as well as thought-provoking . 1 473 | adults will wish the movie were less simplistic , obvious , clumsily plotted and shallowly characterized . 0 474 | you will emerge with a clearer view of how the gears of justice grind on and the death report comes to share airtime alongside the farm report . 1 475 | thanks to haynes ' absolute control of the film 's mood , and buoyed by three terrific performances , far from heaven actually pulls off this stylistic juggling act . 1 476 | the problem with this film is that it lacks focus . 0 477 | belongs to daniel day-lewis as much as it belongs to martin scorsese ; it 's a memorable performance in a big , brassy , disturbing , unusual and highly successful film . 1 478 | involves two mysteries -- one it gives away and the other featuring such badly drawn characters that its outcome hardly matters . 0 479 | a tv style murder mystery with a few big screen moments ( including one that seems to be made for a different film altogether ) . 0 480 | a by-the-numbers patient/doctor pic that covers all the usual ground 0 481 | it 's a stunning lyrical work of considerable force and truth . 1 482 | while undisputed is n't exactly a high , it is a gripping , tidy little movie that takes mr. hill higher than he 's been in a while . 1 483 | funny but perilously slight . 1 484 | cq 's reflection of artists and the love of cinema-and-self suggests nothing less than a new voice that deserves to be considered as a possible successor to the best european directors . 1 485 | even if you do n't think ( kissinger 's ) any more guilty of criminal activity than most contemporary statesmen , he 'd sure make a courtroom trial great fun to watch . 1 486 | dazzling in its complexity , disturbing for its extraordinary themes , the piano teacher is a film that defies categorisation . 1 487 | a literate presentation that wonderfully weaves a murderous event in 1873 with murderous rage in 2002 . 1 488 | the script is n't very good ; not even someone as gifted as hoffman ( the actor ) can make it work . 0 489 | ( e ) ventually , every idea in this film is flushed down the latrine of heroism . 0 490 | a cartoon that 's truly cinematic in scope , and a story that 's compelling and heartfelt -- even if the heart belongs to a big , four-legged herbivore . 1 491 | it 's dumb , but more importantly , it 's just not scary . 0 492 | detox is ultimately a pointless endeavor . 0 493 | as a rumor of angels reveals itself to be a sudsy tub of supernatural hokum , not even ms. redgrave 's noblest efforts can redeem it from hopeless sentimentality . 0 494 | an exquisitely crafted and acted tale . 1 495 | this is so bad . 0 496 | it showcases carvey 's talent for voices , but not nearly enough and not without taxing every drop of one 's patience to get to the good stuff . 0 497 | light years / several warp speeds / levels and levels of dilithium crystals better than the pitiful insurrection . 1 498 | it 's about following your dreams , no matter what your parents think . 1 499 | the overall effect is less like a children 's movie than a recruitment film for future hollywood sellouts . 0 500 | anchored by friel and williams 's exceptional performances , the film 's power lies in its complexity . 1 501 | unlike the speedy wham-bam effect of most hollywood offerings , character development -- and more importantly , character empathy -- is at the heart of italian for beginners . 1 502 | a sequel that 's much too big for its britches . 0 503 | harrison 's flowers puts its heart in the right place , but its brains are in no particular place at all . 1 504 | deadeningly dull , mired in convoluted melodrama , nonsensical jargon and stiff-upper-lip laboriousness . 0 505 | dragonfly has no atmosphere , no tension -- nothing but costner , flailing away . 0 506 | the film is powerful , accessible and funny . 1 507 | and that 's a big part of why we go to the movies . 1 508 | crackerjack entertainment -- nonstop romance , music , suspense and action . 1 509 | the minor figures surrounding ( bobby ) ... form a gritty urban mosaic . 1 510 | a poignant and compelling story about relationships , food of love takes us on a bumpy but satisfying journey of the heart . 1 511 | a movie that successfully crushes a best selling novel into a timeframe that mandates that you avoid the godzilla sized soda . 1 512 | the vivid lead performances sustain interest and empathy , but the journey is far more interesting than the final destination . 1 513 | lapaglia 's ability to convey grief and hope works with weaver 's sensitive reactions to make this a two-actor master class . 1 514 | villeneuve spends too much time wallowing in bibi 's generic angst ( there are a lot of shots of her gazing out windows ) . 0 515 | care deftly captures the wonder and menace of growing up , but he never really embraces the joy of fuhrman 's destructive escapism or the grace-in-rebellion found by his characters . 0 516 | this is an egotistical endeavor from the daughter of horror director dario argento ( a producer here ) , but her raw performance and utter fearlessness make it strangely magnetic . 1 517 | if looking for a thrilling sci-fi cinematic ride , do n't settle for this imposter . 0 518 | plays like a volatile and overlong w magazine fashion spread . 0 519 | not far beneath the surface , this reconfigured tale asks disturbing questions about those things we expect from military epics . 1 520 | michael gerbosi 's script is economically packed with telling scenes . 1 521 | moretti 's compelling anatomy of grief and the difficult process of adapting to loss . 0 522 | so refreshingly incisive is grant that for the first time he 'll probably appeal more to guys than to their girlfriends who drag them to this movie for the hugh factor . 1 523 | comes off like a rejected abc afterschool special , freshened up by the dunce of a screenwriting 101 class . 0 524 | it has its moments of swaggering camaraderie , but more often just feels generic , derivative and done to death . 0 525 | a romantic comedy enriched by a sharp eye for manners and mores . 1 526 | the fly-on-the-wall method used to document rural french school life is a refreshing departure from the now more prevalent technique of the docu-makers being a visible part of their work . 1 527 | rare birds has more than enough charm to make it memorable . 1 528 | it 's a bit disappointing that it only manages to be decent instead of dead brilliant . 0 529 | it has all the excitement of eating oatmeal . 0 530 | it haunts , horrifies , startles and fascinates ; it is impossible to look away . 1 531 | for close to two hours the audience is forced to endure three terminally depressed , mostly inarticulate , hyper dysfunctional families for the price of one . 0 532 | a superbly acted and funny/gritty fable of the humanizing of one woman at the hands of the unseen forces of fate . 1 533 | ( t ) here 's only so much anyone can do with a florid , overplotted , anne rice rock 'n' roll vampire novel before the built-in silliness of the whole affair defeats them . 0 534 | for anyone unfamiliar with pentacostal practices in general and theatrical phenomenon of hell houses in particular , it 's an eye-opener . 1 535 | `` mostly martha '' is a bright , light modern day family parable that wears its heart on its sleeve for all to see . 1 536 | i just loved every minute of this film . 1 537 | a quiet , pure , elliptical film 1 538 | a disappointment for those who love alternate versions of the bard , particularly ones that involve deep fryers and hamburgers . 0 539 | a simple , but gritty and well-acted ensemble drama that encompasses a potent metaphor for a country still dealing with its fascist past . 1 540 | it 's so mediocre , despite the dynamic duo on the marquee , that we just ca n't get no satisfaction . 0 541 | do not see this film . 0 542 | binoche makes it interesting trying to find out . 1 543 | the most compelling wiseman epic of recent years . 1 544 | there 's no emotional pulse to solaris . 0 545 | for each chuckle there are at least 10 complete misses , many coming from the amazingly lifelike tara reid , whose acting skills are comparable to a cardboard cutout . 0 546 | although huppert 's intensity and focus has a raw exhilaration about it , the piano teacher is anything but fun . 0 547 | from the opening scenes , it 's clear that all about the benjamins is a totally formulaic movie . 0 548 | on the heels of the ring comes a similarly morose and humorless horror movie that , although flawed , is to be commended for its straight-ahead approach to creepiness . 1 549 | the film is based on truth and yet there is something about it that feels incomplete , as if the real story starts just around the corner . 0 550 | i 've always dreamed of attending cannes , but after seeing this film , it 's not that big a deal . 0 551 | a coarse and stupid gross-out . 0 552 | ... nothing scary here except for some awful acting and lame special effects . 0 553 | nothing in waking up in reno ever inspired me to think of its inhabitants as anything more than markers in a screenplay . 0 554 | here 's yet another studio horror franchise mucking up its storyline with glitches casual fans could correct in their sleep . 0 555 | so unassuming and pure of heart , you ca n't help but warmly extend your arms and yell ` safe ! ' 1 556 | it treats women like idiots . 0 557 | ... plot holes so large and obvious a marching band might as well be stomping through them in clown clothes , playing a college football fight song on untuned instruments . 0 558 | but it 's too long and too convoluted and it ends in a muddle . 0 559 | one of the best films of the year with its exploration of the obstacles to happiness faced by five contemporary individuals ... a psychological masterpiece . 1 560 | although german cooking does not come readily to mind when considering the world 's best cuisine , mostly martha could make deutchland a popular destination for hungry tourists . 1 561 | if the first men in black was money , the second is small change . 0 562 | do n't be fooled by the impressive cast list - eye see you is pure junk . 0 563 | another one of those estrogen overdose movies like `` divine secrets of the ya ya sisterhood , '' except that the writing , acting and character development are a lot better . 1 564 | scorsese does n't give us a character worth giving a damn about . 0 565 | with rabbit-proof fence , noyce has tailored an epic tale into a lean , economical movie . 1 566 | the plot convolutions ultimately add up to nothing more than jerking the audience 's chain . 0 567 | there are some wonderfully fresh moments that smooth the moral stiffness with human kindness and hopefulness . 1 568 | does little more than play an innocuous game of fill-in - the-blanks with a tragic past . 0 569 | feature debuter d.j. caruso directs a crack ensemble cast , bringing screenwriter tony gayton 's narcotics noir to life . 1 570 | it does nothing new with the old story , except to show fisticuffs in this sort of stop-go slow motion that makes the gang rumbles look like they 're being streamed over a 28k modem . 0 571 | one of those energetic surprises , an original that pleases almost everyone who sees it . 1 572 | seldahl 's barbara is a precise and moving portrait of someone whose world is turned upside down , first by passion and then by illness . 1 573 | passable entertainment , but it 's the kind of motion picture that wo n't make much of a splash when it 's released , and will not be remembered long afterwards . 0 574 | the film 's tone and pacing are off almost from the get-go . 0 575 | lovely and poignant . 1 576 | a broad , melodramatic estrogen opera that 's pretty toxic in its own right . 0 577 | ( director ) o'fallon manages to put some lovely pictures up on the big screen , but his skill at telling a story -- he also contributed to the screenplay -- falls short . 0 578 | offers very little genuine romance and even fewer laughs ... a sad sitcom of a movie , largely devoid of charm . 0 579 | though only 60 minutes long , the film is packed with information and impressions . 1 580 | just not campy enough 0 581 | every dance becomes about seduction , where backstabbing and betrayals are celebrated , and sex is currency . 0 582 | it takes a certain kind of horror movie to qualify as ` worse than expected , ' but ghost ship somehow manages to do exactly that . 0 583 | it can not be enjoyed , even on the level that one enjoys a bad slasher flick , primarily because it is dull . 0 584 | despite all evidence to the contrary , this clunker has somehow managed to pose as an actual feature movie , the kind that charges full admission and gets hyped on tv and purports to amuse small children and ostensible adults . 0 585 | it 's just filler . 0 586 | a hamfisted romantic comedy that makes our girl the hapless facilitator of an extended cheap shot across the mason-dixon line . 0 587 | one of those pictures whose promising , if rather precious , premise is undercut by amateurish execution . 0 588 | the humor is n't as sharp , the effects not as innovative , nor the story as imaginative as in the original . 0 589 | director uwe boll and the actors provide scant reason to care in this crude '70s throwback . 0 590 | ... a story we have n't seen on the big screen before , and it 's a story that we as americans , and human beings , should know . 1 591 | if your taste runs to ` difficult ' films you absolutely ca n't miss it . 1 592 | this movie is something of an impostor itself , stretching and padding its material in a blur of dead ends and distracting camera work . 0 593 | i got a headache watching this meaningless downer . 0 594 | zaidan 's script has barely enough plot to string the stunts together and not quite enough characterization to keep the faces straight . 0 595 | the terrific and bewilderingly underrated campbell scott gives a star performance that is nothing short of mesmerizing . 1 596 | building slowly and subtly , the film , sporting a breezy spontaneity and realistically drawn characterizations , develops into a significant character study that is both moving and wise . 1 597 | all the amped-up tony hawk-style stunts and thrashing rap-metal ca n't disguise the fact that , really , we 've been here , done that . 0 598 | the director knows how to apply textural gloss , but his portrait of sex-as-war is strictly sitcom . 0 599 | visually rather stunning , but ultimately a handsome-looking bore , the true creativity would have been to hide treasure planet entirely and completely reimagine it . 0 600 | nonsensical , dull `` cyber-horror '' flick is a grim , hollow exercise in flat scares and bad acting . 0 601 | big fat waste of time . 0 602 | professionally speaking , it 's tempting to jump ship in january to avoid ridiculous schlock like this shoddy suspense thriller . 0 603 | fancy a real downer ? 0 604 | instead , he shows them the respect they are due . 1 605 | first-time writer-director serry shows a remarkable gift for storytelling with this moving , effective little film . 1 606 | vera 's technical prowess ends up selling his film short ; he smoothes over hard truths even as he uncovers them . 0 607 | puts a human face on a land most westerners are unfamiliar with . 1 608 | makes for a pretty unpleasant viewing experience . 0 609 | ahhhh ... revenge is sweet ! 1 610 | highbrow self-appointed guardians of culture need not apply , but those who loved cool as ice have at last found a worthy follow-up . 1 611 | nine queens is not only than a frighteningly capable debut and genre piece , but also a snapshot of a dangerous political situation on the verge of coming to a head . 1 612 | like you could n't smell this turkey rotting from miles away . 0 613 | the performances take the movie to a higher level . 1 614 | davis ... is so enamored of her own creation that she ca n't see how insufferable the character is . 0 615 | ... takes the beauty of baseball and melds it with a story that could touch anyone regardless of their familiarity with the sport 1 616 | against all odds in heaven and hell , it creeped me out just fine . 1 617 | as the latest bid in the tv-to-movie franchise game , i spy makes its big-screen entry with little of the nervy originality of its groundbreaking small-screen progenitor . 0 618 | there 's really only one good idea in this movie , but the director runs with it and presents it with an unforgettable visual panache . 1 619 | the so-inept - it 's - surreal dubbing ( featuring the voices of glenn close , regis philbin and breckin meyer ) brings back memories of cheesy old godzilla flicks . 0 620 | without non-stop techno or the existential overtones of a kieslowski morality tale , maelström is just another winter sleepers . 0 621 | an unclassifiably awful study in self - and audience-abuse . 0 622 | the moviegoing equivalent of going to a dinner party and being forced to watch the host and hostess 's home video of their baby 's birth . 0 623 | kinnear does n't aim for our sympathy , but rather delivers a performance of striking skill and depth . 1 624 | few films capture so perfectly the hopes and dreams of little boys on baseball fields as well as the grown men who sit in the stands . 1 625 | it is amusing , and that 's all it needs to be . 1 626 | challenging , intermittently engrossing and unflaggingly creative . 1 627 | for the most part stevens glides through on some solid performances and witty dialogue . 1 628 | this flick is about as cool and crowd-pleasing as a documentary can get . 1 629 | nervous breakdowns are not entertaining . 0 630 | writer-director 's mehta 's effort has tons of charm and the whimsy is in the mixture , the intoxicating masala , of cultures and film genres . 1 631 | excessive , profane , packed with cartoonish violence and comic-strip characters . 0 632 | a taut psychological thriller that does n't waste a moment of its two-hour running time . 1 633 | burns never really harnesses to full effect the energetic cast . 0 634 | just embarrassment and a vague sense of shame . 0 635 | still , as a visual treat , the film is almost unsurpassed . 1 636 | harris commands the screen , using his frailty to suggest the ravages of a life of corruption and ruthlessness . 1 637 | or emptying rat traps . 0 638 | offers much to enjoy ... and a lot to mull over in terms of love , loyalty and the nature of staying friends . 1 639 | the piquant story needs more dramatic meat on its bones . 0 640 | my wife is an actress is an utterly charming french comedy that feels so american in sensibility and style it 's virtually its own hollywood remake . 1 641 | indifferently implausible popcorn programmer of a movie . 0 642 | an important movie , a reminder of the power of film to move us and to make us examine our values . 1 643 | the magic of the film lies not in the mysterious spring but in the richness of its performances . 1 644 | this re-do is so dumb and so exploitative in its violence that , ironically , it becomes everything that the rather clumsy original was railing against . 0 645 | the jabs it employs are short , carefully placed and dead-center . 1 646 | while locals will get a kick out of spotting cleveland sites , the rest of the world will enjoy a fast-paced comedy with quirks that might make the award-winning coen brothers envious . 1 647 | the words , ` frankly , my dear , i do n't give a damn , ' have never been more appropriate . 0 648 | the longer the movie goes , the worse it gets , but it 's actually pretty good in the first few minutes . 0 649 | too much of the humor falls flat . 0 650 | further proof that the epicenter of cool , beautiful , thought-provoking foreign cinema is smack-dab in the middle of dubya 's axis of evil . 1 651 | the film 's few ideas are stretched to the point of evaporation ; the whole central section is one big chase that seems to have no goal and no urgency . 0 652 | too slow , too long and too little happens . 0 653 | due to some script weaknesses and the casting of the director 's brother , the film trails off into inconsequentiality . 0 654 | very bad . 0 655 | a lackluster , unessential sequel to the classic disney adaptation of j.m. barrie 's peter pan . 0 656 | a science-fiction pastiche so lacking in originality that if you stripped away its inspirations there would be precious little left . 0 657 | birthday girl is an amusing joy ride , with some surprisingly violent moments . 1 658 | so much facile technique , such cute ideas , so little movie . 1 659 | expect the same-old , lame-old slasher nonsense , just with different scenery . 0 660 | a smart , witty follow-up . 1 661 | chabrol has taken promising material for a black comedy and turned it instead into a somber chamber drama . 0 662 | like mike is a winner for kids , and no doubt a winner for lil bow wow , who can now add movies to the list of things he does well . 1 663 | it 's another video movie photographed like a film , with the bad lighting that 's often written off as indie film naturalism . 0 664 | there 's not enough here to justify the almost two hours . 0 665 | it will grip even viewers who are n't interested in rap , as it cuts to the heart of american society in an unnerving way . 1 666 | the film is beautifully mounted , but , more to the point , the issues are subtly presented , managing to walk a fine line with regard to the question of joan 's madness . 1 667 | and if you 're not nearly moved to tears by a couple of scenes , you 've got ice water in your veins . 1 668 | richard gere and diane lane put in fine performances as does french actor oliver martinez . 1 669 | good film , but very glum . 1 670 | there are plot holes big enough for shamu the killer whale to swim through . 0 671 | preaches to two completely different choirs at the same time , which is a pretty amazing accomplishment . 1 672 | verbinski implements every hack-artist trick to give us the ooky-spookies . 0 673 | two hours fly by -- opera 's a pleasure when you do n't have to endure intermissions -- and even a novice to the form comes away exhilarated . 1 674 | in all , this is a watchable movie that 's not quite the memorable experience it might have been . 0 675 | drops you into a dizzying , volatile , pressure-cooker of a situation that quickly snowballs out of control , while focusing on the what much more than the why . 1 676 | atom egoyan has conjured up a multilayered work that tackles any number of fascinating issues 1 677 | slick piece of cross-promotion . 1 678 | well-nigh unendurable ... though the picture strains to become cinematic poetry , it remains depressingly prosaic and dull . 0 679 | majidi is an unconventional storyteller , capable of finding beauty in the most depressing places . 1 680 | the movie is n't just hilarious : it 's witty and inventive , too , and in hindsight , it is n't even all that dumb . 1 681 | filmmakers who can deftly change moods are treasures and even marvels . 1 682 | the vitality of the actors keeps the intensity of the film high , even as the strafings blend together . 1 683 | ( a ) shapeless blob of desperate entertainment . 0 684 | in the end , the movie collapses on its shaky foundation despite the best efforts of director joe carnahan . 0 685 | true tale of courage -- and complicity -- at auschwitz is a harrowing drama that tries to tell of the unspeakable . 1 686 | a study in shades of gray , offering itself up in subtle plot maneuvers ... 1 687 | no screen fantasy-adventure in recent memory has the showmanship of clones ' last 45 minutes . 1 688 | more romantic , more emotional and ultimately more satisfying than the teary-eyed original . 1 689 | this is a shameless sham , calculated to cash in on the popularity of its stars . 0 690 | if you 've ever entertained the notion of doing what the title of this film implies , what sex with strangers actually shows may put you off the idea forever . 0 691 | once the 50 year old benigni appears as the title character , we find ourselves longing for the block of wood to come back . 0 692 | stultifyingly , dumbfoundingly , mind-numbingly bad . 0 693 | an effectively creepy , fear-inducing ( not fear-reducing ) film from japanese director hideo nakata , who takes the superstitious curse on chain letters and actually applies it . 1 694 | sustains its dreamlike glide through a succession of cheesy coincidences and voluptuous cheap effects , not the least of which is rebecca romijn-stamos . 0 695 | because of an unnecessary and clumsy last scene , ` swimfan ' left me with a very bad feeling . 0 696 | no aspirations to social import inform the movie version . 0 697 | sit through this one , and you wo n't need a magic watch to stop time ; your dvd player will do it for you . 0 698 | for the first time in years , de niro digs deep emotionally , perhaps because he 's been stirred by the powerful work of his co-stars . 1 699 | not since tom cruise in risky business has an actor made such a strong impression in his underwear . 1 700 | an interesting story with a pertinent ( cinematically unique ) message , told fairly well and scored to perfection , i found myself struggling to put my finger on that elusive `` missing thing . '' 1 701 | ... routine , harmless diversion and little else . 1 702 | is the time really ripe for a warmed-over james bond adventure , with a village idiot as the 007 clone ? 0 703 | even the finest chef ca n't make a hotdog into anything more than a hotdog , and robert de niro ca n't make this movie anything more than a trashy cop buddy comedy . 0 704 | the reality of the new live-action pinocchio he directed , cowrote and starred in borders on the grotesque . 0 705 | samira makhmalbaf 's new film blackboards is much like the ethos of a stream of consciousness , although , it 's unfortunate for the viewer that the thoughts and reflections coming through are torpid and banal 0 706 | a bloated gasbag thesis grotesquely impressed by its own gargantuan aura of self-importance ... 0 707 | every time you look , sweet home alabama is taking another bummer of a wrong turn . 0 708 | how do you spell cliché ? 0 709 | no telegraphing is too obvious or simplistic for this movie . 0 710 | director of photography benoit delhomme shot the movie in delicious colors , and the costumes and sets are grand . 1 711 | i thought my own watch had stopped keeping time as i slogged my way through clockstoppers . 0 712 | less dizzying than just dizzy , the jaunt is practically over before it begins . 0 713 | overall the film feels like a low-budget tv pilot that could not find a buyer to play it on the tube . 0 714 | they should have called it gutterball . 0 715 | corpus collosum -- while undeniably interesting -- wore out its welcome well before the end credits rolled about 45 minutes in . 0 716 | ( a ) n utterly charming and hilarious film that reminded me of the best of the disney comedies from the 60s . 1 717 | there 's too much falseness to the second half , and what began as an intriguing look at youth fizzles into a dull , ridiculous attempt at heart-tugging . 0 718 | wince-inducing dialogue , thrift-shop costumes , prosthetic makeup by silly putty and kmart blue-light-special effects all conspire to test trekkie loyalty . 0 719 | a rigorously structured and exquisitely filmed drama about a father and son connection that is a brief shooting star of love . 1 720 | this is human comedy at its most amusing , interesting and confirming . 1 721 | whaley 's determination to immerse you in sheer , unrelenting wretchedness is exhausting . 0 722 | worth watching for dong jie 's performance -- and for the way it documents a culture in the throes of rapid change . 1 723 | a poignant , artfully crafted meditation on mortality . 1 724 | too restrained to be a freak show , too mercenary and obvious to be cerebral , too dull and pretentious to be engaging ... the isle defies an easy categorization . 0 725 | vera 's three actors -- mollà , gil and bardem -- excel in insightful , empathetic performances . 1 726 | it 's everything you do n't go to the movies for . 0 727 | ( grant 's ) bumbling magic takes over the film , and it turns out to be another winning star vehicle . 1 728 | it gets onto the screen just about as much of the novella as one could reasonably expect , and is engrossing and moving in its own right . 1 729 | velocity represents everything wrong with '' independent film '' as a commodified , sold-out concept on the american filmmaking scene . 0 730 | but taken as a stylish and energetic one-shot , the queen of the damned can not be said to suck . 1 731 | the piece plays as well as it does thanks in large measure to anspaugh 's three lead actresses . 1 732 | suffocated by its fussy script and uptight characters , this musty adaptation is all the more annoying since it 's been packaged and sold back to us by hollywood . 0 733 | but what are adults doing in the theater at all ? 0 734 | like being trapped at a perpetual frat party ... how can something so gross be so boring ? 0 735 | the man from elysian fields is a cold , bliss-less work that groans along thinking itself some important comment on how life throws us some beguiling curves . 0 736 | this is n't even madonna 's swept away . 0 737 | the experience of going to a film festival is a rewarding one ; the experiencing of sampling one through this movie is not . 0 738 | american chai encourages rueful laughter at stereotypes only an indian-american would recognize . 0 739 | my big fat greek wedding uses stereotypes in a delightful blend of sweet romance and lovingly dished out humor . 1 740 | ... a magnificent drama well worth tracking down . 1 741 | oscar wilde 's masterpiece , the importance of being earnest , may be the best play of the 19th century . 1 742 | jose campanella delivers a loosely autobiographical story brushed with sentimentality but brimming with gentle humor , bittersweet pathos , and lyric moments that linger like snapshots of memory . 1 743 | but it still jingles in the pocket . 1 744 | it 's a demented kitsch mess ( although the smeary digital video does match the muddled narrative ) , but it 's savvy about celebrity and has more guts and energy than much of what will open this year . 1 745 | you really have to wonder how on earth anyone , anywhere could have thought they 'd make audiences guffaw with a script as utterly diabolical as this . 0 746 | one from the heart . 1 747 | made with no discernible craft and monstrously sanctimonious in dealing with childhood loss . 0 748 | people cinema at its finest . 1 749 | what 's surprising about full frontal is that despite its overt self-awareness , parts of the movie still manage to break past the artifice and thoroughly engage you . 1 750 | and when you 're talking about a slapstick comedy , that 's a pretty big problem . 0 751 | a working class `` us vs. them '' opera that leaves no heartstring untugged and no liberal cause unplundered . 1 752 | nelson 's brutally unsentimental approach ... sucks the humanity from the film , leaving behind an horrific but weirdly unemotional spectacle . 0 753 | one long string of cliches . 0 754 | like watching a dress rehearsal the week before the show goes up : everything 's in place but something 's just a little off-kilter . 0 755 | it 's hard to imagine alan arkin being better than he is in this performance . 1 756 | the film will play equally well on both the standard and giant screens . 1 757 | ... a fun little timewaster , helped especially by the cool presence of jean reno . 1 758 | something like scrubbing the toilet . 0 759 | there 's enough melodrama in this magnolia primavera to make pta proud yet director muccino 's characters are less worthy of puccini than they are of daytime television . 0 760 | may be far from the best of the series , but it 's assured , wonderfully respectful of its past and thrilling enough to make it abundantly clear that this movie phenomenon has once again reinvented itself for a new generation . 1 761 | if there 's one thing this world needs less of , it 's movies about college that are written and directed by people who could n't pass an entrance exam . 0 762 | writer/director joe carnahan 's grimy crime drama is a manual of precinct cliches , but it moves fast enough to cover its clunky dialogue and lapses in logic . 1 763 | the result is a gaudy bag of stale candy , something from a halloween that died . 0 764 | it 's a lovely film with lovely performances by buy and accorsi . 1 765 | there 's something auspicious , and daring , too , about the artistic instinct that pushes a majority-oriented director like steven spielberg to follow a.i. with this challenging report so liable to unnerve the majority . 1 766 | movie fans , get ready to take off ... the other direction . 0 767 | not an objectionable or dull film ; it merely lacks everything except good intentions . 0 768 | while its careful pace and seemingly opaque story may not satisfy every moviegoer 's appetite , the film 's final scene is soaringly , transparently moving . 1 769 | a film about a young man finding god that is accessible and touching to the marrow . 1 770 | a compelling spanish film about the withering effects of jealousy in the life of a young monarch whose sexual passion for her husband becomes an obsession . 1 771 | an infectious cultural fable with a tasty balance of family drama and frenetic comedy . 1 772 | i do n't think i laughed out loud once . 0 773 | very special effects , brilliantly bold colors and heightened reality ca n't hide the giant achilles ' heel in `` stuart little 2 `` : there 's just no story , folks . 0 774 | not the kind of film that will appeal to a mainstream american audience , but there is a certain charm about the film that makes it a suitable entry into the fest circuit . 1 775 | it 's a beautiful madness . 1 776 | when the film ended , i felt tired and drained and wanted to lie on my own deathbed for a while . 0 777 | not really bad so much as distasteful : we need kidnapping suspense dramas right now like we need doomsday thrillers . 0 778 | as ` chick flicks ' go , this one is pretty miserable , resorting to string-pulling rather than legitimate character development and intelligent plotting . 0 779 | by candidly detailing the politics involved in the creation of an extraordinary piece of music , ( jones ) calls our attention to the inherent conflict between commerce and creativity . 1 780 | one of the most significant moviegoing pleasures of the year . 1 781 | prurient playthings aside , there 's little to love about this english trifle . 0 782 | a grimly competent and stolid and earnest military courtroom drama . 1 783 | at once half-baked and overheated . 0 784 | the structure the film takes may find matt damon and ben affleck once again looking for residuals as this officially completes a good will hunting trilogy that was never planned . 1 785 | the movie does a good job of laying out some of the major issues that we encounter as we journey through life . 1 786 | very psychoanalytical -- provocatively so -- and also refreshingly literary . 1 787 | aside from minor tinkering , this is the same movie you probably loved in 1994 , except that it looks even better . 1 788 | the film makes a fatal mistake : it asks us to care about a young man whose only apparent virtue is that he is not quite as unpleasant as some of the people in his life . 0 789 | a valueless kiddie paean to pro basketball underwritten by the nba . 0 790 | based on a devilishly witty script by heather mcgowan and niels mueller , the film gets great laughs , but never at the expense of its characters 1 791 | it 's as if you 're watching a movie that was made in 1978 but not released then because it was so weak , and it has been unearthed and released now , when it has become even weaker . 0 792 | that 's a cheat . 0 793 | it 's somewhat clumsy and too lethargically paced -- but its story about a mysterious creature with psychic abilities offers a solid build-up , a terrific climax , and some nice chills along the way . 0 794 | it 's fun lite . 1 795 | ... an otherwise intense , twist-and-turn thriller that certainly should n't hurt talented young gaghan 's resume . 1 796 | it confirms fincher 's status as a film maker who artfully bends technical know-how to the service of psychological insight . 1 797 | the film contains no good jokes , no good scenes , barely a moment when carvey 's saturday night live-honed mimicry rises above the level of embarrassment . 0 798 | a fitfully amusing romp that , if nothing else , will appeal to fans of malcolm in the middle and its pubescent star , frankie muniz . 1 799 | it 's not original , and , robbed of the element of surprise , it does n't have any huge laughs in its story of irresponsible cops who love to play pranks . 0 800 | though moonlight mile is replete with acclaimed actors and actresses and tackles a subject that 's potentially moving , the movie is too predictable and too self-conscious to reach a level of high drama . 0 801 | a tender , witty , captivating film about friendship , love , memory , trust and loyalty . 1 802 | for all its technical virtuosity , the film is so mired in juvenile and near-xenophobic pedagogy that it 's enough to make one pine for the day when godard can no longer handle the rigors of filmmaking . 0 803 | this film seems thirsty for reflection , itself taking on adolescent qualities . 0 804 | this nickleby thing might have more homosexual undertones than an eddie murphy film . 0 805 | bogdanovich tantalizes by offering a peep show into the lives of the era 's creme de la celluloid . 1 806 | but the power of these ( subjects ) is obscured by the majority of the film that shows a stationary camera on a subject that could be mistaken for giving a public oration , rather than contributing to a film 's narrative . 0 807 | irwin is a man with enough charisma and audacity to carry a dozen films , but this particular result is ultimately held back from being something greater . 0 808 | griffiths proves she 's that rare luminary who continually raises the standard of her profession . 1 809 | just as moving , uplifting and funny as ever . 1 810 | enormously entertaining for moviegoers of any age . 1 811 | a lean , deftly shot , well-acted , weirdly retro thriller that recalls a raft of '60s and '70s european-set spy pictures . 1 812 | this is a good script , good dialogue , funny even for adults . 1 813 | the affectionate loopiness that once seemed congenital to demme 's perspective has a tough time emerging from between the badly dated cutesy-pie mystery scenario and the newfangled hollywood post-production effects . 0 814 | this surreal gilliam-esque film is also a troubling interpretation of ecclesiastes . 1 815 | i can take infantile humor ... but this is the sort of infantile that makes you wonder about changing the director and writer 's diapers . 0 816 | this piece of channel 5 grade trash is , quite frankly , an insult to the intelligence of the true genre enthusiast . 0 817 | a delightful coming-of-age story . 1 818 | a spellbinding african film about the modern condition of rootlessness , a state experienced by millions around the globe . 1 819 | a strangely compelling and brilliantly acted psychological drama . 1 820 | the only excitement comes when the credits finally roll and you get to leave the theater . 0 821 | the movie is dawn of the dead crossed with john carpenter 's ghosts of mars , with zombies not as ghoulish as the first and trains not as big as the second . 0 822 | it has the charm of the original american road movies , feasting on the gorgeous , ramshackle landscape of the filmmaker 's motherland . 1 823 | exciting and direct , with ghost imagery that shows just enough to keep us on our toes . 1 824 | it 's a buggy drag . 0 825 | it wants to tweak them with a taste of tangy new humor . 1 826 | late marriage 's stiffness is unlikely to demonstrate the emotional clout to sweep u.s. viewers off their feet . 0 827 | candid and comfortable ; a film that deftly balances action and reflection as it lets you grasp and feel the passion others have for their work . 1 828 | a quiet treasure -- a film to be savored . 1 829 | the movie , directed by mick jackson , leaves no cliche unturned , from the predictable plot to the characters straight out of central casting . 0 830 | this is the sort of burly action flick where one coincidence pummels another , narrative necessity is a drunken roundhouse , and whatever passes for logic is a factor of the last plot device left standing . 0 831 | teen movies have really hit the skids . 0 832 | woody allen 's latest is an ambling , broad comedy about all there is to love -- and hate -- about the movie biz . 1 833 | it 's made with deftly unsettling genre flair . 1 834 | manages to show life in all of its banality when the intention is quite the opposite . 0 835 | ultimately feels empty and unsatisfying , like swallowing a communion wafer without the wine . 0 836 | collateral damage finally delivers the goods for schwarzenegger fans . 1 837 | a giggle-inducing comedy with snappy dialogue and winning performances by an unlikely team of oscar-winners : susan sarandon and goldie hawn . 1 838 | it 's too self-important and plodding to be funny , and too clipped and abbreviated to be an epic . 0 839 | will amuse and provoke adventurous adults in specialty venues . 1 840 | sometimes seems less like storytelling than something the otherwise compelling director needed to get off his chest . 0 841 | but this films lacks the passion required to sell the material . 0 842 | what is 100 % missing here is a script of even the most elemental literacy , an inkling of genuine wit , and anything resembling acting . 0 843 | there 's a wickedly subversive bent to the best parts of birthday girl . 1 844 | a better title , for all concerned , might be swept under the rug . 0 845 | a wildly inconsistent emotional experience . 0 846 | given how heavy-handed and portent-heavy it is , this could be the worst thing soderbergh has ever done . 0 847 | despite the evocative aesthetics evincing the hollow state of modern love life , the film never percolates beyond a monotonous whine . 0 848 | an absurdist comedy about alienation , separation and loss . 0 849 | ... mafia , rap stars and hood rats butt their ugly heads in a regurgitation of cinematic violence that gives brutal birth to an unlikely , but likable , hero . ' 1 850 | his healthy sense of satire is light and fun ... 1 851 | trademark american triteness and simplicity are tossed out the window with the intelligent french drama that deftly explores the difficult relationship between a father and son . 1 852 | miller is playing so free with emotions , and the fact that children are hostages to fortune , that he makes the audience hostage to his swaggering affectation of seriousness . 1 853 | impostor has a handful of thrilling moments and a couple of good performances , but the movie does n't quite fly . 0 854 | part low rent godfather . 0 855 | ( serry ) wants to blend politics and drama , an admirable ambition . 1 856 | for all the writhing and wailing , tears , rage and opium overdoses , there 's no sense of actual passion being washed away in love 's dissolution . 0 857 | it 's fascinating to see how bettany and mcdowell play off each other . 1 858 | in a way , the film feels like a breath of fresh air , but only to those that allow it in . 1 859 | visually imaginative , thematically instructive and thoroughly delightful , it takes us on a roller-coaster ride from innocence to experience without even a hint of that typical kiddie-flick sentimentality . 1 860 | the film 's welcome breeziness and some unbelievably hilarious moments -- most portraying the idiocy of the film industry -- make it mostly worth the trip . 1 861 | add yet another hat to a talented head , clooney 's a good director . 1 862 | stephen rea , aidan quinn , and alan bates play desmond 's legal eagles , and when joined by brosnan , the sight of this grandiloquent quartet lolling in pretty irish settings is a pleasant enough thing , ` tis . 1 863 | bennett 's naturalistic performance speaks volumes more truth than any ` reality ' show , and anybody contemplating their own drastic life changes should watch some body first . 1 864 | it 's inoffensive , cheerful , built to inspire the young people , set to an unending soundtrack of beach party pop numbers and aside from its remarkable camerawork and awesome scenery , it 's about as exciting as a sunburn . 0 865 | while it 's genuinely cool to hear characters talk about early rap records ( sugar hill gang , etc. ) , the constant referencing of hip-hop arcana can alienate even the savviest audiences . 0 866 | dull , lifeless , and amateurishly assembled . 0 867 | mcconaughey 's fun to watch , the dragons are okay , not much fire in the script . 1 868 | the far future may be awesome to consider , but from period detail to matters of the heart , this film is most transporting when it stays put in the past . 1 869 | has all the depth of a wading pool . 0 870 | a movie with a real anarchic flair . 1 871 | a subject like this should inspire reaction in its audience ; the pianist does not . 0 872 | ... is an arthritic attempt at directing by callie khouri . 0 873 | looking aristocratic , luminous yet careworn in jane hamilton 's exemplary costumes , rampling gives a performance that could not be improved upon . ' 1 874 | --------------------------------------------------------------------------------