├── models ├── __init__.py ├── bert.py ├── fakeTransformer.py └── vl_model.py ├── dataset ├── __init__.py └── dataset.py ├── utils ├── __init__.py ├── load.py └── config.py ├── moco_box.yml ├── LICENSE └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .load import load_wenlan_model 2 | -------------------------------------------------------------------------------- /models/bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import AutoModel 4 | 5 | 6 | class Bert(nn.Module): 7 | 8 | def __init__(self, cfg): 9 | super(Bert, self).__init__() 10 | self.cfg = cfg 11 | self.bert = AutoModel.from_pretrained(cfg.ENCODER) 12 | 13 | def forward(self, x): 14 | y = self.bert(x, return_dict=True).last_hidden_state 15 | return y 16 | -------------------------------------------------------------------------------- /models/fakeTransformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FakeTransformer(nn.Module): 6 | 7 | def __init__(self, input_dim, hidden_dim, output_dim): 8 | super(FakeTransformer, self).__init__() 9 | 10 | self.fc1 = nn.Linear(input_dim, hidden_dim) 11 | self.dropout1 = nn.Dropout(0.5) 12 | self.relu = nn.ReLU() 13 | self.fc2 = nn.Linear(hidden_dim, output_dim) 14 | 15 | def forward(self, x): 16 | out = self.fc1(x) 17 | 18 | out = self.relu(out) 19 | out = self.fc2(out) 20 | return out 21 | -------------------------------------------------------------------------------- /utils/load.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | 4 | ############################################# 5 | # File Name: load.py 6 | # Author: Haoyu Lu 7 | # Mail: lhy1998@ruc.edu.cn 8 | # Created Time: 2022-3-18 19:17:34 9 | ############################################# 10 | 11 | import torch 12 | import sys 13 | 14 | from ..models import build_network 15 | from ..utils.config import cfg_from_yaml_file, cfg 16 | 17 | def load_wenlan_model(load_checkpoint='../wenlan-video-model.pth', cfg_file='none', device='cpu'): 18 | 19 | cfg_from_yaml_file(cfg_file, cfg) 20 | 21 | cfg.MODEL.IMG_SIZE = 384 22 | cfg.MODEL.IS_EXTRACT = True 23 | cfg.DATASET.TEST_SET = 'test' 24 | 25 | model = build_network(cfg.MODEL) 26 | model.load_state_dict(torch.load(load_checkpoint)) 27 | 28 | model = model.to(device) 29 | model.eval() 30 | 31 | return model -------------------------------------------------------------------------------- /moco_box.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: VL 3 | MODE: coco 4 | 5 | ENCODER: 'hfl/chinese-roberta-wwm-ext-large' 6 | 7 | POS_DNCODER: False 8 | POSITION_DROP: 0.0 9 | IMG_FEATURE_DIM: 2560 10 | 11 | IMG_TRANSFORMER_HEAD: 4 12 | IMG_TRANSFORMER_LAYER: 4 # 3.0-v0: 1 13 | MAX_IMG_LEN: 26 14 | 15 | 16 | TEXT_FEATURE_DIM: 1024 17 | TEXT_TRANSFORMER_HEAD: 4 18 | TEXT_TRANSFORMER_LAYER: 4 19 | 20 | 21 | HIDDEN_DIM_1: 2560 22 | HIDDEN_DIM_2: 2560 23 | MAX_TEXT_LEN: 80 24 | COTRANSFORMER_CFG: [1, 3, 2048, 2048, 2048, 2048] 25 | 26 | FIX_ENCODER: True 27 | IS_EXTRACT: False 28 | CNN: "tf_efficientnet_b7_ns" 29 | 30 | IMG_SIZE: 384 31 | 32 | GRID_SIZE: [1, 5] 33 | ROI_GRID_SIZE: 4 34 | num_frame: 10 35 | num_clip: 2 36 | 37 | DATASET: 38 | NAME: 'MSRDataset_boxes' 39 | DATADIR: '/data1/xyb' 40 | JSONPATH: 'xyb_all_bbox.jsonl' 41 | WORKERS: 8 42 | 43 | 44 | OPTIMIZATION: 45 | EPOCH: 200 46 | BATCH_SIZE: 1 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 rucmlcv 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Wenlan-Video-Public 2 | 3 | Wenlan-Video-Public是基于Wenlan 2.0(首个中文通用图文多模态大规模预训练模型)的文澜视频多模态预训练模型。 4 | 5 | Wenlan 1.0论文:[WenLan: Bridging Vision and Language by Large-Scale Multi-Modal Pre-Training](https://arxiv.org/abs/2103.06561) 6 | 7 | Wenlan 2.0论文:[WenLan 2.0: Make AI Imagine via a Multimodal Foundation Model](https://arxiv.org/abs/2110.14378) 8 | 9 | ## 适用场景 10 | 11 | 适用场景示例:视频检索文本、文本检索视频、视频标注、视频零样本分类、作为其他下游多模态任务的输入特征等。 12 | 13 | ## 技术特色 14 | 15 | 1. Wenlan-Video-Public使用对比学习算法将图像/视频和文本映射到了同一特征空间,可用于弥补视觉特征和文本特征之间存在的隔阂。 16 | 2. 基于视觉-语言弱相关的假设,除了能理解对图像/视频的描述性文本外,也可以捕捉图像/视频和文本之间存在的抽象联系。 17 | 3. 视觉编码器和文本编码器可分别独立运行,有利于实际生产环境中的部署。 18 | 4. 三亿通用图文对+50万爱奇艺视频联合训练,强大的泛化性与通用性。 19 | 20 | 21 | ## 运行环境 22 | ``` 23 | python 3.8 24 | torch 1.8 25 | jsonlines 26 | tqdm 27 | easydict 28 | torchvision 29 | transformers 30 | timm 31 | ``` 32 | 33 | ## 模型下载 34 | https://pan.baidu.com/s/1TbSqCEKquxndWWv9yUECMg?pwd=31pf 35 | 或 36 | https://drive.google.com/file/d/1iHt06QSiKOLXsybOatMUtggTwTXXSca8/view?usp=sharing 37 | 38 | ## 快速使用 39 | 40 | 1. 安装wenlan-video-public库 41 | ``` 42 | pip install wenlan-video-public==1.0.2 43 | ``` 44 | 45 | 2. 导入模型 46 | 47 | ``` 48 | from wenlan_video import load_wenlan_model 49 | 50 | # load_checkpoint 下载好的模型地址 51 | # cfg_file github目录下moco_box.yaml文件地址 52 | model = load_wenlan_model(load_checkpoint, cfg_file, device=device) 53 | ``` 54 | 55 | 3. 读取视频/文本 56 | 57 | ``` 58 | from wenlan_video import wenlan_transforms 59 | wenlan_transforms = wenlan_transforms() 60 | 61 | # VIDEO_PATH为视频抽帧好的帧(jpg, png),帧数大于10,命名按顺序标号 62 | video, video_boxes = wenlan_transforms.video_transform(VIDEO_PATH, device=device) 63 | text, textMask = wenlan_transforms.text_transform(‘Hello Wenlan’, device=device) 64 | ``` 65 | 66 | 4. 同时抽取视频/文本特征 67 | ``` 68 | videoFea, textFea = model(video, video_boxes, text.unsqueeze(0), textMask.unsqueeze(0)) 69 | ``` 70 | 71 | 5. 分别抽取视频/文本特征 72 | ``` 73 | videoFea = model.encode_video(videoFea, video_boxes) 74 | textFea = model.encode_text(texts.unsqueeze(0), maskTexts.unsqueeze(0)) 75 | ``` 76 | 77 | ## Have Fun! 78 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.utils.data as data 5 | import torchvision.transforms as transforms 6 | 7 | import json 8 | import random 9 | 10 | 11 | import jsonlines 12 | import argparse 13 | 14 | from PIL import ImageFilter 15 | from transformers import AutoTokenizer 16 | from PIL import Image 17 | from PIL import ImageFile 18 | 19 | 20 | def visual_transforms_box(is_train=True, new_size=384): 21 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 22 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 23 | 24 | return transforms.Compose([ 25 | transforms.ToTensor(), 26 | transforms.Resize((new_size, new_size)), 27 | normalize]) 28 | 29 | 30 | 31 | 32 | def getLanMask(seq_lens, max_len): 33 | # seq_lens (bs) 34 | mask = torch.ones((seq_lens.size(0), max_len)) # (bs, max_len) 35 | idxs = torch.arange(max_len).unsqueeze(dim=0) # (1, max_len) 36 | seq_lens = seq_lens.unsqueeze(-1) # (bs, 1) 37 | mask = torch.where(idxs < seq_lens, mask, torch.Tensor([0.0])) 38 | return mask 39 | 40 | 41 | 42 | class wenlan_transforms(): 43 | def __init__(self): 44 | self.tokenizer = AutoTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext-large') 45 | 46 | def text_transform(self, text, max_length=50, device='cpu'): 47 | # text_transform = tokenizer #AutoTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext-large') 48 | text_info = self.tokenizer(text, padding='max_length', truncation=True, 49 | max_length=max_length, return_tensors='pt') 50 | text = text_info.input_ids.reshape(-1).to(device) 51 | text_len = torch.sum(text_info.attention_mask) 52 | textMask = getLanMask(text_len.unsqueeze(0), max_length).squeeze(0).to(device) 53 | 54 | return text, textMask 55 | 56 | def video_transform(self, video_path, device='cpu'): 57 | visual_transform = visual_transforms_box(False, 384) 58 | 59 | image_boxes = [] 60 | images = [] 61 | sorted_frames = os.listdir(video_path) 62 | for frame in sorted_frames: 63 | 64 | img_path = os.path.join(video_path, frame) 65 | image = Image.open(img_path).convert('RGB') 66 | image = visual_transform(image) 67 | images.append(image) 68 | img_box_s = [] 69 | 70 | 71 | for grid_num in [1, 5]: 72 | for i in range(grid_num): 73 | for j in range(grid_num): 74 | img_box_s.append(torch.from_numpy(np.array( 75 | [i * (384 / grid_num), j * (384 / grid_num), (i + 1) * (384 / grid_num), 76 | (j + 1) * (384 / grid_num)]))) 77 | image_boxes.append(torch.stack(img_box_s, 0)) 78 | 79 | image_boxes = torch.stack(image_boxes, 0).to(device) 80 | images = torch.stack(images, 0).to(device) 81 | 82 | 83 | return images, image_boxes 84 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | 4 | ############################################# 5 | # File Name: config.py 6 | # Author: Haoyu Lu 7 | # Mail: lhy1998@ruc.edu.cn 8 | # Created Time: 2022-3-18 19:17:34 9 | ############################################# 10 | 11 | from pathlib import Path 12 | 13 | import yaml 14 | from easydict import EasyDict 15 | 16 | 17 | def log_config_to_file(cfg, pre='cfg', logger=None): 18 | for key, val in cfg.items(): 19 | if isinstance(cfg[key], EasyDict): 20 | logger.info('\n%s.%s = edict()' % (pre, key)) 21 | log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger) 22 | continue 23 | logger.info('%s.%s: %s' % (pre, key, val)) 24 | 25 | 26 | def cfg_from_list(cfg_list, config): 27 | """Set config keys via list (e.g., from command line).""" 28 | from ast import literal_eval 29 | assert len(cfg_list) % 2 == 0 30 | for k, v in zip(cfg_list[0::2], cfg_list[1::2]): 31 | key_list = k.split('.') 32 | d = config 33 | for subkey in key_list[:-1]: 34 | assert subkey in d, 'NotFoundKey: %s' % subkey 35 | d = d[subkey] 36 | subkey = key_list[-1] 37 | assert subkey in d, 'NotFoundKey: %s' % subkey 38 | try: 39 | value = literal_eval(v) 40 | except: 41 | value = v 42 | 43 | if type(value) != type(d[subkey]) and isinstance(d[subkey], EasyDict): 44 | key_val_list = value.split(',') 45 | for src in key_val_list: 46 | cur_key, cur_val = src.split(':') 47 | val_type = type(d[subkey][cur_key]) 48 | cur_val = val_type(cur_val) 49 | d[subkey][cur_key] = cur_val 50 | elif type(value) != type(d[subkey]) and isinstance(d[subkey], list): 51 | val_list = value.split(',') 52 | for k, x in enumerate(val_list): 53 | val_list[k] = type(d[subkey][0])(x) 54 | d[subkey] = val_list 55 | else: 56 | assert type(value) == type(d[subkey]), \ 57 | 'type {} does not match original type {}'.format(type(value), type(d[subkey])) 58 | d[subkey] = value 59 | 60 | 61 | def merge_new_config(config, new_config): 62 | if '_BASE_CONFIG_' in new_config: 63 | with open(new_config['_BASE_CONFIG_'], 'r') as f: 64 | try: 65 | yaml_config = yaml.load(f, Loader=yaml.FullLoader) 66 | except: 67 | yaml_config = yaml.load(f) 68 | config.update(EasyDict(yaml_config)) 69 | 70 | for key, val in new_config.items(): 71 | if not isinstance(val, dict): 72 | config[key] = val 73 | continue 74 | if key not in config: 75 | config[key] = EasyDict() 76 | merge_new_config(config[key], val) 77 | 78 | return config 79 | 80 | 81 | def cfg_from_yaml_file(cfg_file, config): 82 | with open(cfg_file, 'r') as f: 83 | try: 84 | new_config = yaml.load(f, Loader=yaml.FullLoader) 85 | except: 86 | new_config = yaml.load(f) 87 | 88 | merge_new_config(config=config, new_config=new_config) 89 | 90 | return config 91 | 92 | 93 | cfg = EasyDict() 94 | cfg.ROOT_DIR = (Path(__file__).resolve().parent / '../').resolve() 95 | # cfg.LOCAL_RANK = 0 96 | -------------------------------------------------------------------------------- /models/vl_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | 4 | ############################################# 5 | # File Name: vl_model.py 6 | # Author: Haoyu Lu 7 | # Mail: lhy1998@ruc.edu.cn 8 | # Created Time: 2022-3-18 19:17:34 9 | ############################################# 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torchvision 14 | from transformers import AutoTokenizer 15 | from .fakeTransformer import FakeTransformer 16 | from .bert import Bert 17 | 18 | 19 | import torch.nn.functional as F 20 | import timm 21 | 22 | import numpy as np 23 | import math 24 | 25 | 26 | 27 | 28 | class ImgLearnableEncoder(nn.Module): 29 | def __init__(self, model_cfg, args): 30 | super(ImgLearnableEncoder, self).__init__() 31 | 32 | 33 | self.backbone = timm.create_model(model_cfg.CNN, pretrained=False) 34 | 35 | self.model_cfg = model_cfg 36 | self.learnable = nn.ModuleDict() 37 | 38 | img_encoder_layer = nn.TransformerEncoderLayer(d_model=self.model_cfg.IMG_FEATURE_DIM, nhead=self.model_cfg.IMG_TRANSFORMER_HEAD) 39 | self.learnable['imgAtt'] = nn.TransformerEncoder(img_encoder_layer, num_layers=self.model_cfg.IMG_TRANSFORMER_LAYER) 40 | self.learnable['imgFC'] = FakeTransformer(model_cfg.IMG_FEATURE_DIM, model_cfg.HIDDEN_DIM_1, model_cfg.HIDDEN_DIM_2) 41 | self.learnable['max_pool'] = nn.Sequential(nn.AvgPool2d(model_cfg.ROI_GRID_SIZE, stride=1)) 42 | 43 | # self.learnable['final_mlp'] = FakeTransformer(model_cfg.HIDDEN_DIM_2*2, model_cfg.HIDDEN_DIM_2, model_cfg.HIDDEN_DIM_2) 44 | 45 | 46 | def roi_grid_pool(self, spatial_features_2d, rois): 47 | """ 48 | Args: 49 | rois: (B, num_rois, 4) 50 | spatial_features_2d: (B, C, H, W) 51 | Returns: 52 | pooled_features : (B, num_rois, C) 53 | """ 54 | batch_size = spatial_features_2d.size(0) 55 | rois = rois.detach() 56 | height, width = spatial_features_2d.size(2), spatial_features_2d.size(3) 57 | down_sample_ratio = self.model_cfg.IMG_SIZE / height 58 | pooled_features_list = [] 59 | torch.backends.cudnn.enabled = False 60 | for b_id in range(batch_size): 61 | # Map global boxes coordinates to feature map coordinates 62 | x1 = rois[b_id, :, 0] / down_sample_ratio 63 | y1 = rois[b_id, :, 1] / down_sample_ratio 64 | x2 = rois[b_id, :, 2] / down_sample_ratio 65 | y2 = rois[b_id, :, 3] / down_sample_ratio 66 | angle = torch.zeros((1), device=spatial_features_2d.device) 67 | cosa = torch.cos(angle) 68 | sina = torch.sin(angle) 69 | 70 | theta = torch.stack(( 71 | (x2 - x1) / (width - 1) * cosa, (x2 - x1) / (width - 1) * (-sina), (x1 + x2 - width + 1) / (width - 1), 72 | (y2 - y1) / (height - 1) * sina, (y2 - y1) / (height - 1) * cosa, (y1 + y2 - height + 1) / (height - 1) 73 | ), dim=1).view(-1, 2, 3).float() 74 | 75 | grid_size = self.model_cfg.ROI_GRID_SIZE 76 | grid = nn.functional.affine_grid( 77 | theta, 78 | torch.Size((rois.size(1), spatial_features_2d.size(1), grid_size, grid_size)) 79 | ) 80 | 81 | pooled_features = nn.functional.grid_sample( 82 | spatial_features_2d[b_id].unsqueeze(0).expand(rois.size(1), spatial_features_2d.size(1), height, width), 83 | grid 84 | ) 85 | pooled_features = self.learnable['max_pool'](pooled_features) 86 | pooled_features_list.append(pooled_features.squeeze()) 87 | 88 | torch.backends.cudnn.enabled = True 89 | pooled_features = torch.stack(pooled_features_list, dim=0) 90 | 91 | return pooled_features 92 | 93 | def forward(self, imgFea, image_boxs): 94 | imgFea = self.backbone.forward_features(imgFea) 95 | imgFea = self.roi_grid_pool(imgFea, image_boxs) 96 | imgFea = F.normalize(imgFea, p=2, dim=-1) 97 | 98 | imgFea = self.learnable['imgAtt'](imgFea.transpose(0, 1)).transpose(0,1) # TODO 99 | imgFea = imgFea.mean(1) 100 | imgFea = self.learnable['imgFC'](imgFea) 101 | 102 | 103 | return imgFea 104 | 105 | 106 | class TextLearnableEncoder(nn.Module): 107 | def __init__(self, model_cfg): 108 | super(TextLearnableEncoder, self).__init__() 109 | 110 | self.backbone = Bert(model_cfg) 111 | self.model_cfg = model_cfg # TODO: model_cfg 112 | 113 | self.learnable = nn.ModuleDict() 114 | 115 | text_encoder_layer = nn.TransformerEncoderLayer(d_model=model_cfg.TEXT_FEATURE_DIM, nhead=model_cfg.TEXT_TRANSFORMER_HEAD) 116 | self.learnable['textAtt'] = nn.TransformerEncoder(text_encoder_layer, num_layers=model_cfg.TEXT_TRANSFORMER_LAYER) 117 | 118 | self.learnable['textFC'] = FakeTransformer(model_cfg.TEXT_FEATURE_DIM, model_cfg.HIDDEN_DIM_1, model_cfg.HIDDEN_DIM_2) 119 | 120 | 121 | def forward(self, textFea, maskTexts): 122 | textFea = self.backbone(textFea) 123 | textFea = F.normalize(textFea, p=2, dim=-1) 124 | if self.model_cfg.TEXT_TRANSFORMER_LAYER != 0: 125 | textFea = self.learnable['textAtt'](textFea.transpose(0, 1), src_key_padding_mask=(maskTexts == 0)).transpose(0,1) 126 | tmpMask = torch.where(maskTexts == 1, torch.tensor([1.0], device=maskTexts.device), 127 | torch.tensor([0.0], device=maskTexts.device)) 128 | textFea = (textFea * tmpMask.unsqueeze(-1)).sum(dim=1) / tmpMask.sum(dim=1).unsqueeze(-1) # (bs, dim) 129 | textFea = self.learnable['textFC'](textFea) 130 | 131 | return textFea 132 | 133 | 134 | class VL_model(nn.Module): 135 | 136 | def __init__(self, model_cfg, args): 137 | super(VL_model, self).__init__() 138 | 139 | self.model_cfg = model_cfg 140 | self.num_frame = model_cfg.num_frame 141 | 142 | self.learnable = nn.ModuleDict() 143 | self.learnable['imgencoder'] = ImgLearnableEncoder(model_cfg, args) 144 | self.learnable['textencoder'] = TextLearnableEncoder(model_cfg) 145 | 146 | video_layer = nn.TransformerEncoderLayer(d_model=model_cfg.HIDDEN_DIM_2, nhead=self.model_cfg.IMG_TRANSFORMER_HEAD) 147 | self.learnable['videoAtt'] = nn.TransformerEncoder(video_layer, num_layers=self.model_cfg.IMG_TRANSFORMER_LAYER) 148 | 149 | self.tokenizer = AutoTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext-large') 150 | 151 | def imgfea2videofea(self,t): 152 | t = t.reshape(-1, self.num_frame, t.shape[1]) 153 | # t = self.learnable['videoAtt'](t.transpose(0,1)).transpose(0,1) 154 | v = t.mean(dim = 1) 155 | return v 156 | 157 | 158 | def forward(self, imgFea, image_boxs, texts, maskTexts): 159 | with torch.no_grad(): 160 | imgFea = self.learnable['imgencoder'](imgFea, image_boxs) # 161 | imgFea = self.imgfea2videofea(imgFea) 162 | textFea = self.learnable['textencoder'](texts, maskTexts) # 163 | 164 | imgFea = F.normalize(imgFea, p=2, dim=-1) 165 | textFea = F.normalize(textFea, p=2, dim=-1) 166 | 167 | return [imgFea, textFea] 168 | 169 | 170 | def encode_video(self, imgFea, image_boxs): 171 | with torch.no_grad(): 172 | imgFea = self.learnable['imgencoder'](imgFea, image_boxs) # 173 | imgFea = self.imgfea2videofea(imgFea) 174 | imgFea = F.normalize(imgFea, p=2, dim=-1) 175 | return imgFea 176 | 177 | def encode_text(self, texts, maskTexts): 178 | with torch.no_grad(): 179 | textFea = self.learnable['textencoder'](texts, maskTexts) # 180 | textFea = F.normalize(textFea, p=2, dim=-1) 181 | return textFea 182 | 183 | 184 | 185 | 186 | 187 | --------------------------------------------------------------------------------