├── LICENSE ├── README.md ├── requirements.txt ├── resmem ├── __init__.py ├── __main__.py ├── analysis.py ├── model.py └── utils.py ├── sample.py └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | ResMem Non-commercial License 2 | © 2021 The University of Chicago. 3 | 4 | Redistribution and use for noncommercial purposes in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. The software is used solely for noncommercial purposes. It may not be used indirectly for commercial use, such as on a website that accepts advertising money for content. Noncommercial use does include use by a for-profit company in its research. For commercial use rights, contact the Brain Bridge Lab at The University of Chicago, at wilma@uchicago.edu. 7 | 8 | 2. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 9 | 10 | 3. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 11 | 12 | 4. Neither the name of The University of Chicago nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 13 | THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF CHICAGO AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF CHICAGO OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 14 | 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ResMem 2 | 3 | This is a package that wraps [ResMem](https://coen.needell.co/project/memnet/). This is a residual neural network that 4 | estimates the memorability of an image, on a scale of 0 to 1. 5 | 6 | ## How to Use 7 | To install from PyPI: 8 | ```shell 9 | pip install resmem 10 | ``` 11 | 12 | The package contains two objects, ResMem itself, and a preprocessing transformer function built on torchvision. 13 | ```python 14 | from resmem import ResMem, transformer 15 | 16 | model = ResMem(pretrained=True) 17 | 18 | ``` The `transformer` is designed to be used with pillow. 19 | 20 | ```python 21 | from PIL import Image 22 | 23 | img = Image.open('./path/to/image') # This loads your image into memory 24 | img = img.convert('RGB') 25 | # This will convert your image into RGB, for instance if it's a PNG (RGBA) or if it's black and white. 26 | 27 | model.eval() 28 | # Set the model to inference mode. 29 | 30 | image_x = transformer(img) 31 | # Run the preprocessing function 32 | 33 | prediction = model(image_x.view(-1, 3, 227, 227)) 34 | # For a single image, the image must be reshaped into a batch 35 | # with size 1. 36 | # Get your prediction! 37 | ``` 38 | 39 | For more advanced usage, see the `sample.py` file in this repository. 40 | 41 | ## Description of the Model 42 | 43 | Historically, the next big advancement in using neural networks after AlexNet, the basis for MemNet, was ResNet. This allowed for convolutional neural networks to be built deeper, with more layers, without the gradient exploding or vanishing. Knowing that, let's try to include this in our model. What we will do is take a pre-trained ResNet, that is the whole thing, not just the convolutional features, and add it as an input feature for our regression step. The code for this is [here.](https://www.coeneedell.com/appendix/memnet_extras/#resmem) 44 | 45 | ![ResMem Diagram](ResMem.jpg) 46 | 47 | For the following discussion, while ResMem is initialized to a pre-trained optimum, I have allowed it to retrain for our problem. The thought is that given a larger set of parameters the final model *should* be more generalizable. Using weights and biases, we can run a hyperparameter sweep. 48 | 49 | ![ResMem Testing](resnetsweep.png) 50 | 51 | Here we can see much higher peaks, reaching into the range of 0.66-0.67! All of these runs were both trained and validated on a dataset that was constructed from both MemCat and LaMem databases. 52 | 53 | ## Github Instructions 54 | 55 | If you want to fiddle around with the raw github sourcecode, go right ahead. But be aware that we use git lfs for the 56 | model storage. You'll have to install the package at https://git-lfs.github.com/ and then run: 57 | ```shell 58 | git lfs install 59 | git lfs pull 60 | ``` 61 | to get access to the full model. Also note that the `tests.py` file that's included with this repository references a 62 | database that we, at this time, do not have permission to distribute, so consider that test file to be a guide rather 63 | than a hard-and-fast test. 64 | 65 | If you cannot access the model through git lfs, you can get a fresh copy of the model file from the pypi version of the package, or you can download the weights from (osf)[https://osf.io/qf5ry/]. 66 | 67 | ## Analysis 68 | 69 | New in 2023: 70 | 71 | ResMem now includes an analysis module. This is a set of tools for accessing the direct activations of specific filters within the convolutional neural network. 72 | Generally features can be accessed by: 73 | 74 | 1. Selecting a filter to record data from. 75 | 2. Running an image through the model. 76 | 3. Recording the activation of the selected filter. 77 | 78 | This behavior is implemented using the `SaveFeatures` class in the `analysis` module. 79 | For example: 80 | 81 | ```python 82 | from resmem.model import ResMem, transformer 83 | from resmem.analysis import SaveFeatures 84 | import numpy as np 85 | from PIL import Image 86 | 87 | # Generating a random image for demonstration purposes. You would want to use the image you're analyzing. 88 | randim = np.random.randint(0, 256, (256, 256, 3)) 89 | randim = Image.fromarray(randim, mode='RGB') 90 | 91 | # Set up ResMem for prediction 92 | model = ResMem(pretrained=True) 93 | model.eval() 94 | 95 | activation = SaveFeatures(list(list(model.features.children())[5].children())[4]) 96 | # Notice the numbers here, we have to select specific features a priori. We can't just record everything, because it would quickly fill up the computer's RAM. 97 | # In the SaveFeatures object, we are selecting a specific *Residual Block*. The blocks are organized into "sections". Above we select section 5, block 4. Note that the children selector can also select non-convolutional modules within ResMem 98 | memorability = model(transformer(randim).view(-1, 3, 227, 227)) 99 | # Now the activation object contains some data. 100 | # The first index here should always be zero, it's selecting a batch index. In the case where you're running this in batches, you'll want to slice across that index instead. 101 | # The second index selects a filter in that layer (the filter associated with a certain channel). 102 | activation.features[0, 1] 103 | # Even then, this filter activation is a 29x29 tensor, one for each time the filter was applied to some subimage. 104 | # For analyzing filter activations, you may simply average these individual neuron-level activations. 105 | activation.features[0, 1].mean() 106 | ``` 107 | 108 | This is all pretty complicated, and requires referencing a map of the submodules within ResMem. 109 | For example, "sections" 0-3 are all single-module transformations, so their activation information is pretty much useless, and there are no children to descend into. 110 | "Section" 4 has three residual blocks, each of which has three convolutional layers, three batch norms, a relu, and a resampling filter. 111 | The point here is that sections are not of consistent size. 112 | So, we've implemented a few helper structures for common tasks. 113 | 114 | ### Getting convolutional layers by depth 115 | 116 | Despite being named "ResNet-152," there are 155 convolutional layers in ResMem's resnet branch. 117 | Using the new interface, we can examine a specific depth, ignoring non-convolutional modules in the model. 118 | 119 | ```python 120 | ... 121 | 122 | activation = SaveFeatures(model.resnet_layers[24]) 123 | memorability = model(transformer(randim).view(-1, 3, 227, 227)) 124 | activation.features[0, 1].mean() 125 | # or to get all features from that layer 126 | activation.features[0, :].mean(dim=(1,2)) 127 | ``` 128 | 129 | We wrap all of this in the new `ResMem.resnet_activation()` method. 130 | 131 | ```python 132 | activations = model.resnet_activation(transformer(randim).view(-1, 3, 227, 227), 24) 133 | ``` 134 | 135 | This method does support batching, and will return the activation of every filter in the selected layer. 136 | 137 | 138 | 139 | ## Citation 140 | 141 | ``` 142 | @article{Needell_Bainbridge_2022, 143 | title={Embracing New Techniques in Deep Learning for Estimating Image Memorability}, 144 | volume={5}, 145 | ISSN={2522-0861, 146 | 2522-087X}, 147 | url={https://link.springer.com/10.1007/s42113-022-00126-5}, 148 | DOI={10.1007/s42113-022-00126-5}, 149 | number={2}, 150 | journal={Computational Brain & Behavior}, 151 | author={Needell, Coen D. and Bainbridge, Wilma A.}, 152 | year={2022}, 153 | month=jun, 154 | pages={168–184}, 155 | language={en} } 156 | ``` 157 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | click 4 | -------------------------------------------------------------------------------- /resmem/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /resmem/__main__.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from .model import ResMem, transformer 3 | from .utils import DirectoryImageset 4 | from torch.utils.data import DataLoader 5 | import click 6 | 7 | @click.group() 8 | def cli(): 9 | pass 10 | 11 | @cli.command() 12 | @click.argument('directory', type=click.Path(exists=True)) 13 | @click.argument('layer', type=click.INT) 14 | @click.option('--output', type=click.Path(), help='Output csv') 15 | def resmem_features(directory, layer, output=None): 16 | """Get resmem activations for all filters in a layer, for all images in a directory.""" 17 | import pandas as pd 18 | dataset = DirectoryImageset(directory, transform=transformer) 19 | dl = DataLoader(dataset, batch_size=1, shuffle=False) 20 | model = ResMem(pretrained=True) 21 | model.eval() 22 | outputs = {} 23 | for name, image in dl: 24 | activations = model.resnet_activation(image.view(-1, 3, 227, 227), layer).detach().cpu().numpy().flatten() 25 | outputs[name] = activations 26 | df = pd.DataFrame.from_dict(outputs, orient='index') 27 | if output: 28 | df.to_csv(output) 29 | else: 30 | df.to_csv(f'{directory}_resmem_features.csv') 31 | 32 | 33 | cli.add_command(resmem_features) 34 | if __name__ == "__main__": 35 | cli() 36 | -------------------------------------------------------------------------------- /resmem/analysis.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import shutil 4 | import numpy as np 5 | 6 | import torch 7 | from torch import nn 8 | from torch.utils.data import Dataset 9 | from resmem.model import ResMem 10 | from resmem.utils import SaveFeatures 11 | from torchvision.transforms import ToTensor, ToPILImage, Normalize, Resize 12 | from PIL import Image, ImageFilter 13 | from tqdm import tqdm 14 | 15 | 16 | 17 | def get_gaussian_kernel(kernel_size=3, sigma=1, channels=3): 18 | """ 19 | https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/7 20 | """ 21 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 22 | x_coord = torch.arange(kernel_size) 23 | x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) 24 | y_grid = x_grid.t() 25 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float().cuda() 26 | 27 | mean = (kernel_size - 1) / 2. 28 | variance = sigma ** 2. 29 | 30 | # Calculate the 2-dimensional gaussian kernel which is 31 | # the product of two gaussian distributions for two different 32 | # variables (in this case called x and y) 33 | gaussian_kernel = (1. / (2. * math.pi * variance)) * \ 34 | torch.exp( 35 | -torch.sum((xy_grid - mean) ** 2., dim=-1) / 36 | (2 * variance) 37 | ) 38 | 39 | # Make sure sum of values in gaussian kernel equals 1. 40 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 41 | 42 | # Reshape to 2d depthwise convolutional weight 43 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) 44 | gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) 45 | gaussian_kernel.cuda() 46 | 47 | gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels, 48 | kernel_size=kernel_size, groups=channels, bias=False) 49 | gaussian_filter.cuda() 50 | gaussian_filter.weight.data = gaussian_kernel 51 | gaussian_filter.weight.requires_grad = False 52 | 53 | return gaussian_filter 54 | 55 | 56 | class Viz: 57 | """ 58 | This class will be used to generate a maximally-activating image for a given filter. 59 | """ 60 | def __init__(self, size=67, upscaling_steps=16, upscaling_factor=1.2, device='cuda', branch='alex', layer=0): 61 | self.size, self.upscaling_steps, self.upscaling_factor = size, upscaling_steps, upscaling_factor 62 | self.device = device 63 | self.model = ResMem(pretrained=True).to(device).eval() 64 | if branch == 'alex': 65 | self.target = self.model.alexnet_layers[layer] 66 | elif branch == 'resnet': 67 | self.target = self.model.resnet_layers[layer] 68 | else: 69 | raise ValueError('Branch must be alex or resnet.') 70 | 71 | self.output = None 72 | self.normer = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 73 | 74 | def visualize(self, filt, lr=1e-2, opt_steps=36): 75 | sz = self.size 76 | img = Image.fromarray(np.uint8(np.random.uniform(150, 180, (sz, sz, 3)))) 77 | activations = SaveFeatures(self.target) 78 | gaussian_filter = get_gaussian_kernel() 79 | self.model.zero_grad() 80 | for outer in tqdm(range(self.upscaling_steps), leave=False): 81 | img_var = torch.unsqueeze(ToTensor()(img), 0).to(self.device).requires_grad_(True) 82 | img_var.requires_grad_(True).to(self.device) 83 | optimizer = torch.optim.Adam([img_var], lr=lr, weight_decay=1e-6) 84 | 85 | pbar = tqdm(range(opt_steps), leave=False) 86 | for n in pbar: 87 | optimizer.zero_grad() 88 | self.model.conv_forward(img_var) 89 | loss = -activations.features[0, filt].mean() + 0.00*torch.norm(img_var) 90 | loss.backward() 91 | pbar.set_description(f'Loss: {loss.item()}') 92 | optimizer.step() 93 | 94 | sz = int(sz * self.upscaling_factor) 95 | img = ToPILImage()(img_var.squeeze(0)) 96 | if outer != self.upscaling_steps: 97 | img = img.resize((sz, sz)) 98 | img = img.filter(ImageFilter.BoxBlur(2)) 99 | self.output = img.copy() 100 | activations.close() 101 | 102 | def save(self, layer, filt): 103 | if self.branch == 'alex': 104 | self.output.save(f'alex/layer_{layer}_filter_{filt}.jpg', 'JPEG') 105 | else: 106 | self.output.save(f'resnet/layer_{self.rn_address}-{layer}_filter_{filt}.jpg', 'JPEG') 107 | 108 | class ImgActivations: 109 | """ 110 | This class can be used to extract the top 9 images in a dataset that activate a given filter the most. 111 | """ 112 | def __init__(self, dset: Dataset, branch='alex', layer=0): 113 | self.branch = branch 114 | self.model = ResMem(pretrained=True).cuda().eval() 115 | self.target = self.model 116 | self.output = [] 117 | self.dset = dset 118 | 119 | if branch == 'alex': 120 | self.target = self.model.alexnet_layers[layer] 121 | elif branch == 'resnet': 122 | self.target = self.model.resnet_layers[layer] 123 | else: 124 | raise ValueError('Branch must be alex or resnet.') 125 | 126 | 127 | def calculate(self, filt): 128 | activations = SaveFeatures(self.target) 129 | levels = [] 130 | names = [] 131 | for img in self.dset: 132 | x, y, name = img 133 | self.model(x.cuda().view(-1, 3, 227, 227)) 134 | levels.append(activations.features[0, filt].mean().detach().cpu()) 135 | names.append(name) 136 | 137 | self.output = np.array(names)[np.argsort(levels)[-9:]].flatten() 138 | 139 | def draw(self, filt): 140 | if not os.path.exists(f'./figs/{self.rn_address}-{filt}'): 141 | os.mkdir(f'./figs/{self.rn_address}-{filt}') 142 | for f in self.output: 143 | shortfname = f.split('/')[-1] 144 | shutil.copy2(f, f'./figs/{self.rn_address}-{filt}/'+shortfname) 145 | 146 | -------------------------------------------------------------------------------- /resmem/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision.models import resnet152, ResNet152_Weights 3 | from torchvision import transforms 4 | import torch 5 | import torch.nn.functional as f 6 | from resmem.utils import SaveFeatures 7 | from pickle import UnpicklingError 8 | from pathlib import Path 9 | 10 | path = Path(__file__).parent / "../resmem/model.pt" 11 | 12 | transformer = transforms.Compose(( 13 | transforms.Resize((256, 256)), 14 | transforms.CenterCrop(227), 15 | transforms.ToTensor() 16 | ) 17 | ) 18 | 19 | cpu = torch.device('cpu') 20 | 21 | class ResMem(nn.Module): 22 | def __init__(self, learning_rate=1e-5, momentum=.9, cruise_altitude=384, pretrained=False): 23 | super().__init__() 24 | if pretrained: 25 | weights = ResNet152_Weights.DEFAULT 26 | else: 27 | weights = None 28 | 29 | self.features = resnet152(weights=weights) 30 | for param in self.features.parameters(): 31 | param.requires_grad = False 32 | 33 | self.conv1 = nn.Conv2d(3, 48, kernel_size=11, stride=4) 34 | self.pool1 = nn.MaxPool2d(3, 2) 35 | self.lrn1 = nn.LocalResponseNorm(5) 36 | self.conv2 = nn.Conv2d(48, 256, kernel_size=5, stride=1, padding=2, groups=2) 37 | self.pool2 = nn.MaxPool2d(3, 2) 38 | self.lrn2 = nn.LocalResponseNorm(5) 39 | self.conv3 = nn.Conv2d(256, cruise_altitude, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2) 40 | self.conv4 = nn.Conv2d(cruise_altitude, cruise_altitude, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), 41 | groups=2) 42 | self.conv5 = nn.Conv2d(cruise_altitude, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2) 43 | self.pool5 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) 44 | 45 | self.fc6 = nn.Linear(in_features=9216 + 1000, out_features=4096, bias=True) 46 | self.drp6 = nn.Dropout(p=0.5, inplace=False) 47 | self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True) 48 | self.drp7 = nn.Dropout(p=0.5, inplace=False) 49 | self.fc8 = nn.Linear(in_features=4096, out_features=2048, bias=True) 50 | self.drp8 = nn.Dropout(p=0.5, inplace=False) 51 | self.fc9 = nn.Linear(in_features=2048, out_features=1024, bias=True) 52 | self.fc10 = nn.Linear(in_features=1024, out_features=512, bias=True) 53 | self.fc11 = nn.Linear(in_features=512, out_features=256, bias=True) 54 | self.fc12 = nn.Linear(in_features=256, out_features=1, bias=True) 55 | 56 | self.optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate, momentum=momentum) 57 | self.resnet_layers = self.get_layers(self.features) 58 | self.alexnet_layers = [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5] 59 | 60 | self.learning_rate = learning_rate 61 | if pretrained: 62 | try: 63 | print(path) 64 | self.load_state_dict(torch.load(path, map_location=cpu)) 65 | except UnpicklingError: 66 | raise TypeError("Could not find the model, try running git lfs pull. If you haven't installed git lfs, " 67 | "you can here: https://git-lfs.github.com/") 68 | 69 | @staticmethod 70 | def get_layers(module: nn.Module): 71 | layers = [] 72 | for child in module.children(): 73 | if isinstance(child, torch.nn.modules.conv.Conv2d): 74 | layers.append(child) 75 | else: 76 | layers += ResMem.get_layers(child) 77 | return layers 78 | 79 | 80 | def resnet_activation(self, x, depth): 81 | activation = SaveFeatures(self.resnet_layers[depth]) 82 | was_training = self.training 83 | if self.training: 84 | self.eval() 85 | with torch.no_grad(): 86 | self.forward(x) 87 | if was_training: 88 | self.train() 89 | return activation.features[:, :].mean(dim=(-1,-2)) 90 | 91 | def conv_forward(self, x): 92 | cnv = f.relu(self.conv1(x)) 93 | cnv = self.pool1(cnv) 94 | cnv = self.lrn1(cnv) 95 | cnv = f.relu(self.conv2(cnv)) 96 | cnv = self.pool2(cnv) 97 | cnv = self.lrn2(cnv) 98 | cnv = f.relu(self.conv3(cnv)) 99 | cnv = f.relu(self.conv4(cnv)) 100 | cnv = f.relu(self.conv5(cnv)) 101 | cnv = self.pool5(cnv) 102 | resfeat = self.features(x) 103 | return cnv, resfeat 104 | 105 | 106 | def forward(self, x): 107 | cnv, resfeat = self.conv_forward(x) 108 | feat = cnv.view(-1, 9216) 109 | 110 | catfeat = torch.cat((feat, resfeat), 1) 111 | 112 | hid = f.relu(self.fc6(catfeat)) 113 | hid = self.drp6(hid) 114 | hid = f.relu(self.fc7(hid)) 115 | hid = self.drp8(hid) 116 | hid = f.relu(self.fc8(hid)) 117 | hid = f.relu(self.fc9(hid)) 118 | hid = f.relu(self.fc10(hid)) 119 | hid = f.relu(self.fc11(hid)) 120 | 121 | pry = torch.sigmoid(self.fc12(hid)) 122 | 123 | return pry 124 | -------------------------------------------------------------------------------- /resmem/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import os 4 | from PIL import Image 5 | 6 | 7 | class SaveFeatures: 8 | """ 9 | This class wraps around a pytorch module and saves the output of the module to a class variable. 10 | """ 11 | def __init__(self, module: torch.nn.Module): 12 | self.hook = module.register_forward_hook(self.hook_fn) 13 | self.features = object() 14 | 15 | def hook_fn(self, module, i, output): 16 | self.features = output.clone().requires_grad_(True).cuda() 17 | 18 | def close(self): 19 | self.hook.remove() 20 | 21 | 22 | class DirectoryImageset(Dataset): 23 | def __init__(self, path: str, transform=None): 24 | self.path = path 25 | self.transform = transform 26 | self.image_names = os.listdir(path) 27 | self.image_names.sort() 28 | 29 | def __len__(self): 30 | return len(self.image_names) 31 | 32 | def __getitem__(self, idx): 33 | if torch.is_tensor(idx): 34 | idx = idx.tolist() 35 | 36 | image = Image.open(os.path.join(self.path, self.image_names[idx])) 37 | if self.transform: 38 | image = self.transform(image) 39 | return self.image_names[idx], image 40 | 41 | 42 | class csvImageset(Dataset): 43 | def __init__(self, csv_path: str, basename=os.path.abspath("./"), img_col="img", score_col="memorability", transform=None): 44 | import pandas as pd 45 | self.csv_path = csv_path 46 | self.basename = basename 47 | self.transform = transform 48 | self.csv = pd.read_csv(csv_path) 49 | self.img_col = img_col 50 | self.score_col = score_col 51 | 52 | def __len__(self): 53 | return len(self.csv) 54 | 55 | def __getitem__(self, idx): 56 | if torch.is_tensor(idx): 57 | idx = idx.tolist() 58 | 59 | image = Image.open(os.path.join(self.basename, self.csv[self.img_col][idx])) 60 | if self.transform: 61 | image = self.transform(image) 62 | return image, self.csv[self.score_col][idx] 63 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 3 | from torch.utils.data.dataset import random_split 4 | import torch 5 | from scipy.stats import spearmanr 6 | from resmem import ResMem, transformer 7 | from matplotlib import pyplot as plt 8 | import seaborn as sns 9 | from torchvision import transforms 10 | import pandas as pd 11 | from PIL import Image 12 | import tqdm 13 | from csv import reader 14 | from torch import nn 15 | from torchvision.transforms.transforms import CenterCrop 16 | import glob 17 | 18 | ORDINAL = 1 19 | class MemCatDataset(Dataset): 20 | def __init__(self, loc='./Sources/memcat/', transform=transformer): 21 | self.loc = loc 22 | self.transform = transform 23 | with open(f'{loc}data/memcat_image_data.csv', 'r') as file: 24 | r = reader(file) 25 | next(r) 26 | data = [d for d in r] 27 | self.memcat_frame = np.array(data) 28 | 29 | def __len__(self): 30 | return len(self.memcat_frame) 31 | 32 | def __getitem__(self, idx): 33 | if torch.is_tensor(idx): 34 | idx = idx.tolist() 35 | 36 | img_name = self.memcat_frame[idx, 1] 37 | cat = self.memcat_frame[idx, 2] 38 | scat = self.memcat_frame[idx, 3] 39 | img = Image.open(f'{self.loc}images/{cat}/{scat}/{img_name}').convert('RGB') 40 | y = self.memcat_frame[idx, 12] 41 | y = torch.Tensor([float(y)]) 42 | image_x = self.transform(img) 43 | return [image_x, y, img_name] 44 | 45 | 46 | class LamemDataset(Dataset): 47 | def __init__(self, loc='./Sources/lamem/', transform=transformer): 48 | self.lamem_frame = np.array(np.loadtxt(f'{loc}splits/full.txt', delimiter=' ', dtype=str)) 49 | self.loc = loc 50 | self.transform = transform 51 | 52 | def __len__(self): 53 | return self.lamem_frame.shape[0] 54 | 55 | def __getitem__(self, idx): 56 | if torch.is_tensor(idx): 57 | idx = idx.tolist() 58 | 59 | img_name = self.lamem_frame[idx, 0] 60 | image = Image.open(f'{self.loc}/images/{img_name}') 61 | image = image.convert('RGB') 62 | y = self.lamem_frame[idx, 1] 63 | y = torch.Tensor([float(y)]) 64 | image_x = self.transform(image) 65 | return [image_x, y, img_name] 66 | 67 | 68 | dt = ConcatDataset((LamemDataset(), MemCatDataset())) 69 | _, d_test = random_split(dt, [63741, 5000]) 70 | d_test = DataLoader(d_test, batch_size=32, num_workers=4, pin_memory=True) 71 | model = ResMem(pretrained=True).cuda(ORDINAL) 72 | 73 | distvis='ResMem with Feature Retraining' 74 | model.eval() 75 | if len(d_test): 76 | model.eval() 77 | # If you're using a seperate database for testing, and you aren't just splitting stuff out 78 | with torch.no_grad(): 79 | rloss = 0 80 | preds = [] 81 | ys = [] 82 | names = [] 83 | t = 1 84 | for batch in d_test: 85 | x, y, name = batch 86 | ys += y.squeeze().tolist() 87 | bs, c, h, w = x.size() 88 | ypred = model.forward(x.cuda(ORDINAL).view(-1, c, h, w)).view(bs, -1).mean(1) 89 | preds += ypred.squeeze().tolist() 90 | names += name 91 | rcorr = spearmanr(ys, preds)[0] 92 | loss = ((np.array(ys) - np.array(preds)) ** 2).mean() 93 | if distvis: 94 | sns.distplot(ys, label='Ground Truth') 95 | sns.distplot(preds, label='Predictions') 96 | plt.title(f'{distvis} prediction distribution on {len(d_test)*32} samples') 97 | plt.legend() 98 | plt.savefig(f'{distvis}.png', dpi=500) 99 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="resmem", 8 | version="1.1.3", 9 | author="Coën D. Needell", 10 | author_email="coen@needell.co", 11 | description="A package that wraps the ResMem pytorch model.", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/Brain-Bridge-Lab/resmem", 15 | packages=setuptools.find_packages(), 16 | package_data={'resmem': ['./resmem/*.pt']}, 17 | include_package_data=True, 18 | classifiers=[ 19 | "Programming Language :: Python :: 3", 20 | "License :: OSI Approved :: MIT License", 21 | "Operating System :: OS Independent", 22 | ], 23 | install_requires=[ 24 | 'torch', 25 | 'torchvision', 26 | 'click', 27 | ], 28 | python_requires='>=3.6', 29 | ) 30 | --------------------------------------------------------------------------------