├── Audio Caption Datasets ├── Readme.md └── tts_ws_python3_demo.py ├── LICENSE ├── README.md ├── assets ├── intro.png └── main_structure.png ├── checkpoints └── best │ └── Readme.md ├── configs ├── activitynet.yaml ├── charades-c3d.yaml ├── charades-i3d.yaml ├── charades-vgg.yaml ├── charades.yaml └── tacos.yaml ├── dataset ├── ActivityNet │ ├── test.json │ ├── test_audio_new.json │ ├── train.json │ ├── train_audio_new.json │ ├── val.json │ └── val_audio_new.json ├── Charades-STA │ ├── charades_test.json │ ├── charades_train.json │ ├── test_audio_new.json │ └── train_audio_new.json └── TACoS │ ├── test.json │ ├── test_audio_new.json │ ├── train.json │ ├── train_audio_new.json │ ├── val.json │ └── val_audio_new.json ├── dtfnet ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ └── __init__.cpython-39.pyc ├── config │ ├── .ipynb_checkpoints │ │ ├── __init__-checkpoint.py │ │ ├── defaults-checkpoint.py │ │ └── paths_catalog-checkpoint.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── defaults.cpython-36.pyc │ │ ├── defaults.cpython-38.pyc │ │ ├── defaults.cpython-39.pyc │ │ ├── paths_catalog.cpython-36.pyc │ │ ├── paths_catalog.cpython-38.pyc │ │ └── paths_catalog.cpython-39.pyc │ ├── defaults.py │ └── paths_catalog.py ├── data │ ├── .ipynb_checkpoints │ │ ├── __init__-checkpoint.py │ │ ├── collate_batch-checkpoint.py │ │ └── samplers-checkpoint.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── collate_batch.cpython-36.pyc │ │ ├── collate_batch.cpython-38.pyc │ │ ├── collate_batch.cpython-39.pyc │ │ ├── samplers.cpython-36.pyc │ │ ├── samplers.cpython-38.pyc │ │ └── samplers.cpython-39.pyc │ ├── collate_batch.py │ ├── datasets │ │ ├── .ipynb_checkpoints │ │ │ ├── __init__-checkpoint.py │ │ │ ├── activitynet-3w-checkpoint.py │ │ │ ├── activitynet-Copy1-checkpoint.py │ │ │ ├── activitynet-Copy2-checkpoint.py │ │ │ ├── activitynet-checkpoint.py │ │ │ ├── activitynet-single-checkpoint.py │ │ │ ├── charades-checkpoint.py │ │ │ ├── concat_dataset-checkpoint.py │ │ │ ├── evaluation-checkpoint.py │ │ │ ├── tacos-checkpoint.py │ │ │ └── utils-checkpoint.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── activitynet.cpython-36.pyc │ │ │ ├── activitynet.cpython-38.pyc │ │ │ ├── activitynet.cpython-39.pyc │ │ │ ├── charades.cpython-36.pyc │ │ │ ├── charades.cpython-38.pyc │ │ │ ├── charades.cpython-39.pyc │ │ │ ├── concat_dataset.cpython-36.pyc │ │ │ ├── concat_dataset.cpython-38.pyc │ │ │ ├── concat_dataset.cpython-39.pyc │ │ │ ├── evaluation.cpython-36.pyc │ │ │ ├── evaluation.cpython-38.pyc │ │ │ ├── evaluation.cpython-39.pyc │ │ │ ├── tacos.cpython-36.pyc │ │ │ ├── tacos.cpython-38.pyc │ │ │ ├── tacos.cpython-39.pyc │ │ │ ├── utils.cpython-36.pyc │ │ │ ├── utils.cpython-38.pyc │ │ │ └── utils.cpython-39.pyc │ │ ├── activitynet-3w.py │ │ ├── activitynet-Copy1.py │ │ ├── activitynet-Copy2.py │ │ ├── activitynet-Copy3.py │ │ ├── activitynet-single.py │ │ ├── activitynet.py │ │ ├── charades.py │ │ ├── concat_dataset.py │ │ ├── evaluation.py │ │ ├── tacos.py │ │ └── utils.py │ └── samplers.py ├── engine │ ├── .ipynb_checkpoints │ │ ├── __init__-checkpoint.py │ │ ├── inference-checkpoint.py │ │ └── trainer-checkpoint.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── inference.cpython-36.pyc │ │ ├── inference.cpython-38.pyc │ │ ├── inference.cpython-39.pyc │ │ ├── trainer.cpython-36.pyc │ │ ├── trainer.cpython-38.pyc │ │ └── trainer.cpython-39.pyc │ ├── inference.py │ └── trainer.py ├── modeling │ ├── .ipynb_checkpoints │ │ └── __init__-checkpoint.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ └── __init__.cpython-39.pyc │ └── dtf │ │ ├── .ipynb_checkpoints │ │ ├── __init__-checkpoint.py │ │ ├── dtf_model-checkpoint.py │ │ ├── dynamic_encode-checkpoint.py │ │ ├── feat2d-checkpoint.py │ │ ├── featpool-checkpoint.py │ │ ├── loss-checkpoint.py │ │ ├── position_encoding-checkpoint.py │ │ ├── proposal_conv-checkpoint.py │ │ ├── text_encoder-checkpoint.py │ │ ├── text_out-checkpoint.py │ │ └── utils-checkpoint.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── GCNNet.cpython-36.pyc │ │ ├── GCNNet.cpython-38.pyc │ │ ├── GCNNet.cpython-39.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── dtf_model.cpython-36.pyc │ │ ├── dtf_model.cpython-38.pyc │ │ ├── dtf_model.cpython-39.pyc │ │ ├── dynamic_encode.cpython-36.pyc │ │ ├── feat2d.cpython-36.pyc │ │ ├── feat2d.cpython-38.pyc │ │ ├── feat2d.cpython-39.pyc │ │ ├── featpool.cpython-36.pyc │ │ ├── featpool.cpython-38.pyc │ │ ├── featpool.cpython-39.pyc │ │ ├── loss.cpython-36.pyc │ │ ├── loss.cpython-38.pyc │ │ ├── loss.cpython-39.pyc │ │ ├── mmn.cpython-36.pyc │ │ ├── position_encoding.cpython-36.pyc │ │ ├── position_encoding.cpython-38.pyc │ │ ├── position_encoding.cpython-39.pyc │ │ ├── proposal_conv.cpython-36.pyc │ │ ├── proposal_conv.cpython-38.pyc │ │ ├── proposal_conv.cpython-39.pyc │ │ ├── text_encoder.cpython-36.pyc │ │ ├── text_encoder.cpython-38.pyc │ │ ├── text_encoder.cpython-39.pyc │ │ ├── text_out.cpython-36.pyc │ │ ├── text_out.cpython-38.pyc │ │ ├── text_out.cpython-39.pyc │ │ └── utils.cpython-36.pyc │ │ ├── dtf_model.py │ │ ├── dynamic_encode.py │ │ ├── feat2d.py │ │ ├── featpool.py │ │ ├── loss.py │ │ ├── position_encoding.py │ │ ├── proposal_conv.py │ │ ├── text_encoder.py │ │ ├── text_out.py │ │ └── utils.py ├── structures │ ├── .ipynb_checkpoints │ │ ├── __init__-checkpoint.py │ │ └── tlg_batch-checkpoint.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── tlg_batch.cpython-36.pyc │ │ ├── tlg_batch.cpython-38.pyc │ │ └── tlg_batch.cpython-39.pyc │ └── tlg_batch.py └── utils │ ├── .ipynb_checkpoints │ ├── __init__-checkpoint.py │ ├── checkpoint-checkpoint.py │ ├── comm-checkpoint.py │ ├── imports-checkpoint.py │ ├── logger-checkpoint.py │ ├── metric_logger-checkpoint.py │ ├── miscellaneous-checkpoint.py │ ├── model_serialization-checkpoint.py │ ├── registry-checkpoint.py │ └── timer-checkpoint.py │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── checkpoint.cpython-36.pyc │ ├── checkpoint.cpython-38.pyc │ ├── checkpoint.cpython-39.pyc │ ├── comm.cpython-36.pyc │ ├── comm.cpython-38.pyc │ ├── comm.cpython-39.pyc │ ├── imports.cpython-36.pyc │ ├── imports.cpython-38.pyc │ ├── imports.cpython-39.pyc │ ├── logger.cpython-36.pyc │ ├── logger.cpython-38.pyc │ ├── logger.cpython-39.pyc │ ├── metric_logger.cpython-36.pyc │ ├── metric_logger.cpython-38.pyc │ ├── metric_logger.cpython-39.pyc │ ├── miscellaneous.cpython-36.pyc │ ├── miscellaneous.cpython-38.pyc │ ├── miscellaneous.cpython-39.pyc │ ├── model_serialization.cpython-36.pyc │ ├── model_serialization.cpython-38.pyc │ ├── model_serialization.cpython-39.pyc │ ├── timer.cpython-36.pyc │ ├── timer.cpython-38.pyc │ └── timer.cpython-39.pyc │ ├── checkpoint.py │ ├── comm.py │ ├── imports.py │ ├── logger.py │ ├── metric_logger.py │ ├── miscellaneous.py │ ├── model_serialization.py │ ├── registry.py │ └── timer.py ├── pre_process ├── audio_encode.py ├── feat_pca.py └── text_encode.py ├── test_net.py └── train_net.py /Audio Caption Datasets/Readme.md: -------------------------------------------------------------------------------- 1 | # How do we make audio caption datasets? 2 | 3 | In the [VGCL paper](https://arxiv.org/pdf/2209.00277.pdf), 58 volunteers are asked to fluently read the text in a clean surrounding environment to obtain the sudio dataset corresponding to AncitivyNet Caption Dataset. 4 | 5 | However, in our paper, we use machine simulation of human voice to synthesize audio subtitle datasets corresponding to Charades-STA Caption dataset and TACoS Caption Dataset. 6 | 7 | There are several reasons for machine simulation: 8 | 9 | * Thanks to the development of [TTS technology](https://huggingface.co/microsoft/speecht5_tts), it can highly simulate human voice, including other complex features such as style. 10 | * Cost savings. 11 | * Diverse vocal styles, for example, the [Matthijs/cmu-arctic-xvectors](https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors) dataset contains around 8000 vocal features to choose from. [details](http://www.festvox.org/cmu_arctic/) 12 | 13 | -------------------------------------------------------------------------------- /Audio Caption Datasets/tts_ws_python3_demo.py: -------------------------------------------------------------------------------- 1 | # Following pip packages need to be installed: 2 | # !pip install git+https://github.com/huggingface/transformers sentencepiece datasets 3 | 4 | from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan 5 | import torch 6 | import soundfile as sf 7 | from datasets import load_dataset 8 | import numpy as np 9 | import os 10 | from tqdm import tqdm 11 | import json 12 | 13 | # set random seed for reproducibility 14 | np.random.seed(1) 15 | processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") 16 | model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") 17 | vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") 18 | # load xvector containing speaker's voice characteristics from a dataset 19 | embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split='validation') 20 | 21 | 22 | def text_microsoft_mav(root_path, text_path, audios_dir): 23 | 24 | if not os.path.exists(audios_dir): 25 | os.mkdir(audios_dir) 26 | 27 | if 'train' in text_path: 28 | mode = 'train' 29 | elif 'val' in text_path: 30 | mode = 'val' 31 | elif 'test' in text_path: 32 | mode = 'test' 33 | else: 34 | raise ValueError("text_path should contain train, val or test") 35 | 36 | audio_json = {} 37 | audio_json_path = os.path.join(root_path, f"{mode}_audio.json") 38 | 39 | data = json.load(open(text_path, 'r', encoding='utf-8')) 40 | 41 | # generate random speaker id for each video 42 | random_idx = np.random.randint(0, len(embeddings_dataset), size=len(data)) 43 | 44 | for i, (k, v) in tqdm(enumerate(data.items()), total=len(data)): 45 | vid = k 46 | duration = v['duration'] 47 | timestamps = v['timestamps'] 48 | sentences = v['sentences'] 49 | audios = [] 50 | speaker = embeddings_dataset[int(random_idx[i])]["filename"] 51 | speaker_id = int(random_idx[i]) 52 | for i, s in enumerate(sentences): 53 | audio_name = f"{vid}_{mode}_{i+1}.wav" 54 | save_path = os.path.join(audios_dir, audio_name) 55 | audios.append(audio_name) 56 | if os.path.exists(audio_name): 57 | continue 58 | else: 59 | inputs = processor(text=s, return_tensors="pt") 60 | speaker_embeddings = torch.tensor(embeddings_dataset[speaker_id]["xvector"]).unsqueeze(0) 61 | speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder) 62 | sf.write(save_path, speech.numpy(), samplerate=16000) 63 | 64 | audio_json.update({ 65 | f"{vid}": { 66 | 'duration': duration, 67 | 'timestamps': timestamps, 68 | 'sentences': sentences, 69 | 'audios': audios, 70 | 'speaker': speaker, 71 | 'speaker_id': speaker_id, 72 | } 73 | }) 74 | 75 | json.dump(audio_json, open(audio_json_path, 'w', encoding='utf-8'), indent=4, ensure_ascii=False) 76 | 77 | 78 | def TACoS_text_microsoft_mav(root_path, text_path, audios_dir): 79 | 80 | if not os.path.exists(audios_dir): 81 | os.mkdir(audios_dir) 82 | 83 | if 'train' in text_path: 84 | mode = 'train' 85 | elif 'val' in text_path: 86 | mode = 'val' 87 | elif 'test' in text_path: 88 | mode = 'test' 89 | else: 90 | raise ValueError("text_path should contain train, val or test") 91 | 92 | audio_json = {} 93 | audio_json_path = os.path.join(root_path, f"{mode}_audio.json") 94 | 95 | data = json.load(open(text_path, 'r', encoding='utf-8')) 96 | 97 | # generate random speaker id for each video 98 | random_idx = np.random.randint(0, len(embeddings_dataset), size=len(data)) 99 | 100 | for i, (k, v) in tqdm(enumerate(data.items()), total=len(data)): 101 | vid = k 102 | vid_name = vid.split(".")[0] 103 | num_frames = v['num_frames'] 104 | fps = v['fps'] 105 | timestamps = v['timestamps'] 106 | sentences = v['sentences'] 107 | audios = [] 108 | speaker = embeddings_dataset[int(random_idx[i])]["filename"] 109 | speaker_id = int(random_idx[i]) 110 | for i, s in enumerate(sentences): 111 | audio_name = f"{vid_name}_{mode}_{i+1}.wav" 112 | save_path = os.path.join(audios_dir, audio_name) 113 | audios.append(audio_name) 114 | if os.path.exists(audio_name): 115 | continue 116 | else: 117 | inputs = processor(text=s, return_tensors="pt") 118 | speaker_embeddings = torch.tensor(embeddings_dataset[speaker_id]["xvector"]).unsqueeze(0) 119 | speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder) 120 | sf.write(save_path, speech.numpy(), samplerate=16000) 121 | 122 | audio_json.update({ 123 | f"{vid}": { 124 | 'num_frames': num_frames, 125 | 'fps': fps, 126 | 'timestamps': timestamps, 127 | 'sentences': sentences, 128 | 'audios': audios, 129 | 'speaker': speaker, 130 | 'speaker_id': speaker_id, 131 | } 132 | }) 133 | 134 | json.dump(audio_json, open(audio_json_path, 'w', encoding='utf-8'), indent=4, ensure_ascii=False) 135 | 136 | 137 | if __name__ == "__main__": 138 | # root_path = r".\Charades_STA" 139 | root_path = r".\TACoS" 140 | train_path = os.path.join(root_path, "train.json") 141 | val_path = os.path.join(root_path, "val.json") 142 | test_path = os.path.join(root_path, "test.json") 143 | 144 | audios_dir = r".\TACoS\audios" 145 | 146 | # Charades_STA: train:5336 test:1334 147 | # TACoS: train:75 val:27 test: 25 148 | TACoS_text_microsoft_mav(root_path, test_path, audios_dir) 149 | TACoS_text_microsoft_mav(root_path, val_path, audios_dir) 150 | TACoS_text_microsoft_mav(root_path, train_path, audios_dir) 151 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Compute-Chem 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/assets/intro.png -------------------------------------------------------------------------------- /assets/main_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/assets/main_structure.png -------------------------------------------------------------------------------- /checkpoints/best/Readme.md: -------------------------------------------------------------------------------- 1 | please download the best trained models and put them in this folder 2 | -------------------------------------------------------------------------------- /configs/activitynet.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ARCHITECTURE: "DTF" 3 | DTF: 4 | VIDEO_MODE: 'c3d' 5 | NUM_CLIPS: 64 6 | JOINT_SPACE_SIZE: 256 7 | FEATPOOL: 8 | INPUT_SIZE: 500 9 | HIDDEN_SIZE: 512 10 | KERNEL_SIZE: 4 11 | FEAT2D: 12 | NAME: "pool_b" 13 | POOLING_COUNTS: [15,8,8] 14 | TEXT_ENCODER: 15 | NAME: "BERT" 16 | PREDICTOR: 17 | HIDDEN_SIZE: 512 18 | KERNEL_SIZE: 9 19 | NUM_STACK_LAYERS: 4 20 | LOSS: 21 | MIN_IOU: 0.5 22 | MAX_IOU: 1.0 23 | NUM_POSTIVE_VIDEO_PROPOSAL: 1 24 | NEGATIVE_VIDEO_IOU: 0.5 25 | SENT_REMOVAL_IOU: 0.5 26 | TAU_VIDEO: 0.1 27 | TAU_SENT: 0.1 28 | MARGIN: 0.3 29 | CONTRASTIVE_WEIGHT: 0.1 30 | OUTPUT_DIR: "./activity/dist/full_model" 31 | DATASETS: 32 | NAME: "activitynet" 33 | TRAIN: ("activitynet_train", ) 34 | TEST: ("activitynet_test",) 35 | INPUT: 36 | NUM_PRE_CLIPS: 256 37 | DATALOADER: 38 | NUM_WORKERS: 16 39 | SOLVER: 40 | LR: 0.0008 41 | BATCH_SIZE: 12 42 | MILESTONES: (7, 15) 43 | MAX_EPOCH: 15 44 | TEST_PERIOD: 1 45 | CHECKPOINT_PERIOD: 1 46 | RESUME: False 47 | RESUME_EPOCH: 16 48 | FREEZE_BERT: 4 49 | ONLY_IOU: 6 50 | SKIP_TEST: 1 51 | USE_STATIC: True 52 | USE_GNN: True 53 | GNN_SPARSE: True 54 | GNN_MODE: 'gauss' # 'gauss' or 'd' or 'gat' or 'mlp' 55 | GNN_LAYERS: 2 56 | GNN_U: 5.0 57 | GNN_STEP: 0.1 58 | POS_EMBED: 'sine' # choices=['trainable', 'sine', 'learned'] 59 | TEST: 60 | NMS_THRESH: 0.5 61 | BATCH_SIZE: 12 62 | CONTRASTIVE_SCORE_POW: 0.5 63 | -------------------------------------------------------------------------------- /configs/charades-c3d.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ARCHITECTURE: "DTF" 3 | DTF: 4 | VIDEO_MODE: 'c3d' 5 | NUM_CLIPS: 16 6 | JOINT_SPACE_SIZE: 512 7 | FEATPOOL: 8 | INPUT_SIZE: 4096 9 | HIDDEN_SIZE: 512 10 | KERNEL_SIZE: 2 11 | FEAT2D: 12 | NAME: "conv" 13 | POOLING_COUNTS: [15] 14 | TEXT_ENCODER: 15 | NAME: "BERT" 16 | PREDICTOR: 17 | HIDDEN_SIZE: 512 18 | KERNEL_SIZE: 5 19 | NUM_STACK_LAYERS: 3 20 | LOSS: 21 | MIN_IOU: 0.5 22 | MAX_IOU: 1.0 23 | NUM_POSTIVE_VIDEO_PROPOSAL: 1 24 | NEGATIVE_VIDEO_IOU: 0.5 25 | SENT_REMOVAL_IOU: 0.5 26 | TAU_VIDEO: 0.1 27 | TAU_SENT: 0.1 28 | MARGIN: 0.4 29 | CONTRASTIVE_WEIGHT: 0.05 30 | OUTPUT_DIR: './sta/text_c3d_5_3' 31 | DATASETS: 32 | NAME: "charades" 33 | TRAIN: ("charades_train", ) 34 | TEST: ("charades_test",) 35 | INPUT: 36 | NUM_PRE_CLIPS: 32 37 | DATALOADER: 38 | NUM_WORKERS: 16 39 | SOLVER: 40 | LR: 0.0001 41 | BATCH_SIZE: 48 42 | MILESTONES: (8, 13) 43 | MAX_EPOCH: 30 44 | TEST_PERIOD: 1 45 | CHECKPOINT_PERIOD: 1 46 | RESUME: True 47 | RESUME_EPOCH: 19 48 | FREEZE_BERT: 20 49 | ONLY_IOU: 7 50 | SKIP_TEST: 1 51 | USE_STATIC: True 52 | USE_GNN: True 53 | GNN_SPARSE: True 54 | GNN_MODE: 'gauss' # 'gauss' or 'd' or 'gat' 55 | GNN_LAYERS: 2 56 | GNN_U: 5.0 57 | GNN_STEP: 0.1 58 | POS_EMBED: 'sine' # choices=['trainable', 'sine', 'learned'] 59 | TEST: 60 | NMS_THRESH: 0.5 61 | BATCH_SIZE: 48 62 | CONTRASTIVE_SCORE_POW: 0.5 63 | -------------------------------------------------------------------------------- /configs/charades-i3d.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ARCHITECTURE: "DTF" 3 | DTF: 4 | VIDEO_MODE: 'i3d' 5 | NUM_CLIPS: 64 6 | JOINT_SPACE_SIZE: 256 7 | FEATPOOL: 8 | INPUT_SIZE: 1024 9 | HIDDEN_SIZE: 512 10 | KERNEL_SIZE: 2 11 | FEAT2D: 12 | NAME: "pool" 13 | POOLING_COUNTS: [15,8,8] 14 | TEXT_ENCODER: 15 | NAME: "BERT" 16 | PREDICTOR: 17 | HIDDEN_SIZE: 512 18 | KERNEL_SIZE: 17 19 | NUM_STACK_LAYERS: 2 20 | LOSS: 21 | MIN_IOU: 0.5 22 | MAX_IOU: 1.0 23 | NUM_POSTIVE_VIDEO_PROPOSAL: 1 24 | NEGATIVE_VIDEO_IOU: 0.5 25 | SENT_REMOVAL_IOU: 0.5 26 | TAU_VIDEO: 0.1 27 | TAU_SENT: 0.1 28 | MARGIN: 0.4 29 | CONTRASTIVE_WEIGHT: 0.05 30 | OUTPUT_DIR: './sta/text_i3d_17_2' 31 | DATASETS: 32 | NAME: "charades" 33 | TRAIN: ("charades_train", ) 34 | TEST: ("charades_test",) 35 | INPUT: 36 | NUM_PRE_CLIPS: 128 37 | DATALOADER: 38 | NUM_WORKERS: 16 39 | SOLVER: 40 | LR: 0.001 41 | BATCH_SIZE: 12 42 | MILESTONES: (8, 13) 43 | MAX_EPOCH: 30 44 | TEST_PERIOD: 1 45 | CHECKPOINT_PERIOD: 1 46 | RESUME: False 47 | RESUME_EPOCH: 19 48 | FREEZE_BERT: 60 49 | ONLY_IOU: 7 50 | SKIP_TEST: 1 51 | USE_STATIC: True 52 | USE_GNN: True 53 | GNN_SPARSE: True 54 | GNN_MODE: 'gauss' # 'gauss' or 'd' or 'gat' 55 | GNN_LAYERS: 2 56 | GNN_U: 5.0 57 | GNN_STEP: 0.1 58 | POS_EMBED: 'sine' # choices=['trainable', 'sine', 'learned'] 59 | TEST: 60 | NMS_THRESH: 0.5 61 | BATCH_SIZE: 12 62 | CONTRASTIVE_SCORE_POW: 0.5 63 | -------------------------------------------------------------------------------- /configs/charades-vgg.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ARCHITECTURE: "DTF" 3 | DTF: 4 | VIDEO_MODE: 'vgg' 5 | NUM_CLIPS: 16 6 | JOINT_SPACE_SIZE: 512 7 | FEATPOOL: 8 | INPUT_SIZE: 4096 9 | HIDDEN_SIZE: 512 10 | KERNEL_SIZE: 2 11 | FEAT2D: 12 | NAME: "conv" 13 | POOLING_COUNTS: [15] 14 | TEXT_ENCODER: 15 | NAME: "BERT" 16 | PREDICTOR: 17 | HIDDEN_SIZE: 512 18 | KERNEL_SIZE: 5 19 | NUM_STACK_LAYERS: 3 20 | LOSS: 21 | MIN_IOU: 0.5 22 | MAX_IOU: 1.0 23 | NUM_POSTIVE_VIDEO_PROPOSAL: 1 24 | NEGATIVE_VIDEO_IOU: 0.5 25 | SENT_REMOVAL_IOU: 0.5 26 | TAU_VIDEO: 0.1 27 | TAU_SENT: 0.1 28 | MARGIN: 0.4 29 | CONTRASTIVE_WEIGHT: 0.05 30 | OUTPUT_DIR: './sta/text_vgg_5_3' 31 | DATASETS: 32 | NAME: "charades" 33 | TRAIN: ("charades_train", ) 34 | TEST: ("charades_test",) 35 | INPUT: 36 | NUM_PRE_CLIPS: 32 37 | DATALOADER: 38 | NUM_WORKERS: 16 39 | SOLVER: 40 | LR: 0.0001 41 | BATCH_SIZE: 48 42 | MILESTONES: (8, 13) 43 | MAX_EPOCH: 30 44 | TEST_PERIOD: 1 45 | CHECKPOINT_PERIOD: 1 46 | RESUME: False 47 | RESUME_EPOCH: 19 48 | FREEZE_BERT: 20 49 | ONLY_IOU: 7 50 | SKIP_TEST: 1 51 | USE_STATIC: True 52 | USE_GNN: True 53 | GNN_SPARSE: True 54 | GNN_MODE: 'gauss' # 'gauss' or 'd' or 'gat' 55 | GNN_LAYERS: 2 56 | GNN_U: 5.0 57 | GNN_STEP: 0.1 58 | POS_EMBED: 'sine' # choices=['trainable', 'sine', 'learned'] 59 | TEST: 60 | NMS_THRESH: 0.5 61 | BATCH_SIZE: 48 62 | CONTRASTIVE_SCORE_POW: 0.5 63 | -------------------------------------------------------------------------------- /configs/charades.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ARCHITECTURE: "DTF" 3 | DTF: 4 | NUM_CLIPS: 16 5 | JOINT_SPACE_SIZE: 512 6 | FEATPOOL: 7 | INPUT_SIZE: 4096 8 | HIDDEN_SIZE: 512 9 | KERNEL_SIZE: 2 10 | FEAT2D: 11 | NAME: "pool" 12 | POOLING_COUNTS: [15] 13 | TEXT_ENCODER: 14 | NAME: "BERT" 15 | PREDICTOR: 16 | HIDDEN_SIZE: 512 17 | KERNEL_SIZE: 5 18 | NUM_STACK_LAYERS: 8 19 | LOSS: 20 | MIN_IOU: 0.5 21 | MAX_IOU: 1.0 22 | NUM_POSTIVE_VIDEO_PROPOSAL: 1 23 | NEGATIVE_VIDEO_IOU: 0.5 24 | SENT_REMOVAL_IOU: 0.5 25 | TAU_VIDEO: 0.1 26 | TAU_SENT: 0.1 27 | MARGIN: 0.4 28 | CONTRASTIVE_WEIGHT: 0.05 29 | DATASETS: 30 | NAME: "charades" 31 | TRAIN: ("charades_train", ) 32 | TEST: ("charades_test",) 33 | INPUT: 34 | NUM_PRE_CLIPS: 32 35 | DATALOADER: 36 | NUM_WORKERS: 0 37 | SOLVER: 38 | LR: 0.0001 39 | BATCH_SIZE: 4 40 | MILESTONES: (8, 13) 41 | MAX_EPOCH: 18 42 | TEST_PERIOD: 1 43 | CHECKPOINT_PERIOD: 1 44 | RESUME: False 45 | RESUME_EPOCH: 7 46 | FREEZE_BERT: 4 47 | ONLY_IOU: 7 48 | SKIP_TEST: 1 49 | USE_STATIC: False 50 | USE_GNN: False 51 | GNN_SPARSE: True 52 | GNN_MODE: 'gauss' # 'gauss' or 'd' or 'gat' 53 | GNN_LAYERS: 4 54 | GNN_U: 5.0 55 | GNN_STEP: 0.1 56 | POS_EMBED: 'sine' # choices=['trainable', 'sine', 'learned'] 57 | TEST: 58 | NMS_THRESH: 0.5 59 | BATCH_SIZE: 4 60 | CONTRASTIVE_SCORE_POW: 0.5 61 | -------------------------------------------------------------------------------- /configs/tacos.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ARCHITECTURE: "DTF" 3 | DTF: 4 | NUM_CLIPS: 128 5 | JOINT_SPACE_SIZE: 256 6 | FEATPOOL: 7 | INPUT_SIZE: 500 8 | HIDDEN_SIZE: 512 9 | KERNEL_SIZE: 4 10 | FEAT2D: 11 | NAME: "pool" 12 | POOLING_COUNTS: [15,8,8,8] 13 | TEXT_ENCODER: 14 | NAME: 'BERT' 15 | PREDICTOR: 16 | HIDDEN_SIZE: 512 17 | KERNEL_SIZE: 5 18 | NUM_STACK_LAYERS: 3 19 | LOSS: 20 | MIN_IOU: 0.3 21 | MAX_IOU: 0.7 22 | NUM_POSTIVE_VIDEO_PROPOSAL: 3 23 | NEGATIVE_VIDEO_IOU: 0.5 24 | SENT_REMOVAL_IOU: 0.5 25 | TAU_VIDEO: 0.1 26 | TAU_SENT: 0.1 27 | MARGIN: 0.1 28 | CONTRASTIVE_WEIGHT: 0.1 29 | OUTPUT_DIR: "./tacos/glove/full_model_new" 30 | DATASETS: 31 | NAME: "tacos" 32 | TRAIN: ("tacos_train",) 33 | TEST: ("tacos_test",) 34 | INPUT: 35 | NUM_PRE_CLIPS: 512 36 | DATALOADER: 37 | NUM_WORKERS: 0 38 | SOLVER: 39 | LR: 0.0015 40 | BATCH_SIZE: 2 41 | MILESTONES: (130, 190) 42 | MAX_EPOCH: 250 43 | TEST_PERIOD: 10 44 | CHECKPOINT_PERIOD: 10 45 | RESUME: True 46 | RESUME_EPOCH: 211 47 | FREEZE_BERT: 400 48 | ONLY_IOU: 100 49 | SKIP_TEST: 0 50 | USE_STATIC: True 51 | USE_GNN: True 52 | GNN_SPARSE: True 53 | GNN_MODE: 'gauss' # 'gauss' or 'd' or 'gat' 54 | GNN_LAYERS: 2 55 | GNN_U: 5.0 56 | GNN_STEP: 0.1 57 | POS_EMBED: 'sine' # choices=['trainable', 'sine', 'learned'] 58 | TEST: 59 | NMS_THRESH: 0.4 60 | BATCH_SIZE: 2 61 | CONTRASTIVE_SCORE_POW: 0.3 62 | 63 | -------------------------------------------------------------------------------- /dtfnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/__init__.py -------------------------------------------------------------------------------- /dtfnet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/config/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg 2 | -------------------------------------------------------------------------------- /dtfnet/config/.ipynb_checkpoints/defaults-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from yacs.config import CfgNode as CN 4 | # ----------------------------------------------------------------------------- 5 | # Convention about Training / Test specific parameters 6 | # ----------------------------------------------------------------------------- 7 | # Config definition 8 | # ----------------------------------------------------------------------------- 9 | _C = CN() 10 | 11 | _C.MODEL = CN() 12 | _C.MODEL.DEVICE = "cuda" 13 | _C.MODEL.ARCHITECTURE = "DTF" 14 | 15 | # ----------------------------------------------------------------------------- 16 | # INPUT 17 | # ----------------------------------------------------------------------------- 18 | _C.INPUT = CN() 19 | _C.INPUT.NUM_PRE_CLIPS = 256 20 | 21 | # ----------------------------------------------------------------------------- 22 | # Dataset 23 | # ----------------------------------------------------------------------------- 24 | _C.DATASETS = CN() 25 | # List of the dataset names for training, as present in paths_catalog.py 26 | _C.DATASETS.TRAIN = () 27 | # List of the dataset names for testing, as present in paths_catalog.py 28 | _C.DATASETS.TEST = () 29 | _C.DATASETS.NAME = "" 30 | 31 | # ----------------------------------------------------------------------------- 32 | # DataLoader 33 | # ----------------------------------------------------------------------------- 34 | _C.DATALOADER = CN() 35 | # Number of data loading threads 36 | _C.DATALOADER.NUM_WORKERS = 4 37 | 38 | # ----------------------------------------------------------------------------- 39 | # Models 40 | # ----------------------------------------------------------------------------- 41 | _C.MODEL.DTF = CN() 42 | _C.MODEL.DTF.VIDEO_MODE = 'c3d' 43 | _C.MODEL.DTF.NUM_CLIPS = 128 44 | _C.MODEL.DTF.JOINT_SPACE_SIZE = 256 45 | 46 | _C.MODEL.DTF.FEATPOOL = CN() 47 | _C.MODEL.DTF.FEATPOOL.INPUT_SIZE = 4096 48 | _C.MODEL.DTF.FEATPOOL.HIDDEN_SIZE = 512 49 | _C.MODEL.DTF.FEATPOOL.KERNEL_SIZE = 2 50 | 51 | _C.MODEL.DTF.FEAT2D = CN() 52 | _C.MODEL.DTF.FEAT2D.NAME = "pool" 53 | _C.MODEL.DTF.FEAT2D.POOLING_COUNTS = [15, 8, 8, 8] 54 | 55 | _C.MODEL.DTF.TEXT_ENCODER = CN() 56 | _C.MODEL.DTF.TEXT_ENCODER.NAME = 'BERT' 57 | 58 | _C.MODEL.DTF.PREDICTOR = CN() 59 | _C.MODEL.DTF.PREDICTOR.HIDDEN_SIZE = 512 60 | _C.MODEL.DTF.PREDICTOR.KERNEL_SIZE = 5 61 | _C.MODEL.DTF.PREDICTOR.NUM_STACK_LAYERS = 8 62 | 63 | 64 | _C.MODEL.DTF.LOSS = CN() 65 | _C.MODEL.DTF.LOSS.MIN_IOU = 0.3 66 | _C.MODEL.DTF.LOSS.MAX_IOU = 0.7 67 | _C.MODEL.DTF.LOSS.BCE_WEIGHT = 1 68 | _C.MODEL.DTF.LOSS.NUM_POSTIVE_VIDEO_PROPOSAL = 1 69 | _C.MODEL.DTF.LOSS.NEGATIVE_VIDEO_IOU = 0.5 70 | _C.MODEL.DTF.LOSS.SENT_REMOVAL_IOU = 0.5 71 | _C.MODEL.DTF.LOSS.PAIRWISE_SENT_WEIGHT = 0.0 72 | _C.MODEL.DTF.LOSS.CONTRASTIVE_WEIGHT = 0.05 73 | _C.MODEL.DTF.LOSS.TAU_VIDEO = 0.2 74 | _C.MODEL.DTF.LOSS.TAU_SENT = 0.2 75 | _C.MODEL.DTF.LOSS.MARGIN = 0.2 76 | 77 | # ---------------------------------------------------------------------------- # 78 | # Solver 79 | # ---------------------------------------------------------------------------- # 80 | _C.SOLVER = CN() 81 | _C.SOLVER.MAX_EPOCH = 12 82 | _C.SOLVER.LR = 0.01 83 | _C.SOLVER.CHECKPOINT_PERIOD = 1 84 | _C.SOLVER.TEST_PERIOD = 1 85 | _C.SOLVER.BATCH_SIZE = 32 86 | _C.SOLVER.MILESTONES = (8, 11) 87 | _C.SOLVER.RESUME = False 88 | _C.SOLVER.RESUME_EPOCH = 1 89 | _C.SOLVER.FREEZE_BERT = 4 90 | _C.SOLVER.ONLY_IOU = 7 91 | _C.SOLVER.SKIP_TEST = 0 92 | _C.SOLVER.USE_STATIC = True 93 | _C.SOLVER.USE_GNN = True 94 | _C.SOLVER.GNN_SPARSE = True 95 | _C.SOLVER.GNN_MODE = 'gauss' 96 | _C.SOLVER.POS_EMBED = 'sine' 97 | _C.SOLVER.GNN_LAYERS = 2 98 | _C.SOLVER.GNN_U = 5.0 99 | _C.SOLVER.GNN_STEP = 0.1 100 | # ---------------------------------------------------------------------------- # 101 | # Specific test options 102 | # ---------------------------------------------------------------------------- # 103 | _C.TEST = CN() 104 | _C.TEST.BATCH_SIZE = 64 105 | _C.TEST.NMS_THRESH = 0.5 106 | _C.TEST.CONTRASTIVE_SCORE_POW = 0.5 107 | 108 | # ---------------------------------------------------------------------------- # 109 | # Misc options 110 | # ---------------------------------------------------------------------------- # 111 | _C.OUTPUT_DIR = "./activity/textfull" 112 | _C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py") 113 | -------------------------------------------------------------------------------- /dtfnet/config/.ipynb_checkpoints/paths_catalog-checkpoint.py: -------------------------------------------------------------------------------- 1 | """Centralized catalog of paths.""" 2 | import os 3 | from dtfnet.config import cfg 4 | 5 | class DatasetCatalog(object): 6 | DATA_DIR = "" 7 | 8 | DATASETS = { 9 | "tacos_train":{ 10 | "audio_dir": "/hujingjing2/data/TACoS/audio_data2vec_feat", 11 | "ann_file": "./dataset/TACoS/train_audio_new.json", 12 | "feat_file": "/hujingjing2/data/TACoS/tall_c3d_features.hdf5", 13 | }, 14 | "tacos_val":{ 15 | "audio_dir": "/hujingjing2/data/TACoS/audio_data2vec_feat", 16 | "ann_file": "./dataset/TACoS/val_audio_new.json", 17 | "feat_file": "/hujingjing2/data/TACoS/tall_c3d_features.hdf5", 18 | }, 19 | "tacos_test":{ 20 | "audio_dir": "/hujingjing2/data/TACoS/audio_data2vec_feat", 21 | "ann_file": "./dataset/TACoS/test_audio_new.json", 22 | "feat_file": "/hujingjing2/data/TACoS/tall_c3d_features.hdf5", 23 | }, 24 | "activitynet_train":{ 25 | "audio_dir": "/hujingjing2/data/ActivityNet/text_distbert_feat", 26 | "ann_file": "./dataset/ActivityNet/train_audio_new.json", 27 | "feat_file": "/hujingjing2/Spoken_Video_Grounding/data/activity-c3d", 28 | }, 29 | "activitynet_val":{ 30 | "audio_dir": "/hujingjing2/data/ActivityNet/text_distbert_feat", #/hujingjing2/data/ActivityNet/text_distbert_feat 31 | "ann_file": "./dataset/ActivityNet/val_audio_new.json", 32 | "feat_file": "/hujingjing2/Spoken_Video_Grounding/data/activity-c3d", 33 | }, 34 | "activitynet_test":{ 35 | "audio_dir": "/hujingjing2/data/ActivityNet/text_distbert_feat", 36 | "ann_file": "./dataset/ActivityNet/test_audio_new.json", 37 | "feat_file": "/hujingjing2/Spoken_Video_Grounding/data/activity-c3d", 38 | }, 39 | "charades_train": { 40 | "audio_dir": "/hujingjing2/data/Charades_STA/text_distbert_feat", 41 | "ann_file": "./dataset/Charades_STA/train_audio_new.json", 42 | # "feat_file": "G:/Dataset/data/Charades_STA/vgg_rgb_features.hdf5", 43 | # "feat_file": "/hujingjing2/data/Charades_STA/C3D_unit16_overlap0.5", 44 | "feat_file": "/hujingjing2/data/data/features/i3d_finetuned", 45 | 46 | }, 47 | "charades_test": { 48 | "audio_dir": "/hujingjing2/data/Charades_STA/text_distbert_feat", 49 | "ann_file": "./dataset/Charades_STA/test_audio_new.json", 50 | # "feat_file": "G:/Dataset/data/Charades_STA/vgg_rgb_features.hdf5", 51 | # "feat_file": "/hujingjing2/data/Charades_STA/C3D_unit16_overlap0.5", 52 | "feat_file": "/hujingjing2/data/data/features/i3d_finetuned", 53 | 54 | }, 55 | } 56 | 57 | @staticmethod 58 | def get(name): 59 | data_dir = DatasetCatalog.DATA_DIR 60 | attrs = DatasetCatalog.DATASETS[name] 61 | 62 | if "charades" in name and cfg.MODEL.DTF.VIDEO_MODE == 'vgg': 63 | attrs["feat_file"] = "/hujingjing2/data/Charades_STA/vgg_rgb_features.hdf5" 64 | if "charades" in name and cfg.MODEL.DTF.VIDEO_MODE == 'c3d': 65 | attrs["feat_file"] = "/hujingjing2/data/Charades_STA/C3D_unit16_overlap0.5" 66 | 67 | if "charades" in name and cfg.MODEL.DTF.VIDEO_MODE == 'c3d_pca': 68 | attrs["feat_file"] = "/hujingjing2/data/Charades_STA/C3D_PCA_new" 69 | 70 | if "charades" in name and cfg.MODEL.DTF.VIDEO_MODE == 'i3d': 71 | attrs["feat_file"] = "/hujingjing2/data/data/features/i3d_finetuned" 72 | 73 | args = dict( 74 | audio_dir=os.path.join(data_dir, attrs["audio_dir"]), 75 | ann_file=os.path.join(data_dir, attrs["ann_file"]), 76 | feat_file=os.path.join(data_dir, attrs["feat_file"]), 77 | ) 78 | if "tacos" in name: 79 | return dict( 80 | factory="TACoSDataset", 81 | args=args, 82 | ) 83 | elif "activitynet" in name: 84 | return dict( 85 | factory = "ActivityNetDataset", 86 | args = args 87 | ) 88 | elif "charades" in name: 89 | return dict( 90 | factory = "CharadesDataset", 91 | args = args 92 | ) 93 | raise RuntimeError("Dataset not available: {}".format(name)) 94 | 95 | -------------------------------------------------------------------------------- /dtfnet/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg 2 | -------------------------------------------------------------------------------- /dtfnet/config/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/config/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/config/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/config/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/config/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/config/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/config/__pycache__/defaults.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/config/__pycache__/defaults.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/config/__pycache__/defaults.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/config/__pycache__/defaults.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/config/__pycache__/defaults.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/config/__pycache__/defaults.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/config/__pycache__/paths_catalog.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/config/__pycache__/paths_catalog.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/config/__pycache__/paths_catalog.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/config/__pycache__/paths_catalog.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/config/__pycache__/paths_catalog.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/config/__pycache__/paths_catalog.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/config/defaults.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from yacs.config import CfgNode as CN 4 | # ----------------------------------------------------------------------------- 5 | # Convention about Training / Test specific parameters 6 | # ----------------------------------------------------------------------------- 7 | # Config definition 8 | # ----------------------------------------------------------------------------- 9 | _C = CN() 10 | 11 | _C.MODEL = CN() 12 | _C.MODEL.DEVICE = "cuda" 13 | _C.MODEL.ARCHITECTURE = "DTF" 14 | 15 | # ----------------------------------------------------------------------------- 16 | # INPUT 17 | # ----------------------------------------------------------------------------- 18 | _C.INPUT = CN() 19 | _C.INPUT.NUM_PRE_CLIPS = 256 20 | 21 | # ----------------------------------------------------------------------------- 22 | # Dataset 23 | # ----------------------------------------------------------------------------- 24 | _C.DATASETS = CN() 25 | # List of the dataset names for training, as present in paths_catalog.py 26 | _C.DATASETS.TRAIN = () 27 | # List of the dataset names for testing, as present in paths_catalog.py 28 | _C.DATASETS.TEST = () 29 | _C.DATASETS.NAME = "" 30 | 31 | # ----------------------------------------------------------------------------- 32 | # DataLoader 33 | # ----------------------------------------------------------------------------- 34 | _C.DATALOADER = CN() 35 | # Number of data loading threads 36 | _C.DATALOADER.NUM_WORKERS = 4 37 | 38 | # ----------------------------------------------------------------------------- 39 | # Models 40 | # ----------------------------------------------------------------------------- 41 | _C.MODEL.DTF = CN() 42 | _C.MODEL.DTF.VIDEO_MODE = 'c3d' 43 | _C.MODEL.DTF.NUM_CLIPS = 128 44 | _C.MODEL.DTF.JOINT_SPACE_SIZE = 256 45 | 46 | _C.MODEL.DTF.FEATPOOL = CN() 47 | _C.MODEL.DTF.FEATPOOL.INPUT_SIZE = 4096 48 | _C.MODEL.DTF.FEATPOOL.HIDDEN_SIZE = 512 49 | _C.MODEL.DTF.FEATPOOL.KERNEL_SIZE = 2 50 | 51 | _C.MODEL.DTF.FEAT2D = CN() 52 | _C.MODEL.DTF.FEAT2D.NAME = "pool" 53 | _C.MODEL.DTF.FEAT2D.POOLING_COUNTS = [15, 8, 8, 8] 54 | 55 | _C.MODEL.DTF.TEXT_ENCODER = CN() 56 | _C.MODEL.DTF.TEXT_ENCODER.NAME = 'BERT' 57 | 58 | _C.MODEL.DTF.PREDICTOR = CN() 59 | _C.MODEL.DTF.PREDICTOR.HIDDEN_SIZE = 512 60 | _C.MODEL.DTF.PREDICTOR.KERNEL_SIZE = 5 61 | _C.MODEL.DTF.PREDICTOR.NUM_STACK_LAYERS = 8 62 | 63 | 64 | _C.MODEL.DTF.LOSS = CN() 65 | _C.MODEL.DTF.LOSS.MIN_IOU = 0.3 66 | _C.MODEL.DTF.LOSS.MAX_IOU = 0.7 67 | _C.MODEL.DTF.LOSS.BCE_WEIGHT = 1 68 | _C.MODEL.DTF.LOSS.NUM_POSTIVE_VIDEO_PROPOSAL = 1 69 | _C.MODEL.DTF.LOSS.NEGATIVE_VIDEO_IOU = 0.5 70 | _C.MODEL.DTF.LOSS.SENT_REMOVAL_IOU = 0.5 71 | _C.MODEL.DTF.LOSS.PAIRWISE_SENT_WEIGHT = 0.0 72 | _C.MODEL.DTF.LOSS.CONTRASTIVE_WEIGHT = 0.05 73 | _C.MODEL.DTF.LOSS.TAU_VIDEO = 0.2 74 | _C.MODEL.DTF.LOSS.TAU_SENT = 0.2 75 | _C.MODEL.DTF.LOSS.MARGIN = 0.2 76 | 77 | # ---------------------------------------------------------------------------- # 78 | # Solver 79 | # ---------------------------------------------------------------------------- # 80 | _C.SOLVER = CN() 81 | _C.SOLVER.MAX_EPOCH = 12 82 | _C.SOLVER.LR = 0.01 83 | _C.SOLVER.CHECKPOINT_PERIOD = 1 84 | _C.SOLVER.TEST_PERIOD = 1 85 | _C.SOLVER.BATCH_SIZE = 32 86 | _C.SOLVER.MILESTONES = (8, 11) 87 | _C.SOLVER.RESUME = False 88 | _C.SOLVER.RESUME_EPOCH = 1 89 | _C.SOLVER.FREEZE_BERT = 4 90 | _C.SOLVER.ONLY_IOU = 7 91 | _C.SOLVER.SKIP_TEST = 0 92 | _C.SOLVER.USE_STATIC = True 93 | _C.SOLVER.USE_GNN = True 94 | _C.SOLVER.GNN_SPARSE = True 95 | _C.SOLVER.GNN_MODE = 'gauss' 96 | _C.SOLVER.POS_EMBED = 'sine' 97 | _C.SOLVER.GNN_LAYERS = 2 98 | _C.SOLVER.GNN_U = 5.0 99 | _C.SOLVER.GNN_STEP = 0.1 100 | # ---------------------------------------------------------------------------- # 101 | # Specific test options 102 | # ---------------------------------------------------------------------------- # 103 | _C.TEST = CN() 104 | _C.TEST.BATCH_SIZE = 64 105 | _C.TEST.NMS_THRESH = 0.5 106 | _C.TEST.CONTRASTIVE_SCORE_POW = 0.5 107 | 108 | # ---------------------------------------------------------------------------- # 109 | # Misc options 110 | # ---------------------------------------------------------------------------- # 111 | _C.OUTPUT_DIR = "./activity/textfull" 112 | _C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py") 113 | -------------------------------------------------------------------------------- /dtfnet/config/paths_catalog.py: -------------------------------------------------------------------------------- 1 | """Centralized catalog of paths.""" 2 | import os 3 | from dtfnet.config import cfg 4 | 5 | class DatasetCatalog(object): 6 | DATA_DIR = "" 7 | 8 | DATASETS = { 9 | "tacos_train":{ 10 | "audio_dir": "/hujingjing2/data/TACoS/audio_data2vec_feat", # text or audio feature path 11 | "ann_file": "./dataset/TACoS/train_audio_new.json", 12 | "feat_file": "/hujingjing2/data/TACoS/tall_c3d_features.hdf5", # video feature path 13 | }, 14 | "tacos_val":{ 15 | "audio_dir": "/hujingjing2/data/TACoS/audio_data2vec_feat", 16 | "ann_file": "./dataset/TACoS/val_audio_new.json", 17 | "feat_file": "/hujingjing2/data/TACoS/tall_c3d_features.hdf5", 18 | }, 19 | "tacos_test":{ 20 | "audio_dir": "/hujingjing2/data/TACoS/audio_data2vec_feat", 21 | "ann_file": "./dataset/TACoS/test_audio_new.json", 22 | "feat_file": "/hujingjing2/data/TACoS/tall_c3d_features.hdf5", 23 | }, 24 | "activitynet_train":{ 25 | "audio_dir": "/hujingjing2/data/ActivityNet/text_distbert_feat", 26 | "ann_file": "./dataset/ActivityNet/train_audio_new.json", 27 | "feat_file": "/hujingjing2/Spoken_Video_Grounding/data/activity-c3d", 28 | }, 29 | "activitynet_val":{ 30 | "audio_dir": "/hujingjing2/data/ActivityNet/text_distbert_feat", #/hujingjing2/data/ActivityNet/text_distbert_feat 31 | "ann_file": "./dataset/ActivityNet/val_audio_new.json", 32 | "feat_file": "/hujingjing2/Spoken_Video_Grounding/data/activity-c3d", 33 | }, 34 | "activitynet_test":{ 35 | "audio_dir": "/hujingjing2/data/ActivityNet/text_distbert_feat", 36 | "ann_file": "./dataset/ActivityNet/test_audio_new.json", 37 | "feat_file": "/hujingjing2/Spoken_Video_Grounding/data/activity-c3d", 38 | }, 39 | "charades_train": { 40 | "audio_dir": "/hujingjing2/data/Charades_STA/text_distbert_feat", 41 | "ann_file": "./dataset/Charades_STA/train_audio_new.json", 42 | "feat_file": "/hujingjing2/data/data/features/i3d_finetuned", 43 | 44 | }, 45 | "charades_test": { 46 | "audio_dir": "/hujingjing2/data/Charades_STA/text_distbert_feat", 47 | "ann_file": "./dataset/Charades_STA/test_audio_new.json", 48 | "feat_file": "/hujingjing2/data/data/features/i3d_finetuned", 49 | 50 | }, 51 | } 52 | 53 | @staticmethod 54 | def get(name): 55 | data_dir = DatasetCatalog.DATA_DIR 56 | attrs = DatasetCatalog.DATASETS[name] 57 | 58 | if "charades" in name and cfg.MODEL.DTF.VIDEO_MODE == 'vgg': 59 | attrs["feat_file"] = "/hujingjing2/data/Charades_STA/vgg_rgb_features.hdf5" # change this path with your own charades video feature path 60 | if "charades" in name and cfg.MODEL.DTF.VIDEO_MODE == 'c3d': 61 | attrs["feat_file"] = "/hujingjing2/data/Charades_STA/C3D_unit16_overlap0.5" 62 | 63 | if "charades" in name and cfg.MODEL.DTF.VIDEO_MODE == 'c3d_pca': 64 | attrs["feat_file"] = "/hujingjing2/data/Charades_STA/C3D_PCA_new" 65 | 66 | if "charades" in name and cfg.MODEL.DTF.VIDEO_MODE == 'i3d': 67 | attrs["feat_file"] = "/hujingjing2/data/data/features/i3d_finetuned" 68 | 69 | args = dict( 70 | audio_dir=os.path.join(data_dir, attrs["audio_dir"]), 71 | ann_file=os.path.join(data_dir, attrs["ann_file"]), 72 | feat_file=os.path.join(data_dir, attrs["feat_file"]), 73 | ) 74 | if "tacos" in name: 75 | return dict( 76 | factory="TACoSDataset", 77 | args=args, 78 | ) 79 | elif "activitynet" in name: 80 | return dict( 81 | factory = "ActivityNetDataset", 82 | args = args 83 | ) 84 | elif "charades" in name: 85 | return dict( 86 | factory = "CharadesDataset", 87 | args = args 88 | ) 89 | raise RuntimeError("Dataset not available: {}".format(name)) 90 | 91 | -------------------------------------------------------------------------------- /dtfnet/data/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from dtfnet.utils.comm import get_world_size 4 | from dtfnet.utils.imports import import_file 5 | from . import datasets as D 6 | from .samplers import DistributedSampler 7 | from .collate_batch import BatchCollator 8 | import os 9 | 10 | def build_dataset(dataset_list, dataset_catalog, cfg, is_train=True): 11 | # build specific dataset 12 | if not isinstance(dataset_list, (list, tuple)): 13 | raise RuntimeError( 14 | "dataset_list should be a list of strings, got {}".format( 15 | dataset_list 16 | ) 17 | ) 18 | datasets = [] 19 | for dataset_name in dataset_list: 20 | data = dataset_catalog.get(dataset_name) 21 | factory = getattr(D, data["factory"]) 22 | args = data["args"] 23 | args["num_pre_clips"] = cfg.INPUT.NUM_PRE_CLIPS 24 | args["num_clips"] = cfg.MODEL.DTF.NUM_CLIPS 25 | dataset = factory(**args) 26 | datasets.append(dataset) 27 | 28 | # for testing, return a list of datasets 29 | if not is_train: 30 | return datasets 31 | 32 | # for training, concatenate all datasets into a single one 33 | dataset = datasets[0] 34 | if len(datasets) > 1: 35 | dataset = D.ConcatDataset(datasets) 36 | return [dataset] 37 | 38 | def make_data_sampler(dataset, shuffle, distributed): 39 | if distributed: 40 | return DistributedSampler(dataset, shuffle=shuffle) 41 | if shuffle: 42 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 43 | else: 44 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 45 | return sampler 46 | 47 | def make_train_data_sampler(dataset, sampler, batch_size): 48 | batch_sampler = torch.utils.data.sampler.BatchSampler( 49 | sampler, batch_size, drop_last=False 50 | # TODO: check if drop_last=True helps 51 | ) 52 | return batch_sampler 53 | 54 | def make_test_data_sampler(dataset, sampler, batch_size): 55 | batch_sampler = torch.utils.data.sampler.BatchSampler( 56 | sampler, batch_size, drop_last=False 57 | ) 58 | return batch_sampler 59 | 60 | def make_data_loader(cfg, is_train=True, is_distributed=False, is_for_period=False): 61 | num_gpus = get_world_size() 62 | if is_train: 63 | batch_size = cfg.SOLVER.BATCH_SIZE 64 | assert ( 65 | batch_size % num_gpus == 0 66 | ), "SOLVER.BATCH_SIZE ({}) must be divisible by the number of GPUs ({}) used.".format( 67 | batch_size, num_gpus) 68 | batch_size_per_gpu = batch_size // num_gpus 69 | shuffle = True 70 | max_epoch = cfg.SOLVER.MAX_EPOCH 71 | else: 72 | batch_size = cfg.TEST.BATCH_SIZE 73 | assert ( 74 | batch_size % num_gpus == 0 75 | ), "TEST.BATCH_SIZE ({}) must be divisible by the number of GPUs ({}) used.".format( 76 | batch_size, num_gpus) 77 | batch_size_per_gpu = batch_size // num_gpus 78 | shuffle = True if not is_distributed else False # originally False 79 | 80 | if batch_size_per_gpu > 1: 81 | logger = logging.getLogger(__name__) 82 | 83 | paths_catalog = import_file( 84 | "mmn.cfg.paths_catalog", cfg.PATHS_CATALOG, True 85 | ) 86 | DatasetCatalog = paths_catalog.DatasetCatalog 87 | dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST 88 | datasets = build_dataset(dataset_list, DatasetCatalog, cfg, is_train=is_train or is_for_period) 89 | 90 | data_loaders = [] 91 | for dataset in datasets: 92 | sampler = make_data_sampler(dataset, shuffle, is_distributed) 93 | if is_train: 94 | batch_sampler = make_train_data_sampler(dataset, sampler, batch_size_per_gpu) 95 | else: 96 | batch_sampler = make_test_data_sampler(dataset, sampler, batch_size_per_gpu) 97 | data_loader = torch.utils.data.DataLoader( 98 | dataset, 99 | num_workers=cfg.DATALOADER.NUM_WORKERS, 100 | batch_sampler=batch_sampler, 101 | collate_fn=BatchCollator(), 102 | ) 103 | data_loaders.append(data_loader) 104 | if is_train or is_for_period: 105 | # during training, a single (possibly concatenated) data_loader is returned 106 | assert len(data_loaders) == 1 107 | return data_loaders[0] 108 | return data_loaders 109 | -------------------------------------------------------------------------------- /dtfnet/data/.ipynb_checkpoints/collate_batch-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | from dtfnet.structures import TLGBatch 4 | 5 | 6 | class BatchCollator(object): 7 | """ 8 | Collect batch for dataloader 9 | """ 10 | 11 | def __init__(self, ): 12 | pass 13 | 14 | def __call__(self, batch): 15 | transposed_batch = list(zip(*batch)) 16 | # [xxx, xxx, xxx], [xxx, xxx, xxx] ...... 17 | feats, queries, wordlens, ious2d, moments, num_sentence, idxs, vid = transposed_batch 18 | 19 | return TLGBatch( 20 | feats=torch.stack(feats).float(), 21 | queries=queries, 22 | wordlens=wordlens, 23 | all_iou2d=ious2d, 24 | moments=moments, 25 | num_sentence=num_sentence, 26 | idxs=idxs, 27 | vid=vid, 28 | ), idxs 29 | -------------------------------------------------------------------------------- /dtfnet/data/.ipynb_checkpoints/samplers-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DistributedSampler as _DistributedSampler 3 | 4 | class DistributedSampler(_DistributedSampler): 5 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 6 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 7 | self.shuffle = shuffle 8 | 9 | def __iter__(self): 10 | if self.shuffle: 11 | # deterministically shuffle based on epoch 12 | g = torch.Generator() 13 | g.manual_seed(self.epoch) 14 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 15 | else: 16 | indices = torch.arange(len(self.dataset)).tolist() 17 | 18 | # add extra samples to make it evenly divisible 19 | indices += indices[: (self.total_size - len(indices))] 20 | assert len(indices) == self.total_size 21 | 22 | # subsample 23 | offset = self.num_samples * self.rank 24 | indices = indices[offset : offset + self.num_samples] 25 | assert len(indices) == self.num_samples 26 | return iter(indices) 27 | 28 | def __len__(self): 29 | return self.num_samples 30 | 31 | def set_epoch(self, epoch): 32 | self.epoch = epoch -------------------------------------------------------------------------------- /dtfnet/data/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from dtfnet.utils.comm import get_world_size 4 | from dtfnet.utils.imports import import_file 5 | from . import datasets as D 6 | from .samplers import DistributedSampler 7 | from .collate_batch import BatchCollator 8 | import os 9 | 10 | def build_dataset(dataset_list, dataset_catalog, cfg, is_train=True): 11 | # build specific dataset 12 | if not isinstance(dataset_list, (list, tuple)): 13 | raise RuntimeError( 14 | "dataset_list should be a list of strings, got {}".format( 15 | dataset_list 16 | ) 17 | ) 18 | datasets = [] 19 | for dataset_name in dataset_list: 20 | data = dataset_catalog.get(dataset_name) 21 | factory = getattr(D, data["factory"]) 22 | args = data["args"] 23 | args["num_pre_clips"] = cfg.INPUT.NUM_PRE_CLIPS 24 | args["num_clips"] = cfg.MODEL.DTF.NUM_CLIPS 25 | dataset = factory(**args) 26 | datasets.append(dataset) 27 | 28 | # for testing, return a list of datasets 29 | if not is_train: 30 | return datasets 31 | 32 | # for training, concatenate all datasets into a single one 33 | dataset = datasets[0] 34 | if len(datasets) > 1: 35 | dataset = D.ConcatDataset(datasets) 36 | return [dataset] 37 | 38 | def make_data_sampler(dataset, shuffle, distributed): 39 | if distributed: 40 | return DistributedSampler(dataset, shuffle=shuffle) 41 | if shuffle: 42 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 43 | else: 44 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 45 | return sampler 46 | 47 | def make_train_data_sampler(dataset, sampler, batch_size): 48 | batch_sampler = torch.utils.data.sampler.BatchSampler( 49 | sampler, batch_size, drop_last=False 50 | # TODO: check if drop_last=True helps 51 | ) 52 | return batch_sampler 53 | 54 | def make_test_data_sampler(dataset, sampler, batch_size): 55 | batch_sampler = torch.utils.data.sampler.BatchSampler( 56 | sampler, batch_size, drop_last=False 57 | ) 58 | return batch_sampler 59 | 60 | def make_data_loader(cfg, is_train=True, is_distributed=False, is_for_period=False): 61 | num_gpus = get_world_size() 62 | if is_train: 63 | batch_size = cfg.SOLVER.BATCH_SIZE 64 | assert ( 65 | batch_size % num_gpus == 0 66 | ), "SOLVER.BATCH_SIZE ({}) must be divisible by the number of GPUs ({}) used.".format( 67 | batch_size, num_gpus) 68 | batch_size_per_gpu = batch_size // num_gpus 69 | shuffle = True 70 | max_epoch = cfg.SOLVER.MAX_EPOCH 71 | else: 72 | batch_size = cfg.TEST.BATCH_SIZE 73 | assert ( 74 | batch_size % num_gpus == 0 75 | ), "TEST.BATCH_SIZE ({}) must be divisible by the number of GPUs ({}) used.".format( 76 | batch_size, num_gpus) 77 | batch_size_per_gpu = batch_size // num_gpus 78 | shuffle = True if not is_distributed else False # originally False 79 | 80 | if batch_size_per_gpu > 1: 81 | logger = logging.getLogger(__name__) 82 | 83 | paths_catalog = import_file( 84 | "mmn.cfg.paths_catalog", cfg.PATHS_CATALOG, True 85 | ) 86 | DatasetCatalog = paths_catalog.DatasetCatalog 87 | dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST 88 | datasets = build_dataset(dataset_list, DatasetCatalog, cfg, is_train=is_train or is_for_period) 89 | 90 | data_loaders = [] 91 | for dataset in datasets: 92 | sampler = make_data_sampler(dataset, shuffle, is_distributed) 93 | if is_train: 94 | batch_sampler = make_train_data_sampler(dataset, sampler, batch_size_per_gpu) 95 | else: 96 | batch_sampler = make_test_data_sampler(dataset, sampler, batch_size_per_gpu) 97 | data_loader = torch.utils.data.DataLoader( 98 | dataset, 99 | num_workers=cfg.DATALOADER.NUM_WORKERS, 100 | batch_sampler=batch_sampler, 101 | collate_fn=BatchCollator(), 102 | ) 103 | data_loaders.append(data_loader) 104 | if is_train or is_for_period: 105 | # during training, a single (possibly concatenated) data_loader is returned 106 | assert len(data_loaders) == 1 107 | return data_loaders[0] 108 | return data_loaders 109 | -------------------------------------------------------------------------------- /dtfnet/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/data/__pycache__/collate_batch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/__pycache__/collate_batch.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/data/__pycache__/collate_batch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/__pycache__/collate_batch.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/data/__pycache__/collate_batch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/__pycache__/collate_batch.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/data/__pycache__/samplers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/__pycache__/samplers.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/data/__pycache__/samplers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/__pycache__/samplers.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/data/__pycache__/samplers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/__pycache__/samplers.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/data/collate_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | from dtfnet.structures import TLGBatch 4 | 5 | 6 | class BatchCollator(object): 7 | """ 8 | Collect batch for dataloader 9 | """ 10 | 11 | def __init__(self, ): 12 | pass 13 | 14 | def __call__(self, batch): 15 | transposed_batch = list(zip(*batch)) 16 | # [xxx, xxx, xxx], [xxx, xxx, xxx] ...... 17 | feats, queries, ious2d, moments, num_sentence, idxs, vid = transposed_batch 18 | 19 | return TLGBatch( 20 | feats=torch.stack(feats).float(), 21 | queries=queries, 22 | all_iou2d=ious2d, 23 | moments=moments, 24 | num_sentence=num_sentence, 25 | idxs=idxs, 26 | vid=vid, 27 | ), idxs 28 | -------------------------------------------------------------------------------- /dtfnet/data/datasets/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .concat_dataset import ConcatDataset 2 | from .tacos import TACoSDataset 3 | from .activitynet import ActivityNetDataset 4 | from .charades import CharadesDataset 5 | __all__ = [ 6 | "ConcatDataset", 7 | "TACoSDataset", 8 | "ActivityNetDataset", 9 | "CharadesDataset" 10 | ] 11 | -------------------------------------------------------------------------------- /dtfnet/data/datasets/.ipynb_checkpoints/activitynet-Copy2-checkpoint.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import torch 4 | from .utils import moment_to_iou2d, bert_embedding, get_vid_feat 5 | from transformers import DistilBertTokenizer 6 | import os 7 | 8 | class ActivityNetDataset(torch.utils.data.Dataset): 9 | def __init__(self, audio_dir, ann_file, feat_file, num_pre_clips, num_clips): 10 | super(ActivityNetDataset, self).__init__() 11 | self.num_pre_clips = num_pre_clips 12 | self.num_clips = num_clips 13 | self.audio_dir = audio_dir 14 | self.feat_file = feat_file 15 | with open(ann_file, 'r', encoding='utf-8') as f: 16 | annos = json.load(f) 17 | 18 | self.annos = annos 19 | self.data = list(annos.keys()) 20 | 21 | logger = logging.getLogger("dtf.trainer") 22 | if 'train' in ann_file: 23 | self.mode = 'train' 24 | if 'val' in ann_file: 25 | self.mode = 'val' 26 | if 'test' in ann_file: 27 | self.mode = 'test' 28 | 29 | logger.info("-" * 60) 30 | logger.info(f"Preparing {len(self.data)} {self.mode} data, please wait...") 31 | 32 | self.feat_list = [] 33 | self.moments_list = [] 34 | self.all_iou2d_list = [] 35 | self.audios_list = [] 36 | self.num_audios_list = [] 37 | self.sent_list = [] 38 | for vid in self.data: 39 | duration, timestamps, audios_name, sentences = annos[vid]['duration'], annos[vid]['timestamps'], \ 40 | annos[vid]['audios'], annos[vid]['sentences'] 41 | feat = get_vid_feat(self.feat_file, vid, self.num_pre_clips, dataset_name="activitynet") 42 | moments = [] 43 | all_iou2d = [] 44 | for timestamp in timestamps: 45 | time = torch.Tensor([max(timestamp[0], 0), min(timestamp[1], duration)]) 46 | iou2d = moment_to_iou2d(time, self.num_clips, duration) 47 | moments.append(time) 48 | all_iou2d.append(iou2d) 49 | moments = torch.stack(moments) 50 | all_iou2d = torch.stack(all_iou2d) 51 | 52 | assert moments.size(0) == all_iou2d.size(0) 53 | 54 | num_audios = len(audios_name) 55 | if num_audios == 1: 56 | audios = torch.load(os.path.join(self.audio_dir, f'{audios_name[0].split(".")[0]}.pt')).squeeze( 57 | dim=1).float() 58 | elif num_audios > 1: 59 | audios = [torch.load(os.path.join(self.audio_dir, f'{audio_name.split(".")[0]}.pt')).squeeze(dim=1).float() 60 | for audio_name in audios_name] 61 | audios = torch.squeeze(torch.stack(audios, dim=0), dim=1) 62 | else: 63 | raise ValueError("num_audios should be greater than 0!") 64 | 65 | assert moments.size(0) == audios.size(0) 66 | start_time = moments[:, 0] 67 | index = torch.argsort(start_time) 68 | 69 | audios = torch.index_select(audios, dim=0, index=index) 70 | moments = torch.index_select(moments, dim=0, index=index) 71 | all_iou2d = torch.index_select(all_iou2d, dim=0, index=index) 72 | 73 | sent = [sentences[i] for i in index] 74 | self.feat_list.append(feat) 75 | self.audios_list.append(audios) 76 | self.num_audios_list.append(num_audios) 77 | self.moments_list.append(moments) 78 | self.all_iou2d_list.append(all_iou2d) 79 | self.sent_list.append(sent) 80 | 81 | 82 | 83 | def __getitem__(self, idx): 84 | vid = self.data[idx] 85 | 86 | return self.feat_list[idx], self.audios_list[idx], self.all_iou2d_list[idx], self.moments_list[idx], self.num_audios_list[idx], idx, vid 87 | 88 | def __len__(self): 89 | return len(self.data) 90 | 91 | def get_duration(self, idx): 92 | vid = self.data[idx] 93 | return self.annos[vid]['duration'] 94 | 95 | def get_sentence(self, idx): 96 | 97 | return self.sent_list[idx] 98 | 99 | def get_moment(self, idx): 100 | 101 | return self.moments_list[idx] 102 | 103 | def get_vid(self, idx): 104 | vid = self.data[idx] 105 | return vid 106 | 107 | def get_iou2d(self, idx): 108 | 109 | return self.all_iou2d_list[idx] 110 | 111 | def get_num_audios(self, idx): 112 | return self.num_audios_list[idx] -------------------------------------------------------------------------------- /dtfnet/data/datasets/.ipynb_checkpoints/concat_dataset-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import bisect 3 | 4 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 5 | 6 | 7 | class ConcatDataset(_ConcatDataset): 8 | """ 9 | Same as torch.utils.data.dataset.ConcatDataset, but exposes an extra 10 | method for querying the sizes of the image 11 | """ 12 | def __init__(self, datasets): 13 | super(ConcatDataset, self).__init__(datasets) 14 | 15 | def get_idxs(self, idx): 16 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 17 | if dataset_idx == 0: 18 | sample_idx = idx 19 | else: 20 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 21 | return dataset_idx, sample_idx 22 | 23 | def get_img_info(self, idx): 24 | dataset_idx, sample_idx = self.get_idxs(idx) 25 | return self.datasets[dataset_idx].get_img_info(sample_idx) 26 | -------------------------------------------------------------------------------- /dtfnet/data/datasets/.ipynb_checkpoints/utils-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, exists 3 | import h5py 4 | import numpy as np 5 | 6 | import torch 7 | import torchtext 8 | from torch.functional import F 9 | 10 | 11 | def iou(candidates, gt): 12 | start, end = candidates[:,0], candidates[:,1] 13 | s, e = gt[0].float(), gt[1].float() 14 | # print(s.dtype, start.dtype) 15 | inter = end.min(e) - start.max(s) 16 | union = end.max(e) - start.min(s) 17 | return inter.clamp(min=0) / union 18 | 19 | def box_iou(boxes1, boxes2): 20 | area1 = box_length(boxes1) 21 | area2 = box_length(boxes2) 22 | max_start = torch.max(boxes1[:, None, 0], boxes2[:, 0]) # [N,M] 23 | min_end = torch.min(boxes1[:, None, 1], boxes2[:, 1]) # [N,M] 24 | inter = (min_end - max_start).clamp(min=0) # [N,M] 25 | union = area1[:, None] + area2 - inter 26 | iou = inter / union 27 | return iou 28 | 29 | 30 | def box_length(boxes): 31 | return boxes[:, 1] - boxes[:, 0] 32 | 33 | 34 | def score2d_to_moments_scores(score2d, num_clips, duration): 35 | grids = score2d.nonzero() 36 | scores = score2d[grids[:, 0], grids[:, 1]] 37 | grids[:, 1] += 1 38 | moments = grids * duration / num_clips 39 | return moments, scores 40 | 41 | 42 | def moment_to_iou2d(moment, num_clips, duration): 43 | iou2d = torch.ones(num_clips, num_clips) 44 | candidates, _ = score2d_to_moments_scores(iou2d, num_clips, duration) 45 | iou2d = iou(candidates, moment).reshape(num_clips, num_clips) 46 | return iou2d 47 | 48 | 49 | def avgfeats(feats, num_pre_clips): 50 | # Produce the feature of per video into fixed shape (e.g. 256*4096) 51 | # Input Example: feats (torch.tensor, ?x4096); num_pre_clips (256) 52 | num_src_clips = feats.size(0) 53 | idxs = torch.arange(0, num_pre_clips+1, 1.0) / num_pre_clips * num_src_clips 54 | idxs = idxs.round().long().clamp(max=num_src_clips-1) 55 | # To prevent a empty selection, check the idxs 56 | meanfeats = [] 57 | for i in range(num_pre_clips): 58 | s, e = idxs[i], idxs[i+1] 59 | if s < e: 60 | meanfeats.append(feats[s:e].mean(dim=0)) 61 | else: 62 | meanfeats.append(feats[s]) 63 | return torch.stack(meanfeats) 64 | 65 | def maxfeats(feats, num_pre_clips): 66 | # Produce the feature of per video into fixed shape (e.g. 256*4096) 67 | # Input Example: feats (torch.tensor, ?x4096); num_pre_clips (256) 68 | num_src_clips = feats.size(0) 69 | idxs = torch.arange(0, num_pre_clips+1, 1.0) / num_pre_clips * num_src_clips 70 | idxs = idxs.round().long().clamp(max=num_src_clips-1) 71 | # To prevent a empty selection, check the idxs 72 | maxfeats = [] 73 | for i in range(num_pre_clips): 74 | s, e = idxs[i], idxs[i+1] 75 | if s < e: 76 | maxfeats.append(feats[s:e].max(dim=0)[0]) 77 | else: 78 | maxfeats.append(feats[s]) 79 | return torch.stack(maxfeats) 80 | 81 | def video2feats(feat_file, vids, num_pre_clips, dataset_name): 82 | assert exists(feat_file) 83 | vid_feats = {} 84 | with h5py.File(feat_file, 'r') as f: 85 | for vid in vids: 86 | if dataset_name == "activitynet": 87 | feat = f[vid]['c3d_features'][:] 88 | else: 89 | feat = f[vid][:] 90 | feat = F.normalize(torch.from_numpy(feat), dim=1) 91 | vid_feats[vid] = avgfeats(feat, num_pre_clips) 92 | return vid_feats 93 | 94 | def get_vid_feat(feat_file, vid, num_pre_clips, dataset_name): 95 | assert exists(feat_file) 96 | 97 | if dataset_name == "activitynet": 98 | with h5py.File(os.path.join(feat_file, '%s.h5' % vid), 'r') as fr: 99 | feat = np.asarray(fr['feature']).astype(np.float32) 100 | feat = F.normalize(torch.from_numpy(feat), dim=1) 101 | 102 | elif dataset_name == "charades" and 'i3d' in feat_file: 103 | feat = np.load(os.path.join(feat_file, '%s.npy' % vid)) 104 | feat = F.normalize(torch.from_numpy(feat).squeeze(1).squeeze(1), dim=1) 105 | 106 | elif dataset_name == "charades" and 'vgg' in feat_file: 107 | with h5py.File(feat_file, 'r') as fr: 108 | feat = np.asarray(fr[vid][:]).astype(np.float32) 109 | feat = F.normalize(torch.from_numpy(feat), dim=1) 110 | elif dataset_name == "charades" and 'C3D' in feat_file: 111 | feat = torch.load(os.path.join(feat_file, '%s.pt' % vid)) 112 | feat = F.normalize(feat, dim=1) 113 | 114 | # elif dataset_name == "charades": 115 | # with h5py.File(feat_file, 'r') as f: 116 | # feat = f[vid][:] 117 | # feat = F.normalize(torch.from_numpy(feat), dim=1) 118 | else: 119 | with h5py.File(feat_file, 'r') as f: 120 | feat = f[vid][:] 121 | feat = F.normalize(torch.from_numpy(feat), dim=1) 122 | 123 | return avgfeats(feat, num_pre_clips) 124 | 125 | def get_feat_didemo(feat_file, vid): 126 | assert exists(feat_file) 127 | with h5py.File(feat_file, 'r') as f: 128 | feat = f[vid][:] 129 | return torch.from_numpy(feat) 130 | 131 | def get_c3d_charades(feat_file, num_pre_clips): 132 | assert exists(feat_file) 133 | feat = torch.load(feat_file) 134 | #feat = F.normalize(feat, dim=1) 135 | return maxfeats(feat, num_pre_clips) 136 | 137 | def bert_embedding(sentence, tokenizer): 138 | query_token = tokenizer(sentence, return_tensors="pt", padding=True) 139 | word_lens = query_token['attention_mask'].sum(dim=1) 140 | queries = query_token['input_ids'] 141 | return queries, word_lens 142 | 143 | 144 | def glove_embedding(sentence, vocabs=[], embedders=[]): 145 | if len(vocabs) == 0: 146 | vocab = torchtext.vocab.pretrained_aliases["glove.840B.300d"]() 147 | vocab.itos.extend(['']) 148 | vocab.stoi[''] = vocab.vectors.shape[0] 149 | vocab.vectors = torch.cat([vocab.vectors, torch.zeros(1, vocab.dim)], dim=0) 150 | vocabs.append(vocab) 151 | 152 | if len(embedders) == 0: 153 | embedder = torch.nn.Embedding.from_pretrained(vocab.vectors) 154 | embedders.append(embedder) 155 | 156 | vocab, embedder = vocabs[0], embedders[0] 157 | word_idxs = torch.tensor([vocab.stoi.get(w.lower(), 400000) for w in sentence.split()], dtype=torch.long) 158 | return embedder(word_idxs) 159 | -------------------------------------------------------------------------------- /dtfnet/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .concat_dataset import ConcatDataset 2 | from .tacos import TACoSDataset 3 | from .activitynet import ActivityNetDataset 4 | from .charades import CharadesDataset 5 | __all__ = [ 6 | "ConcatDataset", 7 | "TACoSDataset", 8 | "ActivityNetDataset", 9 | "CharadesDataset" 10 | ] 11 | -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/activitynet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/activitynet.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/activitynet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/activitynet.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/activitynet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/activitynet.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/charades.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/charades.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/charades.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/charades.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/charades.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/charades.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/concat_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/concat_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/concat_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/concat_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/concat_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/concat_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/evaluation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/evaluation.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/evaluation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/evaluation.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/evaluation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/evaluation.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/tacos.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/tacos.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/tacos.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/tacos.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/tacos.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/tacos.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/data/datasets/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/data/datasets/activitynet-Copy2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import torch 4 | from .utils import moment_to_iou2d, bert_embedding, get_vid_feat 5 | from transformers import DistilBertTokenizer 6 | import os 7 | 8 | class ActivityNetDataset(torch.utils.data.Dataset): 9 | def __init__(self, audio_dir, ann_file, feat_file, num_pre_clips, num_clips): 10 | super(ActivityNetDataset, self).__init__() 11 | self.num_pre_clips = num_pre_clips 12 | self.num_clips = num_clips 13 | self.audio_dir = audio_dir 14 | self.feat_file = feat_file 15 | with open(ann_file, 'r', encoding='utf-8') as f: 16 | annos = json.load(f) 17 | 18 | self.annos = annos 19 | self.data = list(annos.keys()) 20 | 21 | logger = logging.getLogger("dtf.trainer") 22 | if 'train' in ann_file: 23 | self.mode = 'train' 24 | if 'val' in ann_file: 25 | self.mode = 'val' 26 | if 'test' in ann_file: 27 | self.mode = 'test' 28 | 29 | logger.info("-" * 60) 30 | logger.info(f"Preparing {len(self.data)} {self.mode} data, please wait...") 31 | 32 | self.feat_list = [] 33 | self.moments_list = [] 34 | self.all_iou2d_list = [] 35 | self.audios_list = [] 36 | self.num_audios_list = [] 37 | self.sent_list = [] 38 | for vid in self.data: 39 | duration, timestamps, audios_name, sentences = annos[vid]['duration'], annos[vid]['timestamps'], \ 40 | annos[vid]['audios'], annos[vid]['sentences'] 41 | feat = get_vid_feat(self.feat_file, vid, self.num_pre_clips, dataset_name="activitynet") 42 | moments = [] 43 | all_iou2d = [] 44 | for timestamp in timestamps: 45 | time = torch.Tensor([max(timestamp[0], 0), min(timestamp[1], duration)]) 46 | iou2d = moment_to_iou2d(time, self.num_clips, duration) 47 | moments.append(time) 48 | all_iou2d.append(iou2d) 49 | moments = torch.stack(moments) 50 | all_iou2d = torch.stack(all_iou2d) 51 | 52 | assert moments.size(0) == all_iou2d.size(0) 53 | 54 | num_audios = len(audios_name) 55 | if num_audios == 1: 56 | audios = torch.load(os.path.join(self.audio_dir, f'{audios_name[0].split(".")[0]}.pt')).squeeze( 57 | dim=1).float() 58 | elif num_audios > 1: 59 | audios = [torch.load(os.path.join(self.audio_dir, f'{audio_name.split(".")[0]}.pt')).squeeze(dim=1).float() 60 | for audio_name in audios_name] 61 | audios = torch.squeeze(torch.stack(audios, dim=0), dim=1) 62 | else: 63 | raise ValueError("num_audios should be greater than 0!") 64 | 65 | assert moments.size(0) == audios.size(0) 66 | start_time = moments[:, 0] 67 | index = torch.argsort(start_time) 68 | 69 | audios = torch.index_select(audios, dim=0, index=index) 70 | moments = torch.index_select(moments, dim=0, index=index) 71 | all_iou2d = torch.index_select(all_iou2d, dim=0, index=index) 72 | 73 | sent = [sentences[i] for i in index] 74 | self.feat_list.append(feat) 75 | self.audios_list.append(audios) 76 | self.num_audios_list.append(num_audios) 77 | self.moments_list.append(moments) 78 | self.all_iou2d_list.append(all_iou2d) 79 | self.sent_list.append(sent) 80 | 81 | 82 | 83 | def __getitem__(self, idx): 84 | vid = self.data[idx] 85 | 86 | return self.feat_list[idx], self.audios_list[idx], self.all_iou2d_list[idx], self.moments_list[idx], self.num_audios_list[idx], idx, vid 87 | 88 | def __len__(self): 89 | return len(self.data) 90 | 91 | def get_duration(self, idx): 92 | vid = self.data[idx] 93 | return self.annos[vid]['duration'] 94 | 95 | def get_sentence(self, idx): 96 | 97 | return self.sent_list[idx] 98 | 99 | def get_moment(self, idx): 100 | 101 | return self.moments_list[idx] 102 | 103 | def get_vid(self, idx): 104 | vid = self.data[idx] 105 | return vid 106 | 107 | def get_iou2d(self, idx): 108 | 109 | return self.all_iou2d_list[idx] 110 | 111 | def get_num_audios(self, idx): 112 | return self.num_audios_list[idx] -------------------------------------------------------------------------------- /dtfnet/data/datasets/concat_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import bisect 3 | 4 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 5 | 6 | 7 | class ConcatDataset(_ConcatDataset): 8 | """ 9 | Same as torch.utils.data.dataset.ConcatDataset, but exposes an extra 10 | method for querying the sizes of the image 11 | """ 12 | def __init__(self, datasets): 13 | super(ConcatDataset, self).__init__(datasets) 14 | 15 | def get_idxs(self, idx): 16 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 17 | if dataset_idx == 0: 18 | sample_idx = idx 19 | else: 20 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 21 | return dataset_idx, sample_idx 22 | 23 | def get_img_info(self, idx): 24 | dataset_idx, sample_idx = self.get_idxs(idx) 25 | return self.datasets[dataset_idx].get_img_info(sample_idx) 26 | -------------------------------------------------------------------------------- /dtfnet/data/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DistributedSampler as _DistributedSampler 3 | 4 | class DistributedSampler(_DistributedSampler): 5 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 6 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 7 | self.shuffle = shuffle 8 | 9 | def __iter__(self): 10 | if self.shuffle: 11 | # deterministically shuffle based on epoch 12 | g = torch.Generator() 13 | g.manual_seed(self.epoch) 14 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 15 | else: 16 | indices = torch.arange(len(self.dataset)).tolist() 17 | 18 | # add extra samples to make it evenly divisible 19 | indices += indices[: (self.total_size - len(indices))] 20 | assert len(indices) == self.total_size 21 | 22 | # subsample 23 | offset = self.num_samples * self.rank 24 | indices = indices[offset : offset + self.num_samples] 25 | assert len(indices) == self.num_samples 26 | return iter(indices) 27 | 28 | def __len__(self): 29 | return self.num_samples 30 | 31 | def set_epoch(self, epoch): 32 | self.epoch = epoch -------------------------------------------------------------------------------- /dtfnet/engine/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dtfnet/engine/.ipynb_checkpoints/inference-checkpoint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from dtfnet.data.datasets.evaluation import evaluate 4 | from ..utils.comm import is_main_process, get_world_size 5 | from ..utils.comm import all_gather 6 | from ..utils.comm import synchronize 7 | from ..utils.timer import Timer, get_time_str 8 | 9 | 10 | def compute_on_dataset(model, data_loader, device, timer=None): 11 | model.eval() 12 | results_dict = {} 13 | cpu_device = torch.device("cpu") 14 | for batch in data_loader: # use tqdm(data_loader) for showing progress bar 15 | batches, idxs = batch 16 | with torch.no_grad(): 17 | if timer: 18 | timer.tic() 19 | _,_,contrastive_output, iou_output = model(batches.to(device)) 20 | if timer: 21 | if not device.type == 'cpu': 22 | torch.cuda.synchronize() 23 | timer.toc() 24 | contrastive_output, iou_output = [o.to(cpu_device) for o in contrastive_output], [o.to(cpu_device) for o in iou_output] 25 | results_dict.update( 26 | {video_id: {'contrastive': result1, 'iou': result2} for video_id, result1, result2 in zip(idxs, contrastive_output, iou_output)} 27 | ) 28 | return results_dict 29 | 30 | 31 | def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu): 32 | all_predictions = all_gather(predictions_per_gpu) 33 | if not is_main_process(): 34 | return 35 | # merge the list of dicts 36 | predictions = {} 37 | for p in all_predictions: 38 | predictions.update(p) 39 | # convert a dict where the key is the index in a list 40 | idxs = list(sorted(predictions.keys())) 41 | if len(idxs) != idxs[-1] + 1: 42 | logger = logging.getLogger("dtf.inference") 43 | logger.warning( 44 | "Number of samples that were gathered from multiple processes is not " 45 | "a contiguous set. Some samples might be missing from the evaluation" 46 | ) 47 | 48 | # convert to a list 49 | predictions = [predictions[i] for i in idxs] 50 | return predictions 51 | 52 | def inference( 53 | cfg, 54 | model, 55 | data_loader, 56 | dataset_name, 57 | nms_thresh, 58 | epoch, 59 | device="cuda", 60 | ): 61 | # convert to a torch.device for efficiency 62 | device = torch.device(device) 63 | num_devices = get_world_size() 64 | logger = logging.getLogger("dtf.inference") 65 | dataset = data_loader.dataset 66 | logger.info("Start evaluation on {} dataset (Size: {}).".format(dataset_name, len(dataset))) 67 | inference_timer = Timer() 68 | predictions = compute_on_dataset(model, data_loader, device, inference_timer) 69 | # wait for all processes to complete before measuring the time 70 | synchronize() 71 | total_infer_time = get_time_str(inference_timer.total_time) 72 | logger.info( 73 | "Model inference time: {} ({:.03f} s / inference per device, on {} devices)".format( 74 | total_infer_time, 75 | inference_timer.total_time * num_devices / 17031.0, 76 | num_devices, 77 | ) 78 | ) 79 | 80 | predictions = _accumulate_predictions_from_multiple_gpus(predictions) 81 | ''' 82 | if not is_main_process(): 83 | return 84 | ''' 85 | return evaluate(cfg, dataset=dataset, predictions=predictions, nms_thresh=nms_thresh, epoch=epoch) 86 | -------------------------------------------------------------------------------- /dtfnet/engine/.ipynb_checkpoints/trainer-checkpoint.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import time 5 | import gc 6 | import torch 7 | import torch.distributed as dist 8 | 9 | from dtfnet.data import make_data_loader 10 | from dtfnet.utils.comm import get_world_size, synchronize 11 | from dtfnet.utils.metric_logger import MetricLogger 12 | from dtfnet.engine.inference import inference 13 | from ..utils.comm import is_main_process 14 | 15 | 16 | def reduce_loss(loss): 17 | world_size = get_world_size() 18 | if world_size < 2: 19 | return loss 20 | with torch.no_grad(): 21 | dist.reduce(loss, dst=0) 22 | if dist.get_rank() == 0: 23 | # only main process gets accumulated, so only divide by 24 | # world_size in this case 25 | loss /= world_size 26 | loss = loss.item() 27 | return loss 28 | 29 | 30 | def do_train( 31 | cfg, 32 | model, 33 | data_loader, 34 | data_loader_val, 35 | optimizer, 36 | scheduler, 37 | checkpointer, 38 | device, 39 | checkpoint_period, 40 | test_period, 41 | arguments, 42 | param_dict, 43 | max_norm=5 44 | ): 45 | 46 | logger = logging.getLogger("dtf.trainer") 47 | logger.info("Start training") 48 | meters = MetricLogger(delimiter=" ") 49 | max_epoch = cfg.SOLVER.MAX_EPOCH 50 | 51 | model.train() 52 | start_training_time = time.time() 53 | end = time.time() 54 | max_iteration = len(data_loader) 55 | writer_count = 0 56 | 57 | for epoch in range(arguments["epoch"], max_epoch + 1): 58 | rest_epoch_iteration = (max_epoch - epoch) * max_iteration 59 | arguments["epoch"] = epoch 60 | # data_loader.batch_sampler.sampler.set_epoch(epoch) 61 | # if epoch <= cfg.SOLVER.FREEZE_BERT: 62 | # for param in param_dict['bert']: 63 | # param.requires_grad_(False) 64 | # else: 65 | # for param in param_dict['bert']: 66 | # param.requires_grad_(True) 67 | # logger.info("Start epoch {}. base_lr={:.1e}, bert_lr={:.1e}, bert.requires_grad={}".format(epoch, optimizer.param_groups[0]["lr"], optimizer.param_groups[1]["lr"], str(param_dict['bert'][0].requires_grad))) 68 | if epoch <= cfg.SOLVER.ONLY_IOU: 69 | logger.info("Using all losses") 70 | else: 71 | logger.info("Using only bce loss") 72 | for iteration, (batches, idx) in enumerate(data_loader): 73 | writer_count += 1 74 | iteration += 1 75 | batches = batches.to(device) 76 | optimizer.zero_grad() 77 | contr_weight = cfg.MODEL.DTF.LOSS.CONTRASTIVE_WEIGHT 78 | loss_vid, loss_sent, loss_iou = model(batches, cur_epoch=epoch) 79 | loss_vid, loss_sent = loss_vid * contr_weight, loss_sent * contr_weight 80 | meters.update(loss_vid=loss_vid.detach(), loss_sent=loss_sent.detach(), loss_iou=loss_iou.detach()) 81 | loss = 0 82 | if epoch <= cfg.SOLVER.ONLY_IOU: 83 | loss += loss_iou 84 | # loss += loss_sent + loss_vid 85 | else: 86 | loss += loss_iou 87 | # loss += (loss_sent + loss_vid) * 0.01 88 | loss.backward() 89 | if max_norm > 0: 90 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 91 | optimizer.step() 92 | 93 | batch_time = time.time() - end 94 | end = time.time() 95 | meters.update(time=batch_time) 96 | eta_seconds = meters.time.global_avg * (max_iteration - iteration + rest_epoch_iteration) 97 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 98 | 99 | if iteration % 10 == 0 or iteration == max_iteration: 100 | logger.info( 101 | meters.delimiter.join( 102 | [ 103 | "eta: {eta}", 104 | "epoch: {epoch}/{max_epoch}", 105 | "iteration: {iteration}/{max_iteration}", 106 | "{meters}", 107 | "max mem: {memory:.0f}", 108 | ] 109 | ).format( 110 | eta=eta_string, 111 | epoch=epoch, 112 | max_epoch=max_epoch, 113 | iteration=iteration, 114 | max_iteration=max_iteration, 115 | meters=str(meters), 116 | memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, 117 | ) 118 | ) 119 | gc.collect() 120 | 121 | scheduler.step() 122 | if checkpoint_period != -1 and epoch % checkpoint_period == 0: 123 | checkpointer.save(f"{cfg.MODEL.DTF.FEAT2D.NAME}_model_{epoch}e", **arguments) 124 | 125 | if data_loader_val is not None and test_period > 0 and epoch % test_period == 0 and epoch >= cfg.SOLVER.SKIP_TEST: 126 | synchronize() 127 | torch.cuda.empty_cache() 128 | result_dict = inference( 129 | cfg, 130 | model, 131 | data_loader_val, 132 | dataset_name=cfg.DATASETS.TEST, 133 | nms_thresh=cfg.TEST.NMS_THRESH, 134 | device=cfg.MODEL.DEVICE, 135 | epoch=epoch, 136 | ) 137 | synchronize() 138 | model.train() 139 | total_training_time = time.time() - start_training_time 140 | total_time_str = str(datetime.timedelta(seconds=total_training_time)) 141 | logger.info( 142 | "Total training time: {} ({:.4f} s / it)".format( 143 | total_time_str, total_training_time / (max_iteration) 144 | ) 145 | ) 146 | -------------------------------------------------------------------------------- /dtfnet/engine/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dtfnet/engine/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/engine/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/engine/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/engine/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/engine/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/engine/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/engine/__pycache__/inference.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/engine/__pycache__/inference.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/engine/__pycache__/inference.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/engine/__pycache__/inference.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/engine/__pycache__/inference.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/engine/__pycache__/inference.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/engine/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/engine/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/engine/__pycache__/trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/engine/__pycache__/trainer.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/engine/__pycache__/trainer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/engine/__pycache__/trainer.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/engine/inference.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from dtfnet.data.datasets.evaluation import evaluate 4 | from ..utils.comm import is_main_process, get_world_size 5 | from ..utils.comm import all_gather 6 | from ..utils.comm import synchronize 7 | from ..utils.timer import Timer, get_time_str 8 | 9 | 10 | def compute_on_dataset(model, data_loader, device, timer=None): 11 | model.eval() 12 | results_dict = {} 13 | cpu_device = torch.device("cpu") 14 | for batch in data_loader: # use tqdm(data_loader) for showing progress bar 15 | batches, idxs = batch 16 | with torch.no_grad(): 17 | if timer: 18 | timer.tic() 19 | _,_,contrastive_output, iou_output = model(batches.to(device)) 20 | if timer: 21 | if not device.type == 'cpu': 22 | torch.cuda.synchronize() 23 | timer.toc() 24 | contrastive_output, iou_output = [o.to(cpu_device) for o in contrastive_output], [o.to(cpu_device) for o in iou_output] 25 | results_dict.update( 26 | {video_id: {'contrastive': result1, 'iou': result2} for video_id, result1, result2 in zip(idxs, contrastive_output, iou_output)} 27 | ) 28 | return results_dict 29 | 30 | 31 | def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu): 32 | all_predictions = all_gather(predictions_per_gpu) 33 | if not is_main_process(): 34 | return 35 | # merge the list of dicts 36 | predictions = {} 37 | for p in all_predictions: 38 | predictions.update(p) 39 | # convert a dict where the key is the index in a list 40 | idxs = list(sorted(predictions.keys())) 41 | if len(idxs) != idxs[-1] + 1: 42 | logger = logging.getLogger("dtf.inference") 43 | logger.warning( 44 | "Number of samples that were gathered from multiple processes is not " 45 | "a contiguous set. Some samples might be missing from the evaluation" 46 | ) 47 | 48 | # convert to a list 49 | predictions = [predictions[i] for i in idxs] 50 | return predictions 51 | 52 | def inference( 53 | cfg, 54 | model, 55 | data_loader, 56 | dataset_name, 57 | nms_thresh, 58 | epoch, 59 | device="cuda", 60 | ): 61 | # convert to a torch.device for efficiency 62 | device = torch.device(device) 63 | num_devices = get_world_size() 64 | logger = logging.getLogger("dtf.inference") 65 | dataset = data_loader.dataset 66 | logger.info("Start evaluation on {} dataset (Size: {}).".format(dataset_name, len(dataset))) 67 | inference_timer = Timer() 68 | predictions = compute_on_dataset(model, data_loader, device, inference_timer) 69 | # wait for all processes to complete before measuring the time 70 | synchronize() 71 | total_infer_time = get_time_str(inference_timer.total_time) 72 | logger.info( 73 | "Model inference time: {} ({:.03f} s / inference per device, on {} devices)".format( 74 | total_infer_time, 75 | inference_timer.total_time * num_devices / 17031.0, 76 | num_devices, 77 | ) 78 | ) 79 | 80 | predictions = _accumulate_predictions_from_multiple_gpus(predictions) 81 | ''' 82 | if not is_main_process(): 83 | return 84 | ''' 85 | return evaluate(cfg, dataset=dataset, predictions=predictions, nms_thresh=nms_thresh, epoch=epoch) 86 | -------------------------------------------------------------------------------- /dtfnet/engine/trainer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import time 5 | import gc 6 | import torch 7 | import torch.distributed as dist 8 | 9 | from dtfnet.data import make_data_loader 10 | from dtfnet.utils.comm import get_world_size, synchronize 11 | from dtfnet.utils.metric_logger import MetricLogger 12 | from dtfnet.engine.inference import inference 13 | from ..utils.comm import is_main_process 14 | 15 | 16 | def reduce_loss(loss): 17 | world_size = get_world_size() 18 | if world_size < 2: 19 | return loss 20 | with torch.no_grad(): 21 | dist.reduce(loss, dst=0) 22 | if dist.get_rank() == 0: 23 | # only main process gets accumulated, so only divide by 24 | # world_size in this case 25 | loss /= world_size 26 | loss = loss.item() 27 | return loss 28 | 29 | 30 | def do_train( 31 | cfg, 32 | model, 33 | data_loader, 34 | data_loader_val, 35 | optimizer, 36 | scheduler, 37 | checkpointer, 38 | device, 39 | checkpoint_period, 40 | test_period, 41 | arguments, 42 | param_dict, 43 | max_norm=5 44 | ): 45 | 46 | logger = logging.getLogger("dtf.trainer") 47 | logger.info("Start training") 48 | meters = MetricLogger(delimiter=" ") 49 | max_epoch = cfg.SOLVER.MAX_EPOCH 50 | 51 | model.train() 52 | start_training_time = time.time() 53 | end = time.time() 54 | max_iteration = len(data_loader) 55 | writer_count = 0 56 | 57 | for epoch in range(arguments["epoch"], max_epoch + 1): 58 | rest_epoch_iteration = (max_epoch - epoch) * max_iteration 59 | arguments["epoch"] = epoch 60 | # data_loader.batch_sampler.sampler.set_epoch(epoch) 61 | # if epoch <= cfg.SOLVER.FREEZE_BERT: 62 | # for param in param_dict['bert']: 63 | # param.requires_grad_(False) 64 | # else: 65 | # for param in param_dict['bert']: 66 | # param.requires_grad_(True) 67 | # logger.info("Start epoch {}. base_lr={:.1e}, bert_lr={:.1e}, bert.requires_grad={}".format(epoch, optimizer.param_groups[0]["lr"], optimizer.param_groups[1]["lr"], str(param_dict['bert'][0].requires_grad))) 68 | if epoch <= cfg.SOLVER.ONLY_IOU: 69 | logger.info("Using all losses") 70 | else: 71 | logger.info("Using only bce loss") 72 | for iteration, (batches, idx) in enumerate(data_loader): 73 | writer_count += 1 74 | iteration += 1 75 | batches = batches.to(device) 76 | optimizer.zero_grad() 77 | contr_weight = cfg.MODEL.DTF.LOSS.CONTRASTIVE_WEIGHT 78 | loss_vid, loss_sent, loss_iou = model(batches, cur_epoch=epoch) 79 | loss_vid, loss_sent = loss_vid * contr_weight, loss_sent * contr_weight 80 | meters.update(loss_vid=loss_vid.detach(), loss_sent=loss_sent.detach(), loss_iou=loss_iou.detach()) 81 | loss = 0 82 | if epoch <= cfg.SOLVER.ONLY_IOU: 83 | loss += loss_iou 84 | # loss += loss_sent + loss_vid 85 | else: 86 | loss += loss_iou 87 | # loss += (loss_sent + loss_vid) * 0.01 88 | loss.backward() 89 | if max_norm > 0: 90 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 91 | optimizer.step() 92 | 93 | batch_time = time.time() - end 94 | end = time.time() 95 | meters.update(time=batch_time) 96 | eta_seconds = meters.time.global_avg * (max_iteration - iteration + rest_epoch_iteration) 97 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 98 | 99 | if iteration % 10 == 0 or iteration == max_iteration: 100 | logger.info( 101 | meters.delimiter.join( 102 | [ 103 | "eta: {eta}", 104 | "epoch: {epoch}/{max_epoch}", 105 | "iteration: {iteration}/{max_iteration}", 106 | "{meters}", 107 | "max mem: {memory:.0f}", 108 | ] 109 | ).format( 110 | eta=eta_string, 111 | epoch=epoch, 112 | max_epoch=max_epoch, 113 | iteration=iteration, 114 | max_iteration=max_iteration, 115 | meters=str(meters), 116 | memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, 117 | ) 118 | ) 119 | gc.collect() 120 | 121 | scheduler.step() 122 | if checkpoint_period != -1 and epoch % checkpoint_period == 0: 123 | checkpointer.save(f"{cfg.MODEL.DTF.FEAT2D.NAME}_model_{epoch}e", **arguments) 124 | 125 | if data_loader_val is not None and test_period > 0 and epoch % test_period == 0 and epoch >= cfg.SOLVER.SKIP_TEST: 126 | synchronize() 127 | torch.cuda.empty_cache() 128 | result_dict = inference( 129 | cfg, 130 | model, 131 | data_loader_val, 132 | dataset_name=cfg.DATASETS.TEST, 133 | nms_thresh=cfg.TEST.NMS_THRESH, 134 | device=cfg.MODEL.DEVICE, 135 | epoch=epoch, 136 | ) 137 | synchronize() 138 | model.train() 139 | total_training_time = time.time() - start_training_time 140 | total_time_str = str(datetime.timedelta(seconds=total_training_time)) 141 | logger.info( 142 | "Total training time: {} ({:.4f} s / it)".format( 143 | total_time_str, total_training_time / (max_iteration) 144 | ) 145 | ) 146 | -------------------------------------------------------------------------------- /dtfnet/modeling/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .dtf import DTF 2 | ARCHITECTURES = {"DTF": DTF} 3 | 4 | def build_model(cfg): 5 | return ARCHITECTURES[cfg.MODEL.ARCHITECTURE](cfg) 6 | -------------------------------------------------------------------------------- /dtfnet/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .dtf import DTF 2 | ARCHITECTURES = {"DTF": DTF} 3 | 4 | def build_model(cfg): 5 | return ARCHITECTURES[cfg.MODEL.ARCHITECTURE](cfg) 6 | -------------------------------------------------------------------------------- /dtfnet/modeling/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .dtf_model import DTF 2 | -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/.ipynb_checkpoints/featpool-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class FeatAvgPool(nn.Module): 6 | def __init__(self, input_size, hidden_size, kernel_size, stride): 7 | super(FeatAvgPool, self).__init__() 8 | self.conv = nn.Conv1d(input_size, hidden_size, 1, 1) 9 | self.pool = nn.AvgPool1d(kernel_size, stride) 10 | 11 | def forward(self, x): 12 | x = x.transpose(1, 2) # B, C, T 13 | return self.pool(self.conv(x).relu()) 14 | 15 | def build_featpool(cfg): 16 | input_size = cfg.MODEL.DTF.FEATPOOL.INPUT_SIZE 17 | hidden_size = cfg.MODEL.DTF.FEATPOOL.HIDDEN_SIZE 18 | kernel_size = cfg.MODEL.DTF.FEATPOOL.KERNEL_SIZE # 4 for anet, 2 for tacos, 16 for charades 19 | stride = cfg.INPUT.NUM_PRE_CLIPS // cfg.MODEL.DTF.NUM_CLIPS 20 | return FeatAvgPool(input_size, hidden_size, kernel_size, stride) 21 | -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/.ipynb_checkpoints/position_encoding-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class TrainablePositionalEncoding(nn.Module): 11 | """Construct the embeddings from word, position and token_type embeddings. 12 | """ 13 | def __init__(self, max_position_embeddings, hidden_size, dropout=0.1): 14 | super(TrainablePositionalEncoding, self).__init__() 15 | self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) 16 | self.LayerNorm = nn.LayerNorm(hidden_size) 17 | self.dropout = nn.Dropout(dropout) 18 | 19 | def forward(self, input_feat, mask=None): 20 | """ 21 | Args: 22 | input_feat: (N, L, D) 23 | """ 24 | bsz, seq_length = input_feat.shape[:2] 25 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device) 26 | position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L) 27 | 28 | position_embeddings = self.position_embeddings(position_ids) 29 | 30 | embeddings = self.LayerNorm(input_feat + position_embeddings) 31 | embeddings = self.dropout(embeddings) 32 | return embeddings 33 | 34 | 35 | class PositionEmbeddingSine(nn.Module): 36 | """ 37 | This is a more standard version of the position embedding, very similar to the one 38 | used by the Attention is all you need paper, generalized to work on images. (To 1D sequences) 39 | """ 40 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 41 | super().__init__() 42 | self.num_pos_feats = num_pos_feats 43 | self.temperature = temperature 44 | self.normalize = normalize 45 | if scale is not None and normalize is False: 46 | raise ValueError("normalize should be True if scale is passed") 47 | if scale is None: 48 | scale = 2 * math.pi 49 | self.scale = scale 50 | 51 | def forward(self, x, mask): 52 | """ 53 | Args: 54 | x: torch.tensor, (batch_size, L, d) 55 | mask: torch.tensor, (batch_size, L), with 1 as valid 56 | 57 | Returns: 58 | 59 | """ 60 | assert mask is not None 61 | x_embed = mask.cumsum(1, dtype=torch.float32) # (bsz, L) 62 | if self.normalize: 63 | eps = 1e-6 64 | x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale 65 | 66 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 67 | # dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 68 | dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / self.num_pos_feats) 69 | pos_x = x_embed[:, :, None] / dim_t # (bsz, L, num_pos_feats) 70 | pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) # (bsz, L, num_pos_feats*2) 71 | # import ipdb; ipdb.set_trace() 72 | return pos_x # .permute(0, 2, 1) # (bsz, num_pos_feats*2, L) 73 | 74 | 75 | class PositionEmbeddingLearned(nn.Module): 76 | """ 77 | Absolute pos embedding, learned. 78 | """ 79 | def __init__(self, num_pos_feats=256): 80 | super().__init__() 81 | self.row_embed = nn.Embedding(50, num_pos_feats) 82 | self.col_embed = nn.Embedding(50, num_pos_feats) 83 | self.reset_parameters() 84 | 85 | def reset_parameters(self): 86 | nn.init.uniform_(self.row_embed.weight) 87 | nn.init.uniform_(self.col_embed.weight) 88 | 89 | def forward(self, x, mask): 90 | h, w = x.shape[-2:] 91 | i = torch.arange(w, device=x.device) 92 | j = torch.arange(h, device=x.device) 93 | x_emb = self.col_embed(i) 94 | y_emb = self.row_embed(j) 95 | pos = torch.cat([ 96 | x_emb.unsqueeze(0).repeat(h, 1, 1), 97 | y_emb.unsqueeze(1).repeat(1, w, 1), 98 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 99 | return pos 100 | 101 | 102 | def build_position_encoding(vid_position_embedding, txt_position_embedding, x): 103 | N_steps = x 104 | if vid_position_embedding == 'trainable': 105 | vid_pos_embed = TrainablePositionalEncoding( 106 | max_position_embeddings=64, 107 | hidden_size=x, 108 | dropout=0.5 109 | ) 110 | elif vid_position_embedding == 'sine': 111 | vid_pos_embed = PositionEmbeddingSine(N_steps, normalize=True) 112 | elif vid_position_embedding == 'learned': 113 | vid_pos_embed = PositionEmbeddingLearned(N_steps) 114 | else: 115 | raise ValueError(f"not supported {vid_position_embedding}") 116 | 117 | if txt_position_embedding == 'trainable': 118 | txt_pos_embed = TrainablePositionalEncoding( 119 | max_position_embeddings=4, 120 | hidden_size=x, 121 | dropout=0.5 122 | ) 123 | elif txt_position_embedding == 'sine': 124 | txt_pos_embed = PositionEmbeddingSine(N_steps, normalize=True) 125 | elif txt_position_embedding == 'learned': 126 | txt_pos_embed = PositionEmbeddingLearned(N_steps) 127 | else: 128 | raise ValueError(f"not supported {txt_position_embedding}") 129 | 130 | return vid_pos_embed, txt_pos_embed 131 | -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/.ipynb_checkpoints/proposal_conv-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | def mask2weight(mask2d, mask_kernel, padding=0): 6 | # from the feat2d.py,we can know the mask2d is 4-d: B, D, N, N 7 | weight = F.conv2d(mask2d[None, None, :, :].float(), 8 | mask_kernel, padding=padding)[0, 0] 9 | weight[weight > 0] = 1 / weight[weight > 0] 10 | return weight 11 | 12 | 13 | def get_padded_mask_and_weight(mask, conv): 14 | masked_weight = torch.round(F.conv2d(mask.clone().float(), torch.ones(1, 1, *conv.kernel_size).cuda(), stride=conv.stride, padding=conv.padding, dilation=conv.dilation)) 15 | masked_weight[masked_weight > 0] = 1 / masked_weight[masked_weight > 0] #conv.kernel_size[0] * conv.kernel_size[1] 16 | padded_mask = masked_weight > 0 17 | return padded_mask, masked_weight 18 | 19 | 20 | class ProposalConv(nn.Module): 21 | def __init__(self, input_size, hidden_size, k, num_stack_layers, output_size, mask2d, dataset): 22 | super(ProposalConv, self).__init__() 23 | self.num_stack_layers = num_stack_layers 24 | self.dataset = dataset 25 | self.mask2d = mask2d[None, None,:,:] 26 | # Padding to ensure the dimension of the output map2d 27 | first_padding = (k - 1) * num_stack_layers // 2 28 | self.bn = nn.ModuleList([nn.BatchNorm2d(hidden_size)]) 29 | self.convs = nn.ModuleList( 30 | [nn.Conv2d(input_size, hidden_size, k, padding=first_padding)] 31 | ) 32 | for _ in range(num_stack_layers - 1): 33 | self.convs.append(nn.Conv2d(hidden_size, hidden_size, k)) 34 | self.bn.append(nn.BatchNorm2d(hidden_size)) 35 | self.conv1x1_iou = nn.Conv2d(hidden_size, output_size, 1) 36 | self.conv1x1_contrastive = nn.Conv2d(hidden_size, output_size, 1) 37 | 38 | def forward(self, x): 39 | padded_mask = self.mask2d 40 | for i in range(self.num_stack_layers): 41 | x = self.bn[i](self.convs[i](x)).relu() 42 | padded_mask, masked_weight = get_padded_mask_and_weight(padded_mask, self.convs[i]) 43 | x = x * masked_weight 44 | out1 = self.conv1x1_contrastive(x) 45 | out2 = self.conv1x1_iou(x) 46 | return out1, out2 47 | 48 | 49 | def build_proposal_conv(cfg, mask2d,x): 50 | input_size = x 51 | hidden_size = cfg.MODEL.DTF.PREDICTOR.HIDDEN_SIZE 52 | kernel_size = cfg.MODEL.DTF.PREDICTOR.KERNEL_SIZE 53 | num_stack_layers = cfg.MODEL.DTF.PREDICTOR.NUM_STACK_LAYERS 54 | output_size = cfg.MODEL.DTF.JOINT_SPACE_SIZE 55 | dataset_name = cfg.DATASETS.NAME 56 | return ProposalConv(input_size, hidden_size, kernel_size, num_stack_layers, output_size, mask2d, dataset_name) -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/.ipynb_checkpoints/text_encoder-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | # from transformers import ASTModel 4 | 5 | 6 | class DistilBert(nn.Module): 7 | def __init__(self, joint_space_size, dataset): 8 | super().__init__() 9 | 10 | # self.bert = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593") 11 | self.fc_out1 = nn.Linear(768, joint_space_size) 12 | self.fc_out2 = nn.Linear(768, joint_space_size) 13 | self.dataset = dataset 14 | # self.layernorm = nn.LayerNorm(768) 15 | # self.avgpool = torch.nn.AdaptiveAvgPool1d(output_size=1) 16 | 17 | def forward(self, queries): 18 | ''' 19 | Average pooling over bert outputs among words to be sentence feature 20 | :param queries: 21 | :param wordlens: 22 | :param vid_avg_feat: B x C 23 | :return: list of [num_sent, C], len=Batch_size 24 | ''' 25 | sent_feat = [] 26 | sent_feat_iou = [] 27 | for query in queries: # each sample (several sentences) in a batch (of videos) 28 | query = query.cuda() 29 | 30 | # x = self.bert(input_values=query) 31 | # x = x.last_hidden_state 32 | # x = self.layernorm(x) 33 | # x = self.avgpool(x.transpose(1, 2)) 34 | # x = x.transpose(1, 2).squeeze(0) 35 | x = query 36 | # print(x.shape) 37 | out_iou = self.fc_out1(x).squeeze(0) 38 | out = self.fc_out2(x).squeeze(0) 39 | # print(out.shape) 40 | if out_iou.ndim == 1 and out.ndim == 1: 41 | out_iou = out_iou.view(1,-1) 42 | out = out.view(1, -1) 43 | 44 | sent_feat.append(out) 45 | sent_feat_iou.append(out_iou) 46 | return sent_feat, sent_feat_iou 47 | 48 | 49 | def build_text_encoder(cfg): 50 | joint_space_size = cfg.MODEL.DTF.JOINT_SPACE_SIZE 51 | dataset_name = cfg.DATASETS.NAME 52 | return DistilBert(joint_space_size, dataset_name) 53 | -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/.ipynb_checkpoints/text_out-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from torch.utils.checkpoint import checkpoint 5 | 6 | def apply_to_sample(f, sample): 7 | if hasattr(sample, '__len__') and len(sample) == 0: 8 | return {} 9 | 10 | def _apply(x): 11 | if torch.is_tensor(x): 12 | return f(x) 13 | elif isinstance(x, dict): 14 | return {key: _apply(value) for key, value in x.items()} 15 | elif isinstance(x, list): 16 | return [_apply(x) for x in x] 17 | else: 18 | return x 19 | 20 | return _apply(sample) 21 | 22 | 23 | def move_to_cuda(sample): 24 | def _move_to_cuda(tensor): 25 | return tensor.cuda() 26 | 27 | return apply_to_sample(_move_to_cuda, sample) 28 | 29 | class TextOut(nn.Module): 30 | def __init__(self, input_size, joint_space_size, dataset): 31 | super().__init__() 32 | 33 | self.fc_out1 = nn.Linear(input_size, joint_space_size) 34 | self.fc_out2 = nn.Linear(input_size, joint_space_size) 35 | self.dataset = dataset 36 | self.layernorm = nn.LayerNorm(joint_space_size) 37 | 38 | def forward(self, txts): 39 | 40 | txt_feat = [] 41 | txt_feat_iou = [] 42 | 43 | for txt in txts: # each sample (several sentences) in a batch (of videos) 44 | 45 | query = move_to_cuda(txt) 46 | # query = self.layernorm(query) 47 | out_iou = self.fc_out1(query) 48 | out = self.fc_out2(query) 49 | 50 | txt_feat.append(out.squeeze(0)) 51 | txt_feat_iou.append(out_iou.squeeze(0)) 52 | 53 | return txt_feat, txt_feat_iou 54 | 55 | 56 | def build_text_out(cfg): 57 | joint_space_size = cfg.MODEL.DTF.JOINT_SPACE_SIZE 58 | dataset_name = cfg.DATASETS.NAME 59 | return TextOut(joint_space_size, joint_space_size, dataset_name) 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | from mmn.config import cfg 65 | 66 | model = build_audio_encoder(cfg) 67 | model = model.cuda() 68 | model.eval() 69 | 70 | 71 | -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/.ipynb_checkpoints/utils-checkpoint.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | 8 | def uniform(size: int, value: Any): 9 | if isinstance(value, Tensor): 10 | bound = 1.0 / math.sqrt(size) 11 | value.data.uniform_(-bound, bound) 12 | else: 13 | for v in value.parameters() if hasattr(value, 'parameters') else []: 14 | uniform(size, v) 15 | for v in value.buffers() if hasattr(value, 'buffers') else []: 16 | uniform(size, v) 17 | 18 | 19 | def kaiming_uniform(value: Any, fan: int, a: float): 20 | if isinstance(value, Tensor): 21 | bound = math.sqrt(6 / ((1 + a**2) * fan)) 22 | value.data.uniform_(-bound, bound) 23 | else: 24 | for v in value.parameters() if hasattr(value, 'parameters') else []: 25 | kaiming_uniform(v, fan, a) 26 | for v in value.buffers() if hasattr(value, 'buffers') else []: 27 | kaiming_uniform(v, fan, a) 28 | 29 | 30 | def glorot(value: Any): 31 | if isinstance(value, Tensor): 32 | stdv = math.sqrt(6.0 / (value.size(-2) + value.size(-1))) 33 | value.data.uniform_(-stdv, stdv) 34 | else: 35 | for v in value.parameters() if hasattr(value, 'parameters') else []: 36 | glorot(v) 37 | for v in value.buffers() if hasattr(value, 'buffers') else []: 38 | glorot(v) 39 | 40 | 41 | def glorot_orthogonal(tensor, scale): 42 | if tensor is not None: 43 | torch.nn.init.orthogonal_(tensor.data) 44 | scale /= ((tensor.size(-2) + tensor.size(-1)) * tensor.var()) 45 | tensor.data *= scale.sqrt() 46 | 47 | 48 | def constant(value: Any, fill_value: float): 49 | if isinstance(value, Tensor): 50 | value.data.fill_(fill_value) 51 | else: 52 | for v in value.parameters() if hasattr(value, 'parameters') else []: 53 | constant(v, fill_value) 54 | for v in value.buffers() if hasattr(value, 'buffers') else []: 55 | constant(v, fill_value) 56 | 57 | 58 | def zeros(value: Any): 59 | constant(value, 0.) 60 | 61 | 62 | def ones(tensor: Any): 63 | constant(tensor, 1.) 64 | 65 | 66 | def normal(value: Any, mean: float, std: float): 67 | if isinstance(value, Tensor): 68 | value.data.normal_(mean, std) 69 | else: 70 | for v in value.parameters() if hasattr(value, 'parameters') else []: 71 | normal(v, mean, std) 72 | for v in value.buffers() if hasattr(value, 'buffers') else []: 73 | normal(v, mean, std) 74 | 75 | 76 | def reset(value: Any): 77 | if hasattr(value, 'reset_parameters'): 78 | value.reset_parameters() 79 | else: 80 | for child in value.children() if hasattr(value, 'children') else []: 81 | reset(child) 82 | -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__init__.py: -------------------------------------------------------------------------------- 1 | from .dtf_model import DTF 2 | -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/GCNNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/GCNNet.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/GCNNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/GCNNet.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/GCNNet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/GCNNet.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/dtf_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/dtf_model.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/dtf_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/dtf_model.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/dtf_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/dtf_model.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/dynamic_encode.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/dynamic_encode.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/feat2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/feat2d.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/feat2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/feat2d.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/feat2d.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/feat2d.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/featpool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/featpool.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/featpool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/featpool.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/featpool.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/featpool.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/mmn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/mmn.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/position_encoding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/position_encoding.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/position_encoding.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/position_encoding.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/position_encoding.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/position_encoding.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/proposal_conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/proposal_conv.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/proposal_conv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/proposal_conv.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/proposal_conv.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/proposal_conv.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/text_encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/text_encoder.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/text_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/text_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/text_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/text_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/text_out.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/text_out.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/text_out.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/text_out.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/text_out.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/text_out.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/modeling/dtf/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/featpool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class FeatAvgPool(nn.Module): 6 | def __init__(self, input_size, hidden_size, kernel_size, stride): 7 | super(FeatAvgPool, self).__init__() 8 | self.conv = nn.Conv1d(input_size, hidden_size, 1, 1) 9 | self.pool = nn.AvgPool1d(kernel_size, stride) 10 | 11 | def forward(self, x): 12 | x = x.transpose(1, 2) # B, C, T 13 | return self.pool(self.conv(x).relu()) 14 | 15 | def build_featpool(cfg): 16 | input_size = cfg.MODEL.DTF.FEATPOOL.INPUT_SIZE 17 | hidden_size = cfg.MODEL.DTF.FEATPOOL.HIDDEN_SIZE 18 | kernel_size = cfg.MODEL.DTF.FEATPOOL.KERNEL_SIZE # 4 for anet, 2 for tacos, 16 for charades 19 | stride = cfg.INPUT.NUM_PRE_CLIPS // cfg.MODEL.DTF.NUM_CLIPS 20 | return FeatAvgPool(input_size, hidden_size, kernel_size, stride) 21 | -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class TrainablePositionalEncoding(nn.Module): 11 | """Construct the embeddings from word, position and token_type embeddings. 12 | """ 13 | def __init__(self, max_position_embeddings, hidden_size, dropout=0.1): 14 | super(TrainablePositionalEncoding, self).__init__() 15 | self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) 16 | self.LayerNorm = nn.LayerNorm(hidden_size) 17 | self.dropout = nn.Dropout(dropout) 18 | 19 | def forward(self, input_feat, mask=None): 20 | """ 21 | Args: 22 | input_feat: (N, L, D) 23 | """ 24 | bsz, seq_length = input_feat.shape[:2] 25 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device) 26 | position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L) 27 | 28 | position_embeddings = self.position_embeddings(position_ids) 29 | 30 | embeddings = self.LayerNorm(input_feat + position_embeddings) 31 | embeddings = self.dropout(embeddings) 32 | return embeddings 33 | 34 | 35 | class PositionEmbeddingSine(nn.Module): 36 | """ 37 | This is a more standard version of the position embedding, very similar to the one 38 | used by the Attention is all you need paper, generalized to work on images. (To 1D sequences) 39 | """ 40 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 41 | super().__init__() 42 | self.num_pos_feats = num_pos_feats 43 | self.temperature = temperature 44 | self.normalize = normalize 45 | if scale is not None and normalize is False: 46 | raise ValueError("normalize should be True if scale is passed") 47 | if scale is None: 48 | scale = 2 * math.pi 49 | self.scale = scale 50 | 51 | def forward(self, x, mask): 52 | """ 53 | Args: 54 | x: torch.tensor, (batch_size, L, d) 55 | mask: torch.tensor, (batch_size, L), with 1 as valid 56 | 57 | Returns: 58 | 59 | """ 60 | assert mask is not None 61 | x_embed = mask.cumsum(1, dtype=torch.float32) # (bsz, L) 62 | if self.normalize: 63 | eps = 1e-6 64 | x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale 65 | 66 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 67 | # dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 68 | dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / self.num_pos_feats) 69 | pos_x = x_embed[:, :, None] / dim_t # (bsz, L, num_pos_feats) 70 | pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) # (bsz, L, num_pos_feats*2) 71 | # import ipdb; ipdb.set_trace() 72 | return pos_x # .permute(0, 2, 1) # (bsz, num_pos_feats*2, L) 73 | 74 | 75 | class PositionEmbeddingLearned(nn.Module): 76 | """ 77 | Absolute pos embedding, learned. 78 | """ 79 | def __init__(self, num_pos_feats=256): 80 | super().__init__() 81 | self.row_embed = nn.Embedding(50, num_pos_feats) 82 | self.col_embed = nn.Embedding(50, num_pos_feats) 83 | self.reset_parameters() 84 | 85 | def reset_parameters(self): 86 | nn.init.uniform_(self.row_embed.weight) 87 | nn.init.uniform_(self.col_embed.weight) 88 | 89 | def forward(self, x, mask): 90 | h, w = x.shape[-2:] 91 | i = torch.arange(w, device=x.device) 92 | j = torch.arange(h, device=x.device) 93 | x_emb = self.col_embed(i) 94 | y_emb = self.row_embed(j) 95 | pos = torch.cat([ 96 | x_emb.unsqueeze(0).repeat(h, 1, 1), 97 | y_emb.unsqueeze(1).repeat(1, w, 1), 98 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 99 | return pos 100 | 101 | 102 | def build_position_encoding(vid_position_embedding, txt_position_embedding, x): 103 | N_steps = x 104 | if vid_position_embedding == 'trainable': 105 | vid_pos_embed = TrainablePositionalEncoding( 106 | max_position_embeddings=64, 107 | hidden_size=x, 108 | dropout=0.5 109 | ) 110 | elif vid_position_embedding == 'sine': 111 | vid_pos_embed = PositionEmbeddingSine(N_steps, normalize=True) 112 | elif vid_position_embedding == 'learned': 113 | vid_pos_embed = PositionEmbeddingLearned(N_steps) 114 | else: 115 | raise ValueError(f"not supported {vid_position_embedding}") 116 | 117 | if txt_position_embedding == 'trainable': 118 | txt_pos_embed = TrainablePositionalEncoding( 119 | max_position_embeddings=4, 120 | hidden_size=x, 121 | dropout=0.5 122 | ) 123 | elif txt_position_embedding == 'sine': 124 | txt_pos_embed = PositionEmbeddingSine(N_steps, normalize=True) 125 | elif txt_position_embedding == 'learned': 126 | txt_pos_embed = PositionEmbeddingLearned(N_steps) 127 | else: 128 | raise ValueError(f"not supported {txt_position_embedding}") 129 | 130 | return vid_pos_embed, txt_pos_embed 131 | -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/proposal_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | def mask2weight(mask2d, mask_kernel, padding=0): 6 | # from the feat2d.py,we can know the mask2d is 4-d: B, D, N, N 7 | weight = F.conv2d(mask2d[None, None, :, :].float(), 8 | mask_kernel, padding=padding)[0, 0] 9 | weight[weight > 0] = 1 / weight[weight > 0] 10 | return weight 11 | 12 | 13 | def get_padded_mask_and_weight(mask, conv): 14 | masked_weight = torch.round(F.conv2d(mask.clone().float(), torch.ones(1, 1, *conv.kernel_size).cuda(), stride=conv.stride, padding=conv.padding, dilation=conv.dilation)) 15 | masked_weight[masked_weight > 0] = 1 / masked_weight[masked_weight > 0] #conv.kernel_size[0] * conv.kernel_size[1] 16 | padded_mask = masked_weight > 0 17 | return padded_mask, masked_weight 18 | 19 | 20 | class ProposalConv(nn.Module): 21 | def __init__(self, input_size, hidden_size, k, num_stack_layers, output_size, mask2d, dataset): 22 | super(ProposalConv, self).__init__() 23 | self.num_stack_layers = num_stack_layers 24 | self.dataset = dataset 25 | self.mask2d = mask2d[None, None,:,:] 26 | # Padding to ensure the dimension of the output map2d 27 | first_padding = (k - 1) * num_stack_layers // 2 28 | self.bn = nn.ModuleList([nn.BatchNorm2d(hidden_size)]) 29 | self.convs = nn.ModuleList( 30 | [nn.Conv2d(input_size, hidden_size, k, padding=first_padding)] 31 | ) 32 | for _ in range(num_stack_layers - 1): 33 | self.convs.append(nn.Conv2d(hidden_size, hidden_size, k)) 34 | self.bn.append(nn.BatchNorm2d(hidden_size)) 35 | self.conv1x1_iou = nn.Conv2d(hidden_size, output_size, 1) 36 | self.conv1x1_contrastive = nn.Conv2d(hidden_size, output_size, 1) 37 | 38 | def forward(self, x): 39 | padded_mask = self.mask2d 40 | for i in range(self.num_stack_layers): 41 | x = self.bn[i](self.convs[i](x)).relu() 42 | padded_mask, masked_weight = get_padded_mask_and_weight(padded_mask, self.convs[i]) 43 | x = x * masked_weight 44 | out1 = self.conv1x1_contrastive(x) 45 | out2 = self.conv1x1_iou(x) 46 | return out1, out2 47 | 48 | 49 | def build_proposal_conv(cfg, mask2d,x): 50 | input_size = x 51 | hidden_size = cfg.MODEL.DTF.PREDICTOR.HIDDEN_SIZE 52 | kernel_size = cfg.MODEL.DTF.PREDICTOR.KERNEL_SIZE 53 | num_stack_layers = cfg.MODEL.DTF.PREDICTOR.NUM_STACK_LAYERS 54 | output_size = cfg.MODEL.DTF.JOINT_SPACE_SIZE 55 | dataset_name = cfg.DATASETS.NAME 56 | return ProposalConv(input_size, hidden_size, kernel_size, num_stack_layers, output_size, mask2d, dataset_name) -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/text_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | # from transformers import ASTModel 4 | 5 | 6 | class DistilBert(nn.Module): 7 | def __init__(self, joint_space_size, dataset): 8 | super().__init__() 9 | 10 | # self.bert = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593") 11 | self.fc_out1 = nn.Linear(768, joint_space_size) 12 | self.fc_out2 = nn.Linear(768, joint_space_size) 13 | self.dataset = dataset 14 | # self.layernorm = nn.LayerNorm(768) 15 | # self.avgpool = torch.nn.AdaptiveAvgPool1d(output_size=1) 16 | 17 | def forward(self, queries): 18 | ''' 19 | Average pooling over bert outputs among words to be sentence feature 20 | :param queries: 21 | :param wordlens: 22 | :param vid_avg_feat: B x C 23 | :return: list of [num_sent, C], len=Batch_size 24 | ''' 25 | sent_feat = [] 26 | sent_feat_iou = [] 27 | for query in queries: # each sample (several sentences) in a batch (of videos) 28 | query = query.cuda() 29 | 30 | # x = self.bert(input_values=query) 31 | # x = x.last_hidden_state 32 | # x = self.layernorm(x) 33 | # x = self.avgpool(x.transpose(1, 2)) 34 | # x = x.transpose(1, 2).squeeze(0) 35 | x = query 36 | # print(x.shape) 37 | out_iou = self.fc_out1(x).squeeze(0) 38 | out = self.fc_out2(x).squeeze(0) 39 | # print(out.shape) 40 | if out_iou.ndim == 1 and out.ndim == 1: 41 | out_iou = out_iou.view(1,-1) 42 | out = out.view(1, -1) 43 | 44 | sent_feat.append(out) 45 | sent_feat_iou.append(out_iou) 46 | return sent_feat, sent_feat_iou 47 | 48 | 49 | def build_text_encoder(cfg): 50 | joint_space_size = cfg.MODEL.DTF.JOINT_SPACE_SIZE 51 | dataset_name = cfg.DATASETS.NAME 52 | return DistilBert(joint_space_size, dataset_name) 53 | -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/text_out.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from torch.utils.checkpoint import checkpoint 5 | 6 | def apply_to_sample(f, sample): 7 | if hasattr(sample, '__len__') and len(sample) == 0: 8 | return {} 9 | 10 | def _apply(x): 11 | if torch.is_tensor(x): 12 | return f(x) 13 | elif isinstance(x, dict): 14 | return {key: _apply(value) for key, value in x.items()} 15 | elif isinstance(x, list): 16 | return [_apply(x) for x in x] 17 | else: 18 | return x 19 | 20 | return _apply(sample) 21 | 22 | 23 | def move_to_cuda(sample): 24 | def _move_to_cuda(tensor): 25 | return tensor.cuda() 26 | 27 | return apply_to_sample(_move_to_cuda, sample) 28 | 29 | class TextOut(nn.Module): 30 | def __init__(self, input_size, joint_space_size, dataset): 31 | super().__init__() 32 | 33 | self.fc_out1 = nn.Linear(input_size, joint_space_size) 34 | self.fc_out2 = nn.Linear(input_size, joint_space_size) 35 | self.dataset = dataset 36 | self.layernorm = nn.LayerNorm(joint_space_size) 37 | 38 | def forward(self, txts): 39 | 40 | txt_feat = [] 41 | txt_feat_iou = [] 42 | 43 | for txt in txts: # each sample (several sentences) in a batch (of videos) 44 | 45 | query = move_to_cuda(txt) 46 | # query = self.layernorm(query) 47 | out_iou = self.fc_out1(query) 48 | out = self.fc_out2(query) 49 | 50 | txt_feat.append(out.squeeze(0)) 51 | txt_feat_iou.append(out_iou.squeeze(0)) 52 | 53 | return txt_feat, txt_feat_iou 54 | 55 | 56 | def build_text_out(cfg): 57 | joint_space_size = cfg.MODEL.DTF.JOINT_SPACE_SIZE 58 | dataset_name = cfg.DATASETS.NAME 59 | return TextOut(joint_space_size, joint_space_size, dataset_name) 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | from mmn.config import cfg 65 | 66 | model = build_audio_encoder(cfg) 67 | model = model.cuda() 68 | model.eval() 69 | 70 | 71 | -------------------------------------------------------------------------------- /dtfnet/modeling/dtf/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | 8 | def uniform(size: int, value: Any): 9 | if isinstance(value, Tensor): 10 | bound = 1.0 / math.sqrt(size) 11 | value.data.uniform_(-bound, bound) 12 | else: 13 | for v in value.parameters() if hasattr(value, 'parameters') else []: 14 | uniform(size, v) 15 | for v in value.buffers() if hasattr(value, 'buffers') else []: 16 | uniform(size, v) 17 | 18 | 19 | def kaiming_uniform(value: Any, fan: int, a: float): 20 | if isinstance(value, Tensor): 21 | bound = math.sqrt(6 / ((1 + a**2) * fan)) 22 | value.data.uniform_(-bound, bound) 23 | else: 24 | for v in value.parameters() if hasattr(value, 'parameters') else []: 25 | kaiming_uniform(v, fan, a) 26 | for v in value.buffers() if hasattr(value, 'buffers') else []: 27 | kaiming_uniform(v, fan, a) 28 | 29 | 30 | def glorot(value: Any): 31 | if isinstance(value, Tensor): 32 | stdv = math.sqrt(6.0 / (value.size(-2) + value.size(-1))) 33 | value.data.uniform_(-stdv, stdv) 34 | else: 35 | for v in value.parameters() if hasattr(value, 'parameters') else []: 36 | glorot(v) 37 | for v in value.buffers() if hasattr(value, 'buffers') else []: 38 | glorot(v) 39 | 40 | 41 | def glorot_orthogonal(tensor, scale): 42 | if tensor is not None: 43 | torch.nn.init.orthogonal_(tensor.data) 44 | scale /= ((tensor.size(-2) + tensor.size(-1)) * tensor.var()) 45 | tensor.data *= scale.sqrt() 46 | 47 | 48 | def constant(value: Any, fill_value: float): 49 | if isinstance(value, Tensor): 50 | value.data.fill_(fill_value) 51 | else: 52 | for v in value.parameters() if hasattr(value, 'parameters') else []: 53 | constant(v, fill_value) 54 | for v in value.buffers() if hasattr(value, 'buffers') else []: 55 | constant(v, fill_value) 56 | 57 | 58 | def zeros(value: Any): 59 | constant(value, 0.) 60 | 61 | 62 | def ones(tensor: Any): 63 | constant(tensor, 1.) 64 | 65 | 66 | def normal(value: Any, mean: float, std: float): 67 | if isinstance(value, Tensor): 68 | value.data.normal_(mean, std) 69 | else: 70 | for v in value.parameters() if hasattr(value, 'parameters') else []: 71 | normal(v, mean, std) 72 | for v in value.buffers() if hasattr(value, 'buffers') else []: 73 | normal(v, mean, std) 74 | 75 | 76 | def reset(value: Any): 77 | if hasattr(value, 'reset_parameters'): 78 | value.reset_parameters() 79 | else: 80 | for child in value.children() if hasattr(value, 'children') else []: 81 | reset(child) 82 | -------------------------------------------------------------------------------- /dtfnet/structures/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .tlg_batch import TLGBatch 2 | -------------------------------------------------------------------------------- /dtfnet/structures/.ipynb_checkpoints/tlg_batch-checkpoint.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | 4 | def apply_to_sample(f, sample): 5 | if hasattr(sample, '__len__') and len(sample) == 0: 6 | return {} 7 | 8 | def _apply(x): 9 | if torch.is_tensor(x): 10 | return f(x) 11 | elif isinstance(x, dict): 12 | return {key: _apply(value) for key, value in x.items()} 13 | elif isinstance(x, list): 14 | return [_apply(x) for x in x] 15 | else: 16 | return x 17 | 18 | return _apply(sample) 19 | 20 | 21 | def move_to_cuda(sample): 22 | def _move_to_cuda(tensor): 23 | return tensor.cuda() 24 | 25 | return apply_to_sample(_move_to_cuda, sample) 26 | 27 | # temporal localization grounding 28 | @dataclass 29 | class TLGBatch(object): 30 | # frames: list # [ImageList] 31 | feats: torch.tensor 32 | queries: list 33 | wordlens: list 34 | all_iou2d: list 35 | moments: list 36 | num_sentence: list 37 | idxs: torch.tensor 38 | vid: str 39 | 40 | def to(self, device): 41 | # self.frames = [f.to(device) for f in self.frames] 42 | self.feats = self.feats.to(device) 43 | self.queries = [query.to(device) for query in self.queries] 44 | self.wordlens = [word_len.to(device) for word_len in self.wordlens] 45 | self.all_iou2d = [iou2d.to(device) for iou2d in self.all_iou2d] 46 | self.moments = [moment.to(device) for moment in self.moments] 47 | self.idxs = move_to_cuda(self.idxs) 48 | self.vid = move_to_cuda(self.vid) 49 | 50 | return self 51 | 52 | 53 | -------------------------------------------------------------------------------- /dtfnet/structures/__init__.py: -------------------------------------------------------------------------------- 1 | from .tlg_batch import TLGBatch 2 | -------------------------------------------------------------------------------- /dtfnet/structures/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/structures/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/structures/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/structures/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/structures/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/structures/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/structures/__pycache__/tlg_batch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/structures/__pycache__/tlg_batch.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/structures/__pycache__/tlg_batch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/structures/__pycache__/tlg_batch.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/structures/__pycache__/tlg_batch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/structures/__pycache__/tlg_batch.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/structures/tlg_batch.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | 4 | def apply_to_sample(f, sample): 5 | if hasattr(sample, '__len__') and len(sample) == 0: 6 | return {} 7 | 8 | def _apply(x): 9 | if torch.is_tensor(x): 10 | return f(x) 11 | elif isinstance(x, dict): 12 | return {key: _apply(value) for key, value in x.items()} 13 | elif isinstance(x, list): 14 | return [_apply(x) for x in x] 15 | else: 16 | return x 17 | 18 | return _apply(sample) 19 | 20 | 21 | def move_to_cuda(sample): 22 | def _move_to_cuda(tensor): 23 | return tensor.cuda() 24 | 25 | return apply_to_sample(_move_to_cuda, sample) 26 | 27 | # temporal localization grounding 28 | @dataclass 29 | class TLGBatch(object): 30 | # frames: list # [ImageList] 31 | feats: torch.tensor 32 | queries: list 33 | all_iou2d: list 34 | moments: list 35 | num_sentence: list 36 | idxs: torch.tensor 37 | vid: str 38 | 39 | def to(self, device): 40 | # self.frames = [f.to(device) for f in self.frames] 41 | self.feats = self.feats.to(device) 42 | self.queries = [query.to(device) for query in self.queries] 43 | self.all_iou2d = [iou2d.to(device) for iou2d in self.all_iou2d] 44 | self.moments = [moment.to(device) for moment in self.moments] 45 | self.idxs = move_to_cuda(self.idxs) 46 | self.vid = move_to_cuda(self.vid) 47 | 48 | return self 49 | 50 | 51 | -------------------------------------------------------------------------------- /dtfnet/utils/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/.ipynb_checkpoints/__init__-checkpoint.py -------------------------------------------------------------------------------- /dtfnet/utils/.ipynb_checkpoints/checkpoint-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | 5 | import torch 6 | 7 | from dtfnet.utils.model_serialization import load_state_dict 8 | from dtfnet.utils.imports import import_file 9 | 10 | class Checkpointer(object): 11 | def __init__( 12 | self, 13 | model, 14 | optimizer=None, 15 | scheduler=None, 16 | save_dir="", 17 | save_to_disk=None, 18 | logger=None, 19 | ): 20 | self.model = model 21 | self.optimizer = optimizer 22 | self.scheduler = scheduler 23 | self.save_dir = save_dir 24 | self.save_to_disk = save_to_disk 25 | if logger is None: 26 | logger = logging.getLogger(__name__) 27 | self.logger = logger 28 | 29 | def save(self, name, **kwargs): 30 | if not self.save_dir: 31 | return 32 | 33 | if not self.save_to_disk: 34 | return 35 | 36 | data = {} 37 | data["model"] = self.model.state_dict() 38 | if self.optimizer is not None: 39 | data["optimizer"] = self.optimizer.state_dict() 40 | if self.scheduler is not None: 41 | data["scheduler"] = self.scheduler.state_dict() 42 | data.update(kwargs) 43 | 44 | save_file = os.path.join(self.save_dir, "{}.pth".format(name)) 45 | self.logger.info("Saving checkpoint to {}".format(save_file)) 46 | torch.save(data, save_file) 47 | #self.tag_last_checkpoint(save_file) 48 | return 49 | 50 | def load(self, f=None, use_latest=True): 51 | if self.has_checkpoint() and use_latest: 52 | # override argument with existing checkpoint 53 | f = self.get_checkpoint_file() 54 | if not f: 55 | # no checkpoint could be found 56 | self.logger.info("No checkpoint found. Initializing model from scratch") 57 | return {} 58 | self.logger.info("Loading checkpoint from {}".format(f)) 59 | checkpoint = self._load_file(f) 60 | self._load_model(checkpoint) 61 | if "optimizer" in checkpoint and self.optimizer: 62 | self.logger.info("Loading optimizer from {}".format(f)) 63 | self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 64 | if "scheduler" in checkpoint and self.scheduler: 65 | self.logger.info("Loading scheduler from {}".format(f)) 66 | self.scheduler.load_state_dict(checkpoint.pop("scheduler")) 67 | # return any further checkpoint data 68 | return self.model, self.optimizer, self.scheduler 69 | 70 | def has_checkpoint(self): 71 | save_file = os.path.join(self.save_dir, "last_checkpoint") 72 | return os.path.exists(save_file) 73 | 74 | def get_checkpoint_file(self): 75 | save_file = os.path.join(self.save_dir, "last_checkpoint") 76 | try: 77 | with open(save_file, "r") as f: 78 | last_saved = f.read() 79 | last_saved = last_saved.strip() 80 | except IOError: 81 | # if file doesn't exist, maybe because it has just been 82 | # deleted by a separate process 83 | last_saved = "" 84 | return last_saved 85 | 86 | def tag_last_checkpoint(self, last_filename): 87 | save_file = os.path.join(self.save_dir, "last_checkpoint") 88 | with open(save_file, "w") as f: 89 | f.write(last_filename) 90 | 91 | def _load_file(self, f): 92 | return torch.load(f, map_location=torch.device("cpu")) 93 | 94 | def _load_model(self, checkpoint): 95 | load_state_dict(self.model, checkpoint.pop("model")) 96 | 97 | 98 | class MmnCheckpointer(Checkpointer): 99 | def __init__( 100 | self, 101 | cfg, 102 | model, 103 | optimizer=None, 104 | scheduler=None, 105 | save_dir="", 106 | save_to_disk=None, 107 | logger=None, 108 | ): 109 | super(MmnCheckpointer, self).__init__( 110 | model, optimizer, scheduler, save_dir, save_to_disk, logger 111 | ) 112 | self.cfg = cfg.clone() 113 | 114 | def _load_file(self, f): 115 | loaded = super(MmnCheckpointer, self)._load_file(f) 116 | if "model" not in loaded: 117 | loaded = dict(model=loaded) 118 | return loaded 119 | -------------------------------------------------------------------------------- /dtfnet/utils/.ipynb_checkpoints/comm-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains primitives for multi-gpu communication. 3 | This is useful when doing distributed training. 4 | """ 5 | 6 | import time 7 | import pickle 8 | import torch 9 | import torch.distributed as dist 10 | 11 | 12 | def get_world_size(): 13 | if not dist.is_available(): 14 | return 1 15 | if not dist.is_initialized(): 16 | return 1 17 | return dist.get_world_size() 18 | 19 | 20 | def get_rank(): 21 | if not dist.is_available(): 22 | return 0 23 | if not dist.is_initialized(): 24 | return 0 25 | return dist.get_rank() 26 | 27 | 28 | def is_main_process(): 29 | return get_rank() == 0 30 | 31 | 32 | def synchronize(): 33 | """ 34 | Helper function to synchronize (barrier) among all processes when 35 | using distributed training 36 | """ 37 | if not dist.is_available(): 38 | return 39 | if not dist.is_initialized(): 40 | return 41 | world_size = dist.get_world_size() 42 | if world_size == 1: 43 | return 44 | dist.barrier() 45 | 46 | 47 | def all_gather(data): 48 | """ 49 | Run all_gather on arbitrary picklable data (not necessarily tensors) 50 | Args: 51 | data: any picklable object 52 | Returns: 53 | list[data]: list of data gathered from each rank 54 | """ 55 | world_size = get_world_size() 56 | if world_size == 1: 57 | return [data] 58 | 59 | buffer = pickle.dumps(data) 60 | storage = torch.ByteStorage.from_buffer(buffer) 61 | tensor = torch.ByteTensor(storage).to("cuda") 62 | # obtain Tensor size of each rank 63 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 64 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 65 | dist.all_gather(size_list, local_size) 66 | size_list = [int(size.item()) for size in size_list] 67 | max_size = max(size_list) 68 | 69 | # receiving Tensor from all ranks 70 | # we pad the tensor because torch all_gather does not support 71 | # gathering tensors of different shapes 72 | tensor_list = [] 73 | for _ in size_list: 74 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 75 | if local_size != max_size: 76 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 77 | tensor = torch.cat((tensor, padding), dim=0) 78 | dist.all_gather(tensor_list, tensor) 79 | 80 | data_list = [] 81 | for size, tensor in zip(size_list, tensor_list): 82 | buffer = tensor.cpu().numpy().tobytes()[:size] 83 | data_list.append(pickle.loads(buffer)) 84 | 85 | return data_list 86 | 87 | 88 | def reduce_dict(input_dict, average=True): 89 | """ 90 | Args: 91 | input_dict (dict): all the values will be reduced 92 | average (bool): whether to do average or sum 93 | Reduce the values in the dictionary from all processes so that process with rank 94 | 0 has the averaged results. Returns a dict with the same fields as 95 | input_dict, after reduction. 96 | """ 97 | world_size = get_world_size() 98 | if world_size < 2: 99 | return input_dict 100 | with torch.no_grad(): 101 | names = [] 102 | values = [] 103 | # sort the keys so that they are consistent across processes 104 | for k in sorted(input_dict.keys()): 105 | names.append(k) 106 | values.append(input_dict[k]) 107 | values = torch.stack(values, dim=0) 108 | dist.reduce(values, dst=0) 109 | if dist.get_rank() == 0 and average: 110 | # only main process gets accumulated, so only divide by 111 | # world_size in this case 112 | values /= world_size 113 | reduced_dict = {k: v for k, v in zip(names, values)} 114 | return reduced_dict 115 | 116 | 117 | def apply_to_sample(f, sample): 118 | if hasattr(sample, '__len__') and len(sample) == 0: 119 | return {} 120 | 121 | def _apply(x): 122 | if torch.is_tensor(x): 123 | return f(x) 124 | elif isinstance(x, dict): 125 | return {key: _apply(value) for key, value in x.items()} 126 | elif isinstance(x, list): 127 | return [_apply(x) for x in x] 128 | else: 129 | return x 130 | 131 | return _apply(sample) 132 | 133 | 134 | def move_to_cuda(sample): 135 | def _move_to_cuda(tensor): 136 | return tensor.cuda() 137 | 138 | return apply_to_sample(_move_to_cuda, sample) -------------------------------------------------------------------------------- /dtfnet/utils/.ipynb_checkpoints/imports-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import importlib 3 | import importlib.util 4 | import sys 5 | 6 | 7 | # from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa 8 | def import_file(module_name, file_path, make_importable=False): 9 | spec = importlib.util.spec_from_file_location(module_name, file_path) 10 | module = importlib.util.module_from_spec(spec) 11 | spec.loader.exec_module(module) 12 | if make_importable: 13 | sys.modules[module_name] = module 14 | return module 15 | -------------------------------------------------------------------------------- /dtfnet/utils/.ipynb_checkpoints/logger-checkpoint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt"): 6 | logger = logging.getLogger(name) 7 | logger.setLevel(logging.DEBUG) 8 | # don't log results for the non-master process 9 | if distributed_rank > 0: 10 | return logger 11 | ch = logging.StreamHandler(stream=sys.stdout) 12 | ch.setLevel(logging.DEBUG) 13 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 14 | ch.setFormatter(formatter) 15 | logger.addHandler(ch) 16 | 17 | if save_dir: 18 | fh = logging.FileHandler(os.path.join(save_dir, filename)) 19 | fh.setLevel(logging.DEBUG) 20 | fh.setFormatter(formatter) 21 | logger.addHandler(fh) 22 | 23 | return logger 24 | -------------------------------------------------------------------------------- /dtfnet/utils/.ipynb_checkpoints/metric_logger-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import defaultdict 3 | from collections import deque 4 | 5 | import torch 6 | 7 | 8 | class SmoothedValue(object): 9 | """Track a series of values and provide access to smoothed values over a 10 | window or the global series average. 11 | """ 12 | 13 | def __init__(self, window_size=10): 14 | self.deque = deque(maxlen=window_size) 15 | #self.series = [] 16 | self.total = 0.0 17 | self.count = 0 18 | 19 | def update(self, value): 20 | self.deque.append(value) 21 | #self.series.append(value) 22 | self.count += 1 23 | self.total += value 24 | 25 | @property 26 | def median(self): 27 | d = torch.tensor(list(self.deque)) 28 | return d.median().item() 29 | 30 | @property 31 | def last(self): 32 | d = torch.tensor(list(self.deque)[-1]) 33 | return d.item() 34 | 35 | @property 36 | def avg(self): 37 | d = torch.tensor(list(self.deque)) 38 | return d.mean().item() 39 | 40 | @property 41 | def global_avg(self): 42 | return self.total / self.count 43 | 44 | 45 | class MetricLogger(object): 46 | def __init__(self, delimiter="\t"): 47 | self.meters = defaultdict(SmoothedValue) 48 | self.delimiter = delimiter 49 | 50 | def update(self, **kwargs): 51 | for k, v in kwargs.items(): 52 | if isinstance(v, torch.Tensor): 53 | v = v.item() 54 | assert isinstance(v, (float, int)) 55 | self.meters[k].update(v) 56 | 57 | def __getattr__(self, attr): 58 | if attr in self.meters: 59 | return self.meters[attr] 60 | if attr in self.__dict__: 61 | return self.__dict__[attr] 62 | raise AttributeError("'{}' object has no attribute '{}'".format( 63 | type(self).__name__, attr)) 64 | 65 | def __str__(self): 66 | loss_str = [] 67 | for name, meter in self.meters.items(): 68 | loss_str.append("{}: {:.2f} ".format(name, meter.avg)) # ({:.2f}) meter.median, 69 | return self.delimiter.join(loss_str) 70 | -------------------------------------------------------------------------------- /dtfnet/utils/.ipynb_checkpoints/miscellaneous-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import errno 3 | import json 4 | import logging 5 | import os 6 | from .comm import is_main_process 7 | 8 | 9 | def mkdir(path): 10 | try: 11 | os.makedirs(path) 12 | except OSError as e: 13 | if e.errno != errno.EEXIST: 14 | raise 15 | 16 | 17 | def save_labels(dataset_list, output_dir): 18 | if is_main_process(): 19 | logger = logging.getLogger(__name__) 20 | 21 | ids_to_labels = {} 22 | for dataset in dataset_list: 23 | if hasattr(dataset, 'categories'): 24 | ids_to_labels.update(dataset.categories) 25 | else: 26 | logger.warning("Dataset [{}] has no categories attribute, labels.json file won't be created".format( 27 | dataset.__class__.__name__)) 28 | 29 | if ids_to_labels: 30 | labels_file = os.path.join(output_dir, 'labels.json') 31 | logger.info("Saving labels mapping into {}".format(labels_file)) 32 | with open(labels_file, 'w') as f: 33 | json.dump(ids_to_labels, f, indent=2) 34 | 35 | 36 | def save_config(cfg, path): 37 | if is_main_process(): 38 | with open(path, 'w') as f: 39 | f.write(cfg.dump()) 40 | -------------------------------------------------------------------------------- /dtfnet/utils/.ipynb_checkpoints/model_serialization-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import OrderedDict 3 | import logging 4 | 5 | import torch 6 | 7 | from dtfnet.utils.imports import import_file 8 | 9 | 10 | def align_and_update_state_dicts(model_state_dict, loaded_state_dict): 11 | """ 12 | Strategy: suppose that the models that we will create will have prefixes appended 13 | to each of its keys, for example due to an extra level of nesting that the original 14 | pre-trained weights from ImageNet won't contain. For example, model.state_dict() 15 | might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains 16 | res2.conv1.weight. We thus want to match both parameters together. 17 | For that, we look for each model weight, look among all loaded keys if there is one 18 | that is a suffix of the current weight name, and use it if that's the case. 19 | If multiple matches exist, take the one with longest size 20 | of the corresponding name. For example, for the same model as before, the pretrained 21 | weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, 22 | we want to match backbone[0].body.conv1.weight to conv1.weight, and 23 | backbone[0].body.res2.conv1.weight to res2.conv1.weight. 24 | """ 25 | current_keys = sorted(list(model_state_dict.keys())) 26 | loaded_keys = sorted(list(loaded_state_dict.keys())) 27 | # get a matrix of string matches, where each (i, j) entry correspond to the size of the 28 | # loaded_key string, if it matches 29 | match_matrix = [ 30 | len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys 31 | ] 32 | match_matrix = torch.as_tensor(match_matrix).view( 33 | len(current_keys), len(loaded_keys) 34 | ) 35 | max_match_size, idxs = match_matrix.max(1) 36 | # remove indices that correspond to no-match 37 | idxs[max_match_size == 0] = -1 38 | 39 | # used for logging 40 | max_size = max([len(key) for key in current_keys]) if current_keys else 1 41 | max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 42 | log_str_template = "{: <{}} loaded from {: <{}} of shape {}" 43 | logger = logging.getLogger(__name__) 44 | for idx_new, idx_old in enumerate(idxs.tolist()): 45 | if idx_old == -1: 46 | continue 47 | key = current_keys[idx_new] 48 | key_old = loaded_keys[idx_old] 49 | model_state_dict[key] = loaded_state_dict[key_old] 50 | logger.info( 51 | log_str_template.format( 52 | key, 53 | max_size, 54 | key_old, 55 | max_size_loaded, 56 | tuple(loaded_state_dict[key_old].shape), 57 | ) 58 | ) 59 | 60 | 61 | def strip_prefix_if_present(state_dict, prefix): 62 | keys = sorted(state_dict.keys()) 63 | if not all(key.startswith(prefix) for key in keys): 64 | return state_dict 65 | stripped_state_dict = OrderedDict() 66 | for key, value in state_dict.items(): 67 | stripped_state_dict[key.replace(prefix, "")] = value 68 | return stripped_state_dict 69 | 70 | 71 | def load_state_dict(model, loaded_state_dict): 72 | model_state_dict = model.state_dict() 73 | # if the state_dict comes from a model that was wrapped in a 74 | # DataParallel or DistributedDataParallel during serialization, 75 | # remove the "module" prefix before performing the matching 76 | loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") 77 | align_and_update_state_dicts(model_state_dict, loaded_state_dict) 78 | 79 | # use strict loading 80 | model.load_state_dict(model_state_dict) 81 | -------------------------------------------------------------------------------- /dtfnet/utils/.ipynb_checkpoints/registry-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | 4 | def _register_generic(module_dict, module_name, module): 5 | assert module_name not in module_dict 6 | module_dict[module_name] = module 7 | 8 | 9 | class Registry(dict): 10 | ''' 11 | A helper class for managing registering modules, it extends a dictionary 12 | and provides a register functions. 13 | 14 | Eg. creeting a registry: 15 | some_registry = Registry({"default": default_module}) 16 | 17 | There're two ways of registering new modules: 18 | 1): normal way is just calling register function: 19 | def foo(): 20 | ... 21 | some_registry.register("foo_module", foo) 22 | 2): used as decorator when declaring the module: 23 | @some_registry.register("foo_module") 24 | @some_registry.register("foo_modeul_nickname") 25 | def foo(): 26 | ... 27 | 28 | Access of module is just like using a dictionary, eg: 29 | f = some_registry["foo_modeul"] 30 | ''' 31 | def __init__(self, *args, **kwargs): 32 | super(Registry, self).__init__(*args, **kwargs) 33 | 34 | def register(self, module_name, module=None): 35 | # used as function call 36 | if module is not None: 37 | _register_generic(self, module_name, module) 38 | return 39 | 40 | # used as decorator 41 | def register_fn(fn): 42 | _register_generic(self, module_name, fn) 43 | return fn 44 | 45 | return register_fn 46 | -------------------------------------------------------------------------------- /dtfnet/utils/.ipynb_checkpoints/timer-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | 4 | import time 5 | import datetime 6 | 7 | 8 | class Timer(object): 9 | def __init__(self): 10 | self.reset() 11 | 12 | @property 13 | def average_time(self): 14 | return self.total_time / self.calls if self.calls > 0 else 0.0 15 | 16 | def tic(self): 17 | # using time.time instead of time.clock because time time.clock 18 | # does not normalize for multithreading 19 | self.start_time = time.time() 20 | 21 | def toc(self, average=True): 22 | self.add(time.time() - self.start_time) 23 | if average: 24 | return self.average_time 25 | else: 26 | return self.diff 27 | 28 | def add(self, time_diff): 29 | self.diff = time_diff 30 | self.total_time += self.diff 31 | self.calls += 1 32 | 33 | def reset(self): 34 | self.total_time = 0.0 35 | self.calls = 0 36 | self.start_time = 0.0 37 | self.diff = 0.0 38 | 39 | def avg_time_str(self): 40 | time_str = str(datetime.timedelta(seconds=self.average_time)) 41 | return time_str 42 | 43 | 44 | def get_time_str(time_diff): 45 | time_str = str(datetime.timedelta(seconds=time_diff)) 46 | return time_str 47 | -------------------------------------------------------------------------------- /dtfnet/utils/README.md: -------------------------------------------------------------------------------- 1 | # Utility functions 2 | 3 | This folder contain utility functions that are not used in the 4 | core library, but are useful for building models or training 5 | code using the config system. 6 | -------------------------------------------------------------------------------- /dtfnet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__init__.py -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/checkpoint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/checkpoint.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/checkpoint.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/checkpoint.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/checkpoint.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/checkpoint.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/comm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/comm.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/comm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/comm.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/imports.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/imports.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/imports.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/imports.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/imports.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/imports.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/metric_logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/metric_logger.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/metric_logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/metric_logger.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/metric_logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/metric_logger.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/miscellaneous.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/miscellaneous.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/miscellaneous.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/miscellaneous.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/miscellaneous.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/miscellaneous.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/model_serialization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/model_serialization.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/model_serialization.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/model_serialization.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/model_serialization.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/model_serialization.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/timer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/timer.cpython-36.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/timer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/timer.cpython-38.pyc -------------------------------------------------------------------------------- /dtfnet/utils/__pycache__/timer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xian-sh/UniSDNet/f6f46eb55f050372cfe5679ee0e572234886576a/dtfnet/utils/__pycache__/timer.cpython-39.pyc -------------------------------------------------------------------------------- /dtfnet/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | 5 | import torch 6 | 7 | from dtfnet.utils.model_serialization import load_state_dict 8 | from dtfnet.utils.imports import import_file 9 | 10 | class Checkpointer(object): 11 | def __init__( 12 | self, 13 | model, 14 | optimizer=None, 15 | scheduler=None, 16 | save_dir="", 17 | save_to_disk=None, 18 | logger=None, 19 | ): 20 | self.model = model 21 | self.optimizer = optimizer 22 | self.scheduler = scheduler 23 | self.save_dir = save_dir 24 | self.save_to_disk = save_to_disk 25 | if logger is None: 26 | logger = logging.getLogger(__name__) 27 | self.logger = logger 28 | 29 | def save(self, name, **kwargs): 30 | if not self.save_dir: 31 | return 32 | 33 | if not self.save_to_disk: 34 | return 35 | 36 | data = {} 37 | data["model"] = self.model.state_dict() 38 | if self.optimizer is not None: 39 | data["optimizer"] = self.optimizer.state_dict() 40 | if self.scheduler is not None: 41 | data["scheduler"] = self.scheduler.state_dict() 42 | data.update(kwargs) 43 | 44 | save_file = os.path.join(self.save_dir, "{}.pth".format(name)) 45 | self.logger.info("Saving checkpoint to {}".format(save_file)) 46 | torch.save(data, save_file) 47 | #self.tag_last_checkpoint(save_file) 48 | return 49 | 50 | def load(self, f=None, use_latest=True): 51 | if self.has_checkpoint() and use_latest: 52 | # override argument with existing checkpoint 53 | f = self.get_checkpoint_file() 54 | if not f: 55 | # no checkpoint could be found 56 | self.logger.info("No checkpoint found. Initializing model from scratch") 57 | return {} 58 | self.logger.info("Loading checkpoint from {}".format(f)) 59 | checkpoint = self._load_file(f) 60 | self._load_model(checkpoint) 61 | if "optimizer" in checkpoint and self.optimizer: 62 | self.logger.info("Loading optimizer from {}".format(f)) 63 | self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 64 | if "scheduler" in checkpoint and self.scheduler: 65 | self.logger.info("Loading scheduler from {}".format(f)) 66 | self.scheduler.load_state_dict(checkpoint.pop("scheduler")) 67 | # return any further checkpoint data 68 | return self.model, self.optimizer, self.scheduler 69 | 70 | def has_checkpoint(self): 71 | save_file = os.path.join(self.save_dir, "last_checkpoint") 72 | return os.path.exists(save_file) 73 | 74 | def get_checkpoint_file(self): 75 | save_file = os.path.join(self.save_dir, "last_checkpoint") 76 | try: 77 | with open(save_file, "r") as f: 78 | last_saved = f.read() 79 | last_saved = last_saved.strip() 80 | except IOError: 81 | # if file doesn't exist, maybe because it has just been 82 | # deleted by a separate process 83 | last_saved = "" 84 | return last_saved 85 | 86 | def tag_last_checkpoint(self, last_filename): 87 | save_file = os.path.join(self.save_dir, "last_checkpoint") 88 | with open(save_file, "w") as f: 89 | f.write(last_filename) 90 | 91 | def _load_file(self, f): 92 | return torch.load(f, map_location=torch.device("cpu")) 93 | 94 | def _load_model(self, checkpoint): 95 | load_state_dict(self.model, checkpoint.pop("model")) 96 | 97 | 98 | class MmnCheckpointer(Checkpointer): 99 | def __init__( 100 | self, 101 | cfg, 102 | model, 103 | optimizer=None, 104 | scheduler=None, 105 | save_dir="", 106 | save_to_disk=None, 107 | logger=None, 108 | ): 109 | super(MmnCheckpointer, self).__init__( 110 | model, optimizer, scheduler, save_dir, save_to_disk, logger 111 | ) 112 | self.cfg = cfg.clone() 113 | 114 | def _load_file(self, f): 115 | loaded = super(MmnCheckpointer, self)._load_file(f) 116 | if "model" not in loaded: 117 | loaded = dict(model=loaded) 118 | return loaded 119 | -------------------------------------------------------------------------------- /dtfnet/utils/comm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains primitives for multi-gpu communication. 3 | This is useful when doing distributed training. 4 | """ 5 | 6 | import time 7 | import pickle 8 | import torch 9 | import torch.distributed as dist 10 | 11 | 12 | def get_world_size(): 13 | if not dist.is_available(): 14 | return 1 15 | if not dist.is_initialized(): 16 | return 1 17 | return dist.get_world_size() 18 | 19 | 20 | def get_rank(): 21 | if not dist.is_available(): 22 | return 0 23 | if not dist.is_initialized(): 24 | return 0 25 | return dist.get_rank() 26 | 27 | 28 | def is_main_process(): 29 | return get_rank() == 0 30 | 31 | 32 | def synchronize(): 33 | """ 34 | Helper function to synchronize (barrier) among all processes when 35 | using distributed training 36 | """ 37 | if not dist.is_available(): 38 | return 39 | if not dist.is_initialized(): 40 | return 41 | world_size = dist.get_world_size() 42 | if world_size == 1: 43 | return 44 | dist.barrier() 45 | 46 | 47 | def all_gather(data): 48 | """ 49 | Run all_gather on arbitrary picklable data (not necessarily tensors) 50 | Args: 51 | data: any picklable object 52 | Returns: 53 | list[data]: list of data gathered from each rank 54 | """ 55 | world_size = get_world_size() 56 | if world_size == 1: 57 | return [data] 58 | 59 | buffer = pickle.dumps(data) 60 | storage = torch.ByteStorage.from_buffer(buffer) 61 | tensor = torch.ByteTensor(storage).to("cuda") 62 | # obtain Tensor size of each rank 63 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 64 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 65 | dist.all_gather(size_list, local_size) 66 | size_list = [int(size.item()) for size in size_list] 67 | max_size = max(size_list) 68 | 69 | # receiving Tensor from all ranks 70 | # we pad the tensor because torch all_gather does not support 71 | # gathering tensors of different shapes 72 | tensor_list = [] 73 | for _ in size_list: 74 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 75 | if local_size != max_size: 76 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 77 | tensor = torch.cat((tensor, padding), dim=0) 78 | dist.all_gather(tensor_list, tensor) 79 | 80 | data_list = [] 81 | for size, tensor in zip(size_list, tensor_list): 82 | buffer = tensor.cpu().numpy().tobytes()[:size] 83 | data_list.append(pickle.loads(buffer)) 84 | 85 | return data_list 86 | 87 | 88 | def reduce_dict(input_dict, average=True): 89 | """ 90 | Args: 91 | input_dict (dict): all the values will be reduced 92 | average (bool): whether to do average or sum 93 | Reduce the values in the dictionary from all processes so that process with rank 94 | 0 has the averaged results. Returns a dict with the same fields as 95 | input_dict, after reduction. 96 | """ 97 | world_size = get_world_size() 98 | if world_size < 2: 99 | return input_dict 100 | with torch.no_grad(): 101 | names = [] 102 | values = [] 103 | # sort the keys so that they are consistent across processes 104 | for k in sorted(input_dict.keys()): 105 | names.append(k) 106 | values.append(input_dict[k]) 107 | values = torch.stack(values, dim=0) 108 | dist.reduce(values, dst=0) 109 | if dist.get_rank() == 0 and average: 110 | # only main process gets accumulated, so only divide by 111 | # world_size in this case 112 | values /= world_size 113 | reduced_dict = {k: v for k, v in zip(names, values)} 114 | return reduced_dict 115 | 116 | 117 | def apply_to_sample(f, sample): 118 | if hasattr(sample, '__len__') and len(sample) == 0: 119 | return {} 120 | 121 | def _apply(x): 122 | if torch.is_tensor(x): 123 | return f(x) 124 | elif isinstance(x, dict): 125 | return {key: _apply(value) for key, value in x.items()} 126 | elif isinstance(x, list): 127 | return [_apply(x) for x in x] 128 | else: 129 | return x 130 | 131 | return _apply(sample) 132 | 133 | 134 | def move_to_cuda(sample): 135 | def _move_to_cuda(tensor): 136 | return tensor.cuda() 137 | 138 | return apply_to_sample(_move_to_cuda, sample) -------------------------------------------------------------------------------- /dtfnet/utils/imports.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import importlib 3 | import importlib.util 4 | import sys 5 | 6 | 7 | # from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa 8 | def import_file(module_name, file_path, make_importable=False): 9 | spec = importlib.util.spec_from_file_location(module_name, file_path) 10 | module = importlib.util.module_from_spec(spec) 11 | spec.loader.exec_module(module) 12 | if make_importable: 13 | sys.modules[module_name] = module 14 | return module 15 | -------------------------------------------------------------------------------- /dtfnet/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt"): 6 | logger = logging.getLogger(name) 7 | logger.setLevel(logging.DEBUG) 8 | # don't log results for the non-master process 9 | if distributed_rank > 0: 10 | return logger 11 | ch = logging.StreamHandler(stream=sys.stdout) 12 | ch.setLevel(logging.DEBUG) 13 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 14 | ch.setFormatter(formatter) 15 | logger.addHandler(ch) 16 | 17 | if save_dir: 18 | fh = logging.FileHandler(os.path.join(save_dir, filename)) 19 | fh.setLevel(logging.DEBUG) 20 | fh.setFormatter(formatter) 21 | logger.addHandler(fh) 22 | 23 | return logger 24 | -------------------------------------------------------------------------------- /dtfnet/utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import defaultdict 3 | from collections import deque 4 | 5 | import torch 6 | 7 | 8 | class SmoothedValue(object): 9 | """Track a series of values and provide access to smoothed values over a 10 | window or the global series average. 11 | """ 12 | 13 | def __init__(self, window_size=10): 14 | self.deque = deque(maxlen=window_size) 15 | #self.series = [] 16 | self.total = 0.0 17 | self.count = 0 18 | 19 | def update(self, value): 20 | self.deque.append(value) 21 | #self.series.append(value) 22 | self.count += 1 23 | self.total += value 24 | 25 | @property 26 | def median(self): 27 | d = torch.tensor(list(self.deque)) 28 | return d.median().item() 29 | 30 | @property 31 | def last(self): 32 | d = torch.tensor(list(self.deque)[-1]) 33 | return d.item() 34 | 35 | @property 36 | def avg(self): 37 | d = torch.tensor(list(self.deque)) 38 | return d.mean().item() 39 | 40 | @property 41 | def global_avg(self): 42 | return self.total / self.count 43 | 44 | 45 | class MetricLogger(object): 46 | def __init__(self, delimiter="\t"): 47 | self.meters = defaultdict(SmoothedValue) 48 | self.delimiter = delimiter 49 | 50 | def update(self, **kwargs): 51 | for k, v in kwargs.items(): 52 | if isinstance(v, torch.Tensor): 53 | v = v.item() 54 | assert isinstance(v, (float, int)) 55 | self.meters[k].update(v) 56 | 57 | def __getattr__(self, attr): 58 | if attr in self.meters: 59 | return self.meters[attr] 60 | if attr in self.__dict__: 61 | return self.__dict__[attr] 62 | raise AttributeError("'{}' object has no attribute '{}'".format( 63 | type(self).__name__, attr)) 64 | 65 | def __str__(self): 66 | loss_str = [] 67 | for name, meter in self.meters.items(): 68 | loss_str.append("{}: {:.2f} ".format(name, meter.avg)) # ({:.2f}) meter.median, 69 | return self.delimiter.join(loss_str) 70 | -------------------------------------------------------------------------------- /dtfnet/utils/miscellaneous.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import errno 3 | import json 4 | import logging 5 | import os 6 | from .comm import is_main_process 7 | 8 | 9 | def mkdir(path): 10 | try: 11 | os.makedirs(path) 12 | except OSError as e: 13 | if e.errno != errno.EEXIST: 14 | raise 15 | 16 | 17 | def save_labels(dataset_list, output_dir): 18 | if is_main_process(): 19 | logger = logging.getLogger(__name__) 20 | 21 | ids_to_labels = {} 22 | for dataset in dataset_list: 23 | if hasattr(dataset, 'categories'): 24 | ids_to_labels.update(dataset.categories) 25 | else: 26 | logger.warning("Dataset [{}] has no categories attribute, labels.json file won't be created".format( 27 | dataset.__class__.__name__)) 28 | 29 | if ids_to_labels: 30 | labels_file = os.path.join(output_dir, 'labels.json') 31 | logger.info("Saving labels mapping into {}".format(labels_file)) 32 | with open(labels_file, 'w') as f: 33 | json.dump(ids_to_labels, f, indent=2) 34 | 35 | 36 | def save_config(cfg, path): 37 | if is_main_process(): 38 | with open(path, 'w') as f: 39 | f.write(cfg.dump()) 40 | -------------------------------------------------------------------------------- /dtfnet/utils/model_serialization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import OrderedDict 3 | import logging 4 | 5 | import torch 6 | 7 | from dtfnet.utils.imports import import_file 8 | 9 | 10 | def align_and_update_state_dicts(model_state_dict, loaded_state_dict): 11 | """ 12 | Strategy: suppose that the models that we will create will have prefixes appended 13 | to each of its keys, for example due to an extra level of nesting that the original 14 | pre-trained weights from ImageNet won't contain. For example, model.state_dict() 15 | might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains 16 | res2.conv1.weight. We thus want to match both parameters together. 17 | For that, we look for each model weight, look among all loaded keys if there is one 18 | that is a suffix of the current weight name, and use it if that's the case. 19 | If multiple matches exist, take the one with longest size 20 | of the corresponding name. For example, for the same model as before, the pretrained 21 | weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, 22 | we want to match backbone[0].body.conv1.weight to conv1.weight, and 23 | backbone[0].body.res2.conv1.weight to res2.conv1.weight. 24 | """ 25 | current_keys = sorted(list(model_state_dict.keys())) 26 | loaded_keys = sorted(list(loaded_state_dict.keys())) 27 | # get a matrix of string matches, where each (i, j) entry correspond to the size of the 28 | # loaded_key string, if it matches 29 | match_matrix = [ 30 | len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys 31 | ] 32 | match_matrix = torch.as_tensor(match_matrix).view( 33 | len(current_keys), len(loaded_keys) 34 | ) 35 | max_match_size, idxs = match_matrix.max(1) 36 | # remove indices that correspond to no-match 37 | idxs[max_match_size == 0] = -1 38 | 39 | # used for logging 40 | max_size = max([len(key) for key in current_keys]) if current_keys else 1 41 | max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 42 | log_str_template = "{: <{}} loaded from {: <{}} of shape {}" 43 | logger = logging.getLogger(__name__) 44 | for idx_new, idx_old in enumerate(idxs.tolist()): 45 | if idx_old == -1: 46 | continue 47 | key = current_keys[idx_new] 48 | key_old = loaded_keys[idx_old] 49 | model_state_dict[key] = loaded_state_dict[key_old] 50 | logger.info( 51 | log_str_template.format( 52 | key, 53 | max_size, 54 | key_old, 55 | max_size_loaded, 56 | tuple(loaded_state_dict[key_old].shape), 57 | ) 58 | ) 59 | 60 | 61 | def strip_prefix_if_present(state_dict, prefix): 62 | keys = sorted(state_dict.keys()) 63 | if not all(key.startswith(prefix) for key in keys): 64 | return state_dict 65 | stripped_state_dict = OrderedDict() 66 | for key, value in state_dict.items(): 67 | stripped_state_dict[key.replace(prefix, "")] = value 68 | return stripped_state_dict 69 | 70 | 71 | def load_state_dict(model, loaded_state_dict): 72 | model_state_dict = model.state_dict() 73 | # if the state_dict comes from a model that was wrapped in a 74 | # DataParallel or DistributedDataParallel during serialization, 75 | # remove the "module" prefix before performing the matching 76 | loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") 77 | align_and_update_state_dicts(model_state_dict, loaded_state_dict) 78 | 79 | # use strict loading 80 | model.load_state_dict(model_state_dict) 81 | -------------------------------------------------------------------------------- /dtfnet/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | 4 | def _register_generic(module_dict, module_name, module): 5 | assert module_name not in module_dict 6 | module_dict[module_name] = module 7 | 8 | 9 | class Registry(dict): 10 | ''' 11 | A helper class for managing registering modules, it extends a dictionary 12 | and provides a register functions. 13 | 14 | Eg. creeting a registry: 15 | some_registry = Registry({"default": default_module}) 16 | 17 | There're two ways of registering new modules: 18 | 1): normal way is just calling register function: 19 | def foo(): 20 | ... 21 | some_registry.register("foo_module", foo) 22 | 2): used as decorator when declaring the module: 23 | @some_registry.register("foo_module") 24 | @some_registry.register("foo_modeul_nickname") 25 | def foo(): 26 | ... 27 | 28 | Access of module is just like using a dictionary, eg: 29 | f = some_registry["foo_modeul"] 30 | ''' 31 | def __init__(self, *args, **kwargs): 32 | super(Registry, self).__init__(*args, **kwargs) 33 | 34 | def register(self, module_name, module=None): 35 | # used as function call 36 | if module is not None: 37 | _register_generic(self, module_name, module) 38 | return 39 | 40 | # used as decorator 41 | def register_fn(fn): 42 | _register_generic(self, module_name, fn) 43 | return fn 44 | 45 | return register_fn 46 | -------------------------------------------------------------------------------- /dtfnet/utils/timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | 4 | import time 5 | import datetime 6 | 7 | 8 | class Timer(object): 9 | def __init__(self): 10 | self.reset() 11 | 12 | @property 13 | def average_time(self): 14 | return self.total_time / self.calls if self.calls > 0 else 0.0 15 | 16 | def tic(self): 17 | # using time.time instead of time.clock because time time.clock 18 | # does not normalize for multithreading 19 | self.start_time = time.time() 20 | 21 | def toc(self, average=True): 22 | self.add(time.time() - self.start_time) 23 | if average: 24 | return self.average_time 25 | else: 26 | return self.diff 27 | 28 | def add(self, time_diff): 29 | self.diff = time_diff 30 | self.total_time += self.diff 31 | self.calls += 1 32 | 33 | def reset(self): 34 | self.total_time = 0.0 35 | self.calls = 0 36 | self.start_time = 0.0 37 | self.diff = 0.0 38 | 39 | def avg_time_str(self): 40 | time_str = str(datetime.timedelta(seconds=self.average_time)) 41 | return time_str 42 | 43 | 44 | def get_time_str(time_diff): 45 | time_str = str(datetime.timedelta(seconds=time_diff)) 46 | return time_str 47 | -------------------------------------------------------------------------------- /pre_process/feat_pca.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from sklearn.decomposition import PCA 4 | 5 | feat_file = "G:/Dataset/data/TACoS/tall_c3d_features.hdf5" 6 | target_file = "G:/Dataset/data/TACoS/tall_c3d_features_pca.hdf5" 7 | 8 | n_components = 500 9 | pca = PCA(n_components=n_components) 10 | 11 | f = h5py.File(feat_file, 'r') 12 | f2 = h5py.File(target_file, 'a') 13 | 14 | feature_all = [] 15 | feat_len = [] 16 | feat_name = [] 17 | 18 | for vid in f.keys(): 19 | feat_name.append(vid) 20 | features = f[vid][:] # (n, 4096) 21 | features = np.asarray(features, dtype=np.float32) 22 | feature_all.append(features) 23 | feat_len.append(features.shape[0]) 24 | 25 | 26 | feature_all = np.concatenate(feature_all, axis=0) # (n, 4096) 27 | print(feature_all.shape) # tacos torch.Size([88569, 4096]) 28 | 29 | feature_all = pca.fit_transform(feature_all) # (n, 500) 30 | print(feature_all.shape) 31 | 32 | for i, vid in enumerate(feat_name): 33 | start = sum(feat_len[:i]) 34 | end = sum(feat_len[:i+1]) 35 | f2.create_dataset(vid, data=feature_all[start:end]) 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /pre_process/text_encode.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Note the dependency package torch for sentence_transformers is version 2.0.1, 3 | it is recommended to create a new Conda environment and then 4 | pip install sentence_transformers 5 | ''' 6 | 7 | import os 8 | import json 9 | import torch 10 | from tqdm import tqdm 11 | from sentence_transformers import SentenceTransformer 12 | from transformers import AutoTokenizer, Data2VecTextModel, DistilBertModel 13 | 14 | def text_glove_feat_extract(text_path, text_dir): 15 | model = SentenceTransformer('sentence-transformers/average_word_embeddings_glove.6B.300d') # 'sentence-transformers/average_word_embeddings_glove.840B.300d' 16 | if not os.path.exists(text_dir): 17 | os.mkdir(text_dir) 18 | 19 | data = json.load(open(text_path, 'r', encoding='utf-8')) 20 | 21 | for value in tqdm(data.values()): 22 | sent = value['sentences'] 23 | aid_name = value['audios'] 24 | for s, a in zip(sent, aid_name): 25 | save_path = os.path.join(text_dir, a.split(".")[0] + '.pt') 26 | if os.path.exists(save_path): 27 | continue 28 | 29 | outputs = model.encode(s, convert_to_numpy=False).unsqueeze(0) 30 | torch.save(outputs, save_path) 31 | 32 | class Data2vecModel(torch.nn.Module): 33 | def __init__(self): 34 | super(Data2vecModel, self).__init__() 35 | 36 | self.tokenizer = AutoTokenizer.from_pretrained("facebook/data2vec-text-base") 37 | self.model = Data2VecTextModel.from_pretrained("facebook/data2vec-text-base") 38 | 39 | self.norm = torch.nn.LayerNorm(768, eps=1e-05, elementwise_affine=True) 40 | self.avgpool = torch.nn.AdaptiveAvgPool1d(output_size=1) 41 | 42 | def forward(self, data): 43 | inputs = self.tokenizer(data, return_tensors="pt") 44 | with torch.no_grad(): 45 | x = self.model(**inputs) 46 | x = x.last_hidden_state 47 | x = self.norm(x) 48 | x = self.avgpool(x.transpose(1, 2)) 49 | x = x.transpose(1, 2).squeeze(1) 50 | # print(x.shape) 51 | return x 52 | 53 | 54 | # data2vec 55 | def text_data2vec_feat_extract(text_path, text_dir): 56 | 57 | text_model = Data2vecModel() 58 | 59 | if not os.path.exists(text_dir): 60 | os.mkdir(text_dir) 61 | 62 | data = json.load(open(text_path, 'r', encoding='utf-8')) 63 | 64 | for value in tqdm(data.values()): 65 | sent = value['sentences'] 66 | aid_name = value['audios'] 67 | for s, a in zip(sent, aid_name): 68 | save_path = os.path.join(text_dir, a.split(".")[0] + '.pt') 69 | if os.path.exists(save_path): 70 | continue 71 | 72 | outputs = text_model(s) 73 | torch.save(outputs, save_path) 74 | 75 | 76 | class DistilBert(torch.nn.Module): 77 | def __init__(self): 78 | super(DistilBert, self).__init__() 79 | self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") 80 | self.model = DistilBertModel.from_pretrained("distilbert-base-uncased") 81 | 82 | self.norm = torch.nn.LayerNorm(768, eps=1e-05, elementwise_affine=True) 83 | self.avgpool = torch.nn.AdaptiveAvgPool1d(output_size=1) 84 | 85 | def forward(self, data): 86 | inputs = self.tokenizer(data, return_tensors="pt") 87 | with torch.no_grad(): 88 | x = self.model(**inputs) 89 | x = x.last_hidden_state 90 | x = self.norm(x) 91 | x = self.avgpool(x.transpose(1, 2)) 92 | x = x.transpose(1, 2).squeeze(1) 93 | # print(x.shape) 94 | return x 95 | 96 | 97 | def text_distbert_feat_extract(text_path, text_dir): 98 | 99 | text_model = DistilBert() 100 | 101 | if not os.path.exists(text_dir): 102 | os.mkdir(text_dir) 103 | 104 | data = json.load(open(text_path, 'r', encoding='utf-8')) 105 | 106 | for value in tqdm(data.values()): 107 | sent = value['sentences'] 108 | aid_name = value['audios'] 109 | for s, a in zip(sent, aid_name): 110 | save_path = os.path.join(text_dir, a.split(".")[0] + '.pt') 111 | if os.path.exists(save_path): 112 | continue 113 | 114 | outputs = text_model(s) 115 | torch.save(outputs, save_path) 116 | 117 | 118 | if __name__ == '__main__': 119 | text_path1 = r'path\TACoS\train_audio_new.json' 120 | text_path2 = r'path\TACoS\test_audio_new.json' 121 | text_path3 = r'path\TACoS\val_audio_new.json' 122 | 123 | text_dir1 = r'path\TACoS\text_glove_feat_new' 124 | text_dir2 = r'path\TACoS\text_distbert_feat_new' 125 | 126 | text_glove_feat_extract(text_path1,text_dir1) 127 | text_glove_feat_extract(text_path2,text_dir1) 128 | text_glove_feat_extract(text_path3,text_dir1) 129 | 130 | text_distbert_feat_extract(text_path1,text_dir2) 131 | text_distbert_feat_extract(text_path2,text_dir2) 132 | text_distbert_feat_extract(text_path3,text_dir2) 133 | -------------------------------------------------------------------------------- /test_net.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from dtfnet.config import cfg 5 | from dtfnet.data import make_data_loader 6 | from dtfnet.engine.inference import inference 7 | from dtfnet.modeling import build_model 8 | from dtfnet.utils.checkpoint import MmnCheckpointer 9 | from dtfnet.utils.comm import synchronize, get_rank 10 | from dtfnet.utils.logger import setup_logger 11 | 12 | def main(): 13 | torch.multiprocessing.set_sharing_strategy('file_system') 14 | parser = argparse.ArgumentParser(description="DTFNet") 15 | parser.add_argument( 16 | "--config-file", 17 | default="activity/text_new_230919/text_1_3w/config.yml", 18 | metavar="FILE", 19 | help="path to config file", 20 | ) 21 | parser.add_argument("--local_rank", type=int, default=0) 22 | parser.add_argument( 23 | "--ckpt", 24 | help="The path to the checkpoint for test, default is the latest checkpoint.", 25 | default=None, 26 | ) 27 | parser.add_argument( 28 | "opts", 29 | help="Modify config options using the command-line", 30 | default=None, 31 | nargs=argparse.REMAINDER, 32 | ) 33 | 34 | args = parser.parse_args() 35 | 36 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 37 | distributed = num_gpus > 1 38 | 39 | if distributed: 40 | torch.cuda.set_device(args.local_rank) 41 | torch.distributed.init_process_group( 42 | backend="nccl", init_method="env://" 43 | ) 44 | synchronize() 45 | 46 | cfg.merge_from_file(args.config_file) 47 | cfg.merge_from_list(args.opts) 48 | cfg.freeze() 49 | 50 | save_dir = "" 51 | logger = setup_logger("dtf", save_dir, get_rank()) 52 | logger.info("Using {} GPUs".format(num_gpus)) 53 | logger.info(cfg) 54 | 55 | model = build_model(cfg) 56 | model.to(cfg.MODEL.DEVICE) 57 | model.eval() 58 | 59 | output_dir = cfg.OUTPUT_DIR 60 | checkpointer = MmnCheckpointer(cfg, model, save_dir=output_dir) 61 | # ckpt = cfg.MODEL.WEIGHT if args.ckpt is None else args.ckpt 62 | _ = checkpointer.load(args.ckpt, use_latest=args.ckpt is None) 63 | 64 | dataset_names = cfg.DATASETS.TEST 65 | data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed)[0] 66 | _ = inference( 67 | cfg, 68 | model, 69 | data_loaders_val, 70 | dataset_name=dataset_names, 71 | nms_thresh=cfg.TEST.NMS_THRESH, 72 | device=cfg.MODEL.DEVICE, 73 | epoch=999, 74 | ) 75 | synchronize() 76 | 77 | if __name__ == "__main__": 78 | main() 79 | -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import torch 5 | from torch import optim 6 | import torch.multiprocessing as mp 7 | mp.set_sharing_strategy('file_system') 8 | import sys 9 | sys.path.append('./') 10 | from dtfnet.data import make_data_loader 11 | from dtfnet.config import cfg 12 | from dtfnet.engine.inference import inference 13 | from dtfnet.engine.trainer import do_train 14 | from dtfnet.modeling import build_model 15 | from dtfnet.utils.checkpoint import MmnCheckpointer 16 | from dtfnet.utils.comm import synchronize, get_rank 17 | #from dtfnet.utils.imports import import_file 18 | from dtfnet.utils.logger import setup_logger 19 | from dtfnet.utils.miscellaneous import mkdir, save_config 20 | import logging 21 | 22 | def train(cfg, local_rank, distributed): 23 | model = build_model(cfg) 24 | logger = logging.getLogger("dtf.trainer") 25 | 26 | for name, module in model.named_children(): 27 | total_params = sum(p.numel() for p in module.parameters()) / 1e6 28 | logger.info(f"{name}: {total_params:.2f}M") 29 | 30 | device = torch.device(cfg.MODEL.DEVICE) 31 | model.to(device) 32 | if distributed: 33 | model = torch.nn.parallel.DistributedDataParallel( 34 | model, device_ids=[local_rank], output_device=local_rank, 35 | # this should be removed if we update BatchNorm stats 36 | broadcast_buffers=False, find_unused_parameters=True 37 | ) 38 | learning_rate = cfg.SOLVER.LR * 1.0 39 | data_loader = make_data_loader(cfg, is_train=True, is_distributed=distributed) 40 | 41 | bert_params = [] 42 | base_params = [] 43 | for name, param in model.named_parameters(): 44 | if "bert" in name: 45 | bert_params.append(param) 46 | else: 47 | base_params.append(param) 48 | 49 | param_dict = {'bert': bert_params, 'base': base_params} 50 | optimizer = optim.AdamW([{'params': base_params}, 51 | {'params': bert_params, 'lr': learning_rate * 0.1}], lr=learning_rate, betas=(0.9, 0.99), weight_decay=1e-5) 52 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.SOLVER.MILESTONES, gamma=0.1) 53 | output_dir = cfg.OUTPUT_DIR 54 | save_to_disk = get_rank() == 0 55 | checkpointer = MmnCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) 56 | arguments = {"epoch": 1} 57 | 58 | if cfg.SOLVER.RESUME: 59 | arguments = {"epoch": cfg.SOLVER.RESUME_EPOCH} 60 | path = '%s_model_%de.pth' % (cfg.MODEL.DTF.FEAT2D.NAME, cfg.SOLVER.RESUME_EPOCH - 1) 61 | weight_path = os.path.join(cfg.OUTPUT_DIR, path) 62 | weight_file = torch.load(weight_path, map_location=torch.device("cpu")) 63 | model.load_state_dict(weight_file.pop("model")) 64 | for _ in range(1, cfg.SOLVER.RESUME_EPOCH): 65 | optimizer.step() 66 | scheduler.step() 67 | 68 | test_period = cfg.SOLVER.TEST_PERIOD 69 | if test_period > 0: 70 | data_loader_val = make_data_loader(cfg, is_train=False, is_distributed=distributed, is_for_period=True) 71 | else: 72 | data_loader_val = None 73 | 74 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 75 | 76 | do_train( 77 | cfg, 78 | model, 79 | data_loader, 80 | data_loader_val, 81 | optimizer, 82 | scheduler, 83 | checkpointer, 84 | device, 85 | checkpoint_period, 86 | test_period, 87 | arguments, 88 | param_dict 89 | ) 90 | return model 91 | 92 | 93 | def run_test(cfg, model, distributed): 94 | if distributed: 95 | model = model.module 96 | torch.cuda.empty_cache() 97 | dataset_names = cfg.DATASETS.TEST 98 | data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) 99 | for dataset_name, data_loader_val in zip(dataset_names, data_loaders_val): 100 | inference( 101 | cfg, 102 | model, 103 | data_loader_val, 104 | dataset_name=dataset_name, 105 | nms_thresh=cfg.TEST.NMS_THRESH, 106 | device=cfg.MODEL.DEVICE, 107 | epoch=999, 108 | ) 109 | synchronize() 110 | 111 | 112 | def main(): 113 | parser = argparse.ArgumentParser(description="Mutual Matching Network") 114 | parser.add_argument( 115 | "--config-file", 116 | default="./configs/activitynet.yaml", 117 | metavar="FILE", 118 | help="path to config file", 119 | type=str, 120 | ) 121 | parser.add_argument("--local_rank", type=int, default=0) 122 | parser.add_argument( 123 | "--skip-test", 124 | dest="skip_test", 125 | help="Do not test the final model", 126 | action="store_true", 127 | ) 128 | parser.add_argument( 129 | "opts", 130 | help="Modify config options using the command-line", 131 | default=None, 132 | nargs=argparse.REMAINDER, 133 | ) 134 | 135 | args = parser.parse_args() 136 | seed = 25285 137 | random.seed(seed) 138 | torch.manual_seed(seed) 139 | torch.cuda.manual_seed_all(seed) 140 | torch.backends.cudnn.deterministic = True 141 | 142 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 143 | args.distributed = False 144 | 145 | if args.distributed: 146 | torch.cuda.set_device(args.local_rank) 147 | torch.distributed.init_process_group( 148 | backend="nccl", init_method="env://" 149 | ) 150 | synchronize() 151 | 152 | cfg.merge_from_file(args.config_file) 153 | cfg.merge_from_list(args.opts) 154 | cfg.freeze() 155 | 156 | output_dir = cfg.OUTPUT_DIR 157 | if output_dir: 158 | mkdir(output_dir) 159 | 160 | logger = setup_logger("dtf", output_dir, get_rank()) 161 | logger.info("Using {} GPUs".format(num_gpus)) 162 | logger.info(args) 163 | 164 | logger.info("Loaded configuration file {}".format(args.config_file)) 165 | with open(args.config_file, "r") as cf: 166 | config_str = "\n" + cf.read() 167 | # logger.info(config_str) 168 | # logger.info("Running with config:\n{}".format(cfg)) 169 | 170 | output_config_path = os.path.join(cfg.OUTPUT_DIR, 'config.yml') 171 | logger.info("Saving config into: {}".format(output_config_path)) 172 | # save overloaded model config in the output directory 173 | save_config(cfg, output_config_path) 174 | 175 | model = train(cfg, args.local_rank, args.distributed) 176 | 177 | # if not args.skip_test: 178 | # run_test(cfg, model, args.distributed) 179 | 180 | 181 | if __name__ == "__main__": 182 | #mp.set_start_method('spawn') 183 | # 184 | main() 185 | --------------------------------------------------------------------------------