├── .gitignore ├── train ├── __init__.py ├── Model.py └── image_captioning │ └── char_model.py ├── data ├── README.md ├── default.jpg ├── models │ ├── vg-decoder-5-3000.pkl │ ├── coco-decoder-5-3000.pkl │ ├── coco-encoder-5-3000.pkl │ ├── lang_mod-decoder-5-3000.pkl │ ├── vg-encoder-5-3000.pkl │ └── lang_mod-encoder-5-3000.pkl └── google_images │ ├── imghttps:i.ytimg.comviSfLV8hD7zX4maxresdefault.jpg.jpg │ ├── imghttps:upload.wikimedia.orgwikipediacommonsthumbaacArriva_T6_nearside.JPG1200px-Arriva_T6_nearside.JPG.jpg │ ├── imghttp:cdn2-www.dogtime.comassetsuploads201101file_23244_what-is-the-appenzeller-sennenhunde-dog-300x189.jpg.jpg │ └── imghttps:upload.wikimedia.orgwikipediacommonsthumbdd9First_Student_IC_school_bus_202076.jpg220px-First_Student_IC_school_bus_202076.jpg.jpg ├── utils ├── pycocotools │ ├── __init__.py │ ├── mask.py │ ├── _mask.pyx │ ├── coco.py │ └── cocoeval.py ├── shell │ ├── getReps.sh │ ├── getGlove.sh │ └── getData.sh ├── image_utils.py ├── README.md ├── numpy_functions.py ├── config.py ├── test_data.py ├── urls.py ├── find_pairs.py ├── imagenet_utils.py ├── sample.py ├── generate_clusters.py ├── compute_results.py ├── build_vocab.py ├── data_stream.py ├── build_data.py └── image_and_text_utils.py ├── paper └── naacl-camera-ready.pdf ├── requirements.txt ├── .gitattributes ├── bayesian_agents ├── rsaWorld.py ├── rsaState.py └── joint_rsa.py ├── README.md ├── main.py └── recursion_schemes └── recursion_schemes.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | #init 2 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | the neural models 2 | -------------------------------------------------------------------------------- /utils/pycocotools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /utils/shell/getReps.sh: -------------------------------------------------------------------------------- 1 | ipython3 charpragcap/utils/build_data.py 2 | -------------------------------------------------------------------------------- /data/default.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reubenharry/Recurrent-RSA/HEAD/data/default.jpg -------------------------------------------------------------------------------- /paper/naacl-camera-ready.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reubenharry/Recurrent-RSA/HEAD/paper/naacl-camera-ready.pdf -------------------------------------------------------------------------------- /data/models/vg-decoder-5-3000.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7ea5b7528d866cb75d7e432c25997818536fe692343c45869f8a1f7fc0b85cfc 3 | size 6401104 4 | -------------------------------------------------------------------------------- /data/models/coco-decoder-5-3000.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:df7a4a6c6feff5bdf10c373970285bfd18a8446e3ca4e24377e8e360edc9dc8e 3 | size 6401159 4 | -------------------------------------------------------------------------------- /data/models/coco-encoder-5-3000.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3fc770312d06eadc77ed0e4703f6c0ad40d1fa4962226b9c914b332950e65d64 3 | size 235407686 4 | -------------------------------------------------------------------------------- /data/models/lang_mod-decoder-5-3000.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bcb605fb32d5073e725f6c331744726798476603addb3c730e9717c9ff6a761f 3 | size 6401104 4 | -------------------------------------------------------------------------------- /data/models/vg-encoder-5-3000.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9fe520ac8a12c06204749d06b3819d1b2ea8cf66b43f6b331ab281f17946dc0f 3 | size 235399860 4 | -------------------------------------------------------------------------------- /data/models/lang_mod-encoder-5-3000.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d90dfbb838af063c5cfd1a5574a02b6d8186e60af5ef24a0f8de88afad764614 3 | size 235399866 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | matplotlib 3 | requests 4 | nltk 5 | torch 6 | torchvision 7 | jupyter 8 | h5py 9 | tensorflow 10 | numpy 11 | keras 12 | pillow 13 | pandas 14 | more_itertools 15 | -------------------------------------------------------------------------------- /data/google_images/imghttps:i.ytimg.comviSfLV8hD7zX4maxresdefault.jpg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reubenharry/Recurrent-RSA/HEAD/data/google_images/imghttps:i.ytimg.comviSfLV8hD7zX4maxresdefault.jpg.jpg -------------------------------------------------------------------------------- /utils/shell/getGlove.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | mkdir -p resources/wordEmbeddings 4 | cd resources/wordEmbeddings 5 | wget http://nlp.stanford.edu/data/glove.6B.zip 6 | unzip glove.6B.zip 7 | rm glove.6B.zip 8 | 9 | cd ../.. 10 | -------------------------------------------------------------------------------- /data/google_images/imghttps:upload.wikimedia.orgwikipediacommonsthumbaacArriva_T6_nearside.JPG1200px-Arriva_T6_nearside.JPG.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reubenharry/Recurrent-RSA/HEAD/data/google_images/imghttps:upload.wikimedia.orgwikipediacommonsthumbaacArriva_T6_nearside.JPG1200px-Arriva_T6_nearside.JPG.jpg -------------------------------------------------------------------------------- /data/google_images/imghttp:cdn2-www.dogtime.comassetsuploads201101file_23244_what-is-the-appenzeller-sennenhunde-dog-300x189.jpg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reubenharry/Recurrent-RSA/HEAD/data/google_images/imghttp:cdn2-www.dogtime.comassetsuploads201101file_23244_what-is-the-appenzeller-sennenhunde-dog-300x189.jpg.jpg -------------------------------------------------------------------------------- /data/google_images/imghttps:upload.wikimedia.orgwikipediacommonsthumbdd9First_Student_IC_school_bus_202076.jpg220px-First_Student_IC_school_bus_202076.jpg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reubenharry/Recurrent-RSA/HEAD/data/google_images/imghttps:upload.wikimedia.orgwikipediacommonsthumbdd9First_Student_IC_school_bus_202076.jpg220px-First_Student_IC_school_bus_202076.jpg.jpg -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | def get_rep_from_url(url): 2 | 3 | import urllib.request 4 | from PIL import Image as PIL_Image 5 | import shutil 6 | import requests 7 | 8 | response = requests.get(url, stream=True) 9 | with open('charpragcap/resources/img.jpg', 'wb') as out_file: 10 | shutil.copyfileobj(response.raw, out_file) 11 | del response -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | todo: 2 | 3 | setup notebook and display cropped images: heck that they are reasonable 4 | finish making dictionary 5 | script to find total number of items and create id lists of train,val and test 6 | write option to supply distractors and ids: find nice way of writing this options: e.g. stream of dicts or something? 7 | give stream an argument for which set of ids to iterate over -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | data/models/coco-decoder-5-3000.pkl filter=lfs diff=lfs merge=lfs -text 2 | data/models/coco-encoder-5-3000.pkl filter=lfs diff=lfs merge=lfs -text 3 | data/models/lang_mod-decoder-5-3000.pkl filter=lfs diff=lfs merge=lfs -text 4 | data/models/lang_mod-encoder-5-3000.pkl filter=lfs diff=lfs merge=lfs -text 5 | data/models/vg-decoder-5-3000.pkl filter=lfs diff=lfs merge=lfs -text 6 | data/models/vg-encoder-5-3000.pkl filter=lfs diff=lfs merge=lfs -text 7 | -------------------------------------------------------------------------------- /utils/numpy_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def softmax(x): 4 | """Compute softmax values for each sets of scores in x.""" 5 | e_x = np.exp(x - np.max(x)) 6 | return e_x / e_x.sum() 7 | 8 | def uniform_vector(length): 9 | return np.ones((length))/length 10 | 11 | def make_initial_prior(initial_image_prior,initial_rationality_prior,initial_speaker_prior): 12 | 13 | return np.log(np.multiply.outer(initial_image_prior,np.multiply.outer(initial_rationality_prior,initial_speaker_prior))) 14 | -------------------------------------------------------------------------------- /utils/shell/getData.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | mkdir -p charpragcap/resources/visual_genome_data 4 | cd charpragcap/resources/visual_genome_data 5 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip 6 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip 7 | unzip images.zip 8 | unzip images2.zip 9 | #rm images.zip 10 | #rm images2.zip 11 | 12 | mv VG_100K_2/* VG_100K/ 13 | rmdir VG_100K_2 14 | 15 | mkdir ../visual_genome_JSON 16 | cd ../visual_genome_JSON 17 | wget http://visualgenome.org/static/data/dataset/image_data.json.zip 18 | unzip image_data.json.zip 19 | rm image_data.json.zip 20 | 21 | wget http://visualgenome.org/static/data/dataset/region_descriptions.json.zip 22 | unzip region_descriptions.json.zip 23 | rm region_descriptions.json.zip 24 | cd ../.. 25 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | IMG_DATA_PATH="charpragcap/resources/visual_genome_data/" 3 | REP_DATA_PATH="charpragcap/resources/resnet_reps/" 4 | TRAINED_MODEL_PATH="data/models/" 5 | WEIGHTS_PATH="charpragcap/resources/weights/" 6 | S0_WEIGHTS_PATH="s0_weights" 7 | S0_PRIME_WEIGHTS_PATH="s0_prime_weights" 8 | caption,region = 0,1 9 | start_token = {"word":"","char":'^'} 10 | stop_token = {"word":"","char":'$'} 11 | pad_token = '&' 12 | sym_set = list('&^$ abcdefghijklmnopqrstuvwxyz') 13 | stride_length = 10 14 | start_index = 1 15 | stop_index = 2 16 | pad_index = 0 17 | batch_size = 50 18 | max_sentence_length = 60 19 | 20 | train_size,val_size,test_size = 0.98,0.01,0.01 21 | rep_size = 2048 22 | img_rep_layer = 'hiddenrep' 23 | 24 | char_to_index = defaultdict(int) 25 | for i,x in enumerate(sym_set): 26 | char_to_index[x] = i 27 | index_to_char = defaultdict(lambda:'') 28 | for i,x in enumerate(sym_set): 29 | index_to_char[i] = x 30 | -------------------------------------------------------------------------------- /bayesian_agents/rsaWorld.py: -------------------------------------------------------------------------------- 1 | class RSA_World: 2 | 3 | def __init__( 4 | self, 5 | target, 6 | speaker, 7 | rationality="DEFAULTBAD", 8 | ): 9 | 10 | self.target=target 11 | self.rationality=rationality 12 | self.speaker=speaker 13 | 14 | def __hash__(self): 15 | return hash((self.target,self.speaker,self.rationality)) 16 | 17 | def __eq__(self,other): 18 | return self.target==other.target and self.speaker==other.speaker and self.rationality==other.rationality 19 | 20 | def set_values(self,values): 21 | 22 | self.target=values[0] 23 | self.rationality=values[1] 24 | 25 | def __repr__(self): 26 | return "" % (self.target,self.rationality, self.speaker) 27 | 28 | # ADD IN 29 | # self.speaker=values[2] 30 | 31 | 32 | # def initial_prior(): 33 | # pass 34 | # #something like: np.zeros() 35 | # return np.outer(image_prior,speaker_prior,rationality_prior) 36 | 37 | # def timestep_prior(): 38 | # out = np.zeros(timestep,*dimensions) 39 | # out[0]=initial_prior() 40 | # return out -------------------------------------------------------------------------------- /utils/test_data.py: -------------------------------------------------------------------------------- 1 | # show a first images and their ground truth captions 2 | # check that stored rep = generated rep for a few random ones 3 | # cehck that unmemoized is same as memoized 4 | 5 | # tests that the saved reps are aved in the right order and so on 6 | 7 | def check_reps(): 8 | 9 | repsandcaps = single_stream(train,X0_type='rep') 10 | idsandcaps = single_stream(train,X0_type='id') 11 | for i in range(10): 12 | repandcaps = next(repsandcaps) 13 | idandcaps = next(idsandcaps) 14 | img_id = idandcaps[0] 15 | img_rep = repandcaps[0] 16 | print("img id:",img_id) 17 | real_rep = item_to_rep(img_id) 18 | stored_rep = np.expand_dims(img_rep,0) 19 | 20 | print("real rep",real_rep) 21 | print("stored rep",stored_rep) 22 | print(real_rep==stored_rep) 23 | assert(np.array_equal(real_rep,stored_rep)) 24 | 25 | def view_data(): 26 | for full_id,cap_in,cap_out in single_stream(test,X0_type='id'): 27 | img = get_img_from_id(full_id,id_to_caption) 28 | display(img) 29 | print("".join([index_to_char[x] for x in cap_in])) 30 | print(''.join([index_to_char[np.argmax(x)] for x in cap_out])) 31 | break 32 | 33 | -------------------------------------------------------------------------------- /bayesian_agents/rsaState.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils.image_and_text_utils import max_sentence_length,vectorize_caption 3 | 4 | class RSA_State: 5 | 6 | def __init__( 7 | self, 8 | initial_world_prior, 9 | listener_rationality=1.0 10 | ): 11 | # should deprecate these two above 12 | 13 | 14 | self.context_sentence=np.expand_dims(np.expand_dims(vectorize_caption("")[0],0),-1) 15 | 16 | # priors for the various dimensions of the world 17 | # keep track of the dimensions of the prior 18 | self.dim = {"image":0,"rationality":1,"speaker":2} 19 | 20 | # the priors at t>0 only matter if we aren't updating the prior at each step 21 | self.world_priors=np.asarray([initial_world_prior for x in range(max_sentence_length+1)]) 22 | 23 | self.listener_rationality=listener_rationality 24 | 25 | # this is a bit confusing, isn't it 26 | self.timestep=1 27 | 28 | 29 | def __hash__(self): 30 | return hash((self.timestep,self.listener_rationality,tuple(self.context_sentence))) 31 | 32 | # def __eq__(self,other): 33 | # return self.target==other.target and self.speaker==other.speaker and self.rationality==other.rationality 34 | 35 | 36 | -------------------------------------------------------------------------------- /utils/urls.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | url1 = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d9/First_Student_IC_school_bus_202076.jpg/220px-First_Student_IC_school_bus_202076.jpg" 4 | url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/a/ac/Arriva_T6_nearside.JPG/1200px-Arriva_T6_nearside.JPG" 5 | url3 = "http://www.petmd.com/sites/default/files/what-does-it-mean-when-cat-wags-tail.jpg" 6 | url4 = "https://www.petfinder.com/wp-content/uploads/2012/11/91615172-find-a-lump-on-cats-skin-632x475.jpg" 7 | url5 = "https://ichef.bbci.co.uk/news/976/media/images/83351000/jpg/_83351965_explorer273lincolnshirewoldssouthpicturebynicholassilkstone.jpg" 8 | url6 = "http://www.vipbuspartyhire.com/wp-content/uploads/2013/08/RedLondonDoubleDeckerBusHire.jpg" 9 | url7 = "https://wallpaperbrowse.com/media/images/pictures-2.jpg" 10 | url8 = "https://static.pexels.com/photos/219998/pexels-photo-219998.jpeg" 11 | urls = [url1,url2,url3,url4,url5,url6,url7,url8] 12 | 13 | url9 = "https://assets.bwbx.io/images/users/iqjWHBFdfxIU/i6bua_ZLxNG0/v1/800x-1.jpg" 14 | 15 | # img_and_reps = pickle.load(open("charpragcap/resources/img_and_reps",'rb')) 16 | 17 | 18 | # urls = [url1,url2] 19 | # reps = [img_and_reps[url] for url in urls] -------------------------------------------------------------------------------- /utils/find_pairs.py: -------------------------------------------------------------------------------- 1 | import re 2 | import pickle 3 | from charpragcap.utils.image_and_text_utils import split_dataset,index_to_char,devectorize_caption 4 | from nltk.corpus import stopwords 5 | 6 | 7 | id_to_caption = pickle.load(open("charpragcap/resources/id_to_caption",'rb')) 8 | train,val,test = split_dataset(id_to_caption) 9 | print(len(test)) 10 | 11 | def cap_to_words(cap): 12 | return [word for word in re.sub('[^a-z ]',"",devectorize_caption(cap)).split() if word not in stopwords.words('english')],devectorize_caption(cap) 13 | 14 | 15 | 16 | 17 | 18 | def find_pairs(item): 19 | 20 | cap = set(cap_to_words(id_to_caption[item][0][0])[0]) 21 | fst_half_id,snd_half_id = item.split('_') 22 | l=[] 23 | 24 | for t in test: 25 | fst_half_new_id,snd_half_new_id = t.split('_') 26 | if fst_half_id!=fst_half_new_id: 27 | new_cap,full_cap = cap_to_words(id_to_caption[t][0][0]) 28 | l.append((len(set(new_cap).intersection(cap)),full_cap,t)) 29 | 30 | return (sorted(l,key=lambda x: x[0],reverse=True)[:30]) 31 | 32 | def all_pairs(): 33 | out = [] 34 | for i,t in enumerate(test): 35 | print(i) 36 | pairs = find_pairs(t) 37 | # if pairs[1][0]>1: 38 | print(pairs[:3]) 39 | if pairs[2][0]>1: 40 | out.append(pairs[:3]) 41 | if i>5000: 42 | break 43 | 44 | return sorted(out,key=lambda x:x[-1]) 45 | # break 46 | 47 | if __name__ == '__main__': 48 | out = list(all_pairs()) 49 | print(out) 50 | pickle.dump(out,open('charpragcap/resources/distractor_pairs_short','wb')) 51 | -------------------------------------------------------------------------------- /utils/imagenet_utils.py: -------------------------------------------------------------------------------- 1 | #CODE FROM KERAS 2 | 3 | import numpy as np 4 | import json 5 | 6 | from keras.utils.data_utils import get_file 7 | from keras import backend as K 8 | 9 | CLASS_INDEX = None 10 | CLASS_INDEX_PATH = 'https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json' 11 | 12 | def preprocess_input(x, dim_ordering='default'): 13 | if dim_ordering == 'default': 14 | dim_ordering = K.image_dim_ordering() 15 | assert dim_ordering in {'tf', 'th'} 16 | 17 | if dim_ordering == 'th': 18 | x[:, 0, :, :] -= 103.939 19 | x[:, 1, :, :] -= 116.779 20 | x[:, 2, :, :] -= 123.68 21 | # 'RGB'->'BGR' 22 | x = x[:, ::-1, :, :] 23 | else: 24 | x[:, :, :, 0] -= 103.939 25 | x[:, :, :, 1] -= 116.779 26 | x[:, :, :, 2] -= 123.68 27 | # 'RGB'->'BGR' 28 | x = x[:, :, :, ::-1] 29 | return x 30 | 31 | 32 | def decode_predictions(preds, top=5): 33 | global CLASS_INDEX 34 | if len(preds.shape) != 2 or preds.shape[1] != 1000: 35 | raise ValueError('`decode_predictions` expects ' 36 | 'a batch of predictions ' 37 | '(i.e. a 2D array of shape (samples, 1000)). ' 38 | 'Found array with shape: ' + str(preds.shape)) 39 | if CLASS_INDEX is None: 40 | fpath = get_file('imagenet_class_index.json', 41 | CLASS_INDEX_PATH, 42 | cache_subdir='models') 43 | CLASS_INDEX = json.load(open(fpath)) 44 | results = [] 45 | for pred in preds: 46 | top_indices = pred.argsort()[-top:][::-1] 47 | result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices] 48 | results.append(result) 49 | return results 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pragmatic-Image-Captioning 2 | 3 | This codebase implements Bayesian pragmatics (i.e. the Rational Speech Acts model - RSA) over the top of a deep neural image 4 | captioning model. These are desirable to combine, since RSA gives rise to linguistically realistic effects, while deep models can capture (at least some) of the flexibility and expressivity of natural language. 5 | 6 | Summary: 7 | 8 | * Suppose we have a space of possible sentences U 9 | 10 | * Choosing the sentence which is the most informative caption for identifying image w out of a set of images W is a useful task (moreover, it represents a key instance of natural language pragmatics) 11 | 12 | * Viewed as an inference problem (of a speaker agent P(U|W=w) ), this task is intractable when U is large. 13 | 14 | * But if the space of possible sentences U is recursively generated, there's a solution: at each stage of the recursive generation of a sentence u, we perform a local inference as to the most informative next step 15 | 16 | * Category theoretic perspective (very roughly): this amounts to mapping the inference onto the coalgebra of the anamorphism used to generate the distribution over U 17 | 18 | * Linguistic perspective (very broadly): we're pushing pragmatics into the lower levels of language, rather than adding it on top 19 | 20 | * Computational perspective: this provides us a way to get the power of Bayesian models of pragmatics (see Rational Speech Acts) with deep machine learning models powerful enough to model natural language 21 | 22 | 23 | 24 | Setup: 25 | 26 | To run the model, you'll need python3.6, and to have cloned the repo with git lfs. Run main.py - you can supply urls for your own images, but on the current settings, the captions probably won't be great (needs beam search, and the example uses greedy) 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /utils/sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import argparse 5 | import pickle 6 | import os 7 | from torch.autograd import Variable 8 | from torchvision import transforms 9 | from utils.build_vocab import Vocabulary 10 | from PIL import Image 11 | import re 12 | 13 | def to_var(x, volatile=False): 14 | if torch.cuda.is_available(): 15 | x = x.cuda() 16 | return Variable(x, volatile=volatile) 17 | 18 | 19 | 20 | 21 | 22 | def load_image_from_path(path, transform=None): 23 | 24 | from PIL import Image as PIL_Image 25 | 26 | 27 | image = Image.open(path) 28 | image = image.resize([224, 224], Image.LANCZOS) 29 | # image = image.crop([0,0,224,224]) 30 | if transform is not None: 31 | image = transform(image).unsqueeze(0) 32 | 33 | return image 34 | 35 | def load_image(url, transform=None): 36 | 37 | import urllib.request 38 | from PIL import Image as PIL_Image 39 | import shutil 40 | import requests 41 | 42 | hashed_url = re.sub('/','',url) 43 | 44 | response = requests.get(url, stream=True) 45 | with open('data/google_images/img'+hashed_url+'.jpg', 'wb') as out_file: 46 | shutil.copyfileobj(response.raw, out_file) 47 | # del response 48 | print(url,response) 49 | # print(os.listdir()) 50 | 51 | 52 | image = Image.open('data/google_images/img'+hashed_url+'.jpg') 53 | # print("image loaded (sample.py)") 54 | image = image.resize([224, 224], Image.LANCZOS) 55 | # width = image.size[0] 56 | # height = image.size[1] 57 | 58 | # if width>height: 59 | # new_height=224 60 | # new_width=224 * (width/height) 61 | # else: 62 | # new_width=224 63 | # new_height=224 * (height/width) 64 | 65 | # b = image.resize([int(new_width),int(new_height)],PIL_Image.LANCZOS) 66 | # # b = a.thumbnail([224, 224], PIL_Image.LANCZOS) 67 | # # c = b.crop([0,0,224,224]) 68 | # image = b.crop([0,0,224,224]) 69 | 70 | if transform is not None: 71 | image = transform(image).unsqueeze(0) 72 | 73 | # image = transforms.ToTensor()(image).unsqueeze(0) 74 | 75 | return image 76 | 77 | -------------------------------------------------------------------------------- /utils/generate_clusters.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import re 4 | from nltk.corpus import stopwords 5 | from charpragcap.utils.image_and_text_utils import devectorize_caption,split_dataset,get_rep_from_id 6 | import copy 7 | 8 | def cap_to_words(cap): 9 | return [word for word in re.sub('[^a-z ]',"",devectorize_caption(cap)).split() if word not in stopwords.words('english')] 10 | 11 | if __name__ == '__main__': 12 | 13 | make = False 14 | name = "test_clusters" 15 | 16 | id_to_caption = pickle.load(open("charpragcap/resources/id_to_caption",'rb')) 17 | trains,vals,tests = split_dataset() 18 | 19 | # new_dic = {} 20 | # for x in val: 21 | # new_dic[x]=id_to_caption[x] 22 | # id_to_caption=new_dic 23 | 24 | # print(len(list(id_to_caption))) 25 | id_to_caption = pickle.load(open("charpragcap/resources/id_to_caption",'rb')) 26 | ids = [x for x in trains+tests+vals if int(x.split("_")[0])>414114][:10000] 27 | 28 | # vocab = {} 29 | # for x in ids: 30 | # words = [word for word in re.sub('[^a-z ]',"",cap).split() if word not in stopwords.words('english')] 31 | # for word in words: 32 | # vocab.add(word) 33 | 34 | # vocab = sorted(list(vocab)) 35 | 36 | if make: 37 | 38 | print(len(ids)) 39 | 40 | id_to_words={} 41 | 42 | cluster_mat = np.zeros((len(ids),len(ids))) 43 | 44 | for i,idx in enumerate(ids): 45 | print(i) 46 | 47 | try: sent_1=id_to_words[idx] 48 | except: 49 | sent_1=set(cap_to_words(id_to_caption[idx][0][0])) 50 | id_to_words[idx]=sent_1 51 | 52 | for j,idx2 in enumerate(ids): 53 | 54 | try: sent_2 = id_to_words[idx2] 55 | except: 56 | sent_2=set(cap_to_words(id_to_caption[idx2][0][0])) 57 | id_to_words[idx2]=sent_2 58 | 59 | overlap= len(sent_1.intersection(sent_2)) 60 | cluster_mat[i,j]=overlap 61 | 62 | print("MADE MATRIX") 63 | pickle.dump(cluster_mat,open("charpragcap/resources/cluster_mat",'wb')) 64 | pickle.dump(id_to_words,open("charpragcap/resources/id_to_words",'wb')) 65 | 66 | cluster_mat=pickle.load(open("charpragcap/resources/cluster_mat",'rb')) 67 | id_to_words=pickle.load(open("charpragcap/resources/id_to_words",'rb')) 68 | 69 | excluded=[] 70 | clusters = [] 71 | for i in range(len(ids)): 72 | if ids[i] not in excluded: 73 | 74 | cluster = [ids[x] for x in np.argsort(-cluster_mat[i])[:10]] 75 | print(cluster) 76 | print([id_to_words[x] for x in cluster]) 77 | clusters.append(cluster) 78 | excluded+=cluster 79 | 80 | pickle.dump(clusters, open("charpragcap/resources/cluster_dicts/"+name,'wb')) 81 | 82 | 83 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # this code will generate a literal caption and a pragmatic caption (referring expression) for the first of the urls provided in the context of the rest 2 | 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | import re 6 | import requests 7 | import time 8 | import pickle 9 | import tensorflow as tf 10 | import numpy as np 11 | from keras.preprocessing import image 12 | from collections import defaultdict 13 | 14 | from utils.config import * 15 | from utils.numpy_functions import uniform_vector, make_initial_prior 16 | from recursion_schemes.recursion_schemes import ana_greedy,ana_beam 17 | from bayesian_agents.joint_rsa import RSA 18 | 19 | 20 | urls = [ 21 | "https://upload.wikimedia.org/wikipedia/commons/thumb/a/ac/Arriva_T6_nearside.JPG/1200px-Arriva_T6_nearside.JPG", 22 | "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d9/First_Student_IC_school_bus_202076.jpg/220px-First_Student_IC_school_bus_202076.jpg" 23 | ] 24 | 25 | # code is written to be able to jointly infer speaker's rationality and neural model, but for simplicity, let's assume these are fixed 26 | # the rationality of the S1 27 | rat = [100.0] 28 | # the neural model: captions trained on MSCOCO ("coco") are more verbose than VisualGenome ("vg") 29 | model = ["vg"] 30 | number_of_images = len(urls) 31 | # the model starts of assuming it's equally likely any image is the intended referent 32 | initial_image_prior=uniform_vector(number_of_images) 33 | initial_rationality_prior=uniform_vector(1) 34 | initial_speaker_prior=uniform_vector(1) 35 | initial_world_prior = make_initial_prior(initial_image_prior,initial_rationality_prior,initial_speaker_prior) 36 | 37 | # make a character level speaker, using torch model (instead of tensorflow model) 38 | speaker_model = RSA(seg_type="char",tf=False) 39 | speaker_model.initialize_speakers(model) 40 | # set the possible images and rationalities 41 | speaker_model.speaker_prior.set_features(images=urls,tf=False,rationalities=rat) 42 | speaker_model.initial_speakers[0].set_features(images=urls,tf=False,rationalities=rat) 43 | # generate a sentence by unfolding stepwise, from the speaker: greedy unrolling used here, not beam search: much better to use beam search generally 44 | literal_caption = ana_greedy( 45 | speaker_model, 46 | target=0, 47 | depth=0, 48 | speaker_rationality=0, 49 | speaker=0, 50 | start_from=list(""), 51 | initial_world_prior=initial_world_prior) 52 | 53 | pragmatic_caption = ana_greedy( 54 | speaker_model, 55 | target=0, 56 | depth=1, 57 | speaker_rationality=0, 58 | speaker=0, 59 | start_from=list(""), 60 | initial_world_prior=initial_world_prior) 61 | 62 | print("Literal caption:\n",literal_caption) 63 | print("Pragmatic caption:\n",pragmatic_caption) 64 | -------------------------------------------------------------------------------- /utils/compute_results.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from charpragcap.rsa.cataRSA import CataRSA 4 | from charpragcap.rsa.cataRSA_working_beam import CataRSA as CataRSA_working_beam 5 | from charpragcap.utils.config import * 6 | from charpragcap.utils.image_and_text_utils import char_to_index,vectorize_caption,get_rep_from_img_id,split_dataset 7 | from charpragcap.utils.generate_clusters import generate_clusters 8 | from charpragcap.utils.urls import reps 9 | 10 | 11 | id_to_caption = pickle.load(open("charpragcap/resources/id_to_caption",'rb')) 12 | train,val,test = split_dataset(id_to_caption) 13 | 14 | full_test_ids = sorted(list(set([x.split('_')[0] for x in test]))) 15 | 16 | cataRSA = CataRSA( 17 | imgs=[reps[0]], 18 | img_paths=[], 19 | trained_s0_path=WEIGHTS_PATH+S0_WEIGHTS_PATH, 20 | trained_s0_prime_path=WEIGHTS_PATH+S0_PRIME_WEIGHTS_PATH, 21 | l0_type='from_s0', 22 | ) 23 | 24 | def compute_results(model,images): 25 | 26 | s0_results = {} 27 | for img_id in images: 28 | model.images=np.array([get_rep_from_img_id(img_id)]) 29 | # out = model.ana_greedy(speaker_rationality=1.0, listener_rationality=1.0, depth=0,start_from="",img_prior=np.log(np.asarray([0.5]))) 30 | out = model.ana_beam( 31 | speaker_rationality=1.0, 32 | listener_rationality=1.0, 33 | depth=0,start_from="", 34 | decay_rate=-1.0, 35 | img_prior=np.log(np.asarray([0.5])) 36 | ) 37 | 38 | s0_results[img_id]=out 39 | print("RESULTS:",out) 40 | model._speaker_cache = {} 41 | model._listener_cache = {} 42 | 43 | 44 | pickle.dump(s0_results,open("charpragcap/resources/s0_results_beam",'wb')) 45 | 46 | # compute_results(cataRSA,full_test_ids[:100]) 47 | generate_clusters() 48 | 49 | 50 | def compute_results_pragmatic(model,depth,name): 51 | clusters = pickle.load(open("charpragcap/resources/clusters",'rb')) 52 | s0_results = {} 53 | out = model.ana_beam( 54 | speaker_rationality=1.0, 55 | listener_rationality=1.0, 56 | depth=depth,start_from="", 57 | decay_rate=-1.0, 58 | img_prior=np.log(np.asarray([1/2,1/2])) 59 | ) 60 | for items in clusters[:20]: 61 | model.images=np.array([get_rep_from_img_id(item[1]) for item in items]) 62 | # out = model.ana_greedy(speaker_rationality=1.0, listener_rationality=1.0, depth=0,start_from="",img_prior=np.log(np.asarray([0.5]))) 63 | 64 | s0_results[tuple([(item[1],item[0]) for item in items])]=out 65 | print(out) 66 | print(items) 67 | model._speaker_cache = {} 68 | model._listener_cache = {} 69 | 70 | 71 | pickle.dump(s0_results,open("charpragcap/resources/s0_results_"+name,'wb')) 72 | 73 | compute_results_pragmatic(cataRSA,1,"depth_1") 74 | compute_results_pragmatic(cataRSA,2,"depth_2") 75 | compute_results_pragmatic(cataRSA,3,"depth_3") -------------------------------------------------------------------------------- /utils/build_vocab.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import pickle 3 | import argparse 4 | from collections import Counter 5 | from utils.pycocotools.coco import COCO 6 | 7 | 8 | class Vocabulary(object): 9 | """Simple vocabulary wrapper.""" 10 | def __init__(self): 11 | self.word2idx = {} 12 | self.idx2word = {} 13 | self.idx = 0 14 | 15 | def add_word(self, word): 16 | if not word in self.word2idx: 17 | self.word2idx[word] = self.idx 18 | self.idx2word[self.idx] = word 19 | self.idx += 1 20 | 21 | def __call__(self, word): 22 | if not word in self.word2idx: 23 | return self.word2idx[''] 24 | return self.word2idx[word] 25 | 26 | def __len__(self): 27 | return len(self.word2idx) 28 | 29 | def build_vocab(json, threshold): 30 | """Build a simple vocabulary wrapper.""" 31 | coco = COCO(json) 32 | counter = Counter() 33 | ids = coco.anns.keys() 34 | for i, id in enumerate(ids): 35 | caption = str(coco.anns[id]['caption']) 36 | tokens = nltk.tokenize.word_tokenize(caption.lower()) 37 | counter.update(tokens) 38 | 39 | if i % 1000 == 0: 40 | print("[%d/%d] Tokenized the captions." %(i, len(ids))) 41 | 42 | # If the word frequency is less than 'threshold', then the word is discarded. 43 | words = [word for word, cnt in counter.items() if cnt >= threshold] 44 | 45 | # Creates a vocab wrapper and add some special tokens. 46 | vocab = Vocabulary() 47 | vocab.add_word('') 48 | vocab.add_word('') 49 | vocab.add_word('') 50 | vocab.add_word('') 51 | 52 | # Adds the words to the vocabulary. 53 | for i, word in enumerate(words): 54 | vocab.add_word(word) 55 | return vocab 56 | 57 | def main(args): 58 | vocab = build_vocab(json=args.caption_path, 59 | threshold=args.threshold) 60 | vocab_path = args.vocab_path 61 | with open(vocab_path, 'wb') as f: 62 | pickle.dump(vocab, f) 63 | print("Total vocabulary size: %d" %len(vocab)) 64 | print("Saved the vocabulary wrapper to '%s'" %vocab_path) 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--caption_path', type=str, 70 | default='/usr/share/mscoco/annotations/captions_train2014.json', 71 | help='path for train annotation file') 72 | parser.add_argument('--vocab_path', type=str, default='./data/vocab.pkl', 73 | help='path for saving vocabulary wrapper') 74 | parser.add_argument('--threshold', type=int, default=4, 75 | help='minimum word count threshold') 76 | args = parser.parse_args() 77 | main(args) 78 | -------------------------------------------------------------------------------- /utils/data_stream.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import more_itertools 4 | from PIL import Image as PIL_Image 5 | from keras.preprocessing import image 6 | from charpragcap.utils.image_and_text_utils import split_dataset,index_to_char,get_img_from_id 7 | from charpragcap.utils.config import * 8 | from charpragcap.resources.models.resnet import resnet 9 | 10 | 11 | id_to_caption = pickle.load(open("charpragcap/resources/id_to_caption",'rb')) 12 | reps = pickle.load(open(REP_DATA_PATH+'reps.pickle','rb')) 13 | fc_resnet = resnet(img_rep_layer) 14 | 15 | 16 | def single_stream(ids,X0_type='rep'): 17 | 18 | # TODO 19 | #find the other ids with same first part: hopefully can just look down the dict, if sorted right: check that 20 | #iterate through this list, taking each as a starting id, and returning the whole list of ids 21 | #use the lookup in a vectorized way (implemented in pandas) to get reps, if reps needed 22 | 23 | for full_id in ids: 24 | pairs=False 25 | if type(full_id)==tuple: pairs=True 26 | 27 | 28 | 29 | if pairs: 30 | 31 | 32 | fst_id,snd_id = full_id 33 | if X0_type=='rep': 34 | try: 35 | fst_X0,snd_X0 = reps.ix[fst_id].values,reps.ix[snd_id].values 36 | except Exception: 37 | out = [] 38 | for idx in [fst_id,snd_id]: 39 | img = get_img_from_id(idx,id_to_caption) 40 | img_vector = image.img_to_array(img) 41 | out.append(fc_resnet.predict([img_vector])) 42 | fst_X0,snd_X0 = tuple(out) 43 | elif X0_type=='id': 44 | fst_X0,snd_X0 = fst_id,snd_id 45 | fst_X1,fst_Y = id_to_caption[fst_id][caption] 46 | snd_X1,snd_Y = id_to_caption[snd_id][caption] 47 | yield (fst_X0,fst_X1,fst_Y,snd_X0,snd_X1,snd_Y) 48 | 49 | else: 50 | if X0_type=='rep': 51 | try: 52 | X0 = reps.ix[full_id].values 53 | except KeyError: 54 | img = get_img_from_id(full_id,id_to_caption) 55 | 56 | img_vector = np.expand_dims(image.img_to_array(img),0) 57 | print("\n\n\nshape",img_vector.shape) 58 | X0 = fc_resnet.predict(img_vector) 59 | 60 | print("\n\n\nGOT IMAGE\n\n\n") 61 | 62 | 63 | elif X0_type=='id': X0 = full_id 64 | X1,Y = id_to_caption[full_id][caption] 65 | yield (X0,X1,Y) 66 | 67 | #divides stream into chunks, i.e. minibatches 68 | def chunked_stream(ids): 69 | chunks = more_itertools.chunked(single_stream(ids),batch_size) 70 | 71 | for chunk in chunks: 72 | x1s,x2s,ys = list(zip(*chunk)) 73 | yield ([np.asarray(x1s),np.expand_dims(np.asarray(x2s),-1)],np.asarray(ys)) 74 | 75 | #cycles chunked stream 76 | def data(ids): 77 | # if X0_type=='rep': 78 | # reps = pickle.load(open(REP_DATA_PATH+'reps.pickle','rb')) 79 | while True: 80 | yield from chunked_stream(ids) 81 | 82 | train,val,test = split_dataset() 83 | 84 | 85 | 86 | #dimension checks 87 | 88 | # a = (next(data(test))) 89 | # assert a[0][0].shape[1] == rep_size 90 | # assert a[0][1].shape[1] == a[1].shape[1] == (max_sentence_length+1) 91 | # assert a[1].shape[2] == len(sym_set) 92 | # print("PASSED DIMENSION CHECKS") -------------------------------------------------------------------------------- /train/Model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torchvision import transforms 5 | from utils.build_vocab import Vocabulary 6 | from train.image_captioning.char_model import EncoderCNN, DecoderRNN 7 | from PIL import Image 8 | import torch 9 | from utils.config import * 10 | from utils.numpy_functions import softmax 11 | 12 | 13 | class Model: 14 | 15 | def __init__(self,path,dictionaries): 16 | 17 | self.seg2idx,self.idx2seg=dictionaries 18 | self.path=path 19 | self.vocab_path='data/vocab.pkl' 20 | self.encoder_path=TRAINED_MODEL_PATH+path+"-encoder-5-3000.pkl" 21 | self.decoder_path=TRAINED_MODEL_PATH+path+"-decoder-5-3000.pkl" 22 | 23 | #todo: change 24 | embed_size=256 25 | hidden_size=512 26 | num_layers=1 27 | 28 | output_size=30 29 | 30 | transform = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize((0.485, 0.456, 0.406), 33 | (0.229, 0.224, 0.225))]) 34 | 35 | self.transform = transform 36 | # Load vocabulary wrapper 37 | 38 | 39 | # Build Models 40 | self.encoder = EncoderCNN(embed_size) 41 | self.encoder.eval() # evaluation mode (BN uses moving mean/variance) 42 | self.decoder = DecoderRNN(embed_size, hidden_size, 43 | output_size, num_layers) 44 | 45 | # Load the trained model parameters 46 | self.encoder.load_state_dict(torch.load(self.encoder_path,map_location={'cuda:0': 'cpu'})) 47 | self.decoder.load_state_dict(torch.load(self.decoder_path,map_location={'cuda:0': 'cpu'})) 48 | 49 | if torch.cuda.is_available(): 50 | self.encoder.cuda() 51 | self.decoder.cuda() 52 | 53 | 54 | 55 | def forward(self,world,state): 56 | 57 | 58 | inputs = self.features[world.target].unsqueeze(1) 59 | 60 | 61 | states=None 62 | 63 | for seg in state.context_sentence: # maximum sampling length 64 | hiddens, states = self.decoder.lstm(inputs, states) # (batch_size, 1, hidden_size), 65 | 66 | outputs = self.decoder.linear(hiddens.squeeze(1)) 67 | 68 | predicted = outputs.max(1)[1] 69 | 70 | predicted[0] = self.seg2idx[seg] 71 | inputs = self.decoder.embed(predicted) 72 | inputs = inputs.unsqueeze(1) # (batch_size, vocab_size) 73 | 74 | hiddens, states = self.decoder.lstm(inputs, states) # (batch_size, 1, hidden_size), 75 | outputs = self.decoder.linear(hiddens.squeeze(1)) 76 | output_array = outputs.squeeze(0).data.cpu().numpy() 77 | 78 | log_softmax_array = np.log(softmax(output_array)) 79 | 80 | 81 | 82 | return log_softmax_array 83 | 84 | def set_features(self,images,rationalities,tf): 85 | 86 | self.number_of_images = len(images) 87 | self.number_of_rationalities = len(rationalities) 88 | self.rationality_support=rationalities 89 | 90 | if tf: 91 | pass 92 | 93 | else: 94 | from utils.sample import to_var,load_image,load_image_from_path 95 | self.features = [self.encoder(to_var(load_image(url, self.transform), volatile=True)) for url in images] 96 | self.default_image = self.encoder(to_var(load_image_from_path("data/default.jpg", self.transform), volatile=True)) 97 | 98 | 99 | # self.speakers = [Model(path) for path in paths] 100 | 101 | # imgs = [load_image(url) for url in urls] 102 | # self.images=[] 103 | 104 | # for img in imgs: 105 | # img_array = np.expand_dims(image.img_to_array(img),0) 106 | # img_rep = resnet(img_rep_layer).predict(img_array) 107 | # self.images.append(img_rep) 108 | 109 | # self.images = np.asarray(self.images) 110 | 111 | 112 | -------------------------------------------------------------------------------- /train/image_captioning/char_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from torch.nn.utils.rnn import pack_padded_sequence 5 | from torch.autograd import Variable 6 | 7 | from utils.build_vocab import Vocabulary 8 | import pickle 9 | 10 | class EncoderCNN(nn.Module): 11 | def __init__(self, embed_size): 12 | """Load the pretrained ResNet-152 and replace top fc layer.""" 13 | super(EncoderCNN, self).__init__() 14 | resnet = models.resnet152(pretrained=True) 15 | modules = list(resnet.children())[:-1] # delete the last fc layer. 16 | self.resnet = nn.Sequential(*modules) 17 | self.linear = nn.Linear(resnet.fc.in_features, embed_size) 18 | self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) 19 | self.init_weights() 20 | 21 | def init_weights(self): 22 | """Initialize the weights.""" 23 | self.linear.weight.data.normal_(0.0, 0.02) 24 | self.linear.bias.data.fill_(0) 25 | 26 | def forward(self, images): 27 | """Extract the image feature vectors.""" 28 | features = self.resnet(images) 29 | features = Variable(features.data) 30 | features = features.view(features.size(0), -1) 31 | features = self.bn(self.linear(features)) 32 | return features 33 | 34 | 35 | class DecoderRNN(nn.Module): 36 | def __init__(self, embed_size, hidden_size, vocab_size, num_layers): 37 | """Set the hyper-parameters and build the layers.""" 38 | super(DecoderRNN, self).__init__() 39 | self.embed = nn.Embedding(vocab_size, embed_size) 40 | self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) 41 | self.linear = nn.Linear(hidden_size, vocab_size) 42 | self.init_weights() 43 | 44 | def init_weights(self): 45 | """Initialize weights.""" 46 | self.embed.weight.data.uniform_(-0.1, 0.1) 47 | self.linear.weight.data.uniform_(-0.1, 0.1) 48 | self.linear.bias.data.fill_(0) 49 | 50 | def forward(self, features, captions, lengths): 51 | """Decode image feature vectors and generates captions.""" 52 | embeddings = self.embed(captions) 53 | embeddings = torch.cat((features.unsqueeze(1), embeddings), 1) 54 | packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 55 | hiddens, _ = self.lstm(packed) 56 | outputs = self.linear(hiddens[0]) 57 | return outputs 58 | 59 | def sample(self, features, states=None): 60 | """Samples captions for given image features (Greedy search).""" 61 | sampled_ids = [] 62 | inputs = features.unsqueeze(1) 63 | with open("./image_captioning/data/vocab.pkl", 'rb') as f: 64 | vocab = pickle.load(f) 65 | for i in range(20): # maximum sampling length 66 | hiddens, states = self.lstm(inputs, states) # (batch_size, 1, hidden_size), 67 | 68 | # print(hiddens.size()) 69 | # print(states[0].size(),states[1].size()) 70 | 71 | outputs = self.linear(hiddens.squeeze(1)) # (batch_size, vocab_size) 72 | predicted = outputs.max(1)[1] 73 | 74 | # print("stuff",type(predicted.data),predicted.data) 75 | # print(vocab.idx2word[1]) 76 | # print("\nNNASDFKLASDJF\n\n",vocab.idx2word[predicted.data.cpu().numpy()[0]]) 77 | 78 | sampled_ids.append(predicted) 79 | inputs = self.embed(predicted) 80 | inputs = inputs.unsqueeze(1) # (batch_size, 1, embed_size) 81 | 82 | # print("SAMPLED IDS",sampled_ids.size()) 83 | sampled_ids = torch.cat(sampled_ids, 0) # (batch_size, 20) 84 | return sampled_ids.squeeze() 85 | -------------------------------------------------------------------------------- /utils/build_data.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from PIL import Image as PIL_Image 3 | import math 4 | from subprocess import call 5 | from charpragcap.utils.image_and_text_utils import vectorize_caption,valid_item,index_to_char,char_to_index,edit_region,get_img_from_id 6 | from charpragcap.utils.config import * 7 | 8 | print("rep_size",rep_size) 9 | 10 | if __name__ == "__main__": 11 | 12 | if os.path.isfile('charpragcap/resources/resnet_reps/reps.pickle') : 13 | proceed = input('Are you sure you want to rebuild the data? (y/n) ') 14 | else: 15 | proceed = 'y' 16 | 17 | if proceed=='y': 18 | 19 | import json 20 | import numpy as np 21 | import pandas as pnd 22 | import pickle 23 | from keras.models import Model 24 | from keras.preprocessing import image 25 | from PIL import Image as PIL_Image 26 | from charpragcap.resources.models.resnet import resnet 27 | 28 | #define resnet from input to fully connected layer 29 | fc_resnet = resnet(img_rep_layer) 30 | 31 | def make_id_to_caption(): 32 | valids = 0 33 | invalids = 0 34 | id_to_caption = {} 35 | json_data=json.loads(open('charpragcap/resources/visual_genome_JSON/region_descriptions.json','r').read()) 36 | print("READ JSON, len:",len(json_data)) 37 | 38 | 39 | 40 | for i,image in enumerate(json_data): 41 | 42 | for s in image['regions']: 43 | 44 | x_coordinate = s['x'] 45 | y_coordinate = s['y'] 46 | height = s['height'] 47 | width = s['width'] 48 | sentence = s['phrase'].lower() 49 | img_id = str(s['image_id']) 50 | region_id = str(s['region_id']) 51 | 52 | is_valid = valid_item(height,width,sentence,img_id) 53 | 54 | if is_valid: 55 | valids+=1 56 | box = edit_region(height,width,x_coordinate,y_coordinate) 57 | id_to_caption[img_id+'_'+region_id] = (vectorize_caption(sentence),box) 58 | else: invalids+=1 59 | 60 | if i%1000==0 and i>0: 61 | print("PROGRESS:",i) 62 | 63 | # if i >6000: 64 | # break 65 | # print(len(id_to_caption)) 66 | # print(id_to_caption) 67 | print(len(id_to_caption)) 68 | print("num valid/ num invalid",valids,invalids) 69 | pickle.dump(id_to_caption,open('charpragcap/resources/id_to_caption','wb')) 70 | 71 | 72 | # e.g.: id_to_caption = {'10_1382': ('the boy with ice cream',(139,82,421,87)),'11_1382': ('the man with ice cream',(139,82,421,87))}... 73 | print("MAKING id_to_caption") 74 | make_id_to_caption() 75 | print("COMPUTING AND STORING image reps") 76 | 77 | #feed each image corresponding to a region in image from id in id_to_caption into a [rep_size] dim vector (or consider fewer) 78 | #save as pandas dataframe, with labelled columns 79 | def store_image_reps(): 80 | 81 | id_to_caption = pickle.load(open("charpragcap/resources/id_to_caption",'rb')) 82 | print("len id_to_caption",len(id_to_caption)) 83 | 84 | size = 1000 85 | num_images = len(id_to_caption) 86 | full_output = np.random.randn(len(id_to_caption),rep_size) 87 | mod_num = num_images % size 88 | r = math.ceil(num_images/size) 89 | for j in range(math.ceil(len(sorted(list(id_to_caption)))/size)): 90 | print("RUNNING IMAGES THROUGH RESNET: step",j+1,"out of", 91 | len(range(math.ceil(len(list(id_to_caption))/size)))) 92 | if j == r -1: 93 | num = mod_num 94 | else: 95 | num = size 96 | img_tensor = np.zeros((num, 224,224,3)) 97 | 98 | for i,item in enumerate(sorted(list(id_to_caption))[j*size:((j*size)+num)]): 99 | 100 | img = get_img_from_id(item,id_to_caption) 101 | img_vector = image.img_to_array(img) 102 | img_tensor[i] = img_vector 103 | 104 | reps = fc_resnet.predict(img_tensor) 105 | # print("check",reps.shape[0],len(list(id_to_caption)[j*size:((j*size)+num)])) 106 | assert reps.shape[0]==len(list(id_to_caption)[j*size:((j*size)+num)]) 107 | full_output[j*size:j*size+num] = reps[:num] 108 | 109 | 110 | 111 | df = pnd.DataFrame(full_output,index=sorted(list(id_to_caption))) 112 | 113 | assert df.shape == (len(id_to_caption),rep_size) 114 | 115 | df.to_pickle(REP_DATA_PATH+"reps.pickle") 116 | if j%10==0: df.to_pickle(REP_DATA_PATH+"reps.pickle_backup") 117 | 118 | # store_image_reps() 119 | -------------------------------------------------------------------------------- /utils/pycocotools/mask.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tsungyi' 2 | 3 | import cocoapi.PythonAPI.pycocotools._mask as _mask 4 | 5 | # Interface for manipulating masks stored in RLE format. 6 | # 7 | # RLE is a simple yet efficient format for storing binary masks. RLE 8 | # first divides a vector (or vectorized image) into a series of piecewise 9 | # constant regions and then for each piece simply stores the length of 10 | # that piece. For example, given M=[0 0 1 1 1 0 1] the RLE counts would 11 | # be [2 3 1 1], or for M=[1 1 1 1 1 1 0] the counts would be [0 6 1] 12 | # (note that the odd counts are always the numbers of zeros). Instead of 13 | # storing the counts directly, additional compression is achieved with a 14 | # variable bitrate representation based on a common scheme called LEB128. 15 | # 16 | # Compression is greatest given large piecewise constant regions. 17 | # Specifically, the size of the RLE is proportional to the number of 18 | # *boundaries* in M (or for an image the number of boundaries in the y 19 | # direction). Assuming fairly simple shapes, the RLE representation is 20 | # O(sqrt(n)) where n is number of pixels in the object. Hence space usage 21 | # is substantially lower, especially for large simple objects (large n). 22 | # 23 | # Many common operations on masks can be computed directly using the RLE 24 | # (without need for decoding). This includes computations such as area, 25 | # union, intersection, etc. All of these operations are linear in the 26 | # size of the RLE, in other words they are O(sqrt(n)) where n is the area 27 | # of the object. Computing these operations on the original mask is O(n). 28 | # Thus, using the RLE can result in substantial computational savings. 29 | # 30 | # The following API functions are defined: 31 | # encode - Encode binary masks using RLE. 32 | # decode - Decode binary masks encoded via RLE. 33 | # merge - Compute union or intersection of encoded masks. 34 | # iou - Compute intersection over union between masks. 35 | # area - Compute area of encoded masks. 36 | # toBbox - Get bounding boxes surrounding encoded masks. 37 | # frPyObjects - Convert polygon, bbox, and uncompressed RLE to encoded RLE mask. 38 | # 39 | # Usage: 40 | # Rs = encode( masks ) 41 | # masks = decode( Rs ) 42 | # R = merge( Rs, intersect=false ) 43 | # o = iou( dt, gt, iscrowd ) 44 | # a = area( Rs ) 45 | # bbs = toBbox( Rs ) 46 | # Rs = frPyObjects( [pyObjects], h, w ) 47 | # 48 | # In the API the following formats are used: 49 | # Rs - [dict] Run-length encoding of binary masks 50 | # R - dict Run-length encoding of binary mask 51 | # masks - [hxwxn] Binary mask(s) (must have type np.ndarray(dtype=uint8) in column-major order) 52 | # iscrowd - [nx1] list of np.ndarray. 1 indicates corresponding gt image has crowd region to ignore 53 | # bbs - [nx4] Bounding box(es) stored as [x y w h] 54 | # poly - Polygon stored as [[x1 y1 x2 y2...],[x1 y1 ...],...] (2D list) 55 | # dt,gt - May be either bounding boxes or encoded masks 56 | # Both poly and bbs are 0-indexed (bbox=[0 0 1 1] encloses first pixel). 57 | # 58 | # Finally, a note about the intersection over union (iou) computation. 59 | # The standard iou of a ground truth (gt) and detected (dt) object is 60 | # iou(gt,dt) = area(intersect(gt,dt)) / area(union(gt,dt)) 61 | # For "crowd" regions, we use a modified criteria. If a gt object is 62 | # marked as "iscrowd", we allow a dt to match any subregion of the gt. 63 | # Choosing gt' in the crowd gt that best matches the dt can be done using 64 | # gt'=intersect(dt,gt). Since by definition union(gt',dt)=dt, computing 65 | # iou(gt,dt,iscrowd) = iou(gt',dt) = area(intersect(gt,dt)) / area(dt) 66 | # For crowd gt regions we use this modified criteria above for the iou. 67 | # 68 | # To compile run "python setup.py build_ext --inplace" 69 | # Please do not contact us for help with compiling. 70 | # 71 | # Microsoft COCO Toolbox. version 2.0 72 | # Data, paper, and tutorials available at: http://mscoco.org/ 73 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 74 | # Licensed under the Simplified BSD License [see coco/license.txt] 75 | 76 | iou = _mask.iou 77 | merge = _mask.merge 78 | frPyObjects = _mask.frPyObjects 79 | 80 | def encode(bimask): 81 | if len(bimask.shape) == 3: 82 | return _mask.encode(bimask) 83 | elif len(bimask.shape) == 2: 84 | h, w = bimask.shape 85 | return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0] 86 | 87 | def decode(rleObjs): 88 | if type(rleObjs) == list: 89 | return _mask.decode(rleObjs) 90 | else: 91 | return _mask.decode([rleObjs])[:,:,0] 92 | 93 | def area(rleObjs): 94 | if type(rleObjs) == list: 95 | return _mask.area(rleObjs) 96 | else: 97 | return _mask.area([rleObjs])[0] 98 | 99 | def toBbox(rleObjs): 100 | if type(rleObjs) == list: 101 | return _mask.toBbox(rleObjs) 102 | else: 103 | return _mask.toBbox([rleObjs])[0] 104 | -------------------------------------------------------------------------------- /bayesian_agents/joint_rsa.py: -------------------------------------------------------------------------------- 1 | import time 2 | import itertools 3 | import scipy 4 | import scipy.stats 5 | import numpy as np 6 | import math 7 | from PIL import Image as PIL_Image 8 | from keras.preprocessing import image 9 | from keras.models import load_model 10 | from utils.image_and_text_utils import index_to_char,char_to_index 11 | from utils.config import * 12 | from bayesian_agents.rsaWorld import RSA_World 13 | from utils.numpy_functions import softmax 14 | from train.Model import Model 15 | 16 | class RSA: 17 | 18 | def __init__( 19 | self, 20 | seg_type, 21 | tf, 22 | ): 23 | 24 | 25 | self.tf=tf 26 | self.seg_type=seg_type 27 | self.char=self.seg_type="char" 28 | 29 | #caches for memoization 30 | self._speaker_cache = {} 31 | self._listener_cache = {} 32 | self._speaker_prior_cache = {} 33 | 34 | if self.char: 35 | self.idx2seg=index_to_char 36 | self.seg2idx=char_to_index 37 | 38 | 39 | 40 | 41 | def initialize_speakers(self,paths): 42 | 43 | 44 | self.initial_speakers = [Model(path=path, 45 | dictionaries=(self.seg2idx,self.idx2seg)) for path in paths] 46 | self.speaker_prior = Model(path="lang_mod", 47 | dictionaries=(self.seg2idx,self.idx2seg)) 48 | # self.initial_speaker.set_features() 49 | 50 | # self.speaker_prior 51 | 52 | # self.images=images 53 | # print("NUMBER OF IMAGES:",self.number_of_images) 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | def flush_cache(self): 62 | 63 | self._speaker_cache = {} 64 | self._listener_cache = {} 65 | self._speaker_prior_cache = {} 66 | 67 | # memoization is crucial for speed of the RSA, which is recursive: memoization via decorators for speaker and listener 68 | # def memoize_speaker_prior(f): 69 | # def helper(self,state,world): 70 | 71 | # # world_prior_list = np.ndarray.tolist(np.ndarray.flatten(state.world_priors)) 72 | # hashable_args = state,world 73 | 74 | # if hashable_args not in self._speaker_cache: 75 | # self._speaker_prior_cache[hashable_args] = f(self,state,world) 76 | # # else: print("cached") 77 | # return self._speaker_prior_cache[hashable_args] 78 | # return helper 79 | 80 | def memoize_speaker(f): 81 | def helper(self,state,world,depth): 82 | 83 | # world_prior_list = np.ndarray.tolist(np.ndarray.flatten(state.world_priors)) 84 | hashable_args = state,world,depth 85 | 86 | if hashable_args not in self._speaker_cache: 87 | self._speaker_cache[hashable_args] = f(self,state,world,depth) 88 | # else: print("cached") 89 | return self._speaker_cache[hashable_args] 90 | return helper 91 | 92 | def memoize_listener(f): 93 | def helper(self,state,utterance,depth): 94 | 95 | # world_prior_list = np.ndarray.tolist(np.ndarray.flatten(state.world_priors)) 96 | hashable_args = state,utterance,depth 97 | 98 | if hashable_args not in self._listener_cache: 99 | self._listener_cache[hashable_args] = f(self,state,utterance,depth) 100 | # else: print("cached") 101 | 102 | return self._listener_cache[hashable_args] 103 | return helper 104 | 105 | 106 | 107 | 108 | # @memoize_speaker_prior 109 | # def speaker_prior(self,state,world): 110 | # # print("SPEAKER PRIOR",(world.target,world.speaker,world.rationality)) 111 | 112 | # pass 113 | 114 | 115 | @memoize_speaker 116 | def speaker(self,state,world,depth): 117 | # print("rationality",world.rationality) 118 | # print("world prior shape",state.world_priors[0].shape) 119 | # print("SPEAKER\n\n",depth) 120 | 121 | 122 | if depth==0: 123 | # print("S0") 124 | # print("TIMESTEP:",state.timestep,"INITIAL SPEAKER CALL") 125 | 126 | 127 | 128 | return self.initial_speakers[world.speaker].forward(state=state,world=world) 129 | 130 | else: 131 | 132 | prior = self.speaker(state,world,depth=0) 133 | # self.initial_speakers[world.speaker].forward(state=state,world=world) 134 | # prior = self.speaker_prior.forward(state=state,world=world) 135 | 136 | # self.speaker(state=state,world=world,depth=0) 137 | if depth==1: 138 | 139 | scores = [] 140 | for k in range(prior.shape[0]): 141 | # print(world.target,world.rationality,"FIRST") 142 | out = self.listener(state=state,utterance=k,depth=depth-1) 143 | 144 | 145 | scores.append(out[world.target,world.rationality,world.speaker]) 146 | 147 | scores = np.asarray(scores) 148 | # print("SCORES",scores) 149 | # rationality in traditional RSA sense 150 | scores = scores*(self.initial_speakers[world.speaker].rationality_support[world.rationality]) 151 | # update prior to posterior 152 | # print(scores.shape,prior.shape) 153 | posterior = (scores + prior) - scipy.misc.logsumexp(scores + prior) 154 | # print("POSTERIOR",posterior) 155 | 156 | return posterior 157 | 158 | elif depth==2: 159 | 160 | scores = [] 161 | for k in range(prior.shape[0]): 162 | 163 | # print(world.rationality,"rat") 164 | out = self.listener(state=state,utterance=k,depth=depth-1) 165 | scores.append(out[world.target,world.rationality,world.speaker]) 166 | 167 | scores = np.asarray(scores) 168 | # rationality not present at s2 169 | # update prior to posterior 170 | posterior = (scores + prior) - scipy.misc.logsumexp(scores + prior) 171 | 172 | return posterior 173 | 174 | @memoize_listener 175 | def listener(self,state,utterance,depth): 176 | 177 | # base case listener is either neurally trained, or inferred from neural s0, given the state's current prior on images 178 | 179 | # world = RSA_World(target=0,speaker=0,rationality=0) 180 | # image_prior = self.listener(state=state,utterance=utterance,depth=depth-1) 181 | # rationality_prior = np.asarray([0.3,0.7]) 182 | 183 | world_prior = state.world_priors[state.timestep-1] 184 | # print("world prior",np.exp(world_prior)) 185 | # if state.timestep < 4: 186 | # print("world priors",np.exp(state.world_priors[:4])) 187 | # print("timestep",state.timestep) 188 | 189 | # if depth==0: 190 | # else: world_prior = self.listener(state=state,utterance=utterance,depth=0) 191 | # print(world_prior.shape) 192 | 193 | # I could write: itertools product axes 194 | scores = np.zeros((world_prior.shape)) 195 | for n_tuple in itertools.product(*[list(range(x)) for x in world_prior.shape]): 196 | # print(n_tuple) 197 | # print(world_prior.shape) 198 | # for j in range(self.number_of_images): 199 | # for i in range(len(rationality_prior)): 200 | # world.target=j 201 | world = RSA_World(target=n_tuple[state.dim["image"]],rationality=n_tuple[state.dim["rationality"]],speaker=n_tuple[state.dim["speaker"]]) 202 | # world.set_values(n_tuple) 203 | 204 | # world.rationality=rationality_prior[i] 205 | # NOTE THAT NOT DEPTH-1 HERE 206 | out = self.speaker(state=state,world=world,depth=depth) 207 | # out = np.squeeze(out) 208 | 209 | # print(out,depth) 210 | scores[n_tuple]=out[utterance] 211 | 212 | scores = scores*state.listener_rationality 213 | world_posterior = (scores + world_prior) - scipy.misc.logsumexp(scores + world_prior) 214 | # print("world posterior listener complex shape",world_posterior.shape) 215 | return world_posterior 216 | 217 | def listener_simple(self,state,utterance,depth): 218 | 219 | 220 | # base case listener is either neurally trained, or inferred from neural s0, given the state's current prior on images 221 | 222 | # world = RSA_World(target=0,speaker=0,rationality=0) 223 | # image_prior = self.listener(state=state,utterance=utterance,depth=depth-1) 224 | # rationality_prior = np.asarray([0.3,0.7]) 225 | 226 | world_prior = state.world_priors[state.timestep-1] 227 | assert world_prior.shape == (2,1,1) 228 | print("world prior",np.exp(world_prior)) 229 | # world_prior = np.log(np.asarray([0.5,0.5])) 230 | # if depth==0: 231 | # else: world_prior = self.listener(state=state,utterance=utterance,depth=0) 232 | # print(world_prior.shape) 233 | 234 | # I could write: itertools product axes 235 | scores = np.zeros((2,1,1)) 236 | for i in range(2): 237 | # print(n_tuple) 238 | # print(world_prior.shape) 239 | # for j in range(self.number_of_images): 240 | # for i in range(len(rationality_prior)): 241 | # world.target=j 242 | world = RSA_World(target=i,rationality=0,speaker=0) 243 | # world.set_values(n_tuple) 244 | 245 | # world.rationality=rationality_prior[i] 246 | # NOTE THAT NOT DEPTH-1 HERE 247 | out = self.speaker(state=state,world=world,depth=depth) 248 | # out = np.squeeze(out) 249 | 250 | # print(out,depth) 251 | scores[i]=out[utterance] 252 | 253 | scores = scores*state.listener_rationality 254 | world_posterior = (scores + world_prior) - scipy.misc.logsumexp(scores + world_prior) 255 | # print("world posterior listener simple shape",world_posterior.shape) 256 | 257 | return world_posterior 258 | 259 | 260 | 261 | -------------------------------------------------------------------------------- /recursion_schemes/recursion_schemes.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import time 4 | import math 5 | import scipy 6 | import scipy.stats 7 | import copy 8 | from utils.config import * 9 | from bayesian_agents.rsaState import RSA_State 10 | from bayesian_agents.rsaWorld import RSA_World 11 | from utils.image_and_text_utils import max_sentence_length,index_to_char,char_to_index,devectorize_caption,\ 12 | sentence_likelihood,largest_indices,vectorize_caption 13 | 14 | ### 15 | """ 16 | Recursion schemes for the RSA 17 | """ 18 | ### 19 | 20 | # you should call this so that running unroll beam prints this out 21 | # calculates likelihood of caption being generated by a speaker 22 | def likelihood(self,depth=0,which_image=0,speaker_rationality=1.0,speaker=0,listener_rationality=1.0,img_prior=np.log(np.asarray([0.5,0.5])),start_from=""): 23 | 24 | self.speaker_rationality=speaker_rationality 25 | self.listener_rationality=listener_rationality 26 | self.image_priors=np.log(np.ones((max_sentence_length+1,self.images.shape[0]))*(1/self.images.shape[0])) 27 | self.image_priors[0]=img_prior 28 | sentence = np.expand_dims(np.expand_dims(vectorize_caption(start_from)[0],0),-1) 29 | self.context_sentence = copy.deepcopy(sentence) 30 | 31 | likelihood = {} 32 | 33 | for i in range(1,max_sentence_length): 34 | self.i = i 35 | char = np.squeeze(sentence)[i] 36 | s = np.squeeze(self.speaker(img_idx=which_image,depth=depth)) 37 | likelihood[i] = s[char] 38 | if index_to_char[char] == stop_token: 39 | break 40 | 41 | print(start_from,np.prod([np.exp(p) for (l,p) in likelihood.items()])) 42 | 43 | def ana_greedy(rsa,initial_world_prior,speaker_rationality,speaker, target, pass_prior=True,listener_rationality=1.0,depth=0,start_from=[]): 44 | 45 | 46 | """ 47 | speaker_rationality,listener_rationality: 48 | 49 | see speaker and listener code for what they do: control strength of conditioning 50 | 51 | depth: 52 | 53 | the number of levels of RSA: depth 0 uses listeral speaker to unroll, depth n uses speaker n to unroll, and listener n to update at each step 54 | 55 | start_from: 56 | 57 | a partial caption you start the unrolling from 58 | 59 | img_prior: 60 | 61 | a prior on the world to start with 62 | """ 63 | 64 | 65 | # this RSA passes along a state: see rsa_state 66 | state = RSA_State(initial_world_prior, listener_rationality=listener_rationality) 67 | # state.image_priors[:]=img_prior 68 | 69 | context_sentence = ['^']+start_from 70 | state.context_sentence=context_sentence 71 | 72 | 73 | world=RSA_World(target=target,rationality=speaker_rationality,speaker=speaker) 74 | 75 | probs=[] 76 | for timestep in tqdm(range(len(start_from)+1,max_sentence_length)): 77 | 78 | state.timestep=timestep 79 | s = rsa.speaker(state=state,world=world,depth=depth) 80 | # print("S:",s) 81 | # print(s) 82 | segment = np.argmax(s) 83 | # print("s",rsa.idx2seg[segment]) 84 | prob = np.max(s) 85 | probs.append(prob) 86 | 87 | if pass_prior: 88 | 89 | l = rsa.listener(state=state,utterance=segment,depth=depth) 90 | state.world_priors[state.timestep]=l 91 | state.context_sentence += [rsa.idx2seg[segment]] 92 | if (rsa.idx2seg[segment] == stop_token[rsa.seg_type]): 93 | break 94 | 95 | summed_probs = np.sum(np.asarray(probs)) 96 | 97 | world_posterior = state.world_priors[:state.timestep+1][:5] 98 | 99 | return [("".join(state.context_sentence),summed_probs)] 100 | 101 | #But within the n-th order ethno-metapragmatic perspective, this creative indexical effect is the motivated realization, or performable execution, of an already constituted framework of semiotic value. 102 | def ana_beam(rsa,initial_world_prior,speaker_rationality, target,speaker, pass_prior=True,listener_rationality=1.0,depth=0,start_from=[],beam_width=len(sym_set),cut_rate=1,decay_rate=0.0,beam_decay=0,): 103 | """ 104 | speaker_rationality,listener_rationality: 105 | 106 | see speaker and listener code for what they do: control strength of conditioning 107 | 108 | depth: 109 | 110 | the number of levels of RSA: depth 0 uses listeral speaker to unroll, depth n uses speaker n to unroll, and listener n to update at each step 111 | 112 | start_from: 113 | a partial caption you start the unrolling from 114 | img_prior: 115 | a prior on the world to start with 116 | which_image: 117 | which of the images in the prior should be targeted? 118 | beam width: width beam is cut down to every cut_rate iterations of the unrolling 119 | cut_rate: how often beam is cut down to beam_width 120 | beam_decay: amount by which beam_width is lessened after each iteration 121 | decay_rate: a multiplier that makes later decisions in the unrolling matter less: 0.0 does no decay. negative decay makes start matter more 122 | """ 123 | 124 | state = RSA_State(initial_world_prior, listener_rationality=listener_rationality) 125 | # state.image_priors[:]=img_prior 126 | 127 | context_sentence = start_from 128 | state.context_sentence=context_sentence 129 | 130 | 131 | world=RSA_World(target=target,rationality=speaker_rationality,speaker=speaker) 132 | 133 | 134 | 135 | context_sentence = start_from 136 | state.context_sentence=context_sentence 137 | 138 | 139 | sent_worldprior_prob = [(state.context_sentence,state.world_priors,0.0)] 140 | 141 | final_sentences=[] 142 | 143 | toc = time.time() 144 | for timestep in tqdm(range(len(start_from)+1,max_sentence_length)): 145 | 146 | 147 | state.timestep=timestep 148 | 149 | new_sent_worldprior_prob = [] 150 | for sent,worldpriors,old_prob in sent_worldprior_prob: 151 | 152 | state.world_priors=worldpriors 153 | 154 | if state.timestep > 1: 155 | 156 | state.context_sentence = sent[:-1] 157 | seg = sent[-1] 158 | 159 | if depth>0 and pass_prior: 160 | 161 | l=rsa.listener(state=state,utterance=rsa.seg2idx[seg],depth=depth) 162 | state.world_priors[state.timestep-1]=copy.deepcopy(l) 163 | 164 | state.context_sentence = sent 165 | 166 | # out = rsa.speaker(state=state,img_idx=which_image,depth=depth) 167 | s = rsa.speaker(state=state,world=world,depth=depth) 168 | 169 | for seg,prob in enumerate(np.squeeze(s)): 170 | 171 | new_sentence = copy.deepcopy(sent) 172 | 173 | # conditional to deal with captions longer than max sentence length 174 | # if state.timestep=50: 211 | # # print("beam unroll time",tic-toc) 212 | # # print(state.image_priors[:]) 213 | sentences = sorted(final_sentences,key=lambda x : x[-1],reverse=True) 214 | output = [] 215 | for i,(sent,prob) in enumerate(sentences): 216 | 217 | output.append(("".join(sent),prob)) 218 | 219 | return output 220 | # # print(sentences) 221 | # for i,(sent,prob) in enumerate(sentences): 222 | 223 | # output.append(("".join([rsa.idx2word[idx] for idx in np.squeeze(sent)]),prob)) 224 | 225 | # return output 226 | # return "COMPLETE" 227 | # return "".join([rsa.idx2word[idx] for idx in np.squeeze(final_sentences[0])]) 228 | 229 | if beam_decay < beam_width: 230 | beam_width-=beam_decay 231 | # print("decayed beam width by "+str(beam_decay)+"; beam_width now: "+str(beam_width)) 232 | 233 | else: 234 | sentences = sorted(final_sentences,key=lambda x : x[-1],reverse=True) 235 | 236 | output = [] 237 | # print(sentences) 238 | for i,(sent,prob) in enumerate(sentences): 239 | 240 | output.append(("".join(sent),prob)) 241 | 242 | return output 243 | 244 | -------------------------------------------------------------------------------- /utils/image_and_text_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | import os.path 4 | from utils.config import * 5 | from keras.preprocessing import image 6 | 7 | ###IMAGE UTILS 8 | 9 | 10 | 11 | 12 | def overlap(target_box, candidate_box): 13 | 14 | saved_x1 = target_box[0] 15 | saved_y1 = target_box[1] 16 | saved_x2 = target_box[2] 17 | saved_y2 = target_box[3] 18 | 19 | x1 = candidate_box[0] 20 | y1 = candidate_box[1] 21 | x2 = candidate_box[2] 22 | y2 = candidate_box[3] 23 | 24 | cond1 = saved_x1 < x1 and x1 < saved_x2 25 | cond2 = saved_y1 < y1 and y1 < saved_y2 26 | cond3 = saved_x1 < x2 and x2 < saved_x2 27 | cond4 = saved_y1 < y2 and y2 < saved_y2 28 | 29 | return (cond1 or cond2 or cond3 or cond4) 30 | 31 | def edit_region(height,width,x_coordinate,y_coordinate): 32 | if (width > height): 33 | # check if image recentering causes box to go off the image up 34 | if(y_coordinate+(height/2)-(width/2) < 0.0): 35 | box = (x_coordinate,y_coordinate, x_coordinate+ \ 36 | max(width,height),y_coordinate+max(width,height)) 37 | else: 38 | box = (x_coordinate,y_coordinate+(height/2)-(width/2), \ 39 | x_coordinate+max(width,height),y_coordinate+(height/2)-(width/2)+max(width,height)) 40 | else: 41 | # check if image recentering causes box to go off the image to the left 42 | if(x_coordinate+(width/2)-(height/2) < 0.0): 43 | box = (x_coordinate,y_coordinate, x_coordinate+ \ 44 | max(width,height),y_coordinate+max(width,height)) 45 | else: 46 | box = (x_coordinate+(width/2)-(height/2),y_coordinate, \ 47 | x_coordinate+(width/2)-(height/2)+max(width,height),y_coordinate+max(width,height)) 48 | 49 | return box 50 | 51 | 52 | 53 | #determine if a region and caption are suitable for inclusion in data 54 | def valid_item(height,width,sentence,img_id): 55 | 56 | ratio = ((float(max(height,width))) / float(min(height,width))) 57 | size = float(height) 58 | file_exists = os.path.isfile(IMG_DATA_PATH+"VG_100K/"+str(img_id)+".jpg") 59 | good_length = len(sentence) < max_sentence_length 60 | no_punctuation = all((char in sym_set) for char in sentence) 61 | return ratio<1.25 and size>100.0 and file_exists and good_length and no_punctuation 62 | 63 | def get_img_from_id(item,id_to_caption): 64 | 65 | 66 | from PIL import Image as PIL_Image 67 | 68 | img_id,region_id = item.split('_') 69 | path = IMG_DATA_PATH+'VG_100K/'+img_id+".jpg" 70 | img = PIL_Image.open(path) 71 | #crop region from img 72 | box = id_to_caption[item][region] 73 | # print("box",box) 74 | cropped_img = img.crop(box) 75 | # print("cropped_img", image.img_to_array(cropped_img)) 76 | #resize into square 77 | resized_img = cropped_img.resize([224,224],PIL_Image.LANCZOS) 78 | 79 | 80 | return resized_img 81 | 82 | def get_rep_from_id(item,id_to_caption): 83 | 84 | from PIL import Image as PIL_Image 85 | from charpragcap.resources.models.resnet import resnet 86 | from keras.preprocessing import image 87 | 88 | img_id,region_id = item.split('_') 89 | path = IMG_DATA_PATH+'VG_100K/'+img_id+".jpg" 90 | img = PIL_Image.open(path) 91 | #crop region from img 92 | box = id_to_caption[item][region] 93 | # print("box",box) 94 | cropped_img = img.crop(box) 95 | # print("cropped_img", image.img_to_array(cropped_img)) 96 | #resize into square 97 | resized_img = cropped_img.resize([224,224],PIL_Image.ANTIALIAS) 98 | 99 | display(resized_img) 100 | 101 | img = np.expand_dims(image.img_to_array(resized_img),0) 102 | 103 | img = resnet(img_rep_layer).predict(img) 104 | return img 105 | 106 | 107 | 108 | # nb: only the first part of the id, i.e. the image, not the region: doesn't crop 109 | def get_rep_from_img_id(img_id): 110 | 111 | from PIL import Image as PIL_Image 112 | import urllib.request 113 | from charpragcap.resources.models.resnet import resnet 114 | from keras.preprocessing import image 115 | 116 | path = IMG_DATA_PATH+'VG_100K/'+img_id+".jpg" 117 | img = PIL_Image.open(path) 118 | resized_img = img.resize([224,224],PIL_Image.ANTIALIAS) 119 | 120 | img = np.expand_dims(image.img_to_array(resized_img),0) 121 | 122 | img = resnet(img_rep_layer).predict(img) 123 | return img 124 | 125 | #TODO use or remove 126 | 127 | def get_img_from_url(url): 128 | import urllib.request 129 | from charpragcap.resources.models.resnet import resnet 130 | from PIL import Image as PIL_Image 131 | import shutil 132 | import requests 133 | from keras.preprocessing import image 134 | from PIL import Image as PIL_Image 135 | response = requests.get(url, stream=True) 136 | with open('charpragcap/resources/img.jpg', 'wb') as out_file: 137 | shutil.copyfileobj(response.raw, out_file) 138 | del response 139 | 140 | img = PIL_Image.open('charpragcap/resources/img.jpg') 141 | 142 | return img 143 | 144 | # model = resnet(img_rep_layer) 145 | 146 | # # file_name = "charpragcap/resources/local-filename.jpg" 147 | # # urllib.request.urlretrieve(url, file_name) 148 | # # img = PIL_Image.open(file_name) 149 | 150 | # img = img.resize([224,224],PIL_Image.ANTIALIAS) 151 | # display(img) 152 | # img = np.expand_dims(image.img_to_array(img),0) 153 | 154 | # rep = resnet(img_rep_layer).predict(img) 155 | # return rep 156 | 157 | def get_rep_from_url(url,model): 158 | import urllib.request 159 | from keras.preprocessing import image 160 | from PIL import Image as PIL_Image 161 | import shutil 162 | import requests 163 | from keras.preprocessing import image 164 | from PIL import Image as PIL_Image 165 | response = requests.get(url, stream=True) 166 | with open('charpragcap/resources/img.jpg', 'wb') as out_file: 167 | shutil.copyfileobj(response.raw, out_file) 168 | del response 169 | 170 | img = PIL_Image.open('charpragcap/resources/img.jpg') 171 | 172 | 173 | # file_name = "charpragcap/resources/local-filename.jpg" 174 | # urllib.request.urlretrieve(url, file_name) 175 | # img = PIL_Image.open(file_name) 176 | 177 | img = img.resize([224,224],PIL_Image.ANTIALIAS) 178 | # display(img) 179 | img = np.expand_dims(image.img_to_array(img),0) 180 | 181 | rep = model.predict(img) 182 | return rep 183 | 184 | #for ipython image displaying 185 | def display_image(number): 186 | 187 | import pickle 188 | from PIL import Image as PIL_Image 189 | 190 | id_to_caption = pickle.load(open("charpragcap/resources/id_to_caption",'rb')) 191 | chosen_id = list(id_to_caption)[number] 192 | 193 | img_path = "data/VG_100K/"+str(chosen_id)+".jpg" 194 | box = id_to_caption[chosen_id][1] 195 | img_id,region_id = chosen_id.split("_") 196 | img = PIL_Image.open(IMG_DATA_PATH+"VG_100K/"+str(img_id)+".jpg") 197 | display(img) 198 | region = img.crop(box) 199 | region = region.resize([224,224],PIL_Image.ANTIALIAS) 200 | display(region) 201 | 202 | def display_img_from_url(url): 203 | import shutil 204 | import requests 205 | from keras.preprocessing import image 206 | from PIL import Image as PIL_Image 207 | response = requests.get(url, stream=True) 208 | with open('charpragcap/resources/img.jpg', 'wb') as out_file: 209 | shutil.copyfileobj(response.raw, out_file) 210 | del response 211 | img = PIL_Image.open('charpragcap/resources/img.jpg') 212 | img = img.resize([224,224],PIL_Image.ANTIALIAS) 213 | display(img) 214 | 215 | def get_img(url): 216 | import shutil 217 | import requests 218 | from keras.preprocessing import image 219 | from PIL import Image as PIL_Image 220 | from charpragcap.resources.models.resnet import resnet 221 | from charpragcap.utils.config import img_rep_layer 222 | 223 | 224 | response = requests.get(url, stream=True) 225 | with open('charpragcap/resources/img.jpg', 'wb') as out_file: 226 | shutil.copyfileobj(response.raw, out_file) 227 | del response 228 | img = PIL_Image.open('charpragcap/resources/img.jpg') 229 | img = img.resize([224,224],PIL_Image.ANTIALIAS) 230 | display(img) 231 | rep = np.expand_dims(image.img_to_array(img),0) 232 | 233 | rep = resnet(img_rep_layer).predict(img) 234 | return rep 235 | 236 | def item_to_rep(item,id_to_caption): 237 | import numpy as np 238 | from charpragcap.resources.models.resnet import resnet 239 | from keras.preprocessing import image 240 | 241 | original_image = get_img_from_id(item,id_to_caption) 242 | original_image_vector = np.expand_dims(image.img_to_array(original_image),axis=0) 243 | input_image = resnet(img_rep_layer).predict(original_image_vector) 244 | return input_image 245 | 246 | ###TEXT UTILS 247 | 248 | #convert caption into vector: SHAPE? 249 | def vectorize_caption(sentence): 250 | if len(sentence) > 0 and sentence[-1] in list("!?."): 251 | sentence = sentence[:-1] 252 | sentence = start_token["char"] + sentence + stop_token["char"] 253 | sentence = list(sentence) 254 | while len(sentence) < max_sentence_length+2: 255 | sentence.append(pad_token) 256 | 257 | caption_in = sentence[:-1] 258 | caption_out = sentence[1:] 259 | caption_in = np.asarray([char_to_index[x] for x in caption_in]) 260 | caption_out = np.expand_dims(np.asarray([char_to_index[x] for x in caption_out]),0) 261 | one_hot = np.zeros((caption_out.shape[1], len(sym_set))) 262 | one_hot[np.arange(caption_out.shape[1]), caption_out] = 1 263 | caption_out = one_hot 264 | return caption_in,caption_out 265 | 266 | # takes (1,39,1) and returns string 267 | def devectorize_caption(ary): 268 | # print("ARY",ary.shape,ary) 269 | return "".join([index_to_char[idx] for idx in np.squeeze(ary)]) 270 | 271 | 272 | 273 | 274 | 275 | #OTHER UTILS 276 | def sentence_likelihood(img,sent): 277 | sentence = text_to_vecs(sent,words=True) 278 | # print(sentence.shape,img.shape) 279 | probs = s_zero.predict([img,sentence]) 280 | probs = [x[word_to_index[sent[i+1]]] for i,x in enumerate(probs[0][:-1])] 281 | # print(np.sum(np.log(probs))) 282 | 283 | def largest_indices(self,ary, n): 284 | flat = ary.flatten() 285 | indices = np.argpartition(flat, -n)[-n:] 286 | indices = indices[np.argsort(-flat[indices])] 287 | return np.unravel_index(indices, ary.shape) 288 | 289 | # gives a train,val,test split, by way of three sets of id_to_caption keys 290 | 291 | 292 | def split_dataset(trains=train_size,vals=val_size,tests=test_size): 293 | 294 | import pickle 295 | id_to_caption = pickle.load(open("charpragcap/resources/id_to_caption",'rb')) 296 | 297 | ids = sorted(list(id_to_caption)) 298 | num_ids = (len(ids)) 299 | assert trains+vals+tests == 1.0 300 | 301 | num_train = int(num_ids*trains) 302 | num_val = num_train + int(num_ids*vals) 303 | num_test = num_val + int(num_ids*tests) 304 | 305 | trains,vals,tests = ids[0:num_train],ids[num_train:num_val],ids[num_val:num_test] 306 | return trains,vals,tests 307 | 308 | 309 | -------------------------------------------------------------------------------- /utils/pycocotools/_mask.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c 2 | # distutils: sources = ../common/maskApi.c 3 | 4 | #************************************************************************** 5 | # Microsoft COCO Toolbox. version 2.0 6 | # Data, paper, and tutorials available at: http://mscoco.org/ 7 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 8 | # Licensed under the Simplified BSD License [see coco/license.txt] 9 | #************************************************************************** 10 | 11 | __author__ = 'tsungyi' 12 | 13 | import sys 14 | PYTHON_VERSION = sys.version_info[0] 15 | 16 | # import both Python-level and C-level symbols of Numpy 17 | # the API uses Numpy to interface C and Python 18 | import numpy as np 19 | cimport numpy as np 20 | from libc.stdlib cimport malloc, free 21 | 22 | # intialized Numpy. must do. 23 | np.import_array() 24 | 25 | # import numpy C function 26 | # we use PyArray_ENABLEFLAGS to make Numpy ndarray responsible to memoery management 27 | cdef extern from "numpy/arrayobject.h": 28 | void PyArray_ENABLEFLAGS(np.ndarray arr, int flags) 29 | 30 | # Declare the prototype of the C functions in MaskApi.h 31 | cdef extern from "maskApi.h": 32 | ctypedef unsigned int uint 33 | ctypedef unsigned long siz 34 | ctypedef unsigned char byte 35 | ctypedef double* BB 36 | ctypedef struct RLE: 37 | siz h, 38 | siz w, 39 | siz m, 40 | uint* cnts, 41 | void rlesInit( RLE **R, siz n ) 42 | void rleEncode( RLE *R, const byte *M, siz h, siz w, siz n ) 43 | void rleDecode( const RLE *R, byte *mask, siz n ) 44 | void rleMerge( const RLE *R, RLE *M, siz n, int intersect ) 45 | void rleArea( const RLE *R, siz n, uint *a ) 46 | void rleIou( RLE *dt, RLE *gt, siz m, siz n, byte *iscrowd, double *o ) 47 | void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o ) 48 | void rleToBbox( const RLE *R, BB bb, siz n ) 49 | void rleFrBbox( RLE *R, const BB bb, siz h, siz w, siz n ) 50 | void rleFrPoly( RLE *R, const double *xy, siz k, siz h, siz w ) 51 | char* rleToString( const RLE *R ) 52 | void rleFrString( RLE *R, char *s, siz h, siz w ) 53 | 54 | # python class to wrap RLE array in C 55 | # the class handles the memory allocation and deallocation 56 | cdef class RLEs: 57 | cdef RLE *_R 58 | cdef siz _n 59 | 60 | def __cinit__(self, siz n =0): 61 | rlesInit(&self._R, n) 62 | self._n = n 63 | 64 | # free the RLE array here 65 | def __dealloc__(self): 66 | if self._R is not NULL: 67 | for i in range(self._n): 68 | free(self._R[i].cnts) 69 | free(self._R) 70 | def __getattr__(self, key): 71 | if key == 'n': 72 | return self._n 73 | raise AttributeError(key) 74 | 75 | # python class to wrap Mask array in C 76 | # the class handles the memory allocation and deallocation 77 | cdef class Masks: 78 | cdef byte *_mask 79 | cdef siz _h 80 | cdef siz _w 81 | cdef siz _n 82 | 83 | def __cinit__(self, h, w, n): 84 | self._mask = malloc(h*w*n* sizeof(byte)) 85 | self._h = h 86 | self._w = w 87 | self._n = n 88 | # def __dealloc__(self): 89 | # the memory management of _mask has been passed to np.ndarray 90 | # it doesn't need to be freed here 91 | 92 | # called when passing into np.array() and return an np.ndarray in column-major order 93 | def __array__(self): 94 | cdef np.npy_intp shape[1] 95 | shape[0] = self._h*self._w*self._n 96 | # Create a 1D array, and reshape it to fortran/Matlab column-major array 97 | ndarray = np.PyArray_SimpleNewFromData(1, shape, np.NPY_UINT8, self._mask).reshape((self._h, self._w, self._n), order='F') 98 | # The _mask allocated by Masks is now handled by ndarray 99 | PyArray_ENABLEFLAGS(ndarray, np.NPY_OWNDATA) 100 | return ndarray 101 | 102 | # internal conversion from Python RLEs object to compressed RLE format 103 | def _toString(RLEs Rs): 104 | cdef siz n = Rs.n 105 | cdef bytes py_string 106 | cdef char* c_string 107 | objs = [] 108 | for i in range(n): 109 | c_string = rleToString( &Rs._R[i] ) 110 | py_string = c_string 111 | objs.append({ 112 | 'size': [Rs._R[i].h, Rs._R[i].w], 113 | 'counts': py_string 114 | }) 115 | free(c_string) 116 | return objs 117 | 118 | # internal conversion from compressed RLE format to Python RLEs object 119 | def _frString(rleObjs): 120 | cdef siz n = len(rleObjs) 121 | Rs = RLEs(n) 122 | cdef bytes py_string 123 | cdef char* c_string 124 | for i, obj in enumerate(rleObjs): 125 | if PYTHON_VERSION == 2: 126 | py_string = str(obj['counts']).encode('utf8') 127 | elif PYTHON_VERSION == 3: 128 | py_string = str.encode(obj['counts']) if type(obj['counts']) == str else obj['counts'] 129 | else: 130 | raise Exception('Python version must be 2 or 3') 131 | c_string = py_string 132 | rleFrString( &Rs._R[i], c_string, obj['size'][0], obj['size'][1] ) 133 | return Rs 134 | 135 | # encode mask to RLEs objects 136 | # list of RLE string can be generated by RLEs member function 137 | def encode(np.ndarray[np.uint8_t, ndim=3, mode='fortran'] mask): 138 | h, w, n = mask.shape[0], mask.shape[1], mask.shape[2] 139 | cdef RLEs Rs = RLEs(n) 140 | rleEncode(Rs._R,mask.data,h,w,n) 141 | objs = _toString(Rs) 142 | return objs 143 | 144 | # decode mask from compressed list of RLE string or RLEs object 145 | def decode(rleObjs): 146 | cdef RLEs Rs = _frString(rleObjs) 147 | h, w, n = Rs._R[0].h, Rs._R[0].w, Rs._n 148 | masks = Masks(h, w, n) 149 | rleDecode(Rs._R, masks._mask, n); 150 | return np.array(masks) 151 | 152 | def merge(rleObjs, intersect=0): 153 | cdef RLEs Rs = _frString(rleObjs) 154 | cdef RLEs R = RLEs(1) 155 | rleMerge(Rs._R, R._R, Rs._n, intersect) 156 | obj = _toString(R)[0] 157 | return obj 158 | 159 | def area(rleObjs): 160 | cdef RLEs Rs = _frString(rleObjs) 161 | cdef uint* _a = malloc(Rs._n* sizeof(uint)) 162 | rleArea(Rs._R, Rs._n, _a) 163 | cdef np.npy_intp shape[1] 164 | shape[0] = Rs._n 165 | a = np.array((Rs._n, ), dtype=np.uint8) 166 | a = np.PyArray_SimpleNewFromData(1, shape, np.NPY_UINT32, _a) 167 | PyArray_ENABLEFLAGS(a, np.NPY_OWNDATA) 168 | return a 169 | 170 | # iou computation. support function overload (RLEs-RLEs and bbox-bbox). 171 | def iou( dt, gt, pyiscrowd ): 172 | def _preproc(objs): 173 | if len(objs) == 0: 174 | return objs 175 | if type(objs) == np.ndarray: 176 | if len(objs.shape) == 1: 177 | objs = objs.reshape((objs[0], 1)) 178 | # check if it's Nx4 bbox 179 | if not len(objs.shape) == 2 or not objs.shape[1] == 4: 180 | raise Exception('numpy ndarray input is only for *bounding boxes* and should have Nx4 dimension') 181 | objs = objs.astype(np.double) 182 | elif type(objs) == list: 183 | # check if list is in box format and convert it to np.ndarray 184 | isbox = np.all(np.array([(len(obj)==4) and ((type(obj)==list) or (type(obj)==np.ndarray)) for obj in objs])) 185 | isrle = np.all(np.array([type(obj) == dict for obj in objs])) 186 | if isbox: 187 | objs = np.array(objs, dtype=np.double) 188 | if len(objs.shape) == 1: 189 | objs = objs.reshape((1,objs.shape[0])) 190 | elif isrle: 191 | objs = _frString(objs) 192 | else: 193 | raise Exception('list input can be bounding box (Nx4) or RLEs ([RLE])') 194 | else: 195 | raise Exception('unrecognized type. The following type: RLEs (rle), np.ndarray (box), and list (box) are supported.') 196 | return objs 197 | def _rleIou(RLEs dt, RLEs gt, np.ndarray[np.uint8_t, ndim=1] iscrowd, siz m, siz n, np.ndarray[np.double_t, ndim=1] _iou): 198 | rleIou( dt._R, gt._R, m, n, iscrowd.data, _iou.data ) 199 | def _bbIou(np.ndarray[np.double_t, ndim=2] dt, np.ndarray[np.double_t, ndim=2] gt, np.ndarray[np.uint8_t, ndim=1] iscrowd, siz m, siz n, np.ndarray[np.double_t, ndim=1] _iou): 200 | bbIou( dt.data, gt.data, m, n, iscrowd.data, _iou.data ) 201 | def _len(obj): 202 | cdef siz N = 0 203 | if type(obj) == RLEs: 204 | N = obj.n 205 | elif len(obj)==0: 206 | pass 207 | elif type(obj) == np.ndarray: 208 | N = obj.shape[0] 209 | return N 210 | # convert iscrowd to numpy array 211 | cdef np.ndarray[np.uint8_t, ndim=1] iscrowd = np.array(pyiscrowd, dtype=np.uint8) 212 | # simple type checking 213 | cdef siz m, n 214 | dt = _preproc(dt) 215 | gt = _preproc(gt) 216 | m = _len(dt) 217 | n = _len(gt) 218 | if m == 0 or n == 0: 219 | return [] 220 | if not type(dt) == type(gt): 221 | raise Exception('The dt and gt should have the same data type, either RLEs, list or np.ndarray') 222 | 223 | # define local variables 224 | cdef double* _iou = 0 225 | cdef np.npy_intp shape[1] 226 | # check type and assign iou function 227 | if type(dt) == RLEs: 228 | _iouFun = _rleIou 229 | elif type(dt) == np.ndarray: 230 | _iouFun = _bbIou 231 | else: 232 | raise Exception('input data type not allowed.') 233 | _iou = malloc(m*n* sizeof(double)) 234 | iou = np.zeros((m*n, ), dtype=np.double) 235 | shape[0] = m*n 236 | iou = np.PyArray_SimpleNewFromData(1, shape, np.NPY_DOUBLE, _iou) 237 | PyArray_ENABLEFLAGS(iou, np.NPY_OWNDATA) 238 | _iouFun(dt, gt, iscrowd, m, n, iou) 239 | return iou.reshape((m,n), order='F') 240 | 241 | def toBbox( rleObjs ): 242 | cdef RLEs Rs = _frString(rleObjs) 243 | cdef siz n = Rs.n 244 | cdef BB _bb = malloc(4*n* sizeof(double)) 245 | rleToBbox( Rs._R, _bb, n ) 246 | cdef np.npy_intp shape[1] 247 | shape[0] = 4*n 248 | bb = np.array((1,4*n), dtype=np.double) 249 | bb = np.PyArray_SimpleNewFromData(1, shape, np.NPY_DOUBLE, _bb).reshape((n, 4)) 250 | PyArray_ENABLEFLAGS(bb, np.NPY_OWNDATA) 251 | return bb 252 | 253 | def frBbox(np.ndarray[np.double_t, ndim=2] bb, siz h, siz w ): 254 | cdef siz n = bb.shape[0] 255 | Rs = RLEs(n) 256 | rleFrBbox( Rs._R, bb.data, h, w, n ) 257 | objs = _toString(Rs) 258 | return objs 259 | 260 | def frPoly( poly, siz h, siz w ): 261 | cdef np.ndarray[np.double_t, ndim=1] np_poly 262 | n = len(poly) 263 | Rs = RLEs(n) 264 | for i, p in enumerate(poly): 265 | np_poly = np.array(p, dtype=np.double, order='F') 266 | rleFrPoly( &Rs._R[i], np_poly.data, int(len(p)/2), h, w ) 267 | objs = _toString(Rs) 268 | return objs 269 | 270 | def frUncompressedRLE(ucRles, siz h, siz w): 271 | cdef np.ndarray[np.uint32_t, ndim=1] cnts 272 | cdef RLE R 273 | cdef uint *data 274 | n = len(ucRles) 275 | objs = [] 276 | for i in range(n): 277 | Rs = RLEs(1) 278 | cnts = np.array(ucRles[i]['counts'], dtype=np.uint32) 279 | # time for malloc can be saved here but it's fine 280 | data = malloc(len(cnts)* sizeof(uint)) 281 | for j in range(len(cnts)): 282 | data[j] = cnts[j] 283 | R = RLE(ucRles[i]['size'][0], ucRles[i]['size'][1], len(cnts), data) 284 | Rs._R[0] = R 285 | objs.append(_toString(Rs)[0]) 286 | return objs 287 | 288 | def frPyObjects(pyobj, h, w): 289 | # encode rle from a list of python objects 290 | if type(pyobj) == np.ndarray: 291 | objs = frBbox(pyobj, h, w) 292 | elif type(pyobj) == list and len(pyobj[0]) == 4: 293 | objs = frBbox(pyobj, h, w) 294 | elif type(pyobj) == list and len(pyobj[0]) > 4: 295 | objs = frPoly(pyobj, h, w) 296 | elif type(pyobj) == list and type(pyobj[0]) == dict \ 297 | and 'counts' in pyobj[0] and 'size' in pyobj[0]: 298 | objs = frUncompressedRLE(pyobj, h, w) 299 | # encode rle from single python object 300 | elif type(pyobj) == list and len(pyobj) == 4: 301 | objs = frBbox([pyobj], h, w)[0] 302 | elif type(pyobj) == list and len(pyobj) > 4: 303 | objs = frPoly([pyobj], h, w)[0] 304 | elif type(pyobj) == dict and 'counts' in pyobj and 'size' in pyobj: 305 | objs = frUncompressedRLE([pyobj], h, w)[0] 306 | else: 307 | raise Exception('input type is not supported.') 308 | return objs 309 | -------------------------------------------------------------------------------- /utils/pycocotools/coco.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | __version__ = '2.0' 3 | # Interface for accessing the Microsoft COCO dataset. 4 | 5 | # Microsoft COCO is a large image dataset designed for object detection, 6 | # segmentation, and caption generation. pycocotools is a Python API that 7 | # assists in loading, parsing and visualizing the annotations in COCO. 8 | # Please visit http://mscoco.org/ for more information on COCO, including 9 | # for the data, paper, and tutorials. The exact format of the annotations 10 | # is also described on the COCO website. For example usage of the pycocotools 11 | # please see pycocotools_demo.ipynb. In addition to this API, please download both 12 | # the COCO images and annotations in order to run the demo. 13 | 14 | # An alternative to using the API is to load the annotations directly 15 | # into Python dictionary 16 | # Using the API provides additional utility functions. Note that this API 17 | # supports both *instance* and *caption* annotations. In the case of 18 | # captions not all functions are defined (e.g. categories are undefined). 19 | 20 | # The following API functions are defined: 21 | # COCO - COCO api class that loads COCO annotation file and prepare data structures. 22 | # decodeMask - Decode binary mask M encoded via run-length encoding. 23 | # encodeMask - Encode binary mask M using run-length encoding. 24 | # getAnnIds - Get ann ids that satisfy given filter conditions. 25 | # getCatIds - Get cat ids that satisfy given filter conditions. 26 | # getImgIds - Get img ids that satisfy given filter conditions. 27 | # loadAnns - Load anns with the specified ids. 28 | # loadCats - Load cats with the specified ids. 29 | # loadImgs - Load imgs with the specified ids. 30 | # annToMask - Convert segmentation in an annotation to binary mask. 31 | # showAnns - Display the specified annotations. 32 | # loadRes - Load algorithm results and create API for accessing them. 33 | # download - Download COCO images from mscoco.org server. 34 | # Throughout the API "ann"=annotation, "cat"=category, and "img"=image. 35 | # Help on each functions can be accessed by: "help COCO>function". 36 | 37 | # See also COCO>decodeMask, 38 | # COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds, 39 | # COCO>getImgIds, COCO>loadAnns, COCO>loadCats, 40 | # COCO>loadImgs, COCO>annToMask, COCO>showAnns 41 | 42 | # Microsoft COCO Toolbox. version 2.0 43 | # Data, paper, and tutorials available at: http://mscoco.org/ 44 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2014. 45 | # Licensed under the Simplified BSD License [see bsd.txt] 46 | 47 | import json 48 | import time 49 | import matplotlib.pyplot as plt 50 | from matplotlib.collections import PatchCollection 51 | from matplotlib.patches import Polygon 52 | import numpy as np 53 | import copy 54 | import itertools 55 | # from cocoapi.PythonAPI.pycocotools import mask as maskUtils 56 | import os 57 | from collections import defaultdict 58 | import sys 59 | PYTHON_VERSION = sys.version_info[0] 60 | if PYTHON_VERSION == 2: 61 | from urllib import urlretrieve 62 | elif PYTHON_VERSION == 3: 63 | from urllib.request import urlretrieve 64 | 65 | 66 | def _isArrayLike(obj): 67 | return hasattr(obj, '__iter__') and hasattr(obj, '__len__') 68 | 69 | 70 | class COCO: 71 | def __init__(self, annotation_file=None): 72 | """ 73 | Constructor of Microsoft COCO helper class for reading and visualizing annotations. 74 | :param annotation_file (str): location of annotation file 75 | :param image_folder (str): location to the folder that hosts images. 76 | :return: 77 | """ 78 | # load dataset 79 | self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict() 80 | self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list) 81 | if not annotation_file == None: 82 | print('loading annotations into memory...') 83 | tic = time.time() 84 | dataset = json.load(open(annotation_file, 'r')) 85 | assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset)) 86 | print('Done (t={:0.2f}s)'.format(time.time()- tic)) 87 | self.dataset = dataset 88 | self.createIndex() 89 | 90 | def createIndex(self): 91 | # create index 92 | print('creating index...') 93 | anns, cats, imgs = {}, {}, {} 94 | imgToAnns,catToImgs = defaultdict(list),defaultdict(list) 95 | if 'annotations' in self.dataset: 96 | for ann in self.dataset['annotations']: 97 | imgToAnns[ann['image_id']].append(ann) 98 | anns[ann['id']] = ann 99 | 100 | if 'images' in self.dataset: 101 | for img in self.dataset['images']: 102 | imgs[img['id']] = img 103 | 104 | if 'categories' in self.dataset: 105 | for cat in self.dataset['categories']: 106 | cats[cat['id']] = cat 107 | 108 | if 'annotations' in self.dataset and 'categories' in self.dataset: 109 | for ann in self.dataset['annotations']: 110 | catToImgs[ann['category_id']].append(ann['image_id']) 111 | 112 | print('index created!') 113 | 114 | # create class members 115 | self.anns = anns 116 | self.imgToAnns = imgToAnns 117 | self.catToImgs = catToImgs 118 | self.imgs = imgs 119 | self.cats = cats 120 | 121 | def info(self): 122 | """ 123 | Print information about the annotation file. 124 | :return: 125 | """ 126 | for key, value in self.dataset['info'].items(): 127 | print('{}: {}'.format(key, value)) 128 | 129 | def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None): 130 | """ 131 | Get ann ids that satisfy given filter conditions. default skips that filter 132 | :param imgIds (int array) : get anns for given imgs 133 | catIds (int array) : get anns for given cats 134 | areaRng (float array) : get anns for given area range (e.g. [0 inf]) 135 | iscrowd (boolean) : get anns for given crowd label (False or True) 136 | :return: ids (int array) : integer array of ann ids 137 | """ 138 | imgIds = imgIds if _isArrayLike(imgIds) else [imgIds] 139 | catIds = catIds if _isArrayLike(catIds) else [catIds] 140 | 141 | if len(imgIds) == len(catIds) == len(areaRng) == 0: 142 | anns = self.dataset['annotations'] 143 | else: 144 | if not len(imgIds) == 0: 145 | lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns] 146 | anns = list(itertools.chain.from_iterable(lists)) 147 | else: 148 | anns = self.dataset['annotations'] 149 | anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds] 150 | anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]] 151 | if not iscrowd == None: 152 | ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd] 153 | else: 154 | ids = [ann['id'] for ann in anns] 155 | return ids 156 | 157 | def getCatIds(self, catNms=[], supNms=[], catIds=[]): 158 | """ 159 | filtering parameters. default skips that filter. 160 | :param catNms (str array) : get cats for given cat names 161 | :param supNms (str array) : get cats for given supercategory names 162 | :param catIds (int array) : get cats for given cat ids 163 | :return: ids (int array) : integer array of cat ids 164 | """ 165 | catNms = catNms if _isArrayLike(catNms) else [catNms] 166 | supNms = supNms if _isArrayLike(supNms) else [supNms] 167 | catIds = catIds if _isArrayLike(catIds) else [catIds] 168 | 169 | if len(catNms) == len(supNms) == len(catIds) == 0: 170 | cats = self.dataset['categories'] 171 | else: 172 | cats = self.dataset['categories'] 173 | cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms] 174 | cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms] 175 | cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds] 176 | ids = [cat['id'] for cat in cats] 177 | return ids 178 | 179 | def getImgIds(self, imgIds=[], catIds=[]): 180 | ''' 181 | Get img ids that satisfy given filter conditions. 182 | :param imgIds (int array) : get imgs for given ids 183 | :param catIds (int array) : get imgs with all given cats 184 | :return: ids (int array) : integer array of img ids 185 | ''' 186 | imgIds = imgIds if _isArrayLike(imgIds) else [imgIds] 187 | catIds = catIds if _isArrayLike(catIds) else [catIds] 188 | 189 | if len(imgIds) == len(catIds) == 0: 190 | ids = self.imgs.keys() 191 | else: 192 | ids = set(imgIds) 193 | for i, catId in enumerate(catIds): 194 | if i == 0 and len(ids) == 0: 195 | ids = set(self.catToImgs[catId]) 196 | else: 197 | ids &= set(self.catToImgs[catId]) 198 | return list(ids) 199 | 200 | def loadAnns(self, ids=[]): 201 | """ 202 | Load anns with the specified ids. 203 | :param ids (int array) : integer ids specifying anns 204 | :return: anns (object array) : loaded ann objects 205 | """ 206 | if _isArrayLike(ids): 207 | return [self.anns[id] for id in ids] 208 | elif type(ids) == int: 209 | return [self.anns[ids]] 210 | 211 | def loadCats(self, ids=[]): 212 | """ 213 | Load cats with the specified ids. 214 | :param ids (int array) : integer ids specifying cats 215 | :return: cats (object array) : loaded cat objects 216 | """ 217 | if _isArrayLike(ids): 218 | return [self.cats[id] for id in ids] 219 | elif type(ids) == int: 220 | return [self.cats[ids]] 221 | 222 | def loadImgs(self, ids=[]): 223 | """ 224 | Load anns with the specified ids. 225 | :param ids (int array) : integer ids specifying img 226 | :return: imgs (object array) : loaded img objects 227 | """ 228 | if _isArrayLike(ids): 229 | return [self.imgs[id] for id in ids] 230 | elif type(ids) == int: 231 | return [self.imgs[ids]] 232 | 233 | def showAnns(self, anns): 234 | """ 235 | Display the specified annotations. 236 | :param anns (array of object): annotations to display 237 | :return: None 238 | """ 239 | if len(anns) == 0: 240 | return 0 241 | if 'segmentation' in anns[0] or 'keypoints' in anns[0]: 242 | datasetType = 'instances' 243 | elif 'caption' in anns[0]: 244 | datasetType = 'captions' 245 | else: 246 | raise Exception('datasetType not supported') 247 | if datasetType == 'instances': 248 | ax = plt.gca() 249 | ax.set_autoscale_on(False) 250 | polygons = [] 251 | color = [] 252 | for ann in anns: 253 | c = (np.random.random((1, 3))*0.6+0.4).tolist()[0] 254 | if 'segmentation' in ann: 255 | if type(ann['segmentation']) == list: 256 | # polygon 257 | for seg in ann['segmentation']: 258 | poly = np.array(seg).reshape((int(len(seg)/2), 2)) 259 | polygons.append(Polygon(poly)) 260 | color.append(c) 261 | else: 262 | # mask 263 | t = self.imgs[ann['image_id']] 264 | if type(ann['segmentation']['counts']) == list: 265 | rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width']) 266 | else: 267 | rle = [ann['segmentation']] 268 | m = maskUtils.decode(rle) 269 | img = np.ones( (m.shape[0], m.shape[1], 3) ) 270 | if ann['iscrowd'] == 1: 271 | color_mask = np.array([2.0,166.0,101.0])/255 272 | if ann['iscrowd'] == 0: 273 | color_mask = np.random.random((1, 3)).tolist()[0] 274 | for i in range(3): 275 | img[:,:,i] = color_mask[i] 276 | ax.imshow(np.dstack( (img, m*0.5) )) 277 | if 'keypoints' in ann and type(ann['keypoints']) == list: 278 | # turn skeleton into zero-based index 279 | sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1 280 | kp = np.array(ann['keypoints']) 281 | x = kp[0::3] 282 | y = kp[1::3] 283 | v = kp[2::3] 284 | for sk in sks: 285 | if np.all(v[sk]>0): 286 | plt.plot(x[sk],y[sk], linewidth=3, color=c) 287 | plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2) 288 | plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2) 289 | p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4) 290 | ax.add_collection(p) 291 | p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2) 292 | ax.add_collection(p) 293 | elif datasetType == 'captions': 294 | for ann in anns: 295 | print(ann['caption']) 296 | 297 | def loadRes(self, resFile): 298 | """ 299 | Load result file and return a result api object. 300 | :param resFile (str) : file name of result file 301 | :return: res (obj) : result api object 302 | """ 303 | res = COCO() 304 | res.dataset['images'] = [img for img in self.dataset['images']] 305 | 306 | print('Loading and preparing results...') 307 | tic = time.time() 308 | if type(resFile) == str or type(resFile) == unicode: 309 | anns = json.load(open(resFile)) 310 | elif type(resFile) == np.ndarray: 311 | anns = self.loadNumpyAnnotations(resFile) 312 | else: 313 | anns = resFile 314 | assert type(anns) == list, 'results in not an array of objects' 315 | annsImgIds = [ann['image_id'] for ann in anns] 316 | assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \ 317 | 'Results do not correspond to current coco set' 318 | if 'caption' in anns[0]: 319 | imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns]) 320 | res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds] 321 | for id, ann in enumerate(anns): 322 | ann['id'] = id+1 323 | elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: 324 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 325 | for id, ann in enumerate(anns): 326 | bb = ann['bbox'] 327 | x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]] 328 | if not 'segmentation' in ann: 329 | ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] 330 | ann['area'] = bb[2]*bb[3] 331 | ann['id'] = id+1 332 | ann['iscrowd'] = 0 333 | elif 'segmentation' in anns[0]: 334 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 335 | for id, ann in enumerate(anns): 336 | # now only support compressed RLE format as segmentation results 337 | ann['area'] = maskUtils.area(ann['segmentation']) 338 | if not 'bbox' in ann: 339 | ann['bbox'] = maskUtils.toBbox(ann['segmentation']) 340 | ann['id'] = id+1 341 | ann['iscrowd'] = 0 342 | elif 'keypoints' in anns[0]: 343 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 344 | for id, ann in enumerate(anns): 345 | s = ann['keypoints'] 346 | x = s[0::3] 347 | y = s[1::3] 348 | x0,x1,y0,y1 = np.min(x), np.max(x), np.min(y), np.max(y) 349 | ann['area'] = (x1-x0)*(y1-y0) 350 | ann['id'] = id + 1 351 | ann['bbox'] = [x0,y0,x1-x0,y1-y0] 352 | print('DONE (t={:0.2f}s)'.format(time.time()- tic)) 353 | 354 | res.dataset['annotations'] = anns 355 | res.createIndex() 356 | return res 357 | 358 | def download(self, tarDir = None, imgIds = [] ): 359 | ''' 360 | Download COCO images from mscoco.org server. 361 | :param tarDir (str): COCO results directory name 362 | imgIds (list): images to be downloaded 363 | :return: 364 | ''' 365 | if tarDir is None: 366 | print('Please specify target directory') 367 | return -1 368 | if len(imgIds) == 0: 369 | imgs = self.imgs.values() 370 | else: 371 | imgs = self.loadImgs(imgIds) 372 | N = len(imgs) 373 | if not os.path.exists(tarDir): 374 | os.makedirs(tarDir) 375 | for i, img in enumerate(imgs): 376 | tic = time.time() 377 | fname = os.path.join(tarDir, img['file_name']) 378 | if not os.path.exists(fname): 379 | urlretrieve(img['coco_url'], fname) 380 | print('downloaded {}/{} images (t={:0.1f}s)'.format(i, N, time.time()- tic)) 381 | 382 | def loadNumpyAnnotations(self, data): 383 | """ 384 | Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class} 385 | :param data (numpy.ndarray) 386 | :return: annotations (python nested list) 387 | """ 388 | print('Converting ndarray to lists...') 389 | assert(type(data) == np.ndarray) 390 | print(data.shape) 391 | assert(data.shape[1] == 7) 392 | N = data.shape[0] 393 | ann = [] 394 | for i in range(N): 395 | if i % 1000000 == 0: 396 | print('{}/{}'.format(i,N)) 397 | ann += [{ 398 | 'image_id' : int(data[i, 0]), 399 | 'bbox' : [ data[i, 1], data[i, 2], data[i, 3], data[i, 4] ], 400 | 'score' : data[i, 5], 401 | 'category_id': int(data[i, 6]), 402 | }] 403 | return ann 404 | 405 | def annToRLE(self, ann): 406 | """ 407 | Convert annotation which can be polygons, uncompressed RLE to RLE. 408 | :return: binary mask (numpy 2D array) 409 | """ 410 | t = self.imgs[ann['image_id']] 411 | h, w = t['height'], t['width'] 412 | segm = ann['segmentation'] 413 | if type(segm) == list: 414 | # polygon -- a single object might consist of multiple parts 415 | # we merge all parts into one mask rle code 416 | rles = maskUtils.frPyObjects(segm, h, w) 417 | rle = maskUtils.merge(rles) 418 | elif type(segm['counts']) == list: 419 | # uncompressed RLE 420 | rle = maskUtils.frPyObjects(segm, h, w) 421 | else: 422 | # rle 423 | rle = ann['segmentation'] 424 | return rle 425 | 426 | def annToMask(self, ann): 427 | """ 428 | Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask. 429 | :return: binary mask (numpy 2D array) 430 | """ 431 | rle = self.annToRLE(ann) 432 | m = maskUtils.decode(rle) 433 | return m -------------------------------------------------------------------------------- /utils/pycocotools/cocoeval.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tsungyi' 2 | 3 | import numpy as np 4 | import datetime 5 | import time 6 | from collections import defaultdict 7 | from . import mask as maskUtils 8 | import copy 9 | 10 | class COCOeval: 11 | # Interface for evaluating detection on the Microsoft COCO dataset. 12 | # 13 | # The usage for CocoEval is as follows: 14 | # cocoGt=..., cocoDt=... # load dataset and results 15 | # E = CocoEval(cocoGt,cocoDt); # initialize CocoEval object 16 | # E.params.recThrs = ...; # set parameters as desired 17 | # E.evaluate(); # run per image evaluation 18 | # E.accumulate(); # accumulate per image results 19 | # E.summarize(); # display summary metrics of results 20 | # For example usage see evalDemo.m and http://mscoco.org/. 21 | # 22 | # The evaluation parameters are as follows (defaults in brackets): 23 | # imgIds - [all] N img ids to use for evaluation 24 | # catIds - [all] K cat ids to use for evaluation 25 | # iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation 26 | # recThrs - [0:.01:1] R=101 recall thresholds for evaluation 27 | # areaRng - [...] A=4 object area ranges for evaluation 28 | # maxDets - [1 10 100] M=3 thresholds on max detections per image 29 | # iouType - ['segm'] set iouType to 'segm', 'bbox' or 'keypoints' 30 | # iouType replaced the now DEPRECATED useSegm parameter. 31 | # useCats - [1] if true use category labels for evaluation 32 | # Note: if useCats=0 category labels are ignored as in proposal scoring. 33 | # Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified. 34 | # 35 | # evaluate(): evaluates detections on every image and every category and 36 | # concats the results into the "evalImgs" with fields: 37 | # dtIds - [1xD] id for each of the D detections (dt) 38 | # gtIds - [1xG] id for each of the G ground truths (gt) 39 | # dtMatches - [TxD] matching gt id at each IoU or 0 40 | # gtMatches - [TxG] matching dt id at each IoU or 0 41 | # dtScores - [1xD] confidence of each dt 42 | # gtIgnore - [1xG] ignore flag for each gt 43 | # dtIgnore - [TxD] ignore flag for each dt at each IoU 44 | # 45 | # accumulate(): accumulates the per-image, per-category evaluation 46 | # results in "evalImgs" into the dictionary "eval" with fields: 47 | # params - parameters used for evaluation 48 | # date - date evaluation was performed 49 | # counts - [T,R,K,A,M] parameter dimensions (see above) 50 | # precision - [TxRxKxAxM] precision for every evaluation setting 51 | # recall - [TxKxAxM] max recall for every evaluation setting 52 | # Note: precision and recall==-1 for settings with no gt objects. 53 | # 54 | # See also coco, mask, pycocoDemo, pycocoEvalDemo 55 | # 56 | # Microsoft COCO Toolbox. version 2.0 57 | # Data, paper, and tutorials available at: http://mscoco.org/ 58 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 59 | # Licensed under the Simplified BSD License [see coco/license.txt] 60 | def __init__(self, cocoGt=None, cocoDt=None, iouType='segm'): 61 | ''' 62 | Initialize CocoEval using coco APIs for gt and dt 63 | :param cocoGt: coco object with ground truth annotations 64 | :param cocoDt: coco object with detection results 65 | :return: None 66 | ''' 67 | if not iouType: 68 | print('iouType not specified. use default iouType segm') 69 | self.cocoGt = cocoGt # ground truth COCO API 70 | self.cocoDt = cocoDt # detections COCO API 71 | self.params = {} # evaluation parameters 72 | self.evalImgs = defaultdict(list) # per-image per-category evaluation results [KxAxI] elements 73 | self.eval = {} # accumulated evaluation results 74 | self._gts = defaultdict(list) # gt for evaluation 75 | self._dts = defaultdict(list) # dt for evaluation 76 | self.params = Params(iouType=iouType) # parameters 77 | self._paramsEval = {} # parameters for evaluation 78 | self.stats = [] # result summarization 79 | self.ious = {} # ious between all gts and dts 80 | if not cocoGt is None: 81 | self.params.imgIds = sorted(cocoGt.getImgIds()) 82 | self.params.catIds = sorted(cocoGt.getCatIds()) 83 | 84 | 85 | def _prepare(self): 86 | ''' 87 | Prepare ._gts and ._dts for evaluation based on params 88 | :return: None 89 | ''' 90 | def _toMask(anns, coco): 91 | # modify ann['segmentation'] by reference 92 | for ann in anns: 93 | rle = coco.annToRLE(ann) 94 | ann['segmentation'] = rle 95 | p = self.params 96 | if p.useCats: 97 | gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) 98 | dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) 99 | else: 100 | gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds)) 101 | dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds)) 102 | 103 | # convert ground truth to mask if iouType == 'segm' 104 | if p.iouType == 'segm': 105 | _toMask(gts, self.cocoGt) 106 | _toMask(dts, self.cocoDt) 107 | # set ignore flag 108 | for gt in gts: 109 | gt['ignore'] = gt['ignore'] if 'ignore' in gt else 0 110 | gt['ignore'] = 'iscrowd' in gt and gt['iscrowd'] 111 | if p.iouType == 'keypoints': 112 | gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore'] 113 | self._gts = defaultdict(list) # gt for evaluation 114 | self._dts = defaultdict(list) # dt for evaluation 115 | for gt in gts: 116 | self._gts[gt['image_id'], gt['category_id']].append(gt) 117 | for dt in dts: 118 | self._dts[dt['image_id'], dt['category_id']].append(dt) 119 | self.evalImgs = defaultdict(list) # per-image per-category evaluation results 120 | self.eval = {} # accumulated evaluation results 121 | 122 | def evaluate(self): 123 | ''' 124 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 125 | :return: None 126 | ''' 127 | tic = time.time() 128 | print('Running per image evaluation...') 129 | p = self.params 130 | # add backward compatibility if useSegm is specified in params 131 | if not p.useSegm is None: 132 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox' 133 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) 134 | print('Evaluate annotation type *{}*'.format(p.iouType)) 135 | p.imgIds = list(np.unique(p.imgIds)) 136 | if p.useCats: 137 | p.catIds = list(np.unique(p.catIds)) 138 | p.maxDets = sorted(p.maxDets) 139 | self.params=p 140 | 141 | self._prepare() 142 | # loop through images, area range, max detection number 143 | catIds = p.catIds if p.useCats else [-1] 144 | 145 | if p.iouType == 'segm' or p.iouType == 'bbox': 146 | computeIoU = self.computeIoU 147 | elif p.iouType == 'keypoints': 148 | computeIoU = self.computeOks 149 | self.ious = {(imgId, catId): computeIoU(imgId, catId) \ 150 | for imgId in p.imgIds 151 | for catId in catIds} 152 | 153 | evaluateImg = self.evaluateImg 154 | maxDet = p.maxDets[-1] 155 | self.evalImgs = [evaluateImg(imgId, catId, areaRng, maxDet) 156 | for catId in catIds 157 | for areaRng in p.areaRng 158 | for imgId in p.imgIds 159 | ] 160 | self._paramsEval = copy.deepcopy(self.params) 161 | toc = time.time() 162 | print('DONE (t={:0.2f}s).'.format(toc-tic)) 163 | 164 | def computeIoU(self, imgId, catId): 165 | p = self.params 166 | if p.useCats: 167 | gt = self._gts[imgId,catId] 168 | dt = self._dts[imgId,catId] 169 | else: 170 | gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]] 171 | dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]] 172 | if len(gt) == 0 and len(dt) ==0: 173 | return [] 174 | inds = np.argsort([-d['score'] for d in dt], kind='mergesort') 175 | dt = [dt[i] for i in inds] 176 | if len(dt) > p.maxDets[-1]: 177 | dt=dt[0:p.maxDets[-1]] 178 | 179 | if p.iouType == 'segm': 180 | g = [g['segmentation'] for g in gt] 181 | d = [d['segmentation'] for d in dt] 182 | elif p.iouType == 'bbox': 183 | g = [g['bbox'] for g in gt] 184 | d = [d['bbox'] for d in dt] 185 | else: 186 | raise Exception('unknown iouType for iou computation') 187 | 188 | # compute iou between each dt and gt region 189 | iscrowd = [int(o['iscrowd']) for o in gt] 190 | ious = maskUtils.iou(d,g,iscrowd) 191 | return ious 192 | 193 | def computeOks(self, imgId, catId): 194 | p = self.params 195 | # dimention here should be Nxm 196 | gts = self._gts[imgId, catId] 197 | dts = self._dts[imgId, catId] 198 | inds = np.argsort([-d['score'] for d in dts], kind='mergesort') 199 | dts = [dts[i] for i in inds] 200 | if len(dts) > p.maxDets[-1]: 201 | dts = dts[0:p.maxDets[-1]] 202 | # if len(gts) == 0 and len(dts) == 0: 203 | if len(gts) == 0 or len(dts) == 0: 204 | return [] 205 | ious = np.zeros((len(dts), len(gts))) 206 | sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0 207 | vars = (sigmas * 2)**2 208 | k = len(sigmas) 209 | # compute oks between each detection and ground truth object 210 | for j, gt in enumerate(gts): 211 | # create bounds for ignore regions(double the gt bbox) 212 | g = np.array(gt['keypoints']) 213 | xg = g[0::3]; yg = g[1::3]; vg = g[2::3] 214 | k1 = np.count_nonzero(vg > 0) 215 | bb = gt['bbox'] 216 | x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2 217 | y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2 218 | for i, dt in enumerate(dts): 219 | d = np.array(dt['keypoints']) 220 | xd = d[0::3]; yd = d[1::3] 221 | if k1>0: 222 | # measure the per-keypoint distance if keypoints visible 223 | dx = xd - xg 224 | dy = yd - yg 225 | else: 226 | # measure minimum distance to keypoints in (x0,y0) & (x1,y1) 227 | z = np.zeros((k)) 228 | dx = np.max((z, x0-xd),axis=0)+np.max((z, xd-x1),axis=0) 229 | dy = np.max((z, y0-yd),axis=0)+np.max((z, yd-y1),axis=0) 230 | e = (dx**2 + dy**2) / vars / (gt['area']+np.spacing(1)) / 2 231 | if k1 > 0: 232 | e=e[vg > 0] 233 | ious[i, j] = np.sum(np.exp(-e)) / e.shape[0] 234 | return ious 235 | 236 | def evaluateImg(self, imgId, catId, aRng, maxDet): 237 | ''' 238 | perform evaluation for single category and image 239 | :return: dict (single image results) 240 | ''' 241 | p = self.params 242 | if p.useCats: 243 | gt = self._gts[imgId,catId] 244 | dt = self._dts[imgId,catId] 245 | else: 246 | gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]] 247 | dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]] 248 | if len(gt) == 0 and len(dt) ==0: 249 | return None 250 | 251 | for g in gt: 252 | if g['ignore'] or (g['area']aRng[1]): 253 | g['_ignore'] = 1 254 | else: 255 | g['_ignore'] = 0 256 | 257 | # sort dt highest score first, sort gt ignore last 258 | gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort') 259 | gt = [gt[i] for i in gtind] 260 | dtind = np.argsort([-d['score'] for d in dt], kind='mergesort') 261 | dt = [dt[i] for i in dtind[0:maxDet]] 262 | iscrowd = [int(o['iscrowd']) for o in gt] 263 | # load computed ious 264 | ious = self.ious[imgId, catId][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId] 265 | 266 | T = len(p.iouThrs) 267 | G = len(gt) 268 | D = len(dt) 269 | gtm = np.zeros((T,G)) 270 | dtm = np.zeros((T,D)) 271 | gtIg = np.array([g['_ignore'] for g in gt]) 272 | dtIg = np.zeros((T,D)) 273 | if not len(ious)==0: 274 | for tind, t in enumerate(p.iouThrs): 275 | for dind, d in enumerate(dt): 276 | # information about best match so far (m=-1 -> unmatched) 277 | iou = min([t,1-1e-10]) 278 | m = -1 279 | for gind, g in enumerate(gt): 280 | # if this gt already matched, and not a crowd, continue 281 | if gtm[tind,gind]>0 and not iscrowd[gind]: 282 | continue 283 | # if dt matched to reg gt, and on ignore gt, stop 284 | if m>-1 and gtIg[m]==0 and gtIg[gind]==1: 285 | break 286 | # continue to next gt unless better match made 287 | if ious[dind,gind] < iou: 288 | continue 289 | # if match successful and best so far, store appropriately 290 | iou=ious[dind,gind] 291 | m=gind 292 | # if match made store id of match for both dt and gt 293 | if m ==-1: 294 | continue 295 | dtIg[tind,dind] = gtIg[m] 296 | dtm[tind,dind] = gt[m]['id'] 297 | gtm[tind,m] = d['id'] 298 | # set unmatched detections outside of area range to ignore 299 | a = np.array([d['area']aRng[1] for d in dt]).reshape((1, len(dt))) 300 | dtIg = np.logical_or(dtIg, np.logical_and(dtm==0, np.repeat(a,T,0))) 301 | # store results for given image and category 302 | return { 303 | 'image_id': imgId, 304 | 'category_id': catId, 305 | 'aRng': aRng, 306 | 'maxDet': maxDet, 307 | 'dtIds': [d['id'] for d in dt], 308 | 'gtIds': [g['id'] for g in gt], 309 | 'dtMatches': dtm, 310 | 'gtMatches': gtm, 311 | 'dtScores': [d['score'] for d in dt], 312 | 'gtIgnore': gtIg, 313 | 'dtIgnore': dtIg, 314 | } 315 | 316 | def accumulate(self, p = None): 317 | ''' 318 | Accumulate per image evaluation results and store the result in self.eval 319 | :param p: input params for evaluation 320 | :return: None 321 | ''' 322 | print('Accumulating evaluation results...') 323 | tic = time.time() 324 | if not self.evalImgs: 325 | print('Please run evaluate() first') 326 | # allows input customized parameters 327 | if p is None: 328 | p = self.params 329 | p.catIds = p.catIds if p.useCats == 1 else [-1] 330 | T = len(p.iouThrs) 331 | R = len(p.recThrs) 332 | K = len(p.catIds) if p.useCats else 1 333 | A = len(p.areaRng) 334 | M = len(p.maxDets) 335 | precision = -np.ones((T,R,K,A,M)) # -1 for the precision of absent categories 336 | recall = -np.ones((T,K,A,M)) 337 | scores = -np.ones((T,R,K,A,M)) 338 | 339 | # create dictionary for future indexing 340 | _pe = self._paramsEval 341 | catIds = _pe.catIds if _pe.useCats else [-1] 342 | setK = set(catIds) 343 | setA = set(map(tuple, _pe.areaRng)) 344 | setM = set(_pe.maxDets) 345 | setI = set(_pe.imgIds) 346 | # get inds to evaluate 347 | k_list = [n for n, k in enumerate(p.catIds) if k in setK] 348 | m_list = [m for n, m in enumerate(p.maxDets) if m in setM] 349 | a_list = [n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) if a in setA] 350 | i_list = [n for n, i in enumerate(p.imgIds) if i in setI] 351 | I0 = len(_pe.imgIds) 352 | A0 = len(_pe.areaRng) 353 | # retrieve E at each category, area range, and max number of detections 354 | for k, k0 in enumerate(k_list): 355 | Nk = k0*A0*I0 356 | for a, a0 in enumerate(a_list): 357 | Na = a0*I0 358 | for m, maxDet in enumerate(m_list): 359 | E = [self.evalImgs[Nk + Na + i] for i in i_list] 360 | E = [e for e in E if not e is None] 361 | if len(E) == 0: 362 | continue 363 | dtScores = np.concatenate([e['dtScores'][0:maxDet] for e in E]) 364 | 365 | # different sorting method generates slightly different results. 366 | # mergesort is used to be consistent as Matlab implementation. 367 | inds = np.argsort(-dtScores, kind='mergesort') 368 | dtScoresSorted = dtScores[inds] 369 | 370 | dtm = np.concatenate([e['dtMatches'][:,0:maxDet] for e in E], axis=1)[:,inds] 371 | dtIg = np.concatenate([e['dtIgnore'][:,0:maxDet] for e in E], axis=1)[:,inds] 372 | gtIg = np.concatenate([e['gtIgnore'] for e in E]) 373 | npig = np.count_nonzero(gtIg==0 ) 374 | if npig == 0: 375 | continue 376 | tps = np.logical_and( dtm, np.logical_not(dtIg) ) 377 | fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg) ) 378 | 379 | tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float) 380 | fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float) 381 | for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): 382 | tp = np.array(tp) 383 | fp = np.array(fp) 384 | nd = len(tp) 385 | rc = tp / npig 386 | pr = tp / (fp+tp+np.spacing(1)) 387 | q = np.zeros((R,)) 388 | ss = np.zeros((R,)) 389 | 390 | if nd: 391 | recall[t,k,a,m] = rc[-1] 392 | else: 393 | recall[t,k,a,m] = 0 394 | 395 | # numpy is slow without cython optimization for accessing elements 396 | # use python array gets significant speed improvement 397 | pr = pr.tolist(); q = q.tolist() 398 | 399 | for i in range(nd-1, 0, -1): 400 | if pr[i] > pr[i-1]: 401 | pr[i-1] = pr[i] 402 | 403 | inds = np.searchsorted(rc, p.recThrs, side='left') 404 | try: 405 | for ri, pi in enumerate(inds): 406 | q[ri] = pr[pi] 407 | ss[ri] = dtScoresSorted[pi] 408 | except: 409 | pass 410 | precision[t,:,k,a,m] = np.array(q) 411 | scores[t,:,k,a,m] = np.array(ss) 412 | self.eval = { 413 | 'params': p, 414 | 'counts': [T, R, K, A, M], 415 | 'date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 416 | 'precision': precision, 417 | 'recall': recall, 418 | 'scores': scores, 419 | } 420 | toc = time.time() 421 | print('DONE (t={:0.2f}s).'.format( toc-tic)) 422 | 423 | def summarize(self): 424 | ''' 425 | Compute and display summary metrics for evaluation results. 426 | Note this functin can *only* be applied on the default parameter setting 427 | ''' 428 | def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ): 429 | p = self.params 430 | iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}' 431 | titleStr = 'Average Precision' if ap == 1 else 'Average Recall' 432 | typeStr = '(AP)' if ap==1 else '(AR)' 433 | iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \ 434 | if iouThr is None else '{:0.2f}'.format(iouThr) 435 | 436 | aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng] 437 | mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] 438 | if ap == 1: 439 | # dimension of precision: [TxRxKxAxM] 440 | s = self.eval['precision'] 441 | # IoU 442 | if iouThr is not None: 443 | t = np.where(iouThr == p.iouThrs)[0] 444 | s = s[t] 445 | s = s[:,:,:,aind,mind] 446 | else: 447 | # dimension of recall: [TxKxAxM] 448 | s = self.eval['recall'] 449 | if iouThr is not None: 450 | t = np.where(iouThr == p.iouThrs)[0] 451 | s = s[t] 452 | s = s[:,:,aind,mind] 453 | if len(s[s>-1])==0: 454 | mean_s = -1 455 | else: 456 | mean_s = np.mean(s[s>-1]) 457 | print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s)) 458 | return mean_s 459 | def _summarizeDets(): 460 | stats = np.zeros((12,)) 461 | stats[0] = _summarize(1) 462 | stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2]) 463 | stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2]) 464 | stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2]) 465 | stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2]) 466 | stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2]) 467 | stats[6] = _summarize(0, maxDets=self.params.maxDets[0]) 468 | stats[7] = _summarize(0, maxDets=self.params.maxDets[1]) 469 | stats[8] = _summarize(0, maxDets=self.params.maxDets[2]) 470 | stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2]) 471 | stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2]) 472 | stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2]) 473 | return stats 474 | def _summarizeKps(): 475 | stats = np.zeros((10,)) 476 | stats[0] = _summarize(1, maxDets=20) 477 | stats[1] = _summarize(1, maxDets=20, iouThr=.5) 478 | stats[2] = _summarize(1, maxDets=20, iouThr=.75) 479 | stats[3] = _summarize(1, maxDets=20, areaRng='medium') 480 | stats[4] = _summarize(1, maxDets=20, areaRng='large') 481 | stats[5] = _summarize(0, maxDets=20) 482 | stats[6] = _summarize(0, maxDets=20, iouThr=.5) 483 | stats[7] = _summarize(0, maxDets=20, iouThr=.75) 484 | stats[8] = _summarize(0, maxDets=20, areaRng='medium') 485 | stats[9] = _summarize(0, maxDets=20, areaRng='large') 486 | return stats 487 | if not self.eval: 488 | raise Exception('Please run accumulate() first') 489 | iouType = self.params.iouType 490 | if iouType == 'segm' or iouType == 'bbox': 491 | summarize = _summarizeDets 492 | elif iouType == 'keypoints': 493 | summarize = _summarizeKps 494 | self.stats = summarize() 495 | 496 | def __str__(self): 497 | self.summarize() 498 | 499 | class Params: 500 | ''' 501 | Params for coco evaluation api 502 | ''' 503 | def setDetParams(self): 504 | self.imgIds = [] 505 | self.catIds = [] 506 | # np.arange causes trouble. the data point on arange is slightly larger than the true value 507 | self.iouThrs = np.linspace(.5, 0.95, np.round((0.95 - .5) / .05) + 1, endpoint=True) 508 | self.recThrs = np.linspace(.0, 1.00, np.round((1.00 - .0) / .01) + 1, endpoint=True) 509 | self.maxDets = [1, 10, 100] 510 | self.areaRng = [[0 ** 2, 1e5 ** 2], [0 ** 2, 32 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]] 511 | self.areaRngLbl = ['all', 'small', 'medium', 'large'] 512 | self.useCats = 1 513 | 514 | def setKpParams(self): 515 | self.imgIds = [] 516 | self.catIds = [] 517 | # np.arange causes trouble. the data point on arange is slightly larger than the true value 518 | self.iouThrs = np.linspace(.5, 0.95, np.round((0.95 - .5) / .05) + 1, endpoint=True) 519 | self.recThrs = np.linspace(.0, 1.00, np.round((1.00 - .0) / .01) + 1, endpoint=True) 520 | self.maxDets = [20] 521 | self.areaRng = [[0 ** 2, 1e5 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]] 522 | self.areaRngLbl = ['all', 'medium', 'large'] 523 | self.useCats = 1 524 | 525 | def __init__(self, iouType='segm'): 526 | if iouType == 'segm' or iouType == 'bbox': 527 | self.setDetParams() 528 | elif iouType == 'keypoints': 529 | self.setKpParams() 530 | else: 531 | raise Exception('iouType not supported') 532 | self.iouType = iouType 533 | # useSegm is deprecated 534 | self.useSegm = None --------------------------------------------------------------------------------