├── .gitattributes
├── .gitignore
├── .travis.yml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── build-windows.bat
├── chapters.py
├── common.py
├── demux.py
├── keyframes.py
├── licenses
├── OpenCV.txt
└── SciPy.txt
├── regression-tests.py
├── requirements-win.txt
├── requirements.txt
├── run-tests.py
├── setup.py
├── subs.py
├── sushi.py
├── tests.example.json
├── tests
├── __init__.py
├── demuxing.py
├── main.py
├── subtitles.py
└── timecodes.py
└── wav.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
4 | # Custom for Visual Studio
5 | *.cs diff=csharp
6 | *.sln merge=union
7 | *.csproj merge=union
8 | *.vbproj merge=union
9 | *.fsproj merge=union
10 | *.dbproj merge=union
11 |
12 | # Standard to msysgit
13 | *.doc diff=astextplain
14 | *.DOC diff=astextplain
15 | *.docx diff=astextplain
16 | *.DOCX diff=astextplain
17 | *.dot diff=astextplain
18 | *.DOT diff=astextplain
19 | *.pdf diff=astextplain
20 | *.PDF diff=astextplain
21 | *.rtf diff=astextplain
22 | *.RTF diff=astextplain
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | test
3 | *.pyc
4 | tests/media
5 | dist
6 | build
7 | tests.json
8 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 |
3 | virtualenv:
4 | system_site_packages: true
5 | before_install:
6 | - sudo apt-get update
7 | - sudo apt-get install python-opencv
8 | - sudo dpkg -L python-opencv
9 | - sudo ln /dev/null /dev/raw1394
10 | install:
11 | - "pip install -r requirements.txt"
12 |
13 | python:
14 | - "2.7"
15 | script:
16 | - python run-tests.py
17 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | Before submitting a bug, please read [how to do it](https://github.com/tp7/Sushi/wiki/Common-errors).
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2014-2017 Victor Efimov
2 |
3 | Permission is hereby granted, free of charge, to any person
4 | obtaining a copy of this software and associated documentation
5 | files (the "Software"), to deal in the Software without
6 | restriction, including without limitation the rights to use,
7 | copy, modify, merge, publish, distribute, sublicense, and/or sell
8 | copies of the Software, and to permit persons to whom the
9 | Software is furnished to do so, subject to the following
10 | conditions:
11 |
12 | The above copyright notice and this permission notice shall be
13 | included in all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
17 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
19 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
20 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 | OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Sushi [](https://travis-ci.org/tp7/Sushi)
2 | Automatic shifter for SRT and ASS subtitle based on audio streams.
3 |
4 | ### Purpose
5 | Imagine you've got a subtitle file synced to one video file, but you want to use these subtitles with some other video you've got via totally legal means. The common example is TV vs. BD releases, PAL vs. NTSC video and releases in different countries. In a lot of cases, subtitles won't match right away and you need to sync them.
6 |
7 | The purpose of this script is to avoid all the hassle of manual syncing. It attempts to synchronize subtitles by finding similarities in audio streams. The script is very fast and can be used right when you want to watch something.
8 |
9 | ### Downloads
10 | The latest Windows binary release can always be found in the [releases][1] section. You need the 7z archive in the top entry.
11 |
12 | ### How it works
13 | You need to provide two audio files and a subtitle file that matches one of those files. For every line of the subtitles, the script will extract corresponding audio from the source audio stream and will try to find the closest similar pattern in the destination audio stream, obtaining a shift value which is later applied to the subtitles.
14 |
15 | Detailed explanation of Sushi workflow and description of command-line arguments can be found in the [wiki][2].
16 |
17 | ### Usage
18 | The minimal command line looks like this:
19 | ```
20 | python sushi.py --src hdtv.wav --dst bluray.wav --script subs.ass
21 | ```
22 | Output file name is optional - `"{destination_path}.sushi.{subtitles_format}"` is used by default. See the [usage][3] page of the wiki for further examples.
23 |
24 | Do note that WAV is not the only format Sushi can work with. It can process audio/video files directly and decode various audio formats, provided that ffmpeg is available. For additional info refer to the [Demuxing][4] part of the wiki.
25 |
26 | ### Requirements
27 | Sushi should work on Windows, Linux and OS X. Please open an issue if it doesn't. To run it, you have to have the following installed:
28 |
29 | 1. [Python 2.7.x][5]
30 | 2. [NumPy][6] (1.8 or newer)
31 | 3. [OpenCV 2.4.x or newer][7] (on Windows putting [this file][8] in the same folder as Sushi should be enough, assuming you use x86 Python)
32 |
33 | Optionally, you might want:
34 |
35 | 1. [FFmpeg][9] for any kind of demuxing
36 | 2. [MkvExtract][10] for faster timecodes extraction when demuxing
37 | 3. [SCXvid-standalone][11] if you want Sushi to make keyframes
38 | 4. [Colorama](https://github.com/tartley/colorama) to add colors to console output on Windows
39 |
40 | The provided Windows binaries include all required components and Colorama so you don't have to install them if you use the binary distribution. You still have to download other applications yourself if you want to use Sushi's demuxing capabilities.
41 |
42 | #### Installation on Mac OS X
43 |
44 | No binary packages are provided for OS X right now so you'll have to use the script form. Assuming you have python 2, pip and [homebrew](http://brew.sh/) installed, run the following:
45 | ```bash
46 | brew tap homebrew/science
47 | brew install git opencv
48 | pip install numpy
49 | git clone https://github.com/tp7/sushi
50 | # create a symlink if you want to run sushi globally
51 | ln -s `pwd`/sushi/sushi.py /usr/local/bin/sushi
52 | # install some optional dependencies
53 | brew install ffmpeg mkvtoolnix
54 | ```
55 | If you don't have pip, you can install numpy with homebrew, but that will probably add a few more dependencies.
56 | ```bash
57 | brew tap homebrew/python
58 | brew install numpy
59 | ```
60 |
61 | #### Installation on Linux
62 | If you have apt-get available, the installation process is trivial.
63 | ```bash
64 | sudo apt-get update
65 | sudo apt-get install git python python-numpy python-opencv
66 | git clone https://github.com/tp7/sushi
67 | ln -s `pwd`/sushi/sushi.py /usr/local/bin/sushi
68 | ```
69 |
70 | ### Limitations
71 | This script will never be able to property handle frame-by-frame typesetting. If underlying video stream changes (e.g. has different telecine pattern), you might get incorrect output.
72 |
73 | This script cannot improve bad timing. If original lines are mistimed, they will be mistimed in the output file too.
74 |
75 | In short, while this might be safe for immediate viewing, you probably shouldn't use it to blindly shift subtitles for permanent storing.
76 |
77 |
78 | [1]: https://github.com/tp7/Sushi/releases
79 | [2]: https://github.com/tp7/Sushi/wiki
80 | [3]: https://github.com/tp7/Sushi/wiki/Examples
81 | [4]: https://github.com/tp7/Sushi/wiki/Demuxing
82 | [5]: https://www.python.org/downloads/
83 | [6]: http://www.scipy.org/scipylib/download.html
84 | [7]: http://opencv.org/
85 | [8]: https://www.dropbox.com/s/nlylgdh4bgrjgxv/cv2.pyd?dl=0
86 | [9]: http://www.ffmpeg.org/download.html
87 | [10]: http://www.bunkus.org/videotools/mkvtoolnix/downloads.html
88 | [11]: https://github.com/soyokaze/SCXvid-standalone/releases
89 |
--------------------------------------------------------------------------------
/build-windows.bat:
--------------------------------------------------------------------------------
1 | rmdir /S /Q dist
2 |
3 | pyinstaller --noupx --onefile --noconfirm ^
4 | --exclude-module Tkconstants ^
5 | --exclude-module Tkinter ^
6 | --exclude-module matplotlib ^
7 | sushi.py
8 |
9 | mkdir dist\licenses
10 | copy /Y licenses\* dist\licenses\*
11 | copy LICENSE dist\licenses\Sushi.txt
12 | copy README.md dist\readme.md
--------------------------------------------------------------------------------
/chapters.py:
--------------------------------------------------------------------------------
1 | import re
2 | import common
3 |
4 |
5 | def parse_times(times):
6 | result = []
7 | for t in times:
8 | hours, minutes, seconds = map(float, t.split(':'))
9 | result.append(hours * 3600 + minutes * 60 + seconds)
10 |
11 | result.sort()
12 | if result[0] != 0:
13 | result.insert(0, 0)
14 | return result
15 |
16 |
17 | def parse_xml_start_times(text):
18 | times = re.findall(r'(\d+:\d+:\d+\.\d+)', text)
19 | return parse_times(times)
20 |
21 |
22 | def get_xml_start_times(path):
23 | return parse_xml_start_times(common.read_all_text(path))
24 |
25 |
26 | def parse_ogm_start_times(text):
27 | times = re.findall(r'CHAPTER\d+=(\d+:\d+:\d+\.\d+)', text, flags=re.IGNORECASE)
28 | return parse_times(times)
29 |
30 |
31 | def get_ogm_start_times(path):
32 | return parse_ogm_start_times(common.read_all_text(path))
33 |
34 |
35 | def format_ogm_chapters(start_times):
36 | return "\n".join("CHAPTER{0:02}={1}\nCHAPTER{0:02}NAME=".format(idx+1, common.format_srt_time(start).replace(',', '.'))
37 | for idx, start in enumerate(start_times)) + "\n"
38 |
--------------------------------------------------------------------------------
/common.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | class SushiError(Exception):
5 | pass
6 |
7 |
8 | def get_extension(path):
9 | return (os.path.splitext(path)[1]).lower()
10 |
11 |
12 | def read_all_text(path):
13 | with open(path) as file:
14 | return file.read()
15 |
16 |
17 | def ensure_static_collection(value):
18 | if isinstance(value, (set, list, tuple)):
19 | return value
20 | return list(value)
21 |
22 |
23 | def format_srt_time(seconds):
24 | cs = round(seconds * 1000)
25 | return u'{0:02d}:{1:02d}:{2:02d},{3:03d}'.format(
26 | int(cs // 3600000),
27 | int((cs // 60000) % 60),
28 | int((cs // 1000) % 60),
29 | int(cs % 1000))
30 |
31 |
32 | def format_time(seconds):
33 | cs = round(seconds * 100)
34 | return u'{0}:{1:02d}:{2:02d}.{3:02d}'.format(
35 | int(cs // 360000),
36 | int((cs // 6000) % 60),
37 | int((cs // 100) % 60),
38 | int(cs % 100))
39 |
40 |
41 | def clip(value, minimum, maximum):
42 | return max(min(value, maximum), minimum)
43 |
--------------------------------------------------------------------------------
/demux.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import subprocess
4 | from collections import namedtuple
5 | import logging
6 | import bisect
7 |
8 | from common import SushiError, get_extension
9 | import chapters
10 |
11 | MediaStreamInfo = namedtuple('MediaStreamInfo', ['id', 'info', 'default', 'title'])
12 | SubtitlesStreamInfo = namedtuple('SubtitlesStreamInfo', ['id', 'info', 'type', 'default', 'title'])
13 | MediaInfo = namedtuple('MediaInfo', ['video', 'audio', 'subtitles', 'chapters'])
14 |
15 |
16 | class FFmpeg(object):
17 | @staticmethod
18 | def get_info(path):
19 | try:
20 | process = subprocess.Popen(['ffmpeg', '-hide_banner', '-i', path], stderr=subprocess.PIPE)
21 | out, err = process.communicate()
22 | process.wait()
23 | return err
24 | except OSError as e:
25 | if e.errno == 2:
26 | raise SushiError("Couldn't invoke ffmpeg, check that it's installed")
27 | raise
28 |
29 | @staticmethod
30 | def demux_file(input_path, **kwargs):
31 | args = ['ffmpeg', '-hide_banner', '-i', input_path, '-y']
32 |
33 | audio_stream = kwargs.get('audio_stream', None)
34 | audio_path = kwargs.get('audio_path', None)
35 | audio_rate = kwargs.get('audio_rate', None)
36 | if audio_stream is not None:
37 | args.extend(('-map', '0:{0}'.format(audio_stream)))
38 | if audio_rate:
39 | args.extend(('-ar', str(audio_rate)))
40 | args.extend(('-ac', '1', '-acodec', 'pcm_s16le', audio_path))
41 |
42 | script_stream = kwargs.get('script_stream', None)
43 | script_path = kwargs.get('script_path', None)
44 | if script_stream is not None:
45 | args.extend(('-map', '0:{0}'.format(script_stream), script_path))
46 |
47 | video_stream = kwargs.get('video_stream', None)
48 | timecodes_path = kwargs.get('timecodes_path', None)
49 | if timecodes_path is not None:
50 | args.extend(('-map', '0:{0}'.format(video_stream), '-f', 'mkvtimestamp_v2', timecodes_path))
51 |
52 | logging.info('ffmpeg args: {0}'.format(' '.join(('"{0}"' if ' ' in a else '{0}').format(a) for a in args)))
53 | try:
54 | subprocess.call(args)
55 | except OSError as e:
56 | if e.errno == 2:
57 | raise SushiError("Couldn't invoke ffmpeg, check that it's installed")
58 | raise
59 |
60 | @staticmethod
61 | def _get_audio_streams(info):
62 | streams = re.findall(r'Stream\s\#0:(\d+).*?Audio:\s*(.*?(?:\((default)\))?)\s*?(?:\(forced\))?\r?\n'
63 | r'(?:\s*Metadata:\s*\r?\n'
64 | r'\s*title\s*:\s*(.*?)\r?\n)?',
65 | info, flags=re.VERBOSE)
66 | return [MediaStreamInfo(int(x[0]), x[1], x[2] != '', x[3]) for x in streams]
67 |
68 | @staticmethod
69 | def _get_video_streams(info):
70 | streams = re.findall(r'Stream\s\#0:(\d+).*?Video:\s*(.*?(?:\((default)\))?)\s*?(?:\(forced\))?\r?\n'
71 | r'(?:\s*Metadata:\s*\r?\n'
72 | r'\s*title\s*:\s*(.*?)\r?\n)?',
73 | info, flags=re.VERBOSE)
74 | return [MediaStreamInfo(int(x[0]), x[1], x[2] != '', x[3]) for x in streams]
75 |
76 | @staticmethod
77 | def _get_chapters_times(info):
78 | return map(float, re.findall(r'Chapter #0.\d+: start (\d+\.\d+)', info))
79 |
80 | @staticmethod
81 | def _get_subtitles_streams(info):
82 | maps = {
83 | 'ssa': '.ass',
84 | 'ass': '.ass',
85 | 'subrip': '.srt'
86 | }
87 |
88 | streams = re.findall(r'Stream\s\#0:(\d+).*?Subtitle:\s*((\w*)\s*?(?:\((default)\))?\s*?(?:\(forced\))?)\r?\n'
89 | r'(?:\s*Metadata:\s*\r?\n'
90 | r'\s*title\s*:\s*(.*?)\r?\n)?',
91 | info, flags=re.VERBOSE)
92 | return [SubtitlesStreamInfo(int(x[0]), x[1], maps.get(x[2], x[2]), x[3] != '', x[4].strip()) for x in streams]
93 |
94 | @classmethod
95 | def get_media_info(cls, path):
96 | info = cls.get_info(path)
97 | video_streams = cls._get_video_streams(info)
98 | audio_streams = cls._get_audio_streams(info)
99 | subs_streams = cls._get_subtitles_streams(info)
100 | chapter_times = cls._get_chapters_times(info)
101 | return MediaInfo(video_streams, audio_streams, subs_streams, chapter_times)
102 |
103 |
104 | class MkvToolnix(object):
105 | @classmethod
106 | def extract_timecodes(cls, mkv_path, stream_idx, output_path):
107 | args = ['mkvextract', 'timecodes_v2', mkv_path, '{0}:{1}'.format(stream_idx, output_path)]
108 | subprocess.call(args)
109 |
110 |
111 | class SCXviD(object):
112 | @classmethod
113 | def make_keyframes(cls, video_path, log_path):
114 | try:
115 | ffmpeg_process = subprocess.Popen(['ffmpeg', '-i', video_path,
116 | '-f', 'yuv4mpegpipe',
117 | '-vf', 'scale=640:360',
118 | '-pix_fmt', 'yuv420p',
119 | '-vsync', 'drop', '-'], stdout=subprocess.PIPE)
120 | except OSError as e:
121 | if e.errno == 2:
122 | raise SushiError("Couldn't invoke ffmpeg, check that it's installed")
123 | raise
124 |
125 | try:
126 | scxvid_process = subprocess.Popen(['SCXvid', log_path], stdin=ffmpeg_process.stdout)
127 | except OSError as e:
128 | ffmpeg_process.kill()
129 | if e.errno == 2:
130 | raise SushiError("Couldn't invoke scxvid, check that it's installed")
131 | raise
132 | scxvid_process.wait()
133 |
134 |
135 | class Timecodes(object):
136 | def __init__(self, times, default_fps):
137 | super(Timecodes, self).__init__()
138 | self.times = times
139 | self.default_frame_duration = 1.0 / default_fps if default_fps else None
140 |
141 | def get_frame_time(self, number):
142 | try:
143 | return self.times[number]
144 | except IndexError:
145 | if not self.default_frame_duration:
146 | return self.get_frame_time(len(self.times)-1)
147 | if self.times:
148 | return self.times[-1] + (self.default_frame_duration) * (number - len(self.times) + 1)
149 | else:
150 | return number * self.default_frame_duration
151 |
152 | def get_frame_number(self, timestamp):
153 | if (not self.times or self.times[-1] < timestamp) and self.default_frame_duration:
154 | return int((timestamp - sum(self.times)) / self.default_frame_duration)
155 | return bisect.bisect_left(self.times, timestamp)
156 |
157 | def get_frame_size(self, timestamp):
158 | try:
159 | number = bisect.bisect_left(self.times, timestamp)
160 | except:
161 | return self.default_frame_duration
162 |
163 | c = self.get_frame_time(number)
164 |
165 | if number == len(self.times):
166 | p = self.get_frame_time(number - 1)
167 | return c - p
168 | else:
169 | n = self.get_frame_time(number + 1)
170 | return n - c
171 |
172 | @classmethod
173 | def _convert_v1_to_v2(cls, default_fps, overrides):
174 | # start, end, fps
175 | overrides = [(int(x[0]), int(x[1]), float(x[2])) for x in overrides]
176 | if not overrides:
177 | return []
178 |
179 | fps = [default_fps] * (overrides[-1][1] + 1)
180 | for o in overrides:
181 | fps[o[0]:o[1] + 1] = [o[2]] * (o[1] - o[0] + 1)
182 |
183 | v2 = [0]
184 | for d in (1.0 / f for f in fps):
185 | v2.append(v2[-1] + d)
186 | return v2
187 |
188 | @classmethod
189 | def parse(cls, text):
190 | lines = text.splitlines()
191 | if not lines:
192 | return []
193 | first = lines[0].lower().lstrip()
194 | if first.startswith('# timecode format v2') or first.startswith('# timestamp format v2'):
195 | tcs = [float(x) / 1000.0 for x in lines[1:]]
196 | return Timecodes(tcs, None)
197 | elif first.startswith('# timecode format v1'):
198 | default = float(lines[1].lower().replace('assume ', ""))
199 | overrides = (x.split(',') for x in lines[2:])
200 | return Timecodes(cls._convert_v1_to_v2(default, overrides), default)
201 | else:
202 | raise SushiError('This timecodes format is not supported')
203 |
204 | @classmethod
205 | def from_file(cls, path):
206 | with open(path) as file:
207 | return cls.parse(file.read())
208 |
209 | @classmethod
210 | def cfr(cls, fps):
211 | class CfrTimecodes(object):
212 | def __init__(self, fps):
213 | self.frame_duration = 1.0 / fps
214 |
215 | def get_frame_time(self, number):
216 | return number * self.frame_duration
217 |
218 | def get_frame_size(self, timestamp):
219 | return self.frame_duration
220 |
221 | def get_frame_number(self, timestamp):
222 | return int(timestamp / self.frame_duration)
223 |
224 | return CfrTimecodes(fps)
225 |
226 |
227 | class Demuxer(object):
228 | def __init__(self, path):
229 | super(Demuxer, self).__init__()
230 | self._path = path
231 | self._is_wav = get_extension(self._path) == '.wav'
232 | self._mi = None if self._is_wav else FFmpeg.get_media_info(self._path)
233 | self._demux_audio = self._demux_subs = self._make_timecodes = self._make_keyframes = self._write_chapters = False
234 |
235 | @property
236 | def is_wav(self):
237 | return self._is_wav
238 |
239 | @property
240 | def path(self):
241 | return self._path
242 |
243 | @property
244 | def chapters(self):
245 | if self.is_wav:
246 | return []
247 | return self._mi.chapters
248 |
249 | @property
250 | def has_video(self):
251 | return not self.is_wav and self._mi.video
252 |
253 | def set_audio(self, stream_idx, output_path, sample_rate):
254 | self._audio_stream = self._select_stream(self._mi.audio, stream_idx, 'audio')
255 | self._audio_output_path = output_path
256 | self._audio_sample_rate = sample_rate
257 | self._demux_audio = True
258 |
259 | def set_script(self, stream_idx, output_path):
260 | self._script_stream = self._select_stream(self._mi.subtitles, stream_idx, 'subtitles')
261 | self._script_output_path = output_path
262 | self._demux_subs = True
263 |
264 | def set_timecodes(self, output_path):
265 | self._timecodes_output_path = output_path
266 | self._make_timecodes = True
267 |
268 | def set_chapters(self, output_path):
269 | self._write_chapters = True
270 | self._chapters_output_path = output_path
271 |
272 | def set_keyframes(self, output_path):
273 | self._keyframes_output_path = output_path
274 | self._make_keyframes = True
275 |
276 | def get_subs_type(self, stream_idx):
277 | return self._select_stream(self._mi.subtitles, stream_idx, 'subtitles').type
278 |
279 | def demux(self):
280 | if self._write_chapters:
281 | with open(self._chapters_output_path, "w") as output_file:
282 | output_file.write(chapters.format_ogm_chapters(self.chapters))
283 |
284 | if self._make_keyframes:
285 | SCXviD.make_keyframes(self._path, self._keyframes_output_path)
286 |
287 | ffargs = {}
288 | if self._demux_audio:
289 | ffargs['audio_stream'] = self._audio_stream.id
290 | ffargs['audio_path'] = self._audio_output_path
291 | ffargs['audio_rate'] = self._audio_sample_rate
292 | if self._demux_subs:
293 | ffargs['script_stream'] = self._script_stream.id
294 | ffargs['script_path'] = self._script_output_path
295 |
296 | if self._make_timecodes:
297 | def set_ffmpeg_timecodes():
298 | ffargs['video_stream'] = self._mi.video[0].id
299 | ffargs['timecodes_path'] = self._timecodes_output_path
300 |
301 | if get_extension(self._path).lower() == '.mkv':
302 | try:
303 | MkvToolnix.extract_timecodes(self._path,
304 | stream_idx=self._mi.video[0].id,
305 | output_path=self._timecodes_output_path)
306 | except OSError as e:
307 | if e.errno == 2:
308 | set_ffmpeg_timecodes()
309 | else:
310 | raise
311 | else:
312 | set_ffmpeg_timecodes()
313 |
314 | if ffargs:
315 | FFmpeg.demux_file(self._path, **ffargs)
316 |
317 | def cleanup(self):
318 | if self._demux_audio:
319 | os.remove(self._audio_output_path)
320 | if self._demux_subs:
321 | os.remove(self._script_output_path)
322 | if self._make_timecodes:
323 | os.remove(self._timecodes_output_path)
324 | if self._write_chapters:
325 | os.remove(self._chapters_output_path)
326 |
327 | @classmethod
328 | def _format_stream(cls, stream):
329 | return '{0}{1}: {2}'.format(stream.id, ' (%s)' % stream.title if stream.title else '', stream.info)
330 |
331 | @classmethod
332 | def _format_streams_list(cls, streams):
333 | return '\n'.join(map(cls._format_stream, streams))
334 |
335 | def _select_stream(self, streams, chosen_idx, name):
336 | if not streams:
337 | raise SushiError('No {0} streams found in {1}'.format(name, self._path))
338 | if chosen_idx is None:
339 | if len(streams) > 1:
340 | default_track = next((s for s in streams if s.default), None)
341 | if default_track:
342 | logging.warning('Using default track {0} in {1} because there are multiple candidates'
343 | .format(self._format_stream(default_track), self._path))
344 | return default_track
345 | raise SushiError('More than one {0} stream found in {1}.'
346 | 'You need to specify the exact one to demux. Here are all candidates:\n'
347 | '{2}'.format(name, self._path, self._format_streams_list(streams)))
348 | return streams[0]
349 |
350 | try:
351 | return next(x for x in streams if x.id == chosen_idx)
352 | except StopIteration:
353 | raise SushiError("Stream with index {0} doesn't exist in {1}.\n"
354 | "Here are all that do:\n"
355 | "{2}".format(chosen_idx, self._path, self._format_streams_list(streams)))
356 |
357 |
--------------------------------------------------------------------------------
/keyframes.py:
--------------------------------------------------------------------------------
1 | from common import SushiError, read_all_text
2 |
3 |
4 | def parse_scxvid_keyframes(text):
5 | return [i-3 for i,line in enumerate(text.splitlines()) if line and line[0] == 'i']
6 |
7 | def parse_keyframes(path):
8 | text = read_all_text(path)
9 | if '# XviD 2pass stat file' in text:
10 | frames = parse_scxvid_keyframes(text)
11 | else:
12 | raise SushiError('Unsupported keyframes type')
13 | if 0 not in frames:
14 | frames.insert(0, 0)
15 | return frames
16 |
--------------------------------------------------------------------------------
/licenses/OpenCV.txt:
--------------------------------------------------------------------------------
1 | By downloading, copying, installing or using the software you agree to this license.
2 | If you do not agree to this license, do not download, install,
3 | copy or use the software.
4 |
5 |
6 | License Agreement
7 | For Open Source Computer Vision Library
8 | (3-clause BSD License)
9 |
10 | Redistribution and use in source and binary forms, with or without modification,
11 | are permitted provided that the following conditions are met:
12 |
13 | * Redistributions of source code must retain the above copyright notice,
14 | this list of conditions and the following disclaimer.
15 |
16 | * Redistributions in binary form must reproduce the above copyright notice,
17 | this list of conditions and the following disclaimer in the documentation
18 | and/or other materials provided with the distribution.
19 |
20 | * Neither the names of the copyright holders nor the names of the contributors
21 | may be used to endorse or promote products derived from this software
22 | without specific prior written permission.
23 |
24 | This software is provided by the copyright holders and contributors "as is" and
25 | any express or implied warranties, including, but not limited to, the implied
26 | warranties of merchantability and fitness for a particular purpose are disclaimed.
27 | In no event shall copyright holders or contributors be liable for any direct,
28 | indirect, incidental, special, exemplary, or consequential damages
29 | (including, but not limited to, procurement of substitute goods or services;
30 | loss of use, data, or profits; or business interruption) however caused
31 | and on any theory of liability, whether in contract, strict liability,
32 | or tort (including negligence or otherwise) arising in any way out of
33 | the use of this software, even if advised of the possibility of such damage.
34 |
--------------------------------------------------------------------------------
/licenses/SciPy.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2001, 2002 Enthought, Inc.
2 | All rights reserved.
3 |
4 | Copyright (c) 2003-2012 SciPy Developers.
5 | All rights reserved.
6 |
7 | Redistribution and use in source and binary forms, with or without
8 | modification, are permitted provided that the following conditions are met:
9 |
10 | a. Redistributions of source code must retain the above copyright notice,
11 | this list of conditions and the following disclaimer.
12 | b. Redistributions in binary form must reproduce the above copyright
13 | notice, this list of conditions and the following disclaimer in the
14 | documentation and/or other materials provided with the distribution.
15 | c. Neither the name of Enthought nor the names of the SciPy Developers
16 | may be used to endorse or promote products derived from this software
17 | without specific prior written permission.
18 |
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS
24 | BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
25 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
26 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
27 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
28 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
29 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
30 | THE POSSIBILITY OF SUCH DAMAGE.
31 |
32 |
--------------------------------------------------------------------------------
/regression-tests.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | import json
3 | import logging
4 | import os
5 | import gc
6 | import sys
7 | import resource
8 | import re
9 | import subprocess
10 | import argparse
11 |
12 | from common import format_time
13 | from demux import Timecodes
14 | from subs import AssScript
15 | from wav import WavStream
16 |
17 |
18 | root_logger = logging.getLogger('')
19 |
20 |
21 | def strip_tags(text):
22 | return re.sub(r'{.*?}', " ", text)
23 |
24 |
25 | @contextmanager
26 | def set_file_logger(path):
27 | handler = logging.FileHandler(path, mode='a')
28 | handler.setLevel(logging.DEBUG)
29 | handler.setFormatter(logging.Formatter('%(message)s'))
30 | root_logger.addHandler(handler)
31 | try:
32 | yield
33 | finally:
34 | root_logger.removeHandler(handler)
35 |
36 |
37 | def compare_scripts(ideal_path, test_path, timecodes, test_name, expected_errors):
38 | ideal_script = AssScript.from_file(ideal_path)
39 | test_script = AssScript.from_file(test_path)
40 | if len(test_script.events) != len(ideal_script.events):
41 | logging.critical("Script length didn't match: {0} in ideal vs {1} in test. Test {2}".format(
42 | len(ideal_script.events), len(test_script.events), test_name)
43 | )
44 | return False
45 | ideal_script.sort_by_time()
46 | test_script.sort_by_time()
47 | failed = 0
48 | ft = format_time
49 | for idx, (ideal, test) in enumerate(zip(ideal_script.events, test_script.events)):
50 | ideal_start_frame = timecodes.get_frame_number(ideal.start)
51 | ideal_end_frame = timecodes.get_frame_number(ideal.end)
52 |
53 | test_start_frame = timecodes.get_frame_number(test.start)
54 | test_end_frame = timecodes.get_frame_number(test.end)
55 |
56 | if ideal_start_frame != test_start_frame and ideal_end_frame != test_end_frame:
57 | logging.debug(u'{0}: start and end time failed at "{1}". {2}-{3} vs {4}-{5}'.format(
58 | idx, strip_tags(ideal.text), ft(ideal.start), ft(ideal.end), ft(test.start), ft(test.end))
59 | )
60 | failed += 1
61 | elif ideal_end_frame != test_end_frame:
62 | logging.debug(
63 | u'{0}: end time failed at "{1}". {2} vs {3}'.format(
64 | idx, strip_tags(ideal.text), ft(ideal.end), ft(test.end))
65 | )
66 | failed += 1
67 | elif ideal_start_frame != test_start_frame:
68 | logging.debug(
69 | u'{0}: start time failed at "{1}". {2} vs {3}'.format(
70 | idx, strip_tags(ideal.text), ft(ideal.start), ft(test.start))
71 | )
72 | failed += 1
73 |
74 | logging.info('Total lines: {0}, good: {1}, failed: {2}'.format(len(ideal_script.events), len(ideal_script.events)-failed, failed))
75 |
76 | if failed > expected_errors:
77 | logging.critical('Got more failed lines than expected ({0} actual vs {1} expected)'.format(failed, expected_errors))
78 | return False
79 | elif failed < expected_errors:
80 | logging.critical('Got less failed lines than expected ({0} actual vs {1} expected)'.format(failed, expected_errors))
81 | return False
82 | else:
83 | logging.critical('Met expectations')
84 | return True
85 |
86 |
87 | def run_test(base_path, plots_path, test_name, params):
88 | def safe_add_key(args, key, name):
89 | if name in params:
90 | args.extend((key, str(params[name])))
91 |
92 | def safe_add_path(args, folder, key, name):
93 | if name in params:
94 | args.extend((key, os.path.join(folder, params[name])))
95 |
96 | logging.info('Testing "{0}"'.format(test_name))
97 |
98 | folder = os.path.join(base_path, params['folder'])
99 |
100 | cmd = ["sushi"]
101 |
102 | safe_add_path(cmd, folder, '--src', 'src')
103 | safe_add_path(cmd, folder, '--dst', 'dst')
104 | safe_add_path(cmd, folder, '--src-keyframes', 'src-keyframes')
105 | safe_add_path(cmd, folder, '--dst-keyframes', 'dst-keyframes')
106 | safe_add_path(cmd, folder, '--src-timecodes', 'src-timecodes')
107 | safe_add_path(cmd, folder, '--dst-timecodes', 'dst-timecodes')
108 | safe_add_path(cmd, folder, '--script', 'script')
109 | safe_add_path(cmd, folder, '--chapters', 'chapters')
110 | safe_add_path(cmd, folder, '--src-script', 'src-script')
111 | safe_add_path(cmd, folder, '--dst-script', 'dst-script')
112 | safe_add_key(cmd, '--max-kf-distance', 'max-kf-distance')
113 | safe_add_key(cmd, '--max-ts-distance', 'max-ts-distance')
114 | safe_add_key(cmd, '--max-ts-duration', 'max-ts-duration')
115 |
116 | output_path = os.path.join(folder, params['dst']) + '.sushi.test.ass'
117 | cmd.extend(('-o', output_path))
118 | if plots_path:
119 | cmd.extend(('--test-shift-plot', os.path.join(plots_path, '{0}.png'.format(test_name))))
120 |
121 | log_path = os.path.join(folder, 'sushi_test.log')
122 |
123 | with open(log_path, "w") as log_file:
124 | try:
125 | subprocess.call(cmd, stderr=log_file, stdout=log_file)
126 | except Exception as e:
127 | logging.critical('Sushi failed on test "{0}": {1}'.format(test_name, e.message))
128 | return False
129 |
130 | with set_file_logger(log_path):
131 | ideal_path = os.path.join(folder, params['ideal'])
132 | try:
133 | timecodes = Timecodes.from_file(os.path.join(folder, params['dst-timecodes']))
134 | except KeyError:
135 | timecodes = Timecodes.cfr(params['fps'])
136 |
137 | return compare_scripts(ideal_path, output_path, timecodes, test_name, params['expected_errors'])
138 |
139 |
140 | def run_wav_test(test_name, file_path, params):
141 | gc.collect(2)
142 |
143 | before = resource.getrusage(resource.RUSAGE_SELF)
144 | loaded = WavStream(file_path, params.get('sample_rate', 12000), params.get('sample_type', 'uint8'))
145 | after = resource.getrusage(resource.RUSAGE_SELF)
146 |
147 | total_time = (after.ru_stime - before.ru_stime) + (after.ru_utime - before.ru_utime)
148 | ram_difference = (after.ru_maxrss - before.ru_maxrss) / 1024.0 / 1024.0
149 |
150 | if 'max_time' in params and total_time > params['max_time']:
151 | logging.critical('Loading "{0}" took too much time: {1} vs {2} seconds'
152 | .format(test_name, total_time, params['max_time']))
153 | return False
154 | if 'max_memory' in params and ram_difference > params['max_memory']:
155 | logging.critical('Loading "{0}" consumed too much RAM: {1} vs {2}'
156 | .format(test_name, ram_difference, params['max_memory']))
157 | return False
158 | return True
159 |
160 |
161 | def create_arg_parser():
162 | parser = argparse.ArgumentParser(description='Sushi regression testing util')
163 |
164 | parser.add_argument('--only', dest="run_only", nargs="*", metavar='',
165 | help='Test names to run')
166 | parser.add_argument('-c', '--conf', default="tests.json", dest='conf_path', metavar='',
167 | help='Config file path')
168 |
169 | return parser
170 |
171 |
172 | def run():
173 | root_logger.setLevel(logging.DEBUG)
174 | console_handler = logging.StreamHandler()
175 | console_handler.setLevel(logging.INFO)
176 | console_handler.setFormatter(logging.Formatter('%(message)s'))
177 | root_logger.addHandler(console_handler)
178 |
179 | args = create_arg_parser().parse_args()
180 |
181 | try:
182 | with open(args.conf_path) as file:
183 | config = json.load(file)
184 | except IOError as e:
185 | logging.critical(e)
186 | sys.exit(2)
187 |
188 | def should_run(name):
189 | return not args.run_only or name in args.run_only
190 |
191 | failed = ran = 0
192 | for test_name, params in config.get('tests', {}).iteritems():
193 | if not should_run(test_name):
194 | continue
195 | if not params.get('disabled', False):
196 | ran += 1
197 | if not run_test(config['basepath'], config['plots'], test_name, params):
198 | failed += 1
199 | logging.info('')
200 | else:
201 | logging.warn('Test "{0}" disabled'.format(test_name))
202 |
203 | if should_run("wavs"):
204 | for test_name, params in config.get('wavs', {}).iteritems():
205 | ran += 1
206 | if not run_wav_test(test_name, os.path.join(config['basepath'], params['file']), params):
207 | failed += 1
208 | logging.info('')
209 |
210 | logging.info('Ran {0} tests, {1} failed'.format(ran, failed))
211 |
212 |
213 | if __name__ == '__main__':
214 | run()
215 |
--------------------------------------------------------------------------------
/requirements-win.txt:
--------------------------------------------------------------------------------
1 | pyinstaller
2 | colorama
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | mock
3 |
--------------------------------------------------------------------------------
/run-tests.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from tests.timecodes import *
3 | from tests.main import *
4 | from tests.subtitles import *
5 | from tests.demuxing import *
6 |
7 | unittest.main(verbosity=0)
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | import sushi
3 |
4 | setup(
5 | name='Sushi',
6 | description='Automatic subtitle shifter based on audio',
7 | version=sushi.VERSION,
8 | url='https://github.com/tp7/Sushi',
9 | console=['sushi.py'],
10 | license='MIT'
11 | )
12 |
13 |
--------------------------------------------------------------------------------
/subs.py:
--------------------------------------------------------------------------------
1 | import codecs
2 | import os
3 | import re
4 | import collections
5 |
6 | from common import SushiError, format_time, format_srt_time
7 |
8 |
9 | def _parse_ass_time(string):
10 | hours, minutes, seconds = map(float, string.split(':'))
11 | return hours * 3600 + minutes * 60 + seconds
12 |
13 |
14 | class ScriptEventBase(object):
15 | def __init__(self, source_index, start, end, text):
16 | self.source_index = source_index
17 | self.start = start
18 | self.end = end
19 | self.text = text
20 |
21 | self._shift = 0
22 | self._diff = 1
23 | self._linked_event = None
24 | self._start_shift = 0
25 | self._end_shift = 0
26 |
27 | @property
28 | def shift(self):
29 | return self._linked_event.shift if self.linked else self._shift
30 |
31 | @property
32 | def diff(self):
33 | return self._linked_event.diff if self.linked else self._diff
34 |
35 | @property
36 | def duration(self):
37 | return self.end - self.start
38 |
39 | @property
40 | def shifted_end(self):
41 | return self.end + self.shift + self._end_shift
42 |
43 | @property
44 | def shifted_start(self):
45 | return self.start + self.shift + self._start_shift
46 |
47 | def apply_shift(self):
48 | self.start = self.shifted_start
49 | self.end = self.shifted_end
50 |
51 | def set_shift(self, shift, audio_diff):
52 | assert not self.linked, 'Cannot set shift of a linked event'
53 | self._shift = shift
54 | self._diff = audio_diff
55 |
56 | def adjust_additional_shifts(self, start_shift, end_shift):
57 | assert not self.linked, 'Cannot apply additional shifts to a linked event'
58 | self._start_shift += start_shift
59 | self._end_shift += end_shift
60 |
61 | def get_link_chain_end(self):
62 | return self._linked_event.get_link_chain_end() if self.linked else self
63 |
64 | def link_event(self, other):
65 | assert other.get_link_chain_end() is not self, 'Circular link detected'
66 | self._linked_event = other
67 |
68 | def resolve_link(self):
69 | assert self.linked, 'Cannot resolve unlinked events'
70 | self._shift = self._linked_event.shift
71 | self._diff = self._linked_event.diff
72 | self._linked_event = None
73 |
74 | @property
75 | def linked(self):
76 | return self._linked_event is not None
77 |
78 | def adjust_shift(self, value):
79 | assert not self.linked, 'Cannot adjust time of linked events'
80 | self._shift += value
81 |
82 | def __repr__(self):
83 | return unicode(self)
84 |
85 |
86 | class ScriptBase(object):
87 | def __init__(self, events):
88 | self.events = events
89 |
90 | def sort_by_time(self):
91 | self.events.sort(key=lambda x: x.start)
92 |
93 |
94 | class SrtEvent(ScriptEventBase):
95 | is_comment = False
96 | style = None
97 |
98 | EVENT_REGEX = re.compile("""
99 | (\d+?)\s+? # line-number
100 | (\d{1,2}:\d{1,2}:\d{1,2},\d+)\s-->\s(\d{1,2}:\d{1,2}:\d{1,2},\d+). # timestamp
101 | (.+?) # actual text
102 | (?= # lookahead for the next line or end of the file
103 | (?:\d+?\s+? # line-number
104 | \d{1,2}:\d{1,2}:\d{1,2},\d+\s-->\s\d{1,2}:\d{1,2}:\d{1,2},\d+) # timestamp
105 | |$
106 | )""", flags=re.VERBOSE | re.DOTALL)
107 |
108 | @classmethod
109 | def from_string(cls, text):
110 | match = cls.EVENT_REGEX.match(text)
111 | start = cls.parse_time(match.group(2))
112 | end = cls.parse_time(match.group(3))
113 | return SrtEvent(int(match.group(1)), start, end, match.group(4).strip())
114 |
115 | def __unicode__(self):
116 | return u'{0}\n{1} --> {2}\n{3}'.format(self.source_index, self._format_time(self.start),
117 | self._format_time(self.end), self.text)
118 |
119 | @staticmethod
120 | def parse_time(time_string):
121 | return _parse_ass_time(time_string.replace(',', '.'))
122 |
123 | @staticmethod
124 | def _format_time(seconds):
125 | return format_srt_time(seconds)
126 |
127 |
128 | class SrtScript(ScriptBase):
129 | @classmethod
130 | def from_file(cls, path):
131 | try:
132 | with codecs.open(path, encoding='utf-8-sig') as script:
133 | text = script.read()
134 | events_list = [SrtEvent(
135 | source_index=int(match.group(1)),
136 | start=SrtEvent.parse_time(match.group(2)),
137 | end=SrtEvent.parse_time(match.group(3)),
138 | text=match.group(4).strip()
139 | ) for match in SrtEvent.EVENT_REGEX.finditer(text)]
140 | return cls(events_list)
141 | except IOError:
142 | raise SushiError("Script {0} not found".format(path))
143 |
144 | def save_to_file(self, path):
145 | text = '\n\n'.join(map(unicode, self.events))
146 | with codecs.open(path, encoding='utf-8', mode='w') as script:
147 | script.write(text)
148 |
149 |
150 | class AssEvent(ScriptEventBase):
151 | def __init__(self, text, position=0):
152 | kind, _, rest = text.partition(':')
153 | split = [x.strip() for x in rest.split(',', 9)]
154 |
155 | super(AssEvent, self).__init__(
156 | source_index=position,
157 | start=_parse_ass_time(split[1]),
158 | end=_parse_ass_time(split[2]),
159 | text=split[9]
160 | )
161 | self.kind = kind
162 | self.is_comment = self.kind.lower() == 'comment'
163 | self.layer = split[0]
164 | self.style = split[3]
165 | self.name = split[4]
166 | self.margin_left = split[5]
167 | self.margin_right = split[6]
168 | self.margin_vertical = split[7]
169 | self.effect = split[8]
170 |
171 | def __unicode__(self):
172 | return u'{0}: {1},{2},{3},{4},{5},{6},{7},{8},{9},{10}'.format(self.kind, self.layer,
173 | self._format_time(self.start),
174 | self._format_time(self.end),
175 | self.style, self.name,
176 | self.margin_left, self.margin_right,
177 | self.margin_vertical, self.effect,
178 | self.text)
179 |
180 | @staticmethod
181 | def _format_time(seconds):
182 | return format_time(seconds)
183 |
184 |
185 | class AssScript(ScriptBase):
186 | def __init__(self, script_info, styles, events, other):
187 | super(AssScript, self).__init__(events)
188 | self.script_info = script_info
189 | self.styles = styles
190 | self.other = other
191 |
192 | @classmethod
193 | def from_file(cls, path):
194 | script_info, styles, events = [], [], []
195 | other_sections = collections.OrderedDict()
196 |
197 | def parse_script_info_line(line):
198 | if line.startswith(u'Format:'):
199 | return
200 | script_info.append(line)
201 |
202 | def parse_styles_line(line):
203 | if line.startswith(u'Format:'):
204 | return
205 | styles.append(line)
206 |
207 | def parse_event_line(line):
208 | if line.startswith(u'Format:'):
209 | return
210 | events.append(AssEvent(line, position=len(events)+1))
211 |
212 | def create_generic_parse(section_name):
213 | if section_name in other_sections:
214 | raise SushiError("Duplicate section detected, invalid script?")
215 | other_sections[section_name] = []
216 | return other_sections[section_name].append
217 |
218 | parse_function = None
219 |
220 | try:
221 | with codecs.open(path, encoding='utf-8-sig') as script:
222 | for line_idx, line in enumerate(script):
223 | line = line.strip()
224 | if not line:
225 | continue
226 | low = line.lower()
227 | if low == u'[script info]':
228 | parse_function = parse_script_info_line
229 | elif low == u'[v4+ styles]':
230 | parse_function = parse_styles_line
231 | elif low == u'[events]':
232 | parse_function = parse_event_line
233 | elif re.match(r'\[.+?\]', low):
234 | parse_function = create_generic_parse(line)
235 | elif not parse_function:
236 | raise SushiError("That's some invalid ASS script")
237 | else:
238 | try:
239 | parse_function(line)
240 | except Exception as e:
241 | raise SushiError("That's some invalid ASS script: {0} [line {1}]".format(e.message, line_idx))
242 | except IOError:
243 | raise SushiError("Script {0} not found".format(path))
244 | return cls(script_info, styles, events, other_sections)
245 |
246 | def save_to_file(self, path):
247 | # if os.path.exists(path):
248 | # raise RuntimeError('File %s already exists' % path)
249 | lines = []
250 | if self.script_info:
251 | lines.append(u'[Script Info]')
252 | lines.extend(self.script_info)
253 | lines.append('')
254 |
255 | if self.styles:
256 | lines.append(u'[V4+ Styles]')
257 | lines.append(u'Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding')
258 | lines.extend(self.styles)
259 | lines.append('')
260 |
261 | if self.events:
262 | events = sorted(self.events, key=lambda x: x.source_index)
263 | lines.append(u'[Events]')
264 | lines.append(u'Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text')
265 | lines.extend(map(unicode, events))
266 |
267 | if self.other:
268 | for section_name, section_lines in self.other.iteritems():
269 | lines.append('')
270 | lines.append(section_name)
271 | lines.extend(section_lines)
272 |
273 | with codecs.open(path, encoding='utf-8-sig', mode='w') as script:
274 | script.write(unicode(os.linesep).join(lines))
275 |
--------------------------------------------------------------------------------
/sushi.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | import logging
3 | import sys
4 | import operator
5 | import argparse
6 | import os
7 | import bisect
8 | import collections
9 | from itertools import takewhile, izip, chain
10 | import time
11 |
12 | import numpy as np
13 |
14 | import chapters
15 | from common import SushiError, get_extension, format_time, ensure_static_collection
16 | from demux import Timecodes, Demuxer
17 | import keyframes
18 | from subs import AssScript, SrtScript
19 | from wav import WavStream
20 |
21 |
22 | try:
23 | import matplotlib.pyplot as plt
24 | plot_enabled = True
25 | except ImportError:
26 | plot_enabled = False
27 |
28 | if sys.platform == 'win32':
29 | try:
30 | import colorama
31 | colorama.init()
32 | console_colors_supported = True
33 | except ImportError:
34 | console_colors_supported = False
35 | else:
36 | console_colors_supported = True
37 |
38 |
39 | ALLOWED_ERROR = 0.01
40 | MAX_GROUP_STD = 0.025
41 | VERSION = '0.5.1'
42 |
43 |
44 | class ColoredLogFormatter(logging.Formatter):
45 | bold_code = "\033[1m"
46 | reset_code = "\033[0m"
47 | grey_code = "\033[30m\033[1m"
48 |
49 | error_format = "{bold}ERROR: %(message)s{reset}".format(bold=bold_code, reset=reset_code)
50 | warn_format = "{bold}WARNING: %(message)s{reset}".format(bold=bold_code, reset=reset_code)
51 | debug_format = "{grey}%(message)s{reset}".format(grey=grey_code, reset=reset_code)
52 | default_format = "%(message)s"
53 |
54 | def format(self, record):
55 | if record.levelno == logging.DEBUG:
56 | self._fmt = self.debug_format
57 | elif record.levelno == logging.WARN:
58 | self._fmt = self.warn_format
59 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
60 | self._fmt = self.error_format
61 | else:
62 | self._fmt = self.default_format
63 |
64 | return super(ColoredLogFormatter, self).format(record)
65 |
66 |
67 | def abs_diff(a, b):
68 | return abs(a - b)
69 |
70 |
71 | def interpolate_nones(data, points):
72 | data = ensure_static_collection(data)
73 | values_lookup = {p: v for p, v in izip(points, data) if v is not None}
74 | if not values_lookup:
75 | return []
76 |
77 | zero_points = {p for p, v in izip(points, data) if v is None}
78 | if not zero_points:
79 | return data
80 |
81 | data_list = sorted(values_lookup.iteritems())
82 | zero_points = sorted(x for x in zero_points if x not in values_lookup)
83 |
84 | out = np.interp(x=zero_points,
85 | xp=map(operator.itemgetter(0), data_list),
86 | fp=map(operator.itemgetter(1), data_list))
87 |
88 | values_lookup.update(izip(zero_points, out))
89 |
90 | return [
91 | values_lookup[point] if value is None else value
92 | for point, value in izip(points, data)
93 | ]
94 |
95 |
96 | # todo: implement this as a running median
97 | def running_median(values, window_size):
98 | if window_size % 2 != 1:
99 | raise SushiError('Median window size should be odd')
100 | half_window = window_size // 2
101 | medians = []
102 | items_count = len(values)
103 | for idx in xrange(items_count):
104 | radius = min(half_window, idx, items_count-idx-1)
105 | med = np.median(values[idx-radius:idx+radius+1])
106 | medians.append(med)
107 | return medians
108 |
109 |
110 | def smooth_events(events, radius):
111 | if not radius:
112 | return
113 | window_size = radius*2+1
114 | shifts = [e.shift for e in events]
115 | smoothed = running_median(shifts, window_size)
116 | for event, new_shift in izip(events, smoothed):
117 | event.set_shift(new_shift, event.diff)
118 |
119 |
120 | def detect_groups(events_iter):
121 | events_iter = iter(events_iter)
122 | groups_list = [[next(events_iter)]]
123 | for event in events_iter:
124 | if abs_diff(event.shift, groups_list[-1][-1].shift) > ALLOWED_ERROR:
125 | groups_list.append([])
126 | groups_list[-1].append(event)
127 | return groups_list
128 |
129 |
130 | def groups_from_chapters(events, times):
131 | logging.info(u'Chapter start points: {0}'.format([format_time(t) for t in times]))
132 | groups = [[]]
133 | chapter_times = iter(times[1:] + [36000000000]) # very large event at the end
134 | current_chapter = next(chapter_times)
135 |
136 | for event in events:
137 | if event.end > current_chapter:
138 | groups.append([])
139 | while event.end > current_chapter:
140 | current_chapter = next(chapter_times)
141 |
142 | groups[-1].append(event)
143 |
144 | groups = filter(None, groups) # non-empty groups
145 | # check if we have any groups where every event is linked
146 | # for example a chapter with only comments inside
147 | broken_groups = [group for group in groups if not any(e for e in group if not e.linked)]
148 | if broken_groups:
149 | for group in broken_groups:
150 | for event in group:
151 | parent = event.get_link_chain_end()
152 | parent_group = next(group for group in groups if parent in group)
153 | parent_group.append(event)
154 | del group[:]
155 | groups = filter(None, groups)
156 | # re-sort the groups again since we might break the order when inserting linked events
157 | # sorting everything again is far from optimal but python sorting is very fast for sorted arrays anyway
158 | for group in groups:
159 | group.sort(key=lambda event: event.start)
160 |
161 | return groups
162 |
163 |
164 | def split_broken_groups(groups):
165 | correct_groups = []
166 | broken_found = False
167 | for g in groups:
168 | std = np.std([e.shift for e in g])
169 | if std > MAX_GROUP_STD:
170 | logging.warn(u'Shift is not consistent between {0} and {1}, most likely chapters are wrong (std: {2}). '
171 | u'Switching to automatic grouping.'.format(format_time(g[0].start), format_time(g[-1].end),
172 | std))
173 | correct_groups.extend(detect_groups(g))
174 | broken_found = True
175 | else:
176 | correct_groups.append(g)
177 |
178 | if broken_found:
179 | groups_iter = iter(correct_groups)
180 | correct_groups = [list(next(groups_iter))]
181 | for group in groups_iter:
182 | if abs_diff(correct_groups[-1][-1].shift, group[0].shift) >= ALLOWED_ERROR \
183 | or np.std([e.shift for e in group + correct_groups[-1]]) >= MAX_GROUP_STD:
184 | correct_groups.append([])
185 |
186 | correct_groups[-1].extend(group)
187 | return correct_groups
188 |
189 |
190 | def fix_near_borders(events):
191 | """
192 | We assume that all lines with diff greater than 5 * (median diff across all events) are broken
193 | """
194 | def fix_border(event_list, median_diff):
195 | last_ten_diff = np.median([x.diff for x in event_list[:10]], overwrite_input=True)
196 | diff_limit = min(last_ten_diff, median_diff)
197 | broken = []
198 | for event in event_list:
199 | if not 0.2 < (event.diff / diff_limit) < 5:
200 | broken.append(event)
201 | else:
202 | for x in broken:
203 | x.link_event(event)
204 | return len(broken)
205 | return 0
206 |
207 | median_diff = np.median([x.diff for x in events], overwrite_input=True)
208 |
209 | fixed_count = fix_border(events, median_diff)
210 | if fixed_count:
211 | logging.info('Fixing {0} border events right after {1}'.format(fixed_count, format_time(events[0].start)))
212 |
213 | fixed_count = fix_border(list(reversed(events)), median_diff)
214 | if fixed_count:
215 | logging.info('Fixing {0} border events right before {1}'.format(fixed_count, format_time(events[-1].end)))
216 |
217 |
218 | def get_distance_to_closest_kf(timestamp, keyframes):
219 | idx = bisect.bisect_left(keyframes, timestamp)
220 | if idx == 0:
221 | kf = keyframes[0]
222 | elif idx == len(keyframes):
223 | kf = keyframes[-1]
224 | else:
225 | before = keyframes[idx - 1]
226 | after = keyframes[idx]
227 | kf = after if after - timestamp < timestamp - before else before
228 | return kf - timestamp
229 |
230 |
231 | def find_keyframe_shift(group, src_keytimes, dst_keytimes, src_timecodes, dst_timecodes, max_kf_distance):
232 | def get_distance(src_distance, dst_distance, limit):
233 | if abs(dst_distance) > limit:
234 | return None
235 | shift = dst_distance - src_distance
236 | return shift if abs(shift) < limit else None
237 |
238 | src_start = get_distance_to_closest_kf(group[0].start, src_keytimes)
239 | src_end = get_distance_to_closest_kf(group[-1].end + src_timecodes.get_frame_size(group[-1].end), src_keytimes)
240 |
241 | dst_start = get_distance_to_closest_kf(group[0].shifted_start, dst_keytimes)
242 | dst_end = get_distance_to_closest_kf(group[-1].shifted_end + dst_timecodes.get_frame_size(group[-1].end), dst_keytimes)
243 |
244 | snapping_limit_start = src_timecodes.get_frame_size(group[0].start) * max_kf_distance
245 | snapping_limit_end = src_timecodes.get_frame_size(group[0].end) * max_kf_distance
246 |
247 | return (get_distance(src_start, dst_start, snapping_limit_start),
248 | get_distance(src_end, dst_end, snapping_limit_end))
249 |
250 |
251 | def find_keyframes_distances(event, src_keytimes, dst_keytimes, timecodes, max_kf_distance):
252 | def find_keyframe_distance(src_time, dst_time):
253 | src = get_distance_to_closest_kf(src_time, src_keytimes)
254 | dst = get_distance_to_closest_kf(dst_time, dst_keytimes)
255 | snapping_limit = timecodes.get_frame_size(src_time) * max_kf_distance
256 |
257 | if abs(src) < snapping_limit and abs(dst) < snapping_limit and abs(src-dst) < snapping_limit:
258 | return dst - src
259 | return 0
260 |
261 | ds = find_keyframe_distance(event.start, event.shifted_start)
262 | de = find_keyframe_distance(event.end, event.shifted_end)
263 | return ds, de
264 |
265 |
266 | def snap_groups_to_keyframes(events, chapter_times, max_ts_duration, max_ts_distance, src_keytimes, dst_keytimes,
267 | src_timecodes, dst_timecodes, max_kf_distance, kf_mode):
268 | if not max_kf_distance:
269 | return
270 |
271 | groups = merge_short_lines_into_groups(events, chapter_times, max_ts_duration, max_ts_distance)
272 |
273 | if kf_mode == 'all' or kf_mode == 'shift':
274 | # step 1: snap events without changing their duration. Useful for some slight audio imprecision correction
275 | shifts = []
276 | times = []
277 | for group in groups:
278 | shifts.extend(find_keyframe_shift(group, src_keytimes, dst_keytimes, src_timecodes, dst_timecodes, max_kf_distance))
279 | times.extend((group[0].shifted_start, group[-1].shifted_end))
280 |
281 | shifts = interpolate_nones(shifts, times)
282 | if shifts:
283 | mean_shift = np.mean(shifts)
284 | shifts = zip(*(iter(shifts), ) * 2)
285 |
286 | logging.info('Group {0}-{1} corrected by {2}'.format(format_time(events[0].start), format_time(events[-1].end), mean_shift))
287 | for group, (start_shift, end_shift) in izip(groups, shifts):
288 | if abs(start_shift-end_shift) > 0.001 and len(group) > 1:
289 | actual_shift = min(start_shift, end_shift, key=lambda x: abs(x - mean_shift))
290 | logging.warning("Typesetting group at {0} had different shift at start/end points ({1} and {2}). Shifting by {3}."
291 | .format(format_time(group[0].start), start_shift, end_shift, actual_shift))
292 | for e in group:
293 | e.adjust_shift(actual_shift)
294 | else:
295 | for e in group:
296 | e.adjust_additional_shifts(start_shift, end_shift)
297 |
298 | if kf_mode == 'all' or kf_mode == 'snap':
299 | # step 2: snap start/end times separately
300 | for group in groups:
301 | if len(group) > 1:
302 | pass # we don't snap typesetting
303 | start_shift, end_shift = find_keyframes_distances(group[0], src_keytimes, dst_keytimes, src_timecodes, max_kf_distance)
304 | if abs(start_shift) > 0.01 or abs(end_shift) > 0.01:
305 | logging.info('Snapping {0} to keyframes, start time by {1}, end: {2}'.format(format_time(group[0].start), start_shift, end_shift))
306 | group[0].adjust_additional_shifts(start_shift, end_shift)
307 |
308 |
309 | def average_shifts(events):
310 | events = [e for e in events if not e.linked]
311 | shifts = [x.shift for x in events]
312 | weights = [1 - x.diff for x in events]
313 | avg = np.average(shifts, weights=weights)
314 | for e in events:
315 | e.set_shift(avg, e.diff)
316 | return avg
317 |
318 |
319 | def merge_short_lines_into_groups(events, chapter_times, max_ts_duration, max_ts_distance):
320 | search_groups = []
321 | chapter_times = iter(chapter_times[1:] + [100000000])
322 | next_chapter = next(chapter_times)
323 | events = ensure_static_collection(events)
324 |
325 | processed = set()
326 | for idx, event in enumerate(events):
327 | if idx in processed:
328 | continue
329 |
330 | while event.end > next_chapter:
331 | next_chapter = next(chapter_times)
332 |
333 | if event.duration > max_ts_duration:
334 | search_groups.append([event])
335 | processed.add(idx)
336 | else:
337 | group = [event]
338 | group_end = event.end
339 | i = idx+1
340 | while i < len(events) and abs(group_end - events[i].start) < max_ts_distance:
341 | if events[i].end < next_chapter and events[i].duration <= max_ts_duration:
342 | processed.add(i)
343 | group.append(events[i])
344 | group_end = max(group_end, events[i].end)
345 | i += 1
346 |
347 | search_groups.append(group)
348 |
349 | return search_groups
350 |
351 |
352 | def prepare_search_groups(events, source_duration, chapter_times, max_ts_duration, max_ts_distance):
353 | last_unlinked = None
354 | for idx, event in enumerate(events):
355 | if event.is_comment:
356 | try:
357 | event.link_event(events[idx+1])
358 | except IndexError:
359 | event.link_event(last_unlinked)
360 | continue
361 | if (event.start + event.duration / 2.0) > source_duration:
362 | logging.info('Event time outside of audio range, ignoring: %s' % unicode(event))
363 | event.link_event(last_unlinked)
364 | continue
365 | elif event.end == event.start:
366 | logging.info('{0}: skipped because zero duration'.format(format_time(event.start)))
367 | try:
368 | event.link_event(events[idx + 1])
369 | except IndexError:
370 | event.link_event(last_unlinked)
371 | continue
372 |
373 | # link lines with start and end times identical to some other event
374 | # assuming scripts are sorted by start time so we don't search the entire collection
375 | same_start = lambda x: event.start == x.start
376 | processed = next((x for x in takewhile(same_start, reversed(events[:idx])) if not x.linked and x.end == event.end),None)
377 | if processed:
378 | event.link_event(processed)
379 | else:
380 | last_unlinked = event
381 |
382 | events = (e for e in events if not e.linked)
383 |
384 | search_groups = merge_short_lines_into_groups(events, chapter_times, max_ts_duration, max_ts_distance)
385 |
386 | # link groups contained inside other groups to the larger group
387 | passed_groups = []
388 | for idx, group in enumerate(search_groups):
389 | try:
390 | other = next(x for x in reversed(search_groups[:idx])
391 | if x[0].start <= group[0].start
392 | and x[-1].end >= group[-1].end)
393 | for event in group:
394 | event.link_event(other[0])
395 | except StopIteration:
396 | passed_groups.append(group)
397 | return passed_groups
398 |
399 |
400 | def calculate_shifts(src_stream, dst_stream, groups_list, normal_window, max_window, rewind_thresh):
401 | def log_shift(state):
402 | logging.info('{0}-{1}: shift: {2:0.10f}, diff: {3:0.10f}'
403 | .format(format_time(state["start_time"]), format_time(state["end_time"]), state["shift"], state["diff"]))
404 |
405 | def log_uncommitted(state, shift, left_side_shift, right_side_shift, search_offset):
406 | logging.debug('{0}-{1}: shift: {2:0.5f} [{3:0.5f}, {4:0.5f}], search offset: {5:0.6f}'
407 | .format(format_time(state["start_time"]), format_time(state["end_time"]),
408 | shift, left_side_shift, right_side_shift, search_offset))
409 |
410 | small_window = 1.5
411 | idx = 0
412 | committed_states = []
413 | uncommitted_states = []
414 | window = normal_window
415 | while idx < len(groups_list):
416 | search_group = groups_list[idx]
417 | tv_audio = src_stream.get_substream(search_group[0].start, search_group[-1].end)
418 | original_time = search_group[0].start
419 | group_state = {"start_time": search_group[0].start, "end_time": search_group[-1].end, "shift": None, "diff": None}
420 | last_committed_shift = committed_states[-1]["shift"] if committed_states else 0
421 | diff = new_time = None
422 |
423 | if not uncommitted_states:
424 | if original_time + last_committed_shift > dst_stream.duration_seconds:
425 | # event outside of audio range, all events past it are also guaranteed to fail
426 | for g in groups_list[idx:]:
427 | committed_states.append({"start_time": g[0].start, "end_time": g[-1].end, "shift": None, "diff": None})
428 | logging.info("{0}-{1}: outside of audio range".format(format_time(g[0].start), format_time(g[-1].end)))
429 | break
430 |
431 | if small_window < window:
432 | diff, new_time = dst_stream.find_substream(tv_audio, original_time + last_committed_shift, small_window)
433 |
434 | if new_time is not None and abs_diff(new_time - original_time, last_committed_shift) <= ALLOWED_ERROR:
435 | # fastest case - small window worked, commit the group immediately
436 | group_state.update({"shift": new_time - original_time, "diff": diff})
437 | committed_states.append(group_state)
438 | log_shift(group_state)
439 | if window != normal_window:
440 | logging.info("Going back to window {0} from {1}".format(normal_window, window))
441 | window = normal_window
442 | idx += 1
443 | continue
444 |
445 | left_audio_half, right_audio_half = np.split(tv_audio, [len(tv_audio[0])/2], axis=1)
446 | right_half_offset = len(left_audio_half[0]) / float(src_stream.sample_rate)
447 | terminate = False
448 | # searching from last committed shift
449 | if original_time + last_committed_shift < dst_stream.duration_seconds:
450 | diff, new_time = dst_stream.find_substream(tv_audio, original_time + last_committed_shift, window)
451 | left_side_time = dst_stream.find_substream(left_audio_half, original_time + last_committed_shift, window)[1]
452 | right_side_time = dst_stream.find_substream(right_audio_half, original_time + last_committed_shift + right_half_offset, window)[1] - right_half_offset
453 | terminate = abs_diff(left_side_time, right_side_time) <= ALLOWED_ERROR and abs_diff(new_time, left_side_time) <= ALLOWED_ERROR
454 | log_uncommitted(group_state, new_time - original_time, left_side_time - original_time,
455 | right_side_time - original_time, last_committed_shift)
456 |
457 | if not terminate and uncommitted_states and uncommitted_states[-1]["shift"] is not None \
458 | and original_time + uncommitted_states[-1]["shift"] < dst_stream.duration_seconds:
459 | start_offset = uncommitted_states[-1]["shift"]
460 | diff, new_time = dst_stream.find_substream(tv_audio, original_time + start_offset, window)
461 | left_side_time = dst_stream.find_substream(left_audio_half, original_time + start_offset, window)[1]
462 | right_side_time = dst_stream.find_substream(right_audio_half, original_time + start_offset + right_half_offset, window)[1] - right_half_offset
463 | terminate = abs_diff(left_side_time, right_side_time) <= ALLOWED_ERROR and abs_diff(new_time, left_side_time) <= ALLOWED_ERROR
464 | log_uncommitted(group_state, new_time - original_time, left_side_time - original_time,
465 | right_side_time - original_time, start_offset)
466 |
467 | shift = new_time - original_time
468 | if not terminate:
469 | # we aren't back on track yet - add this group to uncommitted
470 | group_state.update({"shift": shift, "diff": diff})
471 | uncommitted_states.append(group_state)
472 | idx += 1
473 | if rewind_thresh == len(uncommitted_states) and window < max_window:
474 | logging.warn("Detected possibly broken segment starting at {0}, increasing the window from {1} to {2}"
475 | .format(format_time(uncommitted_states[0]["start_time"]), window, max_window))
476 | window = max_window
477 | idx = len(committed_states)
478 | del uncommitted_states[:]
479 | continue
480 |
481 | # we're back on track - apply current shift to all broken events
482 | if uncommitted_states:
483 | logging.warning("Events from {0} to {1} will most likely be broken!".format(
484 | format_time(uncommitted_states[0]["start_time"]),
485 | format_time(uncommitted_states[-1]["end_time"])))
486 |
487 | uncommitted_states.append(group_state)
488 | for state in uncommitted_states:
489 | state.update({"shift": shift, "diff": diff})
490 | log_shift(state)
491 | committed_states.extend(uncommitted_states)
492 | del uncommitted_states[:]
493 | idx += 1
494 |
495 | for state in uncommitted_states:
496 | log_shift(state)
497 |
498 | for idx, (search_group, group_state) in enumerate(izip(groups_list, chain(committed_states, uncommitted_states))):
499 | if group_state["shift"] is None:
500 | for group in reversed(groups_list[:idx]):
501 | link_to = next((x for x in reversed(group) if not x.linked), None)
502 | if link_to:
503 | for e in search_group:
504 | e.link_event(link_to)
505 | break
506 | else:
507 | for e in search_group:
508 | e.set_shift(group_state["shift"], group_state["diff"])
509 |
510 |
511 | def check_file_exists(path, file_title):
512 | if path and not os.path.exists(path):
513 | raise SushiError("{0} file doesn't exist".format(file_title))
514 |
515 |
516 | def format_full_path(temp_dir, base_path, postfix):
517 | if temp_dir:
518 | return os.path.join(temp_dir, os.path.basename(base_path) + postfix)
519 | else:
520 | return base_path + postfix
521 |
522 |
523 | def create_directory_if_not_exists(path):
524 | if path and not os.path.exists(path):
525 | os.makedirs(path)
526 |
527 |
528 | def run(args):
529 | ignore_chapters = args.chapters_file is not None and args.chapters_file.lower() == 'none'
530 | write_plot = plot_enabled and args.plot_path
531 | if write_plot:
532 | plt.clf()
533 | plt.ylabel('Shift, seconds')
534 | plt.xlabel('Event index')
535 |
536 | # first part should do all possible validation and should NOT take significant amount of time
537 | check_file_exists(args.source, 'Source')
538 | check_file_exists(args.destination, 'Destination')
539 | check_file_exists(args.src_timecodes, 'Source timecodes')
540 | check_file_exists(args.dst_timecodes, 'Source timecodes')
541 | check_file_exists(args.script_file, 'Script')
542 |
543 | if not ignore_chapters:
544 | check_file_exists(args.chapters_file, 'Chapters')
545 | if args.src_keyframes not in ('auto', 'make'):
546 | check_file_exists(args.src_keyframes, 'Source keyframes')
547 | if args.dst_keyframes not in ('auto', 'make'):
548 | check_file_exists(args.dst_keyframes, 'Destination keyframes')
549 |
550 | if (args.src_timecodes and args.src_fps) or (args.dst_timecodes and args.dst_fps):
551 | raise SushiError('Both fps and timecodes file cannot be specified at the same time')
552 |
553 | src_demuxer = Demuxer(args.source)
554 | dst_demuxer = Demuxer(args.destination)
555 |
556 | if src_demuxer.is_wav and not args.script_file:
557 | raise SushiError("Script file isn't specified")
558 |
559 | if (args.src_keyframes and not args.dst_keyframes) or (args.dst_keyframes and not args.src_keyframes):
560 | raise SushiError('Either none or both of src and dst keyframes should be provided')
561 |
562 | create_directory_if_not_exists(args.temp_dir)
563 |
564 | # selecting source audio
565 | if src_demuxer.is_wav:
566 | src_audio_path = args.source
567 | else:
568 | src_audio_path = format_full_path(args.temp_dir, args.source, '.sushi.wav')
569 | src_demuxer.set_audio(stream_idx=args.src_audio_idx, output_path=src_audio_path, sample_rate=args.sample_rate)
570 |
571 | # selecting destination audio
572 | if dst_demuxer.is_wav:
573 | dst_audio_path = args.destination
574 | else:
575 | dst_audio_path = format_full_path(args.temp_dir, args.destination, '.sushi.wav')
576 | dst_demuxer.set_audio(stream_idx=args.dst_audio_idx, output_path=dst_audio_path, sample_rate=args.sample_rate)
577 |
578 | # selecting source subtitles
579 | if args.script_file:
580 | src_script_path = args.script_file
581 | else:
582 | stype = src_demuxer.get_subs_type(args.src_script_idx)
583 | src_script_path = format_full_path(args.temp_dir, args.source, '.sushi'+ stype)
584 | src_demuxer.set_script(stream_idx=args.src_script_idx, output_path=src_script_path)
585 |
586 | script_extension = get_extension(src_script_path)
587 | if script_extension not in ('.ass', '.srt'):
588 | raise SushiError('Unknown script type')
589 |
590 | # selection destination subtitles
591 | if args.output_script:
592 | dst_script_path = args.output_script
593 | dst_script_extension = get_extension(args.output_script)
594 | if dst_script_extension != script_extension:
595 | raise SushiError("Source and destination script file types don't match ({0} vs {1})"
596 | .format(script_extension, dst_script_extension))
597 | else:
598 | dst_script_path = format_full_path(args.temp_dir, args.destination, '.sushi' + script_extension)
599 |
600 | # selecting chapters
601 | if args.grouping and not ignore_chapters:
602 | if args.chapters_file:
603 | if get_extension(args.chapters_file) == '.xml':
604 | chapter_times = chapters.get_xml_start_times(args.chapters_file)
605 | else:
606 | chapter_times = chapters.get_ogm_start_times(args.chapters_file)
607 | elif not src_demuxer.is_wav:
608 | chapter_times = src_demuxer.chapters
609 | output_path = format_full_path(args.temp_dir, src_demuxer.path, ".sushi.chapters.txt")
610 | src_demuxer.set_chapters(output_path)
611 | else:
612 | chapter_times = []
613 | else:
614 | chapter_times = []
615 |
616 | # selecting keyframes and timecodes
617 | if args.src_keyframes:
618 | def select_keyframes(file_arg, demuxer):
619 | auto_file = format_full_path(args.temp_dir, demuxer.path, '.sushi.keyframes.txt')
620 | if file_arg in ('auto', 'make'):
621 | if file_arg == 'make' or not os.path.exists(auto_file):
622 | if not demuxer.has_video:
623 | raise SushiError("Cannot make keyframes for {0} because it doesn't have any video!"
624 | .format(demuxer.path))
625 | demuxer.set_keyframes(output_path=auto_file)
626 | return auto_file
627 | else:
628 | return file_arg
629 |
630 | def select_timecodes(external_file, fps_arg, demuxer):
631 | if external_file:
632 | return external_file
633 | elif fps_arg:
634 | return None
635 | elif demuxer.has_video:
636 | path = format_full_path(args.temp_dir, demuxer.path, '.sushi.timecodes.txt')
637 | demuxer.set_timecodes(output_path=path)
638 | return path
639 | else:
640 | raise SushiError('Fps, timecodes or video files must be provided if keyframes are used')
641 |
642 | src_keyframes_file = select_keyframes(args.src_keyframes, src_demuxer)
643 | dst_keyframes_file = select_keyframes(args.dst_keyframes, dst_demuxer)
644 | src_timecodes_file = select_timecodes(args.src_timecodes, args.src_fps, src_demuxer)
645 | dst_timecodes_file = select_timecodes(args.dst_timecodes, args.dst_fps, dst_demuxer)
646 |
647 | # after this point nothing should fail so it's safe to start slow operations
648 | # like running the actual demuxing
649 | src_demuxer.demux()
650 | dst_demuxer.demux()
651 |
652 | try:
653 | if args.src_keyframes:
654 | src_timecodes = Timecodes.cfr(args.src_fps) if args.src_fps else Timecodes.from_file(src_timecodes_file)
655 | src_keytimes = [src_timecodes.get_frame_time(f) for f in keyframes.parse_keyframes(src_keyframes_file)]
656 |
657 | dst_timecodes = Timecodes.cfr(args.dst_fps) if args.dst_fps else Timecodes.from_file(dst_timecodes_file)
658 | dst_keytimes = [dst_timecodes.get_frame_time(f) for f in keyframes.parse_keyframes(dst_keyframes_file)]
659 |
660 | script = AssScript.from_file(src_script_path) if script_extension == '.ass' else SrtScript.from_file(src_script_path)
661 | script.sort_by_time()
662 |
663 | src_stream = WavStream(src_audio_path, sample_rate=args.sample_rate, sample_type=args.sample_type)
664 | dst_stream = WavStream(dst_audio_path, sample_rate=args.sample_rate, sample_type=args.sample_type)
665 |
666 | search_groups = prepare_search_groups(script.events,
667 | source_duration=src_stream.duration_seconds,
668 | chapter_times=chapter_times,
669 | max_ts_duration=args.max_ts_duration,
670 | max_ts_distance=args.max_ts_distance)
671 |
672 | calculate_shifts(src_stream, dst_stream, search_groups,
673 | normal_window=args.window,
674 | max_window=args.max_window,
675 | rewind_thresh=args.rewind_thresh if args.grouping else 0)
676 |
677 | events = script.events
678 |
679 | if write_plot:
680 | plt.plot([x.shift for x in events], label='From audio')
681 |
682 | if args.grouping:
683 | if not ignore_chapters and chapter_times:
684 | groups = groups_from_chapters(events, chapter_times)
685 | for g in groups:
686 | fix_near_borders(g)
687 | smooth_events([x for x in g if not x.linked], args.smooth_radius)
688 | groups = split_broken_groups(groups)
689 | else:
690 | fix_near_borders(events)
691 | smooth_events([x for x in events if not x.linked], args.smooth_radius)
692 | groups = detect_groups(events)
693 |
694 | if write_plot:
695 | plt.plot([x.shift for x in events], label='Borders fixed')
696 |
697 | for g in groups:
698 | start_shift = g[0].shift
699 | end_shift = g[-1].shift
700 | avg_shift = average_shifts(g)
701 | logging.info(u'Group (start: {0}, end: {1}, lines: {2}), '
702 | u'shifts (start: {3}, end: {4}, average: {5})'
703 | .format(format_time(g[0].start), format_time(g[-1].end), len(g), start_shift, end_shift,
704 | avg_shift))
705 |
706 | if args.src_keyframes:
707 | for e in (x for x in events if x.linked):
708 | e.resolve_link()
709 | for g in groups:
710 | snap_groups_to_keyframes(g, chapter_times, args.max_ts_duration, args.max_ts_distance, src_keytimes,
711 | dst_keytimes, src_timecodes, dst_timecodes, args.max_kf_distance, args.kf_mode)
712 | else:
713 | fix_near_borders(events)
714 | if write_plot:
715 | plt.plot([x.shift for x in events], label='Borders fixed')
716 |
717 | if args.src_keyframes:
718 | for e in (x for x in events if x.linked):
719 | e.resolve_link()
720 | snap_groups_to_keyframes(events, chapter_times, args.max_ts_duration, args.max_ts_distance, src_keytimes,
721 | dst_keytimes, src_timecodes, dst_timecodes, args.max_kf_distance, args.kf_mode)
722 |
723 | for event in events:
724 | event.apply_shift()
725 |
726 | script.save_to_file(dst_script_path)
727 |
728 | if write_plot:
729 | plt.plot([x.shift + (x._start_shift + x._end_shift)/2.0 for x in events], label='After correction')
730 | plt.legend(fontsize=5, frameon=False, fancybox=False)
731 | plt.savefig(args.plot_path, dpi=300)
732 |
733 | finally:
734 | if args.cleanup:
735 | src_demuxer.cleanup()
736 | dst_demuxer.cleanup()
737 |
738 |
739 | def create_arg_parser():
740 | parser = argparse.ArgumentParser(description='Sushi - Automatic Subtitle Shifter')
741 |
742 | parser.add_argument('--window', default=10, type=int, metavar='', dest='window',
743 | help='Search window size. [%(default)s]')
744 | parser.add_argument('--max-window', default=30, type=int, metavar='', dest='max_window',
745 | help="Maximum search size Sushi is allowed to use when trying to recover from errors. [%(default)s]")
746 | parser.add_argument('--rewind-thresh', default=5, type=int, metavar='', dest='rewind_thresh',
747 | help="Number of consecutive errors Sushi has to encounter to consider results broken "
748 | "and retry with larger window. Set to 0 to disable. [%(default)s]")
749 | parser.add_argument('--no-grouping', action='store_false', dest='grouping',
750 | help="Don't events into groups before shifting. Also disables error recovery.")
751 | parser.add_argument('--max-kf-distance', default=2, type=float, metavar='', dest='max_kf_distance',
752 | help='Maximum keyframe snapping distance. [%(default)s]')
753 | parser.add_argument('--kf-mode', default='all', choices=['shift', 'snap', 'all'], dest='kf_mode',
754 | help='Keyframes-based shift correction/snapping mode. [%(default)s]')
755 | parser.add_argument('--smooth-radius', default=3, type=int, metavar='', dest='smooth_radius',
756 | help='Radius of smoothing median filter. [%(default)s]')
757 |
758 | # 10 frames at 23.976
759 | parser.add_argument('--max-ts-duration', default=1001.0 / 24000.0 * 10, type=float, metavar='',
760 | dest='max_ts_duration',
761 | help='Maximum duration of a line to be considered typesetting. [%(default).3f]')
762 | # 10 frames at 23.976
763 | parser.add_argument('--max-ts-distance', default=1001.0 / 24000.0 * 10, type=float, metavar='',
764 | dest='max_ts_distance',
765 | help='Maximum distance between two adjacent typesetting lines to be merged. [%(default).3f]')
766 |
767 | # deprecated/test options, do not use
768 | parser.add_argument('--test-shift-plot', default=None, dest='plot_path', help=argparse.SUPPRESS)
769 | parser.add_argument('--sample-type', default='uint8', choices=['float32', 'uint8'], dest='sample_type',
770 | help=argparse.SUPPRESS)
771 |
772 | parser.add_argument('--sample-rate', default=12000, type=int, metavar='', dest='sample_rate',
773 | help='Downsampled audio sample rate. [%(default)s]')
774 |
775 | parser.add_argument('--src-audio', default=None, type=int, metavar='', dest='src_audio_idx',
776 | help='Audio stream index of the source video')
777 | parser.add_argument('--src-script', default=None, type=int, metavar='', dest='src_script_idx',
778 | help='Script stream index of the source video')
779 | parser.add_argument('--dst-audio', default=None, type=int, metavar='', dest='dst_audio_idx',
780 | help='Audio stream index of the destination video')
781 | # files
782 | parser.add_argument('--no-cleanup', action='store_false', dest='cleanup',
783 | help="Don't delete demuxed streams")
784 | parser.add_argument('--temp-dir', default=None, dest='temp_dir', metavar='',
785 | help='Specify temporary folder to use when demuxing stream.')
786 | parser.add_argument('--chapters', default=None, dest='chapters_file', metavar='',
787 | help="XML or OGM chapters to use instead of any found in the source. 'none' to disable.")
788 | parser.add_argument('--script', default=None, dest='script_file', metavar='',
789 | help='Subtitle file path to use instead of any found in the source')
790 |
791 | parser.add_argument('--dst-keyframes', default=None, dest='dst_keyframes', metavar='',
792 | help='Destination keyframes file')
793 | parser.add_argument('--src-keyframes', default=None, dest='src_keyframes', metavar='',
794 | help='Source keyframes file')
795 | parser.add_argument('--dst-fps', default=None, type=float, dest='dst_fps', metavar='',
796 | help='Fps of the destination video. Must be provided if keyframes are used.')
797 | parser.add_argument('--src-fps', default=None, type=float, dest='src_fps', metavar='',
798 | help='Fps of the source video. Must be provided if keyframes are used.')
799 | parser.add_argument('--dst-timecodes', default=None, dest='dst_timecodes', metavar='',
800 | help='Timecodes file to use instead of making one from the destination (when possible)')
801 | parser.add_argument('--src-timecodes', default=None, dest='src_timecodes', metavar='',
802 | help='Timecodes file to use instead of making one from the source (when possible)')
803 |
804 | parser.add_argument('--src', required=True, dest="source", metavar='',
805 | help='Source audio/video')
806 | parser.add_argument('--dst', required=True, dest="destination", metavar='',
807 | help='Destination audio/video')
808 | parser.add_argument('-o', '--output', default=None, dest='output_script', metavar='',
809 | help='Output script')
810 |
811 | parser.add_argument('-v', '--verbose', default=False, dest='verbose', action='store_true',
812 | help='Enable verbose logging')
813 | parser.add_argument('--version', action='version', version=VERSION)
814 |
815 | return parser
816 |
817 |
818 | def parse_args_and_run(cmd_keys):
819 | def format_arg(arg):
820 | return arg if ' ' not in arg else '"{0}"'.format(arg)
821 |
822 | args = create_arg_parser().parse_args(cmd_keys)
823 | handler = logging.StreamHandler()
824 | if console_colors_supported and os.isatty(sys.stderr.fileno()):
825 | # enable colors
826 | handler.setFormatter(ColoredLogFormatter())
827 | else:
828 | handler.setFormatter(logging.Formatter(fmt=ColoredLogFormatter.default_format))
829 | logging.root.addHandler(handler)
830 | logging.root.setLevel(logging.DEBUG if args.verbose else logging.INFO)
831 |
832 | logging.info("Sushi's running with arguments: {0}".format(' '.join(map(format_arg, cmd_keys))))
833 | start_time = time.time()
834 | run(args)
835 | logging.info('Done in {0}s'.format(time.time() - start_time))
836 |
837 |
838 | if __name__ == '__main__':
839 | try:
840 | parse_args_and_run(sys.argv[1:])
841 | except SushiError as e:
842 | logging.critical(e.message)
843 | sys.exit(2)
844 |
--------------------------------------------------------------------------------
/tests.example.json:
--------------------------------------------------------------------------------
1 | {
2 | "basepath": "J:",
3 | "plots": "J:\\plots",
4 | "tests": {
5 | "first test":{
6 | "disabled": false,
7 | "folder": "test1",
8 | "src": "tv.wav",
9 | "dst": "bd.wav",
10 | "dst-keyframes": "bd.keyframes.txt",
11 | "src-keyframes": "tv.keyframes.txt",
12 | "dst-timecodes": "bd.timecodes.txt",
13 | "src-timecodes": "tv.timecodes.txt",
14 | "chapters": "tv.chapters.xml",
15 | "script": "tv.ass",
16 | "ideal": "ideal.ass",
17 | "fps": 23.976023976023978,
18 | "max-kf-distance": 3,
19 | "expected_errors": 84
20 | }
21 | },
22 | "wavs": {
23 | "first wav": {
24 | "max_time": 0.7,
25 | "max_memory": 120,
26 | "sample_type": "uint8",
27 | "sample_rate": 12000,
28 | "file": "wavs/file.wav"
29 | }
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tp7/Sushi/908c0ff228734059aebb914a8d10f8e4ce2e868c/tests/__init__.py
--------------------------------------------------------------------------------
/tests/demuxing.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import mock
3 |
4 | from demux import FFmpeg, MkvToolnix, SCXviD
5 | from common import SushiError
6 | import chapters
7 |
8 |
9 | def create_popen_mock():
10 | popen_mock = mock.Mock()
11 | process_mock = mock.Mock()
12 | process_mock.configure_mock(**{'communicate.return_value': ('ouput', 'error')})
13 | popen_mock.return_value = process_mock
14 | return popen_mock
15 |
16 |
17 | class FFmpegTestCase(unittest.TestCase):
18 | ffmpeg_output = '''Input #0, matroska,webm, from 'test.mkv':
19 | Stream #0:0(jpn): Video: h264 (High 10), yuv420p10le, 1280x720 [SAR 1:1 DAR 16:9], 23.98 fps, 23.98 tbr, 1k tbn, 47.95 tbc (default)
20 | Metadata:
21 | title : Video 10bit
22 | Stream #0:1(jpn): Audio: aac, 48000 Hz, stereo, fltp (default) (forced)
23 | Metadata:
24 | title : Audio AAC 2.0
25 | Stream #0:2(eng): Audio: aac, 48000 Hz, stereo, fltp
26 | Metadata:
27 | title : English Audio AAC 2.0
28 | Stream #0:3(eng): Subtitle: ssa (default) (forced)
29 | Metadata:
30 | title : English Subtitles
31 | Stream #0:4(enm): Subtitle: ass
32 | Metadata:
33 | title : English (JP honorifics)
34 | .................................'''
35 |
36 | def test_parses_audio_stream(self):
37 | audio = FFmpeg._get_audio_streams(self.ffmpeg_output)
38 | self.assertEqual(len(audio), 2)
39 | self.assertEqual(audio[0].id, 1)
40 | self.assertEqual(audio[0].title, 'Audio AAC 2.0')
41 | self.assertEqual(audio[1].id, 2)
42 | self.assertEqual(audio[1].title, 'English Audio AAC 2.0')
43 |
44 | def test_parses_video_stream(self):
45 | video = FFmpeg._get_video_streams(self.ffmpeg_output)
46 | self.assertEqual(len(video), 1)
47 | self.assertEqual(video[0].id, 0)
48 | self.assertEqual(video[0].title, 'Video 10bit')
49 |
50 | def test_parses_subtitles_stream(self):
51 | subs = FFmpeg._get_subtitles_streams(self.ffmpeg_output)
52 | self.assertEqual(len(subs), 2)
53 | self.assertEqual(subs[0].id, 3)
54 | self.assertTrue(subs[0].default)
55 | self.assertEqual(subs[0].title, 'English Subtitles')
56 | self.assertEqual(subs[1].id, 4)
57 | self.assertFalse(subs[1].default)
58 | self.assertEqual(subs[1].title, 'English (JP honorifics)')
59 |
60 | @mock.patch('subprocess.Popen', new_callable=create_popen_mock)
61 | def test_get_info_call_args(self, popen_mock):
62 | FFmpeg.get_info('random_file.mkv')
63 | self.assertEquals(popen_mock.call_args[0][0], ['ffmpeg', '-hide_banner', '-i', 'random_file.mkv'])
64 |
65 | @mock.patch('subprocess.Popen', new_callable=create_popen_mock)
66 | def test_get_info_fail_when_no_mmpeg(self, popen_mock):
67 | popen_mock.return_value.communicate.side_effect = OSError(2, "ignored")
68 | self.assertRaises(SushiError, lambda: FFmpeg.get_info('random.mkv'))
69 |
70 | @mock.patch('subprocess.call')
71 | def test_demux_file_call_args(self, call_mock):
72 | FFmpeg.demux_file('random.mkv', audio_stream=0, audio_path='audio1.wav')
73 | FFmpeg.demux_file('random.mkv', audio_stream=0, audio_path='audio2.wav', audio_rate=12000)
74 | FFmpeg.demux_file('random.mkv', script_stream=0, script_path='subs1.ass')
75 | FFmpeg.demux_file('random.mkv', video_stream=0, timecodes_path='tcs1.txt')
76 |
77 | FFmpeg.demux_file('random.mkv', audio_stream=1, audio_path='audio0.wav', audio_rate=12000,
78 | script_stream=2, script_path='out0.ass', video_stream=0, timecodes_path='tcs0.txt')
79 |
80 | call_mock.assert_any_call(['ffmpeg', '-hide_banner', '-i', 'random.mkv', '-y',
81 | '-map', '0:0', '-ac', '1', '-acodec', 'pcm_s16le', 'audio1.wav'])
82 | call_mock.assert_any_call(['ffmpeg', '-hide_banner', '-i', 'random.mkv', '-y',
83 | '-map', '0:0', '-ar', '12000', '-ac', '1', '-acodec', 'pcm_s16le', 'audio2.wav'])
84 | call_mock.assert_any_call(['ffmpeg', '-hide_banner', '-i', 'random.mkv', '-y',
85 | '-map', '0:0', 'subs1.ass'])
86 | call_mock.assert_any_call(['ffmpeg', '-hide_banner', '-i', 'random.mkv', '-y',
87 | '-map', '0:0', '-f', 'mkvtimestamp_v2', 'tcs1.txt'])
88 | call_mock.assert_any_call(['ffmpeg', '-hide_banner', '-i', 'random.mkv', '-y',
89 | '-map', '0:1', '-ar', '12000', '-ac', '1', '-acodec', 'pcm_s16le', 'audio0.wav',
90 | '-map', '0:2', 'out0.ass',
91 | '-map', '0:0', '-f', 'mkvtimestamp_v2', 'tcs0.txt'])
92 |
93 |
94 | class MkvExtractTestCase(unittest.TestCase):
95 | @mock.patch('subprocess.call')
96 | def test_extract_timecodes(self, call_mock):
97 | MkvToolnix.extract_timecodes('video.mkv', 1, 'timecodes.tsc')
98 | call_mock.assert_called_once_with(['mkvextract', 'timecodes_v2', 'video.mkv', '1:timecodes.tsc'])
99 |
100 |
101 | class SCXviDTestCase(unittest.TestCase):
102 | @mock.patch('subprocess.Popen')
103 | def test_make_keyframes(self, popen_mock):
104 | SCXviD.make_keyframes('video.mkv', 'keyframes.txt')
105 | self.assertTrue('ffmpeg' in (x.lower() for x in popen_mock.call_args_list[0][0][0]))
106 | self.assertTrue('scxvid' in (x.lower() for x in popen_mock.call_args_list[1][0][0]))
107 |
108 | @mock.patch('subprocess.Popen')
109 | def test_no_ffmpeg(self, popen_mock):
110 | def raise_no_app(cmd_args, **kwargs):
111 | if 'ffmpeg' in (x.lower() for x in cmd_args):
112 | raise OSError(2, 'ignored')
113 |
114 | popen_mock.side_effect = raise_no_app
115 | self.assertRaisesRegexp(SushiError, '[fF][fF][mM][pP][eE][gG]',
116 | lambda: SCXviD.make_keyframes('video.mkv', 'keyframes.txt'))
117 |
118 | @mock.patch('subprocess.Popen')
119 | def test_no_scxvid(self, popen_mock):
120 | def raise_no_app(cmd_args, **kwargs):
121 | if 'scxvid' in (x.lower() for x in cmd_args):
122 | raise OSError(2, 'ignored')
123 | return mock.Mock()
124 |
125 | popen_mock.side_effect = raise_no_app
126 | self.assertRaisesRegexp(SushiError, '[sS][cC][xX][vV][iI][dD]',
127 | lambda: SCXviD.make_keyframes('video.mkv', 'keyframes.txt'))
128 |
129 |
130 | class ExternalChaptersTestCase(unittest.TestCase):
131 | def test_parse_xml_start_times(self):
132 | file_text = """
133 |
134 |
135 |
136 | 2092209815
137 |
138 | 3122448259
139 | 00:00:00.000000000
140 |
141 | Prologue
142 |
143 |
144 |
145 | 998777246
146 | 00:00:17.017000000
147 |
148 | Opening Song ("YES!")
149 |
150 |
151 |
152 | 55571857
153 | 00:01:47.023000000
154 |
155 | Part A (Tale of the Doggypus)
156 |
157 |
158 |
159 |
160 | """
161 | parsed_times = chapters.parse_xml_start_times(file_text)
162 | self.assertEqual(parsed_times, [0, 17.017, 107.023])
163 |
164 | def test_parse_ogm_start_times(self):
165 | file_text = """CHAPTER01=00:00:00.000
166 | CHAPTER01NAME=Prologue
167 | CHAPTER02=00:00:17.017
168 | CHAPTER02NAME=Opening Song ("YES!")
169 | CHAPTER03=00:01:47.023
170 | CHAPTER03NAME=Part A (Tale of the Doggypus)
171 | """
172 | parsed_times = chapters.parse_ogm_start_times(file_text)
173 | self.assertEqual(parsed_times, [0, 17.017, 107.023])
174 |
175 | def test_format_ogm_chapters(self):
176 | chapters_text = chapters.format_ogm_chapters(start_times=[0, 17.017, 107.023])
177 | self.assertEqual(chapters_text, """CHAPTER01=00:00:00.000
178 | CHAPTER01NAME=
179 | CHAPTER02=00:00:17.017
180 | CHAPTER02NAME=
181 | CHAPTER03=00:01:47.023
182 | CHAPTER03NAME=
183 | """)
184 |
--------------------------------------------------------------------------------
/tests/main.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import os
3 | import re
4 | import unittest
5 | from mock import patch, ANY
6 | from common import SushiError, format_time
7 | import sushi
8 |
9 | here = os.path.dirname(os.path.abspath(__file__))
10 |
11 |
12 | class FakeEvent(object):
13 | def __init__(self, shift=0.0, diff=0.0, end=0.0, start=0.0):
14 | self.shift = shift
15 | self.linked = None
16 | self.diff = diff
17 | self.start = start
18 | self.end = end
19 |
20 | def set_shift(self, shift, diff):
21 | self.shift = shift
22 | self.diff = diff
23 |
24 | def link_event(self, other):
25 | self.linked = other
26 |
27 | def __repr__(self):
28 | return repr(self.shift)
29 |
30 | def __eq__(self, other):
31 | return self.__dict__ == other.__dict__
32 |
33 |
34 | class InterpolateNonesTestCase(unittest.TestCase):
35 | def test_returns_empty_array_when_passed_empty_array(self):
36 | self.assertEquals(sushi.interpolate_nones([], []), [])
37 |
38 | def test_returns_false_when_no_valid_points(self):
39 | self.assertFalse(sushi.interpolate_nones([None, None, None], [1, 2, 3]))
40 |
41 | def test_returns_full_array_when_no_nones(self):
42 | self.assertEqual(sushi.interpolate_nones([1, 2, 3], [1, 2, 3]), [1, 2, 3])
43 |
44 | def test_interpolates_simple_nones(self):
45 | self.assertEqual(sushi.interpolate_nones([1, None, 3, None, 5], [1, 2, 3, 4, 5]), [1, 2, 3, 4, 5])
46 |
47 | def test_interpolates_multiple_adjacent_nones(self):
48 | self.assertEqual(sushi.interpolate_nones([1, None, None, None, 5], [1, 2, 3, 4, 5]), [1, 2, 3, 4, 5])
49 |
50 | def test_copies_values_to_borders(self):
51 | self.assertEqual(sushi.interpolate_nones([None, None, 2, None, None], [1, 2, 3, 4, 5]), [2, 2, 2, 2, 2])
52 |
53 | def test_copies_values_to_borders_when_everything_is_zero(self):
54 | self.assertEqual(sushi.interpolate_nones([None, 0, 0, 0, None], [1, 2, 3, 4, 5]), [0, 0, 0, 0, 0])
55 |
56 | def test_interpolates_based_on_passed_points(self):
57 | self.assertEqual(sushi.interpolate_nones([1, None, 10], [1, 2, 10]), [1, 2, 10])
58 |
59 |
60 | class RunningMedianTestCase(unittest.TestCase):
61 | def test_does_no_touch_border_values(self):
62 | shifts = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
63 | smooth = sushi.running_median(shifts, 5)
64 | self.assertEqual(shifts, smooth)
65 |
66 | def test_removes_broken_values(self):
67 | shifts = [0.1, 0.1, 0.1, 9001, 0.1, 0.1, 0.1]
68 | smooth = sushi.running_median(shifts, 5)
69 | self.assertEqual(smooth, [0.1] * 7)
70 |
71 |
72 | class SmoothEventsTestCase(unittest.TestCase):
73 | def test_smooths_events_shifts(self):
74 | events = [FakeEvent(x) for x in (0.1, 0.1, 0.1, 9001, 7777, 0.1, 0.1, 0.1)]
75 | sushi.smooth_events(events, 7)
76 | self.assertEqual([x.shift for x in events], [0.1] * 8)
77 |
78 | def test_keeps_diff_values(self):
79 | events = [FakeEvent(x, diff=x) for x in (0.1, 0.1, 0.1, 9001, 7777, 0.1, 0.1, 0.1)]
80 | diffs = [x.diff for x in events]
81 | sushi.smooth_events(events, 7)
82 | self.assertEqual([x.diff for x in events], diffs)
83 |
84 |
85 | class DetectGroupsTestCase(unittest.TestCase):
86 | def test_splits_three_simple_groups(self):
87 | events = [FakeEvent(0.5)] * 3 + [FakeEvent(1.0)] * 10 + [FakeEvent(0.5)] * 5
88 | groups = sushi.detect_groups(events)
89 | self.assertEqual(3, len(groups[0]))
90 | self.assertEqual(10, len(groups[1]))
91 | self.assertEqual(5, len(groups[2]))
92 |
93 | def test_single_group_for_all_events(self):
94 | events = [FakeEvent(0.5)] * 10
95 | groups = sushi.detect_groups(events)
96 | self.assertEqual(10, len(groups[0]))
97 |
98 |
99 | class GroupsFromChaptersTestCase(unittest.TestCase):
100 | def test_all_events_in_one_group_when_no_chapters(self):
101 | events = [FakeEvent(end=1), FakeEvent(end=2), FakeEvent(end=3)]
102 | groups = sushi.groups_from_chapters(events, [])
103 | self.assertEqual(1, len(groups))
104 | self.assertEqual(events, groups[0])
105 |
106 | def test_events_in_two_groups_one_chapter(self):
107 | events = [FakeEvent(end=1), FakeEvent(end=2), FakeEvent(end=3)]
108 | groups = sushi.groups_from_chapters(events, [0.0, 1.5])
109 | self.assertEqual(2, len(groups))
110 | self.assertItemsEqual([events[0]], groups[0])
111 | self.assertItemsEqual([events[1], events[2]], groups[1])
112 |
113 | def test_multiple_groups_multiple_chapters(self):
114 | events = [FakeEvent(end=x) for x in xrange(1, 10)]
115 | groups = sushi.groups_from_chapters(events, [0.0, 3.2, 4.4, 7.7])
116 | self.assertEqual(4, len(groups))
117 | self.assertItemsEqual(events[0:3], groups[0])
118 | self.assertItemsEqual(events[3:4], groups[1])
119 | self.assertItemsEqual(events[4:7], groups[2])
120 | self.assertItemsEqual(events[7:9], groups[3])
121 |
122 |
123 | class SplitBrokenGroupsTestCase(unittest.TestCase):
124 | def test_doing_nothing_on_correct_groups(self):
125 | groups = [[FakeEvent(0.5), FakeEvent(0.5)], [FakeEvent(10.0)]]
126 | fixed = sushi.split_broken_groups(groups)
127 | self.assertItemsEqual(groups, fixed)
128 |
129 | def test_split_groups_without_merging(self):
130 | groups = [
131 | [FakeEvent(0.5)] * 10 + [FakeEvent(10.0)] * 5,
132 | [FakeEvent(0.5)] * 10,
133 | ]
134 | fixed = sushi.split_broken_groups(groups)
135 | self.assertItemsEqual([
136 | [FakeEvent(0.5)] * 10,
137 | [FakeEvent(10.0)] * 5,
138 | [FakeEvent(0.5)] * 10
139 | ], fixed)
140 |
141 | def test_split_groups_with_merging(self):
142 | groups = [
143 | [FakeEvent(0.5), FakeEvent(10.0)],
144 | [FakeEvent(10.0), FakeEvent(10.0), FakeEvent(15.0)]
145 | ]
146 | fixed = sushi.split_broken_groups(groups)
147 | self.assertItemsEqual([
148 | [FakeEvent(0.5)],
149 | [FakeEvent(10.0), FakeEvent(10.0), FakeEvent(10.0)],
150 | [FakeEvent(15.0)]
151 | ], fixed)
152 |
153 |
154 | class FixNearBordersTestCase(unittest.TestCase):
155 | def test_propagates_last_correct_shift_to_broken_events(self):
156 | events = [FakeEvent(diff=x) for x in (0.9, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.9)]
157 | sushi.fix_near_borders(events)
158 | sf = events[2]
159 | sl = events[-3]
160 | self.assertEqual([x.linked for x in events], [sf, sf, None, None, None, None, None, sl, sl])
161 |
162 | def test_returns_array_with_no_broken_events_unchanged(self):
163 | events = [FakeEvent(diff=x) for x in (0.9, 0.9, 0.9, 1.0, 0.9)]
164 | sushi.fix_near_borders(events)
165 | self.assertEqual([x.linked for x in events], [None, None, None, None, None])
166 |
167 |
168 | class GetDistanceToClosestKeyframeTestCase(unittest.TestCase):
169 | KEYTIMES = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
170 |
171 | def test_finds_correct_distance_to_first_keyframe(self):
172 | self.assertEqual(sushi.get_distance_to_closest_kf(0, self.KEYTIMES), 0)
173 |
174 | def test_finds_correct_distance_to_last_keyframe(self):
175 | self.assertEqual(sushi.get_distance_to_closest_kf(105, self.KEYTIMES), -5)
176 |
177 | def test_finds_correct_distance_to_keyframe_before(self):
178 | self.assertEqual(sushi.get_distance_to_closest_kf(63, self.KEYTIMES), -3)
179 |
180 | def test_finds_distance_to_keyframe_after(self):
181 | self.assertEqual(sushi.get_distance_to_closest_kf(36, self.KEYTIMES), 4)
182 |
183 |
184 | @patch('sushi.check_file_exists')
185 | class MainScriptTestCase(unittest.TestCase):
186 | @staticmethod
187 | def any_case_regex(text):
188 | return re.compile(text, flags=re.IGNORECASE)
189 |
190 | def test_checks_that_files_exist(self, mock_object):
191 | keys = ['--dst', 'dst', '--src', 'src', '--script', 'script', '--chapters', 'chapters',
192 | '--dst-keyframes', 'dst-keyframes', '--src-keyframes', 'src-keyframes',
193 | '--src-timecodes', 'src-tcs', '--dst-timecodes', 'dst-tcs']
194 | try:
195 | sushi.parse_args_and_run(keys)
196 | except SushiError:
197 | pass
198 | mock_object.assert_any_call('src', ANY)
199 | mock_object.assert_any_call('dst', ANY)
200 | mock_object.assert_any_call('script', ANY)
201 | mock_object.assert_any_call('chapters', ANY)
202 | mock_object.assert_any_call('dst-keyframes', ANY)
203 | mock_object.assert_any_call('src-keyframes', ANY)
204 | mock_object.assert_any_call('dst-tcs', ANY)
205 | mock_object.assert_any_call('src-tcs', ANY)
206 |
207 | def test_raises_on_unknown_script_type(self, ignore):
208 | keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.mp4']
209 | self.assertRaisesRegexp(SushiError, self.any_case_regex(r'script.*type'), lambda: sushi.parse_args_and_run(keys))
210 |
211 | def test_raises_on_script_type_not_matching(self, ignore):
212 | keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.ass', '-o', 'd.srt']
213 | self.assertRaisesRegexp(SushiError, self.any_case_regex(r'script.*type.*match'),
214 | lambda: sushi.parse_args_and_run(keys))
215 |
216 | def test_raises_on_timecodes_and_fps_being_defined_together(self, ignore):
217 | keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.ass', '--src-timecodes', 'tc.txt', '--src-fps', '25']
218 | self.assertRaisesRegexp(SushiError, self.any_case_regex(r'timecodes'), lambda: sushi.parse_args_and_run(keys))
219 |
220 |
221 | class FormatTimeTestCase(unittest.TestCase):
222 | def test_format_time_zero(self):
223 | self.assertEqual('0:00:00.00', format_time(0))
224 |
225 | def test_format_time_65_seconds(self):
226 | self.assertEqual('0:01:05.00', format_time(65))
227 |
228 | def test_format_time_float_seconds(self):
229 | self.assertEqual('0:00:05.56', format_time(5.559))
230 |
231 | def test_format_time_hours(self):
232 | self.assertEqual('1:15:35.15', format_time(3600 + 60 * 15 + 35.15))
233 |
234 | def test_format_100ms(self):
235 | self.assertEqual('0:09:05.00', format_time(544.997))
236 |
--------------------------------------------------------------------------------
/tests/subtitles.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import tempfile
3 | import os
4 | import codecs
5 | from subs import AssEvent, AssScript, SrtEvent, SrtScript
6 |
7 | SINGLE_LINE_SRT_EVENT = """1
8 | 00:14:21,960 --> 00:14:22,960
9 | HOW DID IT END UP LIKE THIS?"""
10 |
11 | MULTILINE_SRT_EVENT = """2
12 | 00:13:12,140 --> 00:13:14,100
13 | APPEARANCE!
14 | Appearrance (teisai)!
15 | No wait, you're the worst (saitei)!"""
16 |
17 | ASS_EVENT = r"Dialogue: 0,0:18:50.98,0:18:55.28,Default,,0,0,0,,Are you trying to (ouch) crush it (ouch)\N like a (ouch) vise (ouch, ouch)?"
18 |
19 | ASS_COMMENT = r"Comment: 0,0:18:09.15,0:18:10.36,Default,,0,0,0,,I DON'T GET IT TOO WELL."
20 |
21 |
22 | class SrtEventTestCase(unittest.TestCase):
23 | def test_simple_parsing(self):
24 | event = SrtEvent.from_string(SINGLE_LINE_SRT_EVENT)
25 | self.assertEquals(14*60+21.960, event.start)
26 | self.assertEquals(14*60+22.960, event.end)
27 | self.assertEquals("HOW DID IT END UP LIKE THIS?", event.text)
28 |
29 | def test_multi_line_event_parsing(self):
30 | event = SrtEvent.from_string(MULTILINE_SRT_EVENT)
31 | self.assertEquals(13*60+12.140, event.start)
32 | self.assertEquals(13*60+14.100, event.end)
33 | self.assertEquals("APPEARANCE!\nAppearrance (teisai)!\nNo wait, you're the worst (saitei)!", event.text)
34 |
35 | def test_parsing_and_printing(self):
36 | self.assertEquals(SINGLE_LINE_SRT_EVENT, unicode(SrtEvent.from_string(SINGLE_LINE_SRT_EVENT)))
37 | self.assertEquals(MULTILINE_SRT_EVENT, unicode(SrtEvent.from_string(MULTILINE_SRT_EVENT)))
38 |
39 |
40 | class AssEventTestCase(unittest.TestCase):
41 | def test_simple_parsing(self):
42 | event = AssEvent(ASS_EVENT)
43 | self.assertFalse(event.is_comment)
44 | self.assertEquals("Dialogue", event.kind)
45 | self.assertEquals(18*60+50.98, event.start)
46 | self.assertEquals(18*60+55.28, event.end)
47 | self.assertEquals("0", event.layer)
48 | self.assertEquals("Default", event.style)
49 | self.assertEquals("", event.name)
50 | self.assertEquals("0", event.margin_left)
51 | self.assertEquals("0", event.margin_right)
52 | self.assertEquals("0", event.margin_vertical)
53 | self.assertEquals("", event.effect)
54 | self.assertEquals("Are you trying to (ouch) crush it (ouch)\N like a (ouch) vise (ouch, ouch)?", event.text)
55 |
56 | def test_comment_parsing(self):
57 | event = AssEvent(ASS_COMMENT)
58 | self.assertTrue(event.is_comment)
59 | self.assertEquals("Comment", event.kind)
60 |
61 | def test_parsing_and_printing(self):
62 | self.assertEquals(ASS_EVENT, unicode(AssEvent(ASS_EVENT)))
63 | self.assertEquals(ASS_COMMENT, unicode(AssEvent(ASS_COMMENT)))
64 |
65 |
66 | class ScriptTestBase(unittest.TestCase):
67 | def setUp(self):
68 | self.script_description, self.script_path = tempfile.mkstemp()
69 |
70 | def tearDown(self):
71 | os.remove(self.script_path)
72 |
73 |
74 | class SrtScriptTestCase(ScriptTestBase):
75 | def test_write_to_file(self):
76 | events = [SrtEvent.from_string(SINGLE_LINE_SRT_EVENT), SrtEvent.from_string(MULTILINE_SRT_EVENT)]
77 | SrtScript(events).save_to_file(self.script_path)
78 | with open(self.script_path) as script:
79 | text = script.read()
80 | self.assertEquals(SINGLE_LINE_SRT_EVENT + "\n\n" + MULTILINE_SRT_EVENT, text)
81 |
82 | def test_read_from_file(self):
83 | os.write(self.script_description, """1
84 | 00:00:17,500 --> 00:00:18,870
85 | Yeah, really!
86 |
87 | 2
88 | 00:00:17,500 --> 00:00:18,870
89 |
90 |
91 | 3
92 | 00:00:17,500 --> 00:00:18,870
93 | House number
94 | 35
95 |
96 | 4
97 | 00:00:21,250 --> 00:00:22,750
98 | Serves you right.""")
99 | parsed = SrtScript.from_file(self.script_path).events
100 | self.assertEquals(17.5, parsed[0].start)
101 | self.assertEquals(18.87, parsed[0].end)
102 | self.assertEquals("Yeah, really!", parsed[0].text)
103 | self.assertEquals(17.5, parsed[1].start)
104 | self.assertEquals(18.87, parsed[1].end)
105 | self.assertEquals("", parsed[1].text)
106 | self.assertEquals(17.5, parsed[2].start)
107 | self.assertEquals(18.87, parsed[2].end)
108 | self.assertEquals("House number\n35", parsed[2].text)
109 | self.assertEquals(21.25, parsed[3].start)
110 | self.assertEquals(22.75, parsed[3].end)
111 | self.assertEquals("Serves you right.", parsed[3].text)
112 |
113 |
114 | class AssScriptTestCase(ScriptTestBase):
115 | def test_write_to_file(self):
116 | styles = ["Style: Default,Open Sans Semibold,36,&H00FFFFFF,&H000000FF,&H00020713,&H00000000,-1,0,0,0,100,100,0,0,1,1.7,0,2,0,0,28,1"]
117 | events = [AssEvent(ASS_EVENT), AssEvent(ASS_EVENT)]
118 | AssScript([], styles, events, None).save_to_file(self.script_path)
119 |
120 | with codecs.open(self.script_path, encoding='utf-8-sig') as script:
121 | text = script.read()
122 |
123 | self.assertEquals("""[V4+ Styles]
124 | Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding
125 | Style: Default,Open Sans Semibold,36,&H00FFFFFF,&H000000FF,&H00020713,&H00000000,-1,0,0,0,100,100,0,0,1,1.7,0,2,0,0,28,1
126 |
127 | [Events]
128 | Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
129 | {0}
130 | {0}""".format(ASS_EVENT), text)
131 |
132 | def test_read_from_file(self):
133 | text = """[Script Info]
134 | ; Script generated by Aegisub 3.1.1
135 | Title: script title
136 |
137 | [V4+ Styles]
138 | Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding
139 | Style: Default,Open Sans Semibold,36,&H00FFFFFF,&H000000FF,&H00020713,&H00000000,-1,0,0,0,100,100,0,0,1,1.7,0,2,0,0,28,1
140 | Style: Signs,Gentium Basic,40,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,0,0,0,0,100,100,0,0,1,0,0,2,10,10,10,1
141 |
142 | [Events]
143 | Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
144 | Dialogue: 0,0:00:01.42,0:00:03.36,Default,,0000,0000,0000,,As you already know,
145 | Dialogue: 0,0:00:03.36,0:00:05.93,Default,,0000,0000,0000,,I'm concerned about the hair on my nipples."""
146 |
147 | os.write(self.script_description, text)
148 | script = AssScript.from_file(self.script_path)
149 | self.assertEquals(["; Script generated by Aegisub 3.1.1", "Title: script title"], script.script_info)
150 | self.assertEquals(["Style: Default,Open Sans Semibold,36,&H00FFFFFF,&H000000FF,&H00020713,&H00000000,-1,0,0,0,100,100,0,0,1,1.7,0,2,0,0,28,1",
151 | "Style: Signs,Gentium Basic,40,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,0,0,0,0,100,100,0,0,1,0,0,2,10,10,10,1"],
152 | script.styles)
153 | self.assertEquals([1, 2], [x.source_index for x in script.events])
154 | self.assertEquals(u"Dialogue: 0,0:00:01.42,0:00:03.36,Default,,0000,0000,0000,,As you already know,", unicode(script.events[0]))
155 | self.assertEquals(u"Dialogue: 0,0:00:03.36,0:00:05.93,Default,,0000,0000,0000,,I'm concerned about the hair on my nipples.", unicode(script.events[1]))
156 |
--------------------------------------------------------------------------------
/tests/timecodes.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from demux import Timecodes
3 |
4 |
5 | class CfrTimecodesTestCase(unittest.TestCase):
6 | def test_get_frame_time_zero(self):
7 | tcs = Timecodes.cfr(23.976)
8 | t = tcs.get_frame_time(0)
9 | self.assertEqual(t, 0)
10 |
11 | def test_get_frame_time_sane(self):
12 | tcs = Timecodes.cfr(23.976)
13 | t = tcs.get_frame_time(10)
14 | self.assertAlmostEqual(10.0/23.976, t)
15 |
16 | def test_get_frame_time_insane(self):
17 | tcs = Timecodes.cfr(23.976)
18 | t = tcs.get_frame_time(100000)
19 | self.assertAlmostEqual(100000.0/23.976, t)
20 |
21 | def test_get_frame_size(self):
22 | tcs = Timecodes.cfr(23.976)
23 | t1 = tcs.get_frame_size(0)
24 | t2 = tcs.get_frame_size(1000)
25 | self.assertAlmostEqual(1.0/23.976, t1)
26 | self.assertAlmostEqual(t1, t2)
27 |
28 | def test_get_frame_number(self):
29 | tcs = Timecodes.cfr(24000.0/1001.0)
30 | self.assertEqual(tcs.get_frame_number(0), 0)
31 | self.assertEqual(tcs.get_frame_number(1145.353), 27461)
32 | self.assertEqual(tcs.get_frame_number(1001.0/24000.0 * 1234567), 1234567)
33 |
34 |
35 | class TimecodesTestCase(unittest.TestCase):
36 | def test_cfr_timecodes_v2(self):
37 | text = '# timecode format v2\n' + '\n'.join(str(1000 * x / 23.976) for x in range(0, 30000))
38 | parsed = Timecodes.parse(text)
39 |
40 | self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(0))
41 | self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(25))
42 | self.assertAlmostEqual(1.0/23.976*100, parsed.get_frame_time(100))
43 | self.assertEqual(0, parsed.get_frame_time(0))
44 | self.assertEqual(0, parsed.get_frame_number(0))
45 | self.assertEqual(27461, parsed.get_frame_number(1145.353))
46 |
47 | def test_cfr_timecodes_v1(self):
48 | text = '# timecode format v1\nAssume 23.976024'
49 | parsed = Timecodes.parse(text)
50 | self.assertAlmostEqual(1.0/23.976024, parsed.get_frame_size(0))
51 | self.assertAlmostEqual(1.0/23.976024, parsed.get_frame_size(25))
52 | self.assertAlmostEqual(1.0/23.976024*100, parsed.get_frame_time(100))
53 | self.assertEqual(0, parsed.get_frame_time(0))
54 | self.assertEqual(0, parsed.get_frame_number(0))
55 | self.assertEqual(27461, parsed.get_frame_number(1145.353))
56 |
57 | def test_cfr_timecodes_v1_with_overrides(self):
58 | text = '# timecode format v1\nAssume 23.976000\n0,2000,23.976000\n3000,5000,23.976000'
59 | parsed = Timecodes.parse(text)
60 | self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(0))
61 | self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(25))
62 | self.assertAlmostEqual(1.0/23.976*100, parsed.get_frame_time(100))
63 | self.assertEqual(0, parsed.get_frame_time(0))
64 |
65 | def test_vfr_timecodes_v1_frame_size_at_first_frame(self):
66 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000'
67 | parsed = Timecodes.parse(text)
68 | self.assertAlmostEqual(1.0/29.97, parsed.get_frame_size(timestamp=0))
69 |
70 | def test_vfr_timecodes_v1_frame_size_outside_of_defined_range(self):
71 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000'
72 | parsed = Timecodes.parse(text)
73 | self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(timestamp=5000.0))
74 |
75 | def test_vft_timecodes_v1_frame_size_inside_override_block(self):
76 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000'
77 | parsed = Timecodes.parse(text)
78 | self.assertAlmostEqual(1.0/29.97, parsed.get_frame_size(timestamp=49.983))
79 |
80 | def test_vft_timecodes_v1_frame_size_between_override_blocks(self):
81 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000'
82 | parsed = Timecodes.parse(text)
83 | self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(timestamp=87.496))
84 |
85 | def test_vfr_timecodes_v1_frame_time_at_first_frame(self):
86 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000'
87 | parsed = Timecodes.parse(text)
88 | self.assertAlmostEqual(0, parsed.get_frame_time(number=0))
89 |
90 | def test_vfr_timecodes_v1_frame_time_outside_of_defined_range(self):
91 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000'
92 | parsed = Timecodes.parse(text)
93 | self.assertAlmostEqual(1000.968, parsed.get_frame_time(number=25000), places=3)
94 |
95 | def test_vft_timecodes_v1_frame_time_inside_override_block(self):
96 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000'
97 | parsed = Timecodes.parse(text)
98 | self.assertAlmostEqual(50.05, parsed.get_frame_time(number=1500), places=3)
99 |
100 | def test_vft_timecodes_v1_frame_time_between_override_blocks(self):
101 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000'
102 | parsed = Timecodes.parse(text)
103 | self.assertAlmostEqual(87.579, parsed.get_frame_time(number=2500), places=3)
104 |
105 |
106 |
--------------------------------------------------------------------------------
/wav.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import cv2
3 | import numpy as np
4 | from chunk import Chunk
5 | import struct
6 | import math
7 | from time import time
8 | import os.path
9 | from common import SushiError, clip
10 |
11 | WAVE_FORMAT_PCM = 0x0001
12 | WAVE_FORMAT_EXTENSIBLE = 0xFFFE
13 |
14 |
15 | class DownmixedWavFile(object):
16 | _file = None
17 |
18 | def __init__(self, path):
19 | super(DownmixedWavFile, self).__init__()
20 | self._file = open(path, 'rb')
21 | try:
22 | riff = Chunk(self._file, bigendian=False)
23 | if riff.getname() != 'RIFF':
24 | raise SushiError('File does not start with RIFF id')
25 | if riff.read(4) != 'WAVE':
26 | raise SushiError('Not a WAVE file')
27 |
28 | fmt_chunk_read = False
29 | data_chink_read = False
30 | file_size = os.path.getsize(path)
31 |
32 | while True:
33 | try:
34 | chunk = Chunk(self._file, bigendian=False)
35 | except EOFError:
36 | break
37 |
38 | if chunk.getname() == 'fmt ':
39 | self._read_fmt_chunk(chunk)
40 | fmt_chunk_read = True
41 | elif chunk.getname() == 'data':
42 | if file_size > 0xFFFFFFFF:
43 | # large broken wav
44 | self.frames_count = (file_size - self._file.tell()) // self.frame_size
45 | else:
46 | self.frames_count = chunk.chunksize // self.frame_size
47 | data_chink_read = True
48 | break
49 | chunk.skip()
50 | if not fmt_chunk_read or not data_chink_read:
51 | raise SushiError('Invalid WAV file')
52 | except:
53 | self.close()
54 | raise
55 |
56 | def __del__(self):
57 | self.close()
58 |
59 | def close(self):
60 | if self._file:
61 | self._file.close()
62 | self._file = None
63 |
64 | def readframes(self, count):
65 | if not count:
66 | return ''
67 | data = self._file.read(count * self.frame_size)
68 | if self.sample_width == 2:
69 | unpacked = np.fromstring(data, dtype=np.int16)
70 | elif self.sample_width == 3:
71 | raw_bytes = np.ndarray(len(data), 'int8', data)
72 | unpacked = np.zeros(len(data) / 3, np.int16)
73 | unpacked.view(dtype='int8')[0::2] = raw_bytes[1::3]
74 | unpacked.view(dtype='int8')[1::2] = raw_bytes[2::3]
75 | else:
76 | raise SushiError('Unsupported sample width: {0}'.format(self.sample_width))
77 |
78 | unpacked = unpacked.astype('float32')
79 |
80 | if self.channels_count == 1:
81 | return unpacked
82 | else:
83 | min_length = len(unpacked) // self.channels_count
84 | real_length = len(unpacked) / float(self.channels_count)
85 | if min_length != real_length:
86 | logging.error("Length of audio channels didn't match. This might result in broken output")
87 |
88 | channels = (unpacked[i::self.channels_count] for i in xrange(self.channels_count))
89 | data = reduce(lambda a, b: a[:min_length]+b[:min_length], channels)
90 | data /= float(self.channels_count)
91 | return data
92 |
93 | def _read_fmt_chunk(self, chunk):
94 | wFormatTag, self.channels_count, self.framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack('= 0], overwrite_input=True) * 3
146 | min_value = np.median(self.data[self.data <= 0], overwrite_input=True) * 3
147 |
148 | np.clip(self.data, min_value, max_value, out=self.data)
149 |
150 | self.data -= min_value
151 | self.data /= (max_value - min_value)
152 |
153 | if sample_type == 'uint8':
154 | self.data *= 255.0
155 | self.data += 0.5
156 | self.data = self.data.astype('uint8')
157 |
158 | except Exception as e:
159 | raise SushiError('Error while loading {0}: {1}'.format(path, e))
160 | finally:
161 | stream.close()
162 | logging.info('Done reading WAV {0} in {1}s'.format(path, time() - before_read))
163 |
164 | @property
165 | def duration_seconds(self):
166 | return self.sample_count / self.sample_rate
167 |
168 | def get_substream(self, start, end):
169 | start_off = self._get_sample_for_time(start)
170 | end_off = self._get_sample_for_time(end)
171 | return self.data[:, start_off:end_off]
172 |
173 | def _get_sample_for_time(self, timestamp):
174 | # this function gets REAL sample for time, taking padding into account
175 | return int(self.sample_rate * timestamp) + self.padding_size
176 |
177 | def find_substream(self, pattern, window_center, window_size):
178 | start_time = clip(window_center - window_size, -self.PADDING_SECONDS, self.duration_seconds)
179 | end_time = clip(window_center + window_size, 0, self.duration_seconds + self.PADDING_SECONDS)
180 |
181 | start_sample = self._get_sample_for_time(start_time)
182 | end_sample = self._get_sample_for_time(end_time) + len(pattern[0])
183 |
184 | search_source = self.data[:, start_sample:end_sample]
185 | result = cv2.matchTemplate(search_source, pattern, cv2.TM_SQDIFF_NORMED)
186 | min_idx = result.argmin(axis=1)[0]
187 |
188 | return result[0][min_idx], start_time + (min_idx / float(self.sample_rate))
189 |
--------------------------------------------------------------------------------