├── README.md ├── ddi_dataset.py ├── ddi_model.py ├── eval_data.py └── eval_ddi.py /README.md: -------------------------------------------------------------------------------- 1 | # DDI Dataset Codebase 2 | Code for loading DDI data and the models from our paper:
      ***Disparities in Dermatology AI Performance on a Diverse, Curated Clinical Image Set*** 3 | 4 | For more information, please visit our project page [here](https://ddi-dataset.github.io/) and read our paper [here](https://www.science.org/doi/full/10.1126/sciadv.abq6147). 5 | 6 | Our models can be downloaded [here](https://drive.google.com/drive/folders/1oQ53WH_Tp6rcLZjRp_-UBOQcMl-b1kkP) or through the provided code. 7 | 8 | 9 | ## Description 10 | We include code to download and load our models (`ddi_model.py`), load the DDI dataset (`ddi_dataset.py`), evaluate our models on the DDI dataset (`eval_ddi.py`) as well as evaluate our models on an arbitrary dataset (`eval_data.py`). For `eval_ddi.py` and `eval_data.py`, we provide a command line interface with the following arguments: 11 | - `model_dir`: File path for where to save models. 12 | - `model`: Name of the model to load (HAM10000, DeepDerm, GroupDRO, CORAL, or CDANN). 13 | - `no_download`: Set to disable downloading models. 14 | - `data_dir`: Folder containing dataset to load. In `eval_ddi.py`, `data_dir` should be the root directory and contain (1) a subfolder called `images` containing all the DDI images and (2) a CSV file called `ddi_metadata.csv`. In `eval_data.py`, the structure should match the root directory in [torchvision.datasets.ImageFolder](https://pytorch.org/vision/stable/datasets.html#torchvision.datasets.ImageFolder) with 2 classes: benign (class 0) and malignant (class 1). 15 | - `eval_dir`: Folder to store evaluation results. 16 | - `use_gpu`: Set to use GPU for evaluation. 17 | - `plot`: Set to show ROC plot. 18 | 19 | 20 | ### Example usage 21 | - Evaluate `DeepDerm` model on the DDI dataset. Data (not included in this repo) is stored in the `DDI` directory, and results will be saved in the `DDI-results` directory. 22 | ```bash 23 | >>>python3 eval_ddi.py --model=DeepDerm --data_dir=DDI --eval_dir=DDI-results 24 | ``` 25 | - Evaluate `DeepDerm` model on your own dataset (must be annotated as benign/malignant). Data (not included in this repo) is stored in the `MyData` directory, and results will be saved in the `DDI-results` directory. 26 | ```bash 27 | >>>python3 eval_data.py --model=DeepDerm --data_dir=MyData --eval_dir=DDI-results 28 | ``` 29 | 30 | 31 | ## Citation 32 | If you find this code useful or use the DDI dataset in your research, please cite: 33 | ``` 34 | @article{daneshjou2022disparities, 35 | title={Disparities in dermatology AI performance on a diverse, curated clinical image set}, 36 | author={Daneshjou, Roxana and Vodrahalli, Kailas and Novoa, Roberto A and Jenkins, Melissa and Liang, Weixin and Rotemberg, Veronica and Ko, Justin and Swetter, Susan M and Bailey, Elizabeth E and Gevaert, Olivier and others}, 37 | journal={Science advances}, 38 | volume={8}, 39 | number={31}, 40 | pages={eabq6147}, 41 | year={2022}, 42 | publisher={American Association for the Advancement of Science} 43 | } 44 | ``` 45 | 46 | -------------------------------------------------------------------------------- /ddi_dataset.py: -------------------------------------------------------------------------------- 1 | """Code for loading DDI Dataset.""" 2 | 3 | from torch.utils.data import Subset 4 | from torchvision.datasets import ImageFolder 5 | from torchvision import transforms as T 6 | import os 7 | import pandas as pd 8 | import numpy as np 9 | 10 | means = [0.485, 0.456, 0.406] 11 | stds = [0.229, 0.224, 0.225] 12 | test_transform = T.Compose([ 13 | lambda x: x.convert('RGB'), 14 | T.Resize(299), 15 | T.CenterCrop(299), 16 | T.ToTensor(), 17 | T.Normalize(mean=means, std=stds) 18 | ]) 19 | 20 | 21 | class DDI_Dataset(ImageFolder): 22 | _DDI_download_link = "https://stanfordaimi.azurewebsites.net/datasets/35866158-8196-48d8-87bf-50dca81df965" 23 | """DDI Dataset. 24 | 25 | Note: assumes DDI data is organized as 26 | ./DDI 27 | /images 28 | /000001.png 29 | /000002.png 30 | ... 31 | /ddi_metadata.csv 32 | 33 | (After downloading from the Stanford AIMI repository, this requires moving all .png files into a new subdirectory titled "images".) 34 | 35 | Args: 36 | root (str): Root directory of dataset. 37 | csv_path (str): Path to the metadata CSV file. Defaults to `{root}/ddi_metadata.csv` 38 | transform : Function to transform and collate image input. (can use test_transform from this file) 39 | """ 40 | def __init__(self, root, csv_path=None, download=True, transform=None, *args, **kwargs): 41 | if csv_path is None: 42 | csv_path = os.path.join(root, "ddi_metadata.csv") 43 | if not os.path.exists(csv_path) and download: 44 | raise Exception(f"Please visit <{DDI_Dataset._DDI_download_link}> to download the DDI dataset.") 45 | assert os.path.exists(csv_path), f"Path not found <{csv_path}>." 46 | super(DDI_Dataset, self).__init__(root, *args, transform=transform, **kwargs) 47 | self.annotations = pd.read_csv(csv_path) 48 | m_key = 'malignant' 49 | if m_key not in self.annotations: 50 | self.annotations[m_key] = self.annotations['malignancy(malig=1)'].apply(lambda x: x==1) 51 | 52 | def __getitem__(self, index): 53 | img, target = super(DDI_Dataset, self).__getitem__(index) 54 | path = self.imgs[index][0] 55 | annotation = dict(self.annotations[self.annotations.DDI_file==path.split("/")[-1]]) 56 | target = int(annotation['malignant'].item()) # 1 if malignant, 0 if benign 57 | skin_tone = annotation['skin_tone'].item() # Fitzpatrick- 12, 34, or 56 58 | return path, img, target, skin_tone 59 | 60 | """Return a subset of the DDI dataset based on skin tones and malignancy of lesion. 61 | 62 | Args: 63 | skin_tone (list of int): Which skin tones to include in the subset. Options are {12, 34, 56}. 64 | diagnosis (list of str): Include malignant and/or benign images. Options are {"benign", "malignant"} 65 | """ 66 | def subset(self, skin_tone=None, diagnosis=None): 67 | skin_tone = [12, 34, 56] if skin_tone is None else skin_tone 68 | diagnosis = ["benign", "malignant"] if diagnosis is None else diagnosis 69 | for si in skin_tone: 70 | assert si in [12,34,56], f"{si} is not a valid skin tone" 71 | for di in diagnosis: 72 | assert di in ["benign", "malignant"], f"{di} is not a valid diagnosis" 73 | indices = np.where(self.annotations['skin_tone'].isin(skin_tone) & \ 74 | self.annotations['malignant'].isin([di=="malignant" for di in diagnosis]))[0] 75 | return Subset(self, indices) -------------------------------------------------------------------------------- /ddi_model.py: -------------------------------------------------------------------------------- 1 | """Code for defining and loading the models we trained.""" 2 | import os 3 | import torch 4 | import torchvision 5 | 6 | 7 | # google drive paths to our models 8 | MODEL_WEB_PATHS = { 9 | # base form of models trained on skin data 10 | 'HAM10000':'https://drive.google.com/uc?id=1ToT8ifJ5lcWh8Ix19ifWlMcMz9UZXcmo', 11 | 'DeepDerm':'https://drive.google.com/uc?id=1OLt11htu9bMPgsE33vZuDiU5Xe4UqKVJ', 12 | 13 | # robust training algorithms 14 | 'GroupDRO':'https://drive.google.com/uc?id=193ippDUYpMaOaEyLjd1DNsOiW0aRXL75', 15 | 'CORAL': 'https://drive.google.com/uc?id=18rMU0nRd4LiHN9WkXoDROJ2o2sG1_GD8', 16 | 'CDANN': 'https://drive.google.com/uc?id=1PvvgQVqcrth840bFZ3ddLdVSL7NkxiRK', 17 | } 18 | 19 | # thresholds determined by maximizing F1-score on the test split of the train 20 | # dataset for the given algorithm 21 | MODEL_THRESHOLDS = { 22 | 'HAM10000':0.733, 23 | 'DeepDerm':0.687, 24 | # robust training algorithms 25 | 'GroupDRO':0.980, 26 | 'CORAL':0.990, 27 | 'CDANN':0.980, 28 | } 29 | 30 | def load_model(model_name, save_dir="DDI-models", download=True): 31 | """Load the model and download if necessary. Saves model to provided save 32 | directory.""" 33 | os.makedirs(save_dir, exist_ok=True) 34 | model_path = os.path.join(save_dir, f"{model_name.lower()}.pth") 35 | if not os.path.exists(model_path): 36 | if not download: 37 | raise Exception("Model not downloaded and download option not"\ 38 | " enabled.") 39 | else: 40 | # Requires installation of gdown (pip install gdown) 41 | import gdown 42 | gdown.download(MODEL_WEB_PATHS[model_name], model_path) 43 | model = torchvision.models.inception_v3(pretrained=False, transform_input=True) 44 | model.fc = torch.nn.Linear(2048, 2) 45 | model.AuxLogits.fc = torch.nn.Linear(768, 2) 46 | state_dict = torch.load(model_path) 47 | model.load_state_dict(state_dict) 48 | model._ddi_name = model_name 49 | model._ddi_threshold = MODEL_THRESHOLDS[model_name] 50 | model._ddi_web_path = MODEL_WEB_PATHS[model_name] 51 | return model -------------------------------------------------------------------------------- /eval_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code to load the models trained for our paper, 3 | "Disparities in Dermatology AI Performance on a Diverse, 4 | Curated Clinical Image Set". 5 | 6 | Examples: 7 | 8 | (1) w/command line interface 9 | # evaluate DeepDerm on DDI and store results in `DDI-results` 10 | >>>python3 eval.py --model=DeepDerm --data_dir=DDI --eval_dir=DDI-results 11 | 12 | (2) w/python functions 13 | >>>import eval 14 | >>>model = eval.load_model("DeepDerm") # load DeepDerm model 15 | >>>eval_results = eval.eval_model(model, "DDI") # evaluate images in DDI folder 16 | 17 | """ 18 | import argparse 19 | import matplotlib.pyplot as plt 20 | import numpy as np 21 | import os 22 | import pickle 23 | from sklearn.metrics import (f1_score, balanced_accuracy_score, 24 | classification_report, confusion_matrix, roc_curve, auc) 25 | import torch 26 | import torchvision 27 | from torchvision import transforms, datasets 28 | import tqdm 29 | 30 | # google drive paths to models 31 | MODEL_WEB_PATHS = { 32 | 'HAM10000':'https://drive.google.com/uc?id=1ToT8ifJ5lcWh8Ix19ifWlMcMz9UZXcmo', 33 | 'DeepDerm':'https://drive.google.com/uc?id=1OLt11htu9bMPgsE33vZuDiU5Xe4UqKVJ', 34 | # robust training algorithms 35 | 'GroupDRO':'https://drive.google.com/uc?id=193ippDUYpMaOaEyLjd1DNsOiW0aRXL75', 36 | 'CORAL': 'https://drive.google.com/uc?id=18rMU0nRd4LiHN9WkXoDROJ2o2sG1_GD8', 37 | 'CDANN': 'https://drive.google.com/uc?id=1PvvgQVqcrth840bFZ3ddLdVSL7NkxiRK', 38 | } 39 | 40 | # thresholds determined by maximizing F1-score on the test split of the train 41 | # dataset for the given algorithm 42 | MODEL_THRESHOLDS = { 43 | 'HAM10000':0.733, 44 | 'DeepDerm':0.687, 45 | # robust training algorithms 46 | 'GroupDRO':0.980, 47 | 'CORAL':0.990, 48 | 'CDANN':0.980, 49 | } 50 | 51 | def get_args(): 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument('--model_dir', type=str, default="DDI-models", 54 | help="File path for where to save models.") 55 | parser.add_argument('--model', type=str, default="DeepDerm", 56 | help="Name of the model to load (HAM10000, DeepDerm, GroupDRO, CORAL,"\ 57 | " or CDANN).") 58 | parser.add_argument('--no_download', action='store_true', default=False, 59 | help="Set to disable downloading models.") 60 | parser.add_argument('--data_dir', type=str, default="DDI", 61 | help="Folder containing dataset to load. Structure should match the"\ 62 | " root directory in torchvision.datasets.ImageFolder with 2"\ 63 | " classes: benign (class 0) and malignant (class 1).") 64 | parser.add_argument('--eval_dir', type=str, default="DDI-results", 65 | help="Folder to store evaluation results.") 66 | parser.add_argument('--use_gpu', action='store_true', default=False, 67 | help="Set to use GPU for evaluation.") 68 | parser.add_argument('--plot', action='store_true', default=False, 69 | help="Set to show ROC plot.") 70 | args = parser.parse_args() 71 | return args 72 | 73 | 74 | def load_model(model_name, save_dir="DDI-models", download=True): 75 | """Load the model and download if necessary. Saves model to provided save 76 | directory.""" 77 | 78 | os.makedirs(save_dir, exist_ok=True) 79 | model_path = os.path.join(save_dir, f"{model_name.lower()}.pth") 80 | if not os.path.exists(model_path): 81 | if not download: 82 | raise Exception("Model not downloaded and download option not"\ 83 | " enabled.") 84 | else: 85 | # Requires installation of gdown (pip install gdown) 86 | import gdown 87 | gdown.download(MODEL_WEB_PATHS[model_name], model_path) 88 | model = torchvision.models.inception_v3(init_weights=False, pretrained=False, transform_input=True) 89 | model.fc = torch.nn.Linear(2048, 2) 90 | model.AuxLogits.fc = torch.nn.Linear(768, 2) 91 | state_dict = torch.load(model_path) 92 | model.load_state_dict(state_dict) 93 | model._ddi_name = model_name 94 | model._ddi_threshold = MODEL_THRESHOLDS[model_name] 95 | model._ddi_web_path = MODEL_WEB_PATHS[model_name] 96 | return model 97 | 98 | class ImageFolderWithPaths(datasets.ImageFolder): 99 | """Custom dataset that includes image file paths. Extends 100 | torchvision.datasets.ImageFolder 101 | """ 102 | 103 | # override the __getitem__ method. this is the method that dataloader calls 104 | def __getitem__(self, index): 105 | # this is what ImageFolder normally returns 106 | original_tuple = super(ImageFolderWithPaths, self).__getitem__(index) 107 | # the image file path 108 | path = os.path.abspath(self.imgs[index][0]) 109 | # make a new tuple that includes original and the path 110 | tuple_with_path = (original_tuple + (path,)) 111 | return tuple_with_path 112 | 113 | def eval_model(model, image_dir, use_gpu=False, show_plot=False): 114 | """Evaluate loaded model on provided image dataset. Assumes supplied image 115 | directory corresponds to `root` input for torchvision.datasets.ImageFolder 116 | class. Assumes the data is split into binary/malignant labels, as this is 117 | what our models are trained+evaluated on.""" 118 | 119 | use_gpu = (use_gpu and torch.cuda.is_available()) 120 | device = torch.device("cuda") if use_gpu else torch.device("cpu") 121 | 122 | # load dataset 123 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 124 | std=[0.229, 0.224, 0.225]) 125 | dataset = ImageFolderWithPaths( 126 | image_dir, 127 | transforms.Compose([ 128 | transforms.Resize(299), 129 | transforms.CenterCrop(299), 130 | transforms.ToTensor(), 131 | normalize])) 132 | dataloader = torch.utils.data.DataLoader( 133 | dataset, 134 | batch_size=32, shuffle=False, 135 | num_workers=0, pin_memory=use_gpu) 136 | 137 | # prepare model for evaluation 138 | model.to(device).eval() 139 | 140 | # log output for all images in dataset 141 | hat, star, all_paths = [], [], [] 142 | for batch in tqdm.tqdm(enumerate(dataloader)): 143 | i, (images, target, paths) = batch 144 | images = images.to(device) 145 | target = target.to(device) 146 | 147 | with torch.no_grad(): 148 | output = model(images) 149 | 150 | hat.append(output[:,1].detach().cpu().numpy()) 151 | star.append(target.cpu().numpy()) 152 | all_paths.append(paths) 153 | 154 | hat = np.concatenate(hat) 155 | star = np.concatenate(star) 156 | all_paths = np.concatenate(all_paths) 157 | threshold = model._ddi_threshold 158 | m_name = model._ddi_name 159 | m_web_path = model._ddi_web_path 160 | 161 | report = classification_report(star, (hat>threshold).astype(int), 162 | target_names=["benign","malignant"]) 163 | fpr, tpr, _ = roc_curve(star, hat, pos_label=1, 164 | sample_weight=None, 165 | drop_intermediate=True) 166 | auc_est = auc(fpr, tpr) 167 | 168 | if show_plot: 169 | _=plt.plot(fpr, tpr, 170 | color="blue", linestyle="-", linewidth=2, 171 | marker="o", markersize=2, 172 | label=f"AUC={auc_est:.3f}")[0] 173 | plt.show() 174 | plt.close() 175 | 176 | eval_results = {'predicted_labels':hat, # predicted labels by model 177 | 'true_labels':star, # true labels 178 | 'images':all_paths, # image paths 179 | 'report':report, # sklearn classification report 180 | 'ROC_AUC':auc_est, # ROC-AUC 181 | 'threshold':threshold, # >= threshold ==> malignant 182 | 'model':m_name, # model name 183 | 'web_path':m_web_path, # web link to download model 184 | } 185 | 186 | return eval_results 187 | 188 | 189 | 190 | 191 | if __name__ == '__main__': 192 | # get arguments from command line 193 | args = get_args() 194 | # load model and download if necessary 195 | model = load_model(args.model, 196 | save_dir=args.model_dir, download=not args.no_download) 197 | # evaluate results on data 198 | eval_results = eval_model(model, args.data_dir, use_gpu=args.use_gpu, 199 | show_plot=args.plot) 200 | 201 | # save evaluation results in a .pkl file 202 | if args.eval_dir: 203 | os.makedirs(args.eval_dir, exist_ok=True) 204 | eval_save_path = os.path.join(args.eval_dir, 205 | f"{args.model}-evaluation.pkl") 206 | with open(eval_save_path, 'wb') as f: 207 | pickle.dump(eval_results, f) 208 | 209 | # load results with: 210 | #with open(eval_save_path, 'rb') as f: 211 | # results = pickle.load(f) 212 | -------------------------------------------------------------------------------- /eval_ddi.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code to evaluate the models trained for our paper, 3 | "Disparities in Dermatology AI Performance on a Diverse, 4 | Curated Clinical Image Set", 5 | on the DDI dataset. 6 | 7 | Note: assumes DDI data is organized as 8 | ./DDI 9 | /images 10 | /000001.png 11 | /000002.png 12 | ... 13 | /ddi_metadata.csv 14 | 15 | (After downloading from the Stanford AIMI repository, this requires moving all .png files into a new subdirectory titled "images".) 16 | 17 | ------------------------------------------------------ 18 | Examples: 19 | 20 | (1) w/command line interface 21 | # evaluate DeepDerm on DDI and store results in `DDI-results` 22 | >>>python3 eval_ddi.py --model=DeepDerm --data_dir=DDI --eval_dir=DDI-results 23 | 24 | (2) w/python functions 25 | >>>import eval_ddi 26 | >>>import ddi_model 27 | >>>model = ddi_model.load_model("DeepDerm") # load DeepDerm model 28 | >>>eval_results = eval_ddi.eval_model(model, "DDI") # evaluate images in DDI folder 29 | """ 30 | 31 | import argparse 32 | from ddi_dataset import DDI_Dataset, test_transform 33 | from ddi_model import load_model 34 | import matplotlib.pyplot as plt 35 | import numpy as np 36 | import os 37 | import pickle 38 | from sklearn.metrics import (f1_score, balanced_accuracy_score, 39 | classification_report, confusion_matrix, roc_curve, auc) 40 | import torch 41 | import tqdm 42 | 43 | 44 | def get_args(): 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--model_dir', type=str, default="DDI-models", 47 | help="File path for where to save models.") 48 | parser.add_argument('--model', type=str, default="DeepDerm", 49 | help="Name of the model to load (HAM10000, DeepDerm, GroupDRO, CORAL,"\ 50 | " or CDANN).") 51 | parser.add_argument('--no_download', action='store_true', default=False, 52 | help="Set to disable downloading models.") 53 | parser.add_argument('--data_dir', type=str, default="DDI", 54 | help="Folder containing dataset to load. Structure should be: (1) `[data_dir]/images` contains all images; (2) `[data_dir]/ddi_metadata.csv` contains the CSV metadata for the DDI dataset") 55 | parser.add_argument('--eval_dir', type=str, default="DDI-results", 56 | help="Folder to store evaluation results.") 57 | parser.add_argument('--use_gpu', action='store_true', default=False, 58 | help="Set to use GPU for evaluation.") 59 | parser.add_argument('--plot', action='store_true', default=False, 60 | help="Set to show ROC plot.") 61 | args = parser.parse_args() 62 | return args 63 | 64 | def eval_model(model, dataset, use_gpu=False, show_plot=False): 65 | """Evaluate loaded model on provided image dataset. Assumes supplied image 66 | directory corresponds to `root` input for torchvision.datasets.ImageFolder 67 | class. Assumes the data is split into binary/malignant labels, as this is 68 | what our models are trained+evaluated on.""" 69 | 70 | use_gpu = (use_gpu and torch.cuda.is_available()) 71 | device = torch.device("cuda") if use_gpu else torch.device("cpu") 72 | 73 | # load dataset 74 | dataloader = torch.utils.data.DataLoader( 75 | dataset, 76 | batch_size=32, shuffle=False, 77 | num_workers=0, pin_memory=use_gpu) 78 | 79 | # prepare model for evaluation 80 | model.to(device).eval() 81 | 82 | # log output for all images in dataset 83 | hat, star, all_paths = [], [], [] 84 | for batch in tqdm.tqdm(enumerate(dataloader)): 85 | i, (paths, images, target, skin_tone) = batch 86 | images = images.to(device) 87 | target = target.to(device) 88 | 89 | with torch.no_grad(): 90 | output = model(images) 91 | 92 | hat.append(output[:,1].detach().cpu().numpy()) 93 | star.append(target.cpu().numpy()) 94 | all_paths.append(paths) 95 | 96 | hat = np.concatenate(hat) 97 | star = np.concatenate(star) 98 | all_paths = np.concatenate(all_paths) 99 | threshold = model._ddi_threshold 100 | m_name = model._ddi_name 101 | m_web_path = model._ddi_web_path 102 | 103 | report = classification_report(star, (hat>threshold).astype(int), 104 | target_names=["benign","malignant"]) 105 | fpr, tpr, _ = roc_curve(star, hat, pos_label=1, 106 | sample_weight=None, 107 | drop_intermediate=True) 108 | auc_est = auc(fpr, tpr) 109 | 110 | if show_plot: 111 | _=plt.plot(fpr, tpr, 112 | color="blue", linestyle="-", linewidth=2, 113 | marker="o", markersize=2, 114 | label=f"AUC={auc_est:.3f}")[0] 115 | plt.show() 116 | plt.close() 117 | 118 | eval_results = {'predicted_labels':hat, # predicted labels by model 119 | 'true_labels':star, # true labels 120 | 'images':all_paths, # image paths 121 | 'report':report, # sklearn classification report 122 | 'ROC_AUC':auc_est, # ROC-AUC 123 | 'threshold':threshold, # >= threshold ==> malignant 124 | 'model':m_name, # model name 125 | 'web_path':m_web_path, # web link to download model 126 | } 127 | 128 | return eval_results 129 | 130 | 131 | 132 | 133 | if __name__ == '__main__': 134 | # get arguments from command line 135 | args = get_args() 136 | # load model and download if necessary 137 | model = load_model(args.model, 138 | save_dir=args.model_dir, download=not args.no_download) 139 | # load DDI dataset 140 | dataset = DDI_Dataset("DDI", transform=test_transform) 141 | # evaluate results on data 142 | eval_results = eval_model(model, dataset, 143 | use_gpu=args.use_gpu, show_plot=args.plot) 144 | 145 | # save evaluation results in a .pkl file 146 | if args.eval_dir: 147 | os.makedirs(args.eval_dir, exist_ok=True) 148 | eval_save_path = os.path.join(args.eval_dir, 149 | f"{args.model}-evaluation.pkl") 150 | with open(eval_save_path, 'wb') as f: 151 | pickle.dump(eval_results, f) 152 | 153 | # load results with: 154 | #with open(eval_save_path, 'rb') as f: 155 | # results = pickle.load(f) 156 | --------------------------------------------------------------------------------