├── inference ├── __init__.py ├── transnetv2-weights │ ├── saved_model.pb │ └── variables │ │ ├── variables.index │ │ └── variables.data-00000-of-00001 ├── Dockerfile ├── README.md └── transnetv2.py ├── .gitignore ├── .gitattributes ├── configs ├── transnetv2-realtrans.gin ├── transnetv1.gin └── transnetv2.gin ├── training ├── video_utils.py ├── Dockerfile ├── evaluate.py ├── visualization_utils.py ├── metrics_utils.py ├── consolidate_datasets.py ├── models.py ├── create_dataset.py ├── bi_tempered_loss.py ├── training.py ├── weight_decay_optimizers.py ├── transnet.py └── input_processing.py ├── setup.py ├── LICENSE ├── inference-pytorch ├── README.md ├── convert_weights.py └── transnetv2_pytorch.py └── README.md /inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .transnetv2 import TransNetV2 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.tfrecord 2 | *.npy 3 | *.mp4 4 | *.pickle 5 | .idea 6 | .ipynb_checkpoints 7 | -------------------------------------------------------------------------------- /inference/transnetv2-weights/saved_model.pb: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8ac2a52c5719690d512805b6eaf5ce12097c1d8860b3d9de245dcbbc3100f554 3 | size 5933260 4 | -------------------------------------------------------------------------------- /inference/transnetv2-weights/variables/variables.index: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8b99e28b4ad11372d9a1ad9703298c2e370df14859da4245fdbe818e92dd403f 3 | size 5526 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.h5 filter=lfs diff=lfs merge=lfs -text 2 | *.pb filter=lfs diff=lfs merge=lfs -text 3 | *.index filter=lfs diff=lfs merge=lfs -text 4 | *.data-* filter=lfs diff=lfs merge=lfs -text 5 | -------------------------------------------------------------------------------- /inference/transnetv2-weights/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b8c9dc3eb807583e6215cabee9ca61737b3eb1bceff68418b43bf71459669367 3 | size 30516656 4 | -------------------------------------------------------------------------------- /inference/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:2.1.1-gpu 2 | 3 | RUN pip3 --no-cache-dir install \ 4 | Pillow \ 5 | ffmpeg-python 6 | 7 | RUN apt-get update && apt-get install -y --no-install-recommends \ 8 | ffmpeg 9 | 10 | COPY setup.py /tmp 11 | COPY inference /tmp/inference 12 | 13 | RUN cd /tmp && python3 setup.py install && rm -r * 14 | -------------------------------------------------------------------------------- /configs/transnetv2-realtrans.gin: -------------------------------------------------------------------------------- 1 | include "./configs/transnetv2.gin" 2 | 3 | options.log_name = "transnetv2-realtrans" 4 | options.n_epochs = 50 5 | options.transition_only_data_fraction = 0.15 6 | options.transition_only_trn_files = [ 7 | "data/48x27/ClipShotsTrainTransitions/*.tfrecord", 8 | "data/48x27/ClipShotsGradual-transitions/*.tfrecord" 9 | ] 10 | 11 | concat_shots.hard_cut_prob = 0.412 12 | -------------------------------------------------------------------------------- /training/video_utils.py: -------------------------------------------------------------------------------- 1 | import ffmpeg 2 | import numpy as np 3 | 4 | 5 | def get_frames(fn, width=48, height=27): 6 | video_stream, err = ( 7 | ffmpeg 8 | .input(fn) 9 | .output('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width, height)) 10 | .run(capture_stdout=True, capture_stderr=True) 11 | ) 12 | video = np.frombuffer(video_stream, np.uint8).reshape([-1, height, width, 3]) 13 | return video 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="transnetv2", 5 | version="1.0.0", 6 | # let user install tensorflow, etc. manually 7 | # install_requires=[ 8 | # "tensorflow>=2.0", 9 | # "ffmpeg-python", 10 | # "pillow" 11 | # ], 12 | entry_points={ 13 | "console_scripts": [ 14 | "transnetv2_predict = transnetv2.transnetv2:main", 15 | ] 16 | }, 17 | packages=["transnetv2"], 18 | package_dir={"transnetv2": "./inference"}, 19 | package_data={"transnetv2": [ 20 | "transnetv2-weights/*", 21 | "transnetv2-weights/variables/*" 22 | ]}, 23 | zip_safe=False 24 | ) 25 | -------------------------------------------------------------------------------- /training/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:devel-gpu-py3 2 | 3 | RUN apt-get update && apt-get install -y --no-install-recommends \ 4 | build-essential \ 5 | pkg-config \ 6 | rsync \ 7 | unzip \ 8 | zip \ 9 | zlib1g-dev \ 10 | wget \ 11 | git \ 12 | ffmpeg \ 13 | libsm6 \ 14 | libxext6 \ 15 | libxrender1 \ 16 | libfontconfig1 \ 17 | vim 18 | 19 | RUN pip3 --no-cache-dir install \ 20 | Pillow \ 21 | h5py \ 22 | keras_applications \ 23 | keras_preprocessing \ 24 | matplotlib \ 25 | mock \ 26 | numpy \ 27 | scipy \ 28 | sklearn \ 29 | pandas \ 30 | tensorflow-gpu \ 31 | tqdm \ 32 | ffmpeg-python \ 33 | pyyaml \ 34 | opencv-python \ 35 | opencv-contrib-python \ 36 | shapely 37 | 38 | RUN git clone https://github.com/soCzech/gin-config && cd gin-config && python3 setup.py install 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Tomáš Souček 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /inference-pytorch/README.md: -------------------------------------------------------------------------------- 1 | # Pytorch implementation of TransNet V2 2 | 3 | This is Pytorch reimplementation of the TransNetV2 model. 4 | It should produce identical results as the Tensorflow version. 5 | The code is for inference only, there is no plan to release Pytorch implementation of the training code. 6 | 7 | See [tensorflow inference readme](https://github.com/soCzech/TransNetV2/tree/master/inference) 8 | for details and code how to get correctly predictions for a whole video file. 9 | 10 | ### INSTALL REQUIREMENTS 11 | ```bash 12 | pip install tensorflow==2.1 # needed for model weights conversion 13 | conda install pytorch=1.7.1 cudatoolkit=10.1 -c pytorch 14 | ``` 15 | 16 | ### CONVERT WEIGHTS 17 | Firstly tensorflow weights file needs to be converted into pytorch weights file. 18 | ```bash 19 | python convert_weights.py [--test] 20 | ``` 21 | The pytorch weights are saved into *transnetv2-pytorch-weights.pth* file. 22 | 23 | ### ADVANCED USAGE 24 | 25 | ```python 26 | import torch 27 | from transnetv2_pytorch import TransNetV2 28 | 29 | model = TransNetV2() 30 | state_dict = torch.load("transnetv2-pytorch-weights.pth") 31 | model.load_state_dict(state_dict) 32 | model.eval().cuda() 33 | 34 | with torch.no_grad(): 35 | # shape: batch dim x video frames x frame height x frame width x RGB (not BGR) channels 36 | input_video = torch.zeros(1, 100, 27, 48, 3, dtype=torch.uint8) 37 | single_frame_pred, all_frame_pred = model(input_video.cuda()) 38 | 39 | single_frame_pred = torch.sigmoid(single_frame_pred).cpu().numpy() 40 | all_frame_pred = torch.sigmoid(all_frame_pred["many_hot"]).cpu().numpy() 41 | ``` 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransNet V2: Shot Boundary Detection Neural Network 2 | 3 | This repository contains code for [TransNet V2: An effective deep network architecture for fast shot transition detection](https://arxiv.org/abs/2008.04838). 4 | 5 | Our reevaluation of other publicly available state-of-the-art shot boundary methods (F1 scores): 6 | 7 | Model | ClipShots | BBC Planet Earth | RAI 8 | --- | :---: | :---: | :---: 9 | TransNet V2 (this repo) | **77.9** | **96.2** | 93.9 10 | [TransNet](https://arxiv.org/abs/1906.03363) [(github)](https://github.com/soCzech/TransNet) | 73.5 | 92.9 | **94.3** 11 | [Hassanien et al.](https://arxiv.org/abs/1705.03281) [(github)](https://github.com/melgharib/DSBD) | 75.9 | 92.6 | 93.9 12 | [Tang et al., ResNet baseline](https://arxiv.org/abs/1808.04234) [(github)](https://github.com/Tangshitao/ClipShots_basline) | 76.1 | 89.3 | 92.8 13 | 14 | 15 | ### :movie_camera: USE IT ON YOUR VIDEO 16 | :arrow_right: **See [_inference_ folder](https://github.com/soCzech/TransNetV2/tree/master/inference) and its _README_ file.** :arrow_left: 17 | 18 | 19 | ### :rocket: PYTORCH VERSION for inference RELEASED 20 | **See [_inference-pytorch_ folder](https://github.com/soCzech/TransNetV2/tree/master/inference-pytorch) and its _README_ file.** 21 | 22 | 23 | ### REPLICATE THE WORK 24 | > Note the datasets for training are tens of gigabytes in size, hundreds of gigabytes when exported. 25 | > 26 | > **You do not need to train the network, use code and instructions in [_inference_ folder](https://github.com/soCzech/TransNetV2/tree/master/inference) to detect shots in your videos.** 27 | 28 | This repository contains all that is needed to run any experiment for TransNet V2 network including network training and dataset creation. 29 | All experiments should be runnable in [this NVIDIA DOCKER file](https://github.com/soCzech/TransNetV2/blob/master/training/Dockerfile). 30 | 31 | In general these steps need to be done in order to replicate our work (in [_training_ folder](https://github.com/soCzech/TransNetV2/tree/master/training)): 32 | 33 | 1. Download RAI and BBC Planet Earth test datasets [(link)](https://aimagelab.ing.unimore.it/imagelab/researchActivity.asp?idActivity=19). 34 | Download ClipShots train/test dataset [(link)](https://github.com/Tangshitao/ClipShots). 35 | Optionally get IACC.3 dataset. 36 | 2. Edit and run `consolidate_datasets.py` in order to transform ground truth from all the datasets into one common format. 37 | 3. Take some videos from ClipShotsTrain aside as a validation dataset. 38 | 4. Run `create_dataset.py` to create all train/validation/test datasets. 39 | 5. Run `training.py ../configs/transnetv2.gin` to train a model. 40 | 6. Run `evaluate.py /path/to/run_log_dir epoch_no /path/to/test_dataset` for proper evaluation. 41 | 42 | 43 | ### CREDITS 44 | If found useful, please cite us;) 45 | - This paper: [TransNet V2: An effective deep network architecture for fast shot transition detection](https://arxiv.org/abs/2008.04838) 46 | ``` 47 | @article{soucek2020transnetv2, 48 | title={TransNet V2: An effective deep network architecture for fast shot transition detection}, 49 | author={Sou{\v{c}}ek, Tom{\'a}{\v{s}} and Loko{\v{c}}, Jakub}, 50 | year={2020}, 51 | journal={arXiv preprint arXiv:2008.04838}, 52 | } 53 | ``` 54 | 55 | - ACM Multimedia paper of the older version: [A Framework for Effective Known-item Search in Video](https://dl.acm.org/doi/abs/10.1145/3343031.3351046) 56 | 57 | - The older version paper: [TransNet: A deep network for fast detection of common shot transitions](https://arxiv.org/abs/1906.03363) 58 | -------------------------------------------------------------------------------- /inference/README.md: -------------------------------------------------------------------------------- 1 | # TransNet V2: Shot Boundary Detection Neural Network 2 | 3 | Inference code for [TransNet V2: An effective deep network architecture for fast shot transition detection](https://arxiv.org/abs/2008.04838). 4 | 5 | ### INSTALL REQUIREMENTS 6 | ```bash 7 | pip install tensorflow==2.1 8 | ``` 9 | 10 | If you want to predict directly on video files, install `ffmpeg`. 11 | If you want to visualize results also install `pillow` (simple usage requires both). 12 | ```bash 13 | apt-get install ffmpeg 14 | pip install ffmpeg-python pillow 15 | ``` 16 | 17 | or **use NVIDIA DOCKER**! 18 | ``` 19 | # run from the root directory of the repository 20 | docker build -t transnet -f inference/Dockerfile . 21 | ``` 22 | Then simply use it the following way: 23 | ``` 24 | docker run -it --rm --gpus 1 -v /path/to/video/dir:/tmp transnet transnetv2_predict /tmp/video.mp4 [--visualize] 25 | ``` 26 | 27 | > Note `transnetv2-weights` directory contains files in git-lfs. 28 | > You may need to install git-lfs and run `git lfs pull` from the root directory of the repository 29 | > (or you can download `transnetv2-weights` directory manually). 30 | 31 | ### INSTALL AS PYTHON PACKAGE (optional) 32 | Run `python setup.py install` from the root directory of the repository. 33 | 34 | 35 | ### SIMPLE USAGE 36 | 37 | ``` 38 | # run from this directory 39 | python transnetv2.py /path/to/video.mp4 [--visualize] 40 | # or if installed as python package, run from anywhere 41 | transnetv2_predict /path/to/video.mp4 [--visualize] 42 | ``` 43 | 44 | It creates: 45 | - `/path/to/video.mp4.scenes.txt` file containing a list of scenes - pairs of 46 | *start-frame-index*, *end-frame-index* (indexed from zero, both limits inclusive). 47 | - `/path/to/video.mp4.predictions.txt` file with each line containing raw predictions for corresponding frame 48 | (fist number is from the first 'single-frame-per-transition' head, the second from 'all-frames-per-transition' head) 49 | - optionally it creates visualization in file `/path/to/video.mp4.vis.png` 50 | 51 | 52 | ### ADVANCED USAGE 53 | - Get predictions: 54 | ```python 55 | from transnetv2 import TransNetV2 56 | 57 | # location of learned weights is automatically inferred 58 | # add argument model_dir="/path/to/transnetv2-weights/" to TransNetV2() if it fails 59 | model = TransNetV2() 60 | video_frames, single_frame_predictions, all_frame_predictions = \ 61 | model.predict_video("/path/to/video.mp4") 62 | 63 | # or 64 | video_frames = ... # np.array, shape: [n_frames, 27, 48, 3], dtype: np.uint8, RGB (not BGR) 65 | single_frame_predictions, all_frame_predictions = \ 66 | model.predict_frames(video_frames) 67 | ``` 68 | 69 | - Get scenes from predictions: 70 | ```python 71 | model.predictions_to_scenes(single_frame_predictions) 72 | ``` 73 | 74 | - Visualize predictions: 75 | ```python 76 | model.visualize_predictions( 77 | video_frames, predictions=(single_frame_predictions, all_frame_predictions)) 78 | ``` 79 | 80 | ### NOTES 81 | > :exclamation: It may happen that you get **DecodeError**, **OSError**, **IOError** with text *'Error parsing message'*. It is caused by corrupted files in *transnetv2-weights* folder. To fix the error, re-download the files manually. SHA256 sums for the files can be found in [issue #1](https://github.com/soCzech/TransNetV2/issues/1#issuecomment-647357796). 82 | 83 | > Note that your results on test sets can slightly differ when using different extraction method or particular `ffmpeg` version. 84 | 85 | ### CREDITS 86 | If found useful, please cite us;) 87 | - This paper: [TransNet V2: An effective deep network architecture for fast shot transition detection](https://arxiv.org/abs/2008.04838) 88 | ``` 89 | @article{soucek2020transnetv2, 90 | title={TransNet V2: An effective deep network architecture for fast shot transition detection}, 91 | author={Sou{\v{c}}ek, Tom{\'a}{\v{s}} and Loko{\v{c}}, Jakub}, 92 | year={2020}, 93 | journal={arXiv preprint arXiv:2008.04838}, 94 | } 95 | ``` 96 | 97 | - ACM Multimedia paper of the older version: [A Framework for Effective Known-item Search in Video](https://dl.acm.org/doi/abs/10.1145/3343031.3351046) 98 | 99 | - The older version paper: [TransNet: A deep network for fast detection of common shot transitions](https://arxiv.org/abs/1906.03363) 100 | -------------------------------------------------------------------------------- /training/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gin 3 | import glob 4 | import tqdm 5 | import pickle 6 | import argparse 7 | import numpy as np 8 | import tensorflow as tf 9 | import gin.tf.external_configurables 10 | 11 | import models 12 | import transnet 13 | import training 14 | import metrics_utils 15 | import create_dataset 16 | import input_processing 17 | import visualization_utils 18 | 19 | import logging 20 | logger = tf.get_logger() 21 | logger.setLevel(logging.ERROR) 22 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 23 | 24 | 25 | def get_batches(frames): 26 | reminder = 50 - len(frames) % 50 27 | if reminder == 50: 28 | reminder = 0 29 | frames = np.concatenate([frames[:1]] * 25 + [frames] + [frames[-1:]] * (reminder + 25), 0) 30 | 31 | def func(): 32 | for i in range(0, len(frames) - 50, 50): 33 | yield frames[i:i+100] 34 | return func() 35 | 36 | 37 | if __name__ == "__main__": 38 | 39 | parser = argparse.ArgumentParser(description="Evaluate TransNet") 40 | parser.add_argument("log_dir", help="path to log dir") 41 | parser.add_argument("epoch", help="what weights to use", type=int) 42 | parser.add_argument("directory", help="path to the test dataset") 43 | parser.add_argument("--thr", default=0.5, type=float, help="threshold for transition") 44 | args = parser.parse_args() 45 | 46 | print(args) 47 | gin.parse_config_file(os.path.join(args.log_dir, "config.gin")) 48 | options = training.get_options_dict(create_dir_and_summaries=False) 49 | 50 | if options["original_transnet"]: 51 | net = models.OriginalTransNet() 52 | logit_fc = lambda x: tf.nn.softmax(x)[:, :, 1] 53 | 54 | else: 55 | net = transnet.TransNetV2() 56 | logit_fc = tf.sigmoid 57 | 58 | @tf.function(autograph=False) 59 | def predict(batch): 60 | one_hot = net(tf.cast(batch, tf.float32)[tf.newaxis]) 61 | if isinstance(one_hot, tuple): 62 | one_hot = one_hot[0] 63 | return logit_fc(one_hot)[0] 64 | 65 | net(tf.zeros([1] + options["input_shape"], tf.float32)) 66 | net.load_weights(os.path.join(args.log_dir, "weights-{:d}.h5".format(args.epoch))) 67 | files = glob.glob(os.path.join(args.directory, "*.npy")) 68 | 69 | results = [] 70 | total_stats = {"tp": 0, "fp": 0, "fn": 0} 71 | 72 | dataset_name = [i for i in args.directory.split("/") if i != ""][-1] 73 | img_dir = os.path.join(args.log_dir, "results", "{}-epoch{:d}".format(dataset_name, args.epoch)) 74 | os.makedirs(img_dir, exist_ok=True) 75 | 76 | for np_fn in tqdm.tqdm(files): 77 | predictions = [] 78 | frames = np.load(np_fn) 79 | 80 | for batch in get_batches(frames): 81 | one_hot = predict(batch) 82 | predictions.append(one_hot[25:75]) 83 | 84 | predictions = np.concatenate(predictions, 0)[:len(frames)] 85 | gt_scenes = np.loadtxt(np_fn[:-3] + "txt", dtype=np.int32, ndmin=2) 86 | 87 | _, _, _, (tp, fp, fn), fp_mistakes, fn_mistakes = metrics_utils.evaluate_scenes( 88 | gt_scenes, metrics_utils.predictions_to_scenes((predictions >= args.thr).astype(np.uint8)), 89 | return_mistakes=True) 90 | 91 | total_stats["tp"] += tp 92 | total_stats["fp"] += fp 93 | total_stats["fn"] += fn 94 | 95 | if len(fp_mistakes) > 0 or len(fn_mistakes) > 0: 96 | img = visualization_utils.visualize_errors( 97 | frames, predictions, 98 | create_dataset.scenes2zero_one_representation(gt_scenes, len(frames))[1], 99 | fp_mistakes, fn_mistakes) 100 | if img is not None: 101 | img.save(os.path.join(img_dir, os.path.basename(np_fn[:-3]) + "png")) 102 | 103 | results.append((np_fn, predictions, gt_scenes)) 104 | 105 | with open(os.path.join(args.log_dir, "results", "{}-epoch{:d}.pickle".format(dataset_name, args.epoch)), "wb") as f: 106 | pickle.dump(results, f) 107 | 108 | p = total_stats["tp"] / (total_stats["tp"] + total_stats["fp"]) 109 | r = total_stats["tp"] / (total_stats["tp"] + total_stats["fn"]) 110 | f1 = (p * r * 2) / (p + r) 111 | print(f""" 112 | Precision:{p*100:5.2f}% 113 | Recall: {r*100:5.2f}% 114 | F1 Score: {f1*100:5.2f}% 115 | """) 116 | -------------------------------------------------------------------------------- /inference-pytorch/convert_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | import transnetv2_pytorch 8 | 9 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 10 | 11 | 12 | def remap_name(x): 13 | x = x.replace("TransNet/", "") 14 | l = [] 15 | for a in x.split("/"): 16 | if a.startswith("SDDCNN") or a.startswith("DDCNN"): 17 | a = a.split("_") 18 | a = a[0] + "." + str(int(a[1]) - 1) 19 | elif a == "conv_spatial": 20 | a = "layers.0" 21 | elif a == "conv_temporal": 22 | a = "layers.1" 23 | elif a == "kernel:0" or a == "gamma:0": 24 | a = "weight" 25 | elif a == "bias:0" or a == "beta:0": 26 | a = "bias" 27 | elif a == "dense": 28 | a = "fc1" 29 | elif a == "dense_1": 30 | a = "cls_layer1" 31 | elif a == "dense_2": 32 | a = "cls_layer2" 33 | elif a == "dense_3": 34 | a = "frame_sim_layer.projection" 35 | elif a == "dense_4": 36 | a = "frame_sim_layer.fc" 37 | elif a == "dense_5": 38 | a = "color_hist_layer.fc" 39 | elif a == "FrameSimilarity" or a == "ColorHistograms": 40 | a = "" 41 | elif a == "moving_mean:0": 42 | a = "running_mean" 43 | elif a == "moving_variance:0": 44 | a = "running_var" 45 | l.append(a) 46 | x = ".".join([a for a in l if a != ""]) 47 | return x 48 | 49 | 50 | def remap_tensor(x): 51 | x = x.numpy() 52 | if len(x.shape) == 5: 53 | x = np.transpose(x, [0, 1, 2, 4, 3]) 54 | x = np.transpose(x, [3, 4, 0, 1, 2]) 55 | elif len(x.shape) == 2: 56 | x = np.transpose(x) 57 | return torch.from_numpy(x).clone() 58 | 59 | 60 | def check_and_fix_dicts(tf_dict, torch_dict): 61 | error = False 62 | 63 | for k in torch_dict.keys(): 64 | if k not in tf_dict: 65 | if k.endswith("num_batches_tracked"): 66 | tf_dict[k] = torch.tensor(1., dtype=torch.float32) 67 | else: 68 | print("!", k, "missing in TF") 69 | error = True 70 | 71 | for k in tf_dict.keys(): 72 | if k not in torch_dict: 73 | print("!", k, "missing in TORCH") 74 | error = True 75 | if tuple(tf_dict[k].shape) != torch_dict[k]: 76 | print("!", k, f"has wrong shape (TF: {tuple(tf_dict[k].shape)}, TORCH: {torch_dict[k]})") 77 | error = True 78 | 79 | return not error 80 | 81 | 82 | def convert_weights(tf_weights_dir): 83 | tf_model = tf.saved_model.load(tf_weights_dir) 84 | tf_dict = {remap_name(v.name): remap_tensor(v) for v in tf_model.variables} 85 | 86 | torch_model = transnetv2_pytorch.TransNetV2() 87 | torch_dict = {k: tuple(v.shape) for k, v in list(torch_model.named_parameters()) + list(torch_model.named_buffers())} 88 | 89 | assert check_and_fix_dicts(tf_dict, torch_dict), "some errors occurred when converting" 90 | torch_model.load_state_dict(tf_dict) 91 | 92 | return torch_model, tf_model 93 | 94 | 95 | def test_models(torch_model, tf_model): 96 | input_tensors = [np.random.randint(0, 255, size=(2, 100, 27, 48, 3), dtype=np.uint8) for _ in range(10)] 97 | 98 | print("Tests: computing forward passes...") 99 | with torch.no_grad(): 100 | torch_outputs = [torch_model(torch.from_numpy(x)) for x in input_tensors] 101 | tf_outputs = [tf_model(tf.cast(x, tf.float32)) for x in input_tensors] 102 | 103 | for i, ((torch_single, torch_many), (tf_single, tf_many)) in enumerate(zip(torch_outputs, tf_outputs)): 104 | single = np.isclose(torch_single.numpy(), tf_single.numpy()).mean() 105 | many = np.isclose(torch_many["many_hot"].numpy(), tf_many["many_hot"].numpy()).mean() 106 | 107 | print(f"Test {i:2d}: " 108 | f"{single * 100:5.1f}% of 'single' predictions matching, " 109 | f"{many * 100:5.1f}% of 'many' predictions matching") 110 | 111 | 112 | def main(): 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument("--tf_weights", type=str, help="path to TransNet V2 weights", 115 | default="../inference/transnetv2-weights/") 116 | parser.add_argument('--test', action="store_true", help="run tests") 117 | args = parser.parse_args() 118 | 119 | torch_model, tf_model = convert_weights(args.tf_weights) 120 | 121 | print("Saving model to ./transnetv2-pytorch-weights.pth") 122 | torch.save(torch_model.state_dict(), "./transnetv2-pytorch-weights.pth") 123 | 124 | if args.test: 125 | test_models(torch_model, tf_model) 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /training/visualization_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageDraw 3 | 4 | 5 | def visualize_scenes(frames: np.ndarray, scenes: np.ndarray): 6 | nf, ih, iw, ic = frames.shape 7 | width = 25 8 | if len(frames) % width != 0: 9 | pad_with = width - len(frames) % width 10 | frames = np.concatenate([frames, np.zeros([pad_with, ih, iw, ic], np.uint8)]) 11 | height = len(frames) // width 12 | 13 | scene = frames.reshape([height, width, ih, iw, ic]) 14 | scene = np.concatenate(np.split( 15 | np.concatenate(np.split(scene, height), axis=2)[0], width 16 | ), axis=2)[0] 17 | 18 | img = Image.fromarray(scene) 19 | draw = ImageDraw.Draw(img, "RGBA") 20 | 21 | def draw_start_frame(frame_no): 22 | w = frame_no % width 23 | h = frame_no // width 24 | draw.rectangle([(w * iw, h * ih), (w * iw + 2, h * ih + ih - 1)], fill=(255, 0, 0)) 25 | draw.polygon( 26 | [(w * iw + 7, h * ih + ih // 2 - 4), (w * iw + 12, h * ih + ih // 2), (w * iw + 7, h * ih + ih // 2 + 4)], 27 | fill=(255, 0, 0)) 28 | draw.rectangle([(w * iw, h * ih + ih // 2 - 1), (w * iw + 7, h * ih + ih // 2 + 1)], fill=(255, 0, 0)) 29 | 30 | def draw_end_frame(frame_no): 31 | w = frame_no % width 32 | h = frame_no // width 33 | draw.rectangle([(w * iw + iw - 1, h * ih), (w * iw + iw - 3, h * ih + ih - 1)], fill=(255, 0, 0)) 34 | draw.polygon([(w * iw + iw - 8, h * ih + ih // 2 - 4), (w * iw + iw - 13, h * ih + ih // 2), 35 | (w * iw + iw - 8, h * ih + ih // 2 + 4)], fill=(255, 0, 0)) 36 | draw.rectangle([(w * iw + iw - 1, h * ih + ih // 2 - 1), (w * iw + iw - 8, h * ih + ih // 2 + 1)], 37 | fill=(255, 0, 0)) 38 | 39 | def draw_transition_frame(frame_no): 40 | w = frame_no % width 41 | h = frame_no // width 42 | draw.rectangle([(w * iw, h * ih), (w * iw + iw - 1, h * ih + ih - 1)], fill=(128, 128, 128, 180)) 43 | 44 | curr_frm, curr_scn = 0, 0 45 | 46 | while curr_scn < len(scenes): 47 | start, end = scenes[curr_scn] 48 | # gray out frames that are not in any scene 49 | while curr_frm < start: 50 | draw_transition_frame(curr_frm) 51 | curr_frm += 1 52 | 53 | # draw start and end of a scene 54 | draw_start_frame(curr_frm) 55 | draw_end_frame(end) 56 | 57 | # go to the next scene 58 | curr_frm = end + 1 59 | curr_scn += 1 60 | 61 | # gray out the last frames that are not in any scene (if any) 62 | while curr_frm < nf: 63 | draw_transition_frame(curr_frm) 64 | curr_frm += 1 65 | 66 | return img 67 | 68 | 69 | def visualize_predictions(frame_sequence, one_hot_pred, one_hot_gt, many_hot_pred=None, many_hot_gt=None): 70 | batch_size = len(frame_sequence) 71 | 72 | images = [] 73 | for i in range(batch_size): 74 | scene = frame_sequence[i] 75 | scene_labels = one_hot_gt[i] 76 | scene_one_hot_pred = one_hot_pred[i] 77 | scene_many_hot_pred = many_hot_pred[i] if many_hot_pred is not None else None 78 | 79 | scene_len, ih, iw = scene.shape[:3] 80 | 81 | grid_width = max([i for i in range(int(scene_len ** .5), 0, -1) if scene_len % i == 0]) 82 | grid_height = scene_len // grid_width 83 | 84 | scene = scene.reshape([grid_height, grid_width] + list(scene.shape[1:])) 85 | scene = np.concatenate(np.split( 86 | np.concatenate(np.split(scene, grid_height), axis=2)[0], grid_width 87 | ), axis=2)[0] 88 | 89 | img = Image.fromarray(scene.astype(np.uint8)) 90 | draw = ImageDraw.Draw(img) 91 | 92 | j = 0 93 | for h in range(grid_height): 94 | for w in range(grid_width): 95 | if scene_labels[j] == 1: 96 | draw.text((5 + w * iw, h * ih), "T", fill=(0, 255, 0)) 97 | 98 | draw.rectangle([(w * iw + iw - 1, h * ih), (w * iw + iw - 6, h * ih + ih - 1)], fill=(0, 0, 0)) 99 | draw.rectangle([(w * iw + iw - 4, h * ih), 100 | (w * iw + iw - 5, h * ih + (ih - 1) * scene_one_hot_pred[j])], fill=(0, 255, 0)) 101 | draw.rectangle([(w * iw + iw - 2, h * ih), 102 | (w * iw + iw - 3, h * ih + (ih - 1) * ( 103 | scene_many_hot_pred[j] if scene_many_hot_pred is not None else 0 104 | ))], fill=(255, 255, 0)) 105 | j += 1 106 | 107 | images.append(np.array(img)) 108 | 109 | images = np.stack(images, 0) 110 | return images 111 | 112 | 113 | def visualize_errors(frames, predictions, targets, fp_mistakes, fn_mistakes): 114 | scenes, scene_preds = [], [] 115 | _, ih, iw, _ = frames.shape 116 | 117 | for mistakes in [fp_mistakes, fn_mistakes]: 118 | for start, end in mistakes: 119 | idx = int(start + (end - start) // 2) 120 | scene = frames[max(0, idx - 25):][:50] 121 | scene_pred = predictions[max(0, idx - 25):][:50] 122 | scene_tar = targets[max(0, idx - 25):][:50] 123 | 124 | if len(scene) < 50: 125 | continue 126 | scenes.append(scene) 127 | scene_preds.append((scene_tar, scene_pred)) 128 | 129 | if len(scenes) == 0: 130 | return None 131 | scenes = np.concatenate([np.concatenate(list(scene), 1) for scene in scenes], 0) 132 | 133 | img = Image.fromarray(scenes) 134 | draw = ImageDraw.Draw(img) 135 | for h, preds in enumerate(scene_preds): 136 | for w, (tar, pred) in enumerate(zip(*preds)): 137 | if tar == 1: 138 | draw.text((w * iw + iw - 10, h * ih), "T", fill=(255, 0, 0)) 139 | 140 | draw.rectangle([(w * iw + iw - 1, h * ih), (w * iw + iw - 4, h * ih + ih - 1)], 141 | fill=(0, 0, 0)) 142 | draw.rectangle([(w * iw + iw - 2, h * ih), 143 | (w * iw + iw - 3, h * ih + (ih - 1) * pred)], fill=(0, 255, 0)) 144 | return img 145 | -------------------------------------------------------------------------------- /configs/transnetv1.gin: -------------------------------------------------------------------------------- 1 | # Global Config 2 | ################ 3 | 4 | shot_len = 100 5 | frame_width = 48 6 | frame_height = 27 7 | 8 | bi_tempered_loss = False 9 | bi_tempered_loss_temp1 = 1. 10 | bi_tempered_loss_temp2 = 1. 11 | 12 | options.n_epochs = 10 13 | options.log_dir = "logs" 14 | options.log_name = "transnetv1" 15 | options.trn_files = ["data/48x27/IACC3Random3000/*.tfrecord"] 16 | # options.trn_files = [ 17 | # "data/48x27/IACC3Random3000/*.tfrecord", 18 | # "data/48x27/ClipShotsTrain-Train/*.tfrecord", 19 | # "data/48x27/ClipShotsGradual/*.tfrecord" 20 | # ] 21 | options.tst_files = { 22 | "clip_shots_val": ["data/48x27/ClipShotsTrain-Valid/*.tfrecord"], 23 | "dissolves_20": ["data/48x27/ClipShotsTrain-validation/dissolves_20.tfrecord"], 24 | "dissolves_60": ["data/48x27/ClipShotsTrain-validation/dissolves_60.tfrecord"], 25 | "hardcuts": ["data/48x27/ClipShotsTrain-validation/hardcuts.tfrecord"], 26 | "iacc100": ["data/48x27/IACC3Subset100/*.tfrecord"], 27 | } 28 | options.transition_only_trn_files = None 29 | options.transition_only_data_fraction = 0.3 # meaningful if transition_only_trn_files is not None 30 | options.input_shape = [%shot_len, %frame_height, %frame_width, 3] 31 | options.test_only = False 32 | options.restore = None 33 | options.restore_resnet_features = None 34 | options.original_transnet = False 35 | options.c3d_net = False 36 | options.bi_tempered_loss = %bi_tempered_loss 37 | options.bi_tempered_loss_temp2 = %bi_tempered_loss_temp2 38 | options.learning_rate_schedule = None 39 | options.learning_rate_decay = None 40 | 41 | training.log_freq = 200 42 | training.grad_clipping = 10. 43 | training.n_batches_per_epoch = 750 44 | training.evaluate_on_middle_frames_only = True 45 | training.optimizer = @tf.keras.optimizers.Adam 46 | tf.keras.optimizers.Adam.learning_rate = 0.001 47 | 48 | loss.transition_weight = 1. 49 | loss.many_hot_loss_weight = 0. 50 | loss.l2_loss_weight = 0. 51 | loss.dynamic_weight = None 52 | loss.bi_tempered_loss = %bi_tempered_loss 53 | loss.bi_tempered_loss_temp1 = %bi_tempered_loss_temp1 54 | loss.bi_tempered_loss_temp2 = %bi_tempered_loss_temp2 55 | 56 | 57 | # Dataset Config 58 | ################# 59 | 60 | train_pipeline.shuffle_buffer = 100 61 | train_pipeline.shot_len = %shot_len 62 | train_pipeline.frame_width = %frame_width 63 | train_pipeline.frame_height = %frame_height 64 | train_pipeline.batch_size = 16 65 | train_pipeline.repeat = True 66 | 67 | parse_train_sample.shot_len = %shot_len 68 | parse_train_sample.frame_width = %frame_width 69 | parse_train_sample.frame_height = %frame_height 70 | parse_train_sample.sudden_color_change_prob = 0. 71 | parse_train_sample.spacial_augmentation = False 72 | parse_train_sample.original_width = None 73 | parse_train_sample.original_height = None 74 | 75 | train_transition_pipeline.shuffle_buffer = 100 76 | train_transition_pipeline.batch_size = 16 77 | train_transition_pipeline.repeat = False 78 | 79 | parse_train_transition_sample.shot_len = %shot_len 80 | parse_train_transition_sample.frame_width = %frame_width 81 | parse_train_transition_sample.frame_height = %frame_height 82 | 83 | augment_shot.up_down_flip_prob = 0 84 | augment_shot.left_right_flip_prob = 0 85 | augment_shot.adjust_saturation = False 86 | augment_shot.adjust_contrast = False 87 | augment_shot.adjust_brightness = False 88 | augment_shot.adjust_hue = False 89 | augment_shot.equalize_prob = 0. 90 | augment_shot.posterize_prob = 0. 91 | augment_shot.posterize_min_bits = 2 92 | augment_shot.color_prob = 0. 93 | augment_shot.color_min_val = 0.3 94 | augment_shot.color_max_val = 1.7 95 | 96 | augment_shot_spacial.random_shake_prob = 0.3 97 | augment_shot_spacial.random_shake_max_size = 15 98 | augment_shot_spacial.clip_left_right = 20 99 | augment_shot_spacial.clip_top_bottom = 10 100 | 101 | concat_shots.shot_len = %shot_len 102 | concat_shots.color_transfer_prob = 0. 103 | concat_shots.transition_min_len = 2 104 | concat_shots.transition_max_len = 30 105 | concat_shots.hard_cut_prob = 0.5 106 | concat_shots.cutout_prob = 0. 107 | concat_shots.advanced_shot_trans_prob = 0. 108 | 109 | # affected by: concat_shots.cutout_prob 110 | cutout.min_width_fraction = 0.3 111 | cutout.min_height_fraction = 0.3 112 | cutout.max_width_fraction = 0.6 113 | cutout.max_height_fraction = 0.6 114 | cutout.cutout_color = None # [0., 255.], None is random 115 | 116 | test_pipeline.shot_len = %shot_len 117 | test_pipeline.batch_size = 16 118 | 119 | parse_test_sample.frame_width = %frame_width 120 | parse_test_sample.frame_height = %frame_height 121 | 122 | 123 | # TransNet Config 124 | ################## 125 | 126 | TransNetV2.F = 16 127 | TransNetV2.L = 3 128 | TransNetV2.S = 2 129 | TransNetV2.D = 256 130 | TransNetV2.use_resnet_features = False 131 | TransNetV2.use_many_hot_targets = False 132 | TransNetV2.use_frame_similarity = False 133 | TransNetV2.use_mean_pooling = False 134 | TransNetV2.use_convex_comb_reg = False 135 | TransNetV2.dropout_rate = None 136 | TransNetV2.use_resnet_like_top = False 137 | TransNetV2.frame_similarity_on_last_layer = False 138 | TransNetV2.use_color_histograms = False 139 | 140 | StackedDDCNNV2.shortcut = False 141 | StackedDDCNNV2.use_octave_conv = False 142 | StackedDDCNNV2.pool_type = "max" 143 | StackedDDCNNV2.stochastic_depth_drop_prob = 0. # StackedDDCNNV2.shortcut must be True 144 | 145 | DilatedDCNNV2.batch_norm = False 146 | 147 | Conv3DConfigurable.separable = False 148 | Conv3DConfigurable.kernel_initializer = "glorot_uniform" 149 | 150 | # affected by: StackedDDCNNV2.use_octave_conv 151 | OctConv3D.alpha = 0.25 152 | 153 | # affected by: TransNetV2.use_resnet_features 154 | ResNetFeatures.trainable = False 155 | 156 | # affected by: TransNetV2.use_frame_similarity 157 | FrameSimilarity.similarity_dim = 128 158 | FrameSimilarity.lookup_window = 101 159 | FrameSimilarity.output_dim = 64 160 | FrameSimilarity.stop_gradient = False 161 | FrameSimilarity.use_bias = False 162 | 163 | # affected by: TransNetV2.use_convex_comb_reg 164 | ConvexCombinationRegularization.filters = 32 165 | ConvexCombinationRegularization.delta_scale = 10. 166 | ConvexCombinationRegularization.loss_weight = 0.01 167 | 168 | # affected by: TransNetV2.use_color_histograms 169 | ColorHistograms.lookup_window = 101 170 | ColorHistograms.output_dim = None 171 | 172 | 173 | # C3D Config 174 | ############# 175 | C3DConvolutions.restore_from = None 176 | C3DNet.D = 256 177 | -------------------------------------------------------------------------------- /training/metrics_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import matplotlib.pyplot as plt 4 | plt.switch_backend("agg") 5 | 6 | 7 | def predictions_to_scenes(predictions): 8 | scenes = [] 9 | t, t_prev, start = -1, 0, 0 10 | for i, t in enumerate(predictions): 11 | if t_prev == 1 and t == 0: 12 | start = i 13 | if t_prev == 0 and t == 1 and i != 0: 14 | scenes.append([start, i]) 15 | t_prev = t 16 | if t == 0: 17 | scenes.append([start, i]) 18 | 19 | # just fix if all predictions are 1 20 | if len(scenes) == 0: 21 | return np.array([[0, len(predictions) - 1]], dtype=np.int32) 22 | 23 | return np.array(scenes, dtype=np.int32) 24 | 25 | 26 | def evaluate_scenes(gt_scenes, pred_scenes, return_mistakes=False, n_frames_miss_tolerance=2): 27 | """ 28 | Adapted from: https://github.com/gyglim/shot-detection-evaluation 29 | The original based on: http://imagelab.ing.unimore.it/imagelab/researchActivity.asp?idActivity=19 30 | 31 | n_frames_miss_tolerance: 32 | Number of frames it is possible to miss ground truth by, and still being counted as a correct detection. 33 | 34 | Examples of computation with different tolerance margin: 35 | n_frames_miss_tolerance = 0 36 | pred_scenes: [[0, 5], [6, 9]] -> pred_trans: [[5.5, 5.5]] 37 | gt_scenes: [[0, 5], [6, 9]] -> gt_trans: [[5.5, 5.5]] -> HIT 38 | gt_scenes: [[0, 4], [5, 9]] -> gt_trans: [[4.5, 4.5]] -> MISS 39 | n_frames_miss_tolerance = 1 40 | pred_scenes: [[0, 5], [6, 9]] -> pred_trans: [[5.0, 6.0]] 41 | gt_scenes: [[0, 5], [6, 9]] -> gt_trans: [[5.0, 6.0]] -> HIT 42 | gt_scenes: [[0, 4], [5, 9]] -> gt_trans: [[4.0, 5.0]] -> HIT 43 | gt_scenes: [[0, 3], [4, 9]] -> gt_trans: [[3.0, 4.0]] -> MISS 44 | n_frames_miss_tolerance = 2 45 | pred_scenes: [[0, 5], [6, 9]] -> pred_trans: [[4.5, 6.5]] 46 | gt_scenes: [[0, 5], [6, 9]] -> gt_trans: [[4.5, 6.5]] -> HIT 47 | gt_scenes: [[0, 4], [5, 9]] -> gt_trans: [[3.5, 5.5]] -> HIT 48 | gt_scenes: [[0, 3], [4, 9]] -> gt_trans: [[2.5, 4.5]] -> HIT 49 | gt_scenes: [[0, 2], [3, 9]] -> gt_trans: [[1.5, 3.5]] -> MISS 50 | """ 51 | 52 | shift = n_frames_miss_tolerance / 2 53 | gt_scenes = gt_scenes.astype(np.float32) + np.array([[-0.5 + shift, 0.5 - shift]]) 54 | pred_scenes = pred_scenes.astype(np.float32) + np.array([[-0.5 + shift, 0.5 - shift]]) 55 | 56 | gt_trans = np.stack([gt_scenes[:-1, 1], gt_scenes[1:, 0]], 1) 57 | pred_trans = np.stack([pred_scenes[:-1, 1], pred_scenes[1:, 0]], 1) 58 | 59 | i, j = 0, 0 60 | tp, fp, fn = 0, 0, 0 61 | fp_mistakes, fn_mistakes = [], [] 62 | 63 | while i < len(gt_trans) or j < len(pred_trans): 64 | if j == len(pred_trans): 65 | fn += 1 66 | fn_mistakes.append(gt_trans[i]) 67 | i += 1 68 | elif i == len(gt_trans): 69 | fp += 1 70 | fp_mistakes.append(pred_trans[j]) 71 | j += 1 72 | elif pred_trans[j, 1] < gt_trans[i, 0]: 73 | fp += 1 74 | fp_mistakes.append(pred_trans[j]) 75 | j += 1 76 | elif pred_trans[j, 0] > gt_trans[i, 1]: 77 | fn += 1 78 | fn_mistakes.append(gt_trans[i]) 79 | i += 1 80 | else: 81 | i += 1 82 | j += 1 83 | tp += 1 84 | 85 | if tp + fp != 0: 86 | p = tp / (tp + fp) 87 | else: 88 | p = 0 89 | 90 | if tp + fn != 0: 91 | r = tp / (tp + fn) 92 | else: 93 | r = 0 94 | 95 | if p + r != 0: 96 | f1 = (p * r * 2) / (p + r) 97 | else: 98 | f1 = 0 99 | 100 | assert tp + fn == len(gt_trans) 101 | assert tp + fp == len(pred_trans) 102 | 103 | if return_mistakes: 104 | return p, r, f1, (tp, fp, fn), fp_mistakes, fn_mistakes 105 | return p, r, f1, (tp, fp, fn) 106 | 107 | 108 | def graph(data, labels=None, marker=""): 109 | fig = plt.figure(figsize=(6, 6)) 110 | 111 | plots = [] 112 | for x, y in data: 113 | p, = plt.plot(x, y, marker=marker) 114 | plots.append(p) 115 | 116 | # plt.legend(plots, legends) 117 | plt.grid(alpha=0.2) 118 | 119 | # remove figure border and ticks 120 | for spine in plt.gca().spines.values(): 121 | spine.set_visible(False) 122 | plt.tick_params(length=0) 123 | 124 | # bold 0 axis 125 | plt.axhline(0, color="k", linewidth=1) 126 | plt.axvline(0, color="k", linewidth=1) 127 | 128 | if labels is not None: 129 | plt.xlabel(labels[0]) 130 | plt.ylabel(labels[1]) 131 | 132 | fig.canvas.draw() 133 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 134 | 135 | width, height = fig.canvas.get_width_height() 136 | plt.close() 137 | return data.reshape([height, width, 3]) 138 | 139 | 140 | def create_scene_based_summaries(one_hot_pred, one_hot_gt, prefix="test", step=0): 141 | thresholds = np.array([ 142 | 0.02, 0.06, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9 143 | ]) 144 | precision, recall, f1, tp, fp, fn = np.zeros_like(thresholds), np.zeros_like(thresholds),\ 145 | np.zeros_like(thresholds), np.zeros_like(thresholds),\ 146 | np.zeros_like(thresholds), np.zeros_like(thresholds) 147 | 148 | gt_scenes = predictions_to_scenes(one_hot_gt) 149 | for i in range(len(thresholds)): 150 | pred_scenes = predictions_to_scenes( 151 | (one_hot_pred > thresholds[i]).astype(np.uint8) 152 | ) 153 | precision[i], recall[i], f1[i], (tp[i], fp[i], fn[i]) = evaluate_scenes(gt_scenes, pred_scenes) 154 | 155 | best_idx = np.argmax(f1) 156 | tf.summary.scalar(prefix + "/scene/f1_score_0.1", f1[2], step=step) 157 | tf.summary.scalar(prefix + "/scene/f1_score_0.5", f1[7], step=step) 158 | tf.summary.scalar(prefix + "/scene/f1_max_score", f1[best_idx], step=step) 159 | tf.summary.scalar(prefix + "/scene/f1_max_score_thr", thresholds[best_idx], step=step) 160 | tf.summary.scalar(prefix + "/scene/tp", tp[best_idx], step=step) 161 | tf.summary.scalar(prefix + "/scene/fp", fp[best_idx], step=step) 162 | tf.summary.scalar(prefix + "/scene/fn", fn[best_idx], step=step) 163 | 164 | valid_idx = np.logical_and(recall != 0, precision != 0) 165 | 166 | tf.summary.image(prefix + "/precision_recall", 167 | graph(data=[(recall[valid_idx], precision[valid_idx])], 168 | labels=("Recall", "Precision"), 169 | marker=".")[np.newaxis], 170 | step=step) 171 | tf.summary.image(prefix + "/f1_score", 172 | graph(data=[(thresholds[valid_idx], f1[valid_idx])], 173 | labels=("Threshold", "F1 Score"), 174 | marker=".")[np.newaxis], 175 | step=step) 176 | return f1[best_idx] 177 | -------------------------------------------------------------------------------- /training/consolidate_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import numpy as np 5 | import pandas as pd 6 | from tqdm import tqdm 7 | 8 | from video_utils import get_frames 9 | from visualization_utils import visualize_scenes 10 | 11 | 12 | def get_scenes_from_transition_frames(transition_frames, video_len): 13 | prev_idx = curr_frame = -1 14 | scene_start = 0 15 | scenes = [] 16 | 17 | for curr_frame in transition_frames: 18 | if prev_idx + 1 != curr_frame: 19 | scenes.append((scene_start, curr_frame)) 20 | scene_start = curr_frame + 1 21 | prev_idx = curr_frame 22 | 23 | if curr_frame != video_len - 1: 24 | scenes.append((scene_start, video_len - 1)) 25 | 26 | return np.array(scenes) 27 | 28 | 29 | def save_csv(fn, csv_data): 30 | with open(fn + ".txt", "w") as f: 31 | for l in csv_data: 32 | f.write(l + "\n") 33 | 34 | 35 | BBC_mp4_files = "BBCDataset/*.mp4" 36 | BBC_txt_files = "BBCDataset/annotations/shots/" 37 | BBC_target_dir = "consolidated/BBCDataset" 38 | os.makedirs(BBC_target_dir, exist_ok=True) 39 | 40 | RAI_mp4_files = "RAIDataset/*.mp4" 41 | RAI_txt_files = "RAIDataset/labels/" 42 | RAI_target_dir = "consolidated/RAIDataset" 43 | os.makedirs(RAI_target_dir, exist_ok=True) 44 | 45 | CLIPSHOTS_TRN_mp4_files = "ClipShots/videos/train/*.mp4" 46 | CLIPSHOTS_TRN_txt_files = "ClipShots/annotations/train.json" 47 | CLIPSHOTS_TRN_target_dir = "consolidated/ClipShotsTrain" 48 | os.makedirs(CLIPSHOTS_TRN_target_dir, exist_ok=True) 49 | 50 | CLIPSHOTS_TST_mp4_files = "ClipShots/videos/test/*.mp4" 51 | CLIPSHOTS_TST_txt_files = "ClipShots/annotations/test.json" 52 | CLIPSHOTS_TST_target_dir = "consolidated/ClipShotsTest" 53 | os.makedirs(CLIPSHOTS_TST_target_dir, exist_ok=True) 54 | 55 | CLIPSHOTS_GRD_mp4_files = "ClipShots/videos/only_gradual/*.mp4" 56 | CLIPSHOTS_GRD_txt_files = "ClipShots/annotations/only_gradual.json" 57 | CLIPSHOTS_GRD_target_dir = "consolidated/ClipShotsGradual" 58 | os.makedirs(CLIPSHOTS_GRD_target_dir, exist_ok=True) 59 | 60 | IACC3_SUBSET100_mp4_files = "IACC3Subset100/*.mp4" 61 | IACC3_SUBSET100_txt_files = "IACC3Subset100/" 62 | IACC3_SUBSET100_target_dir = "consolidated/IACC3Subset100" 63 | os.makedirs(IACC3_SUBSET100_target_dir, exist_ok=True) 64 | 65 | IACC3_RANDOM3000_mp4_files = "/Datasets/IACC.3/random_3000/*.mp4" 66 | IACC3_RANDOM3000_txt_files = "/Datasets/IACC.3/msb" 67 | IACC3_RANDOM3000_map_file = "/Datasets/IACC.3/data/filenames.csv" 68 | IACC3_RANDOM3000_target_dir = "consolidated/IACC3Random3000" 69 | os.makedirs(IACC3_RANDOM3000_target_dir, exist_ok=True) 70 | 71 | 72 | # BBC Dataset 73 | print("Consolidating BBC Dataset...") 74 | csv_data = [] 75 | 76 | for fn in tqdm(glob.glob(BBC_mp4_files)): 77 | fn = os.path.abspath(fn) 78 | 79 | fn_idx = os.path.basename(fn).split(".")[0].split("_")[1] 80 | gt_fn = glob.glob(os.path.join(BBC_txt_files, fn_idx + "*"))[0] 81 | 82 | scenes = np.loadtxt(gt_fn, dtype=np.int32, ndmin=2) 83 | scenes = scenes + 1 84 | if scenes[0][0] == 1: 85 | scenes[0][0] = 0 86 | 87 | video = get_frames(fn) 88 | 89 | save_to = os.path.abspath(os.path.join(BBC_target_dir, fn_idx)) 90 | np.savetxt(save_to + ".txt", scenes, fmt="%d") 91 | 92 | visualize_scenes(video, scenes).save(save_to + ".png") 93 | csv_data.append("{},{}".format(fn, save_to + ".txt")) 94 | 95 | save_csv(BBC_target_dir, csv_data) 96 | 97 | 98 | # RAI Dataset 99 | print("Consolidating RAI Dataset...") 100 | csv_data = [] 101 | 102 | for fn in tqdm(glob.glob(RAI_mp4_files)): 103 | fn = os.path.abspath(fn) 104 | 105 | fn_idx = os.path.basename(fn).split(".")[0] 106 | gt_fn = os.path.join(RAI_txt_files, fn_idx + "_gt.txt") 107 | 108 | scenes = np.loadtxt(gt_fn, dtype=np.int32, ndmin=2) 109 | video = get_frames(fn) 110 | 111 | save_to = os.path.abspath(os.path.join(RAI_target_dir, fn_idx)) 112 | np.savetxt(save_to + ".txt", scenes, fmt="%d") 113 | 114 | visualize_scenes(video, scenes).save(save_to + ".png") 115 | csv_data.append("{},{}".format(fn, save_to + ".txt")) 116 | 117 | save_csv(RAI_target_dir, csv_data) 118 | 119 | 120 | # IACC3Subset100 Dataset 121 | print("Consolidating IACC3Subset100 Dataset...") 122 | csv_data = [] 123 | 124 | for fn in tqdm(glob.glob(IACC3_SUBSET100_mp4_files)): 125 | fn = os.path.abspath(fn) 126 | 127 | fn_idx = os.path.basename(fn).split(".")[0] 128 | gt_fn = os.path.join(IACC3_SUBSET100_txt_files, fn_idx + ".txt") 129 | 130 | video = get_frames(fn) 131 | transition_frames = np.loadtxt(gt_fn, dtype=np.int32, ndmin=1) if open(gt_fn).read() != "" else [] 132 | scenes = get_scenes_from_transition_frames(transition_frames, len(video)) 133 | 134 | save_to = os.path.abspath(os.path.join(IACC3_SUBSET100_target_dir, fn_idx)) 135 | np.savetxt(save_to + ".txt", scenes, fmt="%d") 136 | 137 | visualize_scenes(video, scenes).save(save_to + ".png") 138 | csv_data.append("{},{}".format(fn, save_to + ".txt")) 139 | 140 | save_csv(IACC3_SUBSET100_target_dir, csv_data) 141 | 142 | 143 | # IACC3Random3000 Dataset 144 | print("Consolidating IACC3Random3000 Dataset...") 145 | csv_data = [] 146 | id2filename = dict(pd.read_csv(IACC3_RANDOM3000_map_file, delimiter=";", header=None).values) 147 | 148 | for fn in tqdm(glob.glob(IACC3_RANDOM3000_mp4_files)): 149 | fn = os.path.abspath(fn) 150 | 151 | fn_idx = os.path.basename(fn).split(".")[0] 152 | gt_fn = os.path.join(IACC3_RANDOM3000_txt_files, id2filename[int(fn_idx)][:-4] + ".msb") 153 | 154 | scenes = np.loadtxt(gt_fn, dtype=np.int32, skiprows=2, ndmin=2) 155 | video = get_frames(fn) 156 | 157 | save_to = os.path.abspath(os.path.join(IACC3_RANDOM3000_target_dir, fn_idx)) 158 | np.savetxt(save_to + ".txt", scenes, fmt="%d") 159 | 160 | visualize_scenes(video, scenes).save(save_to + ".png") 161 | csv_data.append("{},{}".format(fn, save_to + ".txt")) 162 | 163 | save_csv(IACC3_RANDOM3000_target_dir, csv_data) 164 | 165 | 166 | # ClipShots Dataset 167 | def clipshots_dataset(txt_files, mp4_files, target_dir): 168 | csv_data = [] 169 | gt_data = json.load(open(txt_files)) 170 | 171 | for fn in tqdm(glob.glob(mp4_files)): 172 | fn = os.path.abspath(fn) 173 | k = os.path.basename(fn) 174 | 175 | # number of frames must be integer, check it is true 176 | assert int(gt_data[k]['frame_num']) == gt_data[k]['frame_num'] 177 | n_frames = int(gt_data[k]['frame_num']) 178 | 179 | video = get_frames(fn) 180 | if video is None: 181 | print("ERROR: Video file error", k) 182 | continue 183 | # gt data must match actual extracted data 184 | plus = 0 185 | if len(video) != n_frames: 186 | if len(video) != n_frames + 1: 187 | print("ERROR: {} video length {} vs length specified in gt {}, skipping it".format( 188 | k, len(video), n_frames)) 189 | continue 190 | print("WARN: {} video length {} vs length specified in gt {}, adjusting ground truth".format( 191 | k, len(video), n_frames)) 192 | plus = 1 193 | n_frames = len(video) 194 | 195 | translations = np.array(gt_data[k]["transitions"]) 196 | if len(translations) == 0: 197 | scenes = np.array([[0, n_frames - 1]]) 198 | else: 199 | scene_ends_zeroindexed = translations[:, 0] + plus 200 | scene_starts_zeroindexed = translations[:, 1] + plus 201 | scene_starts_zeroindexed = np.concatenate([[0], scene_starts_zeroindexed]) 202 | scene_ends_zeroindexed = np.concatenate([scene_ends_zeroindexed, [n_frames - 1]]) 203 | scenes = np.stack([scene_starts_zeroindexed, scene_ends_zeroindexed], 1) 204 | 205 | save_to = os.path.abspath(os.path.join(target_dir, k[:-4])) 206 | np.savetxt(save_to + ".txt", scenes, fmt="%d") 207 | 208 | visualize_scenes(video, scenes).save(save_to + ".png") 209 | csv_data.append("{},{}".format(fn, save_to + ".txt")) 210 | 211 | save_csv(target_dir, csv_data) 212 | 213 | 214 | print("Consolidating ClipShots Train Dataset...") 215 | clipshots_dataset(CLIPSHOTS_TRN_txt_files, CLIPSHOTS_TRN_mp4_files, CLIPSHOTS_TRN_target_dir) 216 | print("Consolidating ClipShots Test Dataset...") 217 | clipshots_dataset(CLIPSHOTS_TST_txt_files, CLIPSHOTS_TST_mp4_files, CLIPSHOTS_TST_target_dir) 218 | print("Consolidating ClipShots Gradual Dataset...") 219 | clipshots_dataset(CLIPSHOTS_GRD_txt_files, CLIPSHOTS_GRD_mp4_files, CLIPSHOTS_GRD_target_dir) 220 | -------------------------------------------------------------------------------- /inference/transnetv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | 6 | class TransNetV2: 7 | 8 | def __init__(self, model_dir=None): 9 | if model_dir is None: 10 | model_dir = os.path.join(os.path.dirname(__file__), "transnetv2-weights/") 11 | if not os.path.isdir(model_dir): 12 | raise FileNotFoundError(f"[TransNetV2] ERROR: {model_dir} is not a directory.") 13 | else: 14 | print(f"[TransNetV2] Using weights from {model_dir}.") 15 | 16 | self._input_size = (27, 48, 3) 17 | try: 18 | self._model = tf.saved_model.load(model_dir) 19 | except OSError as exc: 20 | raise IOError(f"[TransNetV2] It seems that files in {model_dir} are corrupted or missing. " 21 | f"Re-download them manually and retry. For more info, see: " 22 | f"https://github.com/soCzech/TransNetV2/issues/1#issuecomment-647357796") from exc 23 | 24 | def predict_raw(self, frames: np.ndarray): 25 | assert len(frames.shape) == 5 and frames.shape[2:] == self._input_size, \ 26 | "[TransNetV2] Input shape must be [batch, frames, height, width, 3]." 27 | frames = tf.cast(frames, tf.float32) 28 | 29 | logits, dict_ = self._model(frames) 30 | single_frame_pred = tf.sigmoid(logits) 31 | all_frames_pred = tf.sigmoid(dict_["many_hot"]) 32 | 33 | return single_frame_pred, all_frames_pred 34 | 35 | def predict_frames(self, frames: np.ndarray): 36 | assert len(frames.shape) == 4 and frames.shape[1:] == self._input_size, \ 37 | "[TransNetV2] Input shape must be [frames, height, width, 3]." 38 | 39 | def input_iterator(): 40 | # return windows of size 100 where the first/last 25 frames are from the previous/next batch 41 | # the first and last window must be padded by copies of the first and last frame of the video 42 | no_padded_frames_start = 25 43 | no_padded_frames_end = 25 + 50 - (len(frames) % 50 if len(frames) % 50 != 0 else 50) # 25 - 74 44 | 45 | start_frame = np.expand_dims(frames[0], 0) 46 | end_frame = np.expand_dims(frames[-1], 0) 47 | padded_inputs = np.concatenate( 48 | [start_frame] * no_padded_frames_start + [frames] + [end_frame] * no_padded_frames_end, 0 49 | ) 50 | 51 | ptr = 0 52 | while ptr + 100 <= len(padded_inputs): 53 | out = padded_inputs[ptr:ptr + 100] 54 | ptr += 50 55 | yield out[np.newaxis] 56 | 57 | predictions = [] 58 | 59 | for inp in input_iterator(): 60 | single_frame_pred, all_frames_pred = self.predict_raw(inp) 61 | predictions.append((single_frame_pred.numpy()[0, 25:75, 0], 62 | all_frames_pred.numpy()[0, 25:75, 0])) 63 | 64 | print("\r[TransNetV2] Processing video frames {}/{}".format( 65 | min(len(predictions) * 50, len(frames)), len(frames) 66 | ), end="") 67 | print("") 68 | 69 | single_frame_pred = np.concatenate([single_ for single_, all_ in predictions]) 70 | all_frames_pred = np.concatenate([all_ for single_, all_ in predictions]) 71 | 72 | return single_frame_pred[:len(frames)], all_frames_pred[:len(frames)] # remove extra padded frames 73 | 74 | def predict_video(self, video_fn: str): 75 | try: 76 | import ffmpeg 77 | except ModuleNotFoundError: 78 | raise ModuleNotFoundError("For `predict_video` function `ffmpeg` needs to be installed in order to extract " 79 | "individual frames from video file. Install `ffmpeg` command line tool and then " 80 | "install python wrapper by `pip install ffmpeg-python`.") 81 | 82 | print("[TransNetV2] Extracting frames from {}".format(video_fn)) 83 | video_stream, err = ffmpeg.input(video_fn).output( 84 | "pipe:", format="rawvideo", pix_fmt="rgb24", s="48x27" 85 | ).run(capture_stdout=True, capture_stderr=True) 86 | 87 | video = np.frombuffer(video_stream, np.uint8).reshape([-1, 27, 48, 3]) 88 | return (video, *self.predict_frames(video)) 89 | 90 | @staticmethod 91 | def predictions_to_scenes(predictions: np.ndarray, threshold: float = 0.5): 92 | predictions = (predictions > threshold).astype(np.uint8) 93 | 94 | scenes = [] 95 | t, t_prev, start = -1, 0, 0 96 | for i, t in enumerate(predictions): 97 | if t_prev == 1 and t == 0: 98 | start = i 99 | if t_prev == 0 and t == 1 and i != 0: 100 | scenes.append([start, i]) 101 | t_prev = t 102 | if t == 0: 103 | scenes.append([start, i]) 104 | 105 | # just fix if all predictions are 1 106 | if len(scenes) == 0: 107 | return np.array([[0, len(predictions) - 1]], dtype=np.int32) 108 | 109 | return np.array(scenes, dtype=np.int32) 110 | 111 | @staticmethod 112 | def visualize_predictions(frames: np.ndarray, predictions): 113 | from PIL import Image, ImageDraw 114 | 115 | if isinstance(predictions, np.ndarray): 116 | predictions = [predictions] 117 | 118 | ih, iw, ic = frames.shape[1:] 119 | width = 25 120 | 121 | # pad frames so that length of the video is divisible by width 122 | # pad frames also by len(predictions) pixels in width in order to show predictions 123 | pad_with = width - len(frames) % width if len(frames) % width != 0 else 0 124 | frames = np.pad(frames, [(0, pad_with), (0, 1), (0, len(predictions)), (0, 0)]) 125 | 126 | predictions = [np.pad(x, (0, pad_with)) for x in predictions] 127 | height = len(frames) // width 128 | 129 | img = frames.reshape([height, width, ih + 1, iw + len(predictions), ic]) 130 | img = np.concatenate(np.split( 131 | np.concatenate(np.split(img, height), axis=2)[0], width 132 | ), axis=2)[0, :-1] 133 | 134 | img = Image.fromarray(img) 135 | draw = ImageDraw.Draw(img) 136 | 137 | # iterate over all frames 138 | for i, pred in enumerate(zip(*predictions)): 139 | x, y = i % width, i // width 140 | x, y = x * (iw + len(predictions)) + iw, y * (ih + 1) + ih - 1 141 | 142 | # we can visualize multiple predictions per single frame 143 | for j, p in enumerate(pred): 144 | color = [0, 0, 0] 145 | color[(j + 1) % 3] = 255 146 | 147 | value = round(p * (ih - 1)) 148 | if value != 0: 149 | draw.line((x + j, y, x + j, y - value), fill=tuple(color), width=1) 150 | return img 151 | 152 | 153 | def main(): 154 | import sys 155 | import argparse 156 | 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument("files", type=str, nargs="+", help="path to video files to process") 159 | parser.add_argument("--weights", type=str, default=None, 160 | help="path to TransNet V2 weights, tries to infer the location if not specified") 161 | parser.add_argument('--visualize', action="store_true", 162 | help="save a png file with prediction visualization for each extracted video") 163 | args = parser.parse_args() 164 | 165 | model = TransNetV2(args.weights) 166 | for file in args.files: 167 | if os.path.exists(file + ".predictions.txt") or os.path.exists(file + ".scenes.txt"): 168 | print(f"[TransNetV2] {file}.predictions.txt or {file}.scenes.txt already exists. " 169 | f"Skipping video {file}.", file=sys.stderr) 170 | continue 171 | 172 | video_frames, single_frame_predictions, all_frame_predictions = \ 173 | model.predict_video(file) 174 | 175 | predictions = np.stack([single_frame_predictions, all_frame_predictions], 1) 176 | np.savetxt(file + ".predictions.txt", predictions, fmt="%.6f") 177 | 178 | scenes = model.predictions_to_scenes(single_frame_predictions) 179 | np.savetxt(file + ".scenes.txt", scenes, fmt="%d") 180 | 181 | if args.visualize: 182 | if os.path.exists(file + ".vis.png"): 183 | print(f"[TransNetV2] {file}.vis.png already exists. " 184 | f"Skipping visualization of video {file}.", file=sys.stderr) 185 | continue 186 | 187 | pil_image = model.visualize_predictions( 188 | video_frames, predictions=(single_frame_predictions, all_frame_predictions)) 189 | pil_image.save(file + ".vis.png") 190 | 191 | 192 | if __name__ == "__main__": 193 | main() 194 | -------------------------------------------------------------------------------- /configs/transnetv2.gin: -------------------------------------------------------------------------------- 1 | # Macros: 2 | # ============================================================================== 3 | bi_tempered_loss = False 4 | bi_tempered_loss_temp1 = 1.0 5 | bi_tempered_loss_temp2 = 1.0 6 | frame_height = 27 7 | frame_width = 48 8 | shot_len = 100 9 | 10 | # Parameters for Adam: 11 | # ============================================================================== 12 | Adam.learning_rate = 0.001 13 | 14 | # Parameters for augment_shot: 15 | # ============================================================================== 16 | augment_shot.adjust_brightness = True 17 | augment_shot.adjust_contrast = True 18 | augment_shot.adjust_hue = True 19 | augment_shot.adjust_saturation = True 20 | augment_shot.color_max_val = 1.2 21 | augment_shot.color_min_val = 0.5 22 | augment_shot.color_prob = 0.05 23 | augment_shot.equalize_prob = 0.05 24 | augment_shot.left_right_flip_prob = 0.5 25 | augment_shot.posterize_min_bits = 4 26 | augment_shot.posterize_prob = 0.05 27 | augment_shot.up_down_flip_prob = 0.1 28 | 29 | # Parameters for augment_shot_spacial: 30 | # ============================================================================== 31 | augment_shot_spacial.clip_left_right = 20 32 | augment_shot_spacial.clip_top_bottom = 10 33 | augment_shot_spacial.random_shake_max_size = 15 34 | augment_shot_spacial.random_shake_prob = 0.3 35 | 36 | # Parameters for C3DConvolutions: 37 | # ============================================================================== 38 | C3DConvolutions.restore_from = None 39 | 40 | # Parameters for C3DNet: 41 | # ============================================================================== 42 | C3DNet.D = 256 43 | 44 | # Parameters for ColorHistograms: 45 | # ============================================================================== 46 | ColorHistograms.lookup_window = 101 47 | ColorHistograms.output_dim = 128 48 | 49 | # Parameters for concat_shots: 50 | # ============================================================================== 51 | concat_shots.advanced_shot_trans_prob = 0.0 52 | concat_shots.color_transfer_prob = 0.1 53 | concat_shots.cutout_prob = 0.0 54 | concat_shots.hard_cut_prob = 0.5 55 | concat_shots.shot_len = %shot_len 56 | concat_shots.transition_max_len = 30 57 | concat_shots.transition_min_len = 2 58 | 59 | # Parameters for Conv3DConfigurable: 60 | # ============================================================================== 61 | Conv3DConfigurable.kernel_initializer = 'he_normal' 62 | Conv3DConfigurable.separable = True 63 | 64 | # Parameters for ConvexCombinationRegularization: 65 | # ============================================================================== 66 | ConvexCombinationRegularization.delta_scale = 10.0 67 | ConvexCombinationRegularization.filters = 32 68 | ConvexCombinationRegularization.loss_weight = 0.01 69 | 70 | # Parameters for cutout: 71 | # ============================================================================== 72 | cutout.cutout_color = None 73 | cutout.max_height_fraction = 0.6 74 | cutout.max_width_fraction = 0.6 75 | cutout.min_height_fraction = 0.3 76 | cutout.min_width_fraction = 0.3 77 | 78 | # Parameters for DilatedDCNNV2: 79 | # ============================================================================== 80 | DilatedDCNNV2.batch_norm = True 81 | 82 | # Parameters for FrameSimilarity: 83 | # ============================================================================== 84 | FrameSimilarity.lookup_window = 101 85 | FrameSimilarity.output_dim = 128 86 | FrameSimilarity.similarity_dim = 128 87 | FrameSimilarity.stop_gradient = False 88 | FrameSimilarity.use_bias = True 89 | 90 | # Parameters for loss: 91 | # ============================================================================== 92 | loss.bi_tempered_loss = %bi_tempered_loss 93 | loss.bi_tempered_loss_temp1 = %bi_tempered_loss_temp1 94 | loss.bi_tempered_loss_temp2 = %bi_tempered_loss_temp2 95 | loss.dynamic_weight = None 96 | loss.l2_loss_weight = 0.0001 97 | loss.many_hot_loss_weight = 0.1 98 | loss.transition_weight = 5.0 99 | 100 | # Parameters for OctConv3D: 101 | # ============================================================================== 102 | OctConv3D.alpha = 0.25 103 | 104 | # Parameters for options: 105 | # ============================================================================== 106 | options.bi_tempered_loss = %bi_tempered_loss 107 | options.bi_tempered_loss_temp2 = %bi_tempered_loss_temp2 108 | options.c3d_net = False 109 | options.input_shape = [%shot_len, %frame_height, %frame_width, 3] 110 | options.learning_rate_decay = None 111 | options.learning_rate_schedule = None 112 | options.log_dir = 'logs' 113 | options.log_name = 'transnetv2' 114 | options.n_epochs = 30 115 | options.original_transnet = False 116 | options.restore = None 117 | options.restore_resnet_features = None 118 | options.test_only = False 119 | options.transition_only_data_fraction = 0.3 120 | options.transition_only_trn_files = None 121 | options.trn_files = \ 122 | ['data/48x27/IACC3Random3000/*.tfrecord', 123 | 'data/48x27/ClipShotsTrain-Train/*.tfrecord', 124 | 'data/48x27/ClipShotsGradual/*.tfrecord'] 125 | options.tst_files = \ 126 | {'clip_shots_val': ['data/48x27/ClipShotsTrain-Valid/*.tfrecord'], 127 | 'dissolves_20': ['data/48x27/ClipShotsTrain-validation/dissolves_20.tfrecord'], 128 | 'dissolves_60': ['data/48x27/ClipShotsTrain-validation/dissolves_60.tfrecord'], 129 | 'hardcuts': ['data/48x27/ClipShotsTrain-validation/hardcuts.tfrecord'], 130 | 'iacc100': ['data/48x27/IACC3Subset100/*.tfrecord']} 131 | 132 | # Parameters for parse_test_sample: 133 | # ============================================================================== 134 | parse_test_sample.frame_height = %frame_height 135 | parse_test_sample.frame_width = %frame_width 136 | 137 | # Parameters for parse_train_sample: 138 | # ============================================================================== 139 | parse_train_sample.frame_height = %frame_height 140 | parse_train_sample.frame_width = %frame_width 141 | parse_train_sample.original_height = None 142 | parse_train_sample.original_width = None 143 | parse_train_sample.shot_len = %shot_len 144 | parse_train_sample.spacial_augmentation = False 145 | parse_train_sample.sudden_color_change_prob = 0.0 146 | 147 | # Parameters for parse_train_transition_sample: 148 | # ============================================================================== 149 | parse_train_transition_sample.frame_height = %frame_height 150 | parse_train_transition_sample.frame_width = %frame_width 151 | parse_train_transition_sample.shot_len = %shot_len 152 | 153 | # Parameters for ResNetFeatures: 154 | # ============================================================================== 155 | ResNetFeatures.trainable = False 156 | 157 | # Parameters for SGD: 158 | # ============================================================================== 159 | SGD.learning_rate = 0.01 160 | SGD.momentum = 0.9 161 | 162 | # Parameters for StackedDDCNNV2: 163 | # ============================================================================== 164 | StackedDDCNNV2.pool_type = 'avg' 165 | StackedDDCNNV2.shortcut = True 166 | StackedDDCNNV2.stochastic_depth_drop_prob = 0.0 167 | StackedDDCNNV2.use_octave_conv = False 168 | 169 | # Parameters for test_pipeline: 170 | # ============================================================================== 171 | test_pipeline.batch_size = 16 172 | test_pipeline.shot_len = %shot_len 173 | 174 | # Parameters for train_pipeline: 175 | # ============================================================================== 176 | train_pipeline.batch_size = 16 177 | train_pipeline.frame_height = %frame_height 178 | train_pipeline.frame_width = %frame_width 179 | train_pipeline.repeat = True 180 | train_pipeline.shot_len = %shot_len 181 | train_pipeline.shuffle_buffer = 100 182 | 183 | # Parameters for train_transition_pipeline: 184 | # ============================================================================== 185 | train_transition_pipeline.batch_size = 16 186 | train_transition_pipeline.repeat = False 187 | train_transition_pipeline.shuffle_buffer = 100 188 | 189 | # Parameters for training: 190 | # ============================================================================== 191 | training.evaluate_on_middle_frames_only = True 192 | training.grad_clipping = 10.0 193 | training.log_freq = 200 194 | training.n_batches_per_epoch = 750 195 | training.optimizer = @tf.keras.optimizers.SGD 196 | 197 | # Parameters for TransNetV2: 198 | # ============================================================================== 199 | TransNetV2.D = 1024 200 | TransNetV2.F = 16 201 | TransNetV2.L = 3 202 | TransNetV2.S = 2 203 | TransNetV2.dropout_rate = 0.5 204 | TransNetV2.frame_similarity_on_last_layer = False 205 | TransNetV2.use_color_histograms = True 206 | TransNetV2.use_convex_comb_reg = False 207 | TransNetV2.use_frame_similarity = True 208 | TransNetV2.use_many_hot_targets = True 209 | TransNetV2.use_mean_pooling = False 210 | TransNetV2.use_resnet_features = False 211 | TransNetV2.use_resnet_like_top = False 212 | -------------------------------------------------------------------------------- /training/models.py: -------------------------------------------------------------------------------- 1 | import gin 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | 6 | class OriginalTransNet(tf.keras.Model): 7 | 8 | def __init__(self, F=16, L=3, S=2, D=256, name="TransNet"): 9 | super(OriginalTransNet, self).__init__(name=name) 10 | 11 | self.blocks = [StackedDDCNN(n_blocks=S, filters=F * 2 ** i, name="SDDCNN_{:d}".format(i + 1)) for i in range(L)] 12 | self.fc1 = tf.keras.layers.Dense(D, activation=tf.nn.relu) 13 | self.fc2 = tf.keras.layers.Dense(2, activation=None) 14 | 15 | def call(self, inputs, training=False): 16 | x = inputs / 255. 17 | for block in self.blocks: 18 | x = block(x) 19 | 20 | shape = [tf.shape(x)[0], tf.shape(x)[1], np.prod(x.get_shape().as_list()[2:])] 21 | x = tf.reshape(x, shape=shape, name="flatten_3d") 22 | 23 | x = self.fc1(x) 24 | x = self.fc2(x) 25 | return x 26 | 27 | 28 | class StackedDDCNN(tf.keras.layers.Layer): 29 | 30 | def __init__(self, n_blocks, filters, name="StackedDDCNN"): 31 | super(StackedDDCNN, self).__init__(name=name) 32 | self.blocks = [DilatedDCNN(filters, name="DDCNN_{:d}".format(i)) for i in range(1, n_blocks + 1)] 33 | self.max_pool = tf.keras.layers.MaxPool3D(pool_size=(1, 2, 2)) 34 | 35 | def call(self, inputs): 36 | x = inputs 37 | for block in self.blocks: 38 | x = block(x) 39 | 40 | x = self.max_pool(x) 41 | return x 42 | 43 | 44 | class DilatedDCNN(tf.keras.layers.Layer): 45 | 46 | def __init__(self, filters, name="DilatedDCNN"): 47 | super(DilatedDCNN, self).__init__(name=name) 48 | 49 | self.conv1 = self._conv3d(filters, 1, name="Conv3D_1") 50 | self.conv2 = self._conv3d(filters, 2, name="Conv3D_2") 51 | self.conv3 = self._conv3d(filters, 4, name="Conv3D_4") 52 | self.conv4 = self._conv3d(filters, 8, name="Conv3D_8") 53 | 54 | @staticmethod 55 | def _conv3d(filters, dilation_rate, name="Conv3D"): 56 | return tf.keras.layers.Conv3D(filters, kernel_size=3, dilation_rate=(dilation_rate, 1, 1), 57 | padding="SAME", activation=tf.nn.relu, use_bias=True, name=name) 58 | 59 | def call(self, inputs): 60 | inputs = tf.identity(inputs) 61 | conv1 = self.conv1(inputs) 62 | conv2 = self.conv2(inputs) 63 | conv3 = self.conv3(inputs) 64 | conv4 = self.conv4(inputs) 65 | x = tf.concat([conv1, conv2, conv3, conv4], axis=4) 66 | return x 67 | 68 | 69 | class ResNet18(tf.keras.Model): 70 | 71 | MEAN = np.array([0.485, 0.456, 0.406], np.float32).reshape([1, 1, 1, 3]) * 255 72 | STD = np.array([0.229, 0.224, 0.225], np.float32).reshape([1, 1, 1, 3]) * 255 73 | 74 | def __init__(self, name="ResNet18"): 75 | super(ResNet18, self).__init__(name=name) 76 | 77 | self.conv1 = tf.keras.layers.Conv2D(filters=64, kernel_size=(7, 7), strides=(2, 2), 78 | padding="SAME", use_bias=False, name="conv1") 79 | self.bn1 = tf.keras.layers.BatchNormalization(name="conv1/bn") 80 | self.max_pool = tf.keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="SAME") 81 | 82 | self.layer2a = ResNetBlock(64, name="Block2a") 83 | self.layer2b = ResNetBlock(64, name="Block2b") 84 | 85 | self.layer3a = ResNetBlock(128, strides=(2, 2), project=True, name="Block3a") 86 | self.layer3b = ResNetBlock(128, name="Block3b") 87 | 88 | self.layer4a = ResNetBlock(256, strides=(2, 2), project=True, name="Block4a") 89 | self.layer4b = ResNetBlock(256, name="Block4b") 90 | 91 | self.layer5a = ResNetBlock(512, strides=(2, 2), project=True, name="Block5a") 92 | self.layer5b = ResNetBlock(512, name="Block5b") 93 | 94 | self.avg_pool = tf.keras.layers.AveragePooling2D(pool_size=(7, 7), strides=(7, 7)) 95 | 96 | self.flatten = tf.keras.layers.Flatten() 97 | self.fc = tf.keras.layers.Dense(1000) 98 | 99 | def call(self, inputs, training=False): 100 | x = self.conv1(inputs) 101 | x = self.bn1(x, training=training) 102 | x = tf.nn.relu(x) 103 | x = self.max_pool(x) 104 | 105 | x = self.layer2a(x, training=training) 106 | x = self.layer2b(x, training=training) 107 | 108 | x = self.layer3a(x, training=training) 109 | x = self.layer3b(x, training=training) 110 | 111 | x = self.layer4a(x, training=training) 112 | x = self.layer4b(x, training=training) 113 | 114 | x = self.layer5a(x, training=training) 115 | x = self.layer5b(x, training=training) 116 | 117 | x = self.avg_pool(x) 118 | return self.fc(self.flatten(x)) 119 | 120 | @staticmethod 121 | def preprocess(inputs): 122 | assert inputs.dtype == np.uint8 or inputs.dtype == tf.uint8 123 | if len(inputs.shape) == 3: 124 | inputs = inputs[tf.newaxis] 125 | assert inputs.shape[1:] == (224, 224, 3) 126 | 127 | mean = tf.constant(ResNet18.MEAN) 128 | std = tf.constant(ResNet18.STD) 129 | 130 | x = tf.cast(inputs, tf.float32) 131 | return (x - mean) / std 132 | 133 | 134 | class ResNetBlock(tf.keras.layers.Layer): 135 | def __init__(self, filters, strides=(1, 1), project=False, name="ResNetBlock"): 136 | super(ResNetBlock, self).__init__(name=name) 137 | 138 | self.conv1 = tf.keras.layers.Conv2D(filters, kernel_size=(3, 3), strides=strides, 139 | padding="SAME", use_bias=False, name="conv1") 140 | self.bn1 = tf.keras.layers.BatchNormalization(name="conv1/bn") 141 | 142 | self.conv2 = tf.keras.layers.Conv2D(filters, kernel_size=(3, 3), padding="SAME", use_bias=False, name="conv2") 143 | self.bn2 = tf.keras.layers.BatchNormalization(gamma_initializer=tf.zeros_initializer(), name="conv2/bn") 144 | 145 | self.project = project 146 | if self.project: 147 | self.conv_shortcut = tf.keras.layers.Conv2D(filters, kernel_size=(1, 1), strides=strides, 148 | use_bias=False, name="conv_shortcut") 149 | self.bn_shortcut = tf.keras.layers.BatchNormalization(name="conv_shortcut/bn") 150 | 151 | def call(self, inputs, training=False): 152 | x = self.conv1(inputs) 153 | x = self.bn1(x, training=training) 154 | x = tf.nn.relu(x) 155 | 156 | x = self.conv2(x) 157 | x = self.bn2(x, training=training) 158 | 159 | shortcut = inputs 160 | if self.project: 161 | shortcut = self.conv_shortcut(shortcut) 162 | shortcut = self.bn_shortcut(shortcut, training=training) 163 | x += shortcut 164 | 165 | return tf.nn.relu(x) 166 | 167 | 168 | @gin.configurable(blacklist=["name"]) 169 | class C3DConvolutions(tf.keras.Model): 170 | # C3D model for UCF101 171 | # https://github.com/tqvinhcs/C3D-tensorflow/blob/master/m_c3d.py#L63 172 | 173 | def __init__(self, weights=None, restore_from=None, name="C3DConvolutions"): 174 | super(C3DConvolutions, self).__init__(name=name) 175 | if restore_from is not None: 176 | weights = self.get_weights(restore_from) 177 | elif weights is None: 178 | weights = [None] * 16 179 | 180 | def conv(filters, kernel_weights, bias_weights): 181 | return tf.keras.layers.Conv3D(filters, kernel_size=3, strides=1, padding="SAME", activation=tf.nn.relu, 182 | kernel_initializer=tf.constant_initializer(kernel_weights) \ 183 | if kernel_weights is not None else "glorot_uniform", 184 | bias_initializer=tf.constant_initializer(bias_weights) \ 185 | if bias_weights is not None else "zeros") 186 | 187 | self.conv_layers = [ 188 | conv(f, ker_init, bias_init) for f, ker_init, bias_init in [ 189 | (64, weights[0], weights[1]), 190 | (128, weights[2], weights[3]), 191 | (256, weights[4], weights[5]), 192 | (256, weights[6], weights[7]), 193 | (512, weights[8], weights[9]), 194 | (512, weights[10], weights[11]), 195 | (512, weights[12], weights[13]), 196 | (512, weights[14], weights[15]) 197 | ] 198 | ] 199 | self.max_pooling = tf.keras.layers.MaxPool3D(pool_size=(1, 2, 2), strides=(1, 2, 2), padding="SAME") 200 | 201 | def call(self, inputs, training=False): 202 | x = inputs - 96.6 203 | print(x.shape) 204 | x = self.conv_layers[0](x) 205 | x = self.max_pooling(x) 206 | print(x.shape) 207 | x = self.conv_layers[1](x) 208 | x = self.max_pooling(x) 209 | print(x.shape) 210 | x = self.conv_layers[2](x) 211 | x = self.conv_layers[3](x) 212 | x = self.max_pooling(x) 213 | print(x.shape) 214 | x = self.conv_layers[4](x) 215 | x = self.conv_layers[5](x) 216 | x = self.max_pooling(x) 217 | print(x.shape) 218 | x = self.conv_layers[6](x) 219 | x = self.conv_layers[7](x) 220 | x = self.max_pooling(x) 221 | print(x.shape) 222 | return x 223 | 224 | @staticmethod 225 | def get_weights(filename): 226 | import scipy.io as sio 227 | return sio.loadmat(filename, squeeze_me=True)['weights'] 228 | 229 | 230 | @gin.configurable(blacklist=["name"]) 231 | class C3DNet(tf.keras.Model): 232 | 233 | def __init__(self, D=256, name="C3DNet"): 234 | super(C3DNet, self).__init__(name=name) 235 | self.convs = C3DConvolutions() 236 | self.fc1 = tf.keras.layers.Dense(D, activation=tf.nn.relu) 237 | self.cls_layer1 = tf.keras.layers.Dense(1, activation=None) 238 | 239 | def call(self, inputs, training=False): 240 | x = self.convs(inputs, training=training) 241 | x = tf.math.reduce_mean(x, axis=[2, 3]) 242 | x = self.fc1(x) 243 | x = self.cls_layer1(x) 244 | 245 | return x 246 | -------------------------------------------------------------------------------- /training/create_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import random 4 | import shutil 5 | import argparse 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | import video_utils 10 | 11 | 12 | def _bytes_feature(value): 13 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 14 | 15 | 16 | def _int64_feature(value): 17 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 18 | 19 | 20 | def scenes2zero_one_representation(scenes, n_frames): 21 | prev_end = 0 22 | one_hot = np.zeros([n_frames], np.uint64) 23 | many_hot = np.zeros([n_frames], np.uint64) 24 | 25 | for start, end in scenes: 26 | # number of frames in transition: start - prev_end - 1 (hardcut has 0) 27 | 28 | # values of many_hot_index 29 | # frame with index (0..n-1) is from a scene, frame [x] is a transition frame 30 | # [0][1] -> 0 31 | # [0][x][2] -> 0, 1 32 | # [0][x][x][3] -> 0, 1, 2 33 | # [0][x][x][x][4] -> 0, 1, 2, 3 34 | # [0][x][x][x][x][5] -> 0, 1, 2, 3, 4 35 | for i in range(prev_end, start): 36 | many_hot[i] = 1 37 | 38 | # values of one_hot_index 39 | # frame with index (0..n-1) is from a scene, frame [x] is a transition frame 40 | # [0]|[1] -> 0 41 | # [0][x]|[2] -> 1 42 | # [0][x]|[x][3] -> 1 43 | # [0][x][x]|[x][4] -> 2 44 | # [0][x][x]|[x][x][5] -> 2 45 | # ... 46 | if not (prev_end == 0 and start == 0): 47 | one_hot_index = prev_end + (start - prev_end) // 2 48 | one_hot[one_hot_index] = 1 49 | 50 | prev_end = end 51 | 52 | # if scene ends with transition 53 | if prev_end + 1 != n_frames: 54 | for i in range(prev_end, n_frames): 55 | many_hot[i] = 1 56 | 57 | one_hot_index = prev_end + (n_frames - prev_end) // 2 58 | one_hot[one_hot_index] = 1 59 | 60 | return one_hot, many_hot 61 | 62 | 63 | def create_test_tfrecord(video_fn, scenes_fn, target_fn, width, height, six_channels=False): 64 | frames = video_utils.get_frames(video_fn, width, height) 65 | if six_channels: 66 | frame_centers = video_utils.get_frames(video_fn, width * 3, height * 3)[:, height:height * 2, width:width * 2] 67 | frames = np.concatenate([frames, frame_centers], -1) 68 | n_frames = len(frames) 69 | 70 | scenes = np.loadtxt(scenes_fn, dtype=np.int32, ndmin=2) 71 | one_hot, many_hot = scenes2zero_one_representation(scenes, n_frames) 72 | 73 | options = tf.io.TFRecordOptions(compression_type="GZIP") 74 | with tf.io.TFRecordWriter(target_fn, options) as writer: 75 | for frame_idx in range(n_frames): 76 | example = tf.train.Example(features=tf.train.Features(feature={ 77 | "frame": _bytes_feature(frames[frame_idx].tobytes("C")), 78 | "is_one_hot_transition": _int64_feature(one_hot[frame_idx]), 79 | "is_many_hot_transition": _int64_feature(many_hot[frame_idx]), 80 | "width": _int64_feature(width), 81 | "height": _int64_feature(height) 82 | })) 83 | writer.write(example.SerializeToString()) 84 | 85 | 86 | def create_test_dataset(target_dir, mapping_fn, width, height, six_channels=False): 87 | os.makedirs(target_dir, exist_ok=True) 88 | mapping = np.loadtxt(mapping_fn, dtype=np.str, delimiter=",") 89 | 90 | for video_fn, scenes_fn in tqdm.tqdm(mapping): 91 | target_fn = os.path.join(target_dir, os.path.splitext(os.path.basename(video_fn))[0] + ".tfrecord") 92 | create_test_tfrecord(video_fn, scenes_fn, target_fn, width, height, six_channels=six_channels) 93 | 94 | 95 | def create_test_npy_files(target_dir, mapping_fn, width, height): 96 | os.makedirs(target_dir, exist_ok=True) 97 | mapping = np.loadtxt(mapping_fn, dtype=np.str, delimiter=",") 98 | 99 | for video_fn, scenes_fn in tqdm.tqdm(mapping): 100 | fn = os.path.splitext(os.path.basename(video_fn))[0] 101 | target_fn = os.path.join(target_dir, fn + ".npy") 102 | frames = video_utils.get_frames(video_fn, width, height) 103 | 104 | shutil.copy2(scenes_fn, os.path.join(target_dir, fn + ".txt")) 105 | np.save(target_fn, frames) 106 | 107 | 108 | def get_scenes_from_video(video_fn, scenes_fn, width, height, min_scene_len=25, six_channels=False): 109 | frames = video_utils.get_frames(video_fn, width, height) 110 | if six_channels: 111 | frame_centers = video_utils.get_frames(video_fn, width * 3, height * 3)[:, height:height * 2, width:width * 2] 112 | frames = np.concatenate([frames, frame_centers], -1) 113 | scenes = np.loadtxt(scenes_fn, dtype=np.int32, ndmin=2) 114 | 115 | video_scenes = [frames[start:end + 1] for start, end in scenes if (end + 1) - start >= min_scene_len] 116 | 117 | selected_sequences = [] 118 | for scene in video_scenes: 119 | len_ = len(scene) 120 | if len_ < 300: 121 | selected_sequences.append(scene) 122 | elif len_ < 600: 123 | selected_sequences.append(scene[(len_ - 300) // 2:][:300]) 124 | else: 125 | selected_sequences.append(scene[:300]) 126 | if len_ >= 900: 127 | selected_sequences.append(scene[len_ // 2 - 150:][:300]) 128 | selected_sequences.append(scene[-300:]) 129 | 130 | return selected_sequences 131 | 132 | 133 | def create_train_dataset(target_dir, target_fn, mapping_fn, width, height, n_videos_in_tfrecord=20, six_channels=False): 134 | os.makedirs(target_dir, exist_ok=True) 135 | mapping = np.loadtxt(mapping_fn, dtype=np.str, delimiter=",").tolist() 136 | 137 | random.seed(42) 138 | random.shuffle(mapping) 139 | 140 | pbar = tqdm.tqdm(total=len(mapping)) 141 | 142 | for start_idx in range(0, len(mapping), n_videos_in_tfrecord): 143 | tfrecord_scenes = [] 144 | for video_fn, scenes_fn in mapping[start_idx:start_idx + n_videos_in_tfrecord]: 145 | tfrecord_scenes.extend( 146 | get_scenes_from_video(video_fn, scenes_fn, width, height, six_channels=six_channels) 147 | ) 148 | pbar.update() 149 | 150 | random.shuffle(tfrecord_scenes) 151 | 152 | options = tf.io.TFRecordOptions(compression_type="GZIP") 153 | with tf.io.TFRecordWriter( 154 | os.path.join(target_dir, "{}-{:04d}.tfrecord".format(target_fn, start_idx)), options) as writer: 155 | for scene in tfrecord_scenes: 156 | example = tf.train.Example(features=tf.train.Features(feature={ 157 | "scene": _bytes_feature(scene.tobytes()), 158 | "length": _int64_feature(len(scene)), 159 | "width": _int64_feature(width), 160 | "height": _int64_feature(height) 161 | })) 162 | writer.write(example.SerializeToString()) 163 | 164 | 165 | def get_transitions_from_video(video_fn, scenes_fn, width, height, window_size=160): 166 | frames = video_utils.get_frames(video_fn, width, height) 167 | n_frames = len(frames) 168 | 169 | scenes = np.loadtxt(scenes_fn, dtype=np.int32, ndmin=2) 170 | one_hot, many_hot = scenes2zero_one_representation(scenes, n_frames) 171 | 172 | transitions = [] 173 | for i, is_transition in enumerate(one_hot): 174 | if is_transition != 1: 175 | continue 176 | 177 | start = max(0, i - window_size // 2) 178 | scene = frames[start:][:window_size] 179 | if len(scene) != window_size: 180 | continue 181 | one = one_hot[start:][:window_size] 182 | many = many_hot[start:][:window_size] 183 | 184 | transitions.append((scene, one, many)) 185 | return transitions 186 | 187 | 188 | def create_train_transition_dataset(target_dir, target_fn, mapping_fn, width, height, n_videos_in_tfrecord=50): 189 | os.makedirs(target_dir, exist_ok=True) 190 | mapping = np.loadtxt(mapping_fn, dtype=np.str, delimiter=",").tolist() 191 | 192 | random.seed(42) 193 | random.shuffle(mapping) 194 | 195 | pbar = tqdm.tqdm(total=len(mapping)) 196 | n_transitions = 0 197 | 198 | for start_idx in range(0, len(mapping), n_videos_in_tfrecord): 199 | tfrecord_scenes = [] 200 | for video_fn, scenes_fn in mapping[start_idx:start_idx + n_videos_in_tfrecord]: 201 | tfrecord_scenes.extend( 202 | get_transitions_from_video(video_fn, scenes_fn, width, height) 203 | ) 204 | pbar.update() 205 | 206 | random.shuffle(tfrecord_scenes) 207 | 208 | options = tf.io.TFRecordOptions(compression_type="GZIP") 209 | with tf.io.TFRecordWriter( 210 | os.path.join(target_dir, "{}-{:04d}.tfrecord".format(target_fn, start_idx)), options) as writer: 211 | for scene, one_hot, many_hot in tfrecord_scenes: 212 | example = tf.train.Example(features=tf.train.Features(feature={ 213 | "scene": _bytes_feature(scene.tobytes()), 214 | "one_hot": _bytes_feature(one_hot.astype(np.uint8).tobytes()), 215 | "many_hot": _bytes_feature(many_hot.astype(np.uint8).tobytes()), 216 | "length": _int64_feature(len(scene)), 217 | "width": _int64_feature(width), 218 | "height": _int64_feature(height) 219 | })) 220 | writer.write(example.SerializeToString()) 221 | n_transitions += len(tfrecord_scenes) 222 | 223 | print("# Transitions: {:d}".format(n_transitions)) 224 | 225 | 226 | def create_test_tfrecord_from_dataset(dataset, target_fn): 227 | options = tf.io.TFRecordOptions(compression_type="GZIP") 228 | with tf.io.TFRecordWriter(target_fn, options) as writer: 229 | for scenes, one_hots, many_hots in dataset: 230 | scenes, one_hots, many_hots = scenes.numpy().astype(np.uint8), one_hots.numpy(), many_hots.numpy() 231 | for scene, one_hot, many_hot in zip(scenes, one_hots, many_hots): 232 | for frame_idx in range(len(scene)): 233 | example = tf.train.Example(features=tf.train.Features(feature={ 234 | "frame": _bytes_feature(scene[frame_idx].tobytes("C")), 235 | "is_one_hot_transition": _int64_feature(one_hot[frame_idx]), 236 | "is_many_hot_transition": _int64_feature(many_hot[frame_idx]), 237 | "width": _int64_feature(scene[frame_idx].shape[1]), 238 | "height": _int64_feature(scene[frame_idx].shape[0]) 239 | })) 240 | writer.write(example.SerializeToString()) 241 | 242 | 243 | if __name__ == "__main__": 244 | parser = argparse.ArgumentParser(description="Convert videos to tfrecords") 245 | parser.add_argument("type", type=str, choices=["train", "test", "train-transitions", "test-npy"], 246 | help="type of tfrecord to generate") 247 | parser.add_argument("--mapping_fn", type=str, help="path to mapping file containing lines in following format: " 248 | "/path/to/video.mp4,/path/to/scenes/gt", required=True) 249 | parser.add_argument("--target_dir", type=str, help="directory where to store the results", required=True) 250 | parser.add_argument("--target_fn", type=str, help="filename where to store the results (only if type=`train`)") 251 | parser.add_argument("--w", type=int, help="width of frames", default=48) 252 | parser.add_argument("--h", type=int, help="height of frames", default=27) 253 | parser.add_argument("--six_channels", action="store_true") 254 | 255 | args = parser.parse_args() 256 | 257 | if args.type == "train": 258 | assert args.target_fn is not None 259 | create_train_dataset(args.target_dir, args.target_fn, args.mapping_fn, args.w, args.h, 260 | six_channels=args.six_channels) 261 | elif args.type == "train-transitions": 262 | assert args.target_fn is not None 263 | assert not args.six_channels # not implemented 264 | create_train_transition_dataset(args.target_dir, args.target_fn, args.mapping_fn, args.w, args.h) 265 | elif args.type == "test": 266 | create_test_dataset(args.target_dir, args.mapping_fn, args.w, args.h, six_channels=args.six_channels) 267 | elif args.type == "test-npy": 268 | assert not args.six_channels # not implemented 269 | create_test_npy_files(args.target_dir, args.mapping_fn, args.w, args.h) 270 | -------------------------------------------------------------------------------- /inference-pytorch/transnetv2_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | 5 | import random 6 | 7 | 8 | class TransNetV2(nn.Module): 9 | 10 | def __init__(self, 11 | F=16, L=3, S=2, D=1024, 12 | use_many_hot_targets=True, 13 | use_frame_similarity=True, 14 | use_color_histograms=True, 15 | use_mean_pooling=False, 16 | dropout_rate=0.5, 17 | use_convex_comb_reg=False, # not supported 18 | use_resnet_features=False, # not supported 19 | use_resnet_like_top=False, # not supported 20 | frame_similarity_on_last_layer=False): # not supported 21 | super(TransNetV2, self).__init__() 22 | 23 | if use_resnet_features or use_resnet_like_top or use_convex_comb_reg or frame_similarity_on_last_layer: 24 | raise NotImplemented("Some options not implemented in Pytorch version of Transnet!") 25 | 26 | self.SDDCNN = nn.ModuleList( 27 | [StackedDDCNNV2(in_filters=3, n_blocks=S, filters=F, stochastic_depth_drop_prob=0.)] + 28 | [StackedDDCNNV2(in_filters=(F * 2 ** (i - 1)) * 4, n_blocks=S, filters=F * 2 ** i) for i in range(1, L)] 29 | ) 30 | 31 | self.frame_sim_layer = FrameSimilarity( 32 | sum([(F * 2 ** i) * 4 for i in range(L)]), lookup_window=101, output_dim=128, similarity_dim=128, use_bias=True 33 | ) if use_frame_similarity else None 34 | self.color_hist_layer = ColorHistograms( 35 | lookup_window=101, output_dim=128 36 | ) if use_color_histograms else None 37 | 38 | self.dropout = nn.Dropout(dropout_rate) if dropout_rate is not None else None 39 | 40 | output_dim = ((F * 2 ** (L - 1)) * 4) * 3 * 6 # 3x6 for spatial dimensions 41 | if use_frame_similarity: output_dim += 128 42 | if use_color_histograms: output_dim += 128 43 | 44 | self.fc1 = nn.Linear(output_dim, D) 45 | self.cls_layer1 = nn.Linear(D, 1) 46 | self.cls_layer2 = nn.Linear(D, 1) if use_many_hot_targets else None 47 | 48 | self.use_mean_pooling = use_mean_pooling 49 | self.eval() 50 | 51 | def forward(self, inputs): 52 | assert isinstance(inputs, torch.Tensor) and list(inputs.shape[2:]) == [27, 48, 3] and inputs.dtype == torch.uint8, \ 53 | "incorrect input type and/or shape" 54 | # uint8 of shape [B, T, H, W, 3] to float of shape [B, 3, T, H, W] 55 | x = inputs.permute([0, 4, 1, 2, 3]).float() 56 | x = x.div_(255.) 57 | 58 | block_features = [] 59 | for block in self.SDDCNN: 60 | x = block(x) 61 | block_features.append(x) 62 | 63 | if self.use_mean_pooling: 64 | x = torch.mean(x, dim=[3, 4]) 65 | x = x.permute(0, 2, 1) 66 | else: 67 | x = x.permute(0, 2, 3, 4, 1) 68 | x = x.reshape(x.shape[0], x.shape[1], -1) 69 | 70 | if self.frame_sim_layer is not None: 71 | x = torch.cat([self.frame_sim_layer(block_features), x], 2) 72 | 73 | if self.color_hist_layer is not None: 74 | x = torch.cat([self.color_hist_layer(inputs), x], 2) 75 | 76 | x = self.fc1(x) 77 | x = functional.relu(x) 78 | 79 | if self.dropout is not None: 80 | x = self.dropout(x) 81 | 82 | one_hot = self.cls_layer1(x) 83 | 84 | if self.cls_layer2 is not None: 85 | return one_hot, {"many_hot": self.cls_layer2(x)} 86 | 87 | return one_hot 88 | 89 | 90 | class StackedDDCNNV2(nn.Module): 91 | 92 | def __init__(self, 93 | in_filters, 94 | n_blocks, 95 | filters, 96 | shortcut=True, 97 | use_octave_conv=False, # not supported 98 | pool_type="avg", 99 | stochastic_depth_drop_prob=0.0): 100 | super(StackedDDCNNV2, self).__init__() 101 | 102 | if use_octave_conv: 103 | raise NotImplemented("Octave convolution not implemented in Pytorch version of Transnet!") 104 | 105 | assert pool_type == "max" or pool_type == "avg" 106 | if use_octave_conv and pool_type == "max": 107 | print("WARN: Octave convolution was designed with average pooling, not max pooling.") 108 | 109 | self.shortcut = shortcut 110 | self.DDCNN = nn.ModuleList([ 111 | DilatedDCNNV2(in_filters if i == 1 else filters * 4, filters, octave_conv=use_octave_conv, 112 | activation=functional.relu if i != n_blocks else None) for i in range(1, n_blocks + 1) 113 | ]) 114 | self.pool = nn.MaxPool3d(kernel_size=(1, 2, 2)) if pool_type == "max" else nn.AvgPool3d(kernel_size=(1, 2, 2)) 115 | self.stochastic_depth_drop_prob = stochastic_depth_drop_prob 116 | 117 | def forward(self, inputs): 118 | x = inputs 119 | shortcut = None 120 | 121 | for block in self.DDCNN: 122 | x = block(x) 123 | if shortcut is None: 124 | shortcut = x 125 | 126 | x = functional.relu(x) 127 | 128 | if self.shortcut is not None: 129 | if self.stochastic_depth_drop_prob != 0.: 130 | if self.training: 131 | if random.random() < self.stochastic_depth_drop_prob: 132 | x = shortcut 133 | else: 134 | x = x + shortcut 135 | else: 136 | x = (1 - self.stochastic_depth_drop_prob) * x + shortcut 137 | else: 138 | x += shortcut 139 | 140 | x = self.pool(x) 141 | return x 142 | 143 | 144 | class DilatedDCNNV2(nn.Module): 145 | 146 | def __init__(self, 147 | in_filters, 148 | filters, 149 | batch_norm=True, 150 | activation=None, 151 | octave_conv=False): # not supported 152 | super(DilatedDCNNV2, self).__init__() 153 | 154 | if octave_conv: 155 | raise NotImplemented("Octave convolution not implemented in Pytorch version of Transnet!") 156 | 157 | assert not (octave_conv and batch_norm) 158 | 159 | self.Conv3D_1 = Conv3DConfigurable(in_filters, filters, 1, use_bias=not batch_norm) 160 | self.Conv3D_2 = Conv3DConfigurable(in_filters, filters, 2, use_bias=not batch_norm) 161 | self.Conv3D_4 = Conv3DConfigurable(in_filters, filters, 4, use_bias=not batch_norm) 162 | self.Conv3D_8 = Conv3DConfigurable(in_filters, filters, 8, use_bias=not batch_norm) 163 | 164 | self.bn = nn.BatchNorm3d(filters * 4, eps=1e-3) if batch_norm else None 165 | self.activation = activation 166 | 167 | def forward(self, inputs): 168 | conv1 = self.Conv3D_1(inputs) 169 | conv2 = self.Conv3D_2(inputs) 170 | conv3 = self.Conv3D_4(inputs) 171 | conv4 = self.Conv3D_8(inputs) 172 | 173 | x = torch.cat([conv1, conv2, conv3, conv4], dim=1) 174 | 175 | if self.bn is not None: 176 | x = self.bn(x) 177 | 178 | if self.activation is not None: 179 | x = self.activation(x) 180 | 181 | return x 182 | 183 | 184 | class Conv3DConfigurable(nn.Module): 185 | 186 | def __init__(self, 187 | in_filters, 188 | filters, 189 | dilation_rate, 190 | separable=True, 191 | octave=False, # not supported 192 | use_bias=True, 193 | kernel_initializer=None): # not supported 194 | super(Conv3DConfigurable, self).__init__() 195 | 196 | if octave: 197 | raise NotImplemented("Octave convolution not implemented in Pytorch version of Transnet!") 198 | if kernel_initializer is not None: 199 | raise NotImplemented("Kernel initializers are not implemented in Pytorch version of Transnet!") 200 | 201 | assert not (separable and octave) 202 | 203 | if separable: 204 | # (2+1)D convolution https://arxiv.org/pdf/1711.11248.pdf 205 | conv1 = nn.Conv3d(in_filters, 2 * filters, kernel_size=(1, 3, 3), 206 | dilation=(1, 1, 1), padding=(0, 1, 1), bias=False) 207 | conv2 = nn.Conv3d(2 * filters, filters, kernel_size=(3, 1, 1), 208 | dilation=(dilation_rate, 1, 1), padding=(dilation_rate, 0, 0), bias=use_bias) 209 | self.layers = nn.ModuleList([conv1, conv2]) 210 | else: 211 | conv = nn.Conv3d(in_filters, filters, kernel_size=3, 212 | dilation=(dilation_rate, 1, 1), padding=(dilation_rate, 1, 1), bias=use_bias) 213 | self.layers = nn.ModuleList([conv]) 214 | 215 | def forward(self, inputs): 216 | x = inputs 217 | for layer in self.layers: 218 | x = layer(x) 219 | return x 220 | 221 | 222 | class FrameSimilarity(nn.Module): 223 | 224 | def __init__(self, 225 | in_filters, 226 | similarity_dim=128, 227 | lookup_window=101, 228 | output_dim=128, 229 | stop_gradient=False, # not supported 230 | use_bias=False): 231 | super(FrameSimilarity, self).__init__() 232 | 233 | if stop_gradient: 234 | raise NotImplemented("Stop gradient not implemented in Pytorch version of Transnet!") 235 | 236 | self.projection = nn.Linear(in_filters, similarity_dim, bias=use_bias) 237 | self.fc = nn.Linear(lookup_window, output_dim) 238 | 239 | self.lookup_window = lookup_window 240 | assert lookup_window % 2 == 1, "`lookup_window` must be odd integer" 241 | 242 | def forward(self, inputs): 243 | x = torch.cat([torch.mean(x, dim=[3, 4]) for x in inputs], dim=1) 244 | x = torch.transpose(x, 1, 2) 245 | 246 | x = self.projection(x) 247 | x = functional.normalize(x, p=2, dim=2) 248 | 249 | batch_size, time_window = x.shape[0], x.shape[1] 250 | similarities = torch.bmm(x, x.transpose(1, 2)) # [batch_size, time_window, time_window] 251 | similarities_padded = functional.pad(similarities, [(self.lookup_window - 1) // 2, (self.lookup_window - 1) // 2]) 252 | 253 | batch_indices = torch.arange(0, batch_size, device=x.device).view([batch_size, 1, 1]).repeat( 254 | [1, time_window, self.lookup_window]) 255 | time_indices = torch.arange(0, time_window, device=x.device).view([1, time_window, 1]).repeat( 256 | [batch_size, 1, self.lookup_window]) 257 | lookup_indices = torch.arange(0, self.lookup_window, device=x.device).view([1, 1, self.lookup_window]).repeat( 258 | [batch_size, time_window, 1]) + time_indices 259 | 260 | similarities = similarities_padded[batch_indices, time_indices, lookup_indices] 261 | return functional.relu(self.fc(similarities)) 262 | 263 | 264 | class ColorHistograms(nn.Module): 265 | 266 | def __init__(self, 267 | lookup_window=101, 268 | output_dim=None): 269 | super(ColorHistograms, self).__init__() 270 | 271 | self.fc = nn.Linear(lookup_window, output_dim) if output_dim is not None else None 272 | self.lookup_window = lookup_window 273 | assert lookup_window % 2 == 1, "`lookup_window` must be odd integer" 274 | 275 | @staticmethod 276 | def compute_color_histograms(frames): 277 | frames = frames.int() 278 | 279 | def get_bin(frames): 280 | # returns 0 .. 511 281 | R, G, B = frames[:, :, 0], frames[:, :, 1], frames[:, :, 2] 282 | R, G, B = R >> 5, G >> 5, B >> 5 283 | return (R << 6) + (G << 3) + B 284 | 285 | batch_size, time_window, height, width, no_channels = frames.shape 286 | assert no_channels == 3 287 | frames_flatten = frames.view(batch_size * time_window, height * width, 3) 288 | 289 | binned_values = get_bin(frames_flatten) 290 | frame_bin_prefix = (torch.arange(0, batch_size * time_window, device=frames.device) << 9).view(-1, 1) 291 | binned_values = (binned_values + frame_bin_prefix).view(-1) 292 | 293 | histograms = torch.zeros(batch_size * time_window * 512, dtype=torch.int32, device=frames.device) 294 | histograms.scatter_add_(0, binned_values, torch.ones(len(binned_values), dtype=torch.int32, device=frames.device)) 295 | 296 | histograms = histograms.view(batch_size, time_window, 512).float() 297 | histograms_normalized = functional.normalize(histograms, p=2, dim=2) 298 | return histograms_normalized 299 | 300 | def forward(self, inputs): 301 | x = self.compute_color_histograms(inputs) 302 | 303 | batch_size, time_window = x.shape[0], x.shape[1] 304 | similarities = torch.bmm(x, x.transpose(1, 2)) # [batch_size, time_window, time_window] 305 | similarities_padded = functional.pad(similarities, [(self.lookup_window - 1) // 2, (self.lookup_window - 1) // 2]) 306 | 307 | batch_indices = torch.arange(0, batch_size, device=x.device).view([batch_size, 1, 1]).repeat( 308 | [1, time_window, self.lookup_window]) 309 | time_indices = torch.arange(0, time_window, device=x.device).view([1, time_window, 1]).repeat( 310 | [batch_size, 1, self.lookup_window]) 311 | lookup_indices = torch.arange(0, self.lookup_window, device=x.device).view([1, 1, self.lookup_window]).repeat( 312 | [batch_size, time_window, 1]) + time_indices 313 | 314 | similarities = similarities_padded[batch_indices, time_indices, lookup_indices] 315 | 316 | if self.fc is not None: 317 | return functional.relu(self.fc(similarities)) 318 | return similarities 319 | -------------------------------------------------------------------------------- /training/bi_tempered_loss.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Robust Bi-Tempered Logistic Loss Based on Bregman Divergences. 17 | 18 | Source: https://arxiv.org/pdf/1906.03361.pdf 19 | """ 20 | 21 | import functools 22 | import tensorflow as tf 23 | 24 | 25 | def log_t(u, t): 26 | """Compute log_t for `u`.""" 27 | 28 | def _internal_log_t(u, t): 29 | return (u**(1.0 - t) - 1.0) / (1.0 - t) 30 | 31 | return tf.cond(tf.equal(t, 1.0), 32 | lambda: tf.math.log(u), 33 | functools.partial(_internal_log_t, u, t)) 34 | 35 | 36 | def exp_t(u, t): 37 | """Compute exp_t for `u`.""" 38 | 39 | def _internal_exp_t(u, t): 40 | return tf.nn.relu(1.0 + (1.0 - t) * u)**(1.0 / (1.0 - t)) 41 | 42 | return tf.cond(tf.equal(t, 1.0), 43 | lambda: tf.exp(u), 44 | functools.partial(_internal_exp_t, u, t)) 45 | 46 | 47 | def compute_normalization_fixed_point(activations, t, num_iters=5): 48 | """Returns the normalization value for each example (t > 1.0). 49 | 50 | Args: 51 | activations: A multi-dimensional tensor with last dimension `num_classes`. 52 | t: Temperature 2 (> 1.0 for tail heaviness). 53 | num_iters: Number of iterations to run the method. 54 | Return: A tensor of same rank as activation with the last dimension being 1. 55 | """ 56 | 57 | mu = tf.reduce_max(activations, -1, keepdims=True) 58 | normalized_activations_step_0 = activations - mu 59 | shape_normalized_activations = tf.shape(normalized_activations_step_0) 60 | 61 | def iter_condition(i, unused_normalized_activations): 62 | return i < num_iters 63 | 64 | def iter_body(i, normalized_activations): 65 | logt_partition = tf.reduce_sum(exp_t(normalized_activations, t), -1, keepdims=True) 66 | normalized_activations_t = tf.reshape( 67 | normalized_activations_step_0 * tf.pow(logt_partition, 1.0 - t), shape_normalized_activations) 68 | return [i + 1, normalized_activations_t] 69 | 70 | _, normalized_activations_t = tf.while_loop( 71 | iter_condition, 72 | iter_body, [0, normalized_activations_step_0], 73 | maximum_iterations=num_iters) 74 | logt_partition = tf.reduce_sum(exp_t(normalized_activations_t, t), -1, keepdims=True) 75 | return -log_t(1.0 / logt_partition, t) + mu 76 | 77 | 78 | def compute_normalization_binary_search(activations, t, num_iters=10): 79 | """Returns the normalization value for each example (t < 1.0). 80 | 81 | Args: 82 | activations: A multi-dimensional tensor with last dimension `num_classes`. 83 | t: Temperature 2 (< 1.0 for finite support). 84 | num_iters: Number of iterations to run the method. 85 | Return: A tensor of same rank as activation with the last dimension being 1. 86 | """ 87 | mu = tf.reduce_max(activations, -1, keepdims=True) 88 | normalized_activations = activations - mu 89 | shape_activations = tf.shape(activations) 90 | effective_dim = tf.cast( 91 | tf.reduce_sum( 92 | tf.cast(tf.greater(normalized_activations, -1.0 / (1.0 - t)), tf.int32), 93 | -1, keepdims=True), tf.float32) 94 | shape_partition = tf.concat([shape_activations[:-1], [1]], 0) 95 | lower = tf.zeros(shape_partition) 96 | upper = -log_t(1.0 / effective_dim, t) * tf.ones(shape_partition) 97 | 98 | def iter_condition(i, unused_lower, unused_upper): 99 | return i < num_iters 100 | 101 | def iter_body(i, lower, upper): 102 | logt_partition = (upper + lower)/2.0 103 | sum_probs = tf.reduce_sum(exp_t(normalized_activations - logt_partition, t), -1, keepdims=True) 104 | update = tf.cast(tf.less(sum_probs, 1.0), tf.float32) 105 | lower = tf.reshape(lower * update + (1.0 - update) * logt_partition, shape_partition) 106 | upper = tf.reshape(upper * (1.0 - update) + update * logt_partition, shape_partition) 107 | return [i + 1, lower, upper] 108 | 109 | _, lower, upper = tf.while_loop( 110 | iter_condition, 111 | iter_body, [0, lower, upper], 112 | maximum_iterations=num_iters) 113 | logt_partition = (upper + lower)/2.0 114 | return logt_partition + mu 115 | 116 | 117 | def compute_normalization(activations, t, num_iters=5): 118 | """Returns the normalization value for each example. 119 | 120 | Args: 121 | activations: A multi-dimensional tensor with last dimension `num_classes`. 122 | t: Temperature 2 (< 1.0 for finite support, > 1.0 for tail heaviness). 123 | num_iters: Number of iterations to run the method. 124 | Return: A tensor of same rank as activation with the last dimension being 1. 125 | """ 126 | return tf.cond(tf.less(t, 1.0), 127 | functools.partial(compute_normalization_binary_search, activations, t, num_iters), 128 | functools.partial(compute_normalization_fixed_point, activations, t, num_iters)) 129 | 130 | 131 | def _internal_bi_tempered_logistic_loss(activations, labels, t1, t2): 132 | """Computes the Bi-Tempered logistic loss. 133 | 134 | Args: 135 | activations: A multi-dimensional tensor with last dimension `num_classes`. 136 | labels: batch_size 137 | t1: Temperature 1 (< 1.0 for boundedness). 138 | t2: Temperature 2 (> 1.0 for tail heaviness). 139 | 140 | Returns: 141 | A loss tensor for robust loss. 142 | """ 143 | if t2 == 1.0: 144 | normalization_constants = tf.math.log(tf.reduce_sum(tf.exp(activations), -1, keepdims=True)) 145 | if t1 == 1.0: 146 | return normalization_constants + tf.reduce_sum( 147 | tf.multiply(labels, tf.math.log(labels + 1e-10) - activations), -1, 148 | keepdims=True) 149 | else: 150 | shifted_activations = tf.exp(activations - normalization_constants) 151 | one_minus_t2 = 1.0 152 | else: 153 | one_minus_t1 = (1.0 - t1) 154 | one_minus_t2 = (1.0 - t2) 155 | normalization_constants = compute_normalization(activations, t2, num_iters=5) 156 | shifted_activations = tf.nn.relu(1.0 + one_minus_t2 * (activations - normalization_constants)) 157 | 158 | if t1 == 1.0: 159 | return tf.reduce_sum( 160 | tf.multiply( 161 | tf.math.log(labels + 1e-10) - 162 | tf.math.log(tf.pow(shifted_activations, 1.0 / one_minus_t2)), labels 163 | ), -1, keepdims=True) 164 | else: 165 | beta = 1.0 + one_minus_t1 166 | logt_probs = (tf.pow(shifted_activations, one_minus_t1 / one_minus_t2) - 1.0) / one_minus_t1 167 | return tf.reduce_sum( 168 | tf.multiply(log_t(labels, t1) - logt_probs, labels) - 1.0 / beta * 169 | (tf.pow(labels, beta) - tf.pow(shifted_activations, beta / one_minus_t2)), -1) 170 | 171 | 172 | @tf.function(autograph=False) 173 | def tempered_sigmoid(activations, t, num_iters=5): 174 | """Tempered sigmoid function. 175 | 176 | Args: 177 | activations: Activations for the positive class for binary classification. 178 | t: Temperature tensor > 0.0. 179 | num_iters: Number of iterations to run the method. 180 | 181 | Returns: 182 | A probabilities tensor. 183 | """ 184 | t = tf.convert_to_tensor(t) 185 | input_shape = tf.shape(activations) 186 | activations_2d = tf.reshape(activations, [-1, 1]) 187 | internal_activations = tf.concat([tf.zeros_like(activations_2d), activations_2d], 1) 188 | normalization_constants = tf.cond( 189 | tf.equal(t, 1.0), 190 | lambda: tf.math.log(tf.reduce_sum(tf.exp(internal_activations), -1, keepdims=True)), 191 | functools.partial(compute_normalization, internal_activations, t, num_iters)) 192 | internal_probabilities = exp_t(internal_activations - normalization_constants, t) 193 | one_class_probabilities = tf.split(internal_probabilities, 2, axis=1)[1] 194 | return tf.reshape(one_class_probabilities, input_shape) 195 | 196 | 197 | @tf.function(autograph=False) 198 | def tempered_softmax(activations, t, num_iters=5): 199 | """Tempered softmax function. 200 | 201 | Args: 202 | activations: A multi-dimensional tensor with last dimension `num_classes`. 203 | t: Temperature tensor > 0.0. 204 | num_iters: Number of iterations to run the method. 205 | 206 | Returns: 207 | A probabilities tensor. 208 | """ 209 | t = tf.convert_to_tensor(t) 210 | normalization_constants = tf.cond( 211 | tf.equal(t, 1.0), 212 | lambda: tf.math.log(tf.reduce_sum(tf.exp(activations), -1, keepdims=True)), 213 | functools.partial(compute_normalization, activations, t, num_iters)) 214 | return exp_t(activations - normalization_constants, t) 215 | 216 | 217 | @tf.function(autograph=False) 218 | def bi_tempered_binary_logistic_loss(activations, 219 | labels, 220 | t1, 221 | t2, 222 | label_smoothing=0.0, 223 | num_iters=5): 224 | """Bi-Tempered binary logistic loss. 225 | 226 | Args: 227 | activations: A tensor containing activations for class 1. 228 | labels: A tensor with shape and dtype as activations. 229 | t1: Temperature 1 (< 1.0 for boundedness). 230 | t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support). 231 | label_smoothing: Label smoothing 232 | num_iters: Number of iterations to run the method. 233 | 234 | Returns: 235 | A loss tensor. 236 | """ 237 | with tf.name_scope('binary_bitempered_logistic'): 238 | t1 = tf.convert_to_tensor(t1) 239 | t2 = tf.convert_to_tensor(t2) 240 | out_shape = tf.shape(labels) 241 | labels_2d = tf.reshape(labels, [-1, 1]) 242 | activations_2d = tf.reshape(activations, [-1, 1]) 243 | internal_labels = tf.concat([1.0 - labels_2d, labels_2d], 1) 244 | internal_logits = tf.concat([tf.zeros_like(activations_2d), activations_2d], 1) 245 | losses = bi_tempered_logistic_loss(internal_logits, internal_labels, t1, t2, label_smoothing, num_iters) 246 | return tf.reshape(losses, out_shape) 247 | 248 | 249 | @tf.function(autograph=False) 250 | def bi_tempered_logistic_loss(activations, 251 | labels, 252 | t1, 253 | t2, 254 | label_smoothing=0.0, 255 | num_iters=5): 256 | """Bi-Tempered Logistic Loss with custom gradient. 257 | 258 | Args: 259 | activations: A multi-dimensional tensor with last dimension `num_classes`. 260 | labels: A tensor with shape and dtype as activations. 261 | t1: Temperature 1 (< 1.0 for boundedness). 262 | t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support). 263 | label_smoothing: Label smoothing parameter between [0, 1). 264 | num_iters: Number of iterations to run the method. 265 | 266 | Returns: 267 | A loss tensor. 268 | """ 269 | with tf.name_scope('bitempered_logistic'): 270 | t1 = tf.convert_to_tensor(t1) 271 | t2 = tf.convert_to_tensor(t2) 272 | if label_smoothing > 0.0: 273 | num_classes = tf.cast(tf.shape(labels)[-1], tf.float32) 274 | labels = (1 - num_classes / (num_classes - 1) * label_smoothing) * labels + label_smoothing / (num_classes - 1) 275 | 276 | @tf.custom_gradient 277 | def _custom_gradient_bi_tempered_logistic_loss(activations): 278 | """Bi-Tempered Logistic Loss with custom gradient. 279 | 280 | Args: 281 | activations: A multi-dimensional tensor with last dim `num_classes`. 282 | 283 | Returns: 284 | A loss tensor, grad. 285 | """ 286 | with tf.name_scope('gradient_bitempered_logistic'): 287 | probabilities = tempered_softmax(activations, t2, num_iters) 288 | loss_values = tf.multiply( 289 | labels, log_t(labels + 1e-10, t1) - log_t(probabilities, t1)) - 1.0 / (2.0 - t1) * ( 290 | tf.pow(labels, 2.0 - t1) - tf.pow(probabilities, 2.0 - t1)) 291 | 292 | def grad(d_loss): 293 | """Explicit gradient calculation. 294 | 295 | Args: 296 | d_loss: Infinitesimal change in the loss value. 297 | Returns: Loss gradient. 298 | """ 299 | delta_probs = probabilities - labels 300 | forget_factor = tf.pow(probabilities, t2 - t1) 301 | delta_probs_times_forget_factor = tf.multiply(delta_probs, forget_factor) 302 | delta_forget_sum = tf.reduce_sum(delta_probs_times_forget_factor, -1, keepdims=True) 303 | escorts = tf.pow(probabilities, t2) 304 | escorts = escorts / tf.reduce_sum(escorts, -1, keepdims=True) 305 | derivative = delta_probs_times_forget_factor - tf.multiply(escorts, delta_forget_sum) 306 | return tf.multiply(d_loss, derivative) 307 | 308 | return loss_values, grad 309 | 310 | loss_values = _custom_gradient_bi_tempered_logistic_loss(activations) 311 | return tf.reduce_sum(loss_values, -1) 312 | 313 | 314 | def sparse_bi_tempered_logistic_loss(activations, labels, t1, t2, num_iters=5): 315 | """Sparse Bi-Tempered Logistic Loss with custom gradient. 316 | 317 | Args: 318 | activations: A multi-dimensional tensor with last dimension `num_classes`. 319 | labels: A tensor with dtype of int32. 320 | t1: Temperature 1 (< 1.0 for boundedness). 321 | t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support). 322 | num_iters: Number of iterations to run the method. 323 | 324 | Returns: 325 | A loss tensor. 326 | """ 327 | with tf.name_scope('sparse_bitempered_logistic'): 328 | t1 = tf.convert_to_tensor(t1) 329 | t2 = tf.convert_to_tensor(t2) 330 | num_classes = tf.shape(activations)[-1] 331 | 332 | @tf.custom_gradient 333 | def _custom_gradient_sparse_bi_tempered_logistic_loss(activations): 334 | """Sparse Bi-Tempered Logistic Loss with custom gradient. 335 | 336 | Args: 337 | activations: A multi-dimensional tensor with last dim `num_classes`. 338 | 339 | Returns: 340 | A loss tensor, grad. 341 | """ 342 | with tf.name_scope('gradient_sparse_bitempered_logistic'): 343 | probabilities = tempered_softmax(activations, t2, num_iters) 344 | loss_values = -log_t( 345 | tf.reshape( 346 | tf.gather_nd(probabilities, 347 | tf.where(tf.one_hot(labels, num_classes))), 348 | tf.shape(activations)[:-1]), t1) - 1.0 / (2.0 - t1) * ( 349 | 1.0 - tf.reduce_sum(tf.pow(probabilities, 2.0 - t1), -1)) 350 | 351 | def grad(d_loss): 352 | """Explicit gradient calculation. 353 | 354 | Args: 355 | d_loss: Infinitesimal change in the loss value. 356 | Returns: Loss gradient. 357 | """ 358 | delta_probs = probabilities - tf.one_hot(labels, num_classes) 359 | forget_factor = tf.pow(probabilities, t2 - t1) 360 | delta_probs_times_forget_factor = tf.multiply(delta_probs, forget_factor) 361 | delta_forget_sum = tf.reduce_sum(delta_probs_times_forget_factor, -1, keepdims=True) 362 | escorts = tf.pow(probabilities, t2) 363 | escorts = escorts / tf.reduce_sum(escorts, -1, keepdims=True) 364 | derivative = delta_probs_times_forget_factor - tf.multiply(escorts, delta_forget_sum) 365 | return tf.multiply(d_loss, derivative) 366 | 367 | return loss_values, grad 368 | 369 | loss_values = _custom_gradient_sparse_bi_tempered_logistic_loss(activations) 370 | return loss_values 371 | -------------------------------------------------------------------------------- /training/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gin 3 | import glob 4 | import argparse 5 | import datetime 6 | from PIL import Image 7 | import tensorflow as tf 8 | import gin.tf.external_configurables 9 | 10 | import models 11 | import transnet 12 | import metrics_utils 13 | import input_processing 14 | import visualization_utils 15 | from bi_tempered_loss import bi_tempered_binary_logistic_loss, tempered_sigmoid 16 | import weight_decay_optimizers 17 | gin.config.external_configurable(weight_decay_optimizers.SGDW, 'weight_decay_optimizers.SGDW') 18 | 19 | 20 | @gin.configurable("options", blacklist=["create_dir_and_summaries"]) 21 | def get_options_dict(n_epochs=None, 22 | log_dir=gin.REQUIRED, 23 | log_name=gin.REQUIRED, 24 | trn_files=gin.REQUIRED, 25 | tst_files=gin.REQUIRED, 26 | input_shape=gin.REQUIRED, 27 | test_only=False, 28 | restore=None, 29 | restore_resnet_features=None, 30 | original_transnet=False, 31 | transition_only_trn_files=None, 32 | create_dir_and_summaries=True, 33 | transition_only_data_fraction=0.3, 34 | c3d_net=False, 35 | bi_tempered_loss=False, 36 | bi_tempered_loss_temp2=1., 37 | learning_rate_schedule=None, 38 | learning_rate_decay=None): 39 | trn_files_ = [] 40 | for fn in trn_files: 41 | trn_files_.extend(glob.glob(fn)) 42 | 43 | if transition_only_trn_files is not None: 44 | transition_trn_files_ = [] 45 | for fn in transition_only_trn_files: 46 | transition_trn_files_.extend(glob.glob(fn)) 47 | 48 | tst_files_ = {} 49 | for k, v in tst_files.items(): 50 | tst_files_[k] = [] 51 | for fn in v: 52 | tst_files_[k].extend(glob.glob(fn)) 53 | 54 | log_dir = os.path.join(log_dir, log_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")) 55 | summary_writer = tf.summary.create_file_writer(log_dir) if create_dir_and_summaries else None 56 | 57 | config_str = gin.config_str().replace("# ", "### ").split("\n") 58 | config_str = "\n\n".join([l for l in config_str if not l.startswith("### =====")]) 59 | 60 | if create_dir_and_summaries: 61 | with summary_writer.as_default(): 62 | tf.summary.text("config", config_str, step=0) 63 | with open(os.path.join(log_dir, "config.gin"), "w") as f: 64 | f.write(gin.config_str()) 65 | 66 | print("\n{}\n".format(log_name.upper())) 67 | 68 | return { 69 | "n_epochs": n_epochs, 70 | "log_dir": log_dir, 71 | "summary_writer": summary_writer, 72 | "trn_files": trn_files_, 73 | "tst_files": tst_files_, 74 | "input_shape": input_shape, 75 | "test_only": test_only, 76 | "restore": restore, 77 | "restore_resnet_features": restore_resnet_features, 78 | "original_transnet": original_transnet, 79 | "transition_only_trn_files": transition_trn_files_ if transition_only_trn_files is not None else None, 80 | "transition_only_data_fraction": transition_only_data_fraction, 81 | "c3d_net": c3d_net, 82 | "bi_tempered_loss": bi_tempered_loss, 83 | "bi_tempered_loss_temp2": bi_tempered_loss_temp2, 84 | "learning_rate_schedule": learning_rate_schedule, 85 | "learning_rate_decay": learning_rate_decay 86 | } 87 | 88 | 89 | @gin.configurable("training", blacklist=["net", "summary_writer"]) 90 | class Trainer: 91 | 92 | def __init__(self, net, summary_writer, 93 | optimizer=None, 94 | log_freq=None, 95 | grad_clipping=10., 96 | n_batches_per_epoch=None, 97 | evaluate_on_middle_frames_only=False): 98 | self.net = net 99 | self.summary_writer = summary_writer 100 | self.optimizer = optimizer() if optimizer is not None else None 101 | self.log_freq = log_freq 102 | self.grad_clipping = grad_clipping 103 | self.n_batches_per_epoch = n_batches_per_epoch 104 | self.mean_metrics = dict([(name, tf.keras.metrics.Mean(name=name, dtype=tf.float32)) for name in 105 | ["loss/total", "loss/one_hot_loss", "loss/many_hot_loss", "loss/l2_loss", 106 | "loss/comb_reg"]]) 107 | self.results = {} 108 | self.evaluate_on_middle_frames_only = evaluate_on_middle_frames_only 109 | 110 | @gin.configurable("loss", blacklist=["one_hot_pred", "one_hot_gt", "many_hot_pred", "many_hot_gt", "reg_losses"]) 111 | def compute_loss(self, one_hot_pred, one_hot_gt, many_hot_pred=None, many_hot_gt=None, 112 | transition_weight=1., 113 | many_hot_loss_weight=0., 114 | l2_loss_weight=0., 115 | dynamic_weight=None, 116 | reg_losses=None, 117 | bi_tempered_loss=False, 118 | bi_tempered_loss_temp1=1., 119 | bi_tempered_loss_temp2=1.): 120 | assert not (dynamic_weight and transition_weight != 1) 121 | 122 | one_hot_pred = one_hot_pred[:, :, 0] 123 | 124 | with tf.name_scope("losses"): 125 | if bi_tempered_loss: 126 | one_hot_loss = bi_tempered_binary_logistic_loss(activations=one_hot_pred, 127 | labels=tf.cast(one_hot_gt, tf.float32), 128 | t1=bi_tempered_loss_temp1, t2=bi_tempered_loss_temp2, 129 | label_smoothing=0.) 130 | else: 131 | one_hot_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=one_hot_pred, 132 | labels=tf.cast(one_hot_gt, tf.float32)) 133 | if transition_weight != 1: 134 | one_hot_loss *= 1 + tf.cast(one_hot_gt, tf.float32) * (transition_weight - 1) 135 | elif dynamic_weight is not None: 136 | pred_sigmoid = tf.nn.sigmoid(one_hot_pred) 137 | trans_weight = 4 * (dynamic_weight - 1) * (pred_sigmoid * pred_sigmoid - pred_sigmoid + 0.25) 138 | trans_weight = tf.where(pred_sigmoid < 0.5, trans_weight, 0) 139 | trans_weight = tf.stop_gradient(trans_weight) 140 | one_hot_loss *= 1 + tf.cast(one_hot_gt, tf.float32) * trans_weight 141 | 142 | one_hot_loss = tf.reduce_mean(one_hot_loss) 143 | 144 | many_hot_loss = 0. 145 | if many_hot_loss_weight != 0. and many_hot_pred is not None: 146 | many_hot_loss = many_hot_loss_weight * tf.reduce_mean( 147 | tf.nn.sigmoid_cross_entropy_with_logits(logits=many_hot_pred[:, :, 0], 148 | labels=tf.cast(many_hot_gt, tf.float32))) 149 | 150 | l2_loss = 0. 151 | if l2_loss_weight != 0.: 152 | l2_loss = l2_loss_weight * tf.add_n([tf.nn.l2_loss(v) for v in self.net.trainable_weights], 153 | name="l2_loss") 154 | 155 | total_loss = one_hot_loss + many_hot_loss + l2_loss 156 | losses = { 157 | "loss/one_hot_loss": one_hot_loss, 158 | "loss/many_hot_loss": many_hot_loss, 159 | "loss/l2_loss": l2_loss 160 | } 161 | 162 | if reg_losses is not None: 163 | for name, value in reg_losses.items(): 164 | if value is not None: 165 | total_loss += value 166 | losses["loss/" + name] = value 167 | losses["loss/total"] = total_loss 168 | 169 | return total_loss, losses 170 | 171 | @tf.function(autograph=False) 172 | def train_batch(self, frame_sequence, one_hot_gt, many_hot_gt, run_summaries=False): 173 | with tf.GradientTape() as tape: 174 | one_hot_pred = self.net(frame_sequence, training=True) 175 | 176 | dict_ = {} 177 | if isinstance(one_hot_pred, tuple): 178 | one_hot_pred, dict_ = one_hot_pred 179 | 180 | many_hot_pred = dict_.get("many_hot", None) 181 | alphas = dict_.get("alphas", None) 182 | comb_reg_loss = dict_.get("comb_reg_loss", None) 183 | 184 | total_loss, losses_dict = self.compute_loss(one_hot_pred, one_hot_gt, 185 | many_hot_pred, many_hot_gt, 186 | reg_losses={"comb_reg": comb_reg_loss}) 187 | 188 | grads = tape.gradient(total_loss, self.net.trainable_weights) 189 | with tf.name_scope("grad_check"): 190 | grads = [ 191 | tf.where(tf.math.is_nan(g), tf.zeros_like(g), g) 192 | for g in grads] 193 | grads, grad_norm = tf.clip_by_global_norm(grads, self.grad_clipping) 194 | self.optimizer.apply_gradients(zip(grads, self.net.trainable_weights)) 195 | 196 | for loss_name, loss_value in losses_dict.items(): 197 | self.mean_metrics[loss_name].update_state(loss_value) 198 | 199 | with self.summary_writer.as_default(): 200 | tf.summary.scalar("grads/norm", grad_norm, step=self.optimizer.iterations) 201 | tf.summary.scalar("loss/immediate/total", total_loss, step=self.optimizer.iterations) 202 | 203 | if not run_summaries: 204 | return 205 | 206 | with self.summary_writer.as_default(): 207 | for grad, var in zip(grads, self.net.trainable_weights): 208 | tf.summary.histogram("grad/" + var.name, grad, step=self.optimizer.iterations) 209 | tf.summary.histogram("var/" + var.name, var.value(), step=self.optimizer.iterations) 210 | 211 | for loss_name, loss_value in losses_dict.items(): 212 | tf.summary.scalar(loss_name, self.mean_metrics[loss_name].result(), step=self.optimizer.iterations) 213 | self.mean_metrics[loss_name].reset_states() 214 | tf.summary.scalar("learning_rate", self.optimizer.learning_rate, step=self.optimizer.iterations) 215 | 216 | return one_hot_pred, alphas if alphas is not None else many_hot_pred, self.optimizer.iterations 217 | 218 | def train_epoch(self, dataset, logit_fc=tf.sigmoid): 219 | print("\nTraining") 220 | for metric in self.mean_metrics.values(): 221 | metric.reset_states() 222 | 223 | for i, (frame_sequence, one_hot_gt, many_hot_gt) in dataset.enumerate(): 224 | if i % self.log_freq == self.log_freq - 1: 225 | one_hot_pred, many_hot_pred, step = self.train_batch( 226 | frame_sequence, one_hot_gt, many_hot_gt, run_summaries=True) 227 | 228 | with self.summary_writer.as_default(): 229 | visualizations = visualization_utils.visualize_predictions( 230 | frame_sequence.numpy()[:, :, :, :, :3], logit_fc(one_hot_pred).numpy(), one_hot_gt.numpy(), 231 | logit_fc(many_hot_pred).numpy() if many_hot_pred is not None else None, many_hot_gt.numpy()) 232 | tf.summary.image("train/visualization", visualizations, step=step) 233 | 234 | for metric in self.mean_metrics.values(): 235 | metric.reset_states() 236 | else: 237 | self.train_batch(frame_sequence, one_hot_gt, many_hot_gt, run_summaries=False) 238 | print("\r", i.numpy(), end="") 239 | if self.n_batches_per_epoch is not None and self.n_batches_per_epoch == i: 240 | break 241 | 242 | @tf.function(autograph=False) 243 | def test_batch(self, frame_sequence, one_hot_gt, many_hot_gt): 244 | one_hot_pred = self.net(frame_sequence, training=False) 245 | 246 | dict_ = {} 247 | if isinstance(one_hot_pred, tuple): 248 | one_hot_pred, dict_ = one_hot_pred 249 | 250 | many_hot_pred = dict_.get("many_hot", None) 251 | alphas = dict_.get("alphas", None) 252 | comb_reg_loss = dict_.get("comb_reg_loss", None) 253 | 254 | total_loss, losses_dict = self.compute_loss(one_hot_pred, one_hot_gt, 255 | many_hot_pred, many_hot_gt, 256 | reg_losses={"comb_reg": comb_reg_loss}) 257 | 258 | for loss_name, loss_value in losses_dict.items(): 259 | self.mean_metrics[loss_name].update_state(loss_value) 260 | 261 | return one_hot_pred, alphas if alphas is not None else many_hot_pred 262 | 263 | def test_epoch(self, datasets, epoch_no, save_visualization_to=None, trace=False, logit_fc=tf.sigmoid): 264 | for metric in self.mean_metrics.values(): 265 | metric.reset_states() 266 | 267 | for ds_name, dataset in datasets: 268 | print("\nEvaluating", ds_name.upper()) 269 | one_hot_gt_list, one_hot_pred_list = [], [] 270 | 271 | for i, (frame_sequence, one_hot_gt, many_hot_gt) in dataset.enumerate(): 272 | if trace: 273 | tf.summary.trace_on(graph=True, profiler=False) 274 | one_hot_pred, many_hot_pred = self.test_batch(frame_sequence, one_hot_gt, many_hot_gt) 275 | with self.summary_writer.as_default(): 276 | if trace: 277 | tf.summary.trace_export(name="graph", step=0) 278 | trace = False 279 | 280 | if self.evaluate_on_middle_frames_only: 281 | x = int(one_hot_gt.shape[1] * 0.25) 282 | one_hot_gt_list.extend(one_hot_gt.numpy()[:, x:-x].flatten().tolist()) 283 | one_hot_pred_list.extend(logit_fc(one_hot_pred).numpy()[:, x:-x].flatten().tolist()) 284 | else: 285 | one_hot_gt_list.extend(one_hot_gt.numpy().flatten().tolist()) 286 | one_hot_pred_list.extend(logit_fc(one_hot_pred).numpy().flatten().tolist()) 287 | 288 | print("\r", i.numpy(), end="") 289 | if i != 0 or save_visualization_to is None: 290 | continue 291 | 292 | with self.summary_writer.as_default(): 293 | visualizations = visualization_utils.visualize_predictions( 294 | frame_sequence.numpy()[:, :, :, :, :3], logit_fc(one_hot_pred).numpy(), one_hot_gt.numpy(), 295 | logit_fc(many_hot_pred).numpy() if many_hot_pred is not None else None, many_hot_gt.numpy()) 296 | tf.summary.image("test/{}/visualization".format(ds_name), visualizations, step=epoch_no) 297 | 298 | for idx, img in enumerate(visualizations): 299 | Image.fromarray(img).save("{}_{}_{:02d}.png".format(save_visualization_to, ds_name, idx)) 300 | 301 | with self.summary_writer.as_default(): 302 | for loss_name, loss in self.mean_metrics.items(): 303 | tf.summary.scalar("test/{}/{}".format(ds_name, loss_name), loss.result(), step=epoch_no) 304 | 305 | f1 = metrics_utils.create_scene_based_summaries(one_hot_pred_list, one_hot_gt_list, 306 | prefix="test/" + ds_name, step=epoch_no) 307 | if self.results.get(ds_name, 0) < f1: 308 | self.results[ds_name] = f1 309 | 310 | def finish(self): 311 | with self.summary_writer.as_default(): 312 | for ds_name, f1 in self.results.items(): 313 | tf.summary.scalar("test/" + ds_name + "/scene/best_f1", f1, step=0) 314 | 315 | 316 | if __name__ == "__main__": 317 | 318 | parser = argparse.ArgumentParser(description="Train TransNet") 319 | parser.add_argument("config", help="path to config") 320 | args = parser.parse_args() 321 | 322 | gin.parse_config_file(args.config) 323 | options = get_options_dict() 324 | 325 | trn_ds = input_processing.train_pipeline(options["trn_files"]) if len(options["trn_files"]) > 0 else None 326 | if options["transition_only_trn_files"] is not None: 327 | trn_ds_ = input_processing.train_transition_pipeline(options["transition_only_trn_files"]) 328 | if trn_ds is not None: 329 | frac = options["transition_only_data_fraction"] 330 | trn_ds = tf.data.experimental.sample_from_datasets([trn_ds, trn_ds_], weights=[1 - frac, frac]) 331 | else: 332 | trn_ds = trn_ds_ 333 | 334 | tst_ds = [(name, input_processing.test_pipeline(files)) 335 | for name, files in options["tst_files"].items()] 336 | 337 | if options["original_transnet"]: 338 | net = models.OriginalTransNet() 339 | logit_fc = lambda x: tf.nn.softmax(x)[:, :, 1] 340 | elif options["c3d_net"]: 341 | net = models.C3DNet() 342 | logit_fc = tf.sigmoid 343 | else: 344 | net = transnet.TransNetV2() 345 | logit_fc = tf.sigmoid 346 | if options["bi_tempered_loss"]: 347 | logit_fc = lambda x: tempered_sigmoid(x, t=options["bi_tempered_loss_temp2"]) 348 | 349 | net(tf.zeros([1] + options["input_shape"], tf.float32)) 350 | trainer = Trainer(net, options["summary_writer"]) 351 | 352 | if options["restore_resnet_features"] is not None: 353 | net.resnet_layers.restore_me(options["restore_resnet_features"]) 354 | print("ResNet weights restored from", options["restore_resnet_features"]) 355 | 356 | if options["restore"] is not None: 357 | net.load_weights(options["restore"]) 358 | print("Weights restored from", options["restore"]) 359 | 360 | if options["test_only"]: 361 | trainer.test_epoch(tst_ds, 0, os.path.join(options["log_dir"], "visualization-00"), trace=True, 362 | logit_fc=logit_fc) 363 | exit() 364 | 365 | for epoch in range(1, options["n_epochs"] + 1): 366 | trainer.train_epoch(trn_ds, logit_fc=logit_fc) 367 | net.save_weights(os.path.join(options["log_dir"], "weights-{}.h5".format(epoch))) 368 | 369 | trainer.test_epoch(tst_ds, epoch, os.path.join(options["log_dir"], "visualization-{:02d}".format(epoch)), 370 | trace=epoch == 1, logit_fc=logit_fc) 371 | 372 | if options["learning_rate_schedule"] is not None: 373 | if epoch in options["learning_rate_schedule"]: 374 | trainer.optimizer.learning_rate = \ 375 | trainer.optimizer.learning_rate.numpy() * options["learning_rate_decay"] 376 | trainer.finish() 377 | -------------------------------------------------------------------------------- /training/weight_decay_optimizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Base class to make optimizers weight decay ready.""" 16 | 17 | import gin 18 | import numpy as np 19 | import tensorflow as tf 20 | from typing import Union, Callable, Type 21 | 22 | 23 | FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64] 24 | 25 | 26 | def _ref(var): 27 | return var.ref() if hasattr(var, "ref") else var.experimental_ref() 28 | 29 | 30 | class DecoupledWeightDecayExtension: 31 | """This class allows to extend optimizers with decoupled weight decay. 32 | 33 | It implements the decoupled weight decay described by Loshchilov & Hutter 34 | (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is 35 | decoupled from the optimization steps w.r.t. to the loss function. 36 | For SGD variants, this simplifies hyperparameter search since it decouples 37 | the settings of weight decay and learning rate. 38 | For adaptive gradient algorithms, it regularizes variables with large 39 | gradients more than L2 regularization would, which was shown to yield 40 | better training loss and generalization error in the paper above. 41 | 42 | This class alone is not an optimizer but rather extends existing 43 | optimizers with decoupled weight decay. We explicitly define the two 44 | examples used in the above paper (SGDW and AdamW), but in general this 45 | can extend any OptimizerX by using 46 | `extend_with_decoupled_weight_decay( 47 | OptimizerX, weight_decay=weight_decay)`. 48 | In order for it to work, it must be the first class the Optimizer with 49 | weight decay inherits from, e.g. 50 | 51 | ```python 52 | class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): 53 | def __init__(self, weight_decay, *args, **kwargs): 54 | super(AdamW, self).__init__(weight_decay, *args, **kwargs). 55 | ``` 56 | 57 | Note: this extension decays weights BEFORE applying the update based 58 | on the gradient, i.e. this extension only has the desired behaviour for 59 | optimizers which do not depend on the value of'var' in the update step! 60 | 61 | Note: when applying a decay to the learning rate, be sure to manually apply 62 | the decay to the `weight_decay` as well. For example: 63 | 64 | ```python 65 | step = tf.Variable(0, trainable=False) 66 | schedule = tf.optimizers.schedules.PiecewiseConstantDecay( 67 | [10000, 15000], [1e-0, 1e-1, 1e-2]) 68 | # lr and wd can be a function or a tensor 69 | lr = 1e-1 * schedule(step) 70 | wd = lambda: 1e-4 * schedule(step) 71 | 72 | # ... 73 | 74 | optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) 75 | ``` 76 | """ 77 | 78 | def __init__(self, weight_decay: Union[FloatTensorLike, Callable], **kwargs): 79 | """Extension class that adds weight decay to an optimizer. 80 | 81 | Args: 82 | weight_decay: A `Tensor` or a floating point value, the factor by 83 | which a variable is decayed in the update step. 84 | **kwargs: Optional list or tuple or set of `Variable` objects to 85 | decay. 86 | """ 87 | wd = kwargs.pop("weight_decay", weight_decay) 88 | super().__init__(**kwargs) 89 | self._decay_var_list = None # is set in minimize or apply_gradients 90 | self._set_hyper("weight_decay", wd) 91 | 92 | def get_config(self): 93 | config = super().get_config() 94 | config.update( 95 | {"weight_decay": self._serialize_hyperparameter("weight_decay"),} 96 | ) 97 | return config 98 | 99 | def minimize(self, loss, var_list, grad_loss=None, name=None, decay_var_list=None): 100 | """Minimize `loss` by updating `var_list`. 101 | 102 | This method simply computes gradient using `tf.GradientTape` and calls 103 | `apply_gradients()`. If you want to process the gradient before 104 | applying then call `tf.GradientTape` and `apply_gradients()` explicitly 105 | instead of using this function. 106 | 107 | Args: 108 | loss: A callable taking no arguments which returns the value to 109 | minimize. 110 | var_list: list or tuple of `Variable` objects to update to 111 | minimize `loss`, or a callable returning the list or tuple of 112 | `Variable` objects. Use callable when the variable list would 113 | otherwise be incomplete before `minimize` since the variables 114 | are created at the first time `loss` is called. 115 | grad_loss: Optional. A `Tensor` holding the gradient computed for 116 | `loss`. 117 | decay_var_list: Optional list of variables to be decayed. Defaults 118 | to all variables in var_list. 119 | name: Optional name for the returned operation. 120 | Returns: 121 | An Operation that updates the variables in `var_list`. 122 | Raises: 123 | ValueError: If some of the variables are not `Variable` objects. 124 | """ 125 | self._decay_var_list = ( 126 | set([_ref(v) for v in decay_var_list]) if decay_var_list else False 127 | ) 128 | return super().minimize(loss, var_list=var_list, grad_loss=grad_loss, name=name) 129 | 130 | def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwargs): 131 | """Apply gradients to variables. 132 | 133 | This is the second part of `minimize()`. It returns an `Operation` that 134 | applies gradients. 135 | 136 | Args: 137 | grads_and_vars: List of (gradient, variable) pairs. 138 | name: Optional name for the returned operation. Default to the 139 | name passed to the `Optimizer` constructor. 140 | decay_var_list: Optional list of variables to be decayed. Defaults 141 | to all variables in var_list. 142 | **kwargs: Additional arguments to pass to the base optimizer's 143 | apply_gradient method, e.g., TF2.2 added an argument 144 | `all_reduce_sum_gradients`. 145 | Returns: 146 | An `Operation` that applies the specified gradients. 147 | Raises: 148 | TypeError: If `grads_and_vars` is malformed. 149 | ValueError: If none of the variables have gradients. 150 | """ 151 | self._decay_var_list = ( 152 | set([_ref(v) for v in decay_var_list]) if decay_var_list else False 153 | ) 154 | return super().apply_gradients(grads_and_vars, name=name, **kwargs) 155 | 156 | def _decay_weights_op(self, var): 157 | if not self._decay_var_list or _ref(var) in self._decay_var_list: 158 | return var.assign_sub( 159 | self._get_hyper("weight_decay", var.dtype) * var, self._use_locking 160 | ) 161 | return tf.no_op() 162 | 163 | def _decay_weights_sparse_op(self, var, indices): 164 | if not self._decay_var_list or _ref(var) in self._decay_var_list: 165 | update = -self._get_hyper("weight_decay", var.dtype) * tf.gather( 166 | var, indices 167 | ) 168 | return self._resource_scatter_add(var, indices, update) 169 | return tf.no_op() 170 | 171 | # Here, we overwrite the apply functions that the base optimizer calls. 172 | # super().apply_x resolves to the apply_x function of the BaseOptimizer. 173 | 174 | def _resource_apply_dense(self, grad, var): 175 | with tf.control_dependencies([self._decay_weights_op(var)]): 176 | return super()._resource_apply_dense(grad, var) 177 | 178 | def _resource_apply_sparse(self, grad, var, indices): 179 | decay_op = self._decay_weights_sparse_op(var, indices) 180 | with tf.control_dependencies([decay_op]): 181 | return super()._resource_apply_sparse(grad, var, indices) 182 | 183 | 184 | def extend_with_decoupled_weight_decay( 185 | base_optimizer: Type[tf.keras.optimizers.Optimizer], 186 | ) -> Type[tf.keras.optimizers.Optimizer]: 187 | """Factory function returning an optimizer class with decoupled weight 188 | decay. 189 | 190 | Returns an optimizer class. An instance of the returned class computes the 191 | update step of `base_optimizer` and additionally decays the weights. 192 | E.g., the class returned by 193 | `extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is 194 | equivalent to `tfa.optimizers.AdamW`. 195 | 196 | The API of the new optimizer class slightly differs from the API of the 197 | base optimizer: 198 | - The first argument to the constructor is the weight decay rate. 199 | - `minimize` and `apply_gradients` accept the optional keyword argument 200 | `decay_var_list`, which specifies the variables that should be decayed. 201 | If `None`, all variables that are optimized are decayed. 202 | 203 | Usage example: 204 | ```python 205 | # MyAdamW is a new class 206 | MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam) 207 | # Create a MyAdamW object 208 | optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) 209 | # update var1, var2 but only decay var1 210 | optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1]) 211 | 212 | Note: this extension decays weights BEFORE applying the update based 213 | on the gradient, i.e. this extension only has the desired behaviour for 214 | optimizers which do not depend on the value of 'var' in the update step! 215 | 216 | Note: when applying a decay to the learning rate, be sure to manually apply 217 | the decay to the `weight_decay` as well. For example: 218 | 219 | ```python 220 | step = tf.Variable(0, trainable=False) 221 | schedule = tf.optimizers.schedules.PiecewiseConstantDecay( 222 | [10000, 15000], [1e-0, 1e-1, 1e-2]) 223 | # lr and wd can be a function or a tensor 224 | lr = 1e-1 * schedule(step) 225 | wd = lambda: 1e-4 * schedule(step) 226 | 227 | # ... 228 | 229 | optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) 230 | ``` 231 | 232 | Note: you might want to register your own custom optimizer using 233 | `tf.keras.utils.get_custom_objects()`. 234 | 235 | Args: 236 | base_optimizer: An optimizer class that inherits from 237 | tf.optimizers.Optimizer. 238 | 239 | Returns: 240 | A new optimizer class that inherits from DecoupledWeightDecayExtension 241 | and base_optimizer. 242 | """ 243 | 244 | class OptimizerWithDecoupledWeightDecay( 245 | DecoupledWeightDecayExtension, base_optimizer 246 | ): 247 | """Base_optimizer with decoupled weight decay. 248 | 249 | This class computes the update step of `base_optimizer` and 250 | additionally decays the variable with the weight decay being 251 | decoupled from the optimization steps w.r.t. to the loss 252 | function, as described by Loshchilov & Hutter 253 | (https://arxiv.org/pdf/1711.05101.pdf). For SGD variants, this 254 | simplifies hyperparameter search since it decouples the settings 255 | of weight decay and learning rate. For adaptive gradient 256 | algorithms, it regularizes variables with large gradients more 257 | than L2 regularization would, which was shown to yield better 258 | training loss and generalization error in the paper above. 259 | """ 260 | 261 | def __init__( 262 | self, weight_decay: Union[FloatTensorLike, Callable], *args, **kwargs 263 | ): 264 | # super delegation is necessary here 265 | super().__init__(weight_decay, *args, **kwargs) 266 | 267 | return OptimizerWithDecoupledWeightDecay 268 | 269 | 270 | class SGDW(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD): 271 | """Optimizer that implements the Momentum algorithm with weight_decay. 272 | 273 | This is an implementation of the SGDW optimizer described in "Decoupled 274 | Weight Decay Regularization" by Loshchilov & Hutter 275 | (https://arxiv.org/abs/1711.05101) 276 | ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). 277 | It computes the update step of `tf.keras.optimizers.SGD` and additionally 278 | decays the variable. Note that this is different from adding 279 | L2 regularization on the variables to the loss. Decoupling the weight decay 280 | from other hyperparameters (in particular the learning rate) simplifies 281 | hyperparameter search. 282 | 283 | For further information see the documentation of the SGD Optimizer. 284 | 285 | This optimizer can also be instantiated as 286 | ```python 287 | extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD, 288 | weight_decay=weight_decay) 289 | ``` 290 | 291 | Note: when applying a decay to the learning rate, be sure to manually apply 292 | the decay to the `weight_decay` as well. For example: 293 | 294 | ```python 295 | step = tf.Variable(0, trainable=False) 296 | schedule = tf.optimizers.schedules.PiecewiseConstantDecay( 297 | [10000, 15000], [1e-0, 1e-1, 1e-2]) 298 | # lr and wd can be a function or a tensor 299 | lr = 1e-1 * schedule(step) 300 | wd = lambda: 1e-4 * schedule(step) 301 | 302 | # ... 303 | 304 | optimizer = tfa.optimizers.SGDW( 305 | learning_rate=lr, weight_decay=wd, momentum=0.9) 306 | ``` 307 | """ 308 | 309 | def __init__( 310 | self, 311 | weight_decay: Union[FloatTensorLike, Callable], 312 | learning_rate: Union[FloatTensorLike, Callable] = 0.001, 313 | momentum: Union[FloatTensorLike, Callable] = 0.0, 314 | nesterov: bool = False, 315 | name: str = "SGDW", 316 | **kwargs 317 | ): 318 | """Construct a new SGDW optimizer. 319 | 320 | For further information see the documentation of the SGD Optimizer. 321 | 322 | Args: 323 | learning_rate: float hyperparameter >= 0. Learning rate. 324 | momentum: float hyperparameter >= 0 that accelerates SGD in the 325 | relevant direction and dampens oscillations. 326 | nesterov: boolean. Whether to apply Nesterov momentum. 327 | name: Optional name prefix for the operations created when applying 328 | gradients. Defaults to 'SGD'. 329 | **kwargs: keyword arguments. Allowed to be {`clipnorm`, 330 | `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by 331 | norm; `clipvalue` is clip gradients by value, `decay` is 332 | included for backward compatibility to allow time inverse decay 333 | of learning rate. `lr` is included for backward compatibility, 334 | recommended to use `learning_rate` instead. 335 | """ 336 | super().__init__( 337 | weight_decay, 338 | learning_rate=learning_rate, 339 | momentum=momentum, 340 | nesterov=nesterov, 341 | name=name, 342 | **kwargs, 343 | ) 344 | 345 | 346 | class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): 347 | """Optimizer that implements the Adam algorithm with weight decay. 348 | 349 | This is an implementation of the AdamW optimizer described in "Decoupled 350 | Weight Decay Regularization" by Loshch ilov & Hutter 351 | (https://arxiv.org/abs/1711.05101) 352 | ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). 353 | 354 | It computes the update step of `tf.keras.optimizers.Adam` and additionally 355 | decays the variable. Note that this is different from adding L2 356 | regularization on the variables to the loss: it regularizes variables with 357 | large gradients more than L2 regularization would, which was shown to yield 358 | better training loss and generalization error in the paper above. 359 | 360 | For further information see the documentation of the Adam Optimizer. 361 | 362 | This optimizer can also be instantiated as 363 | ```python 364 | extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam, 365 | weight_decay=weight_decay) 366 | ``` 367 | 368 | Note: when applying a decay to the learning rate, be sure to manually apply 369 | the decay to the `weight_decay` as well. For example: 370 | 371 | ```python 372 | step = tf.Variable(0, trainable=False) 373 | schedule = tf.optimizers.schedules.PiecewiseConstantDecay( 374 | [10000, 15000], [1e-0, 1e-1, 1e-2]) 375 | # lr and wd can be a function or a tensor 376 | lr = 1e-1 * schedule(step) 377 | wd = lambda: 1e-4 * schedule(step) 378 | 379 | # ... 380 | 381 | optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) 382 | ``` 383 | """ 384 | 385 | def __init__( 386 | self, 387 | weight_decay: Union[FloatTensorLike, Callable], 388 | learning_rate: Union[FloatTensorLike, Callable] = 0.001, 389 | beta_1: Union[FloatTensorLike, Callable] = 0.9, 390 | beta_2: Union[FloatTensorLike, Callable] = 0.999, 391 | epsilon: FloatTensorLike = 1e-07, 392 | amsgrad: bool = False, 393 | name: str = "AdamW", 394 | **kwargs 395 | ): 396 | """Construct a new AdamW optimizer. 397 | 398 | For further information see the documentation of the Adam Optimizer. 399 | 400 | Args: 401 | weight_decay: A Tensor or a floating point value. The weight decay. 402 | learning_rate: A Tensor or a floating point value. The learning 403 | rate. 404 | beta_1: A float value or a constant float tensor. The exponential 405 | decay rate for the 1st moment estimates. 406 | beta_2: A float value or a constant float tensor. The exponential 407 | decay rate for the 2nd moment estimates. 408 | epsilon: A small constant for numerical stability. This epsilon is 409 | "epsilon hat" in the Kingma and Ba paper (in the formula just 410 | before Section 2.1), not the epsilon in Algorithm 1 of the 411 | paper. 412 | amsgrad: boolean. Whether to apply AMSGrad variant of this 413 | algorithm from the paper "On the Convergence of Adam and 414 | beyond". 415 | name: Optional name for the operations created when applying 416 | gradients. Defaults to "AdamW". 417 | **kwargs: keyword arguments. Allowed to be {`clipnorm`, 418 | `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by 419 | norm; `clipvalue` is clip gradients by value, `decay` is 420 | included for backward compatibility to allow time inverse decay 421 | of learning rate. `lr` is included for backward compatibility, 422 | recommended to use `learning_rate` instead. 423 | """ 424 | super().__init__( 425 | weight_decay, 426 | learning_rate=learning_rate, 427 | beta_1=beta_1, 428 | beta_2=beta_2, 429 | epsilon=epsilon, 430 | amsgrad=amsgrad, 431 | name=name, 432 | **kwargs, 433 | ) 434 | -------------------------------------------------------------------------------- /training/transnet.py: -------------------------------------------------------------------------------- 1 | import gin 2 | import h5py 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from models import ResNet18, ResNetBlock 7 | 8 | 9 | @gin.configurable(blacklist=["name"]) 10 | class TransNetV2(tf.keras.Model): 11 | 12 | def __init__(self, F=16, L=3, S=2, D=256, 13 | use_resnet_features=False, 14 | use_many_hot_targets=False, 15 | use_frame_similarity=False, 16 | use_mean_pooling=False, 17 | use_convex_comb_reg=False, 18 | dropout_rate=None, 19 | use_resnet_like_top=False, 20 | frame_similarity_on_last_layer=False, 21 | use_color_histograms=False, 22 | name="TransNet"): 23 | super(TransNetV2, self).__init__(name=name) 24 | 25 | self.resnet_layers = ResNetFeatures() if use_resnet_features else (lambda x, training=False: x / 255.) 26 | self.blocks = [StackedDDCNNV2(n_blocks=S, filters=F, stochastic_depth_drop_prob=0., name="SDDCNN_1")] 27 | self.blocks += [StackedDDCNNV2(n_blocks=S, filters=F * 2**i, name="SDDCNN_{:d}".format(i + 1)) for i in range(1, L)] 28 | self.fc1 = tf.keras.layers.Dense(D, activation=tf.nn.relu) 29 | self.cls_layer1 = tf.keras.layers.Dense(1, activation=None) 30 | self.cls_layer2 = tf.keras.layers.Dense(1, activation=None) if use_many_hot_targets else None 31 | self.frame_sim_layer = FrameSimilarity() if use_frame_similarity else None 32 | self.color_hist_layer = ColorHistograms() if use_color_histograms else None 33 | self.use_mean_pooling = use_mean_pooling 34 | self.convex_comb_reg = ConvexCombinationRegularization() if use_convex_comb_reg else None 35 | self.dropout = tf.keras.layers.Dropout(dropout_rate) if dropout_rate is not None else None 36 | 37 | self.frame_similarity_on_last_layer = frame_similarity_on_last_layer 38 | self.resnet_like_top = use_resnet_like_top 39 | if self.resnet_like_top: 40 | self.resnet_like_top_conv = tf.keras.layers.Conv3D(filters=32, kernel_size=(3, 7, 7), strides=(1, 2, 2), 41 | padding="SAME", use_bias=False, 42 | name="resnet_like_top/conv") 43 | self.resnet_like_top_bn = tf.keras.layers.BatchNormalization(name="resnet_like_top/bn") 44 | self.resnet_like_top_max_pool = tf.keras.layers.MaxPooling3D(pool_size=(1, 3, 3), strides=(1, 2, 2), 45 | padding="SAME") 46 | 47 | def call(self, inputs, training=False): 48 | out_dict = {} 49 | 50 | x = inputs 51 | x = self.resnet_layers(x, training=training) 52 | 53 | if self.resnet_like_top: 54 | x = self.resnet_like_top_conv(x) 55 | x = self.resnet_like_top_bn(x) 56 | x = self.resnet_like_top_max_pool(x) 57 | 58 | block_features = [] 59 | for block in self.blocks: 60 | x = block(x, training=training) 61 | block_features.append(x) 62 | 63 | if self.convex_comb_reg is not None: 64 | out_dict["alphas"], out_dict["comb_reg_loss"] = self.convex_comb_reg(inputs, x) 65 | 66 | if self.use_mean_pooling: 67 | x = tf.math.reduce_mean(x, axis=[2, 3]) 68 | else: 69 | shape = [tf.shape(x)[0], tf.shape(x)[1], np.prod(x.get_shape().as_list()[2:])] 70 | x = tf.reshape(x, shape=shape, name="flatten_3d") 71 | 72 | if self.frame_sim_layer is not None and not self.frame_similarity_on_last_layer: 73 | x = tf.concat([self.frame_sim_layer(block_features), x], 2) 74 | 75 | if self.color_hist_layer is not None: 76 | x = tf.concat([self.color_hist_layer(inputs), x], 2) 77 | 78 | x = self.fc1(x) 79 | if self.dropout is not None: 80 | x = self.dropout(x, training=training) 81 | 82 | if self.frame_sim_layer is not None and self.frame_similarity_on_last_layer: 83 | x = tf.concat([self.frame_sim_layer(block_features), x], 2) 84 | 85 | one_hot = self.cls_layer1(x) 86 | 87 | if self.cls_layer2 is not None: 88 | out_dict["many_hot"] = self.cls_layer2(x) 89 | 90 | if len(out_dict) > 0: 91 | return one_hot, out_dict 92 | return one_hot 93 | 94 | 95 | @gin.configurable(whitelist=["shortcut", "use_octave_conv", "pool_type", "stochastic_depth_drop_prob"]) 96 | class StackedDDCNNV2(tf.keras.layers.Layer): 97 | 98 | def __init__(self, n_blocks, filters, shortcut=False, use_octave_conv=False, pool_type="max", 99 | stochastic_depth_drop_prob=0., name="StackedDDCNN"): 100 | super(StackedDDCNNV2, self).__init__(name=name) 101 | assert pool_type == "max" or pool_type == "avg" 102 | if use_octave_conv and pool_type == "max": 103 | print("WARN: Octave convolution was designed with average pooling, not max pooling.") 104 | 105 | self.shortcut = shortcut 106 | # self.shortcut = None 107 | # if shortcut: 108 | # self.shortcut = tf.keras.layers.Conv3D(filters * 4, kernel_size=1, dilation_rate=1, padding="SAME", 109 | # activation=None, use_bias=True, name="shortcut") 110 | 111 | self.blocks = [DilatedDCNNV2(filters, octave_conv=use_octave_conv, 112 | activation=tf.nn.relu if i != n_blocks else None, 113 | name="DDCNN_{:d}".format(i)) for i in range(1, n_blocks + 1)] 114 | self.pool = tf.keras.layers.MaxPool3D(pool_size=(1, 2, 2)) if pool_type == "max" else \ 115 | tf.keras.layers.AveragePooling3D(pool_size=(1, 2, 2)) 116 | self.octave = use_octave_conv 117 | self.stochastic_depth_drop_prob = stochastic_depth_drop_prob 118 | 119 | def call(self, inputs, training=False): 120 | x = inputs 121 | shortcut = None 122 | 123 | if self.octave: 124 | x = [self.pool(x), x] 125 | for block in self.blocks: 126 | x = block(x, training=training) 127 | if shortcut is None: 128 | shortcut = x 129 | if self.octave: 130 | x = tf.concat([x[0], self.pool(x[1])], -1) 131 | 132 | x = tf.nn.relu(x) 133 | 134 | if self.shortcut is not None: 135 | # shortcut = self.shortcut(inputs) 136 | if self.stochastic_depth_drop_prob != 0.: 137 | if training: 138 | x = tf.cond(tf.random.uniform([]) < self.stochastic_depth_drop_prob, 139 | lambda: shortcut, lambda: x + shortcut) 140 | else: 141 | x = (1 - self.stochastic_depth_drop_prob) * x + shortcut 142 | else: 143 | x += shortcut 144 | 145 | if not self.octave: 146 | x = self.pool(x) 147 | return x 148 | 149 | 150 | @gin.configurable(whitelist=["batch_norm"]) 151 | class DilatedDCNNV2(tf.keras.layers.Layer): 152 | 153 | def __init__(self, filters, batch_norm=False, activation=None, octave_conv=False, name="DilatedDCNN"): 154 | super(DilatedDCNNV2, self).__init__(name=name) 155 | assert not (octave_conv and batch_norm) 156 | 157 | self.conv1 = Conv3DConfigurable(filters, 1, use_bias=not batch_norm, octave=octave_conv, name="Conv3D_1") 158 | self.conv2 = Conv3DConfigurable(filters, 2, use_bias=not batch_norm, octave=octave_conv, name="Conv3D_2") 159 | self.conv3 = Conv3DConfigurable(filters, 4, use_bias=not batch_norm, octave=octave_conv, name="Conv3D_4") 160 | self.conv4 = Conv3DConfigurable(filters, 8, use_bias=not batch_norm, octave=octave_conv, name="Conv3D_8") 161 | self.octave = octave_conv 162 | 163 | self.batch_norm = tf.keras.layers.BatchNormalization(name="bn") if batch_norm else None 164 | self.activation = activation 165 | 166 | def call(self, inputs, training=False): 167 | conv1 = self.conv1(inputs, training=training) 168 | conv2 = self.conv2(inputs, training=training) 169 | conv3 = self.conv3(inputs, training=training) 170 | conv4 = self.conv4(inputs, training=training) 171 | 172 | if self.octave: 173 | x = [tf.concat([conv1[0], conv2[0], conv3[0], conv4[0]], axis=4), 174 | tf.concat([conv1[1], conv2[1], conv3[1], conv4[1]], axis=4)] 175 | else: 176 | x = tf.concat([conv1, conv2, conv3, conv4], axis=4) 177 | 178 | if self.batch_norm is not None: 179 | x = self.batch_norm(x, training=training) 180 | 181 | if self.activation is not None: 182 | if self.octave: 183 | x = [self.activation(x[0]), self.activation(x[1])] 184 | else: 185 | x = self.activation(x) 186 | return x 187 | 188 | 189 | @gin.configurable(whitelist=["separable", "kernel_initializer"]) 190 | class Conv3DConfigurable(tf.keras.layers.Layer): 191 | 192 | def __init__(self, 193 | filters, 194 | dilation_rate, 195 | separable=False, 196 | octave=False, 197 | use_bias=True, 198 | kernel_initializer="glorot_uniform", 199 | name="Conv3D"): 200 | super(Conv3DConfigurable, self).__init__(name=name) 201 | assert not (separable and octave) 202 | 203 | if separable: 204 | # (2+1)D convolution https://arxiv.org/pdf/1711.11248.pdf 205 | conv1 = tf.keras.layers.Conv3D(2 * filters, kernel_size=(1, 3, 3), dilation_rate=(1, 1, 1), 206 | padding="SAME", activation=None, use_bias=False, 207 | name="conv_spatial", kernel_initializer=kernel_initializer) 208 | conv2 = tf.keras.layers.Conv3D(filters, kernel_size=(3, 1, 1), dilation_rate=(dilation_rate, 1, 1), 209 | padding="SAME", activation=None, use_bias=use_bias, name="conv_temporal", 210 | kernel_initializer=kernel_initializer) 211 | self.layers = [conv1, conv2] 212 | elif octave: 213 | conv = OctConv3D(filters, kernel_size=3, dilation_rate=(dilation_rate, 1, 1), use_bias=use_bias, 214 | kernel_initializer=kernel_initializer) 215 | self.layers = [conv] 216 | else: 217 | conv = tf.keras.layers.Conv3D(filters, kernel_size=3, dilation_rate=(dilation_rate, 1, 1), 218 | padding="SAME", activation=None, use_bias=use_bias, name="conv", 219 | kernel_initializer=kernel_initializer) 220 | self.layers = [conv] 221 | 222 | def call(self, inputs): 223 | x = inputs 224 | for layer in self.layers: 225 | x = layer(x) 226 | return x 227 | 228 | 229 | @gin.configurable(whitelist=["alpha"]) 230 | class OctConv3D(tf.keras.layers.Layer): 231 | 232 | def __init__(self, filters, kernel_size=3, dilation_rate=(1, 1, 1), alpha=0.25, 233 | use_bias=True, kernel_initializer="glorot_uniform", name="OctConv3D"): 234 | super(OctConv3D, self).__init__(name=name) 235 | 236 | self.low_channels = int(filters * alpha) 237 | self.high_channels = filters - self.low_channels 238 | 239 | self.high_to_high = tf.keras.layers.Conv3D(self.high_channels, kernel_size=kernel_size, activation=None, 240 | dilation_rate=dilation_rate, padding="SAME", 241 | use_bias=use_bias, kernel_initializer=kernel_initializer, 242 | name="high_to_high") 243 | self.high_to_low = tf.keras.layers.Conv3D(self.low_channels, kernel_size=kernel_size, activation=None, 244 | dilation_rate=dilation_rate, padding="SAME", 245 | use_bias=False, kernel_initializer=kernel_initializer, 246 | name="high_to_low") 247 | self.low_to_high = tf.keras.layers.Conv3D(self.high_channels, kernel_size=kernel_size, activation=None, 248 | dilation_rate=dilation_rate, padding="SAME", 249 | use_bias=False, kernel_initializer=kernel_initializer, 250 | name="low_to_high") 251 | self.low_to_low = tf.keras.layers.Conv3D(self.low_channels, kernel_size=kernel_size, activation=None, 252 | dilation_rate=dilation_rate, padding="SAME", 253 | use_bias=use_bias, kernel_initializer=kernel_initializer, 254 | name="low_to_low") 255 | self.upsampler = tf.keras.layers.UpSampling3D(size=(1, 2, 2)) 256 | self.downsampler = tf.keras.layers.AveragePooling3D(pool_size=(1, 2, 2), strides=(1, 2, 2), padding="SAME") 257 | 258 | @staticmethod 259 | def pad_to(tensor, target_shape): 260 | shape = tf.shape(tensor) 261 | padding = [[0, tar - curr] for curr, tar in zip(shape, target_shape)] 262 | return tf.pad(tensor, padding, "CONSTANT") 263 | 264 | @staticmethod 265 | def crop_to(tensor, target_width, target_height): 266 | return tensor[:, :, :target_height, :target_width] 267 | 268 | def call(self, inputs): 269 | low_inputs, high_inputs = inputs 270 | 271 | high_to_high = self.high_to_high(high_inputs) 272 | high_to_low = self.high_to_low(self.downsampler(high_inputs)) 273 | 274 | low_to_high = self.upsampler(self.low_to_high(low_inputs)) 275 | low_to_low = self.low_to_low(low_inputs) 276 | 277 | high_output = high_to_high[:, :, :tf.shape(low_to_high)[2], :tf.shape(low_to_high)[3]] + low_to_high 278 | low_output = low_to_low + high_to_low[:, :, :tf.shape(low_to_low)[2], :tf.shape(low_to_low)[3]] 279 | 280 | # print("OctConv3D:", low_inputs.shape, "->", low_output.shape, "|", high_inputs.shape, "->", high_output.shape) 281 | return low_output, high_output 282 | 283 | 284 | @gin.configurable(whitelist=["trainable"]) 285 | class ResNetFeatures(tf.keras.layers.Layer): 286 | 287 | def __init__(self, trainable=False, name="ResNetFeatures"): 288 | super(ResNetFeatures, self).__init__(trainable=trainable, name=name) 289 | 290 | self.conv1 = tf.keras.layers.Conv2D(filters=64, kernel_size=(7, 7), strides=(2, 2), 291 | padding="SAME", use_bias=False, name="conv1") 292 | self.bn1 = tf.keras.layers.BatchNormalization(name="conv1/bn") 293 | self.max_pool = tf.keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="SAME") 294 | 295 | self.layer2a = ResNetBlock(64, name="Block2a") 296 | self.layer2b = ResNetBlock(64, name="Block2b") 297 | 298 | self.mean = tf.constant(ResNet18.MEAN) 299 | self.std = tf.constant(ResNet18.STD) 300 | 301 | def call(self, inputs, training=False): 302 | training = training if self.trainable else False 303 | shape = tf.shape(inputs) 304 | 305 | x = tf.reshape(inputs, [shape[0] * shape[1], shape[2], shape[3], shape[4]]) 306 | x = (x - self.mean) / self.std 307 | 308 | x = self.conv1(x) 309 | x = self.bn1(x, training=training) 310 | x = tf.nn.relu(x) 311 | x = self.max_pool(x) 312 | 313 | x = self.layer2a(x, training=training) 314 | x = self.layer2b(x, training=training) 315 | 316 | new_shape = tf.shape(x) 317 | x = tf.reshape(x, [shape[0], shape[1], new_shape[1], new_shape[2], new_shape[3]]) 318 | return x 319 | 320 | def restore_me(self, checkpoint): 321 | with h5py.File(checkpoint, "r") as f: 322 | for v in self.variables: 323 | name = v.name.split("/")[2:] 324 | if name[0].startswith("Block"): 325 | name = name[:1] + name 326 | else: 327 | name = name[:len(name) - 1] + name 328 | name = "/".join(name) 329 | v.assign(f[name][:]) 330 | 331 | 332 | @gin.configurable(whitelist=["similarity_dim", "lookup_window", "output_dim", "stop_gradient", "use_bias"]) 333 | class FrameSimilarity(tf.keras.layers.Layer): 334 | 335 | def __init__(self, 336 | similarity_dim=128, 337 | lookup_window=101, 338 | output_dim=128, 339 | stop_gradient=False, 340 | use_bias=False, 341 | name="FrameSimilarity"): 342 | super(FrameSimilarity, self).__init__(name=name) 343 | 344 | self.projection = tf.keras.layers.Dense(similarity_dim, use_bias=use_bias, activation=None) 345 | self.fc = tf.keras.layers.Dense(output_dim, activation=tf.nn.relu) 346 | 347 | self.lookup_window = lookup_window 348 | self.stop_gradient = stop_gradient 349 | assert lookup_window % 2 == 1, "`lookup_window` must be odd integer" 350 | 351 | def call(self, inputs): 352 | x = tf.concat([ 353 | tf.math.reduce_mean(x, axis=[2, 3]) for x in inputs 354 | ], axis=2) 355 | 356 | if self.stop_gradient: 357 | x = tf.stop_gradient(x) 358 | 359 | x = self.projection(x) 360 | x = tf.nn.l2_normalize(x, axis=2) 361 | 362 | batch_size, time_window = tf.shape(x)[0], tf.shape(x)[1] 363 | similarities = tf.matmul(x, x, transpose_b=True) # [batch_size, time_window, time_window] 364 | similarities_padded = tf.pad(similarities, [[0, 0], [0, 0], [(self.lookup_window - 1) // 2] * 2]) 365 | 366 | batch_indices = tf.tile( 367 | tf.reshape(tf.range(batch_size), [batch_size, 1, 1]), [1, time_window, self.lookup_window] 368 | ) 369 | time_indices = tf.tile( 370 | tf.reshape(tf.range(time_window), [1, time_window, 1]), [batch_size, 1, self.lookup_window] 371 | ) 372 | lookup_indices = tf.tile( 373 | tf.reshape(tf.range(self.lookup_window), [1, 1, self.lookup_window]), [batch_size, time_window, 1] 374 | ) + time_indices 375 | 376 | indices = tf.stack([batch_indices, time_indices, lookup_indices], -1) 377 | 378 | similarities = tf.gather_nd(similarities_padded, indices) 379 | return self.fc(similarities) 380 | 381 | 382 | @gin.configurable(whitelist=["filters", "delta_scale", "loss_weight"]) 383 | class ConvexCombinationRegularization(tf.keras.layers.Layer): 384 | 385 | def __init__(self, filters=32, delta_scale=10., loss_weight=0.01, name="ConvexCombinationRegularization"): 386 | super(ConvexCombinationRegularization, self).__init__(name=name) 387 | 388 | self.projection = tf.keras.layers.Conv3D(filters, kernel_size=1, dilation_rate=1, padding="SAME", 389 | activation=tf.nn.relu, use_bias=True) 390 | self.features = tf.keras.layers.Conv3D(filters * 2, kernel_size=(3, 3, 3), dilation_rate=1, padding="SAME", 391 | activation=tf.nn.relu, use_bias=True) 392 | self.dense = tf.keras.layers.Dense(1, activation=None, use_bias=True) 393 | self.loss = tf.keras.losses.Huber(reduction=tf.keras.losses.Reduction.NONE) 394 | self.delta_scale = delta_scale 395 | self.loss_weight = loss_weight 396 | 397 | def call(self, image_inputs, feature_inputs): 398 | x = feature_inputs 399 | x = self.projection(x) 400 | 401 | batch_size = tf.shape(x)[0] 402 | window_size = tf.shape(x)[1] 403 | 404 | first_frame = tf.tile(x[:, :1], [1, window_size, 1, 1, 1]) 405 | last_frame = tf.tile(x[:, -1:], [1, window_size, 1, 1, 1]) 406 | 407 | x = tf.concat([x, first_frame, last_frame], -1) 408 | x = self.features(x) 409 | 410 | x = tf.math.reduce_mean(x, axis=[2, 3]) 411 | alpha = self.dense(x) 412 | 413 | first_img = tf.tile(image_inputs[:, :1], [1, window_size, 1, 1, 1]) 414 | last_img = tf.tile(image_inputs[:, -1:], [1, window_size, 1, 1, 1]) 415 | 416 | alpha_ = tf.nn.sigmoid(alpha) 417 | alpha_ = tf.reshape(alpha_, [batch_size, window_size, 1, 1, 1]) 418 | predictions_ = (alpha_ * first_img + (1 - alpha_) * last_img) 419 | 420 | loss_ = self.loss(y_true=image_inputs / self.delta_scale, y_pred=predictions_ / self.delta_scale) 421 | loss_ = self.loss_weight * tf.math.reduce_mean(loss_) 422 | return alpha, loss_ 423 | 424 | 425 | @gin.configurable(whitelist=["lookup_window", "output_dim"]) 426 | class ColorHistograms(tf.keras.layers.Layer): 427 | 428 | def __init__(self, lookup_window=101, output_dim=None, name="ColorHistograms"): 429 | super(ColorHistograms, self).__init__(name=name) 430 | 431 | self.fc = tf.keras.layers.Dense(output_dim, activation=tf.nn.relu) if output_dim is not None else None 432 | self.lookup_window = lookup_window 433 | assert lookup_window % 2 == 1, "`lookup_window` must be odd integer" 434 | 435 | @staticmethod 436 | def compute_color_histograms(frames): 437 | frames = tf.cast(frames, tf.int32) 438 | 439 | def get_bin(frames): 440 | # returns 0 .. 511 441 | R, G, B = frames[:, :, 0], frames[:, :, 1], frames[:, :, 2] 442 | R, G, B = tf.bitwise.right_shift(R, 5), tf.bitwise.right_shift(G, 5), tf.bitwise.right_shift(B, 5) 443 | return tf.bitwise.left_shift(R, 6) + tf.bitwise.left_shift(G, 3) + B 444 | 445 | batch_size, time_window, height, width = tf.shape(frames)[0], tf.shape(frames)[1], tf.shape(frames)[2], \ 446 | tf.shape(frames)[3] 447 | no_channels = frames.shape[-1] 448 | assert no_channels == 3 or no_channels == 6 449 | if no_channels == 3: 450 | frames_flatten = tf.reshape(frames, [batch_size * time_window, height * width, 3]) 451 | else: 452 | frames_flatten = tf.reshape(frames, [batch_size * time_window, height * width * 2, 3]) 453 | 454 | binned_values = get_bin(frames_flatten) 455 | frame_bin_prefix = tf.bitwise.left_shift(tf.range(batch_size * time_window), 9)[:, tf.newaxis] 456 | binned_values = binned_values + frame_bin_prefix 457 | 458 | ones = tf.ones_like(binned_values, dtype=tf.int32) 459 | histograms = tf.math.unsorted_segment_sum(ones, binned_values, batch_size * time_window * 512) 460 | histograms = tf.reshape(histograms, [batch_size, time_window, 512]) 461 | 462 | histograms_normalized = tf.cast(histograms, tf.float32) 463 | histograms_normalized = histograms_normalized / tf.linalg.norm(histograms_normalized, axis=2, keepdims=True) 464 | return histograms_normalized 465 | 466 | def call(self, inputs): 467 | x = self.compute_color_histograms(inputs) 468 | 469 | batch_size, time_window = tf.shape(x)[0], tf.shape(x)[1] 470 | similarities = tf.matmul(x, x, transpose_b=True) # [batch_size, time_window, time_window] 471 | similarities_padded = tf.pad(similarities, [[0, 0], [0, 0], [(self.lookup_window - 1) // 2] * 2]) 472 | 473 | batch_indices = tf.tile( 474 | tf.reshape(tf.range(batch_size), [batch_size, 1, 1]), [1, time_window, self.lookup_window] 475 | ) 476 | time_indices = tf.tile( 477 | tf.reshape(tf.range(time_window), [1, time_window, 1]), [batch_size, 1, self.lookup_window] 478 | ) 479 | lookup_indices = tf.tile( 480 | tf.reshape(tf.range(self.lookup_window), [1, 1, self.lookup_window]), [batch_size, time_window, 1] 481 | ) + time_indices 482 | 483 | indices = tf.stack([batch_indices, time_indices, lookup_indices], -1) 484 | 485 | similarities = tf.gather_nd(similarities_padded, indices) 486 | 487 | if self.fc is not None: 488 | return self.fc(similarities) 489 | return similarities 490 | -------------------------------------------------------------------------------- /training/input_processing.py: -------------------------------------------------------------------------------- 1 | import gin 2 | import tensorflow as tf 3 | 4 | 5 | @gin.configurable(blacklist=["filenames"]) 6 | def train_pipeline(filenames, 7 | shuffle_buffer=100, 8 | shot_len=100, 9 | frame_width=48, 10 | frame_height=27, 11 | batch_size=16, 12 | repeat=False, 13 | no_channels=3): 14 | ds = tf.data.Dataset.from_tensor_slices(filenames) 15 | ds = ds.shuffle(len(filenames)) 16 | ds = ds.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type="GZIP").map(parse_train_sample, 17 | num_parallel_calls=1), 18 | cycle_length=8, 19 | block_length=16, 20 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 21 | ds = ds.shuffle(shuffle_buffer) 22 | ds = ds.padded_batch(2, ([shot_len, frame_height, frame_width, no_channels], []), drop_remainder=True) 23 | ds = ds.map(concat_shots, num_parallel_calls=tf.data.experimental.AUTOTUNE) 24 | 25 | def filter_(*args): 26 | return args[-1] 27 | 28 | def map_(*args): 29 | return args[:-1] 30 | 31 | ds = ds.filter(filter_).map(map_) 32 | ds = ds.batch(batch_size) 33 | if repeat: 34 | ds = ds.repeat() 35 | ds = ds.prefetch(2) 36 | return ds 37 | 38 | 39 | @gin.configurable(blacklist=["filenames"]) 40 | def train_transition_pipeline(filenames, 41 | shuffle_buffer=100, 42 | batch_size=16, 43 | repeat=False): 44 | ds = tf.data.Dataset.from_tensor_slices(filenames) 45 | ds = ds.shuffle(len(filenames)) 46 | ds = ds.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type="GZIP"), 47 | cycle_length=8, 48 | block_length=16, 49 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 50 | ds = ds.map(parse_train_transition_sample, num_parallel_calls=tf.data.experimental.AUTOTUNE) 51 | ds = ds.shuffle(shuffle_buffer) 52 | ds = ds.batch(batch_size) 53 | if repeat: 54 | ds = ds.repeat() 55 | ds = ds.prefetch(2) 56 | return ds 57 | 58 | 59 | @tf.function 60 | @gin.configurable(blacklist=["sample"]) 61 | def parse_train_transition_sample(sample, 62 | shot_len=None, 63 | frame_width=48, 64 | frame_height=27): 65 | features = tf.io.parse_single_example(sample, features={ 66 | "scene": tf.io.FixedLenFeature([], tf.string), 67 | "one_hot": tf.io.FixedLenFeature([], tf.string), 68 | "many_hot": tf.io.FixedLenFeature([], tf.string), 69 | "length": tf.io.FixedLenFeature([], tf.int64) 70 | }) 71 | length = tf.cast(features["length"], tf.int32) 72 | 73 | scene = tf.io.decode_raw(features["scene"], tf.uint8) 74 | scene = tf.reshape(scene, [length, frame_height, frame_width, 3]) 75 | 76 | one_hot = tf.io.decode_raw(features["one_hot"], tf.uint8) 77 | many_hot = tf.io.decode_raw(features["many_hot"], tf.uint8) 78 | 79 | shot_start = tf.random.uniform([], minval=0, maxval=length - shot_len, dtype=tf.int32) 80 | shot_end = shot_start + shot_len 81 | 82 | scene = tf.reshape(scene[shot_start:shot_end], [shot_len, frame_height, frame_width, 3]) 83 | scene = tf.cast(scene, dtype=tf.float32) 84 | scene = augment_shot(scene) 85 | 86 | one_hot = tf.cast(tf.reshape(one_hot[shot_start:shot_end], [shot_len]), tf.int32) 87 | many_hot = tf.cast(tf.reshape(many_hot[shot_start:shot_end], [shot_len]), tf.int32) 88 | 89 | return scene, one_hot, many_hot 90 | 91 | 92 | @tf.function 93 | @gin.configurable(blacklist=["sample"]) 94 | def parse_train_sample(sample, 95 | shot_len=None, 96 | frame_width=48, 97 | frame_height=27, 98 | sudden_color_change_prob=0., 99 | spacial_augmentation=False, 100 | original_width=None, 101 | original_height=None, 102 | no_channels=3): 103 | assert no_channels == 3 or no_channels == 6 104 | 105 | features = tf.io.parse_single_example(sample, features={ 106 | "scene": tf.io.FixedLenFeature([], tf.string), 107 | "length": tf.io.FixedLenFeature([], tf.int64) 108 | }) 109 | length = tf.cast(features["length"], tf.int32) 110 | 111 | original_width = original_width if spacial_augmentation else frame_width 112 | original_height = original_height if spacial_augmentation else frame_height 113 | 114 | scene = tf.io.decode_raw(features["scene"], tf.uint8) 115 | scene = tf.reshape(scene, [length, original_height, original_width, no_channels]) 116 | 117 | shot_start = tf.random.uniform([], minval=0, maxval=tf.maximum(1, length - shot_len), dtype=tf.int32) 118 | shot_end = shot_start + shot_len 119 | scene = scene[shot_start:shot_end] 120 | 121 | scene = tf.cast(scene, dtype=tf.float32) 122 | 123 | if sudden_color_change_prob != 0.: 124 | assert no_channels == 3 # not implemented 125 | 126 | def color_change(shot_): 127 | bound = tf.random.uniform([], minval=1, maxval=tf.shape(shot_)[0], dtype=tf.int32) 128 | start, end = shot_[:bound], shot_[bound:] 129 | start = augment_shot(start, up_down_flip_prob=0., left_right_flip_prob=0.) 130 | return tf.concat([start, end], axis=0) 131 | 132 | scene = tf.cond(tf.random.uniform([]) < sudden_color_change_prob, 133 | lambda: color_change(scene), lambda: scene) 134 | 135 | if spacial_augmentation: 136 | assert no_channels == 3 # not implemented 137 | scene = augment_shot_spacial(scene, frame_width, frame_height) 138 | 139 | scene = augment_shot(scene, no_channels=no_channels) 140 | return scene, tf.shape(scene)[0] # [= 2, "`transition_min_len` must be even" 251 | assert transition_max_len % 2 == 0, "`transition_max_len` must be even" 252 | shot1 = shots[0][:lens[0]] 253 | shot2 = shots[1][:lens[1]] 254 | 255 | if color_transfer_prob > 0: 256 | assert no_channels == 3 # not implemented 257 | shot2 = tf.cond(tf.random.uniform([]) < color_transfer_prob, 258 | lambda: color_transfer(source=shot1, target=shot2), lambda: shot2) 259 | 260 | transition_boundary = tf.random.uniform([], maxval=shot_len, dtype=tf.int32) # {0, ..., shot_len - 1} 261 | # convert transition boundary to vector with 1 at the boundary and 0 otherwise 262 | one_hot_gt = tf.one_hot(transition_boundary, shot_len, dtype=tf.int32) # [SHOT_LENGTH] 263 | 264 | # hard_cut 265 | hard_cut = tf.cast(tf.range(shot_len) <= transition_boundary, dtype=tf.float32) 266 | 267 | # dissolve 268 | dis_len = tf.random.uniform([], minval=transition_min_len // 2, 269 | maxval=(transition_max_len // 2) + 1, dtype=tf.int32) 270 | dis_kernel = tf.linspace(1., 0., dis_len * 2 + 2)[1:-1] 271 | dis_left_win = tf.minimum(dis_len - 1, transition_boundary) 272 | dis_right_win = tf.minimum(dis_len, (shot_len - 1) - transition_boundary) 273 | dissolve = tf.concat([ 274 | tf.ones([transition_boundary - dis_left_win], dtype=tf.float32), 275 | dis_kernel[dis_len - dis_left_win - 1:dis_len + dis_right_win], 276 | tf.zeros([shot_len - (transition_boundary + dis_right_win + 1)], dtype=tf.float32) 277 | ], axis=0) 278 | dissolve_trans = tf.reshape(tf.cast( 279 | tf.logical_and(tf.not_equal(dissolve, 0.), tf.not_equal(dissolve, 1.)), tf.int32 280 | ), [shot_len]) 281 | 282 | # switch between hard cut and dissolve 283 | is_dissolve = tf.random.uniform([]) > hard_cut_prob 284 | transition, many_hot_gt = tf.cond(is_dissolve, 285 | lambda: (dissolve, dissolve_trans), 286 | lambda: (hard_cut, one_hot_gt)) 287 | 288 | # pad shots to full length if they are smaller 289 | many_hot_gt_indices = tf.cast(tf.where(many_hot_gt), tf.int32) 290 | shot1_min_len = tf.reduce_max(many_hot_gt_indices) 291 | shot2_min_len = shot_len - tf.reduce_min(many_hot_gt_indices) 292 | 293 | shot1_pad_start = tf.maximum(shot1_min_len - lens[0], 0) 294 | shot1_pad_end = tf.maximum(shot_len - (lens[0] + shot1_pad_start), 0) 295 | shot1 = tf.pad(shot1, [[shot1_pad_start, shot1_pad_end], [0, 0], [0, 0], [0, 0]]) 296 | 297 | shot2_pad_end = tf.maximum(shot2_min_len - lens[1], 0) 298 | shot2_pad_start = tf.maximum(shot_len - (lens[1] + shot2_pad_end), 0) 299 | shot2 = tf.pad(shot2, [[shot2_pad_start, shot2_pad_end], [0, 0], [0, 0], [0, 0]]) 300 | 301 | def basic_shot_transitions(shot1, shot2, trans_interpolation): 302 | # add together two shots 303 | trans_interpolation = tf.reshape(trans_interpolation, [tf.shape(shot1)[0], 1, 1, 1]) 304 | return shot1 * trans_interpolation + shot2 * (1 - trans_interpolation) 305 | 306 | # [SHOT_LENGTH, IMAGE_HEIGHT, IMAGE_WIDTH, 3] 307 | shot = tf.cond(tf.logical_and(is_dissolve, tf.random.uniform([]) < advanced_shot_trans_prob), 308 | lambda: advanced_shot_transitions(shot1, shot2, transition), 309 | lambda: basic_shot_transitions(shot1, shot2, transition)) 310 | 311 | if cutout_prob > 0.: 312 | assert no_channels == 3 # not implemented 313 | shot = tf.cond(tf.random.uniform([]) < cutout_prob, 314 | lambda: cutout(shot), lambda: shot) 315 | return shot, one_hot_gt, many_hot_gt, tf.maximum(shot1_pad_start, shot2_pad_end) == 0 316 | 317 | 318 | @tf.function 319 | def advanced_shot_transitions(shot1, shot2, trans_interpolation): 320 | # transition in horizontal or vertical direction 321 | flip_wh = tf.cast(tf.random.uniform([], maxval=2, dtype=tf.int32), tf.bool) 322 | shot1, shot2 = tf.cond(flip_wh, 323 | lambda: (tf.transpose(shot1, [0, 2, 1, 3]), tf.transpose(shot2, [0, 2, 1, 3])), 324 | lambda: (shot1, shot2)) 325 | 326 | # transition from top to bottom or from bottom to top 327 | flip_fromto = tf.cast(tf.random.uniform([], maxval=2, dtype=tf.int32), tf.bool) 328 | trans_interpolation = tf.cond(flip_fromto, lambda: trans_interpolation, lambda: 1 - trans_interpolation) 329 | 330 | shot_len, shot_height, shot_width = tf.shape(shot1)[0], tf.shape(shot1)[1], tf.shape(shot1)[2] 331 | # compute gather indices 332 | time_indices = tf.tile(tf.reshape(tf.range(shot_len), [-1, 1]), [1, shot_height]) 333 | initial_rows = tf.tile(tf.reshape(tf.range(shot_height), [1, -1]), [shot_len, 1]) 334 | row_additions = tf.cast(tf.reshape(trans_interpolation, [-1, 1]) * tf.cast(shot_height, tf.float32), tf.int32) 335 | indices = tf.stack([time_indices, initial_rows + row_additions], -1) 336 | 337 | # makes the shot move 338 | shot1_out = tf.gather_nd(tf.concat([shot1, tf.zeros_like(shot1)], 1), indices) 339 | shot2_out = tf.gather_nd(tf.concat([tf.zeros_like(shot2), shot2], 1), indices) 340 | # makes the shot stationary 341 | shot1_mask = tf.gather_nd(tf.concat([tf.ones_like(shot1), tf.zeros_like(shot1)], 1), indices) 342 | shot2_mask = tf.gather_nd(tf.concat([tf.zeros_like(shot2), tf.ones_like(shot2)], 1), indices) 343 | 344 | # select between moving or stationary variant for each shot 345 | shot1 = tf.cond(tf.cast(tf.random.uniform([], maxval=2, dtype=tf.int32), tf.bool), 346 | lambda: shot1_out, lambda: shot1 * shot1_mask) 347 | shot2 = tf.cond(tf.cast(tf.random.uniform([], maxval=2, dtype=tf.int32), tf.bool), 348 | lambda: shot2_out, lambda: shot2 * shot2_mask) 349 | 350 | result = shot1 + shot2 351 | 352 | # flip back if needed 353 | result = tf.cond(flip_wh, lambda: tf.transpose(result, [0, 2, 1, 3]), lambda: result) 354 | return result 355 | 356 | 357 | @tf.function 358 | @gin.configurable(blacklist=["shot"]) 359 | def cutout(shot, 360 | min_width_fraction=1/4, 361 | min_height_fraction=1/4, 362 | max_width_fraction=2/3, 363 | max_height_fraction=2/3, 364 | cutout_color=None): 365 | frame_height, frame_width = tf.shape(shot)[1], tf.shape(shot)[2] 366 | frame_height_float, frame_width_float = tf.cast(frame_height, tf.float32), tf.cast(frame_width, tf.float32) 367 | 368 | height = tf.random.uniform([], 369 | tf.cast(frame_height_float * min_height_fraction, tf.int32), 370 | tf.cast(frame_height_float * max_height_fraction, tf.int32), 371 | tf.int32) 372 | width = tf.random.uniform([], 373 | tf.cast(frame_width_float * min_width_fraction, tf.int32), 374 | tf.cast(frame_width_float * max_width_fraction, tf.int32), 375 | tf.int32) 376 | 377 | left = tf.random.uniform([], 0, frame_width - width, tf.int32) 378 | top = tf.random.uniform([], 0, frame_height - height, tf.int32) 379 | 380 | bottom = tf.minimum(top + height, frame_height) 381 | right = tf.minimum(left + width, frame_width) 382 | 383 | if cutout_color is not None: 384 | t = tf.fill([1, height, width, 3], tf.constant(cutout_color, dtype=tf.float32)) 385 | # t = tf.zeros([1, height, width, 3], dtype=tf.float32) + cutout_color 386 | else: 387 | t = tf.random.uniform([1, height, width, 3], 0, 255., dtype=tf.float32) 388 | 389 | random_patch = tf.pad(t, [[0, 0], [top, frame_height - bottom], [left, frame_width - right], [0, 0]]) 390 | mask = tf.pad(tf.zeros([1, height, width, 1]), 391 | [[0, 0], [top, frame_height - bottom], [left, frame_width - right], [0, 0]], constant_values=1.) 392 | return random_patch + shot * mask 393 | 394 | 395 | @tf.function 396 | def pil_equalize(shot): 397 | # Implements Equalize function from PIL using TF ops. 398 | # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 399 | def scale_channel(im, c): 400 | im = tf.cast(im[:, :, c], tf.int32) 401 | # Compute the histogram of the image channel. 402 | histo = tf.histogram_fixed_width(im, [0, 255], nbins=256) 403 | 404 | # For the purposes of computing the step, filter out the nonzeros. 405 | nonzero = tf.where(tf.not_equal(histo, 0)) 406 | nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1]) 407 | step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255 408 | 409 | def build_lut(histo, step): 410 | # Compute the cumulative sum, shifting by step // 2 and then normalization by step. 411 | lut = (tf.cumsum(histo) + (step // 2)) // step 412 | # Shift lut, prepending with 0. 413 | lut = tf.concat([[0], lut[:-1]], 0) 414 | # Clip the counts to be in range. This is done in the C code for image.point. 415 | return tf.clip_by_value(lut, 0, 255) 416 | 417 | # If step is zero, return the original image. 418 | # Otherwise, build lut from the full histogram and step and then index from it. 419 | result = tf.cond(tf.equal(step, 0), 420 | lambda: im, 421 | lambda: tf.gather(build_lut(histo, step), im)) 422 | 423 | return tf.cast(result, tf.uint8) 424 | 425 | # Assumes RGB for now. Scales each channel independently and then stacks the result. 426 | l, h, w, c = tf.shape(shot)[0], tf.shape(shot)[1], tf.shape(shot)[2], tf.shape(shot)[3] 427 | 428 | shot = tf.reshape(shot, [l * h, w, c]) 429 | s1 = scale_channel(shot, 0) 430 | s2 = scale_channel(shot, 1) 431 | s3 = scale_channel(shot, 2) 432 | shot = tf.stack([s1, s2, s3], 2) 433 | shot = tf.reshape(shot, [l, h, w, c]) 434 | return shot 435 | 436 | 437 | def pil_posterize(image, bits): 438 | # Equivalent of PIL Posterize. 439 | # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 440 | shift = 8 - bits 441 | return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift) 442 | 443 | 444 | def pil_color(shot, factor): 445 | # Equivalent of PIL Color. 446 | # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 447 | degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(shot)) 448 | difference = shot - degenerate 449 | scaled = factor * difference 450 | return tf.clip_by_value(degenerate + scaled, 0., 255.) 451 | 452 | 453 | @tf.function 454 | def color_transfer(source, target, name="color_transfer"): 455 | # based on https://www.pyimagesearch.com/2014/06/30/super-fast-color-transfer-images/ 456 | # transfers based from https://github.com/xahidbuffon/rgb-lab-conv 457 | with tf.name_scope(name): 458 | source = rgb_to_lab(source) 459 | target = rgb_to_lab(target) 460 | 461 | src_mean, src_var = tf.nn.moments(source, axes=(0, 1, 2), keepdims=True) 462 | src_std = tf.sqrt(src_var) 463 | tar_mean, tar_var = tf.nn.moments(target, axes=(0, 1, 2), keepdims=True) 464 | tar_std = tf.sqrt(tar_var) 465 | 466 | # ensure reasonable scaling and prevent division by zero 467 | src_std = tf.maximum(src_std, 1) 468 | tar_std = tf.maximum(tar_std, 1) 469 | 470 | target_shifted = (target - tar_mean) * (src_std / tar_std) + src_mean 471 | 472 | lab_min = tf.constant([[[[0, -86.185, -107.863]]]]) 473 | lab_max = tf.constant([[[[100, 98.254, 94.482]]]]) 474 | target_clipped = tf.clip_by_value(target_shifted, lab_min, lab_max) 475 | 476 | return lab_to_rgb(target_clipped) 477 | 478 | 479 | def rgb_to_lab(rgb): 480 | with tf.name_scope("rgb_to_lab"): 481 | srgb_pixels = tf.reshape(rgb, [-1, 3]) / 255. 482 | with tf.name_scope("srgb_to_xyz"): 483 | linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32) 484 | exponential_mask = 1 - linear_mask 485 | rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask 486 | rgb_to_xyz = tf.constant([ 487 | # X Y Z 488 | [0.412453, 0.212671, 0.019334], # R 489 | [0.357580, 0.715160, 0.119193], # G 490 | [0.180423, 0.072169, 0.950227], # B 491 | ]) 492 | xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz) 493 | 494 | # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions 495 | with tf.name_scope("xyz_to_cielab"): 496 | # convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn) 497 | 498 | # normalize for D65 white point 499 | xyz_normalized_pixels = tf.multiply(xyz_pixels, tf.constant([[1/0.950456, 1.0, 1/1.088754]], tf.float32)) 500 | # fix when values -0.0001 result in Nan if raised to 1/3 501 | xyz_normalized_pixels = tf.maximum(xyz_normalized_pixels, 0) 502 | 503 | epsilon = 6/29 504 | linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32) 505 | exponential_mask = 1 - linear_mask 506 | fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask 507 | 508 | # convert to lab 509 | fxfyfz_to_lab = tf.constant([ 510 | # l a b 511 | [ 0.0, 500.0, 0.0], # fx 512 | [116.0, -500.0, 200.0], # fy 513 | [ 0.0, 0.0, -200.0], # fz 514 | ]) 515 | lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0]) 516 | 517 | return tf.reshape(lab_pixels, tf.shape(rgb)) 518 | 519 | 520 | def lab_to_rgb(lab): 521 | with tf.name_scope("lab_to_rgb"): 522 | lab_pixels = tf.reshape(lab, [-1, 3]) 523 | # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions 524 | with tf.name_scope("cielab_to_xyz"): 525 | # convert to fxfyfz 526 | lab_to_fxfyfz = tf.constant([ 527 | # fx fy fz 528 | [1/116.0, 1/116.0, 1/116.0], # l 529 | [1/500.0, 0.0, 0.0], # a 530 | [ 0.0, 0.0, -1/200.0], # b 531 | ]) 532 | fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz) 533 | 534 | # convert to xyz 535 | epsilon = 6/29 536 | linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32) 537 | exponential_mask = 1 - linear_mask 538 | xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask 539 | 540 | # denormalize for D65 white point 541 | xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754]) 542 | 543 | with tf.name_scope("xyz_to_srgb"): 544 | xyz_to_rgb = tf.constant([ 545 | # r g b 546 | [ 3.2404542, -0.9692660, 0.0556434], # x 547 | [-1.5371385, 1.8760108, -0.2040259], # y 548 | [-0.4985314, 0.0415560, 1.0572252], # z 549 | ]) 550 | rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb) 551 | # avoid a slightly negative number messing up the conversion 552 | rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0) 553 | linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32) 554 | exponential_mask = 1 - linear_mask 555 | rgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask 556 | 557 | return tf.reshape(rgb_pixels, tf.shape(lab)) * 255. 558 | 559 | 560 | @gin.configurable(blacklist=["filenames"]) 561 | def test_pipeline(filenames, 562 | shot_len=100, 563 | batch_size=16): 564 | ds = tf.data.Dataset.from_tensor_slices(filenames) 565 | ds = ds.interleave( 566 | lambda x: tf.data.TFRecordDataset( 567 | x, compression_type="GZIP").map(parse_test_sample, 568 | num_parallel_calls=1).batch(shot_len, drop_remainder=True), 569 | cycle_length=8, 570 | block_length=16, 571 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 572 | ds = ds.batch(batch_size) 573 | ds = ds.prefetch(2) 574 | return ds 575 | 576 | 577 | @tf.function 578 | @gin.configurable(blacklist=["sample"]) 579 | def parse_test_sample(sample, 580 | frame_width=48, 581 | frame_height=27, 582 | no_channels=3): 583 | features = tf.io.parse_single_example(sample, features={ 584 | "frame": tf.io.FixedLenFeature([], tf.string), 585 | "is_one_hot_transition": tf.io.FixedLenFeature([], tf.int64), 586 | "is_many_hot_transition": tf.io.FixedLenFeature([], tf.int64) 587 | }) 588 | 589 | frame = tf.io.decode_raw(features["frame"], tf.uint8) 590 | frame = tf.reshape(frame, [frame_height, frame_width, no_channels]) 591 | 592 | one_hot = tf.cast(features["is_one_hot_transition"], tf.int32) 593 | many_hot = tf.cast(features["is_many_hot_transition"], tf.int32) 594 | 595 | frame = tf.cast(frame, dtype=tf.float32) 596 | return frame, one_hot, many_hot 597 | --------------------------------------------------------------------------------