├── models ├── __init__.py └── network.py ├── asset ├── mouth_000.png ├── mouth_001.png ├── mouth_074.png ├── s2_bbbf7p_000.png ├── s2_bbbf7p_001.png ├── s2_bbbf7p_074.png └── network_structure.png ├── requirements.txt ├── utils ├── __init__.py ├── common.py ├── align.py ├── multi.py ├── download_data.py ├── run_preprocess.ipynb ├── preprocess_data.py └── run_preprocess_single_process.ipynb ├── checkpoint └── __init__.py ├── tests └── test_beamsearch.py ├── .gitignore ├── main.py ├── infer.py ├── data_loader.py ├── BeamSearch.py ├── trainer.py └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asset/mouth_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ski-net/lipnet/HEAD/asset/mouth_000.png -------------------------------------------------------------------------------- /asset/mouth_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ski-net/lipnet/HEAD/asset/mouth_001.png -------------------------------------------------------------------------------- /asset/mouth_074.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ski-net/lipnet/HEAD/asset/mouth_074.png -------------------------------------------------------------------------------- /asset/s2_bbbf7p_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ski-net/lipnet/HEAD/asset/s2_bbbf7p_000.png -------------------------------------------------------------------------------- /asset/s2_bbbf7p_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ski-net/lipnet/HEAD/asset/s2_bbbf7p_001.png -------------------------------------------------------------------------------- /asset/s2_bbbf7p_074.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ski-net/lipnet/HEAD/asset/s2_bbbf7p_074.png -------------------------------------------------------------------------------- /asset/network_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ski-net/lipnet/HEAD/asset/network_structure.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dlib==19.15.0 2 | Pillow==4.1.0 3 | scipy==0.19.0 4 | scikit-image==0.13.1 5 | scikit-video==1.1.11 6 | sk-video==1.1.10 7 | tqdm 8 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | -------------------------------------------------------------------------------- /checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | -------------------------------------------------------------------------------- /tests/test_beamsearch.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | """it is the test for the decode using beam search 19 | Ref: 20 | https://github.com/ThomasDelteil/HandwrittenTextRecognition_MXNet/blob/master/utils/CTCDecoder/BeamSearch.py 21 | """ 22 | 23 | import unittest 24 | import numpy as np 25 | from BeamSearch import ctcBeamSearch 26 | 27 | class TestBeamSearch(unittest.TestCase): 28 | """Test Beam Search 29 | """ 30 | def test_ctc_beam_search(self): 31 | "test decoder" 32 | classes = 'ab' 33 | mat = np.array([[0.4, 0, 0.6], [0.4, 0, 0.6]]) 34 | print('Test beam search') 35 | expected = 'a' 36 | actual = ctcBeamSearch(mat, classes, None, k=2, beamWidth=3)[0] 37 | print('Expected: "' + expected + '"') 38 | print('Actual: "' + actual + '"') 39 | self.assertEqual(expected, actual) 40 | 41 | if __name__ == '__main__': 42 | unittest.main() 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | utils/*.dat 107 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | """ 19 | Description : main module to run the lipnet training code 20 | """ 21 | 22 | 23 | import argparse 24 | from trainer import Train 25 | 26 | def main(): 27 | """ 28 | Description : run lipnet training code using argument info 29 | """ 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--batch_size', type=int, default=64) 32 | parser.add_argument('--epochs', type=int, default=100) 33 | parser.add_argument('--image_path', type=str, default='./data/datasets/') 34 | parser.add_argument('--align_path', type=str, default='./data/align/') 35 | parser.add_argument('--dr_rate', type=float, default=0.5) 36 | parser.add_argument('--num_gpus', type=int, default=1) 37 | parser.add_argument('--num_workers', type=int, default=0) 38 | parser.add_argument('--model_path', type=str, default=None) 39 | config = parser.parse_args() 40 | trainer = Train(config) 41 | trainer.build_model(dr_rate=config.dr_rate, path=config.model_path) 42 | trainer.load_dataloader() 43 | trainer.run(epochs=config.epochs) 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | """ 19 | Description : main module to run the lipnet inference code 20 | """ 21 | 22 | 23 | import argparse 24 | from trainer import Train 25 | 26 | def main(): 27 | """ 28 | Description : run lipnet training code using argument info 29 | """ 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--batch_size', type=int, default=64) 32 | parser.add_argument('--image_path', type=str, default='./data/datasets/') 33 | parser.add_argument('--align_path', type=str, default='./data/align/') 34 | parser.add_argument('--num_gpus', type=int, default=1) 35 | parser.add_argument('--num_workers', type=int, default=0) 36 | parser.add_argument('--data_type', type=str, default='valid') 37 | parser.add_argument('--model_path', type=str, default=None) 38 | config = parser.parse_args() 39 | trainer = Train(config) 40 | trainer.build_model(path=config.model_path) 41 | trainer.load_dataloader() 42 | 43 | if config.data_type == 'train': 44 | data_loader = trainer.train_dataloader 45 | elif config.data_type == 'valid': 46 | data_loader = trainer.valid_dataloader 47 | 48 | trainer.infer_batch(data_loader) 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | """ 19 | Module: This module contains common conversion functions 20 | 21 | """ 22 | 23 | 24 | def char2int(char): 25 | """ 26 | Convert character to integer. 27 | """ 28 | if char >= 'a' and char <= 'z': 29 | return ord(char) - ord('a') 30 | elif char == ' ': 31 | return 26 32 | return None 33 | 34 | 35 | def int2char(num): 36 | """ 37 | Convert integer to character. 38 | """ 39 | if num >= 0 and num < 26: 40 | return chr(num + ord('a')) 41 | elif num == 26: 42 | return ' ' 43 | return None 44 | 45 | 46 | def word_to_vector(word): 47 | """ 48 | Convert character vectors to integer vectors. 49 | """ 50 | vector = [] 51 | for char in list(word): 52 | vector.append(char2int(char)) 53 | return vector 54 | 55 | 56 | def vector_to_word(vector): 57 | """ 58 | Convert integer vectors to character vectors. 59 | """ 60 | word = "" 61 | for vec in vector: 62 | word = word + int2char(vec) 63 | return word 64 | 65 | 66 | def char_conv(out): 67 | """ 68 | Convert integer vectors to character vectors for batch. 69 | """ 70 | out_conv = list() 71 | for i in range(out.shape[0]): 72 | tmp_str = '' 73 | for j in range(out.shape[1]): 74 | if int(out[i][j]) >= 0: 75 | tmp_char = int2char(int(out[i][j])) 76 | if int(out[i][j]) == 27: 77 | tmp_char = '' 78 | tmp_str = tmp_str + tmp_char 79 | out_conv.append(tmp_str) 80 | return out_conv 81 | -------------------------------------------------------------------------------- /utils/align.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | """ 19 | Module: align 20 | This is used when the data is genrated by LipsDataset 21 | """ 22 | 23 | import numpy as np 24 | from .common import word_to_vector 25 | 26 | 27 | class Align(object): 28 | """ 29 | Preprocess for Align 30 | """ 31 | skip_list = ['sil', 'sp'] 32 | 33 | def __init__(self, align_path): 34 | self.build(align_path) 35 | 36 | def build(self, align_path): 37 | """ 38 | Build the align array 39 | """ 40 | file = open(align_path, 'r') 41 | lines = file.readlines() 42 | file.close() 43 | # words: list([op, ed, word]) 44 | words = [] 45 | for line in lines: 46 | _op, _ed, word = line.strip().split(' ') 47 | if word not in Align.skip_list: 48 | words.append((int(_op), int(_ed), word)) 49 | self.words = words 50 | self.n_words = len(words) 51 | self.sentence_str = " ".join([w[2] for w in self.words]) 52 | self.sentence_length = len(self.sentence_str) 53 | 54 | def sentence(self, padding=75): 55 | """ 56 | Get sentence 57 | """ 58 | vec = word_to_vector(self.sentence_str) 59 | vec += [-1] * (padding - self.sentence_length) 60 | return np.array(vec, dtype=np.int32) 61 | 62 | def word(self, _id, padding=75): 63 | """ 64 | Get words 65 | """ 66 | word = self.words[_id][2] 67 | vec = word_to_vector(word) 68 | vec += [-1] * (padding - len(vec)) 69 | return np.array(vec, dtype=np.int32) 70 | 71 | def word_length(self, _id): 72 | """ 73 | Get the length of words 74 | """ 75 | return len(self.words[_id][2]) 76 | 77 | def word_frame_pos(self, _id): 78 | """ 79 | Get the position of words 80 | """ 81 | left = int(self.words[_id][0]/1000) 82 | right = max(left+1, int(self.words[_id][1]/1000)) 83 | return (left, right) 84 | -------------------------------------------------------------------------------- /models/network.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | """ 19 | Description : LipNet module using gluon 20 | """ 21 | 22 | from mxnet.gluon import nn, rnn 23 | # pylint: disable=too-many-instance-attributes 24 | class LipNet(nn.HybridBlock): 25 | """ 26 | Description : LipNet network using gluon 27 | dr_rate : Dropout rate 28 | """ 29 | def __init__(self, dr_rate, **kwargs): 30 | super(LipNet, self).__init__(**kwargs) 31 | with self.name_scope(): 32 | self.conv1 = nn.Conv3D(32, kernel_size=(3, 5, 5), strides=(1, 2, 2), padding=(1, 2, 2)) 33 | self.bn1 = nn.InstanceNorm(in_channels=32) 34 | self.dr1 = nn.Dropout(dr_rate, axes=(1, 2)) 35 | self.pool1 = nn.MaxPool3D((1, 2, 2), (1, 2, 2)) 36 | self.conv2 = nn.Conv3D(64, kernel_size=(3, 5, 5), strides=(1, 1, 1), padding=(1, 2, 2)) 37 | self.bn2 = nn.InstanceNorm(in_channels=64) 38 | self.dr2 = nn.Dropout(dr_rate, axes=(1, 2)) 39 | self.pool2 = nn.MaxPool3D((1, 2, 2), (1, 2, 2)) 40 | self.conv3 = nn.Conv3D(96, kernel_size=(3, 3, 3), strides=(1, 1, 1), padding=(1, 2, 2)) 41 | self.bn3 = nn.InstanceNorm(in_channels=96) 42 | self.dr3 = nn.Dropout(dr_rate, axes=(1, 2)) 43 | self.pool3 = nn.MaxPool3D((1, 2, 2), (1, 2, 2)) 44 | self.gru1 = rnn.GRU(256, bidirectional=True) 45 | self.gru2 = rnn.GRU(256, bidirectional=True) 46 | self.dense = nn.Dense(27+1, flatten=False) 47 | 48 | # pylint: disable=arguments-differ 49 | def hybrid_forward(self, F, x): 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = F.relu(out) 53 | out = self.dr1(out) 54 | out = self.pool1(out) 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | out = F.relu(out) 58 | out = self.dr2(out) 59 | out = self.pool2(out) 60 | out = self.conv3(out) 61 | out = self.bn3(out) 62 | out = F.relu(out) 63 | out = self.dr3(out) 64 | out = self.pool3(out) 65 | out = F.transpose(out, (2, 0, 1, 3, 4)) 66 | # pylint: disable=no-member 67 | out = out.reshape((0, 0, -1)) 68 | out = self.gru1(out) 69 | out = self.gru2(out) 70 | out = self.dense(out) 71 | out = F.log_softmax(out, axis=2) 72 | out = F.transpose(out, (1, 0, 2)) 73 | return out 74 | -------------------------------------------------------------------------------- /utils/multi.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | """ 19 | Module: preprocess with multi-process 20 | """ 21 | 22 | 23 | def multi_p_run(tot_num, _func, worker, params, n_process): 24 | """ 25 | Run _func with multi-process using params. 26 | """ 27 | from multiprocessing import Process, Queue 28 | out_q = Queue() 29 | procs = [] 30 | 31 | split_num = split_seq(list(range(0, tot_num)), n_process) 32 | 33 | print(tot_num, ">>", split_num) 34 | 35 | split_len = len(split_num) 36 | if n_process > split_len: 37 | n_process = split_len 38 | 39 | for i in range(n_process): 40 | _p = Process(target=_func, 41 | args=(worker, split_num[i][0], split_num[i][1], 42 | params, out_q)) 43 | _p.daemon = True 44 | procs.append(_p) 45 | _p.start() 46 | 47 | try: 48 | result = [] 49 | for i in range(n_process): 50 | result.append(out_q.get()) 51 | for i in procs: 52 | i.join() 53 | except KeyboardInterrupt: 54 | print('Killing all the children in the pool.') 55 | for i in procs: 56 | i.terminate() 57 | i.join() 58 | return -1 59 | 60 | while not out_q.empty(): 61 | print(out_q.get(block=False)) 62 | 63 | return result 64 | 65 | 66 | def split_seq(sam_num, n_tile): 67 | """ 68 | Split the number(sam_num) into numbers by n_tile 69 | """ 70 | import math 71 | print(sam_num) 72 | print(n_tile) 73 | start_num = sam_num[0::int(math.ceil(len(sam_num) / (n_tile)))] 74 | end_num = start_num[1::] 75 | end_num.append(len(sam_num)) 76 | return [[i, j] for i, j in zip(start_num, end_num)] 77 | 78 | 79 | def put_worker(func, from_idx, to_idx, params, out_q): 80 | """ 81 | put worker 82 | """ 83 | succ, fail = func(from_idx, to_idx, params) 84 | return out_q.put({'succ': succ, 'fail': fail}) 85 | 86 | 87 | def test_worker(from_idx, to_idx, params): 88 | """ 89 | the worker to test multi-process 90 | """ 91 | params = params 92 | succ = set() 93 | fail = set() 94 | for idx in range(from_idx, to_idx): 95 | try: 96 | succ.add(idx) 97 | except ValueError: 98 | fail.add(idx) 99 | return (succ, fail) 100 | 101 | 102 | if __name__ == '__main__': 103 | RES = multi_p_run(35, put_worker, test_worker, params={}, n_process=5) 104 | print(RES) 105 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description : Set DataSet module for lip images 3 | """ 4 | # Licensed to the Apache Software Foundation (ASF) under one 5 | # or more contributor license agreements. See the NOTICE file 6 | # distributed with this work for additional information 7 | # regarding copyright ownership. The ASF licenses this file 8 | # to you under the Apache License, Version 2.0 (the 9 | # "License"); you may not use this file except in compliance 10 | # with the License. You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, 15 | # software distributed under the License is distributed on an 16 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 17 | # KIND, either express or implied. See the License for the 18 | # specific language governing permissions and limitations 19 | # under the License. 20 | 21 | import os 22 | import glob 23 | from mxnet import nd 24 | import mxnet.gluon.data.dataset as dataset 25 | from mxnet.gluon.data.vision.datasets import image 26 | from utils.align import Align 27 | 28 | # pylint: disable=too-many-instance-attributes, too-many-arguments 29 | class LipsDataset(dataset.Dataset): 30 | """ 31 | Description : DataSet class for lip images 32 | """ 33 | def __init__(self, root, align_root, flag=1, 34 | mode='train', transform=None, seq_len=75): 35 | assert mode in ['train', 'valid'] 36 | self._root = os.path.expanduser(root) 37 | self._align_root = align_root 38 | self._flag = flag 39 | self._transform = transform 40 | self._exts = ['.jpg', '.jpeg', '.png'] 41 | self._seq_len = seq_len 42 | self._mode = mode 43 | self._list_images(self._root) 44 | 45 | def _list_images(self, root): 46 | """ 47 | Description : generate list for lip images 48 | """ 49 | self.labels = [] 50 | self.items = [] 51 | 52 | valid_unseen_sub_idx = [1, 2, 20, 22] 53 | skip_sub_idx = [21] 54 | 55 | if self._mode == 'train': 56 | sub_idx = ['s' + str(i) for i in range(1, 35) \ 57 | if i not in valid_unseen_sub_idx + skip_sub_idx] 58 | elif self._mode == 'valid': 59 | sub_idx = ['s' + str(i) for i in valid_unseen_sub_idx] 60 | 61 | folder_path = [] 62 | for i in sub_idx: 63 | folder_path.extend(glob.glob(os.path.join(root, i, "*"))) 64 | 65 | for folder in folder_path: 66 | filename = glob.glob(os.path.join(folder, "*")) 67 | if len(filename) != self._seq_len: 68 | continue 69 | filename.sort() 70 | label = os.path.split(folder)[-1] 71 | self.items.append((filename, label)) 72 | 73 | def align_generation(self, file_nm, padding=75): 74 | """ 75 | Description : Align to lip position 76 | """ 77 | align = Align(self._align_root + '/' + file_nm + '.align') 78 | return nd.array(align.sentence(padding)) 79 | 80 | def __getitem__(self, idx): 81 | img = list() 82 | for image_name in self.items[idx][0]: 83 | tmp_img = image.imread(image_name, self._flag) 84 | if self._transform is not None: 85 | tmp_img = self._transform(tmp_img) 86 | img.append(tmp_img) 87 | img = nd.stack(*img) 88 | img = nd.transpose(img, (1, 0, 2, 3)) 89 | label = self.align_generation(self.items[idx][1], 90 | padding=self._seq_len) 91 | return img, label 92 | 93 | def __len__(self): 94 | return len(self.items) 95 | -------------------------------------------------------------------------------- /utils/download_data.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | """ 19 | Module: download_data 20 | This module provides utilities for downloading the datasets for training LipNet 21 | """ 22 | 23 | import os 24 | from os.path import exists 25 | from multi import multi_p_run, put_worker 26 | 27 | 28 | def download_mp4(from_idx, to_idx, _params): 29 | """ 30 | download mp4s 31 | """ 32 | succ = set() 33 | fail = set() 34 | for idx in range(from_idx, to_idx): 35 | name = 's' + str(idx) 36 | save_folder = '{src_path}/{nm}'.format(src_path=_params['src_path'], nm=name) 37 | if idx == 0 or os.path.isdir(save_folder): 38 | continue 39 | script = "http://spandh.dcs.shef.ac.uk/gridcorpus/{nm}/video/{nm}.mpg_vcd.zip".format( \ 40 | nm=name) 41 | down_sc = 'cd {src_path} && curl {script} --output {nm}.mpg_vcd.zip && \ 42 | unzip {nm}.mpg_vcd.zip'.format(script=script, 43 | nm=name, 44 | src_path=_params['src_path']) 45 | try: 46 | print(down_sc) 47 | os.system(down_sc) 48 | succ.add(idx) 49 | except OSError as error: 50 | print(error) 51 | fail.add(idx) 52 | return (succ, fail) 53 | 54 | 55 | def download_align(from_idx, to_idx, _params): 56 | """ 57 | download aligns 58 | """ 59 | succ = set() 60 | fail = set() 61 | for idx in range(from_idx, to_idx): 62 | name = 's' + str(idx) 63 | if idx == 0: 64 | continue 65 | script = "http://spandh.dcs.shef.ac.uk/gridcorpus/{nm}/align/{nm}.tar".format(nm=name) 66 | down_sc = 'cd {align_path} && wget {script} && \ 67 | tar -xvf {nm}.tar'.format(script=script, 68 | nm=name, 69 | align_path=_params['align_path']) 70 | try: 71 | print(down_sc) 72 | os.system(down_sc) 73 | succ.add(idx) 74 | except OSError as error: 75 | print(error) 76 | fail.add(idx) 77 | return (succ, fail) 78 | 79 | 80 | if __name__ == '__main__': 81 | import argparse 82 | PARSER = argparse.ArgumentParser() 83 | PARSER.add_argument('--src_path', type=str, default='../data/mp4s') 84 | PARSER.add_argument('--align_path', type=str, default='../data') 85 | PARSER.add_argument('--n_process', type=int, default=1) 86 | CONFIG = PARSER.parse_args() 87 | PARAMS = {'src_path': CONFIG.src_path, 'align_path': CONFIG.align_path} 88 | N_PROCESS = CONFIG.n_process 89 | 90 | if exists('./shape_predictor_68_face_landmarks.dat') is False: 91 | os.system('wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 && \ 92 | bzip2 -d shape_predictor_68_face_landmarks.dat.bz2') 93 | 94 | os.makedirs('{src_path}'.format(src_path=PARAMS['src_path']), exist_ok=True) 95 | os.makedirs('{align_path}'.format(align_path=PARAMS['align_path']), exist_ok=True) 96 | 97 | if N_PROCESS == 1: 98 | RES = download_mp4(0, 35, PARAMS) 99 | RES = download_align(0, 35, PARAMS) 100 | else: 101 | # download movie files 102 | RES = multi_p_run(tot_num=35, _func=put_worker, worker=download_mp4, \ 103 | params=PARAMS, n_process=N_PROCESS) 104 | 105 | # download align files 106 | RES = multi_p_run(tot_num=35, _func=put_worker, worker=download_align, \ 107 | params=PARAMS, n_process=N_PROCESS) 108 | 109 | os.system('rm -f {src_path}/*.zip && rm -f {src_path}/*/Thumbs.db'.format( \ 110 | src_path=PARAMS['src_path'])) 111 | os.system('rm -f {align_path}/*.tar && rm -f {align_path}/Thumbs.db'.format( \ 112 | align_path=PARAMS['align_path'])) 113 | -------------------------------------------------------------------------------- /utils/run_preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "from download_data import multi_p_run, put_worker, _worker, download_mp4, download_align" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "## TEST" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]\n", 31 | "5\n", 32 | "35 >> [[0, 7], [7, 14], [14, 21], [21, 28], [28, 35]]\n", 33 | "[{'succ': {0, 1, 2, 3, 4, 5, 6}, 'fail': set()}, {'succ': {7, 8, 9, 10, 11, 12, 13}, 'fail': set()}, {'succ': {14, 15, 16, 17, 18, 19, 20}, 'fail': set()}, {'succ': {21, 22, 23, 24, 25, 26, 27}, 'fail': set()}, {'succ': {32, 33, 34, 28, 29, 30, 31}, 'fail': set()}]\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "res = multi_p_run(35, put_worker, _worker, 5)\n", 39 | "print (res)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## Download Data" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "## down\n", 56 | "import os\n", 57 | "os.makedirs('./datasets', exist_ok=True)\n", 58 | "#os.system('rm -rf ./datasets/*')\n", 59 | "\n", 60 | "res = multi_p_run(35, put_worker, download_align, 9)\n", 61 | "print (res)\n", 62 | "\n", 63 | "os.system('rm -f datasets/*.tar && rm -f datasets/align/Thumbs.db')" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "res = multi_p_run(35, put_worker, download_mp4, 9)\n", 73 | "print (res)\n", 74 | "\n", 75 | "os.system('rm -f datasets/*.zip && rm -f datasets/*/Thumbs.db')" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "## download single 22 th dir\n", 85 | "#download_data.py(22, 22)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "## Preprocess Data" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 3, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "from preprocess_data import preprocess, find_files, Video" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 4, 107 | "metadata": { 108 | "scrolled": true 109 | }, 110 | "outputs": [ 111 | { 112 | "data": { 113 | "text/plain": [ 114 | "0" 115 | ] 116 | }, 117 | "execution_count": 4, 118 | "metadata": {}, 119 | "output_type": "execute_result" 120 | } 121 | ], 122 | "source": [ 123 | "import os\n", 124 | "os.makedirs('./TARGET', exist_ok=True)\n", 125 | "os.system('rm -rf ./TARGET/*')" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]\n", 138 | "9\n", 139 | "35 >> [[0, 4], [4, 8], [8, 12], [12, 16], [16, 20], [20, 24], [24, 28], [28, 32], [32, 35]]\n", 140 | "Processing: datasets/s1/prwq3s.mpg\n", 141 | "Processing: datasets/s4/lrix7n.mpg\n", 142 | "Processing: datasets/s8/pgbyza.mpg\n", 143 | "Processing: datasets/s12/brik7n.mpg\n", 144 | "Processing: datasets/s16/sgit7p.mpg\n", 145 | "Processing: datasets/s20/lrbp8a.mpg\n", 146 | "Processing: datasets/s24/sbik8a.mpg\n", 147 | "Processing: datasets/s28/srwf8a.mpg\n", 148 | "Processing: datasets/s32/pbbm1n.mpg\n", 149 | "Processing: datasets/s12/sbbaza.mpg\n", 150 | "Processing: datasets/s28/lbit7n.mpg\n", 151 | "Processing: datasets/s32/pbwm7p.mpg\n", 152 | "Processing: datasets/s8/bril2s.mpg\n", 153 | "Processing: datasets/s20/bway7n.mpg\n", 154 | "Processing: datasets/s1/pbib8p.mpg\n", 155 | "Processing: datasets/s16/lwaj7n.mpg\n", 156 | "Processing: datasets/s24/bwwl6a.mpg\n", 157 | "Processing: datasets/s4/bbwf7n.mpg\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "res = multi_p_run(35, put_worker, preprocess, 9)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [] 171 | } 172 | ], 173 | "metadata": { 174 | "kernelspec": { 175 | "display_name": "Python 3", 176 | "language": "python", 177 | "name": "python3" 178 | }, 179 | "language_info": { 180 | "codemirror_mode": { 181 | "name": "ipython", 182 | "version": 3 183 | }, 184 | "file_extension": ".py", 185 | "mimetype": "text/x-python", 186 | "name": "python", 187 | "nbconvert_exporter": "python", 188 | "pygments_lexer": "ipython3", 189 | "version": "3.6.6" 190 | } 191 | }, 192 | "nbformat": 4, 193 | "nbformat_minor": 2 194 | } 195 | -------------------------------------------------------------------------------- /BeamSearch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | """ 21 | Module : this module to decode using beam search 22 | https://github.com/ThomasDelteil/HandwrittenTextRecognition_MXNet/blob/master/utils/CTCDecoder/BeamSearch.py 23 | """ 24 | 25 | from __future__ import division 26 | from __future__ import print_function 27 | import numpy as np 28 | 29 | class BeamEntry: 30 | """ 31 | information about one single beam at specific time-step 32 | """ 33 | def __init__(self): 34 | self.prTotal = 0 # blank and non-blank 35 | self.prNonBlank = 0 # non-blank 36 | self.prBlank = 0 # blank 37 | self.prText = 1 # LM score 38 | self.lmApplied = False # flag if LM was already applied to this beam 39 | self.labeling = () # beam-labeling 40 | 41 | class BeamState: 42 | """ 43 | information about the beams at specific time-step 44 | """ 45 | def __init__(self): 46 | self.entries = {} 47 | 48 | def norm(self): 49 | """ 50 | length-normalise LM score 51 | """ 52 | for (k, _) in self.entries.items(): 53 | labelingLen = len(self.entries[k].labeling) 54 | self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0)) 55 | 56 | def sort(self): 57 | """ 58 | return beam-labelings, sorted by probability 59 | """ 60 | beams = [v for (_, v) in self.entries.items()] 61 | sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText) 62 | return [x.labeling for x in sortedBeams] 63 | 64 | def applyLM(parentBeam, childBeam, classes, lm): 65 | """ 66 | calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars 67 | """ 68 | if lm and not childBeam.lmApplied: 69 | c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] # first char 70 | c2 = classes[childBeam.labeling[-1]] # second char 71 | lmFactor = 0.01 # influence of language model 72 | bigramProb = lm.getCharBigram(c1, c2) ** lmFactor # probability of seeing first and second char next to each other 73 | childBeam.prText = parentBeam.prText * bigramProb # probability of char sequence 74 | childBeam.lmApplied = True # only apply LM once per beam entry 75 | 76 | def addBeam(beamState, labeling): 77 | """ 78 | add beam if it does not yet exist 79 | """ 80 | if labeling not in beamState.entries: 81 | beamState.entries[labeling] = BeamEntry() 82 | 83 | def ctcBeamSearch(mat, classes, lm, k, beamWidth): 84 | """ 85 | beam search as described by the paper of Hwang et al. and the paper of Graves et al. 86 | """ 87 | 88 | blankIdx = len(classes) 89 | maxT, maxC = mat.shape 90 | 91 | # initialise beam state 92 | last = BeamState() 93 | labeling = () 94 | last.entries[labeling] = BeamEntry() 95 | last.entries[labeling].prBlank = 1 96 | last.entries[labeling].prTotal = 1 97 | 98 | # go over all time-steps 99 | for t in range(maxT): 100 | curr = BeamState() 101 | 102 | # get beam-labelings of best beams 103 | bestLabelings = last.sort()[0:beamWidth] 104 | 105 | # go over best beams 106 | for labeling in bestLabelings: 107 | 108 | # probability of paths ending with a non-blank 109 | prNonBlank = 0 110 | # in case of non-empty beam 111 | if labeling: 112 | # probability of paths with repeated last char at the end 113 | try: 114 | prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]] 115 | except FloatingPointError: 116 | prNonBlank = 0 117 | 118 | # probability of paths ending with a blank 119 | prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx] 120 | 121 | # add beam at current time-step if needed 122 | addBeam(curr, labeling) 123 | 124 | # fill in data 125 | curr.entries[labeling].labeling = labeling 126 | curr.entries[labeling].prNonBlank += prNonBlank 127 | curr.entries[labeling].prBlank += prBlank 128 | curr.entries[labeling].prTotal += prBlank + prNonBlank 129 | curr.entries[labeling].prText = last.entries[labeling].prText # beam-labeling not changed, therefore also LM score unchanged from 130 | curr.entries[labeling].lmApplied = True # LM already applied at previous time-step for this beam-labeling 131 | 132 | # extend current beam-labeling 133 | for c in range(maxC - 1): 134 | # add new char to current beam-labeling 135 | newLabeling = labeling + (c,) 136 | 137 | # if new labeling contains duplicate char at the end, only consider paths ending with a blank 138 | if labeling and labeling[-1] == c: 139 | prNonBlank = mat[t, c] * last.entries[labeling].prBlank 140 | else: 141 | prNonBlank = mat[t, c] * last.entries[labeling].prTotal 142 | 143 | # add beam at current time-step if needed 144 | addBeam(curr, newLabeling) 145 | 146 | # fill in data 147 | curr.entries[newLabeling].labeling = newLabeling 148 | curr.entries[newLabeling].prNonBlank += prNonBlank 149 | curr.entries[newLabeling].prTotal += prNonBlank 150 | 151 | # apply LM 152 | applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm) 153 | 154 | # set new beam state 155 | last = curr 156 | 157 | # normalise LM scores according to beam-labeling-length 158 | last.norm() 159 | 160 | # sort by probability 161 | bestLabelings = last.sort()[:k] # get most probable labeling 162 | 163 | output = [] 164 | for bestLabeling in bestLabelings: 165 | # map labels to chars 166 | res = '' 167 | for l in bestLabeling: 168 | res += classes[l] 169 | output.append(res) 170 | return output -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | """ 19 | Description : Training module for LipNet 20 | """ 21 | 22 | 23 | import sys 24 | import mxnet as mx 25 | from mxnet import gluon, autograd, nd 26 | from mxnet.gluon.data.vision import transforms 27 | from tqdm import tqdm, trange 28 | from data_loader import LipsDataset 29 | from models.network import LipNet 30 | from BeamSearch import ctcBeamSearch 31 | from utils.common import char_conv, int2char 32 | # set gpu count 33 | 34 | 35 | def setting_ctx(num_gpus): 36 | """ 37 | Description : set gpu module 38 | """ 39 | if num_gpus > 0: 40 | ctx = [mx.gpu(i) for i in range(num_gpus)] 41 | else: 42 | ctx = [mx.cpu()] 43 | return ctx 44 | 45 | 46 | ALPHABET = '' 47 | for i in range(27): 48 | ALPHABET += int2char(i) 49 | 50 | def char_beam_search(out): 51 | """ 52 | Description : apply beam search for prediction result 53 | """ 54 | out_conv = list() 55 | for idx in range(out.shape[0]): 56 | probs = out[idx] 57 | prob = probs.softmax().asnumpy() 58 | line_string_proposals = ctcBeamSearch(prob, ALPHABET, None, k=4, beamWidth=25) 59 | out_conv.append(line_string_proposals[0]) 60 | return out_conv 61 | 62 | # pylint: disable=too-many-instance-attributes, too-many-locals 63 | class Train: 64 | """ 65 | Description : Train class for training network 66 | """ 67 | def __init__(self, config): 68 | ##setting hyper-parameters 69 | self.batch_size = config.batch_size 70 | self.image_path = config.image_path 71 | self.align_path = config.align_path 72 | self.num_gpus = config.num_gpus 73 | self.ctx = setting_ctx(self.num_gpus) 74 | self.num_workers = config.num_workers 75 | self.seq_len = 75 76 | 77 | def build_model(self, dr_rate=0, path=None): 78 | """ 79 | Description : build network 80 | """ 81 | #set network 82 | self.net = LipNet(dr_rate) 83 | self.net.hybridize() 84 | self.net.initialize(ctx=self.ctx) 85 | 86 | if path is not None: 87 | self.load_model(path) 88 | 89 | #set optimizer 90 | self.loss_fn = gluon.loss.CTCLoss() 91 | self.trainer = gluon.Trainer(self.net.collect_params(), \ 92 | optimizer='SGD') 93 | 94 | def save_model(self, epoch, loss): 95 | """ 96 | Description : save parameter of network weight 97 | """ 98 | prefix = 'checkpoint/epoches' 99 | file_name = "{prefix}_{epoch}_loss_{l:.4f}".format(prefix=prefix, 100 | epoch=str(epoch), 101 | l=loss) 102 | self.net.save_parameters(file_name) 103 | 104 | def load_model(self, path=''): 105 | """ 106 | Description : load parameter of network weight 107 | """ 108 | self.net.load_parameters(path) 109 | 110 | def load_dataloader(self): 111 | """ 112 | Description : Setup the dataloader 113 | """ 114 | 115 | input_transform = transforms.Compose([transforms.ToTensor(), \ 116 | transforms.Normalize((0.7136, 0.4906, 0.3283), \ 117 | (0.1138, 0.1078, 0.0917))]) 118 | training_dataset = LipsDataset(self.image_path, 119 | self.align_path, 120 | mode='train', 121 | transform=input_transform, 122 | seq_len=self.seq_len) 123 | 124 | self.train_dataloader = mx.gluon.data.DataLoader(training_dataset, 125 | batch_size=self.batch_size, 126 | shuffle=True, 127 | num_workers=self.num_workers) 128 | 129 | valid_dataset = LipsDataset(self.image_path, 130 | self.align_path, 131 | mode='valid', 132 | transform=input_transform, 133 | seq_len=self.seq_len) 134 | 135 | self.valid_dataloader = mx.gluon.data.DataLoader(valid_dataset, 136 | batch_size=self.batch_size, 137 | shuffle=True, 138 | num_workers=self.num_workers) 139 | 140 | def train(self, data, label, batch_size): 141 | """ 142 | Description : training for LipNet 143 | """ 144 | # pylint: disable=no-member 145 | sum_losses = 0 146 | len_losses = 0 147 | with autograd.record(): 148 | losses = [self.loss_fn(self.net(X), Y) for X, Y in zip(data, label)] 149 | for loss in losses: 150 | sum_losses += mx.nd.array(loss).sum().asscalar() 151 | len_losses += len(loss) 152 | loss.backward() 153 | self.trainer.step(batch_size) 154 | return sum_losses, len_losses 155 | 156 | def infer(self, input_data, input_label): 157 | """ 158 | Description : Print sentence for prediction result 159 | """ 160 | sum_losses = 0 161 | len_losses = 0 162 | for data, label in zip(input_data, input_label): 163 | pred = self.net(data) 164 | sum_losses += mx.nd.array(self.loss_fn(pred, label)).sum().asscalar() 165 | len_losses += len(data) 166 | pred_convert = char_beam_search(pred) 167 | label_convert = char_conv(label.asnumpy()) 168 | for target, pred in zip(label_convert, pred_convert): 169 | print("target:{t} pred:{p}".format(t=target, p=pred)) 170 | return sum_losses, len_losses 171 | 172 | def train_batch(self, dataloader): 173 | """ 174 | Description : training for LipNet 175 | """ 176 | sum_losses = 0 177 | len_losses = 0 178 | for input_data, input_label in tqdm(dataloader): 179 | data = gluon.utils.split_and_load(input_data, self.ctx, even_split=False) 180 | label = gluon.utils.split_and_load(input_label, self.ctx, even_split=False) 181 | batch_size = input_data.shape[0] 182 | sum_losses, len_losses = self.train(data, label, batch_size) 183 | sum_losses += sum_losses 184 | len_losses += len_losses 185 | 186 | return sum_losses, len_losses 187 | 188 | def infer_batch(self, dataloader): 189 | """ 190 | Description : inference for LipNet 191 | """ 192 | sum_losses = 0 193 | len_losses = 0 194 | for input_data, input_label in dataloader: 195 | data = gluon.utils.split_and_load(input_data, self.ctx, even_split=False) 196 | label = gluon.utils.split_and_load(input_label, self.ctx, even_split=False) 197 | sum_losses, len_losses = self.infer(data, label) 198 | sum_losses += sum_losses 199 | len_losses += len_losses 200 | 201 | return sum_losses, len_losses 202 | 203 | def run(self, epochs): 204 | """ 205 | Description : Run training for LipNet 206 | """ 207 | best_loss = sys.maxsize 208 | for epoch in trange(epochs): 209 | iter_no = 0 210 | 211 | ## train 212 | sum_losses, len_losses = self.train_batch(self.train_dataloader) 213 | 214 | if iter_no % 20 == 0: 215 | current_loss = sum_losses / len_losses 216 | print("[Train] epoch:{e} iter:{i} loss:{l:.4f}".format(e=epoch, 217 | i=iter_no, 218 | l=current_loss)) 219 | 220 | ## validating 221 | sum_val_losses, len_val_losses = self.infer_batch(self.valid_dataloader) 222 | 223 | current_val_loss = sum_val_losses / len_val_losses 224 | print("[Vaild] epoch:{e} iter:{i} loss:{l:.4f}".format(e=epoch, 225 | i=iter_no, 226 | l=current_val_loss)) 227 | 228 | if best_loss > current_val_loss: 229 | self.save_model(epoch, current_val_loss) 230 | best_loss = current_val_loss 231 | 232 | iter_no += 1 233 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 19 | 20 | # LipNet: End-to-End Sentence-level Lipreading 21 | 22 | --- 23 | 24 | This is a Gluon implementation of [LipNet: End-to-End Sentence-level Lipreading](https://arxiv.org/abs/1611.01599) 25 | 26 | ![net_structure](asset/network_structure.png) 27 | 28 | ![sample output](https://user-images.githubusercontent.com/11376047/52533982-d7227680-2d7e-11e9-9f18-c15b952faf0e.png) 29 | 30 | ## Requirements 31 | - Python 3.6.4 32 | - MXNet 1.3.0 33 | - Required disk space: 35 GB 34 | ``` 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | --- 39 | 40 | ## The Data 41 | - The GRID audiovisual sentence corpus (http://spandh.dcs.shef.ac.uk/gridcorpus/) 42 | - GRID is a large multi-talker audiovisual sentence corpus to support joint computational-behavioral studies in speech perception. In brief, the corpus consists of high-quality audio and video (facial) recordings of 1000 sentences spoken by each of 34 talkers (18 male, 16 female). Sentences are of the form "put red at G9 now". The corpus, together with transcriptions, is freely available for research use. 43 | - Video: (normal)(480 M each) 44 | - Each movie has one sentence consist of 6 words. 45 | - Align: word alignments (190 K each) 46 | - One align has 6 words. Each word has start time and end time. But this tutorial needs just sentence because of using ctc-loss. 47 | 48 | --- 49 | 50 | ## Pretrained model 51 | You can train the model yourself in the following sections, you can test a pretrained model's inference, or resume training from the model checkpoint. To work with the provided pretrained model, first download it, then run one of the provided Python scripts for inference (infer.py) or training (main.py). 52 | 53 | * Download the [pretrained model](https://github.com/soeque1/temp_files/files/2848870/epoches_81_loss_15.7157.zip) 54 | * Try inference with the following: 55 | 56 | ``` 57 | python infer.py model_path='checkpoint/epoches_81_loss_15.7157' 58 | ``` 59 | 60 | * Resume training with the following: 61 | 62 | ``` 63 | python main.py model_path='checkpoint/epoches_81_loss_15.7157' 64 | ``` 65 | 66 | ## Prepare the Data 67 | 68 | You can prepare the data yourself, or you can download preprocessed data. 69 | 70 | ### Option 1 - Download the preprocessed data 71 | 72 | There are two download routes provided for the preprocessed data. 73 | 74 | #### Download and untar the data 75 | To download tar zipped files by link, download the following files and extract in a folder called `data` in the root of this example folder. You should have the following structure: 76 | ``` 77 | /lipnet/data/align 78 | /lipnet/data/datasets 79 | ``` 80 | 81 | * [align files](https://mxnet-public.s3.amazonaws.com/lipnet/data-archives/align.tgz) 82 | * [datasets files](https://mxnet-public.s3.amazonaws.com/lipnet/data-archives/datasets.tgz) 83 | 84 | #### Use AWS CLI to sync the data 85 | To get the folders and files all unzipped with AWS CLI, can use the following command. This will provide the folder structure for you. Run this command from `/lipnet/`: 86 | 87 | ``` 88 | aws s3 sync s3://mxnet-public/lipnet/data . 89 | ``` 90 | 91 | ### Option 2 (part 1)- Download the raw dataset 92 | - Outputs 93 | - The Total Movies(mp4): 16GB 94 | - The Total Aligns(text): 134MB 95 | - Arguments 96 | - src_path : Path for videos (default='./data/mp4s/') 97 | - align_path : Path for aligns (default='./data/') 98 | - n_process : num of process (default=1) 99 | 100 | ``` 101 | cd ./utils && python download_data.py --n_process=$(nproc) 102 | ``` 103 | 104 | ### Option 2 (part 2) Preprocess the raw dataset: Extracting the mouth images from a video and save it 105 | 106 | * Using Face Landmark Detection(http://dlib.net/) 107 | 108 | #### Preprocess (preprocess_data.py) 109 | * If there is no landmark, it download automatically. 110 | * Using Face Landmark Detection, It extract the mouth from a video. 111 | 112 | - example: 113 | - video: ./data/mp4s/s2/bbbf7p.mpg 114 | - align(target): ./data/align/s2/bbbf7p.align 115 | : 'sil bin blue by f seven please sil' 116 | 117 | 118 | - Video to the images (75 Frames) 119 | 120 | Frame 0 | Frame 1 | ... | Frame 74 | 121 | :-------------------------:|:-------------------------:|:-------------------------:|:-------------------------: 122 | ![](asset/s2_bbbf7p_000.png) | ![](asset/s2_bbbf7p_001.png) | ... | ![](asset/s2_bbbf7p_074.png) 123 | 124 | - Extract the mouth from images 125 | 126 | Frame 0 | Frame 1 | ... | Frame 74 | 127 | :-------------------------:|:-------------------------:|:-------------------------:|:-------------------------: 128 | ![](asset/mouth_000.png) | ![](asset/mouth_001.png) | ... | ![](asset/mouth_074.png) 129 | 130 | * Save the result images into tgt_path. 131 | 132 | ---- 133 | 134 | #### How to run the preprocess script 135 | 136 | - Arguments 137 | - src_path : Path for videos (default='./data/mp4s/') 138 | - tgt_path : Path for preprocessed images (default='./data/datasets/') 139 | - n_process : num of process (default=1) 140 | 141 | - Outputs 142 | - The Total Images(png): 19GB 143 | - Elapsed time 144 | - About 54 Hours using 1 process 145 | - If you use the multi-processes, you can finish the number of processes faster. 146 | - e.g) 9 hours using 6 processes 147 | 148 | You can run the preprocessing with just one processor, but this will take a long time (>48 hours). To use all of the available processors, use the following command: 149 | 150 | ``` 151 | cd ./utils && python preprocess_data.py --n_process=$(nproc) 152 | ``` 153 | 154 | #### Output: Data structure of the preprocessed data 155 | 156 | ``` 157 | The training data folder should look like : 158 | 159 | |--datasets 160 | |--s1 161 | |--bbir7s 162 | |--mouth_000.png 163 | |--mouth_001.png 164 | ... 165 | |--bgaa8p 166 | |--mouth_000.png 167 | |--mouth_001.png 168 | ... 169 | |--s2 170 | ... 171 | |--align 172 | |--bw1d8a.align 173 | |--bggzzs.align 174 | ... 175 | 176 | ``` 177 | 178 | --- 179 | 180 | ## Training 181 | After you have acquired the preprocessed data you are ready to train the lipnet model. 182 | 183 | - According to [LipNet: End-to-End Sentence-level Lipreading](https://arxiv.org/abs/1611.01599), four (S1, S2, S20, S22) of the 34 subjects are used for evaluation. 184 | The other subjects are used for training. 185 | 186 | - To use the multi-gpu, it is recommended to make the batch size $(num_gpus) times larger. 187 | 188 | - e.g) 1-gpu and 128 batch_size > 2-gpus 256 batch_size 189 | 190 | 191 | - arguments 192 | - batch_size : Define batch size (default=64) 193 | - epochs : Define total epochs (default=100) 194 | - image_path : Path for lip image files (default='./data/datasets/') 195 | - align_path : Path for align files (default='./data/align/') 196 | - dr_rate : Dropout rate(default=0.5) 197 | - num_gpus : Num of gpus (if num_gpus is 0, then use cpu) (default=1) 198 | - num_workers : Num of workers when generating data (default=0) 199 | - model_path : Path of pretrained model (default=None) 200 | 201 | ``` 202 | python main.py 203 | ``` 204 | 205 | --- 206 | 207 | ## Test Environment 208 | - 72 CPU cores 209 | - 1 GPU (NVIDIA Tesla V100 SXM2 32 GB) 210 | - 128 Batch Size 211 | 212 | - It takes over 24 hours (60 epochs) to get some good results. 213 | 214 | --- 215 | 216 | ## Inference 217 | 218 | - arguments 219 | - batch_size : Define batch size (default=64) 220 | - image_path : Path for lip image files (default='./data/datasets/') 221 | - align_path : Path for align files (default='./data/align/') 222 | - num_gpus : Num of gpus (if num_gpus is 0, then use cpu) (default=1) 223 | - num_workers : Num of workers when generating data (default=0) 224 | - data_type : 'train' or 'valid' (defalut='valid') 225 | - model_path : Path of pretrained model (default=None) 226 | 227 | ``` 228 | python infer.py --model_path=$(model_path) 229 | ``` 230 | 231 | 232 | ``` 233 | [Target] 234 | ['lay green with a zero again', 235 | 'bin blue with r nine please', 236 | 'set blue with e five again', 237 | 'bin green by t seven soon', 238 | 'lay red at d five now', 239 | 'bin green in x eight now', 240 | 'bin blue with e one now', 241 | 'lay red at j nine now'] 242 | ``` 243 | 244 | ``` 245 | [Pred] 246 | ['lay green with s zero again', 247 | 'bin blue with r nine please', 248 | 'set blue with e five again', 249 | 'bin green by t seven soon', 250 | 'lay red at c five now', 251 | 'bin green in x eight now', 252 | 'bin blue with m one now', 253 | 'lay red at j nine now'] 254 | ``` 255 | -------------------------------------------------------------------------------- /utils/preprocess_data.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | """ 19 | Module: preprocess_data 20 | Reference: https://github.com/rizkiarm/LipNet 21 | """ 22 | 23 | # pylint: disable=too-many-locals, no-self-use, c-extension-no-member 24 | 25 | import os 26 | import fnmatch 27 | import errno 28 | import numpy as np 29 | from scipy import ndimage 30 | from scipy.misc import imresize 31 | from skimage import io 32 | import skvideo.io 33 | import dlib 34 | 35 | def mkdir_p(path): 36 | """ 37 | Make a directory 38 | """ 39 | try: 40 | os.makedirs(path) 41 | except OSError as exc: # Python >2.5 42 | if exc.errno == errno.EEXIST and os.path.isdir(path): 43 | pass 44 | else: 45 | raise 46 | 47 | def find_files(directory, pattern): 48 | """ 49 | Find files 50 | """ 51 | for root, _, files in os.walk(directory): 52 | for basename in files: 53 | if fnmatch.fnmatch(basename, pattern): 54 | filename = os.path.join(root, basename) 55 | yield filename 56 | 57 | class Video(object): 58 | """ 59 | Preprocess for Video 60 | """ 61 | def __init__(self, vtype='mouth', face_predictor_path=None): 62 | if vtype == 'face' and face_predictor_path is None: 63 | raise AttributeError('Face video need to be accompanied with face predictor') 64 | self.face_predictor_path = face_predictor_path 65 | self.vtype = vtype 66 | self.face = None 67 | self.mouth = None 68 | self.data = None 69 | self.length = None 70 | 71 | def from_frames(self, path): 72 | """ 73 | Read from frames 74 | """ 75 | frames_path = sorted([os.path.join(path, x) for x in os.listdir(path)]) 76 | frames = [ndimage.imread(frame_path) for frame_path in frames_path] 77 | self.handle_type(frames) 78 | return self 79 | 80 | def from_video(self, path): 81 | """ 82 | Read from videos 83 | """ 84 | frames = self.get_video_frames(path) 85 | self.handle_type(frames) 86 | return self 87 | 88 | def from_array(self, frames): 89 | """ 90 | Read from array 91 | """ 92 | self.handle_type(frames) 93 | return self 94 | 95 | def handle_type(self, frames): 96 | """ 97 | Config video types 98 | """ 99 | if self.vtype == 'mouth': 100 | self.process_frames_mouth(frames) 101 | elif self.vtype == 'face': 102 | self.process_frames_face(frames) 103 | else: 104 | raise Exception('Video type not found') 105 | 106 | def process_frames_face(self, frames): 107 | """ 108 | Preprocess from frames using face detector 109 | """ 110 | detector = dlib.get_frontal_face_detector() 111 | predictor = dlib.shape_predictor(self.face_predictor_path) 112 | mouth_frames = self.get_frames_mouth(detector, predictor, frames) 113 | self.face = np.array(frames) 114 | self.mouth = np.array(mouth_frames) 115 | if mouth_frames[0] is not None: 116 | self.set_data(mouth_frames) 117 | 118 | def process_frames_mouth(self, frames): 119 | """ 120 | Preprocess from frames using mouth detector 121 | """ 122 | self.face = np.array(frames) 123 | self.mouth = np.array(frames) 124 | self.set_data(frames) 125 | 126 | def get_frames_mouth(self, detector, predictor, frames): 127 | """ 128 | Get frames using mouth crop 129 | """ 130 | mouth_width = 100 131 | mouth_height = 50 132 | horizontal_pad = 0.19 133 | normalize_ratio = None 134 | mouth_frames = [] 135 | for frame in frames: 136 | dets = detector(frame, 1) 137 | shape = None 138 | for det in dets: 139 | shape = predictor(frame, det) 140 | i = -1 141 | if shape is None: # Detector doesn't detect face, just return None 142 | return [None] 143 | mouth_points = [] 144 | for part in shape.parts(): 145 | i += 1 146 | if i < 48: # Only take mouth region 147 | continue 148 | mouth_points.append((part.x, part.y)) 149 | np_mouth_points = np.array(mouth_points) 150 | 151 | mouth_centroid = np.mean(np_mouth_points[:, -2:], axis=0) 152 | 153 | if normalize_ratio is None: 154 | mouth_left = np.min(np_mouth_points[:, :-1]) * (1.0 - horizontal_pad) 155 | mouth_right = np.max(np_mouth_points[:, :-1]) * (1.0 + horizontal_pad) 156 | 157 | normalize_ratio = mouth_width / float(mouth_right - mouth_left) 158 | 159 | new_img_shape = (int(frame.shape[0] * normalize_ratio), 160 | int(frame.shape[1] * normalize_ratio)) 161 | resized_img = imresize(frame, new_img_shape) 162 | 163 | mouth_centroid_norm = mouth_centroid * normalize_ratio 164 | 165 | mouth_l = int(mouth_centroid_norm[0] - mouth_width / 2) 166 | mouth_r = int(mouth_centroid_norm[0] + mouth_width / 2) 167 | mouth_t = int(mouth_centroid_norm[1] - mouth_height / 2) 168 | mouth_b = int(mouth_centroid_norm[1] + mouth_height / 2) 169 | 170 | mouth_crop_image = resized_img[mouth_t:mouth_b, mouth_l:mouth_r] 171 | 172 | mouth_frames.append(mouth_crop_image) 173 | return mouth_frames 174 | 175 | def get_video_frames(self, path): 176 | """ 177 | Get video frames 178 | """ 179 | videogen = skvideo.io.vreader(path) 180 | frames = np.array([frame for frame in videogen]) 181 | return frames 182 | 183 | def set_data(self, frames): 184 | """ 185 | Prepare the input of model 186 | """ 187 | data_frames = [] 188 | for frame in frames: 189 | #frame H x W x C 190 | frame = frame.swapaxes(0, 1) # swap width and height to form format W x H x C 191 | if len(frame.shape) < 3: 192 | frame = np.array([frame]).swapaxes(0, 2).swapaxes(0, 1) # Add grayscale channel 193 | data_frames.append(frame) 194 | frames_n = len(data_frames) 195 | data_frames = np.array(data_frames) # T x W x H x C 196 | data_frames = np.rollaxis(data_frames, 3) # C x T x W x H 197 | data_frames = data_frames.swapaxes(2, 3) # C x T x H x W = NCDHW 198 | 199 | self.data = data_frames 200 | self.length = frames_n 201 | 202 | def preprocess(from_idx, to_idx, _params): 203 | """ 204 | Preprocess: Convert a video into the mouth images 205 | """ 206 | source_exts = '*.mpg' 207 | src_path = _params['src_path'] 208 | tgt_path = _params['tgt_path'] 209 | face_predictor_path = './shape_predictor_68_face_landmarks.dat' 210 | 211 | succ = set() 212 | fail = set() 213 | for idx in range(from_idx, to_idx): 214 | s_id = 's' + str(idx) + '/' 215 | source_path = src_path + '/' + s_id 216 | target_path = tgt_path + '/' + s_id 217 | fail_cnt = 0 218 | for filepath in find_files(source_path, source_exts): 219 | print("Processing: {}".format(filepath)) 220 | filepath_wo_ext = os.path.splitext(filepath)[0].split('/')[-2:] 221 | target_dir = os.path.join(tgt_path, '/'.join(filepath_wo_ext)) 222 | 223 | if os.path.exists(target_dir): 224 | continue 225 | 226 | try: 227 | video = Video(vtype='face', \ 228 | face_predictor_path=face_predictor_path).from_video(filepath) 229 | mkdir_p(target_dir) 230 | i = 0 231 | if video.mouth[0] is None: 232 | continue 233 | for frame in video.mouth: 234 | io.imsave(os.path.join(target_dir, "mouth_{0:03d}.png".format(i)), frame) 235 | i += 1 236 | except ValueError as error: 237 | print(error) 238 | fail_cnt += 1 239 | if fail_cnt == 0: 240 | succ.add(idx) 241 | else: 242 | fail.add(idx) 243 | return (succ, fail) 244 | 245 | if __name__ == '__main__': 246 | import argparse 247 | from multi import multi_p_run, put_worker 248 | PARSER = argparse.ArgumentParser() 249 | PARSER.add_argument('--src_path', type=str, default='../data/mp4s') 250 | PARSER.add_argument('--tgt_path', type=str, default='../data/datasets') 251 | PARSER.add_argument('--n_process', type=int, default=1) 252 | CONFIG = PARSER.parse_args() 253 | N_PROCESS = CONFIG.n_process 254 | PARAMS = {'src_path':CONFIG.src_path, 255 | 'tgt_path':CONFIG.tgt_path} 256 | 257 | os.makedirs('{tgt_path}'.format(tgt_path=PARAMS['tgt_path']), exist_ok=True) 258 | 259 | if N_PROCESS == 1: 260 | RES = preprocess(0, 35, PARAMS) 261 | else: 262 | RES = multi_p_run(35, put_worker, preprocess, PARAMS, N_PROCESS) 263 | -------------------------------------------------------------------------------- /utils/run_preprocess_single_process.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "from download_data import multi_p_run, put_worker, test_worker, download_mp4, download_align" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import os" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 3, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "tot_movies=35" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "## TEST" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 4, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]\n", 56 | "5\n", 57 | "35 >> [[0, 7], [7, 14], [14, 21], [21, 28], [28, 35]]\n", 58 | "[{'succ': {0, 1, 2, 3, 4, 5, 6}, 'fail': set()}, {'succ': {7, 8, 9, 10, 11, 12, 13}, 'fail': set()}, {'succ': {14, 15, 16, 17, 18, 19, 20}, 'fail': set()}, {'succ': {21, 22, 23, 24, 25, 26, 27}, 'fail': set()}, {'succ': {32, 33, 34, 28, 29, 30, 31}, 'fail': set()}]\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "res = multi_p_run(tot_movies, put_worker, test_worker, params={}, n_process=5)\n", 64 | "print (res)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "## Download Data" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "### Aligns" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 5, 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s0/align/s0.tar && tar -xvf s0.tar\n", 91 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s1/align/s1.tar && tar -xvf s1.tar\n", 92 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s2/align/s2.tar && tar -xvf s2.tar\n", 93 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s3/align/s3.tar && tar -xvf s3.tar\n", 94 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s4/align/s4.tar && tar -xvf s4.tar\n", 95 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s5/align/s5.tar && tar -xvf s5.tar\n", 96 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s6/align/s6.tar && tar -xvf s6.tar\n", 97 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s7/align/s7.tar && tar -xvf s7.tar\n", 98 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s8/align/s8.tar && tar -xvf s8.tar\n", 99 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s9/align/s9.tar && tar -xvf s9.tar\n", 100 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s10/align/s10.tar && tar -xvf s10.tar\n", 101 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s11/align/s11.tar && tar -xvf s11.tar\n", 102 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s12/align/s12.tar && tar -xvf s12.tar\n", 103 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s13/align/s13.tar && tar -xvf s13.tar\n", 104 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s14/align/s14.tar && tar -xvf s14.tar\n", 105 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s15/align/s15.tar && tar -xvf s15.tar\n", 106 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s16/align/s16.tar && tar -xvf s16.tar\n", 107 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s17/align/s17.tar && tar -xvf s17.tar\n", 108 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s18/align/s18.tar && tar -xvf s18.tar\n", 109 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s19/align/s19.tar && tar -xvf s19.tar\n", 110 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s20/align/s20.tar && tar -xvf s20.tar\n", 111 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s21/align/s21.tar && tar -xvf s21.tar\n", 112 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s22/align/s22.tar && tar -xvf s22.tar\n", 113 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s23/align/s23.tar && tar -xvf s23.tar\n", 114 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s24/align/s24.tar && tar -xvf s24.tar\n", 115 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s25/align/s25.tar && tar -xvf s25.tar\n", 116 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s26/align/s26.tar && tar -xvf s26.tar\n", 117 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s27/align/s27.tar && tar -xvf s27.tar\n", 118 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s28/align/s28.tar && tar -xvf s28.tar\n", 119 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s29/align/s29.tar && tar -xvf s29.tar\n", 120 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s30/align/s30.tar && tar -xvf s30.tar\n", 121 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s31/align/s31.tar && tar -xvf s31.tar\n", 122 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s32/align/s32.tar && tar -xvf s32.tar\n", 123 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s33/align/s33.tar && tar -xvf s33.tar\n", 124 | "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s34/align/s34.tar && tar -xvf s34.tar\n" 125 | ] 126 | } 127 | ], 128 | "source": [ 129 | "align_path = '../data/align'\n", 130 | "os.makedirs(align_path, exist_ok=True)\n", 131 | "\n", 132 | "res = download_align(0, tot_movies, {'align_path':align_path})" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 6, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34}, set())\n" 145 | ] 146 | }, 147 | { 148 | "data": { 149 | "text/plain": [ 150 | "0" 151 | ] 152 | }, 153 | "execution_count": 6, 154 | "metadata": {}, 155 | "output_type": "execute_result" 156 | } 157 | ], 158 | "source": [ 159 | "print (res)\n", 160 | "os.system('rm -f {align_path}/*.tar && rm -f {align_path}/Thumbs.db'.format(align_path=align_path))" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 7, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "### Moives(MP4s)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 8, 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "name": "stdout", 179 | "output_type": "stream", 180 | "text": [ 181 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s0/video/s0.mpg_vcd.zip --output s0.mpg_vcd.zip && unzip s0.mpg_vcd.zip\n", 182 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s1/video/s1.mpg_vcd.zip --output s1.mpg_vcd.zip && unzip s1.mpg_vcd.zip\n", 183 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s2/video/s2.mpg_vcd.zip --output s2.mpg_vcd.zip && unzip s2.mpg_vcd.zip\n", 184 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s3/video/s3.mpg_vcd.zip --output s3.mpg_vcd.zip && unzip s3.mpg_vcd.zip\n", 185 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s4/video/s4.mpg_vcd.zip --output s4.mpg_vcd.zip && unzip s4.mpg_vcd.zip\n", 186 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s5/video/s5.mpg_vcd.zip --output s5.mpg_vcd.zip && unzip s5.mpg_vcd.zip\n", 187 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s6/video/s6.mpg_vcd.zip --output s6.mpg_vcd.zip && unzip s6.mpg_vcd.zip\n", 188 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s7/video/s7.mpg_vcd.zip --output s7.mpg_vcd.zip && unzip s7.mpg_vcd.zip\n", 189 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s8/video/s8.mpg_vcd.zip --output s8.mpg_vcd.zip && unzip s8.mpg_vcd.zip\n", 190 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s9/video/s9.mpg_vcd.zip --output s9.mpg_vcd.zip && unzip s9.mpg_vcd.zip\n", 191 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s10/video/s10.mpg_vcd.zip --output s10.mpg_vcd.zip && unzip s10.mpg_vcd.zip\n", 192 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s11/video/s11.mpg_vcd.zip --output s11.mpg_vcd.zip && unzip s11.mpg_vcd.zip\n", 193 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s12/video/s12.mpg_vcd.zip --output s12.mpg_vcd.zip && unzip s12.mpg_vcd.zip\n", 194 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s13/video/s13.mpg_vcd.zip --output s13.mpg_vcd.zip && unzip s13.mpg_vcd.zip\n", 195 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s14/video/s14.mpg_vcd.zip --output s14.mpg_vcd.zip && unzip s14.mpg_vcd.zip\n", 196 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s15/video/s15.mpg_vcd.zip --output s15.mpg_vcd.zip && unzip s15.mpg_vcd.zip\n", 197 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s16/video/s16.mpg_vcd.zip --output s16.mpg_vcd.zip && unzip s16.mpg_vcd.zip\n", 198 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s17/video/s17.mpg_vcd.zip --output s17.mpg_vcd.zip && unzip s17.mpg_vcd.zip\n", 199 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s18/video/s18.mpg_vcd.zip --output s18.mpg_vcd.zip && unzip s18.mpg_vcd.zip\n", 200 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s19/video/s19.mpg_vcd.zip --output s19.mpg_vcd.zip && unzip s19.mpg_vcd.zip\n", 201 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s20/video/s20.mpg_vcd.zip --output s20.mpg_vcd.zip && unzip s20.mpg_vcd.zip\n", 202 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s21/video/s21.mpg_vcd.zip --output s21.mpg_vcd.zip && unzip s21.mpg_vcd.zip\n", 203 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s22/video/s22.mpg_vcd.zip --output s22.mpg_vcd.zip && unzip s22.mpg_vcd.zip\n", 204 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s23/video/s23.mpg_vcd.zip --output s23.mpg_vcd.zip && unzip s23.mpg_vcd.zip\n", 205 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s24/video/s24.mpg_vcd.zip --output s24.mpg_vcd.zip && unzip s24.mpg_vcd.zip\n", 206 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s25/video/s25.mpg_vcd.zip --output s25.mpg_vcd.zip && unzip s25.mpg_vcd.zip\n", 207 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s26/video/s26.mpg_vcd.zip --output s26.mpg_vcd.zip && unzip s26.mpg_vcd.zip\n", 208 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s27/video/s27.mpg_vcd.zip --output s27.mpg_vcd.zip && unzip s27.mpg_vcd.zip\n", 209 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s28/video/s28.mpg_vcd.zip --output s28.mpg_vcd.zip && unzip s28.mpg_vcd.zip\n", 210 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s29/video/s29.mpg_vcd.zip --output s29.mpg_vcd.zip && unzip s29.mpg_vcd.zip\n", 211 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s30/video/s30.mpg_vcd.zip --output s30.mpg_vcd.zip && unzip s30.mpg_vcd.zip\n", 212 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s31/video/s31.mpg_vcd.zip --output s31.mpg_vcd.zip && unzip s31.mpg_vcd.zip\n", 213 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s32/video/s32.mpg_vcd.zip --output s32.mpg_vcd.zip && unzip s32.mpg_vcd.zip\n", 214 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s33/video/s33.mpg_vcd.zip --output s33.mpg_vcd.zip && unzip s33.mpg_vcd.zip\n", 215 | "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s34/video/s34.mpg_vcd.zip --output s34.mpg_vcd.zip && unzip s34.mpg_vcd.zip\n" 216 | ] 217 | } 218 | ], 219 | "source": [ 220 | "src_path = '../data/mp4s'\n", 221 | "res = download_mp4(0, tot_movies, {'src_path':src_path})" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 9, 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "name": "stdout", 231 | "output_type": "stream", 232 | "text": [ 233 | "({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34}, set())\n" 234 | ] 235 | }, 236 | { 237 | "data": { 238 | "text/plain": [ 239 | "0" 240 | ] 241 | }, 242 | "execution_count": 9, 243 | "metadata": {}, 244 | "output_type": "execute_result" 245 | } 246 | ], 247 | "source": [ 248 | "print (res)\n", 249 | "os.system('rm -f {src_path}/*.zip && rm -f {src_path}/*/Thumbs.db'.format(src_path=src_path))" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": {}, 262 | "source": [ 263 | "## Preprocess Data" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 10, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "from preprocess_data import preprocess, find_files, Video" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 11, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "tgt_path = '../data/datasets'" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 12, 287 | "metadata": {}, 288 | "outputs": [ 289 | { 290 | "data": { 291 | "text/plain": [ 292 | "0" 293 | ] 294 | }, 295 | "execution_count": 12, 296 | "metadata": {}, 297 | "output_type": "execute_result" 298 | } 299 | ], 300 | "source": [ 301 | "os.makedirs('{tgt_path}'.format(tgt_path=tgt_path), exist_ok=True)\n", 302 | "os.system('rm -rf {tgt_path}'.format(tgt_path=tgt_path))" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 13, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "res = preprocess(0, tot_movies, {'src_path':src_path, 'tgt_path':tgt_path})" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 14, 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "name": "stdout", 321 | "output_type": "stream", 322 | "text": [ 323 | "({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34}, set())\n" 324 | ] 325 | } 326 | ], 327 | "source": [ 328 | "print (res)" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [] 337 | } 338 | ], 339 | "metadata": { 340 | "kernelspec": { 341 | "display_name": "Python [default]", 342 | "language": "python", 343 | "name": "python3" 344 | }, 345 | "language_info": { 346 | "codemirror_mode": { 347 | "name": "ipython", 348 | "version": 3 349 | }, 350 | "file_extension": ".py", 351 | "mimetype": "text/x-python", 352 | "name": "python", 353 | "nbconvert_exporter": "python", 354 | "pygments_lexer": "ipython3", 355 | "version": "3.6.4" 356 | } 357 | }, 358 | "nbformat": 4, 359 | "nbformat_minor": 2 360 | } 361 | --------------------------------------------------------------------------------