├── .gitattributes ├── .gitignore ├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── build-windows.bat ├── chapters.py ├── common.py ├── demux.py ├── keyframes.py ├── licenses ├── OpenCV.txt └── SciPy.txt ├── regression-tests.py ├── requirements-win.txt ├── requirements.txt ├── run-tests.py ├── setup.py ├── subs.py ├── sushi.py ├── tests.example.json ├── tests ├── __init__.py ├── demuxing.py ├── main.py ├── subtitles.py └── timecodes.py └── wav.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | 4 | # Custom for Visual Studio 5 | *.cs diff=csharp 6 | *.sln merge=union 7 | *.csproj merge=union 8 | *.vbproj merge=union 9 | *.fsproj merge=union 10 | *.dbproj merge=union 11 | 12 | # Standard to msysgit 13 | *.doc diff=astextplain 14 | *.DOC diff=astextplain 15 | *.docx diff=astextplain 16 | *.DOCX diff=astextplain 17 | *.dot diff=astextplain 18 | *.DOT diff=astextplain 19 | *.pdf diff=astextplain 20 | *.PDF diff=astextplain 21 | *.rtf diff=astextplain 22 | *.RTF diff=astextplain 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | test 3 | *.pyc 4 | tests/media 5 | dist 6 | build 7 | tests.json 8 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | virtualenv: 4 | system_site_packages: true 5 | before_install: 6 | - sudo apt-get update 7 | - sudo apt-get install python-opencv 8 | - sudo dpkg -L python-opencv 9 | - sudo ln /dev/null /dev/raw1394 10 | install: 11 | - "pip install -r requirements.txt" 12 | 13 | python: 14 | - "2.7" 15 | script: 16 | - python run-tests.py 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Before submitting a bug, please read [how to do it](https://github.com/tp7/Sushi/wiki/Common-errors). 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014-2017 Victor Efimov 2 | 3 | Permission is hereby granted, free of charge, to any person 4 | obtaining a copy of this software and associated documentation 5 | files (the "Software"), to deal in the Software without 6 | restriction, including without limitation the rights to use, 7 | copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the 9 | Software is furnished to do so, subject to the following 10 | conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 17 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 19 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 20 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Sushi [![Build Status](https://travis-ci.org/tp7/Sushi.svg?branch=master)](https://travis-ci.org/tp7/Sushi) 2 | Automatic shifter for SRT and ASS subtitle based on audio streams. 3 | 4 | ### Purpose 5 | Imagine you've got a subtitle file synced to one video file, but you want to use these subtitles with some other video you've got via totally legal means. The common example is TV vs. BD releases, PAL vs. NTSC video and releases in different countries. In a lot of cases, subtitles won't match right away and you need to sync them. 6 | 7 | The purpose of this script is to avoid all the hassle of manual syncing. It attempts to synchronize subtitles by finding similarities in audio streams. The script is very fast and can be used right when you want to watch something. 8 | 9 | ### Downloads 10 | The latest Windows binary release can always be found in the [releases][1] section. You need the 7z archive in the top entry. 11 | 12 | ### How it works 13 | You need to provide two audio files and a subtitle file that matches one of those files. For every line of the subtitles, the script will extract corresponding audio from the source audio stream and will try to find the closest similar pattern in the destination audio stream, obtaining a shift value which is later applied to the subtitles. 14 | 15 | Detailed explanation of Sushi workflow and description of command-line arguments can be found in the [wiki][2]. 16 | 17 | ### Usage 18 | The minimal command line looks like this: 19 | ``` 20 | python sushi.py --src hdtv.wav --dst bluray.wav --script subs.ass 21 | ``` 22 | Output file name is optional - `"{destination_path}.sushi.{subtitles_format}"` is used by default. See the [usage][3] page of the wiki for further examples. 23 | 24 | Do note that WAV is not the only format Sushi can work with. It can process audio/video files directly and decode various audio formats, provided that ffmpeg is available. For additional info refer to the [Demuxing][4] part of the wiki. 25 | 26 | ### Requirements 27 | Sushi should work on Windows, Linux and OS X. Please open an issue if it doesn't. To run it, you have to have the following installed: 28 | 29 | 1. [Python 2.7.x][5] 30 | 2. [NumPy][6] (1.8 or newer) 31 | 3. [OpenCV 2.4.x or newer][7] (on Windows putting [this file][8] in the same folder as Sushi should be enough, assuming you use x86 Python) 32 | 33 | Optionally, you might want: 34 | 35 | 1. [FFmpeg][9] for any kind of demuxing 36 | 2. [MkvExtract][10] for faster timecodes extraction when demuxing 37 | 3. [SCXvid-standalone][11] if you want Sushi to make keyframes 38 | 4. [Colorama](https://github.com/tartley/colorama) to add colors to console output on Windows 39 | 40 | The provided Windows binaries include all required components and Colorama so you don't have to install them if you use the binary distribution. You still have to download other applications yourself if you want to use Sushi's demuxing capabilities. 41 | 42 | #### Installation on Mac OS X 43 | 44 | No binary packages are provided for OS X right now so you'll have to use the script form. Assuming you have python 2, pip and [homebrew](http://brew.sh/) installed, run the following: 45 | ```bash 46 | brew tap homebrew/science 47 | brew install git opencv 48 | pip install numpy 49 | git clone https://github.com/tp7/sushi 50 | # create a symlink if you want to run sushi globally 51 | ln -s `pwd`/sushi/sushi.py /usr/local/bin/sushi 52 | # install some optional dependencies 53 | brew install ffmpeg mkvtoolnix 54 | ``` 55 | If you don't have pip, you can install numpy with homebrew, but that will probably add a few more dependencies. 56 | ```bash 57 | brew tap homebrew/python 58 | brew install numpy 59 | ``` 60 | 61 | #### Installation on Linux 62 | If you have apt-get available, the installation process is trivial. 63 | ```bash 64 | sudo apt-get update 65 | sudo apt-get install git python python-numpy python-opencv 66 | git clone https://github.com/tp7/sushi 67 | ln -s `pwd`/sushi/sushi.py /usr/local/bin/sushi 68 | ``` 69 | 70 | ### Limitations 71 | This script will never be able to property handle frame-by-frame typesetting. If underlying video stream changes (e.g. has different telecine pattern), you might get incorrect output. 72 | 73 | This script cannot improve bad timing. If original lines are mistimed, they will be mistimed in the output file too. 74 | 75 | In short, while this might be safe for immediate viewing, you probably shouldn't use it to blindly shift subtitles for permanent storing. 76 | 77 | 78 | [1]: https://github.com/tp7/Sushi/releases 79 | [2]: https://github.com/tp7/Sushi/wiki 80 | [3]: https://github.com/tp7/Sushi/wiki/Examples 81 | [4]: https://github.com/tp7/Sushi/wiki/Demuxing 82 | [5]: https://www.python.org/downloads/ 83 | [6]: http://www.scipy.org/scipylib/download.html 84 | [7]: http://opencv.org/ 85 | [8]: https://www.dropbox.com/s/nlylgdh4bgrjgxv/cv2.pyd?dl=0 86 | [9]: http://www.ffmpeg.org/download.html 87 | [10]: http://www.bunkus.org/videotools/mkvtoolnix/downloads.html 88 | [11]: https://github.com/soyokaze/SCXvid-standalone/releases 89 | -------------------------------------------------------------------------------- /build-windows.bat: -------------------------------------------------------------------------------- 1 | rmdir /S /Q dist 2 | 3 | pyinstaller --noupx --onefile --noconfirm ^ 4 | --exclude-module Tkconstants ^ 5 | --exclude-module Tkinter ^ 6 | --exclude-module matplotlib ^ 7 | sushi.py 8 | 9 | mkdir dist\licenses 10 | copy /Y licenses\* dist\licenses\* 11 | copy LICENSE dist\licenses\Sushi.txt 12 | copy README.md dist\readme.md -------------------------------------------------------------------------------- /chapters.py: -------------------------------------------------------------------------------- 1 | import re 2 | import common 3 | 4 | 5 | def parse_times(times): 6 | result = [] 7 | for t in times: 8 | hours, minutes, seconds = map(float, t.split(':')) 9 | result.append(hours * 3600 + minutes * 60 + seconds) 10 | 11 | result.sort() 12 | if result[0] != 0: 13 | result.insert(0, 0) 14 | return result 15 | 16 | 17 | def parse_xml_start_times(text): 18 | times = re.findall(r'(\d+:\d+:\d+\.\d+)', text) 19 | return parse_times(times) 20 | 21 | 22 | def get_xml_start_times(path): 23 | return parse_xml_start_times(common.read_all_text(path)) 24 | 25 | 26 | def parse_ogm_start_times(text): 27 | times = re.findall(r'CHAPTER\d+=(\d+:\d+:\d+\.\d+)', text, flags=re.IGNORECASE) 28 | return parse_times(times) 29 | 30 | 31 | def get_ogm_start_times(path): 32 | return parse_ogm_start_times(common.read_all_text(path)) 33 | 34 | 35 | def format_ogm_chapters(start_times): 36 | return "\n".join("CHAPTER{0:02}={1}\nCHAPTER{0:02}NAME=".format(idx+1, common.format_srt_time(start).replace(',', '.')) 37 | for idx, start in enumerate(start_times)) + "\n" 38 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class SushiError(Exception): 5 | pass 6 | 7 | 8 | def get_extension(path): 9 | return (os.path.splitext(path)[1]).lower() 10 | 11 | 12 | def read_all_text(path): 13 | with open(path) as file: 14 | return file.read() 15 | 16 | 17 | def ensure_static_collection(value): 18 | if isinstance(value, (set, list, tuple)): 19 | return value 20 | return list(value) 21 | 22 | 23 | def format_srt_time(seconds): 24 | cs = round(seconds * 1000) 25 | return u'{0:02d}:{1:02d}:{2:02d},{3:03d}'.format( 26 | int(cs // 3600000), 27 | int((cs // 60000) % 60), 28 | int((cs // 1000) % 60), 29 | int(cs % 1000)) 30 | 31 | 32 | def format_time(seconds): 33 | cs = round(seconds * 100) 34 | return u'{0}:{1:02d}:{2:02d}.{3:02d}'.format( 35 | int(cs // 360000), 36 | int((cs // 6000) % 60), 37 | int((cs // 100) % 60), 38 | int(cs % 100)) 39 | 40 | 41 | def clip(value, minimum, maximum): 42 | return max(min(value, maximum), minimum) 43 | -------------------------------------------------------------------------------- /demux.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import subprocess 4 | from collections import namedtuple 5 | import logging 6 | import bisect 7 | 8 | from common import SushiError, get_extension 9 | import chapters 10 | 11 | MediaStreamInfo = namedtuple('MediaStreamInfo', ['id', 'info', 'default', 'title']) 12 | SubtitlesStreamInfo = namedtuple('SubtitlesStreamInfo', ['id', 'info', 'type', 'default', 'title']) 13 | MediaInfo = namedtuple('MediaInfo', ['video', 'audio', 'subtitles', 'chapters']) 14 | 15 | 16 | class FFmpeg(object): 17 | @staticmethod 18 | def get_info(path): 19 | try: 20 | process = subprocess.Popen(['ffmpeg', '-hide_banner', '-i', path], stderr=subprocess.PIPE) 21 | out, err = process.communicate() 22 | process.wait() 23 | return err 24 | except OSError as e: 25 | if e.errno == 2: 26 | raise SushiError("Couldn't invoke ffmpeg, check that it's installed") 27 | raise 28 | 29 | @staticmethod 30 | def demux_file(input_path, **kwargs): 31 | args = ['ffmpeg', '-hide_banner', '-i', input_path, '-y'] 32 | 33 | audio_stream = kwargs.get('audio_stream', None) 34 | audio_path = kwargs.get('audio_path', None) 35 | audio_rate = kwargs.get('audio_rate', None) 36 | if audio_stream is not None: 37 | args.extend(('-map', '0:{0}'.format(audio_stream))) 38 | if audio_rate: 39 | args.extend(('-ar', str(audio_rate))) 40 | args.extend(('-ac', '1', '-acodec', 'pcm_s16le', audio_path)) 41 | 42 | script_stream = kwargs.get('script_stream', None) 43 | script_path = kwargs.get('script_path', None) 44 | if script_stream is not None: 45 | args.extend(('-map', '0:{0}'.format(script_stream), script_path)) 46 | 47 | video_stream = kwargs.get('video_stream', None) 48 | timecodes_path = kwargs.get('timecodes_path', None) 49 | if timecodes_path is not None: 50 | args.extend(('-map', '0:{0}'.format(video_stream), '-f', 'mkvtimestamp_v2', timecodes_path)) 51 | 52 | logging.info('ffmpeg args: {0}'.format(' '.join(('"{0}"' if ' ' in a else '{0}').format(a) for a in args))) 53 | try: 54 | subprocess.call(args) 55 | except OSError as e: 56 | if e.errno == 2: 57 | raise SushiError("Couldn't invoke ffmpeg, check that it's installed") 58 | raise 59 | 60 | @staticmethod 61 | def _get_audio_streams(info): 62 | streams = re.findall(r'Stream\s\#0:(\d+).*?Audio:\s*(.*?(?:\((default)\))?)\s*?(?:\(forced\))?\r?\n' 63 | r'(?:\s*Metadata:\s*\r?\n' 64 | r'\s*title\s*:\s*(.*?)\r?\n)?', 65 | info, flags=re.VERBOSE) 66 | return [MediaStreamInfo(int(x[0]), x[1], x[2] != '', x[3]) for x in streams] 67 | 68 | @staticmethod 69 | def _get_video_streams(info): 70 | streams = re.findall(r'Stream\s\#0:(\d+).*?Video:\s*(.*?(?:\((default)\))?)\s*?(?:\(forced\))?\r?\n' 71 | r'(?:\s*Metadata:\s*\r?\n' 72 | r'\s*title\s*:\s*(.*?)\r?\n)?', 73 | info, flags=re.VERBOSE) 74 | return [MediaStreamInfo(int(x[0]), x[1], x[2] != '', x[3]) for x in streams] 75 | 76 | @staticmethod 77 | def _get_chapters_times(info): 78 | return map(float, re.findall(r'Chapter #0.\d+: start (\d+\.\d+)', info)) 79 | 80 | @staticmethod 81 | def _get_subtitles_streams(info): 82 | maps = { 83 | 'ssa': '.ass', 84 | 'ass': '.ass', 85 | 'subrip': '.srt' 86 | } 87 | 88 | streams = re.findall(r'Stream\s\#0:(\d+).*?Subtitle:\s*((\w*)\s*?(?:\((default)\))?\s*?(?:\(forced\))?)\r?\n' 89 | r'(?:\s*Metadata:\s*\r?\n' 90 | r'\s*title\s*:\s*(.*?)\r?\n)?', 91 | info, flags=re.VERBOSE) 92 | return [SubtitlesStreamInfo(int(x[0]), x[1], maps.get(x[2], x[2]), x[3] != '', x[4].strip()) for x in streams] 93 | 94 | @classmethod 95 | def get_media_info(cls, path): 96 | info = cls.get_info(path) 97 | video_streams = cls._get_video_streams(info) 98 | audio_streams = cls._get_audio_streams(info) 99 | subs_streams = cls._get_subtitles_streams(info) 100 | chapter_times = cls._get_chapters_times(info) 101 | return MediaInfo(video_streams, audio_streams, subs_streams, chapter_times) 102 | 103 | 104 | class MkvToolnix(object): 105 | @classmethod 106 | def extract_timecodes(cls, mkv_path, stream_idx, output_path): 107 | args = ['mkvextract', 'timecodes_v2', mkv_path, '{0}:{1}'.format(stream_idx, output_path)] 108 | subprocess.call(args) 109 | 110 | 111 | class SCXviD(object): 112 | @classmethod 113 | def make_keyframes(cls, video_path, log_path): 114 | try: 115 | ffmpeg_process = subprocess.Popen(['ffmpeg', '-i', video_path, 116 | '-f', 'yuv4mpegpipe', 117 | '-vf', 'scale=640:360', 118 | '-pix_fmt', 'yuv420p', 119 | '-vsync', 'drop', '-'], stdout=subprocess.PIPE) 120 | except OSError as e: 121 | if e.errno == 2: 122 | raise SushiError("Couldn't invoke ffmpeg, check that it's installed") 123 | raise 124 | 125 | try: 126 | scxvid_process = subprocess.Popen(['SCXvid', log_path], stdin=ffmpeg_process.stdout) 127 | except OSError as e: 128 | ffmpeg_process.kill() 129 | if e.errno == 2: 130 | raise SushiError("Couldn't invoke scxvid, check that it's installed") 131 | raise 132 | scxvid_process.wait() 133 | 134 | 135 | class Timecodes(object): 136 | def __init__(self, times, default_fps): 137 | super(Timecodes, self).__init__() 138 | self.times = times 139 | self.default_frame_duration = 1.0 / default_fps if default_fps else None 140 | 141 | def get_frame_time(self, number): 142 | try: 143 | return self.times[number] 144 | except IndexError: 145 | if not self.default_frame_duration: 146 | return self.get_frame_time(len(self.times)-1) 147 | if self.times: 148 | return self.times[-1] + (self.default_frame_duration) * (number - len(self.times) + 1) 149 | else: 150 | return number * self.default_frame_duration 151 | 152 | def get_frame_number(self, timestamp): 153 | if (not self.times or self.times[-1] < timestamp) and self.default_frame_duration: 154 | return int((timestamp - sum(self.times)) / self.default_frame_duration) 155 | return bisect.bisect_left(self.times, timestamp) 156 | 157 | def get_frame_size(self, timestamp): 158 | try: 159 | number = bisect.bisect_left(self.times, timestamp) 160 | except: 161 | return self.default_frame_duration 162 | 163 | c = self.get_frame_time(number) 164 | 165 | if number == len(self.times): 166 | p = self.get_frame_time(number - 1) 167 | return c - p 168 | else: 169 | n = self.get_frame_time(number + 1) 170 | return n - c 171 | 172 | @classmethod 173 | def _convert_v1_to_v2(cls, default_fps, overrides): 174 | # start, end, fps 175 | overrides = [(int(x[0]), int(x[1]), float(x[2])) for x in overrides] 176 | if not overrides: 177 | return [] 178 | 179 | fps = [default_fps] * (overrides[-1][1] + 1) 180 | for o in overrides: 181 | fps[o[0]:o[1] + 1] = [o[2]] * (o[1] - o[0] + 1) 182 | 183 | v2 = [0] 184 | for d in (1.0 / f for f in fps): 185 | v2.append(v2[-1] + d) 186 | return v2 187 | 188 | @classmethod 189 | def parse(cls, text): 190 | lines = text.splitlines() 191 | if not lines: 192 | return [] 193 | first = lines[0].lower().lstrip() 194 | if first.startswith('# timecode format v2') or first.startswith('# timestamp format v2'): 195 | tcs = [float(x) / 1000.0 for x in lines[1:]] 196 | return Timecodes(tcs, None) 197 | elif first.startswith('# timecode format v1'): 198 | default = float(lines[1].lower().replace('assume ', "")) 199 | overrides = (x.split(',') for x in lines[2:]) 200 | return Timecodes(cls._convert_v1_to_v2(default, overrides), default) 201 | else: 202 | raise SushiError('This timecodes format is not supported') 203 | 204 | @classmethod 205 | def from_file(cls, path): 206 | with open(path) as file: 207 | return cls.parse(file.read()) 208 | 209 | @classmethod 210 | def cfr(cls, fps): 211 | class CfrTimecodes(object): 212 | def __init__(self, fps): 213 | self.frame_duration = 1.0 / fps 214 | 215 | def get_frame_time(self, number): 216 | return number * self.frame_duration 217 | 218 | def get_frame_size(self, timestamp): 219 | return self.frame_duration 220 | 221 | def get_frame_number(self, timestamp): 222 | return int(timestamp / self.frame_duration) 223 | 224 | return CfrTimecodes(fps) 225 | 226 | 227 | class Demuxer(object): 228 | def __init__(self, path): 229 | super(Demuxer, self).__init__() 230 | self._path = path 231 | self._is_wav = get_extension(self._path) == '.wav' 232 | self._mi = None if self._is_wav else FFmpeg.get_media_info(self._path) 233 | self._demux_audio = self._demux_subs = self._make_timecodes = self._make_keyframes = self._write_chapters = False 234 | 235 | @property 236 | def is_wav(self): 237 | return self._is_wav 238 | 239 | @property 240 | def path(self): 241 | return self._path 242 | 243 | @property 244 | def chapters(self): 245 | if self.is_wav: 246 | return [] 247 | return self._mi.chapters 248 | 249 | @property 250 | def has_video(self): 251 | return not self.is_wav and self._mi.video 252 | 253 | def set_audio(self, stream_idx, output_path, sample_rate): 254 | self._audio_stream = self._select_stream(self._mi.audio, stream_idx, 'audio') 255 | self._audio_output_path = output_path 256 | self._audio_sample_rate = sample_rate 257 | self._demux_audio = True 258 | 259 | def set_script(self, stream_idx, output_path): 260 | self._script_stream = self._select_stream(self._mi.subtitles, stream_idx, 'subtitles') 261 | self._script_output_path = output_path 262 | self._demux_subs = True 263 | 264 | def set_timecodes(self, output_path): 265 | self._timecodes_output_path = output_path 266 | self._make_timecodes = True 267 | 268 | def set_chapters(self, output_path): 269 | self._write_chapters = True 270 | self._chapters_output_path = output_path 271 | 272 | def set_keyframes(self, output_path): 273 | self._keyframes_output_path = output_path 274 | self._make_keyframes = True 275 | 276 | def get_subs_type(self, stream_idx): 277 | return self._select_stream(self._mi.subtitles, stream_idx, 'subtitles').type 278 | 279 | def demux(self): 280 | if self._write_chapters: 281 | with open(self._chapters_output_path, "w") as output_file: 282 | output_file.write(chapters.format_ogm_chapters(self.chapters)) 283 | 284 | if self._make_keyframes: 285 | SCXviD.make_keyframes(self._path, self._keyframes_output_path) 286 | 287 | ffargs = {} 288 | if self._demux_audio: 289 | ffargs['audio_stream'] = self._audio_stream.id 290 | ffargs['audio_path'] = self._audio_output_path 291 | ffargs['audio_rate'] = self._audio_sample_rate 292 | if self._demux_subs: 293 | ffargs['script_stream'] = self._script_stream.id 294 | ffargs['script_path'] = self._script_output_path 295 | 296 | if self._make_timecodes: 297 | def set_ffmpeg_timecodes(): 298 | ffargs['video_stream'] = self._mi.video[0].id 299 | ffargs['timecodes_path'] = self._timecodes_output_path 300 | 301 | if get_extension(self._path).lower() == '.mkv': 302 | try: 303 | MkvToolnix.extract_timecodes(self._path, 304 | stream_idx=self._mi.video[0].id, 305 | output_path=self._timecodes_output_path) 306 | except OSError as e: 307 | if e.errno == 2: 308 | set_ffmpeg_timecodes() 309 | else: 310 | raise 311 | else: 312 | set_ffmpeg_timecodes() 313 | 314 | if ffargs: 315 | FFmpeg.demux_file(self._path, **ffargs) 316 | 317 | def cleanup(self): 318 | if self._demux_audio: 319 | os.remove(self._audio_output_path) 320 | if self._demux_subs: 321 | os.remove(self._script_output_path) 322 | if self._make_timecodes: 323 | os.remove(self._timecodes_output_path) 324 | if self._write_chapters: 325 | os.remove(self._chapters_output_path) 326 | 327 | @classmethod 328 | def _format_stream(cls, stream): 329 | return '{0}{1}: {2}'.format(stream.id, ' (%s)' % stream.title if stream.title else '', stream.info) 330 | 331 | @classmethod 332 | def _format_streams_list(cls, streams): 333 | return '\n'.join(map(cls._format_stream, streams)) 334 | 335 | def _select_stream(self, streams, chosen_idx, name): 336 | if not streams: 337 | raise SushiError('No {0} streams found in {1}'.format(name, self._path)) 338 | if chosen_idx is None: 339 | if len(streams) > 1: 340 | default_track = next((s for s in streams if s.default), None) 341 | if default_track: 342 | logging.warning('Using default track {0} in {1} because there are multiple candidates' 343 | .format(self._format_stream(default_track), self._path)) 344 | return default_track 345 | raise SushiError('More than one {0} stream found in {1}.' 346 | 'You need to specify the exact one to demux. Here are all candidates:\n' 347 | '{2}'.format(name, self._path, self._format_streams_list(streams))) 348 | return streams[0] 349 | 350 | try: 351 | return next(x for x in streams if x.id == chosen_idx) 352 | except StopIteration: 353 | raise SushiError("Stream with index {0} doesn't exist in {1}.\n" 354 | "Here are all that do:\n" 355 | "{2}".format(chosen_idx, self._path, self._format_streams_list(streams))) 356 | 357 | -------------------------------------------------------------------------------- /keyframes.py: -------------------------------------------------------------------------------- 1 | from common import SushiError, read_all_text 2 | 3 | 4 | def parse_scxvid_keyframes(text): 5 | return [i-3 for i,line in enumerate(text.splitlines()) if line and line[0] == 'i'] 6 | 7 | def parse_keyframes(path): 8 | text = read_all_text(path) 9 | if '# XviD 2pass stat file' in text: 10 | frames = parse_scxvid_keyframes(text) 11 | else: 12 | raise SushiError('Unsupported keyframes type') 13 | if 0 not in frames: 14 | frames.insert(0, 0) 15 | return frames 16 | -------------------------------------------------------------------------------- /licenses/OpenCV.txt: -------------------------------------------------------------------------------- 1 | By downloading, copying, installing or using the software you agree to this license. 2 | If you do not agree to this license, do not download, install, 3 | copy or use the software. 4 | 5 | 6 | License Agreement 7 | For Open Source Computer Vision Library 8 | (3-clause BSD License) 9 | 10 | Redistribution and use in source and binary forms, with or without modification, 11 | are permitted provided that the following conditions are met: 12 | 13 | * Redistributions of source code must retain the above copyright notice, 14 | this list of conditions and the following disclaimer. 15 | 16 | * Redistributions in binary form must reproduce the above copyright notice, 17 | this list of conditions and the following disclaimer in the documentation 18 | and/or other materials provided with the distribution. 19 | 20 | * Neither the names of the copyright holders nor the names of the contributors 21 | may be used to endorse or promote products derived from this software 22 | without specific prior written permission. 23 | 24 | This software is provided by the copyright holders and contributors "as is" and 25 | any express or implied warranties, including, but not limited to, the implied 26 | warranties of merchantability and fitness for a particular purpose are disclaimed. 27 | In no event shall copyright holders or contributors be liable for any direct, 28 | indirect, incidental, special, exemplary, or consequential damages 29 | (including, but not limited to, procurement of substitute goods or services; 30 | loss of use, data, or profits; or business interruption) however caused 31 | and on any theory of liability, whether in contract, strict liability, 32 | or tort (including negligence or otherwise) arising in any way out of 33 | the use of this software, even if advised of the possibility of such damage. 34 | -------------------------------------------------------------------------------- /licenses/SciPy.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2001, 2002 Enthought, Inc. 2 | All rights reserved. 3 | 4 | Copyright (c) 2003-2012 SciPy Developers. 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | a. Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | b. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | c. Neither the name of Enthought nor the names of the SciPy Developers 16 | may be used to endorse or promote products derived from this software 17 | without specific prior written permission. 18 | 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 23 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS 24 | BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, 25 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 26 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 27 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 28 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 29 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 30 | THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | -------------------------------------------------------------------------------- /regression-tests.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import json 3 | import logging 4 | import os 5 | import gc 6 | import sys 7 | import resource 8 | import re 9 | import subprocess 10 | import argparse 11 | 12 | from common import format_time 13 | from demux import Timecodes 14 | from subs import AssScript 15 | from wav import WavStream 16 | 17 | 18 | root_logger = logging.getLogger('') 19 | 20 | 21 | def strip_tags(text): 22 | return re.sub(r'{.*?}', " ", text) 23 | 24 | 25 | @contextmanager 26 | def set_file_logger(path): 27 | handler = logging.FileHandler(path, mode='a') 28 | handler.setLevel(logging.DEBUG) 29 | handler.setFormatter(logging.Formatter('%(message)s')) 30 | root_logger.addHandler(handler) 31 | try: 32 | yield 33 | finally: 34 | root_logger.removeHandler(handler) 35 | 36 | 37 | def compare_scripts(ideal_path, test_path, timecodes, test_name, expected_errors): 38 | ideal_script = AssScript.from_file(ideal_path) 39 | test_script = AssScript.from_file(test_path) 40 | if len(test_script.events) != len(ideal_script.events): 41 | logging.critical("Script length didn't match: {0} in ideal vs {1} in test. Test {2}".format( 42 | len(ideal_script.events), len(test_script.events), test_name) 43 | ) 44 | return False 45 | ideal_script.sort_by_time() 46 | test_script.sort_by_time() 47 | failed = 0 48 | ft = format_time 49 | for idx, (ideal, test) in enumerate(zip(ideal_script.events, test_script.events)): 50 | ideal_start_frame = timecodes.get_frame_number(ideal.start) 51 | ideal_end_frame = timecodes.get_frame_number(ideal.end) 52 | 53 | test_start_frame = timecodes.get_frame_number(test.start) 54 | test_end_frame = timecodes.get_frame_number(test.end) 55 | 56 | if ideal_start_frame != test_start_frame and ideal_end_frame != test_end_frame: 57 | logging.debug(u'{0}: start and end time failed at "{1}". {2}-{3} vs {4}-{5}'.format( 58 | idx, strip_tags(ideal.text), ft(ideal.start), ft(ideal.end), ft(test.start), ft(test.end)) 59 | ) 60 | failed += 1 61 | elif ideal_end_frame != test_end_frame: 62 | logging.debug( 63 | u'{0}: end time failed at "{1}". {2} vs {3}'.format( 64 | idx, strip_tags(ideal.text), ft(ideal.end), ft(test.end)) 65 | ) 66 | failed += 1 67 | elif ideal_start_frame != test_start_frame: 68 | logging.debug( 69 | u'{0}: start time failed at "{1}". {2} vs {3}'.format( 70 | idx, strip_tags(ideal.text), ft(ideal.start), ft(test.start)) 71 | ) 72 | failed += 1 73 | 74 | logging.info('Total lines: {0}, good: {1}, failed: {2}'.format(len(ideal_script.events), len(ideal_script.events)-failed, failed)) 75 | 76 | if failed > expected_errors: 77 | logging.critical('Got more failed lines than expected ({0} actual vs {1} expected)'.format(failed, expected_errors)) 78 | return False 79 | elif failed < expected_errors: 80 | logging.critical('Got less failed lines than expected ({0} actual vs {1} expected)'.format(failed, expected_errors)) 81 | return False 82 | else: 83 | logging.critical('Met expectations') 84 | return True 85 | 86 | 87 | def run_test(base_path, plots_path, test_name, params): 88 | def safe_add_key(args, key, name): 89 | if name in params: 90 | args.extend((key, str(params[name]))) 91 | 92 | def safe_add_path(args, folder, key, name): 93 | if name in params: 94 | args.extend((key, os.path.join(folder, params[name]))) 95 | 96 | logging.info('Testing "{0}"'.format(test_name)) 97 | 98 | folder = os.path.join(base_path, params['folder']) 99 | 100 | cmd = ["sushi"] 101 | 102 | safe_add_path(cmd, folder, '--src', 'src') 103 | safe_add_path(cmd, folder, '--dst', 'dst') 104 | safe_add_path(cmd, folder, '--src-keyframes', 'src-keyframes') 105 | safe_add_path(cmd, folder, '--dst-keyframes', 'dst-keyframes') 106 | safe_add_path(cmd, folder, '--src-timecodes', 'src-timecodes') 107 | safe_add_path(cmd, folder, '--dst-timecodes', 'dst-timecodes') 108 | safe_add_path(cmd, folder, '--script', 'script') 109 | safe_add_path(cmd, folder, '--chapters', 'chapters') 110 | safe_add_path(cmd, folder, '--src-script', 'src-script') 111 | safe_add_path(cmd, folder, '--dst-script', 'dst-script') 112 | safe_add_key(cmd, '--max-kf-distance', 'max-kf-distance') 113 | safe_add_key(cmd, '--max-ts-distance', 'max-ts-distance') 114 | safe_add_key(cmd, '--max-ts-duration', 'max-ts-duration') 115 | 116 | output_path = os.path.join(folder, params['dst']) + '.sushi.test.ass' 117 | cmd.extend(('-o', output_path)) 118 | if plots_path: 119 | cmd.extend(('--test-shift-plot', os.path.join(plots_path, '{0}.png'.format(test_name)))) 120 | 121 | log_path = os.path.join(folder, 'sushi_test.log') 122 | 123 | with open(log_path, "w") as log_file: 124 | try: 125 | subprocess.call(cmd, stderr=log_file, stdout=log_file) 126 | except Exception as e: 127 | logging.critical('Sushi failed on test "{0}": {1}'.format(test_name, e.message)) 128 | return False 129 | 130 | with set_file_logger(log_path): 131 | ideal_path = os.path.join(folder, params['ideal']) 132 | try: 133 | timecodes = Timecodes.from_file(os.path.join(folder, params['dst-timecodes'])) 134 | except KeyError: 135 | timecodes = Timecodes.cfr(params['fps']) 136 | 137 | return compare_scripts(ideal_path, output_path, timecodes, test_name, params['expected_errors']) 138 | 139 | 140 | def run_wav_test(test_name, file_path, params): 141 | gc.collect(2) 142 | 143 | before = resource.getrusage(resource.RUSAGE_SELF) 144 | loaded = WavStream(file_path, params.get('sample_rate', 12000), params.get('sample_type', 'uint8')) 145 | after = resource.getrusage(resource.RUSAGE_SELF) 146 | 147 | total_time = (after.ru_stime - before.ru_stime) + (after.ru_utime - before.ru_utime) 148 | ram_difference = (after.ru_maxrss - before.ru_maxrss) / 1024.0 / 1024.0 149 | 150 | if 'max_time' in params and total_time > params['max_time']: 151 | logging.critical('Loading "{0}" took too much time: {1} vs {2} seconds' 152 | .format(test_name, total_time, params['max_time'])) 153 | return False 154 | if 'max_memory' in params and ram_difference > params['max_memory']: 155 | logging.critical('Loading "{0}" consumed too much RAM: {1} vs {2}' 156 | .format(test_name, ram_difference, params['max_memory'])) 157 | return False 158 | return True 159 | 160 | 161 | def create_arg_parser(): 162 | parser = argparse.ArgumentParser(description='Sushi regression testing util') 163 | 164 | parser.add_argument('--only', dest="run_only", nargs="*", metavar='', 165 | help='Test names to run') 166 | parser.add_argument('-c', '--conf', default="tests.json", dest='conf_path', metavar='', 167 | help='Config file path') 168 | 169 | return parser 170 | 171 | 172 | def run(): 173 | root_logger.setLevel(logging.DEBUG) 174 | console_handler = logging.StreamHandler() 175 | console_handler.setLevel(logging.INFO) 176 | console_handler.setFormatter(logging.Formatter('%(message)s')) 177 | root_logger.addHandler(console_handler) 178 | 179 | args = create_arg_parser().parse_args() 180 | 181 | try: 182 | with open(args.conf_path) as file: 183 | config = json.load(file) 184 | except IOError as e: 185 | logging.critical(e) 186 | sys.exit(2) 187 | 188 | def should_run(name): 189 | return not args.run_only or name in args.run_only 190 | 191 | failed = ran = 0 192 | for test_name, params in config.get('tests', {}).iteritems(): 193 | if not should_run(test_name): 194 | continue 195 | if not params.get('disabled', False): 196 | ran += 1 197 | if not run_test(config['basepath'], config['plots'], test_name, params): 198 | failed += 1 199 | logging.info('') 200 | else: 201 | logging.warn('Test "{0}" disabled'.format(test_name)) 202 | 203 | if should_run("wavs"): 204 | for test_name, params in config.get('wavs', {}).iteritems(): 205 | ran += 1 206 | if not run_wav_test(test_name, os.path.join(config['basepath'], params['file']), params): 207 | failed += 1 208 | logging.info('') 209 | 210 | logging.info('Ran {0} tests, {1} failed'.format(ran, failed)) 211 | 212 | 213 | if __name__ == '__main__': 214 | run() 215 | -------------------------------------------------------------------------------- /requirements-win.txt: -------------------------------------------------------------------------------- 1 | pyinstaller 2 | colorama 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | mock 3 | -------------------------------------------------------------------------------- /run-tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from tests.timecodes import * 3 | from tests.main import * 4 | from tests.subtitles import * 5 | from tests.demuxing import * 6 | 7 | unittest.main(verbosity=0) 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | import sushi 3 | 4 | setup( 5 | name='Sushi', 6 | description='Automatic subtitle shifter based on audio', 7 | version=sushi.VERSION, 8 | url='https://github.com/tp7/Sushi', 9 | console=['sushi.py'], 10 | license='MIT' 11 | ) 12 | 13 | -------------------------------------------------------------------------------- /subs.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | import re 4 | import collections 5 | 6 | from common import SushiError, format_time, format_srt_time 7 | 8 | 9 | def _parse_ass_time(string): 10 | hours, minutes, seconds = map(float, string.split(':')) 11 | return hours * 3600 + minutes * 60 + seconds 12 | 13 | 14 | class ScriptEventBase(object): 15 | def __init__(self, source_index, start, end, text): 16 | self.source_index = source_index 17 | self.start = start 18 | self.end = end 19 | self.text = text 20 | 21 | self._shift = 0 22 | self._diff = 1 23 | self._linked_event = None 24 | self._start_shift = 0 25 | self._end_shift = 0 26 | 27 | @property 28 | def shift(self): 29 | return self._linked_event.shift if self.linked else self._shift 30 | 31 | @property 32 | def diff(self): 33 | return self._linked_event.diff if self.linked else self._diff 34 | 35 | @property 36 | def duration(self): 37 | return self.end - self.start 38 | 39 | @property 40 | def shifted_end(self): 41 | return self.end + self.shift + self._end_shift 42 | 43 | @property 44 | def shifted_start(self): 45 | return self.start + self.shift + self._start_shift 46 | 47 | def apply_shift(self): 48 | self.start = self.shifted_start 49 | self.end = self.shifted_end 50 | 51 | def set_shift(self, shift, audio_diff): 52 | assert not self.linked, 'Cannot set shift of a linked event' 53 | self._shift = shift 54 | self._diff = audio_diff 55 | 56 | def adjust_additional_shifts(self, start_shift, end_shift): 57 | assert not self.linked, 'Cannot apply additional shifts to a linked event' 58 | self._start_shift += start_shift 59 | self._end_shift += end_shift 60 | 61 | def get_link_chain_end(self): 62 | return self._linked_event.get_link_chain_end() if self.linked else self 63 | 64 | def link_event(self, other): 65 | assert other.get_link_chain_end() is not self, 'Circular link detected' 66 | self._linked_event = other 67 | 68 | def resolve_link(self): 69 | assert self.linked, 'Cannot resolve unlinked events' 70 | self._shift = self._linked_event.shift 71 | self._diff = self._linked_event.diff 72 | self._linked_event = None 73 | 74 | @property 75 | def linked(self): 76 | return self._linked_event is not None 77 | 78 | def adjust_shift(self, value): 79 | assert not self.linked, 'Cannot adjust time of linked events' 80 | self._shift += value 81 | 82 | def __repr__(self): 83 | return unicode(self) 84 | 85 | 86 | class ScriptBase(object): 87 | def __init__(self, events): 88 | self.events = events 89 | 90 | def sort_by_time(self): 91 | self.events.sort(key=lambda x: x.start) 92 | 93 | 94 | class SrtEvent(ScriptEventBase): 95 | is_comment = False 96 | style = None 97 | 98 | EVENT_REGEX = re.compile(""" 99 | (\d+?)\s+? # line-number 100 | (\d{1,2}:\d{1,2}:\d{1,2},\d+)\s-->\s(\d{1,2}:\d{1,2}:\d{1,2},\d+). # timestamp 101 | (.+?) # actual text 102 | (?= # lookahead for the next line or end of the file 103 | (?:\d+?\s+? # line-number 104 | \d{1,2}:\d{1,2}:\d{1,2},\d+\s-->\s\d{1,2}:\d{1,2}:\d{1,2},\d+) # timestamp 105 | |$ 106 | )""", flags=re.VERBOSE | re.DOTALL) 107 | 108 | @classmethod 109 | def from_string(cls, text): 110 | match = cls.EVENT_REGEX.match(text) 111 | start = cls.parse_time(match.group(2)) 112 | end = cls.parse_time(match.group(3)) 113 | return SrtEvent(int(match.group(1)), start, end, match.group(4).strip()) 114 | 115 | def __unicode__(self): 116 | return u'{0}\n{1} --> {2}\n{3}'.format(self.source_index, self._format_time(self.start), 117 | self._format_time(self.end), self.text) 118 | 119 | @staticmethod 120 | def parse_time(time_string): 121 | return _parse_ass_time(time_string.replace(',', '.')) 122 | 123 | @staticmethod 124 | def _format_time(seconds): 125 | return format_srt_time(seconds) 126 | 127 | 128 | class SrtScript(ScriptBase): 129 | @classmethod 130 | def from_file(cls, path): 131 | try: 132 | with codecs.open(path, encoding='utf-8-sig') as script: 133 | text = script.read() 134 | events_list = [SrtEvent( 135 | source_index=int(match.group(1)), 136 | start=SrtEvent.parse_time(match.group(2)), 137 | end=SrtEvent.parse_time(match.group(3)), 138 | text=match.group(4).strip() 139 | ) for match in SrtEvent.EVENT_REGEX.finditer(text)] 140 | return cls(events_list) 141 | except IOError: 142 | raise SushiError("Script {0} not found".format(path)) 143 | 144 | def save_to_file(self, path): 145 | text = '\n\n'.join(map(unicode, self.events)) 146 | with codecs.open(path, encoding='utf-8', mode='w') as script: 147 | script.write(text) 148 | 149 | 150 | class AssEvent(ScriptEventBase): 151 | def __init__(self, text, position=0): 152 | kind, _, rest = text.partition(':') 153 | split = [x.strip() for x in rest.split(',', 9)] 154 | 155 | super(AssEvent, self).__init__( 156 | source_index=position, 157 | start=_parse_ass_time(split[1]), 158 | end=_parse_ass_time(split[2]), 159 | text=split[9] 160 | ) 161 | self.kind = kind 162 | self.is_comment = self.kind.lower() == 'comment' 163 | self.layer = split[0] 164 | self.style = split[3] 165 | self.name = split[4] 166 | self.margin_left = split[5] 167 | self.margin_right = split[6] 168 | self.margin_vertical = split[7] 169 | self.effect = split[8] 170 | 171 | def __unicode__(self): 172 | return u'{0}: {1},{2},{3},{4},{5},{6},{7},{8},{9},{10}'.format(self.kind, self.layer, 173 | self._format_time(self.start), 174 | self._format_time(self.end), 175 | self.style, self.name, 176 | self.margin_left, self.margin_right, 177 | self.margin_vertical, self.effect, 178 | self.text) 179 | 180 | @staticmethod 181 | def _format_time(seconds): 182 | return format_time(seconds) 183 | 184 | 185 | class AssScript(ScriptBase): 186 | def __init__(self, script_info, styles, events, other): 187 | super(AssScript, self).__init__(events) 188 | self.script_info = script_info 189 | self.styles = styles 190 | self.other = other 191 | 192 | @classmethod 193 | def from_file(cls, path): 194 | script_info, styles, events = [], [], [] 195 | other_sections = collections.OrderedDict() 196 | 197 | def parse_script_info_line(line): 198 | if line.startswith(u'Format:'): 199 | return 200 | script_info.append(line) 201 | 202 | def parse_styles_line(line): 203 | if line.startswith(u'Format:'): 204 | return 205 | styles.append(line) 206 | 207 | def parse_event_line(line): 208 | if line.startswith(u'Format:'): 209 | return 210 | events.append(AssEvent(line, position=len(events)+1)) 211 | 212 | def create_generic_parse(section_name): 213 | if section_name in other_sections: 214 | raise SushiError("Duplicate section detected, invalid script?") 215 | other_sections[section_name] = [] 216 | return other_sections[section_name].append 217 | 218 | parse_function = None 219 | 220 | try: 221 | with codecs.open(path, encoding='utf-8-sig') as script: 222 | for line_idx, line in enumerate(script): 223 | line = line.strip() 224 | if not line: 225 | continue 226 | low = line.lower() 227 | if low == u'[script info]': 228 | parse_function = parse_script_info_line 229 | elif low == u'[v4+ styles]': 230 | parse_function = parse_styles_line 231 | elif low == u'[events]': 232 | parse_function = parse_event_line 233 | elif re.match(r'\[.+?\]', low): 234 | parse_function = create_generic_parse(line) 235 | elif not parse_function: 236 | raise SushiError("That's some invalid ASS script") 237 | else: 238 | try: 239 | parse_function(line) 240 | except Exception as e: 241 | raise SushiError("That's some invalid ASS script: {0} [line {1}]".format(e.message, line_idx)) 242 | except IOError: 243 | raise SushiError("Script {0} not found".format(path)) 244 | return cls(script_info, styles, events, other_sections) 245 | 246 | def save_to_file(self, path): 247 | # if os.path.exists(path): 248 | # raise RuntimeError('File %s already exists' % path) 249 | lines = [] 250 | if self.script_info: 251 | lines.append(u'[Script Info]') 252 | lines.extend(self.script_info) 253 | lines.append('') 254 | 255 | if self.styles: 256 | lines.append(u'[V4+ Styles]') 257 | lines.append(u'Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding') 258 | lines.extend(self.styles) 259 | lines.append('') 260 | 261 | if self.events: 262 | events = sorted(self.events, key=lambda x: x.source_index) 263 | lines.append(u'[Events]') 264 | lines.append(u'Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text') 265 | lines.extend(map(unicode, events)) 266 | 267 | if self.other: 268 | for section_name, section_lines in self.other.iteritems(): 269 | lines.append('') 270 | lines.append(section_name) 271 | lines.extend(section_lines) 272 | 273 | with codecs.open(path, encoding='utf-8-sig', mode='w') as script: 274 | script.write(unicode(os.linesep).join(lines)) 275 | -------------------------------------------------------------------------------- /sushi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | import logging 3 | import sys 4 | import operator 5 | import argparse 6 | import os 7 | import bisect 8 | import collections 9 | from itertools import takewhile, izip, chain 10 | import time 11 | 12 | import numpy as np 13 | 14 | import chapters 15 | from common import SushiError, get_extension, format_time, ensure_static_collection 16 | from demux import Timecodes, Demuxer 17 | import keyframes 18 | from subs import AssScript, SrtScript 19 | from wav import WavStream 20 | 21 | 22 | try: 23 | import matplotlib.pyplot as plt 24 | plot_enabled = True 25 | except ImportError: 26 | plot_enabled = False 27 | 28 | if sys.platform == 'win32': 29 | try: 30 | import colorama 31 | colorama.init() 32 | console_colors_supported = True 33 | except ImportError: 34 | console_colors_supported = False 35 | else: 36 | console_colors_supported = True 37 | 38 | 39 | ALLOWED_ERROR = 0.01 40 | MAX_GROUP_STD = 0.025 41 | VERSION = '0.5.1' 42 | 43 | 44 | class ColoredLogFormatter(logging.Formatter): 45 | bold_code = "\033[1m" 46 | reset_code = "\033[0m" 47 | grey_code = "\033[30m\033[1m" 48 | 49 | error_format = "{bold}ERROR: %(message)s{reset}".format(bold=bold_code, reset=reset_code) 50 | warn_format = "{bold}WARNING: %(message)s{reset}".format(bold=bold_code, reset=reset_code) 51 | debug_format = "{grey}%(message)s{reset}".format(grey=grey_code, reset=reset_code) 52 | default_format = "%(message)s" 53 | 54 | def format(self, record): 55 | if record.levelno == logging.DEBUG: 56 | self._fmt = self.debug_format 57 | elif record.levelno == logging.WARN: 58 | self._fmt = self.warn_format 59 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 60 | self._fmt = self.error_format 61 | else: 62 | self._fmt = self.default_format 63 | 64 | return super(ColoredLogFormatter, self).format(record) 65 | 66 | 67 | def abs_diff(a, b): 68 | return abs(a - b) 69 | 70 | 71 | def interpolate_nones(data, points): 72 | data = ensure_static_collection(data) 73 | values_lookup = {p: v for p, v in izip(points, data) if v is not None} 74 | if not values_lookup: 75 | return [] 76 | 77 | zero_points = {p for p, v in izip(points, data) if v is None} 78 | if not zero_points: 79 | return data 80 | 81 | data_list = sorted(values_lookup.iteritems()) 82 | zero_points = sorted(x for x in zero_points if x not in values_lookup) 83 | 84 | out = np.interp(x=zero_points, 85 | xp=map(operator.itemgetter(0), data_list), 86 | fp=map(operator.itemgetter(1), data_list)) 87 | 88 | values_lookup.update(izip(zero_points, out)) 89 | 90 | return [ 91 | values_lookup[point] if value is None else value 92 | for point, value in izip(points, data) 93 | ] 94 | 95 | 96 | # todo: implement this as a running median 97 | def running_median(values, window_size): 98 | if window_size % 2 != 1: 99 | raise SushiError('Median window size should be odd') 100 | half_window = window_size // 2 101 | medians = [] 102 | items_count = len(values) 103 | for idx in xrange(items_count): 104 | radius = min(half_window, idx, items_count-idx-1) 105 | med = np.median(values[idx-radius:idx+radius+1]) 106 | medians.append(med) 107 | return medians 108 | 109 | 110 | def smooth_events(events, radius): 111 | if not radius: 112 | return 113 | window_size = radius*2+1 114 | shifts = [e.shift for e in events] 115 | smoothed = running_median(shifts, window_size) 116 | for event, new_shift in izip(events, smoothed): 117 | event.set_shift(new_shift, event.diff) 118 | 119 | 120 | def detect_groups(events_iter): 121 | events_iter = iter(events_iter) 122 | groups_list = [[next(events_iter)]] 123 | for event in events_iter: 124 | if abs_diff(event.shift, groups_list[-1][-1].shift) > ALLOWED_ERROR: 125 | groups_list.append([]) 126 | groups_list[-1].append(event) 127 | return groups_list 128 | 129 | 130 | def groups_from_chapters(events, times): 131 | logging.info(u'Chapter start points: {0}'.format([format_time(t) for t in times])) 132 | groups = [[]] 133 | chapter_times = iter(times[1:] + [36000000000]) # very large event at the end 134 | current_chapter = next(chapter_times) 135 | 136 | for event in events: 137 | if event.end > current_chapter: 138 | groups.append([]) 139 | while event.end > current_chapter: 140 | current_chapter = next(chapter_times) 141 | 142 | groups[-1].append(event) 143 | 144 | groups = filter(None, groups) # non-empty groups 145 | # check if we have any groups where every event is linked 146 | # for example a chapter with only comments inside 147 | broken_groups = [group for group in groups if not any(e for e in group if not e.linked)] 148 | if broken_groups: 149 | for group in broken_groups: 150 | for event in group: 151 | parent = event.get_link_chain_end() 152 | parent_group = next(group for group in groups if parent in group) 153 | parent_group.append(event) 154 | del group[:] 155 | groups = filter(None, groups) 156 | # re-sort the groups again since we might break the order when inserting linked events 157 | # sorting everything again is far from optimal but python sorting is very fast for sorted arrays anyway 158 | for group in groups: 159 | group.sort(key=lambda event: event.start) 160 | 161 | return groups 162 | 163 | 164 | def split_broken_groups(groups): 165 | correct_groups = [] 166 | broken_found = False 167 | for g in groups: 168 | std = np.std([e.shift for e in g]) 169 | if std > MAX_GROUP_STD: 170 | logging.warn(u'Shift is not consistent between {0} and {1}, most likely chapters are wrong (std: {2}). ' 171 | u'Switching to automatic grouping.'.format(format_time(g[0].start), format_time(g[-1].end), 172 | std)) 173 | correct_groups.extend(detect_groups(g)) 174 | broken_found = True 175 | else: 176 | correct_groups.append(g) 177 | 178 | if broken_found: 179 | groups_iter = iter(correct_groups) 180 | correct_groups = [list(next(groups_iter))] 181 | for group in groups_iter: 182 | if abs_diff(correct_groups[-1][-1].shift, group[0].shift) >= ALLOWED_ERROR \ 183 | or np.std([e.shift for e in group + correct_groups[-1]]) >= MAX_GROUP_STD: 184 | correct_groups.append([]) 185 | 186 | correct_groups[-1].extend(group) 187 | return correct_groups 188 | 189 | 190 | def fix_near_borders(events): 191 | """ 192 | We assume that all lines with diff greater than 5 * (median diff across all events) are broken 193 | """ 194 | def fix_border(event_list, median_diff): 195 | last_ten_diff = np.median([x.diff for x in event_list[:10]], overwrite_input=True) 196 | diff_limit = min(last_ten_diff, median_diff) 197 | broken = [] 198 | for event in event_list: 199 | if not 0.2 < (event.diff / diff_limit) < 5: 200 | broken.append(event) 201 | else: 202 | for x in broken: 203 | x.link_event(event) 204 | return len(broken) 205 | return 0 206 | 207 | median_diff = np.median([x.diff for x in events], overwrite_input=True) 208 | 209 | fixed_count = fix_border(events, median_diff) 210 | if fixed_count: 211 | logging.info('Fixing {0} border events right after {1}'.format(fixed_count, format_time(events[0].start))) 212 | 213 | fixed_count = fix_border(list(reversed(events)), median_diff) 214 | if fixed_count: 215 | logging.info('Fixing {0} border events right before {1}'.format(fixed_count, format_time(events[-1].end))) 216 | 217 | 218 | def get_distance_to_closest_kf(timestamp, keyframes): 219 | idx = bisect.bisect_left(keyframes, timestamp) 220 | if idx == 0: 221 | kf = keyframes[0] 222 | elif idx == len(keyframes): 223 | kf = keyframes[-1] 224 | else: 225 | before = keyframes[idx - 1] 226 | after = keyframes[idx] 227 | kf = after if after - timestamp < timestamp - before else before 228 | return kf - timestamp 229 | 230 | 231 | def find_keyframe_shift(group, src_keytimes, dst_keytimes, src_timecodes, dst_timecodes, max_kf_distance): 232 | def get_distance(src_distance, dst_distance, limit): 233 | if abs(dst_distance) > limit: 234 | return None 235 | shift = dst_distance - src_distance 236 | return shift if abs(shift) < limit else None 237 | 238 | src_start = get_distance_to_closest_kf(group[0].start, src_keytimes) 239 | src_end = get_distance_to_closest_kf(group[-1].end + src_timecodes.get_frame_size(group[-1].end), src_keytimes) 240 | 241 | dst_start = get_distance_to_closest_kf(group[0].shifted_start, dst_keytimes) 242 | dst_end = get_distance_to_closest_kf(group[-1].shifted_end + dst_timecodes.get_frame_size(group[-1].end), dst_keytimes) 243 | 244 | snapping_limit_start = src_timecodes.get_frame_size(group[0].start) * max_kf_distance 245 | snapping_limit_end = src_timecodes.get_frame_size(group[0].end) * max_kf_distance 246 | 247 | return (get_distance(src_start, dst_start, snapping_limit_start), 248 | get_distance(src_end, dst_end, snapping_limit_end)) 249 | 250 | 251 | def find_keyframes_distances(event, src_keytimes, dst_keytimes, timecodes, max_kf_distance): 252 | def find_keyframe_distance(src_time, dst_time): 253 | src = get_distance_to_closest_kf(src_time, src_keytimes) 254 | dst = get_distance_to_closest_kf(dst_time, dst_keytimes) 255 | snapping_limit = timecodes.get_frame_size(src_time) * max_kf_distance 256 | 257 | if abs(src) < snapping_limit and abs(dst) < snapping_limit and abs(src-dst) < snapping_limit: 258 | return dst - src 259 | return 0 260 | 261 | ds = find_keyframe_distance(event.start, event.shifted_start) 262 | de = find_keyframe_distance(event.end, event.shifted_end) 263 | return ds, de 264 | 265 | 266 | def snap_groups_to_keyframes(events, chapter_times, max_ts_duration, max_ts_distance, src_keytimes, dst_keytimes, 267 | src_timecodes, dst_timecodes, max_kf_distance, kf_mode): 268 | if not max_kf_distance: 269 | return 270 | 271 | groups = merge_short_lines_into_groups(events, chapter_times, max_ts_duration, max_ts_distance) 272 | 273 | if kf_mode == 'all' or kf_mode == 'shift': 274 | # step 1: snap events without changing their duration. Useful for some slight audio imprecision correction 275 | shifts = [] 276 | times = [] 277 | for group in groups: 278 | shifts.extend(find_keyframe_shift(group, src_keytimes, dst_keytimes, src_timecodes, dst_timecodes, max_kf_distance)) 279 | times.extend((group[0].shifted_start, group[-1].shifted_end)) 280 | 281 | shifts = interpolate_nones(shifts, times) 282 | if shifts: 283 | mean_shift = np.mean(shifts) 284 | shifts = zip(*(iter(shifts), ) * 2) 285 | 286 | logging.info('Group {0}-{1} corrected by {2}'.format(format_time(events[0].start), format_time(events[-1].end), mean_shift)) 287 | for group, (start_shift, end_shift) in izip(groups, shifts): 288 | if abs(start_shift-end_shift) > 0.001 and len(group) > 1: 289 | actual_shift = min(start_shift, end_shift, key=lambda x: abs(x - mean_shift)) 290 | logging.warning("Typesetting group at {0} had different shift at start/end points ({1} and {2}). Shifting by {3}." 291 | .format(format_time(group[0].start), start_shift, end_shift, actual_shift)) 292 | for e in group: 293 | e.adjust_shift(actual_shift) 294 | else: 295 | for e in group: 296 | e.adjust_additional_shifts(start_shift, end_shift) 297 | 298 | if kf_mode == 'all' or kf_mode == 'snap': 299 | # step 2: snap start/end times separately 300 | for group in groups: 301 | if len(group) > 1: 302 | pass # we don't snap typesetting 303 | start_shift, end_shift = find_keyframes_distances(group[0], src_keytimes, dst_keytimes, src_timecodes, max_kf_distance) 304 | if abs(start_shift) > 0.01 or abs(end_shift) > 0.01: 305 | logging.info('Snapping {0} to keyframes, start time by {1}, end: {2}'.format(format_time(group[0].start), start_shift, end_shift)) 306 | group[0].adjust_additional_shifts(start_shift, end_shift) 307 | 308 | 309 | def average_shifts(events): 310 | events = [e for e in events if not e.linked] 311 | shifts = [x.shift for x in events] 312 | weights = [1 - x.diff for x in events] 313 | avg = np.average(shifts, weights=weights) 314 | for e in events: 315 | e.set_shift(avg, e.diff) 316 | return avg 317 | 318 | 319 | def merge_short_lines_into_groups(events, chapter_times, max_ts_duration, max_ts_distance): 320 | search_groups = [] 321 | chapter_times = iter(chapter_times[1:] + [100000000]) 322 | next_chapter = next(chapter_times) 323 | events = ensure_static_collection(events) 324 | 325 | processed = set() 326 | for idx, event in enumerate(events): 327 | if idx in processed: 328 | continue 329 | 330 | while event.end > next_chapter: 331 | next_chapter = next(chapter_times) 332 | 333 | if event.duration > max_ts_duration: 334 | search_groups.append([event]) 335 | processed.add(idx) 336 | else: 337 | group = [event] 338 | group_end = event.end 339 | i = idx+1 340 | while i < len(events) and abs(group_end - events[i].start) < max_ts_distance: 341 | if events[i].end < next_chapter and events[i].duration <= max_ts_duration: 342 | processed.add(i) 343 | group.append(events[i]) 344 | group_end = max(group_end, events[i].end) 345 | i += 1 346 | 347 | search_groups.append(group) 348 | 349 | return search_groups 350 | 351 | 352 | def prepare_search_groups(events, source_duration, chapter_times, max_ts_duration, max_ts_distance): 353 | last_unlinked = None 354 | for idx, event in enumerate(events): 355 | if event.is_comment: 356 | try: 357 | event.link_event(events[idx+1]) 358 | except IndexError: 359 | event.link_event(last_unlinked) 360 | continue 361 | if (event.start + event.duration / 2.0) > source_duration: 362 | logging.info('Event time outside of audio range, ignoring: %s' % unicode(event)) 363 | event.link_event(last_unlinked) 364 | continue 365 | elif event.end == event.start: 366 | logging.info('{0}: skipped because zero duration'.format(format_time(event.start))) 367 | try: 368 | event.link_event(events[idx + 1]) 369 | except IndexError: 370 | event.link_event(last_unlinked) 371 | continue 372 | 373 | # link lines with start and end times identical to some other event 374 | # assuming scripts are sorted by start time so we don't search the entire collection 375 | same_start = lambda x: event.start == x.start 376 | processed = next((x for x in takewhile(same_start, reversed(events[:idx])) if not x.linked and x.end == event.end),None) 377 | if processed: 378 | event.link_event(processed) 379 | else: 380 | last_unlinked = event 381 | 382 | events = (e for e in events if not e.linked) 383 | 384 | search_groups = merge_short_lines_into_groups(events, chapter_times, max_ts_duration, max_ts_distance) 385 | 386 | # link groups contained inside other groups to the larger group 387 | passed_groups = [] 388 | for idx, group in enumerate(search_groups): 389 | try: 390 | other = next(x for x in reversed(search_groups[:idx]) 391 | if x[0].start <= group[0].start 392 | and x[-1].end >= group[-1].end) 393 | for event in group: 394 | event.link_event(other[0]) 395 | except StopIteration: 396 | passed_groups.append(group) 397 | return passed_groups 398 | 399 | 400 | def calculate_shifts(src_stream, dst_stream, groups_list, normal_window, max_window, rewind_thresh): 401 | def log_shift(state): 402 | logging.info('{0}-{1}: shift: {2:0.10f}, diff: {3:0.10f}' 403 | .format(format_time(state["start_time"]), format_time(state["end_time"]), state["shift"], state["diff"])) 404 | 405 | def log_uncommitted(state, shift, left_side_shift, right_side_shift, search_offset): 406 | logging.debug('{0}-{1}: shift: {2:0.5f} [{3:0.5f}, {4:0.5f}], search offset: {5:0.6f}' 407 | .format(format_time(state["start_time"]), format_time(state["end_time"]), 408 | shift, left_side_shift, right_side_shift, search_offset)) 409 | 410 | small_window = 1.5 411 | idx = 0 412 | committed_states = [] 413 | uncommitted_states = [] 414 | window = normal_window 415 | while idx < len(groups_list): 416 | search_group = groups_list[idx] 417 | tv_audio = src_stream.get_substream(search_group[0].start, search_group[-1].end) 418 | original_time = search_group[0].start 419 | group_state = {"start_time": search_group[0].start, "end_time": search_group[-1].end, "shift": None, "diff": None} 420 | last_committed_shift = committed_states[-1]["shift"] if committed_states else 0 421 | diff = new_time = None 422 | 423 | if not uncommitted_states: 424 | if original_time + last_committed_shift > dst_stream.duration_seconds: 425 | # event outside of audio range, all events past it are also guaranteed to fail 426 | for g in groups_list[idx:]: 427 | committed_states.append({"start_time": g[0].start, "end_time": g[-1].end, "shift": None, "diff": None}) 428 | logging.info("{0}-{1}: outside of audio range".format(format_time(g[0].start), format_time(g[-1].end))) 429 | break 430 | 431 | if small_window < window: 432 | diff, new_time = dst_stream.find_substream(tv_audio, original_time + last_committed_shift, small_window) 433 | 434 | if new_time is not None and abs_diff(new_time - original_time, last_committed_shift) <= ALLOWED_ERROR: 435 | # fastest case - small window worked, commit the group immediately 436 | group_state.update({"shift": new_time - original_time, "diff": diff}) 437 | committed_states.append(group_state) 438 | log_shift(group_state) 439 | if window != normal_window: 440 | logging.info("Going back to window {0} from {1}".format(normal_window, window)) 441 | window = normal_window 442 | idx += 1 443 | continue 444 | 445 | left_audio_half, right_audio_half = np.split(tv_audio, [len(tv_audio[0])/2], axis=1) 446 | right_half_offset = len(left_audio_half[0]) / float(src_stream.sample_rate) 447 | terminate = False 448 | # searching from last committed shift 449 | if original_time + last_committed_shift < dst_stream.duration_seconds: 450 | diff, new_time = dst_stream.find_substream(tv_audio, original_time + last_committed_shift, window) 451 | left_side_time = dst_stream.find_substream(left_audio_half, original_time + last_committed_shift, window)[1] 452 | right_side_time = dst_stream.find_substream(right_audio_half, original_time + last_committed_shift + right_half_offset, window)[1] - right_half_offset 453 | terminate = abs_diff(left_side_time, right_side_time) <= ALLOWED_ERROR and abs_diff(new_time, left_side_time) <= ALLOWED_ERROR 454 | log_uncommitted(group_state, new_time - original_time, left_side_time - original_time, 455 | right_side_time - original_time, last_committed_shift) 456 | 457 | if not terminate and uncommitted_states and uncommitted_states[-1]["shift"] is not None \ 458 | and original_time + uncommitted_states[-1]["shift"] < dst_stream.duration_seconds: 459 | start_offset = uncommitted_states[-1]["shift"] 460 | diff, new_time = dst_stream.find_substream(tv_audio, original_time + start_offset, window) 461 | left_side_time = dst_stream.find_substream(left_audio_half, original_time + start_offset, window)[1] 462 | right_side_time = dst_stream.find_substream(right_audio_half, original_time + start_offset + right_half_offset, window)[1] - right_half_offset 463 | terminate = abs_diff(left_side_time, right_side_time) <= ALLOWED_ERROR and abs_diff(new_time, left_side_time) <= ALLOWED_ERROR 464 | log_uncommitted(group_state, new_time - original_time, left_side_time - original_time, 465 | right_side_time - original_time, start_offset) 466 | 467 | shift = new_time - original_time 468 | if not terminate: 469 | # we aren't back on track yet - add this group to uncommitted 470 | group_state.update({"shift": shift, "diff": diff}) 471 | uncommitted_states.append(group_state) 472 | idx += 1 473 | if rewind_thresh == len(uncommitted_states) and window < max_window: 474 | logging.warn("Detected possibly broken segment starting at {0}, increasing the window from {1} to {2}" 475 | .format(format_time(uncommitted_states[0]["start_time"]), window, max_window)) 476 | window = max_window 477 | idx = len(committed_states) 478 | del uncommitted_states[:] 479 | continue 480 | 481 | # we're back on track - apply current shift to all broken events 482 | if uncommitted_states: 483 | logging.warning("Events from {0} to {1} will most likely be broken!".format( 484 | format_time(uncommitted_states[0]["start_time"]), 485 | format_time(uncommitted_states[-1]["end_time"]))) 486 | 487 | uncommitted_states.append(group_state) 488 | for state in uncommitted_states: 489 | state.update({"shift": shift, "diff": diff}) 490 | log_shift(state) 491 | committed_states.extend(uncommitted_states) 492 | del uncommitted_states[:] 493 | idx += 1 494 | 495 | for state in uncommitted_states: 496 | log_shift(state) 497 | 498 | for idx, (search_group, group_state) in enumerate(izip(groups_list, chain(committed_states, uncommitted_states))): 499 | if group_state["shift"] is None: 500 | for group in reversed(groups_list[:idx]): 501 | link_to = next((x for x in reversed(group) if not x.linked), None) 502 | if link_to: 503 | for e in search_group: 504 | e.link_event(link_to) 505 | break 506 | else: 507 | for e in search_group: 508 | e.set_shift(group_state["shift"], group_state["diff"]) 509 | 510 | 511 | def check_file_exists(path, file_title): 512 | if path and not os.path.exists(path): 513 | raise SushiError("{0} file doesn't exist".format(file_title)) 514 | 515 | 516 | def format_full_path(temp_dir, base_path, postfix): 517 | if temp_dir: 518 | return os.path.join(temp_dir, os.path.basename(base_path) + postfix) 519 | else: 520 | return base_path + postfix 521 | 522 | 523 | def create_directory_if_not_exists(path): 524 | if path and not os.path.exists(path): 525 | os.makedirs(path) 526 | 527 | 528 | def run(args): 529 | ignore_chapters = args.chapters_file is not None and args.chapters_file.lower() == 'none' 530 | write_plot = plot_enabled and args.plot_path 531 | if write_plot: 532 | plt.clf() 533 | plt.ylabel('Shift, seconds') 534 | plt.xlabel('Event index') 535 | 536 | # first part should do all possible validation and should NOT take significant amount of time 537 | check_file_exists(args.source, 'Source') 538 | check_file_exists(args.destination, 'Destination') 539 | check_file_exists(args.src_timecodes, 'Source timecodes') 540 | check_file_exists(args.dst_timecodes, 'Source timecodes') 541 | check_file_exists(args.script_file, 'Script') 542 | 543 | if not ignore_chapters: 544 | check_file_exists(args.chapters_file, 'Chapters') 545 | if args.src_keyframes not in ('auto', 'make'): 546 | check_file_exists(args.src_keyframes, 'Source keyframes') 547 | if args.dst_keyframes not in ('auto', 'make'): 548 | check_file_exists(args.dst_keyframes, 'Destination keyframes') 549 | 550 | if (args.src_timecodes and args.src_fps) or (args.dst_timecodes and args.dst_fps): 551 | raise SushiError('Both fps and timecodes file cannot be specified at the same time') 552 | 553 | src_demuxer = Demuxer(args.source) 554 | dst_demuxer = Demuxer(args.destination) 555 | 556 | if src_demuxer.is_wav and not args.script_file: 557 | raise SushiError("Script file isn't specified") 558 | 559 | if (args.src_keyframes and not args.dst_keyframes) or (args.dst_keyframes and not args.src_keyframes): 560 | raise SushiError('Either none or both of src and dst keyframes should be provided') 561 | 562 | create_directory_if_not_exists(args.temp_dir) 563 | 564 | # selecting source audio 565 | if src_demuxer.is_wav: 566 | src_audio_path = args.source 567 | else: 568 | src_audio_path = format_full_path(args.temp_dir, args.source, '.sushi.wav') 569 | src_demuxer.set_audio(stream_idx=args.src_audio_idx, output_path=src_audio_path, sample_rate=args.sample_rate) 570 | 571 | # selecting destination audio 572 | if dst_demuxer.is_wav: 573 | dst_audio_path = args.destination 574 | else: 575 | dst_audio_path = format_full_path(args.temp_dir, args.destination, '.sushi.wav') 576 | dst_demuxer.set_audio(stream_idx=args.dst_audio_idx, output_path=dst_audio_path, sample_rate=args.sample_rate) 577 | 578 | # selecting source subtitles 579 | if args.script_file: 580 | src_script_path = args.script_file 581 | else: 582 | stype = src_demuxer.get_subs_type(args.src_script_idx) 583 | src_script_path = format_full_path(args.temp_dir, args.source, '.sushi'+ stype) 584 | src_demuxer.set_script(stream_idx=args.src_script_idx, output_path=src_script_path) 585 | 586 | script_extension = get_extension(src_script_path) 587 | if script_extension not in ('.ass', '.srt'): 588 | raise SushiError('Unknown script type') 589 | 590 | # selection destination subtitles 591 | if args.output_script: 592 | dst_script_path = args.output_script 593 | dst_script_extension = get_extension(args.output_script) 594 | if dst_script_extension != script_extension: 595 | raise SushiError("Source and destination script file types don't match ({0} vs {1})" 596 | .format(script_extension, dst_script_extension)) 597 | else: 598 | dst_script_path = format_full_path(args.temp_dir, args.destination, '.sushi' + script_extension) 599 | 600 | # selecting chapters 601 | if args.grouping and not ignore_chapters: 602 | if args.chapters_file: 603 | if get_extension(args.chapters_file) == '.xml': 604 | chapter_times = chapters.get_xml_start_times(args.chapters_file) 605 | else: 606 | chapter_times = chapters.get_ogm_start_times(args.chapters_file) 607 | elif not src_demuxer.is_wav: 608 | chapter_times = src_demuxer.chapters 609 | output_path = format_full_path(args.temp_dir, src_demuxer.path, ".sushi.chapters.txt") 610 | src_demuxer.set_chapters(output_path) 611 | else: 612 | chapter_times = [] 613 | else: 614 | chapter_times = [] 615 | 616 | # selecting keyframes and timecodes 617 | if args.src_keyframes: 618 | def select_keyframes(file_arg, demuxer): 619 | auto_file = format_full_path(args.temp_dir, demuxer.path, '.sushi.keyframes.txt') 620 | if file_arg in ('auto', 'make'): 621 | if file_arg == 'make' or not os.path.exists(auto_file): 622 | if not demuxer.has_video: 623 | raise SushiError("Cannot make keyframes for {0} because it doesn't have any video!" 624 | .format(demuxer.path)) 625 | demuxer.set_keyframes(output_path=auto_file) 626 | return auto_file 627 | else: 628 | return file_arg 629 | 630 | def select_timecodes(external_file, fps_arg, demuxer): 631 | if external_file: 632 | return external_file 633 | elif fps_arg: 634 | return None 635 | elif demuxer.has_video: 636 | path = format_full_path(args.temp_dir, demuxer.path, '.sushi.timecodes.txt') 637 | demuxer.set_timecodes(output_path=path) 638 | return path 639 | else: 640 | raise SushiError('Fps, timecodes or video files must be provided if keyframes are used') 641 | 642 | src_keyframes_file = select_keyframes(args.src_keyframes, src_demuxer) 643 | dst_keyframes_file = select_keyframes(args.dst_keyframes, dst_demuxer) 644 | src_timecodes_file = select_timecodes(args.src_timecodes, args.src_fps, src_demuxer) 645 | dst_timecodes_file = select_timecodes(args.dst_timecodes, args.dst_fps, dst_demuxer) 646 | 647 | # after this point nothing should fail so it's safe to start slow operations 648 | # like running the actual demuxing 649 | src_demuxer.demux() 650 | dst_demuxer.demux() 651 | 652 | try: 653 | if args.src_keyframes: 654 | src_timecodes = Timecodes.cfr(args.src_fps) if args.src_fps else Timecodes.from_file(src_timecodes_file) 655 | src_keytimes = [src_timecodes.get_frame_time(f) for f in keyframes.parse_keyframes(src_keyframes_file)] 656 | 657 | dst_timecodes = Timecodes.cfr(args.dst_fps) if args.dst_fps else Timecodes.from_file(dst_timecodes_file) 658 | dst_keytimes = [dst_timecodes.get_frame_time(f) for f in keyframes.parse_keyframes(dst_keyframes_file)] 659 | 660 | script = AssScript.from_file(src_script_path) if script_extension == '.ass' else SrtScript.from_file(src_script_path) 661 | script.sort_by_time() 662 | 663 | src_stream = WavStream(src_audio_path, sample_rate=args.sample_rate, sample_type=args.sample_type) 664 | dst_stream = WavStream(dst_audio_path, sample_rate=args.sample_rate, sample_type=args.sample_type) 665 | 666 | search_groups = prepare_search_groups(script.events, 667 | source_duration=src_stream.duration_seconds, 668 | chapter_times=chapter_times, 669 | max_ts_duration=args.max_ts_duration, 670 | max_ts_distance=args.max_ts_distance) 671 | 672 | calculate_shifts(src_stream, dst_stream, search_groups, 673 | normal_window=args.window, 674 | max_window=args.max_window, 675 | rewind_thresh=args.rewind_thresh if args.grouping else 0) 676 | 677 | events = script.events 678 | 679 | if write_plot: 680 | plt.plot([x.shift for x in events], label='From audio') 681 | 682 | if args.grouping: 683 | if not ignore_chapters and chapter_times: 684 | groups = groups_from_chapters(events, chapter_times) 685 | for g in groups: 686 | fix_near_borders(g) 687 | smooth_events([x for x in g if not x.linked], args.smooth_radius) 688 | groups = split_broken_groups(groups) 689 | else: 690 | fix_near_borders(events) 691 | smooth_events([x for x in events if not x.linked], args.smooth_radius) 692 | groups = detect_groups(events) 693 | 694 | if write_plot: 695 | plt.plot([x.shift for x in events], label='Borders fixed') 696 | 697 | for g in groups: 698 | start_shift = g[0].shift 699 | end_shift = g[-1].shift 700 | avg_shift = average_shifts(g) 701 | logging.info(u'Group (start: {0}, end: {1}, lines: {2}), ' 702 | u'shifts (start: {3}, end: {4}, average: {5})' 703 | .format(format_time(g[0].start), format_time(g[-1].end), len(g), start_shift, end_shift, 704 | avg_shift)) 705 | 706 | if args.src_keyframes: 707 | for e in (x for x in events if x.linked): 708 | e.resolve_link() 709 | for g in groups: 710 | snap_groups_to_keyframes(g, chapter_times, args.max_ts_duration, args.max_ts_distance, src_keytimes, 711 | dst_keytimes, src_timecodes, dst_timecodes, args.max_kf_distance, args.kf_mode) 712 | else: 713 | fix_near_borders(events) 714 | if write_plot: 715 | plt.plot([x.shift for x in events], label='Borders fixed') 716 | 717 | if args.src_keyframes: 718 | for e in (x for x in events if x.linked): 719 | e.resolve_link() 720 | snap_groups_to_keyframes(events, chapter_times, args.max_ts_duration, args.max_ts_distance, src_keytimes, 721 | dst_keytimes, src_timecodes, dst_timecodes, args.max_kf_distance, args.kf_mode) 722 | 723 | for event in events: 724 | event.apply_shift() 725 | 726 | script.save_to_file(dst_script_path) 727 | 728 | if write_plot: 729 | plt.plot([x.shift + (x._start_shift + x._end_shift)/2.0 for x in events], label='After correction') 730 | plt.legend(fontsize=5, frameon=False, fancybox=False) 731 | plt.savefig(args.plot_path, dpi=300) 732 | 733 | finally: 734 | if args.cleanup: 735 | src_demuxer.cleanup() 736 | dst_demuxer.cleanup() 737 | 738 | 739 | def create_arg_parser(): 740 | parser = argparse.ArgumentParser(description='Sushi - Automatic Subtitle Shifter') 741 | 742 | parser.add_argument('--window', default=10, type=int, metavar='', dest='window', 743 | help='Search window size. [%(default)s]') 744 | parser.add_argument('--max-window', default=30, type=int, metavar='', dest='max_window', 745 | help="Maximum search size Sushi is allowed to use when trying to recover from errors. [%(default)s]") 746 | parser.add_argument('--rewind-thresh', default=5, type=int, metavar='', dest='rewind_thresh', 747 | help="Number of consecutive errors Sushi has to encounter to consider results broken " 748 | "and retry with larger window. Set to 0 to disable. [%(default)s]") 749 | parser.add_argument('--no-grouping', action='store_false', dest='grouping', 750 | help="Don't events into groups before shifting. Also disables error recovery.") 751 | parser.add_argument('--max-kf-distance', default=2, type=float, metavar='', dest='max_kf_distance', 752 | help='Maximum keyframe snapping distance. [%(default)s]') 753 | parser.add_argument('--kf-mode', default='all', choices=['shift', 'snap', 'all'], dest='kf_mode', 754 | help='Keyframes-based shift correction/snapping mode. [%(default)s]') 755 | parser.add_argument('--smooth-radius', default=3, type=int, metavar='', dest='smooth_radius', 756 | help='Radius of smoothing median filter. [%(default)s]') 757 | 758 | # 10 frames at 23.976 759 | parser.add_argument('--max-ts-duration', default=1001.0 / 24000.0 * 10, type=float, metavar='', 760 | dest='max_ts_duration', 761 | help='Maximum duration of a line to be considered typesetting. [%(default).3f]') 762 | # 10 frames at 23.976 763 | parser.add_argument('--max-ts-distance', default=1001.0 / 24000.0 * 10, type=float, metavar='', 764 | dest='max_ts_distance', 765 | help='Maximum distance between two adjacent typesetting lines to be merged. [%(default).3f]') 766 | 767 | # deprecated/test options, do not use 768 | parser.add_argument('--test-shift-plot', default=None, dest='plot_path', help=argparse.SUPPRESS) 769 | parser.add_argument('--sample-type', default='uint8', choices=['float32', 'uint8'], dest='sample_type', 770 | help=argparse.SUPPRESS) 771 | 772 | parser.add_argument('--sample-rate', default=12000, type=int, metavar='', dest='sample_rate', 773 | help='Downsampled audio sample rate. [%(default)s]') 774 | 775 | parser.add_argument('--src-audio', default=None, type=int, metavar='', dest='src_audio_idx', 776 | help='Audio stream index of the source video') 777 | parser.add_argument('--src-script', default=None, type=int, metavar='', dest='src_script_idx', 778 | help='Script stream index of the source video') 779 | parser.add_argument('--dst-audio', default=None, type=int, metavar='', dest='dst_audio_idx', 780 | help='Audio stream index of the destination video') 781 | # files 782 | parser.add_argument('--no-cleanup', action='store_false', dest='cleanup', 783 | help="Don't delete demuxed streams") 784 | parser.add_argument('--temp-dir', default=None, dest='temp_dir', metavar='', 785 | help='Specify temporary folder to use when demuxing stream.') 786 | parser.add_argument('--chapters', default=None, dest='chapters_file', metavar='', 787 | help="XML or OGM chapters to use instead of any found in the source. 'none' to disable.") 788 | parser.add_argument('--script', default=None, dest='script_file', metavar='', 789 | help='Subtitle file path to use instead of any found in the source') 790 | 791 | parser.add_argument('--dst-keyframes', default=None, dest='dst_keyframes', metavar='', 792 | help='Destination keyframes file') 793 | parser.add_argument('--src-keyframes', default=None, dest='src_keyframes', metavar='', 794 | help='Source keyframes file') 795 | parser.add_argument('--dst-fps', default=None, type=float, dest='dst_fps', metavar='', 796 | help='Fps of the destination video. Must be provided if keyframes are used.') 797 | parser.add_argument('--src-fps', default=None, type=float, dest='src_fps', metavar='', 798 | help='Fps of the source video. Must be provided if keyframes are used.') 799 | parser.add_argument('--dst-timecodes', default=None, dest='dst_timecodes', metavar='', 800 | help='Timecodes file to use instead of making one from the destination (when possible)') 801 | parser.add_argument('--src-timecodes', default=None, dest='src_timecodes', metavar='', 802 | help='Timecodes file to use instead of making one from the source (when possible)') 803 | 804 | parser.add_argument('--src', required=True, dest="source", metavar='', 805 | help='Source audio/video') 806 | parser.add_argument('--dst', required=True, dest="destination", metavar='', 807 | help='Destination audio/video') 808 | parser.add_argument('-o', '--output', default=None, dest='output_script', metavar='', 809 | help='Output script') 810 | 811 | parser.add_argument('-v', '--verbose', default=False, dest='verbose', action='store_true', 812 | help='Enable verbose logging') 813 | parser.add_argument('--version', action='version', version=VERSION) 814 | 815 | return parser 816 | 817 | 818 | def parse_args_and_run(cmd_keys): 819 | def format_arg(arg): 820 | return arg if ' ' not in arg else '"{0}"'.format(arg) 821 | 822 | args = create_arg_parser().parse_args(cmd_keys) 823 | handler = logging.StreamHandler() 824 | if console_colors_supported and os.isatty(sys.stderr.fileno()): 825 | # enable colors 826 | handler.setFormatter(ColoredLogFormatter()) 827 | else: 828 | handler.setFormatter(logging.Formatter(fmt=ColoredLogFormatter.default_format)) 829 | logging.root.addHandler(handler) 830 | logging.root.setLevel(logging.DEBUG if args.verbose else logging.INFO) 831 | 832 | logging.info("Sushi's running with arguments: {0}".format(' '.join(map(format_arg, cmd_keys)))) 833 | start_time = time.time() 834 | run(args) 835 | logging.info('Done in {0}s'.format(time.time() - start_time)) 836 | 837 | 838 | if __name__ == '__main__': 839 | try: 840 | parse_args_and_run(sys.argv[1:]) 841 | except SushiError as e: 842 | logging.critical(e.message) 843 | sys.exit(2) 844 | -------------------------------------------------------------------------------- /tests.example.json: -------------------------------------------------------------------------------- 1 | { 2 | "basepath": "J:", 3 | "plots": "J:\\plots", 4 | "tests": { 5 | "first test":{ 6 | "disabled": false, 7 | "folder": "test1", 8 | "src": "tv.wav", 9 | "dst": "bd.wav", 10 | "dst-keyframes": "bd.keyframes.txt", 11 | "src-keyframes": "tv.keyframes.txt", 12 | "dst-timecodes": "bd.timecodes.txt", 13 | "src-timecodes": "tv.timecodes.txt", 14 | "chapters": "tv.chapters.xml", 15 | "script": "tv.ass", 16 | "ideal": "ideal.ass", 17 | "fps": 23.976023976023978, 18 | "max-kf-distance": 3, 19 | "expected_errors": 84 20 | } 21 | }, 22 | "wavs": { 23 | "first wav": { 24 | "max_time": 0.7, 25 | "max_memory": 120, 26 | "sample_type": "uint8", 27 | "sample_rate": 12000, 28 | "file": "wavs/file.wav" 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tp7/Sushi/908c0ff228734059aebb914a8d10f8e4ce2e868c/tests/__init__.py -------------------------------------------------------------------------------- /tests/demuxing.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import mock 3 | 4 | from demux import FFmpeg, MkvToolnix, SCXviD 5 | from common import SushiError 6 | import chapters 7 | 8 | 9 | def create_popen_mock(): 10 | popen_mock = mock.Mock() 11 | process_mock = mock.Mock() 12 | process_mock.configure_mock(**{'communicate.return_value': ('ouput', 'error')}) 13 | popen_mock.return_value = process_mock 14 | return popen_mock 15 | 16 | 17 | class FFmpegTestCase(unittest.TestCase): 18 | ffmpeg_output = '''Input #0, matroska,webm, from 'test.mkv': 19 | Stream #0:0(jpn): Video: h264 (High 10), yuv420p10le, 1280x720 [SAR 1:1 DAR 16:9], 23.98 fps, 23.98 tbr, 1k tbn, 47.95 tbc (default) 20 | Metadata: 21 | title : Video 10bit 22 | Stream #0:1(jpn): Audio: aac, 48000 Hz, stereo, fltp (default) (forced) 23 | Metadata: 24 | title : Audio AAC 2.0 25 | Stream #0:2(eng): Audio: aac, 48000 Hz, stereo, fltp 26 | Metadata: 27 | title : English Audio AAC 2.0 28 | Stream #0:3(eng): Subtitle: ssa (default) (forced) 29 | Metadata: 30 | title : English Subtitles 31 | Stream #0:4(enm): Subtitle: ass 32 | Metadata: 33 | title : English (JP honorifics) 34 | .................................''' 35 | 36 | def test_parses_audio_stream(self): 37 | audio = FFmpeg._get_audio_streams(self.ffmpeg_output) 38 | self.assertEqual(len(audio), 2) 39 | self.assertEqual(audio[0].id, 1) 40 | self.assertEqual(audio[0].title, 'Audio AAC 2.0') 41 | self.assertEqual(audio[1].id, 2) 42 | self.assertEqual(audio[1].title, 'English Audio AAC 2.0') 43 | 44 | def test_parses_video_stream(self): 45 | video = FFmpeg._get_video_streams(self.ffmpeg_output) 46 | self.assertEqual(len(video), 1) 47 | self.assertEqual(video[0].id, 0) 48 | self.assertEqual(video[0].title, 'Video 10bit') 49 | 50 | def test_parses_subtitles_stream(self): 51 | subs = FFmpeg._get_subtitles_streams(self.ffmpeg_output) 52 | self.assertEqual(len(subs), 2) 53 | self.assertEqual(subs[0].id, 3) 54 | self.assertTrue(subs[0].default) 55 | self.assertEqual(subs[0].title, 'English Subtitles') 56 | self.assertEqual(subs[1].id, 4) 57 | self.assertFalse(subs[1].default) 58 | self.assertEqual(subs[1].title, 'English (JP honorifics)') 59 | 60 | @mock.patch('subprocess.Popen', new_callable=create_popen_mock) 61 | def test_get_info_call_args(self, popen_mock): 62 | FFmpeg.get_info('random_file.mkv') 63 | self.assertEquals(popen_mock.call_args[0][0], ['ffmpeg', '-hide_banner', '-i', 'random_file.mkv']) 64 | 65 | @mock.patch('subprocess.Popen', new_callable=create_popen_mock) 66 | def test_get_info_fail_when_no_mmpeg(self, popen_mock): 67 | popen_mock.return_value.communicate.side_effect = OSError(2, "ignored") 68 | self.assertRaises(SushiError, lambda: FFmpeg.get_info('random.mkv')) 69 | 70 | @mock.patch('subprocess.call') 71 | def test_demux_file_call_args(self, call_mock): 72 | FFmpeg.demux_file('random.mkv', audio_stream=0, audio_path='audio1.wav') 73 | FFmpeg.demux_file('random.mkv', audio_stream=0, audio_path='audio2.wav', audio_rate=12000) 74 | FFmpeg.demux_file('random.mkv', script_stream=0, script_path='subs1.ass') 75 | FFmpeg.demux_file('random.mkv', video_stream=0, timecodes_path='tcs1.txt') 76 | 77 | FFmpeg.demux_file('random.mkv', audio_stream=1, audio_path='audio0.wav', audio_rate=12000, 78 | script_stream=2, script_path='out0.ass', video_stream=0, timecodes_path='tcs0.txt') 79 | 80 | call_mock.assert_any_call(['ffmpeg', '-hide_banner', '-i', 'random.mkv', '-y', 81 | '-map', '0:0', '-ac', '1', '-acodec', 'pcm_s16le', 'audio1.wav']) 82 | call_mock.assert_any_call(['ffmpeg', '-hide_banner', '-i', 'random.mkv', '-y', 83 | '-map', '0:0', '-ar', '12000', '-ac', '1', '-acodec', 'pcm_s16le', 'audio2.wav']) 84 | call_mock.assert_any_call(['ffmpeg', '-hide_banner', '-i', 'random.mkv', '-y', 85 | '-map', '0:0', 'subs1.ass']) 86 | call_mock.assert_any_call(['ffmpeg', '-hide_banner', '-i', 'random.mkv', '-y', 87 | '-map', '0:0', '-f', 'mkvtimestamp_v2', 'tcs1.txt']) 88 | call_mock.assert_any_call(['ffmpeg', '-hide_banner', '-i', 'random.mkv', '-y', 89 | '-map', '0:1', '-ar', '12000', '-ac', '1', '-acodec', 'pcm_s16le', 'audio0.wav', 90 | '-map', '0:2', 'out0.ass', 91 | '-map', '0:0', '-f', 'mkvtimestamp_v2', 'tcs0.txt']) 92 | 93 | 94 | class MkvExtractTestCase(unittest.TestCase): 95 | @mock.patch('subprocess.call') 96 | def test_extract_timecodes(self, call_mock): 97 | MkvToolnix.extract_timecodes('video.mkv', 1, 'timecodes.tsc') 98 | call_mock.assert_called_once_with(['mkvextract', 'timecodes_v2', 'video.mkv', '1:timecodes.tsc']) 99 | 100 | 101 | class SCXviDTestCase(unittest.TestCase): 102 | @mock.patch('subprocess.Popen') 103 | def test_make_keyframes(self, popen_mock): 104 | SCXviD.make_keyframes('video.mkv', 'keyframes.txt') 105 | self.assertTrue('ffmpeg' in (x.lower() for x in popen_mock.call_args_list[0][0][0])) 106 | self.assertTrue('scxvid' in (x.lower() for x in popen_mock.call_args_list[1][0][0])) 107 | 108 | @mock.patch('subprocess.Popen') 109 | def test_no_ffmpeg(self, popen_mock): 110 | def raise_no_app(cmd_args, **kwargs): 111 | if 'ffmpeg' in (x.lower() for x in cmd_args): 112 | raise OSError(2, 'ignored') 113 | 114 | popen_mock.side_effect = raise_no_app 115 | self.assertRaisesRegexp(SushiError, '[fF][fF][mM][pP][eE][gG]', 116 | lambda: SCXviD.make_keyframes('video.mkv', 'keyframes.txt')) 117 | 118 | @mock.patch('subprocess.Popen') 119 | def test_no_scxvid(self, popen_mock): 120 | def raise_no_app(cmd_args, **kwargs): 121 | if 'scxvid' in (x.lower() for x in cmd_args): 122 | raise OSError(2, 'ignored') 123 | return mock.Mock() 124 | 125 | popen_mock.side_effect = raise_no_app 126 | self.assertRaisesRegexp(SushiError, '[sS][cC][xX][vV][iI][dD]', 127 | lambda: SCXviD.make_keyframes('video.mkv', 'keyframes.txt')) 128 | 129 | 130 | class ExternalChaptersTestCase(unittest.TestCase): 131 | def test_parse_xml_start_times(self): 132 | file_text = """ 133 | 134 | 135 | 136 | 2092209815 137 | 138 | 3122448259 139 | 00:00:00.000000000 140 | 141 | Prologue 142 | 143 | 144 | 145 | 998777246 146 | 00:00:17.017000000 147 | 148 | Opening Song ("YES!") 149 | 150 | 151 | 152 | 55571857 153 | 00:01:47.023000000 154 | 155 | Part A (Tale of the Doggypus) 156 | 157 | 158 | 159 | 160 | """ 161 | parsed_times = chapters.parse_xml_start_times(file_text) 162 | self.assertEqual(parsed_times, [0, 17.017, 107.023]) 163 | 164 | def test_parse_ogm_start_times(self): 165 | file_text = """CHAPTER01=00:00:00.000 166 | CHAPTER01NAME=Prologue 167 | CHAPTER02=00:00:17.017 168 | CHAPTER02NAME=Opening Song ("YES!") 169 | CHAPTER03=00:01:47.023 170 | CHAPTER03NAME=Part A (Tale of the Doggypus) 171 | """ 172 | parsed_times = chapters.parse_ogm_start_times(file_text) 173 | self.assertEqual(parsed_times, [0, 17.017, 107.023]) 174 | 175 | def test_format_ogm_chapters(self): 176 | chapters_text = chapters.format_ogm_chapters(start_times=[0, 17.017, 107.023]) 177 | self.assertEqual(chapters_text, """CHAPTER01=00:00:00.000 178 | CHAPTER01NAME= 179 | CHAPTER02=00:00:17.017 180 | CHAPTER02NAME= 181 | CHAPTER03=00:01:47.023 182 | CHAPTER03NAME= 183 | """) 184 | -------------------------------------------------------------------------------- /tests/main.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import os 3 | import re 4 | import unittest 5 | from mock import patch, ANY 6 | from common import SushiError, format_time 7 | import sushi 8 | 9 | here = os.path.dirname(os.path.abspath(__file__)) 10 | 11 | 12 | class FakeEvent(object): 13 | def __init__(self, shift=0.0, diff=0.0, end=0.0, start=0.0): 14 | self.shift = shift 15 | self.linked = None 16 | self.diff = diff 17 | self.start = start 18 | self.end = end 19 | 20 | def set_shift(self, shift, diff): 21 | self.shift = shift 22 | self.diff = diff 23 | 24 | def link_event(self, other): 25 | self.linked = other 26 | 27 | def __repr__(self): 28 | return repr(self.shift) 29 | 30 | def __eq__(self, other): 31 | return self.__dict__ == other.__dict__ 32 | 33 | 34 | class InterpolateNonesTestCase(unittest.TestCase): 35 | def test_returns_empty_array_when_passed_empty_array(self): 36 | self.assertEquals(sushi.interpolate_nones([], []), []) 37 | 38 | def test_returns_false_when_no_valid_points(self): 39 | self.assertFalse(sushi.interpolate_nones([None, None, None], [1, 2, 3])) 40 | 41 | def test_returns_full_array_when_no_nones(self): 42 | self.assertEqual(sushi.interpolate_nones([1, 2, 3], [1, 2, 3]), [1, 2, 3]) 43 | 44 | def test_interpolates_simple_nones(self): 45 | self.assertEqual(sushi.interpolate_nones([1, None, 3, None, 5], [1, 2, 3, 4, 5]), [1, 2, 3, 4, 5]) 46 | 47 | def test_interpolates_multiple_adjacent_nones(self): 48 | self.assertEqual(sushi.interpolate_nones([1, None, None, None, 5], [1, 2, 3, 4, 5]), [1, 2, 3, 4, 5]) 49 | 50 | def test_copies_values_to_borders(self): 51 | self.assertEqual(sushi.interpolate_nones([None, None, 2, None, None], [1, 2, 3, 4, 5]), [2, 2, 2, 2, 2]) 52 | 53 | def test_copies_values_to_borders_when_everything_is_zero(self): 54 | self.assertEqual(sushi.interpolate_nones([None, 0, 0, 0, None], [1, 2, 3, 4, 5]), [0, 0, 0, 0, 0]) 55 | 56 | def test_interpolates_based_on_passed_points(self): 57 | self.assertEqual(sushi.interpolate_nones([1, None, 10], [1, 2, 10]), [1, 2, 10]) 58 | 59 | 60 | class RunningMedianTestCase(unittest.TestCase): 61 | def test_does_no_touch_border_values(self): 62 | shifts = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 63 | smooth = sushi.running_median(shifts, 5) 64 | self.assertEqual(shifts, smooth) 65 | 66 | def test_removes_broken_values(self): 67 | shifts = [0.1, 0.1, 0.1, 9001, 0.1, 0.1, 0.1] 68 | smooth = sushi.running_median(shifts, 5) 69 | self.assertEqual(smooth, [0.1] * 7) 70 | 71 | 72 | class SmoothEventsTestCase(unittest.TestCase): 73 | def test_smooths_events_shifts(self): 74 | events = [FakeEvent(x) for x in (0.1, 0.1, 0.1, 9001, 7777, 0.1, 0.1, 0.1)] 75 | sushi.smooth_events(events, 7) 76 | self.assertEqual([x.shift for x in events], [0.1] * 8) 77 | 78 | def test_keeps_diff_values(self): 79 | events = [FakeEvent(x, diff=x) for x in (0.1, 0.1, 0.1, 9001, 7777, 0.1, 0.1, 0.1)] 80 | diffs = [x.diff for x in events] 81 | sushi.smooth_events(events, 7) 82 | self.assertEqual([x.diff for x in events], diffs) 83 | 84 | 85 | class DetectGroupsTestCase(unittest.TestCase): 86 | def test_splits_three_simple_groups(self): 87 | events = [FakeEvent(0.5)] * 3 + [FakeEvent(1.0)] * 10 + [FakeEvent(0.5)] * 5 88 | groups = sushi.detect_groups(events) 89 | self.assertEqual(3, len(groups[0])) 90 | self.assertEqual(10, len(groups[1])) 91 | self.assertEqual(5, len(groups[2])) 92 | 93 | def test_single_group_for_all_events(self): 94 | events = [FakeEvent(0.5)] * 10 95 | groups = sushi.detect_groups(events) 96 | self.assertEqual(10, len(groups[0])) 97 | 98 | 99 | class GroupsFromChaptersTestCase(unittest.TestCase): 100 | def test_all_events_in_one_group_when_no_chapters(self): 101 | events = [FakeEvent(end=1), FakeEvent(end=2), FakeEvent(end=3)] 102 | groups = sushi.groups_from_chapters(events, []) 103 | self.assertEqual(1, len(groups)) 104 | self.assertEqual(events, groups[0]) 105 | 106 | def test_events_in_two_groups_one_chapter(self): 107 | events = [FakeEvent(end=1), FakeEvent(end=2), FakeEvent(end=3)] 108 | groups = sushi.groups_from_chapters(events, [0.0, 1.5]) 109 | self.assertEqual(2, len(groups)) 110 | self.assertItemsEqual([events[0]], groups[0]) 111 | self.assertItemsEqual([events[1], events[2]], groups[1]) 112 | 113 | def test_multiple_groups_multiple_chapters(self): 114 | events = [FakeEvent(end=x) for x in xrange(1, 10)] 115 | groups = sushi.groups_from_chapters(events, [0.0, 3.2, 4.4, 7.7]) 116 | self.assertEqual(4, len(groups)) 117 | self.assertItemsEqual(events[0:3], groups[0]) 118 | self.assertItemsEqual(events[3:4], groups[1]) 119 | self.assertItemsEqual(events[4:7], groups[2]) 120 | self.assertItemsEqual(events[7:9], groups[3]) 121 | 122 | 123 | class SplitBrokenGroupsTestCase(unittest.TestCase): 124 | def test_doing_nothing_on_correct_groups(self): 125 | groups = [[FakeEvent(0.5), FakeEvent(0.5)], [FakeEvent(10.0)]] 126 | fixed = sushi.split_broken_groups(groups) 127 | self.assertItemsEqual(groups, fixed) 128 | 129 | def test_split_groups_without_merging(self): 130 | groups = [ 131 | [FakeEvent(0.5)] * 10 + [FakeEvent(10.0)] * 5, 132 | [FakeEvent(0.5)] * 10, 133 | ] 134 | fixed = sushi.split_broken_groups(groups) 135 | self.assertItemsEqual([ 136 | [FakeEvent(0.5)] * 10, 137 | [FakeEvent(10.0)] * 5, 138 | [FakeEvent(0.5)] * 10 139 | ], fixed) 140 | 141 | def test_split_groups_with_merging(self): 142 | groups = [ 143 | [FakeEvent(0.5), FakeEvent(10.0)], 144 | [FakeEvent(10.0), FakeEvent(10.0), FakeEvent(15.0)] 145 | ] 146 | fixed = sushi.split_broken_groups(groups) 147 | self.assertItemsEqual([ 148 | [FakeEvent(0.5)], 149 | [FakeEvent(10.0), FakeEvent(10.0), FakeEvent(10.0)], 150 | [FakeEvent(15.0)] 151 | ], fixed) 152 | 153 | 154 | class FixNearBordersTestCase(unittest.TestCase): 155 | def test_propagates_last_correct_shift_to_broken_events(self): 156 | events = [FakeEvent(diff=x) for x in (0.9, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.9)] 157 | sushi.fix_near_borders(events) 158 | sf = events[2] 159 | sl = events[-3] 160 | self.assertEqual([x.linked for x in events], [sf, sf, None, None, None, None, None, sl, sl]) 161 | 162 | def test_returns_array_with_no_broken_events_unchanged(self): 163 | events = [FakeEvent(diff=x) for x in (0.9, 0.9, 0.9, 1.0, 0.9)] 164 | sushi.fix_near_borders(events) 165 | self.assertEqual([x.linked for x in events], [None, None, None, None, None]) 166 | 167 | 168 | class GetDistanceToClosestKeyframeTestCase(unittest.TestCase): 169 | KEYTIMES = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] 170 | 171 | def test_finds_correct_distance_to_first_keyframe(self): 172 | self.assertEqual(sushi.get_distance_to_closest_kf(0, self.KEYTIMES), 0) 173 | 174 | def test_finds_correct_distance_to_last_keyframe(self): 175 | self.assertEqual(sushi.get_distance_to_closest_kf(105, self.KEYTIMES), -5) 176 | 177 | def test_finds_correct_distance_to_keyframe_before(self): 178 | self.assertEqual(sushi.get_distance_to_closest_kf(63, self.KEYTIMES), -3) 179 | 180 | def test_finds_distance_to_keyframe_after(self): 181 | self.assertEqual(sushi.get_distance_to_closest_kf(36, self.KEYTIMES), 4) 182 | 183 | 184 | @patch('sushi.check_file_exists') 185 | class MainScriptTestCase(unittest.TestCase): 186 | @staticmethod 187 | def any_case_regex(text): 188 | return re.compile(text, flags=re.IGNORECASE) 189 | 190 | def test_checks_that_files_exist(self, mock_object): 191 | keys = ['--dst', 'dst', '--src', 'src', '--script', 'script', '--chapters', 'chapters', 192 | '--dst-keyframes', 'dst-keyframes', '--src-keyframes', 'src-keyframes', 193 | '--src-timecodes', 'src-tcs', '--dst-timecodes', 'dst-tcs'] 194 | try: 195 | sushi.parse_args_and_run(keys) 196 | except SushiError: 197 | pass 198 | mock_object.assert_any_call('src', ANY) 199 | mock_object.assert_any_call('dst', ANY) 200 | mock_object.assert_any_call('script', ANY) 201 | mock_object.assert_any_call('chapters', ANY) 202 | mock_object.assert_any_call('dst-keyframes', ANY) 203 | mock_object.assert_any_call('src-keyframes', ANY) 204 | mock_object.assert_any_call('dst-tcs', ANY) 205 | mock_object.assert_any_call('src-tcs', ANY) 206 | 207 | def test_raises_on_unknown_script_type(self, ignore): 208 | keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.mp4'] 209 | self.assertRaisesRegexp(SushiError, self.any_case_regex(r'script.*type'), lambda: sushi.parse_args_and_run(keys)) 210 | 211 | def test_raises_on_script_type_not_matching(self, ignore): 212 | keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.ass', '-o', 'd.srt'] 213 | self.assertRaisesRegexp(SushiError, self.any_case_regex(r'script.*type.*match'), 214 | lambda: sushi.parse_args_and_run(keys)) 215 | 216 | def test_raises_on_timecodes_and_fps_being_defined_together(self, ignore): 217 | keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.ass', '--src-timecodes', 'tc.txt', '--src-fps', '25'] 218 | self.assertRaisesRegexp(SushiError, self.any_case_regex(r'timecodes'), lambda: sushi.parse_args_and_run(keys)) 219 | 220 | 221 | class FormatTimeTestCase(unittest.TestCase): 222 | def test_format_time_zero(self): 223 | self.assertEqual('0:00:00.00', format_time(0)) 224 | 225 | def test_format_time_65_seconds(self): 226 | self.assertEqual('0:01:05.00', format_time(65)) 227 | 228 | def test_format_time_float_seconds(self): 229 | self.assertEqual('0:00:05.56', format_time(5.559)) 230 | 231 | def test_format_time_hours(self): 232 | self.assertEqual('1:15:35.15', format_time(3600 + 60 * 15 + 35.15)) 233 | 234 | def test_format_100ms(self): 235 | self.assertEqual('0:09:05.00', format_time(544.997)) 236 | -------------------------------------------------------------------------------- /tests/subtitles.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tempfile 3 | import os 4 | import codecs 5 | from subs import AssEvent, AssScript, SrtEvent, SrtScript 6 | 7 | SINGLE_LINE_SRT_EVENT = """1 8 | 00:14:21,960 --> 00:14:22,960 9 | HOW DID IT END UP LIKE THIS?""" 10 | 11 | MULTILINE_SRT_EVENT = """2 12 | 00:13:12,140 --> 00:13:14,100 13 | APPEARANCE! 14 | Appearrance (teisai)! 15 | No wait, you're the worst (saitei)!""" 16 | 17 | ASS_EVENT = r"Dialogue: 0,0:18:50.98,0:18:55.28,Default,,0,0,0,,Are you trying to (ouch) crush it (ouch)\N like a (ouch) vise (ouch, ouch)?" 18 | 19 | ASS_COMMENT = r"Comment: 0,0:18:09.15,0:18:10.36,Default,,0,0,0,,I DON'T GET IT TOO WELL." 20 | 21 | 22 | class SrtEventTestCase(unittest.TestCase): 23 | def test_simple_parsing(self): 24 | event = SrtEvent.from_string(SINGLE_LINE_SRT_EVENT) 25 | self.assertEquals(14*60+21.960, event.start) 26 | self.assertEquals(14*60+22.960, event.end) 27 | self.assertEquals("HOW DID IT END UP LIKE THIS?", event.text) 28 | 29 | def test_multi_line_event_parsing(self): 30 | event = SrtEvent.from_string(MULTILINE_SRT_EVENT) 31 | self.assertEquals(13*60+12.140, event.start) 32 | self.assertEquals(13*60+14.100, event.end) 33 | self.assertEquals("APPEARANCE!\nAppearrance (teisai)!\nNo wait, you're the worst (saitei)!", event.text) 34 | 35 | def test_parsing_and_printing(self): 36 | self.assertEquals(SINGLE_LINE_SRT_EVENT, unicode(SrtEvent.from_string(SINGLE_LINE_SRT_EVENT))) 37 | self.assertEquals(MULTILINE_SRT_EVENT, unicode(SrtEvent.from_string(MULTILINE_SRT_EVENT))) 38 | 39 | 40 | class AssEventTestCase(unittest.TestCase): 41 | def test_simple_parsing(self): 42 | event = AssEvent(ASS_EVENT) 43 | self.assertFalse(event.is_comment) 44 | self.assertEquals("Dialogue", event.kind) 45 | self.assertEquals(18*60+50.98, event.start) 46 | self.assertEquals(18*60+55.28, event.end) 47 | self.assertEquals("0", event.layer) 48 | self.assertEquals("Default", event.style) 49 | self.assertEquals("", event.name) 50 | self.assertEquals("0", event.margin_left) 51 | self.assertEquals("0", event.margin_right) 52 | self.assertEquals("0", event.margin_vertical) 53 | self.assertEquals("", event.effect) 54 | self.assertEquals("Are you trying to (ouch) crush it (ouch)\N like a (ouch) vise (ouch, ouch)?", event.text) 55 | 56 | def test_comment_parsing(self): 57 | event = AssEvent(ASS_COMMENT) 58 | self.assertTrue(event.is_comment) 59 | self.assertEquals("Comment", event.kind) 60 | 61 | def test_parsing_and_printing(self): 62 | self.assertEquals(ASS_EVENT, unicode(AssEvent(ASS_EVENT))) 63 | self.assertEquals(ASS_COMMENT, unicode(AssEvent(ASS_COMMENT))) 64 | 65 | 66 | class ScriptTestBase(unittest.TestCase): 67 | def setUp(self): 68 | self.script_description, self.script_path = tempfile.mkstemp() 69 | 70 | def tearDown(self): 71 | os.remove(self.script_path) 72 | 73 | 74 | class SrtScriptTestCase(ScriptTestBase): 75 | def test_write_to_file(self): 76 | events = [SrtEvent.from_string(SINGLE_LINE_SRT_EVENT), SrtEvent.from_string(MULTILINE_SRT_EVENT)] 77 | SrtScript(events).save_to_file(self.script_path) 78 | with open(self.script_path) as script: 79 | text = script.read() 80 | self.assertEquals(SINGLE_LINE_SRT_EVENT + "\n\n" + MULTILINE_SRT_EVENT, text) 81 | 82 | def test_read_from_file(self): 83 | os.write(self.script_description, """1 84 | 00:00:17,500 --> 00:00:18,870 85 | Yeah, really! 86 | 87 | 2 88 | 00:00:17,500 --> 00:00:18,870 89 | 90 | 91 | 3 92 | 00:00:17,500 --> 00:00:18,870 93 | House number 94 | 35 95 | 96 | 4 97 | 00:00:21,250 --> 00:00:22,750 98 | Serves you right.""") 99 | parsed = SrtScript.from_file(self.script_path).events 100 | self.assertEquals(17.5, parsed[0].start) 101 | self.assertEquals(18.87, parsed[0].end) 102 | self.assertEquals("Yeah, really!", parsed[0].text) 103 | self.assertEquals(17.5, parsed[1].start) 104 | self.assertEquals(18.87, parsed[1].end) 105 | self.assertEquals("", parsed[1].text) 106 | self.assertEquals(17.5, parsed[2].start) 107 | self.assertEquals(18.87, parsed[2].end) 108 | self.assertEquals("House number\n35", parsed[2].text) 109 | self.assertEquals(21.25, parsed[3].start) 110 | self.assertEquals(22.75, parsed[3].end) 111 | self.assertEquals("Serves you right.", parsed[3].text) 112 | 113 | 114 | class AssScriptTestCase(ScriptTestBase): 115 | def test_write_to_file(self): 116 | styles = ["Style: Default,Open Sans Semibold,36,&H00FFFFFF,&H000000FF,&H00020713,&H00000000,-1,0,0,0,100,100,0,0,1,1.7,0,2,0,0,28,1"] 117 | events = [AssEvent(ASS_EVENT), AssEvent(ASS_EVENT)] 118 | AssScript([], styles, events, None).save_to_file(self.script_path) 119 | 120 | with codecs.open(self.script_path, encoding='utf-8-sig') as script: 121 | text = script.read() 122 | 123 | self.assertEquals("""[V4+ Styles] 124 | Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding 125 | Style: Default,Open Sans Semibold,36,&H00FFFFFF,&H000000FF,&H00020713,&H00000000,-1,0,0,0,100,100,0,0,1,1.7,0,2,0,0,28,1 126 | 127 | [Events] 128 | Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text 129 | {0} 130 | {0}""".format(ASS_EVENT), text) 131 | 132 | def test_read_from_file(self): 133 | text = """[Script Info] 134 | ; Script generated by Aegisub 3.1.1 135 | Title: script title 136 | 137 | [V4+ Styles] 138 | Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding 139 | Style: Default,Open Sans Semibold,36,&H00FFFFFF,&H000000FF,&H00020713,&H00000000,-1,0,0,0,100,100,0,0,1,1.7,0,2,0,0,28,1 140 | Style: Signs,Gentium Basic,40,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,0,0,0,0,100,100,0,0,1,0,0,2,10,10,10,1 141 | 142 | [Events] 143 | Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text 144 | Dialogue: 0,0:00:01.42,0:00:03.36,Default,,0000,0000,0000,,As you already know, 145 | Dialogue: 0,0:00:03.36,0:00:05.93,Default,,0000,0000,0000,,I'm concerned about the hair on my nipples.""" 146 | 147 | os.write(self.script_description, text) 148 | script = AssScript.from_file(self.script_path) 149 | self.assertEquals(["; Script generated by Aegisub 3.1.1", "Title: script title"], script.script_info) 150 | self.assertEquals(["Style: Default,Open Sans Semibold,36,&H00FFFFFF,&H000000FF,&H00020713,&H00000000,-1,0,0,0,100,100,0,0,1,1.7,0,2,0,0,28,1", 151 | "Style: Signs,Gentium Basic,40,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,0,0,0,0,100,100,0,0,1,0,0,2,10,10,10,1"], 152 | script.styles) 153 | self.assertEquals([1, 2], [x.source_index for x in script.events]) 154 | self.assertEquals(u"Dialogue: 0,0:00:01.42,0:00:03.36,Default,,0000,0000,0000,,As you already know,", unicode(script.events[0])) 155 | self.assertEquals(u"Dialogue: 0,0:00:03.36,0:00:05.93,Default,,0000,0000,0000,,I'm concerned about the hair on my nipples.", unicode(script.events[1])) 156 | -------------------------------------------------------------------------------- /tests/timecodes.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from demux import Timecodes 3 | 4 | 5 | class CfrTimecodesTestCase(unittest.TestCase): 6 | def test_get_frame_time_zero(self): 7 | tcs = Timecodes.cfr(23.976) 8 | t = tcs.get_frame_time(0) 9 | self.assertEqual(t, 0) 10 | 11 | def test_get_frame_time_sane(self): 12 | tcs = Timecodes.cfr(23.976) 13 | t = tcs.get_frame_time(10) 14 | self.assertAlmostEqual(10.0/23.976, t) 15 | 16 | def test_get_frame_time_insane(self): 17 | tcs = Timecodes.cfr(23.976) 18 | t = tcs.get_frame_time(100000) 19 | self.assertAlmostEqual(100000.0/23.976, t) 20 | 21 | def test_get_frame_size(self): 22 | tcs = Timecodes.cfr(23.976) 23 | t1 = tcs.get_frame_size(0) 24 | t2 = tcs.get_frame_size(1000) 25 | self.assertAlmostEqual(1.0/23.976, t1) 26 | self.assertAlmostEqual(t1, t2) 27 | 28 | def test_get_frame_number(self): 29 | tcs = Timecodes.cfr(24000.0/1001.0) 30 | self.assertEqual(tcs.get_frame_number(0), 0) 31 | self.assertEqual(tcs.get_frame_number(1145.353), 27461) 32 | self.assertEqual(tcs.get_frame_number(1001.0/24000.0 * 1234567), 1234567) 33 | 34 | 35 | class TimecodesTestCase(unittest.TestCase): 36 | def test_cfr_timecodes_v2(self): 37 | text = '# timecode format v2\n' + '\n'.join(str(1000 * x / 23.976) for x in range(0, 30000)) 38 | parsed = Timecodes.parse(text) 39 | 40 | self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(0)) 41 | self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(25)) 42 | self.assertAlmostEqual(1.0/23.976*100, parsed.get_frame_time(100)) 43 | self.assertEqual(0, parsed.get_frame_time(0)) 44 | self.assertEqual(0, parsed.get_frame_number(0)) 45 | self.assertEqual(27461, parsed.get_frame_number(1145.353)) 46 | 47 | def test_cfr_timecodes_v1(self): 48 | text = '# timecode format v1\nAssume 23.976024' 49 | parsed = Timecodes.parse(text) 50 | self.assertAlmostEqual(1.0/23.976024, parsed.get_frame_size(0)) 51 | self.assertAlmostEqual(1.0/23.976024, parsed.get_frame_size(25)) 52 | self.assertAlmostEqual(1.0/23.976024*100, parsed.get_frame_time(100)) 53 | self.assertEqual(0, parsed.get_frame_time(0)) 54 | self.assertEqual(0, parsed.get_frame_number(0)) 55 | self.assertEqual(27461, parsed.get_frame_number(1145.353)) 56 | 57 | def test_cfr_timecodes_v1_with_overrides(self): 58 | text = '# timecode format v1\nAssume 23.976000\n0,2000,23.976000\n3000,5000,23.976000' 59 | parsed = Timecodes.parse(text) 60 | self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(0)) 61 | self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(25)) 62 | self.assertAlmostEqual(1.0/23.976*100, parsed.get_frame_time(100)) 63 | self.assertEqual(0, parsed.get_frame_time(0)) 64 | 65 | def test_vfr_timecodes_v1_frame_size_at_first_frame(self): 66 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' 67 | parsed = Timecodes.parse(text) 68 | self.assertAlmostEqual(1.0/29.97, parsed.get_frame_size(timestamp=0)) 69 | 70 | def test_vfr_timecodes_v1_frame_size_outside_of_defined_range(self): 71 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' 72 | parsed = Timecodes.parse(text) 73 | self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(timestamp=5000.0)) 74 | 75 | def test_vft_timecodes_v1_frame_size_inside_override_block(self): 76 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' 77 | parsed = Timecodes.parse(text) 78 | self.assertAlmostEqual(1.0/29.97, parsed.get_frame_size(timestamp=49.983)) 79 | 80 | def test_vft_timecodes_v1_frame_size_between_override_blocks(self): 81 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' 82 | parsed = Timecodes.parse(text) 83 | self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(timestamp=87.496)) 84 | 85 | def test_vfr_timecodes_v1_frame_time_at_first_frame(self): 86 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' 87 | parsed = Timecodes.parse(text) 88 | self.assertAlmostEqual(0, parsed.get_frame_time(number=0)) 89 | 90 | def test_vfr_timecodes_v1_frame_time_outside_of_defined_range(self): 91 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' 92 | parsed = Timecodes.parse(text) 93 | self.assertAlmostEqual(1000.968, parsed.get_frame_time(number=25000), places=3) 94 | 95 | def test_vft_timecodes_v1_frame_time_inside_override_block(self): 96 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' 97 | parsed = Timecodes.parse(text) 98 | self.assertAlmostEqual(50.05, parsed.get_frame_time(number=1500), places=3) 99 | 100 | def test_vft_timecodes_v1_frame_time_between_override_blocks(self): 101 | text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' 102 | parsed = Timecodes.parse(text) 103 | self.assertAlmostEqual(87.579, parsed.get_frame_time(number=2500), places=3) 104 | 105 | 106 | -------------------------------------------------------------------------------- /wav.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import cv2 3 | import numpy as np 4 | from chunk import Chunk 5 | import struct 6 | import math 7 | from time import time 8 | import os.path 9 | from common import SushiError, clip 10 | 11 | WAVE_FORMAT_PCM = 0x0001 12 | WAVE_FORMAT_EXTENSIBLE = 0xFFFE 13 | 14 | 15 | class DownmixedWavFile(object): 16 | _file = None 17 | 18 | def __init__(self, path): 19 | super(DownmixedWavFile, self).__init__() 20 | self._file = open(path, 'rb') 21 | try: 22 | riff = Chunk(self._file, bigendian=False) 23 | if riff.getname() != 'RIFF': 24 | raise SushiError('File does not start with RIFF id') 25 | if riff.read(4) != 'WAVE': 26 | raise SushiError('Not a WAVE file') 27 | 28 | fmt_chunk_read = False 29 | data_chink_read = False 30 | file_size = os.path.getsize(path) 31 | 32 | while True: 33 | try: 34 | chunk = Chunk(self._file, bigendian=False) 35 | except EOFError: 36 | break 37 | 38 | if chunk.getname() == 'fmt ': 39 | self._read_fmt_chunk(chunk) 40 | fmt_chunk_read = True 41 | elif chunk.getname() == 'data': 42 | if file_size > 0xFFFFFFFF: 43 | # large broken wav 44 | self.frames_count = (file_size - self._file.tell()) // self.frame_size 45 | else: 46 | self.frames_count = chunk.chunksize // self.frame_size 47 | data_chink_read = True 48 | break 49 | chunk.skip() 50 | if not fmt_chunk_read or not data_chink_read: 51 | raise SushiError('Invalid WAV file') 52 | except: 53 | self.close() 54 | raise 55 | 56 | def __del__(self): 57 | self.close() 58 | 59 | def close(self): 60 | if self._file: 61 | self._file.close() 62 | self._file = None 63 | 64 | def readframes(self, count): 65 | if not count: 66 | return '' 67 | data = self._file.read(count * self.frame_size) 68 | if self.sample_width == 2: 69 | unpacked = np.fromstring(data, dtype=np.int16) 70 | elif self.sample_width == 3: 71 | raw_bytes = np.ndarray(len(data), 'int8', data) 72 | unpacked = np.zeros(len(data) / 3, np.int16) 73 | unpacked.view(dtype='int8')[0::2] = raw_bytes[1::3] 74 | unpacked.view(dtype='int8')[1::2] = raw_bytes[2::3] 75 | else: 76 | raise SushiError('Unsupported sample width: {0}'.format(self.sample_width)) 77 | 78 | unpacked = unpacked.astype('float32') 79 | 80 | if self.channels_count == 1: 81 | return unpacked 82 | else: 83 | min_length = len(unpacked) // self.channels_count 84 | real_length = len(unpacked) / float(self.channels_count) 85 | if min_length != real_length: 86 | logging.error("Length of audio channels didn't match. This might result in broken output") 87 | 88 | channels = (unpacked[i::self.channels_count] for i in xrange(self.channels_count)) 89 | data = reduce(lambda a, b: a[:min_length]+b[:min_length], channels) 90 | data /= float(self.channels_count) 91 | return data 92 | 93 | def _read_fmt_chunk(self, chunk): 94 | wFormatTag, self.channels_count, self.framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack('= 0], overwrite_input=True) * 3 146 | min_value = np.median(self.data[self.data <= 0], overwrite_input=True) * 3 147 | 148 | np.clip(self.data, min_value, max_value, out=self.data) 149 | 150 | self.data -= min_value 151 | self.data /= (max_value - min_value) 152 | 153 | if sample_type == 'uint8': 154 | self.data *= 255.0 155 | self.data += 0.5 156 | self.data = self.data.astype('uint8') 157 | 158 | except Exception as e: 159 | raise SushiError('Error while loading {0}: {1}'.format(path, e)) 160 | finally: 161 | stream.close() 162 | logging.info('Done reading WAV {0} in {1}s'.format(path, time() - before_read)) 163 | 164 | @property 165 | def duration_seconds(self): 166 | return self.sample_count / self.sample_rate 167 | 168 | def get_substream(self, start, end): 169 | start_off = self._get_sample_for_time(start) 170 | end_off = self._get_sample_for_time(end) 171 | return self.data[:, start_off:end_off] 172 | 173 | def _get_sample_for_time(self, timestamp): 174 | # this function gets REAL sample for time, taking padding into account 175 | return int(self.sample_rate * timestamp) + self.padding_size 176 | 177 | def find_substream(self, pattern, window_center, window_size): 178 | start_time = clip(window_center - window_size, -self.PADDING_SECONDS, self.duration_seconds) 179 | end_time = clip(window_center + window_size, 0, self.duration_seconds + self.PADDING_SECONDS) 180 | 181 | start_sample = self._get_sample_for_time(start_time) 182 | end_sample = self._get_sample_for_time(end_time) + len(pattern[0]) 183 | 184 | search_source = self.data[:, start_sample:end_sample] 185 | result = cv2.matchTemplate(search_source, pattern, cv2.TM_SQDIFF_NORMED) 186 | min_idx = result.argmin(axis=1)[0] 187 | 188 | return result[0][min_idx], start_time + (min_idx / float(self.sample_rate)) 189 | --------------------------------------------------------------------------------