├── CNN_RNN ├── model_attention.py └── train.py ├── README.md ├── data_process ├── build_vocab.py └── one_hot.py ├── deepmiml └── deepmiml.py └── visual_concept ├── train_visual.py └── visual_concept.py /CNN_RNN/model_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from torch.nn.utils.rnn import pack_padded_sequence 5 | from torch.nn.functional import avg_pool2d 6 | from torch.autograd import Variable 7 | 8 | def to_var(x, volatile=False): 9 | if torch.cuda.is_available(): 10 | x = x.cuda() 11 | return Variable(x, volatile=volatile) 12 | 13 | class EncoderCNN(nn.Module): 14 | def __init__(self, embed_size): 15 | """Load the pretrained ResNet-152 and replace top fc layer.""" 16 | super(EncoderCNN, self).__init__() 17 | vgg = models.vgg16(pretrained=True) 18 | modules = list(vgg.features[i] for i in range(29)) # delete the last fc layer. 19 | self.vgg = nn.Sequential(*modules) 20 | 21 | def forward(self, images): 22 | """Extract feature vectors from input images.""" 23 | with torch.no_grad(): 24 | features = self.vgg(images) 25 | N,C,H,W=features.size() 26 | #print('features',features.size()) 27 | features = features.view(N,C,H*W) 28 | features = features.permute(0,2,1) 29 | return features 30 | 31 | 32 | class DecoderRNN(nn.Module): 33 | def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=40): 34 | """Set the hyper-parameters and build the layers.""" 35 | super(DecoderRNN, self).__init__() 36 | self.embed = nn.Embedding(vocab_size, embed_size) 37 | self.lstm_cell = nn.LSTMCell(embed_size*2,hidden_size) 38 | self.linear = nn.Linear(hidden_size, vocab_size) 39 | self.max_seg_length = max_seq_length 40 | self.vocab_size= vocab_size 41 | self.vis_dim=512 42 | self.hidden_dim=1024 43 | vis_num=196 44 | self.att_vw = nn.Linear(self.vis_dim,self.vis_dim,bias=False) 45 | self.att_hw = nn.Linear(self.hidden_dim,self.vis_dim,bias=False) 46 | self.att_bias = nn.Parameter(torch.zeros(vis_num)) 47 | self.att_w = nn.Linear(self.vis_dim,1,bias=False) 48 | 49 | def attention(self,features,hiddens): 50 | att_fea = self.att_vw(features) 51 | att_h = self.att_hw(hiddens).unsqueeze(1) 52 | att_full = nn.ReLU()(att_fea + att_h +self.att_bias.view(1,-1,1)) 53 | att_out = self.att_w(att_full) 54 | alpha=nn.Softmax(dim=1)(att_out) 55 | context=torch.sum(features*alpha,1) 56 | return context,alpha 57 | 58 | def forward(self, features, captions, lengths): 59 | """Decode image feature vectors and generates captions.""" 60 | embeddings = self.embed(captions) 61 | feats=torch.mean(features,1).unsqueeze(1) 62 | embeddings = torch.cat((feats, embeddings), 1) 63 | batch_size, time_step = captions.size() 64 | predicts = to_var(torch.zeros(batch_size, time_step, self.vocab_size)) 65 | #packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 66 | hx=to_var(torch.zeros(batch_size, 1024)) 67 | cx=to_var(torch.zeros(batch_size, 1024)) 68 | for i in range(time_step): 69 | feas,_=self.attention(features,hx) 70 | input=torch.cat((feas,embeddings[:,i,:]),-1) 71 | hx,cx = self.lstm_cell(input,(hx,cx)) 72 | output = self.linear(hx) 73 | predicts[:,i,:]=output 74 | return predicts 75 | 76 | def sample(self, features, states=None): 77 | """Generate captions for given image features using greedy search.""" 78 | sampled_ids = [] 79 | hx=to_var(torch.zeros(1,1024)) 80 | cx=to_var(torch.zeros(1,1024)) 81 | inputs = torch.mean(features,1) 82 | alphas=[] 83 | for i in range(self.max_seg_length): 84 | feas,alpha=self.attention(features,hx) 85 | alphas.append(alpha) 86 | inputs=torch.cat((feas,inputs),-1) 87 | hx, cx = self.lstm_cell(inputs,(hx,cx)) # hiddens: (batch_size, 1, hidden_size) 88 | outputs = self.linear(hx.squeeze(1)) # outputs: (batch_size, vocab_size) 89 | _, predicted = outputs.max(1) # predicted: (batch_size) 90 | sampled_ids.append(predicted) 91 | inputs = self.embed(predicted) # inputs: (batch_size, embed_size) 92 | sampled_ids = torch.stack(sampled_ids, 1) # sampled_ids: (batch_size, max_seq_length) 93 | return sampled_ids,alphas 94 | -------------------------------------------------------------------------------- /CNN_RNN/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import os 6 | import pickle 7 | from data_loader_v2 import get_loader 8 | from build_vocab_v2 import Vocabulary 9 | from model_attention import EncoderCNN, DecoderRNN 10 | from torch.nn.utils.rnn import pack_padded_sequence 11 | from torchvision import transforms 12 | 13 | 14 | # Device configuration 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | def main(args): 18 | # Create model directory 19 | if not os.path.exists(args.model_path): 20 | os.makedirs(args.model_path) 21 | 22 | # Image preprocessing, normalization for the pretrained resnet 23 | transform = transforms.Compose([ 24 | transforms.RandomCrop(args.crop_size), 25 | transforms.RandomHorizontalFlip(), 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.485, 0.456, 0.406), 28 | (0.229, 0.224, 0.225))]) 29 | 30 | print("load vocabulary ...") 31 | # Load vocabulary wrapper 32 | with open(args.vocab_path, 'rb') as f: 33 | vocab = pickle.load(f) 34 | print("build data loader ...") 35 | # Build data loader 36 | data_loader = get_loader(args.image_dir, args.caption_path, vocab, 37 | transform, args.batch_size, 38 | shuffle=True, num_workers=args.num_workers) 39 | 40 | print("build the models ...") 41 | # Build the models 42 | encoder = EncoderCNN(args.embed_size).to(device) 43 | decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers).to(device) 44 | #encoder.load_state_dict(torch.load("models/encoder-2-1000.ckpt")) 45 | #decoder.load_state_dict(torch.load("models/decoder-2-1000.ckpt")) 46 | 47 | 48 | # Loss and optimizer 49 | criterion = nn.CrossEntropyLoss() 50 | params = list(decoder.parameters())# + list(encoder.linear.parameters()) + list(encoder.bn.parameters()) 51 | optimizer = torch.optim.Adam(params, lr=args.learning_rate) 52 | 53 | # Train the models 54 | total_step = len(data_loader) 55 | for epoch in range(args.num_epochs): 56 | for i, (images, captions, lengths) in enumerate(data_loader): 57 | 58 | # Set mini-batch dataset 59 | images = images.to(device) 60 | captions = captions.to(device) 61 | targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] 62 | 63 | # Forward, backward and optimize 64 | features = encoder(images) 65 | outputs = decoder(features, captions, lengths) 66 | outputs = pack_padded_sequence(outputs, lengths, batch_first=True)[0] 67 | loss = criterion(outputs, targets) 68 | decoder.zero_grad() 69 | encoder.zero_grad() 70 | loss.backward() 71 | optimizer.step() 72 | 73 | # Print log info 74 | if i % args.log_step == 0: 75 | print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}' 76 | .format(epoch, args.num_epochs, i, total_step, loss.item(), np.exp(loss.item()))) 77 | 78 | # Save the model checkpoints 79 | if (i+1) % args.save_step == 0: 80 | torch.save(decoder.state_dict(), os.path.join( 81 | args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1))) 82 | torch.save(encoder.state_dict(), os.path.join( 83 | args.model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1))) 84 | 85 | 86 | if __name__ == '__main__': 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--model_path', type=str, default='models/' , help='path for saving trained models') 89 | parser.add_argument('--crop_size', type=int, default=224 , help='size for randomly cropping images') 90 | parser.add_argument('--vocab_path', type=str, default='data/zh_vocab.pkl', help='path for vocabulary wrapper') 91 | parser.add_argument('--image_dir', type=str, default='data/resized2014', help='directory for resized images') 92 | parser.add_argument('--caption_path', type=str, default='data/annotations/img_tag.txt', help='path for train annotation json file') 93 | parser.add_argument('--log_step', type=int , default=10, help='step size for prining log info') 94 | parser.add_argument('--save_step', type=int , default=1000, help='step size for saving trained models') 95 | 96 | # Model parameters 97 | parser.add_argument('--embed_size', type=int , default=512, help='dimension of word embedding vectors') 98 | parser.add_argument('--hidden_size', type=int , default=1024, help='dimension of lstm hidden states') 99 | parser.add_argument('--num_layers', type=int , default=1, help='number of layers in lstm') 100 | 101 | parser.add_argument('--num_epochs', type=int, default=5) 102 | parser.add_argument('--batch_size', type=int, default=128) 103 | parser.add_argument('--num_workers', type=int, default=1) 104 | parser.add_argument('--learning_rate', type=float, default=0.001) 105 | args = parser.parse_args() 106 | print(args) 107 | main(args) 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multiple-instance-learning 2 | 3 | Pytorch implementation of three Multiple Instance Learning or Multi-classification papers, the performace of the visual_concept method is the best. 4 | 5 | 三种多示例学习方法实现,用于图像的多标签,其中 visual_concept效果最好 6 | 7 | * data_process: vocabulary id dict construction file, used by the three methods.构造词汇数据词典,三个方法均通用 8 | * CNN-RNN: A Unified Framework for Multi-label Image Classification https://arxiv.org/abs/1604.04573 9 | * Visual_concept: From captions to visual concepts and back https://arxiv.org/abs/1411.4952?context=cs 10 | * DeepMIML: https://cs.nju.edu.cn/zhouzh/zhouzh.files/publication/aaai17deepMIML.pdf 11 | 12 | ## Data prepare 13 | 14 | We will not provide the original dataset, but you can build it using your own dataset. Among them, **resized2014** is image dataset, **img_tag.txt** is the mapping dict file of image to tags, having that, you can generate the **zh_vocab.pkl** vocabulary file using https://github.com/Epiphqny/Multiple-instance-learning/blob/master/data_process/build_vocab.py 15 | 16 | ### Examples 17 | 18 | img_tag.txt(with number id represent different image name): 19 | 20 | 1\tab girl,bottle,car 21 | 22 | 2\tab boy 23 | 24 | 3\tab child,bike 25 | 26 | ... 27 | 28 | zh_vocab.pkl: 29 | 30 | self.idx2word={1:girl,2:bottle,3:boy,4:car...} 31 | 32 | self.word2idx={girl:1,bottle:2,boy:3,car:4...} 33 | 34 | Just an example, the realization may have some variation, the lines in the text file are in json format. 35 | -------------------------------------------------------------------------------- /data_process/build_vocab.py: -------------------------------------------------------------------------------- 1 | #import nltk 2 | import pickle 3 | import argparse 4 | from collections import Counter 5 | #from pycocotools.coco import COCO 6 | import json 7 | 8 | class Vocabulary(object): 9 | """Simple vocabulary wrapper.""" 10 | def __init__(self): 11 | self.word2idx = {} 12 | self.idx2word = {} 13 | self.idx = 0 14 | 15 | def add_word(self, word): 16 | if not word in self.word2idx: 17 | self.word2idx[word] = self.idx 18 | self.idx2word[self.idx] = word 19 | self.idx += 1 20 | 21 | def __call__(self, word): 22 | if not word in self.word2idx: 23 | return self.word2idx[''] 24 | return self.word2idx[word] 25 | 26 | def __len__(self): 27 | return len(self.word2idx) 28 | 29 | def build_vocab(file, threshold): 30 | """Build a simple vocabulary wrapper.""" 31 | counter = Counter() 32 | with open('data/annotations/img_tag.txt','r') as file: 33 | for line in file: 34 | id,tokens=json.loads(line) 35 | counter.update(tokens) 36 | 37 | # If the word frequency is less than 'threshold', then the word is discarded. 38 | words = [word for word, cnt in counter.items() if cnt >= threshold] 39 | 40 | # Create a vocab wrapper and add some special tokens. 41 | vocab = Vocabulary() 42 | vocab.add_word('') 43 | vocab.add_word('') 44 | vocab.add_word('') 45 | vocab.add_word('') 46 | 47 | # Add the words to the vocabulary. 48 | for i, word in enumerate(words): 49 | vocab.add_word(word) 50 | return vocab 51 | 52 | def main(args): 53 | vocab = build_vocab(file=args.caption_path, threshold=args.threshold) 54 | print('vocab',vocab.word2idx.keys()) 55 | vocab_path = args.vocab_path 56 | with open(vocab_path, 'wb') as f: 57 | pickle.dump(vocab, f) 58 | print("Total vocabulary size: {}".format(len(vocab))) 59 | print("Saved the vocabulary wrapper to '{}'".format(vocab_path)) 60 | 61 | 62 | if __name__ == '__main__': 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('--caption_path', type=str, 65 | default='data/annotations/img_tag.txt', 66 | help='path for train annotation file') 67 | parser.add_argument('--vocab_path', type=str, default='./data/zh_vocab.pkl', 68 | help='path for saving vocabulary wrapper') 69 | parser.add_argument('--threshold', type=int, default=4, 70 | help='minimum word count threshold') 71 | args = parser.parse_args() 72 | main(args) 73 | -------------------------------------------------------------------------------- /data_process/one_hot.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8-*- 2 | import numpy as np 3 | from build_vocab_v3 import Vocabulary 4 | import pickle 5 | import json 6 | 7 | with open('data/zh_vocab.pkl','rb') as f: 8 | vocab=pickle.load(f) 9 | 10 | 11 | f_img=open('data/annotations/img_tag.txt','r') 12 | f_save=open('data/annotations/-1_stop.txt','w') 13 | 14 | for line in f_img: 15 | id,tokens=json.loads(line) 16 | l=[] 17 | l=[-1 for i in range(1032)] 18 | for j in range(len(tokens)): 19 | try: 20 | l[j]=vocab.word2idx[tokens[j]] 21 | except: 22 | pass 23 | x=[id,l,len(tokens)] 24 | f_save.write('%s\n'%json.dumps(x)) 25 | f_img.close() 26 | f_save.close() 27 | -------------------------------------------------------------------------------- /deepmiml/deepmiml.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torchvision.models as models 5 | 6 | class EncoderCNN(nn.Module): 7 | def __init__(self): 8 | super(EncoderCNN,self).__init__() 9 | vgg=models.vgg16(pretrained=True) 10 | modules=list(vgg.features[i] for i in range(29)) 11 | self.vgg=nn.Sequential(*modules) 12 | def forward(self,images): 13 | with torch.no_grad(): 14 | features=self.vgg(images) 15 | #N,C,H,W=features.size() 16 | #features=features.view(N,C,H*W) 17 | #features=features.permute(0,2,1) 18 | return features 19 | 20 | class DeepMIML(nn.Module): 21 | def __init__(self,L=1032,K=100): 22 | super(DeepMIML,self).__init__() 23 | self.L=L 24 | self.K=K 25 | self.conv1=nn.Conv2d(in_channels=512,out_channels=L*K,kernel_size=1) 26 | self.pool1=nn.MaxPool2d((K,1),stride=(1,1)) 27 | self.activation=nn.Sigmoid() 28 | self.pool2=nn.MaxPool2d((1,14*14),stride=(1,1)) 29 | def forward(self,features): 30 | N,C,H,W=features.size() 31 | n_instances=H*W 32 | conv1=self.conv1(features) 33 | conv1=conv1.view(N,self.L,self.K,n_instances) 34 | pool1=self.pool1(conv1) 35 | act=self.activation(pool1) 36 | pool2=self.pool2(act) 37 | 38 | out=pool2.view(N,self.L) 39 | print('out',out[0]) 40 | return out 41 | 42 | -------------------------------------------------------------------------------- /visual_concept/train_visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import os 6 | import pickle 7 | from data_loader_v3 import get_loader 8 | from build_vocab_v3 import Vocabulary 9 | from visual_concept import EncoderCNN, Decoder 10 | from torch.nn.utils.rnn import pack_padded_sequence 11 | from torchvision import transforms 12 | from torch.autograd import Variable 13 | import torch.nn.functional as F 14 | import time 15 | import random 16 | 17 | 18 | class Sample_loss(torch.nn.Module): 19 | def __init__(self): 20 | super(Sample_loss,self).__init__() 21 | 22 | def forward(self,x,y,lengths): 23 | loss=0 24 | batch_size=len(lengths)//8 25 | for i in range(batch_size): 26 | label_index=y[i][:lengths[i]] 27 | values=1-x[i][label_index] 28 | prod=1 29 | for value in values: 30 | prod=prod*value 31 | print('prod',prod) 32 | loss+=1-prod 33 | loss=Variable(loss, requires_grad=True).unsqueeze(0) 34 | return loss 35 | 36 | class bce_loss(torch.nn.Module): 37 | def __init__(self): 38 | super(bce_loss,self).__init__() 39 | 40 | def forward(self,x,y): 41 | loss=F.binary_cross_entropy(x,y) 42 | loss=Variable(loss.cuda(), requires_grad=True) 43 | return loss 44 | 45 | 46 | def main(args): 47 | # Create model directory 48 | if not os.path.exists(args.model_path): 49 | os.makedirs(args.model_path) 50 | 51 | # Image preprocessing, normalization for the pretrained resnet 52 | transform = transforms.Compose([ 53 | transforms.RandomCrop(args.crop_size), 54 | transforms.RandomHorizontalFlip(), 55 | transforms.ToTensor(), 56 | transforms.Normalize((0.485, 0.456, 0.406), 57 | (0.229, 0.224, 0.225))]) 58 | 59 | print("load vocabulary ...") 60 | # Load vocabulary wrapper 61 | with open(args.vocab_path, 'rb') as f: 62 | vocab = pickle.load(f) 63 | print("build data loader ...") 64 | # Build data loader 65 | data_loader = get_loader(args.image_dir, args.caption_path, vocab, 66 | transform, args.batch_size, 67 | shuffle=True, num_workers=args.num_workers) 68 | 69 | print("build the models ...") 70 | # Build the models 71 | encoder = nn.DataParallel(EncoderCNN()).cuda() 72 | decoder = nn.DataParallel(Decoder()).cuda() 73 | 74 | params = list(decoder.parameters()) 75 | optimizer = torch.optim.Adam(params, lr=args.learning_rate) 76 | time_start=time.time() 77 | # Train the models 78 | total_step = len(data_loader) 79 | for epoch in range(args.num_epochs): 80 | for i, (images, targets,lengths) in enumerate(data_loader): 81 | # Set mini-batch dataset 82 | images = images.cuda() 83 | targets = targets.cuda() 84 | 85 | #targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] 86 | 87 | # Forward, backward and optimize 88 | features = encoder(images) 89 | outputs = decoder(features) 90 | 91 | pos=nn.functional.binary_cross_entropy(outputs*targets,targets)*1e3 92 | neg=nn.functional.binary_cross_entropy(outputs*(1-targets)+targets,targets)*1 93 | 94 | loss=pos+neg 95 | decoder.zero_grad() 96 | encoder.zero_grad() 97 | loss.backward() 98 | optimizer.step() 99 | 100 | # Print log info 101 | if i % args.log_step == 0: 102 | time_end=time.time() 103 | print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Time:{}' 104 | .format(epoch, args.num_epochs, i, total_step, loss.item(),time_end-time_start)) 105 | time_start=time_end 106 | 107 | # Save the model checkpoints 108 | if (i+1) % args.save_step == 0: 109 | torch.save(decoder.state_dict(), os.path.join( 110 | args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1))) 111 | torch.save(encoder.state_dict(), os.path.join( 112 | args.model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1))) 113 | 114 | 115 | if __name__ == '__main__': 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('--model_path', type=str, default='models/' , help='path for saving trained models') 118 | parser.add_argument('--crop_size', type=int, default=224 , help='size for randomly cropping images') 119 | parser.add_argument('--vocab_path', type=str, default='data/zh_vocab.pkl', help='path for vocabulary wrapper') 120 | parser.add_argument('--image_dir', type=str, default='data/resized2014', help='directory for resized images') 121 | parser.add_argument('--caption_path', type=str, default='data/annotations/img_vector.txt', help='path for train annotation json file') 122 | parser.add_argument('--log_step', type=int , default=1, help='step size for prining log info') 123 | parser.add_argument('--save_step', type=int , default=390, help='step size for saving trained models') 124 | 125 | # Model parameters 126 | 127 | parser.add_argument('--num_epochs', type=int, default=5) 128 | parser.add_argument('--batch_size', type=int, default=1024) 129 | parser.add_argument('--num_workers', type=int, default=1) 130 | parser.add_argument('--learning_rate', type=float, default=1e-3) 131 | args = parser.parse_args() 132 | print(args) 133 | main(args) 134 | -------------------------------------------------------------------------------- /visual_concept/visual_concept.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torchvision.models as models 5 | import torch.nn.functional as F 6 | 7 | class EncoderCNN(nn.Module): 8 | def __init__(self): 9 | super(EncoderCNN,self).__init__() 10 | vgg=models.vgg16(pretrained=True) 11 | modules=list(vgg.features[i] for i in range(29)) 12 | self.vgg=nn.Sequential(*modules) 13 | def forward(self,images): 14 | with torch.no_grad(): 15 | features=self.vgg(images) 16 | #N,C,H,W=features.size() 17 | #features=features.view(N,C,H*W) 18 | #features=features.permute(0,2,1) 19 | return features 20 | 21 | class Decoder(nn.Module): 22 | def __init__(self): 23 | super(Decoder,self).__init__() 24 | self.relu1=nn.ReLU() 25 | self.conv1=nn.Conv2d(in_channels=512,out_channels=4096,kernel_size=5) 26 | self.relu2=nn.ReLU() 27 | self.conv2=nn.Conv2d(in_channels=4096,out_channels=4096,kernel_size=3) 28 | self.relu0=nn.ReLU() 29 | self.conv3=nn.Conv2d(in_channels=4096,out_channels=1032,kernel_size=1) 30 | self.sigmoid=nn.Sigmoid() 31 | self.pool_mil=nn.MaxPool2d(8,stride=0) 32 | def forward(self,features): 33 | N,C,H,W=features.size() 34 | relu0=self.relu0(features) 35 | conv1=self.conv1(relu0) 36 | relu1=self.relu1(conv1) 37 | conv2=self.conv2(relu1) 38 | relu2=self.relu2(conv2) 39 | conv3=self.conv3(relu2) 40 | 41 | sigmoid=self.sigmoid(conv3) 42 | pool=self.pool_mil(sigmoid) 43 | x=pool.squeeze(2).squeeze(2) 44 | x1 = torch.add(torch.mul(sigmoid.view(x.size(0), 1032, -1), -1), 1) 45 | cumprod=torch.prod(x1,2) 46 | out = torch.min(x, torch.add(torch.mul(cumprod, -1), 1)) 47 | print('out',out[0]) 48 | return out 49 | 50 | --------------------------------------------------------------------------------