├── .gitignore ├── LICENSE ├── README.md ├── dataset_preprocessing ├── README.md ├── amazon_yelp │ ├── create_unlabeled_amazon.py │ ├── generate_splits_amazon.py │ ├── generate_splits_yelp.py │ ├── process_amazon.py │ ├── process_yelp.py │ ├── subsample_amazon.py │ └── utils.py ├── camelyon17 │ ├── README.md │ ├── extract_final_patches_to_disk.py │ ├── generate_all_patch_coords.py │ ├── generate_final_metadata.py │ └── unlabeled │ │ ├── README.md │ │ ├── extract_final_patches_to_disk.py │ │ ├── generate_all_patch_coords.py │ │ ├── generate_final_metadata.py │ │ └── validate.py ├── civilcomments │ ├── README.md │ ├── attr_definitions.py │ ├── process_labeled.py │ └── process_unlabeled.py ├── domainnet │ ├── generate_metadata.py │ └── generate_sentry_metadata.py ├── encode │ ├── README.md │ ├── prep_accessibility.py │ ├── prep_metadata_labels.py │ └── prep_sequence.py ├── fmow │ ├── convert_npy_to_jpg.py │ └── process_metadata_fmow.py ├── iwildcam │ └── create_split.py ├── molpcba_unlabeled │ └── process.py └── poverty │ ├── batcher.py │ ├── convert_poverty_to_npy.py │ ├── dataset_constants.py │ ├── process_metadata_poverty.py │ ├── split_npys.py │ └── split_npys_unlabeled.py ├── examples ├── __init__.py ├── algorithms │ ├── AFN.py │ ├── DANN.py │ ├── ERM.py │ ├── IRM.py │ ├── algorithm.py │ ├── deepCORAL.py │ ├── fixmatch.py │ ├── groupDRO.py │ ├── group_algorithm.py │ ├── initializer.py │ ├── noisy_student.py │ ├── pseudolabel.py │ └── single_model_algorithm.py ├── configs │ ├── algorithm.py │ ├── data_loader.py │ ├── datasets.py │ ├── model.py │ ├── scheduler.py │ ├── supported.py │ └── utils.py ├── data_augmentation │ ├── __init__.py │ └── randaugment.py ├── evaluate.py ├── losses.py ├── models │ ├── CNN_genome.py │ ├── __init__.py │ ├── bert │ │ ├── __init__.py │ │ ├── bert.py │ │ └── distilbert.py │ ├── code_gpt.py │ ├── detection │ │ └── fasterrcnn.py │ ├── domain_adversarial_network.py │ ├── gnn.py │ ├── initializer.py │ ├── layers.py │ └── resnet_multispectral.py ├── noisy_student_wrapper.py ├── optimizer.py ├── pretraining │ ├── mlm │ │ ├── README.md │ │ ├── get_data.py │ │ ├── run_mlm.py │ │ └── run_pretrain.sh │ └── swav │ │ ├── LICENSE │ │ ├── README.md │ │ ├── main_swav.py │ │ └── src │ │ ├── config.py │ │ ├── logger.py │ │ ├── model.py │ │ ├── multicropdataset.py │ │ └── utils.py ├── run_expt.py ├── scheduler.py ├── train.py ├── transforms.py └── utils.py ├── setup.py └── wilds ├── __init__.py ├── common ├── __init__.py ├── data_loaders.py ├── grouper.py ├── metrics │ ├── __init__.py │ ├── all_metrics.py │ ├── loss.py │ └── metric.py └── utils.py ├── datasets ├── __init__.py ├── amazon_dataset.py ├── archive │ ├── __init__.py │ ├── fmow_v1_0_dataset.py │ ├── iwildcam_v1_0_dataset.py │ └── poverty_v1_0_dataset.py ├── bdd100k_dataset.py ├── camelyon17_dataset.py ├── celebA_dataset.py ├── civilcomments_dataset.py ├── domainnet_dataset.py ├── download_utils.py ├── encode_dataset.py ├── fmow_dataset.py ├── globalwheat_dataset.py ├── iwildcam_dataset.py ├── ogbmolpcba_dataset.py ├── poverty_dataset.py ├── py150_dataset.py ├── rxrx1_dataset.py ├── sqf_dataset.py ├── unlabeled │ ├── __init__.py │ ├── amazon_unlabeled_dataset.py │ ├── camelyon17_unlabeled_dataset.py │ ├── civilcomments_unlabeled_dataset.py │ ├── domainnet_unlabeled_dataset.py │ ├── fmow_unlabeled_dataset.py │ ├── globalwheat_unlabeled_dataset.py │ ├── iwildcam_unlabeled_dataset.py │ ├── ogbmolpcba_unlabeled_dataset.py │ ├── poverty_unlabeled_dataset.py │ └── wilds_unlabeled_dataset.py ├── waterbirds_dataset.py ├── wilds_dataset.py └── yelp_dataset.py ├── download_datasets.py ├── get_dataset.py └── version.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | build 4 | data 5 | logs 6 | dist 7 | venv 8 | wilds.egg-info 9 | .DS_Store 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 WILDS team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataset_preprocessing/README.md: -------------------------------------------------------------------------------- 1 | ## WILDS dataset preprocessing scripts 2 | 3 | **These files are not directly used by the WILDS package, and users do not need to look at them to use the package.** 4 | 5 | This directory contains scripts that were used to preprocess the WILDS datasets from their original forms into the `*-wilds` forms that we use in our benchmark. 6 | The WILDS package automatically downloads the already-processed forms; 7 | We archive these scripts here just for reproducibility purposes and for users who are interested in the precise details of the dataset preprocessing. 8 | 9 | Some of these scripts have specific requirements beyond what is required for the WILDS package, e.g., specialized software for handling pathology slides. 10 | -------------------------------------------------------------------------------- /dataset_preprocessing/amazon_yelp/generate_splits_amazon.py: -------------------------------------------------------------------------------- 1 | import os, json, gzip, argparse, time, csv 2 | import numpy as np 3 | import pandas as pd 4 | from utils import * 5 | 6 | CATEGORIES = ["AMAZON_FASHION", "All_Beauty","Appliances", "Arts_Crafts_and_Sewing", "Automotive", "Books", "CDs_and_Vinyl", "Cell_Phones_and_Accessories", "Clothing_Shoes_and_Jewelry", "Digital_Music", "Electronics", "Gift_Cards", "Grocery_and_Gourmet_Food", "Home_and_Kitchen", "Industrial_and_Scientific", "Kindle_Store", "Luxury_Beauty", "Magazine_Subscriptions", "Movies_and_TV", "Musical_Instruments", "Office_Products", "Patio_Lawn_and_Garden", "Pet_Supplies", "Prime_Pantry", "Software", "Sports_and_Outdoors", "Tools_and_Home_Improvement", "Toys_and_Games", "Video_Games"] 7 | 8 | ############# 9 | ### PATHS ### 10 | ############# 11 | 12 | def data_dir(root_dir): 13 | return os.path.join(root_dir, 'amazon', 'data') 14 | 15 | def generate_user_splits(data_dir, reviews_df, min_size_per_user, 16 | train_size, eval_size, seed): 17 | # mark duplicates 18 | duplicated_within_user = reviews_df[['reviewerID','reviewText']].duplicated() 19 | df_deduplicated_within_user = reviews_df[~duplicated_within_user] 20 | duplicated_text = df_deduplicated_within_user[df_deduplicated_within_user['reviewText'].apply(lambda x: x.lower()).duplicated(keep=False)]['reviewText'] 21 | duplicated_text = set(duplicated_text.values) 22 | reviews_df['duplicate'] = ((reviews_df['reviewText'].isin(duplicated_text)) | duplicated_within_user) 23 | # mark html candidates 24 | reviews_df['contains_html'] = reviews_df['reviewText'].apply(lambda x: '<' in x and '>' in x) 25 | # mark clean ones 26 | reviews_df['clean'] = (~reviews_df['duplicate'] & ~reviews_df['contains_html']) 27 | 28 | # generate splits 29 | generate_group_splits( 30 | data_dir=data_dir, 31 | reviews_df=reviews_df, 32 | min_size_per_group=min_size_per_user, 33 | group_field='reviewerID', 34 | split_name='user', 35 | train_size=train_size, 36 | eval_size=eval_size, 37 | seed=seed, 38 | select_column='clean') 39 | 40 | def generate_users_baseline_splits(data_dir, reviews_df, reviewer_id, seed, user_split_name='user'): 41 | # seed 42 | np.random.seed(seed) 43 | # sizes 44 | n, _ = reviews_df.shape 45 | splits = np.ones(n)*-1 46 | # load user split 47 | orig_splits_df = pd.read_csv(splits_path(data_dir, user_split_name)) 48 | splits[((orig_splits_df['split']==OOD_TEST) & (reviews_df['reviewerID']==reviewer_id)).values] = TEST 49 | # train 50 | train_indices, = np.where(np.logical_and.reduce((reviews_df['reviewerID']==reviewer_id, 51 | splits==-1, 52 | orig_splits_df['clean']))) 53 | np.random.shuffle(train_indices) 54 | eval_size = np.sum(splits==TEST) 55 | splits[train_indices[:eval_size]] = VAL 56 | splits[train_indices[eval_size:]] = TRAIN 57 | split_df = pd.DataFrame({'split': splits}) 58 | split_df.to_csv(splits_path(data_dir, f'{reviewer_id}_baseline'), index=False) 59 | 60 | def main(): 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('--root_dir', required=True) 63 | args = parser.parse_args() 64 | 65 | df = pd.read_csv(reviews_path(data_dir(args.root_dir)), 66 | dtype={'reviewerID':str, 'asin':str, 'reviewTime':str,'unixReviewTime':int, 67 | 'reviewText':str,'summary':str,'verified':bool,'category':str, 'reviewYear':int}, 68 | keep_default_na=False, na_values=[]) 69 | 70 | # category subpopulation 71 | generate_fixed_group_splits( 72 | data_dir=data_dir(args.root_dir), 73 | reviews_df=df, 74 | group_field='category', 75 | train_groups=None, 76 | split_name='category_subpopulation', 77 | train_size=int(1e6), 78 | eval_size_per_group=1000, 79 | seed=0) 80 | 81 | # category generalization and baselines 82 | train_categories_list = [[category,] for category in CATEGORIES] + \ 83 | [['Books','Movies_and_TV','Home_and_Kitchen','Electronics'], 84 | ['Movies_and_TV','Books'], 85 | ['Movies_and_TV','Books','Home_and_Kitchen']] 86 | for train_categories in train_categories_list: 87 | split_name = ','.join([category.lower() for category in train_categories])+'_generalization' 88 | generate_fixed_group_splits( 89 | data_dir=data_dir(args.root_dir), 90 | reviews_df=df, 91 | group_field='category', 92 | train_groups=train_categories, 93 | split_name=split_name, 94 | train_size=int(1e6), 95 | eval_size_per_group=1000, 96 | seed=0) 97 | 98 | # time shift 99 | generate_time_splits( 100 | data_dir=data_dir(args.root_dir), 101 | reviews_df=df, 102 | year_field='reviewYear', 103 | year_threshold=2013, 104 | train_size=int(1e6), 105 | eval_size_per_year=4000, 106 | seed=0) 107 | 108 | # user splits 109 | generate_user_splits( 110 | data_dir=data_dir(args.root_dir), 111 | reviews_df=df, 112 | min_size_per_user=150, 113 | train_size=int(1e6), 114 | eval_size=1e5, 115 | seed=0) 116 | 117 | baseline_reviewers = ['AV6QDP8Q0ONK4', 'A37BRR2L8PX3R2', 'A1UH21GLZTYYR5', 'ASVY5XSYJ1XOE', 'A1NE43T0OM6NNX', 118 | 'A9Q28YTLYREO7', 'A1CNQTCRQ35IMM', 'A20EEWWSFMZ1PN', 'A3JVZY05VLMYEM', 'A219Y76LD1VP4N'] 119 | for reviewer_id in baseline_reviewers: 120 | generate_users_baseline_splits( 121 | data_dir=data_dir(args.root_dir), 122 | reviews_df=df, 123 | reviewer_id=reviewer_id, 124 | seed=0) 125 | 126 | if __name__=='__main__': 127 | main() 128 | -------------------------------------------------------------------------------- /dataset_preprocessing/amazon_yelp/generate_splits_yelp.py: -------------------------------------------------------------------------------- 1 | import os, json, gzip, argparse, time, csv 2 | import numpy as np 3 | import pandas as pd 4 | from utils import * 5 | 6 | def data_dir(root_dir): 7 | return os.path.join(root_dir, 'yelp', 'data') 8 | 9 | def load_reviews(data_dir): 10 | reviews_df = pd.read_csv(reviews_path(data_dir), 11 | dtype={'review_id': str, 'user_id': str, 'business_id':str, 'stars': int, 12 | 'useful': int, 'funny': int, 'cool':int, 'text': str, 'date':str}, 13 | keep_default_na=False, na_values=[]) 14 | return reviews_df 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--root_dir', required=True) 19 | args = parser.parse_args() 20 | 21 | reviews_df = load_reviews(data_dir(args.root_dir)) 22 | # time 23 | generate_time_splits( 24 | data_dir=data_dir(args.root_dir), 25 | reviews_df=reviews_df, 26 | year_field='year', 27 | year_threshold=2013, 28 | train_size=int(1e6), 29 | eval_size_per_year=1000, 30 | seed=0) 31 | 32 | # user shifts 33 | generate_group_splits( 34 | data_dir=data_dir(args.root_dir), 35 | reviews_df=reviews_df, 36 | min_size_per_group=50, 37 | group_field='user_id', 38 | split_name='user', 39 | train_size=int(1e6), 40 | eval_size=int(4e4), 41 | seed=0) 42 | 43 | if __name__=='__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /dataset_preprocessing/amazon_yelp/process_yelp.py: -------------------------------------------------------------------------------- 1 | import os, sys, torch, json, csv, argparse 2 | import numpy as np 3 | import pandas as pd 4 | from transformers import BertTokenizerFast 5 | from utils import * 6 | 7 | ############# 8 | ### PATHS ### 9 | ############# 10 | 11 | def data_dir(root_dir): 12 | return os.path.join(root_dir, 'yelp', 'data') 13 | 14 | def token_length_path(data_dir): 15 | return os.path.join(preprocessing_dir(data_dir), f'token_counts.csv') 16 | 17 | ############ 18 | ### LOAD ### 19 | ############ 20 | 21 | def parse(path): 22 | with open(path, 'r') as f: 23 | for l in f: 24 | yield json.loads(l) 25 | 26 | def load_business_data(data_dir): 27 | keys = ['business_id', 'city', 'state', 'categories'] 28 | df = {} 29 | for k in keys: 30 | df[k] = [] 31 | with open(os.path.join(raw_data_dir(data_dir), 'yelp_academic_dataset_business.json'), 'r') as f: 32 | for i, line in enumerate(f): 33 | data = json.loads(line) 34 | for k in keys: 35 | df[k].append(data[k]) 36 | business_df = pd.DataFrame(df) 37 | return business_df 38 | 39 | ##################### 40 | ### PREPROCESSING ### 41 | ##################### 42 | 43 | def compute_token_length(data_dir): 44 | tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') 45 | token_counts = [] 46 | with open(os.path.join(raw_data_dir(data_dir), 'yelp_academic_dataset_review.json'), 'r') as f: 47 | text_list = [] 48 | for i, line in enumerate(f): 49 | if i % 100000==0: 50 | print(f'Processed {i} reviews') 51 | data = json.loads(line) 52 | text = data['text'] 53 | text_list.append(text) 54 | if len(text_list)==1024: 55 | tokens = tokenizer(text_list, 56 | padding='do_not_pad', 57 | truncation='do_not_truncate', 58 | return_token_type_ids=False, 59 | return_attention_mask=False, 60 | return_overflowing_tokens=False, 61 | return_special_tokens_mask=False, 62 | return_offsets_mapping=False, 63 | return_length=True) 64 | token_counts += tokens['length'] 65 | text_list = [] 66 | if len(text_list)>0: 67 | tokens = tokenizer(text_list, 68 | padding='do_not_pad', 69 | truncation='do_not_truncate', 70 | return_token_type_ids=False, 71 | return_attention_mask=False, 72 | return_overflowing_tokens=False, 73 | return_special_tokens_mask=False, 74 | return_offsets_mapping=False, 75 | return_length=True) 76 | token_counts += tokens['length'] 77 | 78 | csv_path = token_length_path(data_dir) 79 | df = pd.DataFrame({'token_counts': token_counts}) 80 | df.to_csv(csv_path, index=False, quoting=csv.QUOTE_NONNUMERIC) 81 | 82 | def process_reviews(data_dir): 83 | # load pre-computed token length 84 | assert os.path.exists(token_length_path(data_dir)), 'pre-compute token length first' 85 | token_length = pd.read_csv(token_length_path(data_dir))['token_counts'].values 86 | 87 | # filter and export 88 | with open(reviews_path(data_dir), 'w') as f: 89 | fields = ['review_id', 'user_id', 'business_id', 'stars', 'useful', 'funny', 'cool', 'text', 'date'] 90 | writer = csv.DictWriter(f, fields, quoting=csv.QUOTE_NONNUMERIC) 91 | 92 | for i, review in enumerate(parse(os.path.join(raw_data_dir(data_dir), 'yelp_academic_dataset_review.json'))): 93 | if 'text' not in review: 94 | continue 95 | if len(review['text'].strip())==0: 96 | continue 97 | if token_length[i] > 512: 98 | continue 99 | row = {} 100 | for field in fields: 101 | row[field] = review[field] 102 | writer.writerow(row) 103 | # compute year 104 | df = pd.read_csv(reviews_path(data_dir), names=fields, 105 | dtype={'review_id': str, 'user_id': str, 'business_id':str, 'stars': int, 106 | 'useful': int, 'funny': int, 'cool':int, 'text': str, 'date':str}, 107 | keep_default_na=False, na_values=[]) 108 | print(f'Before deduplication: {df.shape}') 109 | df['year'] = df['date'].apply(lambda x: int(x.split('-')[0])) 110 | # remove duplicates 111 | duplicated_within_user = df[['user_id','text']].duplicated() 112 | df_deduplicated_within_user = df[~duplicated_within_user] 113 | duplicated_text = df_deduplicated_within_user[df_deduplicated_within_user['text'].apply(lambda x: x.lower()).duplicated(keep=False)]['text'] 114 | duplicated_text = set(duplicated_text.values) 115 | if len(duplicated_text)>0: 116 | print('Eliminating reviews with the following duplicate texts:') 117 | print('\n'.join(list(duplicated_text))) 118 | print('') 119 | df['duplicate'] = ((df['text'].isin(duplicated_text)) | duplicated_within_user) 120 | df = df[~df['duplicate']] 121 | print(f'After deduplication: {df[~df["duplicate"]].shape}') 122 | business_df = load_business_data(data_dir) 123 | df = pd.merge(df, business_df, on='business_id', how='left') 124 | df = df.drop(columns=['duplicate']) 125 | df.to_csv(reviews_path(data_dir), index=False, quoting=csv.QUOTE_NONNUMERIC) 126 | 127 | def main(): 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--root_dir', required=True) 130 | args = parser.parse_args() 131 | 132 | for dirpath in [splits_dir(data_dir(args.root_dir)), preprocessing_dir(data_dir(args.root_dir))]: 133 | if not os.path.exists(dirpath): 134 | os.mkdir(dirpath) 135 | 136 | compute_token_length(data_dir(args.root_dir)) 137 | process_reviews(data_dir(args.root_dir)) 138 | 139 | if __name__=='__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /dataset_preprocessing/amazon_yelp/subsample_amazon.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | 5 | import pandas as pd 6 | import numpy as np 7 | 8 | # Fix the seed for reproducibility 9 | np.random.seed(0) 10 | 11 | """ 12 | Subsample the Amazon dataset. 13 | 14 | Usage: 15 | python dataset_preprocessing/amazon_yelp/subsample_amazon.py 16 | """ 17 | 18 | NOT_IN_DATASET = -1 19 | # Split: {'train': 0, 'val': 1, 'id_val': 2, 'test': 3, 'id_test': 4} 20 | TRAIN, OOD_VAL, ID_VAL, OOD_TEST, ID_TEST = range(5) 21 | 22 | 23 | def main(dataset_path, frac=0.25): 24 | def output_dataset_sizes(split_df): 25 | print("-" * 50) 26 | print(f'Train size: {len(split_df[split_df["split"] == TRAIN])}') 27 | print(f'Val size: {len(split_df[split_df["split"] == OOD_VAL])}') 28 | print(f'ID Val size: {len(split_df[split_df["split"] == ID_VAL])}') 29 | print(f'Test size: {len(split_df[split_df["split"] == OOD_TEST])}') 30 | print(f'ID Test size: {len(split_df[split_df["split"] == ID_TEST])}') 31 | print( 32 | f'Number of examples not included: {len(split_df[split_df["split"] == NOT_IN_DATASET])}' 33 | ) 34 | print("-" * 50) 35 | print("\n") 36 | 37 | data_df = pd.read_csv( 38 | os.path.join(dataset_path, "reviews.csv"), 39 | dtype={ 40 | "reviewerID": str, 41 | "asin": str, 42 | "reviewTime": str, 43 | "unixReviewTime": int, 44 | "reviewText": str, 45 | "summary": str, 46 | "verified": bool, 47 | "category": str, 48 | "reviewYear": int, 49 | }, 50 | keep_default_na=False, 51 | na_values=[], 52 | quoting=csv.QUOTE_NONNUMERIC, 53 | ) 54 | 55 | user_csv_path = os.path.join(dataset_path, "splits", "user.csv") 56 | split_df = pd.read_csv(user_csv_path) 57 | output_dataset_sizes(split_df) 58 | 59 | train_data_df = data_df[split_df["split"] == 0] 60 | train_reviewer_ids = train_data_df.reviewerID.unique() 61 | print(f"Number of unique reviewers in train set: {len(train_reviewer_ids)}") 62 | 63 | # Randomly sample (1 - frac) x number of reviewers 64 | # Blackout all the reviews belonging to the randomly sampled reviewers 65 | subsampled_reviewers_count = int((1 - frac) * len(train_reviewer_ids)) 66 | subsampled_reviewers = np.random.choice( 67 | train_reviewer_ids, subsampled_reviewers_count, replace=False 68 | ) 69 | print(subsampled_reviewers) 70 | 71 | blackout_indices = train_data_df[ 72 | train_data_df["reviewerID"].isin(subsampled_reviewers) 73 | ].index 74 | 75 | # Mark all the corresponding reviews of blackout_indices as -1 76 | split_df.loc[blackout_indices, "split"] = NOT_IN_DATASET 77 | output_dataset_sizes(split_df) 78 | 79 | # Mark duplicates 80 | duplicated_within_user = data_df[["reviewerID", "reviewText"]].duplicated() 81 | df_deduplicated_within_user = data_df[~duplicated_within_user] 82 | duplicated_text = df_deduplicated_within_user[ 83 | df_deduplicated_within_user["reviewText"] 84 | .apply(lambda x: x.lower()) 85 | .duplicated(keep=False) 86 | ]["reviewText"] 87 | duplicated_text = set(duplicated_text.values) 88 | data_df["duplicate"] = ( 89 | data_df["reviewText"].isin(duplicated_text) 90 | ) | duplicated_within_user 91 | 92 | # Mark html candidates 93 | data_df["contains_html"] = data_df["reviewText"].apply( 94 | lambda x: "<" in x and ">" in x 95 | ) 96 | 97 | # Mark clean ones 98 | data_df["clean"] = ~data_df["duplicate"] & ~data_df["contains_html"] 99 | 100 | # Clear ID val and ID test since we're regenerating 101 | split_df.loc[split_df["split"] == ID_VAL, "split"] = NOT_IN_DATASET 102 | split_df.loc[split_df["split"] == ID_TEST, "split"] = NOT_IN_DATASET 103 | 104 | # Regenerate ID val and ID test 105 | train_reviewer_ids = data_df[split_df["split"] == TRAIN]["reviewerID"].unique() 106 | np.random.shuffle(train_reviewer_ids) 107 | cutoff = int(len(train_reviewer_ids) / 2) 108 | id_val_reviewer_ids = train_reviewer_ids[:cutoff] 109 | id_test_reviewer_ids = train_reviewer_ids[cutoff:] 110 | split_df.loc[ 111 | (split_df["split"] == NOT_IN_DATASET) 112 | & data_df["clean"] 113 | & data_df["reviewerID"].isin(id_val_reviewer_ids), 114 | "split", 115 | ] = ID_VAL 116 | split_df.loc[ 117 | (split_df["split"] == NOT_IN_DATASET) 118 | & data_df["clean"] 119 | & data_df["reviewerID"].isin(id_test_reviewer_ids), 120 | "split", 121 | ] = ID_TEST 122 | 123 | # Sanity check 124 | assert ( 125 | data_df[(split_df["split"] == ID_VAL)]["reviewerID"].value_counts().min() == 75 126 | ) 127 | assert ( 128 | data_df[(split_df["split"] == ID_VAL)]["reviewerID"].value_counts().max() == 75 129 | ) 130 | assert ( 131 | data_df[(split_df["split"] == ID_TEST)]["reviewerID"].value_counts().min() == 75 132 | ) 133 | assert ( 134 | data_df[(split_df["split"] == ID_TEST)]["reviewerID"].value_counts().max() == 75 135 | ) 136 | 137 | # Write out the new splits to user.csv 138 | output_dataset_sizes(split_df) 139 | split_df.to_csv(user_csv_path, index=False) 140 | print("Done.") 141 | 142 | 143 | if __name__ == "__main__": 144 | parser = argparse.ArgumentParser(description="Subsample the Amazon dataset.") 145 | parser.add_argument( 146 | "path", 147 | type=str, 148 | help="Path to the Amazon dataset", 149 | ) 150 | parser.add_argument( 151 | "frac", 152 | type=float, 153 | help="Subsample fraction", 154 | ) 155 | 156 | args = parser.parse_args() 157 | main(args.path, args.frac) 158 | -------------------------------------------------------------------------------- /dataset_preprocessing/camelyon17/README.md: -------------------------------------------------------------------------------- 1 | ## Camelyon17-wilds patch processing 2 | 3 | #### Requirements 4 | - openslide-python>=1.1.2 5 | - opencv-python>=4.4.0 6 | 7 | openslide-python relies on first installing OpenSlide; see [installation instructions](https://github.com/openslide/openslide-python). 8 | 9 | #### Instructions 10 | 11 | 1. Download the CAMELYON17 data from https://camelyon17.grand-challenge.org/Data/ into `SLIDE_ROOT`. The dataset is huge, so you might want to only download the 100 WSIs with lesion annotations, which by themselves are already 600G. You can find out which WSIs have annotations by looking at the `lesion_annotations` folder. The patch extraction code expects `SLIDE_ROOT` to contain the `lesion_annotations` and `tif` folders. 12 | 13 | 2. Run `python generate_all_patch_coords.py --slide_root SLIDE_ROOT --output_root OUTPUT_ROOT` to generate a .csv of all potential patches as well as the tissue/tumor/normal masks for each WSI. `OUTPUT_ROOT` is wherever you would like the patches to eventually be written. 14 | 15 | 3. Then run `python generate_final_metadata.py --output_root OUTPUT_ROOT` to select a class-balanced set of patches and assign splits. 16 | 17 | 4. Finally, run `python extract_final_patches_to_disk.py --slide_root SLIDE_ROOT --output_root OUTPUT_ROOT` to extract the chosen patches from the WSIs and write them to disk. 18 | -------------------------------------------------------------------------------- /dataset_preprocessing/camelyon17/extract_final_patches_to_disk.py: -------------------------------------------------------------------------------- 1 | import openslide 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | import random 7 | from tqdm import tqdm 8 | from generate_all_patch_coords import PATCH_LEVEL, MASK_LEVEL, CENTER_SIZE 9 | 10 | def write_patch_images_from_df(slide_root, output_root): 11 | read_df = pd.read_csv( 12 | os.path.join(output_root, 'metadata.csv'), 13 | index_col=0, 14 | dtype={'patient': 'str'}) 15 | 16 | patch_level = PATCH_LEVEL 17 | center_size = CENTER_SIZE 18 | patch_size = center_size * 3 19 | 20 | for idx in tqdm(read_df.index): 21 | orig_x = read_df.loc[idx, 'x_coord'] 22 | orig_y = read_df.loc[idx, 'y_coord'] 23 | patient = read_df.loc[idx, 'patient'] 24 | node = read_df.loc[idx, 'node'] 25 | 26 | patch_folder = os.path.join( 27 | output_root, 28 | 'patches', 29 | f'patient_{patient}_node_{node}') 30 | patch_path = os.path.join( 31 | patch_folder, 32 | f'patch_patient_{patient}_node_{node}_x_{orig_x}_y_{orig_y}.png') 33 | 34 | os.makedirs(patch_folder, exist_ok=True) 35 | if os.path.isfile(patch_path): 36 | continue 37 | 38 | slide_path = os.path.join( 39 | slide_root, 40 | 'tif', 41 | f'patient_{patient}_node_{node}.tif') 42 | 43 | slide = openslide.OpenSlide(slide_path) 44 | 45 | # Coords are at patch_level 46 | # First shift coords to top left corner of the entire patch 47 | x = orig_x - center_size 48 | y = orig_y - center_size 49 | # Then match to level 0 coords so we can use read_region 50 | x = int(round(x * slide.level_dimensions[0][0] / slide.level_dimensions[patch_level][0])) 51 | y = int(round(y * slide.level_dimensions[0][1] / slide.level_dimensions[patch_level][1])) 52 | 53 | patch = slide.read_region( 54 | (x, y), 55 | 2, 56 | (patch_size, patch_size)) 57 | patch.save(patch_path) 58 | 59 | 60 | if __name__ == '__main__': 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('--slide_root', required=True) 63 | parser.add_argument('--output_root', required=True) 64 | args = parser.parse_args() 65 | write_patch_images_from_df( 66 | slide_root=args.slide_root, 67 | output_root=args.output_root) 68 | -------------------------------------------------------------------------------- /dataset_preprocessing/camelyon17/generate_final_metadata.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from matplotlib import pyplot as plt 3 | import argparse 4 | import os,sys 5 | import numpy as np 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | def generate_final_metadata(output_root): 10 | df = pd.read_csv(os.path.join(output_root, 'all_patch_coords.csv'), 11 | index_col=0, 12 | dtype={ 13 | 'patient': 'str', 14 | 'tumor': 'int' 15 | }) 16 | 17 | # Assign slide numbers to patients + nodes 18 | patient_node_list = list(set(df[['patient', 'node']].itertuples(index=False, name=None))) 19 | patient_node_list.sort() 20 | patient_node_to_slide_map = {} 21 | for idx, (patient, node) in enumerate(patient_node_list): 22 | patient_node_to_slide_map[(patient, node)] = idx 23 | 24 | for (patient, node), slide_idx in patient_node_to_slide_map.items(): 25 | mask = (df['patient'] == patient) & (df['node'] == node) 26 | df.loc[mask, 'slide'] = slide_idx 27 | df['slide'] = df['slide'].astype('int') 28 | 29 | # The raw data has the following assignments: 30 | # Center 0: patients 0 to 19 31 | # Center 1: patients 20 to 39 32 | # Center 2: patients 40 to 59 33 | # Center 3: patients 60 to 79 34 | # Center 4: patients 80 to 99 35 | num_centers = 5 36 | patients_per_center = 20 37 | df['center'] = df['patient'].astype('int') // patients_per_center 38 | 39 | for k in range(num_centers): 40 | print(f"center {k}: " 41 | f"{np.sum((df['center'] == k) & (df['tumor'] == 0)):6d} non-tumor, " 42 | f"{np.sum((df['center'] == k) & (df['tumor'] == 1)):6d} tumor") 43 | 44 | for center, slide in set(df[['center', 'slide']].itertuples(index=False, name=None)): 45 | assert center == slide // 10 46 | 47 | # Keep all tumor patches, except if the slide has fewer normal than tumor patches 48 | # (slide 096 in center 4) 49 | # in which case we discard the excess tumor patches 50 | indices_to_keep = [] 51 | np.random.seed(0) 52 | tumor_mask = df['tumor'] == 1 53 | for slide in set(df['slide']): 54 | slide_mask = (df['slide'] == slide) 55 | num_tumor = np.sum(slide_mask & tumor_mask) 56 | num_non_tumor = np.sum(slide_mask & ~tumor_mask) 57 | slide_indices_with_tumor = list(df.index[slide_mask & tumor_mask]) 58 | indices_to_keep += list(np.random.choice( 59 | slide_indices_with_tumor, 60 | size=min(num_tumor, num_non_tumor), 61 | replace=False)) 62 | 63 | tumor_keep_mask = np.zeros(len(df)) 64 | tumor_keep_mask[df.index[indices_to_keep]] = 1 65 | 66 | # Within each center and split, keep same number of normal patches as tumor patches 67 | for center in range(num_centers): 68 | print(f'Center {center}:') 69 | center_mask = df['center'] == center 70 | num_tumor = np.sum(center_mask & tumor_keep_mask) 71 | print(f' Num tumor: {num_tumor}') 72 | 73 | num_non_tumor = np.sum(center_mask & ~tumor_mask) 74 | center_indices_without_tumor = list(df.index[center_mask & ~tumor_mask]) 75 | indices_to_keep += list(np.random.choice( 76 | center_indices_without_tumor, 77 | size=min(num_tumor, num_non_tumor), 78 | replace=False)) 79 | 80 | print(f' Num non-tumor: {min(num_tumor, num_non_tumor)} out of {num_non_tumor} ({min(num_tumor, num_non_tumor) / num_non_tumor * 100:.1f}%)') 81 | 82 | df_to_keep = df.loc[indices_to_keep, :].copy().reset_index(drop=True) 83 | 84 | val_frac = 0.1 85 | 86 | split_dict = { 87 | 'train': 0, 88 | 'val': 1, 89 | 'test': 2 90 | } 91 | 92 | df_to_keep['split'] = split_dict['train'] 93 | 94 | all_indices = list(df_to_keep.index) 95 | val_indices = list(np.random.choice( 96 | all_indices, 97 | size=int(val_frac * len(all_indices)), 98 | replace=False)) 99 | df_to_keep.loc[val_indices, 'split'] = split_dict['val'] 100 | 101 | print('Statistics by center:') 102 | for center in range(num_centers): 103 | tumor_mask = df_to_keep['tumor'] == 1 104 | center_mask = df_to_keep['center'] == center 105 | num_tumor = np.sum(center_mask & tumor_mask) 106 | num_non_tumor = np.sum(center_mask & ~tumor_mask) 107 | 108 | print(f'Center {center}') 109 | print(f' {num_tumor} / {num_tumor + num_non_tumor} ({num_tumor / (num_tumor + num_non_tumor) * 100:.1f}%) tumor') 110 | 111 | df_to_keep.to_csv(os.path.join(output_root, 'metadata.csv')) 112 | 113 | if __name__ == '__main__': 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('--output_root', required=True) 116 | args = parser.parse_args() 117 | generate_final_metadata(args.output_root) 118 | -------------------------------------------------------------------------------- /dataset_preprocessing/camelyon17/unlabeled/README.md: -------------------------------------------------------------------------------- 1 | ## Unlabeled Camelyon17-WILDS patch processing 2 | 3 | #### Requirements 4 | 5 | - openslide-python>=1.1.2 6 | - opencv-python>=4.4.0 7 | 8 | openslide-python relies on first installing OpenSlide; 9 | see [installation instructions](https://github.com/openslide/openslide-python). 10 | 11 | #### Instructions 12 | 13 | 1. Download the [CAMELYON17 training data](https://drive.google.com/drive/folders/0BzsdkU4jWx9BSEI2X1VOLUpYZ3c?resourcekey=0-41XIPJNyEAo598wHxVAP9w) 14 | into `SLIDE_ROOT`. 15 | 16 | 2. Run `python generate_all_patch_coords.py --slide_root SLIDE_ROOT --output_root OUTPUT_ROOT` to generate a .csv of all 17 | potential patches as well as the tissue masks for each WSI. `OUTPUT_ROOT` is wherever you would like the 18 | patches to eventually be written. 19 | 20 | 3. Then run `python generate_final_metadata.py --slide_root SLIDE_ROOT --output_root OUTPUT_ROOT` 21 | to generate the metadata.csv file for unlabeled Camelyon. 22 | 23 | 4. Finally, run `python extract_final_patches_to_disk.py --slide_root SLIDE_ROOT --output_root OUTPUT_ROOT` to 24 | extract the chosen patches from the WSIs and write them to disk. -------------------------------------------------------------------------------- /dataset_preprocessing/camelyon17/unlabeled/extract_final_patches_to_disk.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pdb 4 | from tqdm import tqdm 5 | 6 | import openslide 7 | import pandas as pd 8 | 9 | from generate_all_patch_coords import PATCH_LEVEL, CENTER_SIZE 10 | 11 | 12 | def write_patch_images_from_df(slide_root, output_root): 13 | print("Writing patch images to disk...") 14 | read_df = pd.read_csv( 15 | os.path.join(output_root, "metadata.csv"), index_col=0, dtype={"patient": "str"} 16 | ) 17 | 18 | patch_level = PATCH_LEVEL 19 | center_size = CENTER_SIZE 20 | patch_size = center_size * 3 21 | 22 | for idx in tqdm(read_df.index): 23 | orig_x = read_df.loc[idx, "x_coord"] 24 | orig_y = read_df.loc[idx, "y_coord"] 25 | center = read_df.loc[idx, "center"] 26 | patient = read_df.loc[idx, "patient"] 27 | node = read_df.loc[idx, "node"] 28 | 29 | patch_folder = os.path.join( 30 | output_root, "patches", f"patient_{patient}_node_{node}" 31 | ) 32 | patch_path = os.path.join( 33 | patch_folder, 34 | f"patch_patient_{patient}_node_{node}_x_{orig_x}_y_{orig_y}.png", 35 | ) 36 | 37 | os.makedirs(patch_folder, exist_ok=True) 38 | if os.path.isfile(patch_path): 39 | continue 40 | 41 | slide_path = os.path.join( 42 | slide_root, 43 | f"center_{center}", 44 | f"patient_{patient}", 45 | f"patient_{patient}_node_{node}.tif", 46 | ) 47 | slide = openslide.OpenSlide(slide_path) 48 | 49 | # Coords are at patch_level 50 | # First shift coords to top left corner of the entire patch 51 | x = orig_x - center_size 52 | y = orig_y - center_size 53 | # Then match to level 0 coords so we can use read_region 54 | x = int( 55 | round( 56 | x 57 | * slide.level_dimensions[0][0] 58 | / slide.level_dimensions[patch_level][0] 59 | ) 60 | ) 61 | y = int( 62 | round( 63 | y 64 | * slide.level_dimensions[0][1] 65 | / slide.level_dimensions[patch_level][1] 66 | ) 67 | ) 68 | 69 | patch = slide.read_region((x, y), 2, (patch_size, patch_size)) 70 | patch.save(patch_path) 71 | print("Done.") 72 | 73 | 74 | if __name__ == "__main__": 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument("--slide_root", required=True) 77 | parser.add_argument("--output_root", required=True) 78 | args = parser.parse_args() 79 | write_patch_images_from_df(slide_root=args.slide_root, output_root=args.output_root) 80 | -------------------------------------------------------------------------------- /dataset_preprocessing/camelyon17/unlabeled/generate_final_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pdb 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from matplotlib import pyplot as plt 8 | 9 | # Fix seed for reproducibility 10 | np.random.seed(0) 11 | 12 | _NUM_CENTERS = 5 13 | _NUM_PATCHES_TO_SUBSAMPLE = 6667 14 | _NUM_PATIENTS_PER_HOSPITAL = 20 15 | 16 | 17 | def generate_final_metadata(slide_root, output_root): 18 | def print_stats(patches_df): 19 | print(f"\nStatistics:\nTotal # of patches: {patches_df.shape[0]}") 20 | for center in range(_NUM_CENTERS): 21 | print( 22 | f"Center {center}: {np.sum(patches_df['center'] == center):6d} patches" 23 | ) 24 | print() 25 | 26 | patches_path = os.path.join(output_root, "all_unlabeled_patch_coords.csv") 27 | print(f"Importing patches from {patches_path}...") 28 | df = pd.read_csv( 29 | patches_path, 30 | index_col=0, 31 | dtype={"patient": "str", "tumor": "int"}, 32 | ) 33 | 34 | # Assign slide numbers to patients + nodes 35 | patient_node_list = list( 36 | set(df[["patient", "node"]].itertuples(index=False, name=None)) 37 | ) 38 | patient_node_list.sort() 39 | patient_node_to_slide_map = {} 40 | for idx, (patient, node) in enumerate(patient_node_list): 41 | patient_node_to_slide_map[(patient, node)] = idx 42 | 43 | for (patient, node), slide_idx in patient_node_to_slide_map.items(): 44 | mask = (df["patient"] == patient) & (df["node"] == node) 45 | df.loc[mask, "slide"] = slide_idx 46 | df["slide"] = df["slide"].astype("int") 47 | 48 | # The raw data has the following assignments: 49 | # Center 0: patients 0 to 19 50 | # Center 1: patients 20 to 39 51 | # Center 2: patients 40 to 59 52 | # Center 3: patients 60 to 79 53 | # Center 4: patients 80 to 99 54 | df["center"] = df["patient"].astype("int") // _NUM_PATIENTS_PER_HOSPITAL 55 | print_stats(df) 56 | for center, slide in set( 57 | df[["center", "slide"]].itertuples(index=False, name=None) 58 | ): 59 | assert center == slide // 100, "Expected 100 slides per center." 60 | 61 | # Remove patches from the original metadata.csv before subsampling. 62 | # There are 50 XML files in the lesion_annotation folder, so 50 patient-node pairs were 63 | # already used in the original WILDS Camelyon dataset. 64 | print( 65 | "Removing patches from slides that were used in the original Camelyon-WILDS dataset..." 66 | ) 67 | for file in os.listdir(os.path.join(slide_root, "lesion_annotations")): 68 | if file.endswith(".xml") and not file.startswith("._"): 69 | prefix = file.split(".xml")[0] 70 | patient = prefix.split("_")[1] 71 | node = prefix.split("_")[3] 72 | 73 | patient_mask = df["patient"] == patient 74 | node_mask = df["node"] == int(node) 75 | df = df[~(patient_mask & node_mask)] 76 | print_stats(df) 77 | 78 | # The labeled Camelyon-WILDS dataset has approximately 300,000 patches. We want about 10x unlabeled data, 79 | # which corresponds to ~3 million patches. Since each hospital of the original Camelyon17 training set 80 | # has a 100 slides, we subsample 6,667 patches from each slide, resulting in 600,030 patches total from each 81 | # hospital except for Center 0. Slide 38 of Center 0 only has 5,824 patches, so we instead subsample a total of 82 | # 599,187 patches for Center 0. Therefore, there is a total of 2,999,307 unlabeled patches across the hospitals. 83 | print(f"Subsampling {_NUM_PATCHES_TO_SUBSAMPLE} patches from each slide...") 84 | indices_to_keep = [] 85 | for slide in set(df["slide"]): 86 | slide_mask = df["slide"] == slide 87 | slide_indices = list(df.index[slide_mask]) 88 | print( 89 | f"slide={slide}, choosing {_NUM_PATCHES_TO_SUBSAMPLE} patches from {len(slide_indices)} patches" 90 | ) 91 | if _NUM_PATCHES_TO_SUBSAMPLE < len(slide_indices): 92 | indices_to_keep += list( 93 | np.random.choice( 94 | slide_indices, size=_NUM_PATCHES_TO_SUBSAMPLE, replace=False 95 | ) 96 | ) 97 | else: 98 | print("Adding all slides...") 99 | indices_to_keep += slide_indices 100 | df_to_keep = df.loc[indices_to_keep, :].copy().reset_index(drop=True) 101 | 102 | print_stats(df_to_keep) 103 | df_to_keep.to_csv(os.path.join(output_root, "metadata.csv")) 104 | print("Done.") 105 | 106 | 107 | if __name__ == "__main__": 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument("--slide_root", required=True) 110 | parser.add_argument("--output_root", required=True) 111 | args = parser.parse_args() 112 | 113 | generate_final_metadata(slide_root=args.slide_root, output_root=args.output_root) 114 | -------------------------------------------------------------------------------- /dataset_preprocessing/camelyon17/unlabeled/validate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pdb 4 | 5 | """ 6 | Validate the content of the unlabeled Camelyon17 dataset after 7 | preprocessing and uploading to CodaLab. 8 | 9 | Statistics: 10 | Total # of patches: 2,999,307 11 | Center 0: 599,187 patches 12 | Center 1: 600,030 patches 13 | Center 2: 600,030 patches 14 | Center 3: 600,030 patches 15 | Center 4: 600,030 patches 16 | 17 | Usage: 18 | 19 | python dataset_preprocessing/camelyon17/unlabeled/validate.py 20 | """ 21 | 22 | _EXPECTED_SLIDES_COUNT = 450 23 | 24 | 25 | def validate_unlabeled_dataset(root_dir: str): 26 | def get_patients_center(patient_id: str): 27 | patient_no = int(patient_id) 28 | if 0 <= patient_no < 20: 29 | return 0 30 | elif 20 <= patient_no < 40: 31 | return 1 32 | elif 40 <= patient_no < 60: 33 | return 2 34 | elif 60 <= patient_no < 80: 35 | return 3 36 | elif 80 <= patient_no < 100: 37 | return 4 38 | else: 39 | raise ValueError(f"Can't get center for patient {patient_id}.") 40 | 41 | dataset_dir = os.path.join(root_dir, "camelyon17_unlabeled_v1.0") 42 | content = os.listdir(dataset_dir) 43 | assert "patches" in content 44 | assert "RELEASE_v1.0.txt" in content 45 | assert "metadata.csv" in content 46 | 47 | slides_dir = os.path.join(dataset_dir, "patches") 48 | slides = os.listdir(slides_dir) 49 | 50 | slide_count = 0 51 | patch_counts = [0 for _ in range(5)] 52 | for slide in slides: 53 | patches_dir = os.path.join(slides_dir, slide) 54 | if not os.path.isdir(patches_dir): 55 | continue 56 | slide_count += 1 57 | 58 | slide_split = slide.split("_") 59 | assert len(slide_split) == 4 60 | patient_id = slide_split[1] 61 | center = get_patients_center(patient_id) 62 | for patch in os.listdir(patches_dir): 63 | if patch.endswith(".png"): 64 | patch_counts[center] += 1 65 | 66 | assert ( 67 | slide_count == _EXPECTED_SLIDES_COUNT 68 | ), f"Got incorrect number of slides. Expected: {_EXPECTED_SLIDES_COUNT}, Actual: {len(slides)}" 69 | print(f"Patch counts: {patch_counts}") 70 | assert patch_counts == [599187, 600030, 600030, 600030, 600030] 71 | assert sum(patch_counts) == 2999307 72 | print("\nVerified.") 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument("root_dir", help="Path to the datasets directory.") 78 | args = parser.parse_args() 79 | validate_unlabeled_dataset(args.root_dir) 80 | -------------------------------------------------------------------------------- /dataset_preprocessing/civilcomments/README.md: -------------------------------------------------------------------------------- 1 | ## CivilComments-wilds processing 2 | 3 | #### Instructions 4 | 5 | 1. Download `all_data.csv` from https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/data. 6 | 7 | 2. Run `python process_labeled.py --root ROOT`, where `ROOT` is where you downloaded `ROOT`. This will create `all_data_with_identities.csv` in the same folder, which is the labeled data that we use in WILDS. 8 | 9 | 3. After the above step, run `python process_unlabeled.py --root ROOT`, where `ROOT` is where you downloaded `ROOT`. This will create `unlabeled_data_with_identities.csv` in the same folder, which is the unlabeled data that we optionally use in WILDS. 10 | -------------------------------------------------------------------------------- /dataset_preprocessing/civilcomments/attr_definitions.py: -------------------------------------------------------------------------------- 1 | ORIG_ATTRS = [ 2 | 'male', 3 | 'female', 4 | 'transgender', 5 | 'other_gender', 6 | 'heterosexual', 7 | 'homosexual_gay_or_lesbian', 8 | 'bisexual', 9 | 'other_sexual_orientation', 10 | 'christian', 11 | 'jewish', 12 | 'muslim', 13 | 'hindu', 14 | 'buddhist', 15 | 'atheist', 16 | 'other_religion', 17 | 'black', 18 | 'white', 19 | 'asian', 20 | 'latino', 21 | 'other_race_or_ethnicity', 22 | 'physical_disability', 23 | 'intellectual_or_learning_disability', 24 | 'psychiatric_or_mental_illness', 25 | 'other_disability', 26 | ] 27 | 28 | AGGREGATE_ATTRS = { 29 | 'LGBTQ': [ 30 | 'homosexual_gay_or_lesbian', 31 | 'bisexual', 32 | 'other_sexual_orientation', 33 | 'transgender', 34 | 'other_gender'], 35 | 'other_religions': [ 36 | 'jewish', 37 | 'hindu', 38 | 'buddhist', 39 | 'atheist', 40 | 'other_religion' 41 | ], 42 | 'asian_latino_etc': [ 43 | 'asian', 44 | 'latino', 45 | 'other_race_or_ethnicity' 46 | ], 47 | 'disability_any': [ 48 | 'physical_disability', 49 | 'intellectual_or_learning_disability', 50 | 'psychiatric_or_mental_illness', 51 | 'other_disability', 52 | ], 53 | 'identity_any': ORIG_ATTRS, 54 | } 55 | 56 | GROUP_ATTRS = { 57 | 'gender': [ 58 | 'male', 59 | 'female', 60 | 'transgender', 61 | 'other_gender', 62 | ], 63 | 'orientation': [ 64 | 'heterosexual', 65 | 'homosexual_gay_or_lesbian', 66 | 'bisexual', 67 | 'other_sexual_orientation', 68 | ], 69 | 'religion': [ 70 | 'christian', 71 | 'jewish', 72 | 'muslim', 73 | 'hindu', 74 | 'buddhist', 75 | 'atheist', 76 | 'other_religion' 77 | ], 78 | 'race': [ 79 | 'black', 80 | 'white', 81 | 'asian', 82 | 'latino', 83 | 'other_race_or_ethnicity' 84 | ], 85 | 'disability': [ 86 | 'physical_disability', 87 | 'intellectual_or_learning_disability', 88 | 'psychiatric_or_mental_illness', 89 | 'other_disability', 90 | ] 91 | } 92 | -------------------------------------------------------------------------------- /dataset_preprocessing/civilcomments/process_labeled.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from matplotlib import pyplot as plt 3 | import os,sys 4 | import numpy as np 5 | from tqdm import tqdm 6 | import argparse 7 | 8 | from attr_definitions import GROUP_ATTRS, AGGREGATE_ATTRS, ORIG_ATTRS 9 | 10 | def load_df(root): 11 | """ 12 | Loads the data and removes all examples where we don't have identity annotations. 13 | """ 14 | df = pd.read_csv(os.path.join(root, 'all_data.csv')) 15 | df = df.loc[(df['identity_annotator_count'] > 0), :] 16 | df = df.reset_index(drop=True) 17 | return df 18 | 19 | def augment_df(df): 20 | """ 21 | Augment the dataframe with auxiliary attributes. 22 | First, we create aggregate attributes, like `LGBTQ` or `other_religions`. 23 | These are aggregated because there would otherwise not be enough examples to accurately 24 | estimate their accuracy. 25 | 26 | Next, for each category of demographics (e.g., race, gender), we construct an auxiliary 27 | attribute (e.g., `na_race`, `na_gender`) that is 1 if the comment has no identities related to 28 | that demographic, and is 0 otherwise. 29 | Note that we can't just create a single multi-valued attribute like `gender` because there's 30 | substantial overlap: for example, 4.6% of comments mention both male and female identities. 31 | """ 32 | df = df.copy() 33 | for aggregate_attr in AGGREGATE_ATTRS: 34 | aggregate_mask = pd.Series([False] * len(df)) 35 | for attr in AGGREGATE_ATTRS[aggregate_attr]: 36 | attr_mask = (df[attr] >= 0.5) 37 | aggregate_mask = aggregate_mask | attr_mask 38 | df[aggregate_attr] = 0 39 | df.loc[aggregate_mask, aggregate_attr] = 1 40 | 41 | attr_count = np.zeros(len(df)) 42 | for attr in ORIG_ATTRS: 43 | attr_mask = (df[attr] >= 0.5) 44 | attr_count += attr_mask 45 | df['num_identities'] = attr_count 46 | df['more_than_one_identity'] = (attr_count > 1) 47 | 48 | for group in GROUP_ATTRS: 49 | print(f'## {group}') 50 | counts = {} 51 | na_mask = np.ones(len(df)) 52 | for attr in GROUP_ATTRS[group]: 53 | attr_mask = (df[attr] >= 0.5) 54 | na_mask = na_mask & ~attr_mask 55 | counts[attr] = np.mean(attr_mask) 56 | counts['n/a'] = np.mean(na_mask) 57 | 58 | col_name = f'na_{group}' 59 | df[col_name] = 0 60 | df.loc[na_mask, col_name] = 1 61 | 62 | for k, v in counts.items(): 63 | print(f'{k:40s}: {v:.4f}') 64 | print() 65 | return df 66 | 67 | def construct_splits(df): 68 | """ 69 | Construct splits. 70 | The original data already has a train vs. test split. 71 | We triple the size of the test set so that we can better estimate accuracy on the small groups, 72 | and construct a validation set by randomly sampling articles. 73 | """ 74 | 75 | df = df.copy() 76 | train_df = df.loc[df['split'] == 'train'] 77 | test_df = df.loc[df['split'] == 'test'] 78 | train_articles = set(train_df['article_id'].values) 79 | test_articles = set(test_df['article_id'].values) 80 | # Assert no overlap between train and test articles 81 | assert len(train_articles.intersection(test_articles)) == 0 82 | 83 | n_train = len(train_df) 84 | n_test = len(test_df) 85 | n_train_articles = len(train_articles) 86 | n_test_articles = len(test_articles) 87 | 88 | ## Set params 89 | n_val_articles = n_test_articles 90 | n_new_test_articles = 2 * n_test_articles 91 | 92 | np.random.seed(0) 93 | 94 | # Sample val articles 95 | val_articles = np.random.choice( 96 | list(train_articles), 97 | size=n_val_articles, 98 | replace=False) 99 | df.loc[df['article_id'].isin(val_articles), 'split'] = 'val' 100 | 101 | # Sample new test articles 102 | train_articles = train_articles - set(val_articles) 103 | new_test_articles = np.random.choice( 104 | list(train_articles), 105 | size=n_new_test_articles, 106 | replace=False) 107 | df.loc[df['article_id'].isin(new_test_articles), 'split'] = 'test' 108 | 109 | train_df = df.loc[df['split'] == 'train'] 110 | val_df = df.loc[df['split'] == 'val'] 111 | test_df = df.loc[df['split'] == 'test'] 112 | 113 | train_articles = set(train_df['article_id'].values) 114 | val_articles = set(val_df['article_id'].values) 115 | test_articles = set(test_df['article_id'].values) 116 | 117 | # Sanity checks 118 | assert len(df) == len(train_df) + len(val_df) + len(test_df) 119 | assert n_train == len(train_df) + len(val_df) + np.sum(df['article_id'].isin(new_test_articles)) 120 | assert n_test == len(test_df) - np.sum(df['article_id'].isin(new_test_articles)) 121 | assert n_train_articles == len(train_articles) + len(val_articles) + len(new_test_articles) 122 | assert n_val_articles == len(val_articles) 123 | assert n_test_articles == len(test_articles) - n_new_test_articles 124 | assert len(train_articles.intersection(val_articles)) == 0 125 | assert len(train_articles.intersection(test_articles)) == 0 126 | assert len(val_articles.intersection(test_articles)) == 0 127 | 128 | print('% of examples') 129 | for split in ['train', 'val', 'test']: 130 | print(split, np.mean(df['split'] == split), np.sum(df['split'] == split)) 131 | print('') 132 | 133 | print('class balance') 134 | for split in ['train', 'val', 'test']: 135 | split_df = df.loc[df['split'] == split] 136 | print('pos', np.mean(split_df['toxicity'] > 0.5)) 137 | return df 138 | 139 | if __name__ == '__main__': 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument('--root', required=True) 142 | args = parser.parse_args() 143 | 144 | df = load_df(args.root) 145 | df = augment_df(df) 146 | df = construct_splits(df) 147 | df.to_csv(os.path.join(args.root, f'all_data_with_identities.csv')) 148 | -------------------------------------------------------------------------------- /dataset_preprocessing/civilcomments/process_unlabeled.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import pdb 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | # Fix the seed for reproducibility 10 | np.random.seed(0) 11 | 12 | """ 13 | Process unlabeled data in CivilComments. 14 | Script is intended to be run after process_labeled.py 15 | 16 | Note that there is substantial overlap between the articles that unlabeled 17 | comments are from and the articles that the labeled comments are from. 18 | Specifically, 92% (1427849 out of 1551516) unlabeled comments are from 19 | articles that also have comments in the labeled set. 20 | """ 21 | 22 | TRAIN, VAL, TEST, UNLABELED = ('train', 'val', 'test', 'extra_unlabeled') 23 | 24 | def load_unlabeled_df(root): 25 | """ 26 | Loads the raw data where we don't have identity annotations. 27 | """ 28 | df = pd.read_csv(os.path.join(root, 'all_data.csv')) 29 | df = df.loc[(df['identity_annotator_count'] == 0), :] 30 | df = df.dropna(axis=0, how='any', subset=['id', 'comment_text', 'article_id']) # make sure data is clean 31 | df = df.reset_index(drop=True) 32 | return df 33 | 34 | def load_labeled_df(root): 35 | """ 36 | Loads the processed data for which we do have identity annotations. 37 | """ 38 | df = pd.read_csv(os.path.join(root, 'all_data_with_identities.csv'), index_col=0) 39 | return df 40 | 41 | def merge_dfs(unlabeled, labeled): 42 | """ 43 | Drops columns that are in unlabeled but not labeled 44 | Adds columns that are in labeled but not unlabeled and sets values to NaN 45 | """ 46 | common_cols = unlabeled.columns & labeled.columns 47 | unlabeled = unlabeled[common_cols] 48 | joint = labeled.append(unlabeled, ignore_index = True) 49 | return joint 50 | 51 | def main(args): 52 | unlabeled = load_unlabeled_df(args.root) 53 | labeled = load_labeled_df(args.root) 54 | 55 | # set all unlabeled examples to the same split 56 | unlabeled['split'] = UNLABELED 57 | 58 | # merge unlabeled, labeled dfs 59 | joint = merge_dfs(unlabeled, labeled) 60 | assert (joint.columns == labeled.columns).all() 61 | 62 | def output_split_sizes(df): 63 | print("-" * 50) 64 | print(f'Train size: {len(df[df["split"] == TRAIN])}') 65 | print(f'Val size: {len(df[df["split"] == VAL])}') 66 | print(f'Test size: {len(df[df["split"] == TEST])}') 67 | print( 68 | f'Unlabeled size: {len(df[df["split"] == UNLABELED])}' 69 | ) 70 | print("-" * 50) 71 | print("\n") 72 | 73 | output_split_sizes(joint) 74 | 75 | # Write out the new unlabeled split to user.csv 76 | joint.to_csv(f'{args.root}/all_data_with_identities_and_unlabeled.csv', index=True) 77 | joint[joint['split'] == UNLABELED].to_csv(f'{args.root}/unlabeled_data_with_identities.csv', index=True) 78 | print("Done.") 79 | 80 | 81 | if __name__ == "__main__": 82 | parser = argparse.ArgumentParser(description="Create unlabeled splits for CivilComments.") 83 | parser.add_argument( 84 | "--root", 85 | type=str, 86 | help="Path to the dir containing the CivilComments processed labeled csv and full csv.", 87 | ) 88 | args = parser.parse_args() 89 | main(args) 90 | -------------------------------------------------------------------------------- /dataset_preprocessing/domainnet/generate_sentry_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pdb 4 | 5 | import pandas as pd 6 | import numpy as np 7 | 8 | # Fix the seed for reproducibility 9 | np.random.seed(0) 10 | 11 | """ 12 | Generate a CSV with the metadata for DomainNet (SENTRY version): 13 | 14 | @inproceedings{peng2019moment, 15 | title={Moment matching for multi-source domain adaptation}, 16 | author={Peng, Xingchao and Bai, Qinxun and Xia, Xide and Huang, Zijun and Saenko, Kate and Wang, Bo}, 17 | booktitle={Proceedings of the IEEE International Conference on Computer Vision}, 18 | pages={1406--1415}, 19 | year={2019} 20 | } 21 | 22 | @article{prabhu2020sentry 23 | author = {Prabhu, Viraj and Khare, Shivam and Kartik, Deeksha and Hoffman, Judy}, 24 | title = {SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation}, 25 | year = {2020}, 26 | journal = {arXiv preprint: 2012.11460}, 27 | } 28 | 29 | The dataset can be downloaded from http://ai.bu.edu/M3SDA. 30 | The SENTRY splits can be found https://github.com/virajprabhu/SENTRY/tree/main/data/DomainNet/txt. 31 | 32 | There are 586,576 images in 345 categories (airplane, ball, cup, etc.) across 6 domains (clipart, infograph, 33 | painting, quickdraw, real and sketch) in the original DomainNet dataset. Images are either PNG or JPG files. 34 | 35 | The SENTRY version of the dataset has 40 categories across 4 domains: 36 | "Due to labeling noise prevalent in the full version of DomainNet, we instead use the subset proposed in 37 | Tan et al. [42], which uses 40-commonly seen classes from four domains: Real (R), Clipart (C), Painting (P), 38 | and Sketch (S)." 39 | 40 | The metadata CSV file has the following fields: 41 | 42 | 1. image_path: Path to the image file. The path has the following format: //. 43 | 2. domain: One of the 4 possible domains. 44 | 3. split: One of "train" or "test". 45 | 4. category: One of the 40 possible categories. 46 | 5. y: Given to us by the SENTRY split 47 | 48 | Example usage: 49 | 50 | python dataset_preprocessing/domainnet/generate_sentry_metadata.py . 51 | 52 | """ 53 | 54 | DOMAINS = ["clipart", "painting", "real", "sketch"] 55 | METADATA_COLUMNS = ["image_path", "domain", "split", "category", "y"] 56 | NUM_OF_CATEGORIES = 40 57 | TEST_SPLIT = "test" 58 | TRAIN_SPLIT = "train" 59 | 60 | 61 | def main(sentry_splits_path): 62 | def process_split(split, split_path): 63 | count = 0 64 | categories = set() 65 | with open(split_path) as f: 66 | for line in f.readlines(): 67 | image_path, label = line.strip().split(" ") 68 | metadata_values = image_path.split(os.path.sep) 69 | metadata_dict["image_path"].append(image_path) 70 | metadata_dict["domain"].append(metadata_values[0]) 71 | metadata_dict["split"].append(split) 72 | metadata_dict["category"].append(metadata_values[1]) 73 | categories.add(metadata_values[1]) 74 | metadata_dict["y"].append(int(label)) 75 | count += 1 76 | assert len(categories) == NUM_OF_CATEGORIES 77 | return count 78 | 79 | print("Generating sentry_metadata.csv for DomainNet (SENTRY version)...") 80 | 81 | metadata_dict = {column: [] for column in METADATA_COLUMNS} 82 | for domain in DOMAINS: 83 | train_count = process_split( 84 | TRAIN_SPLIT, 85 | os.path.join(sentry_splits_path, f"{domain}_{TRAIN_SPLIT}_mini.txt"), 86 | ) 87 | test_count = process_split( 88 | TEST_SPLIT, 89 | os.path.join(sentry_splits_path, f"{domain}_{TEST_SPLIT}_mini.txt"), 90 | ) 91 | total_count = train_count + test_count 92 | train_percentage = np.round(float(train_count) / total_count * 100.0, 2) 93 | test_percentage = np.round(float(test_count) / total_count * 100.0, 2) 94 | print( 95 | f"Domain {domain} had {train_count} ({train_percentage}%) training examples " 96 | f"and {test_count} ({test_percentage}%) test examples with a total of {total_count} examples." 97 | ) 98 | 99 | # Write metadata out as a CSV file 100 | metadata_df = pd.DataFrame(metadata_dict) 101 | metadata_path = os.path.join(sentry_splits_path, "sentry_metadata.csv") 102 | print(f"Writing metadata out to {metadata_path}...") 103 | metadata_df.to_csv(metadata_path, index=False) 104 | print("Done.") 105 | 106 | 107 | if __name__ == "__main__": 108 | parser = argparse.ArgumentParser( 109 | description="Generate a CSV with the metadata for DomainNet (SENTRY version)." 110 | ) 111 | parser.add_argument( 112 | "path", 113 | type=str, 114 | help="Path to the DomainNet dataset downloaded from http://ai.bu.edu/M3SDA", 115 | ) 116 | 117 | args = parser.parse_args() 118 | main(args.path) 119 | -------------------------------------------------------------------------------- /dataset_preprocessing/encode/README.md: -------------------------------------------------------------------------------- 1 | ## ENCODE feature generation and preprocessing 2 | 3 | #### Requirements 4 | - pyBigWig 5 | 6 | #### Instructions to create Codalab bundle 7 | 8 | Here are instructions to reproduce the Codalab bundle, in a directory path `BUNDLE_ROOT_DIRECTORY`. 9 | 10 | 1. Download the human genome sequence (hg19 assembly) in FASTA format from http://hgdownload.cse.ucsc.edu/goldenpath/hg19/bigZips/hg19.fa.gz and extract it into `SEQUENCE_PATH`. 11 | 12 | 2. Run `python prep_sequence.py --seq_path SEQUENCE_PATH --output_dir OUTPUT_DIR` to write the fasta file found in `SEQUENCE_PATH` to a numpy array archive in `OUTPUT_PATH`. (The dataset loader assumes `OUTPUT_PATH` to be `/sequence.npz`.) 13 | 14 | 3. Download the DNase accessibility data. This consists of whole-genome DNase files in bigwig format from https://guanfiles.dcmb.med.umich.edu/Leopard/dnase_bigwig/. Save these to filenames `/DNASE..fc.signal.bigwig` in the code. 15 | 16 | 4. Run `python prep_accessibility.py`. This writes samples of each bigwig file to `/qn..npy`. These are used at runtime when the dataset loader is initialized, to perform quantile normalization on the DNase accessibility signals. 17 | 18 | 5. Download the labels from the challenge into a label directory `/labels/` created for this purpose: 19 | - The training chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX, downloaded as MAX.train.labels.tsv.gz ). 20 | - The training chromosome labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8077511 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8077648 for the TF MAX, downloaded as MAX.train_wc.labels.tsv.gz ). 21 | - The validation chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX, downloaded as MAX.val.labels.tsv.gz ). 22 | - The validation chromosome labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8442975 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8443021 for the TF MAX, downloaded as MAX.test.labels.tsv.gz ). 23 | 24 | 6. Run `python prep_metadata_labels.py`. 25 | 26 | -------------------------------------------------------------------------------- /dataset_preprocessing/encode/prep_accessibility.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/GuanLab/Leopard/blob/master/data/quantile_normalize_bigwig.py 2 | 3 | import argparse, time 4 | import numpy as np 5 | import pyBigWig 6 | 7 | # Human chromosomes in hg19, and their sizes in bp 8 | chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560} 9 | 10 | 11 | def qn_sample_to_array( 12 | input_celltypes, 13 | input_chroms=None, 14 | subsampling_ratio=1000, 15 | data_pfx = '/users/abalsubr/wilds/examples/data/encode_v1.0/' 16 | ): 17 | """ 18 | Compute and write distribution of DNase bigwigs corresponding to input celltypes. 19 | """ 20 | if input_chroms is None: 21 | input_chroms = chrom_sizes.keys() 22 | qn_chrom_sizes = { k: chrom_sizes[k] for k in input_chroms } 23 | # Initialize chromosome-specific seeds for subsampling 24 | chr_to_seed = {} 25 | i = 0 26 | for the_chr in qn_chrom_sizes: 27 | chr_to_seed[the_chr] = i 28 | i += 1 29 | 30 | # subsampling 31 | sample_len = np.ceil(np.array(list(qn_chrom_sizes.values()))/subsampling_ratio).astype(int) 32 | sample = np.zeros(sum(sample_len)) 33 | start = 0 34 | j = 0 35 | for the_chr in qn_chrom_sizes: 36 | np.random.seed(chr_to_seed[the_chr]) 37 | for ct in input_celltypes: 38 | path = data_pfx + 'DNASE.{}.fc.signal.bigwig'.format(ct) 39 | bw = pyBigWig.open(path) 40 | signal = np.nan_to_num(np.array(bw.values(the_chr, 0, qn_chrom_sizes[the_chr]))) 41 | index = np.random.randint(0, len(signal), sample_len[j]) 42 | sample[start:(start+sample_len[j])] += (1.0/len(input_celltypes))*signal[index] 43 | start += sample_len[j] 44 | j += 1 45 | print(the_chr, ct) 46 | sample.sort() 47 | np.save(data_pfx + "qn.{}.npy".format('.'.join(input_celltypes)), sample) 48 | 49 | 50 | if __name__ == '__main__': 51 | train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] 52 | all_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878', 'MCF-7', 'HepG2', 'liver'] 53 | for ct in all_celltypes: 54 | qn_sample_to_array([ct], input_chroms=train_chroms) 55 | -------------------------------------------------------------------------------- /dataset_preprocessing/fmow/convert_npy_to_jpg.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import numpy as np 4 | from PIL import Image 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | 8 | def main(): 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--root_dir', required=True, 12 | help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).') 13 | config = parser.parse_args() 14 | data_dir = Path(config.root_dir) / 'fmow_v1.0' 15 | image_dir = Path(config.root_dir) / 'fmow_v1.0_images_jpg' 16 | os.makedirs(image_dir, exist_ok=True) 17 | 18 | img_counter = 0 19 | for chunk in tqdm(range(101)): 20 | npy_chunk = np.load(data_dir / f'rgb_all_imgs_{chunk}.npy', mmap_mode='r') 21 | for i in range(len(npy_chunk)): 22 | npy_image = npy_chunk[i] 23 | img = Image.fromarray(npy_image, mode='RGB') 24 | img.save(image_dir / f'rgb_img_{img_counter}.jpg') 25 | img_counter += 1 26 | 27 | if __name__=='__main__': 28 | main() 29 | -------------------------------------------------------------------------------- /dataset_preprocessing/fmow/process_metadata_fmow.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | from torchvision import transforms 7 | from wilds.datasets.fmow_dataset import categories 8 | from PIL import Image 9 | import shutil 10 | import time 11 | 12 | root = Path('/u/scr/nlp/dro/fMoW/') 13 | dstroot = Path('/u/scr/nlp/dro/fMoW/data') 14 | 15 | # build test and seq mapping 16 | 17 | with open(root / 'test_gt_mapping.json', 'r') as f: 18 | test_mapping = json.load(f) 19 | with open(root / 'seq_gt_mapping.json', 'r') as f: 20 | seq_mapping = json.load(f) 21 | 22 | def process_mapping(mapping): 23 | new_mapping = {} 24 | for pair in tqdm(mapping): 25 | new_mapping[pair['input']] = pair['output'] 26 | return new_mapping 27 | 28 | test_mapping = process_mapping(test_mapping) 29 | seq_mapping = process_mapping(seq_mapping) 30 | 31 | 32 | rgb_metadata = [] 33 | msrgb_metadata = [] 34 | 35 | for split in ['train', 'val', 'test', 'seq']: 36 | split_dir = root / (split + '_gt') 37 | 38 | len_split_dir = len(list(split_dir.iterdir())) 39 | for class_dir in tqdm(split_dir.iterdir(), total=len_split_dir): 40 | classname = class_dir.stem 41 | len_class_dir = len(list(class_dir.iterdir())) 42 | for class_subdir in tqdm(class_dir.iterdir(), total=len_class_dir): 43 | for metadata_file in class_subdir.iterdir(): 44 | if metadata_file.suffix == '.json': 45 | with open(metadata_file, 'r') as f: 46 | metadata_json = json.load(f) 47 | 48 | locs = metadata_json['raw_location'].split('((')[1].split('))')[0].split(',') 49 | locs = [loc.strip().split(' ') for loc in locs] 50 | locs = [[float(loc[0]), float(loc[1])] for loc in locs] 51 | # lat long are reversed in locs 52 | lats = [loc[1] for loc in locs] 53 | lons = [loc[0] for loc in locs] 54 | 55 | if split in {'train', 'val'}: 56 | img_path = f"{split}/{metadata_file.parent.parent.stem}/{metadata_file.parent.stem}/{metadata_file.stem}.jpg" 57 | else: 58 | test_mapping_key = f"{split_dir.stem}/{metadata_file.parent.parent.stem}/{metadata_file.parent.stem}" 59 | if split == 'test': 60 | img_path_dir = Path(test_mapping[test_mapping_key]) 61 | else: 62 | img_path_dir = Path(seq_mapping[test_mapping_key]) 63 | 64 | new_img_filename = metadata_file.stem.replace(str(metadata_file.parent.stem), img_path_dir.stem) + ".jpg" 65 | img_path = img_path_dir / new_img_filename 66 | 67 | curr_metadata = { 68 | 'split': split, 69 | 'img_filename': metadata_json['img_filename'], 70 | 'img_path': str(img_path), 71 | 'spatial_reference': metadata_json['spatial_reference'], 72 | 'epsg': metadata_json['epsg'], 73 | 'category': metadata_json['bounding_boxes'][1]['category'], 74 | 'visible': metadata_json['bounding_boxes'][1]['visible'], 75 | 'img_width': metadata_json['img_width'], 76 | 'img_height': metadata_json['img_height'], 77 | 'country_code': metadata_json['country_code'], 78 | 'cloud_cover': metadata_json['cloud_cover'], 79 | 'timestamp': metadata_json['timestamp'], 80 | 'lat': np.mean(lats), 81 | 'lon': np.mean(lons)} 82 | 83 | if str(metadata_file).endswith('msrgb.json'): 84 | msrgb_metadata.append(curr_metadata) 85 | elif str(metadata_file).endswith('rgb.json'): 86 | rgb_metadata.append(curr_metadata) 87 | 88 | 89 | rgb_df = pd.DataFrame(rgb_metadata) 90 | msrgb_df = pd.DataFrame(msrgb_metadata) 91 | 92 | # add region 93 | def add_region(df): 94 | country_codes_df = pd.read_csv(dstroot / 'country_code_mapping.csv') 95 | countrycode_to_region = {k: v for k, v in zip(country_codes_df['alpha-3'], country_codes_df['region'])} 96 | country_codes = df['country_code'].to_list() 97 | regions = [countrycode_to_region.get(code, 'Other') for code in country_codes] 98 | df['region'] = regions 99 | 100 | add_region(rgb_df) 101 | add_region(msrgb_df) 102 | 103 | rgb_df.to_csv(dstroot / 'rgb_metadata.csv', index=False) 104 | msrgb_df.to_csv(dstroot / 'msrgb_metadata.csv', index=False) 105 | 106 | ################ save rgb imgs to npy 107 | 108 | category_to_idx = {cat: i for i, cat in enumerate(categories)} 109 | default_transform = transforms.Compose([ 110 | transforms.Resize(224), 111 | transforms.CenterCrop(224)]) 112 | metadata = pd.read_csv(dstroot / 'rgb_metadata.csv') 113 | 114 | num_batches = 100 115 | batch_size = len(metadata) // num_batches 116 | if len(metadata) % num_batches != 0: 117 | num_batches += 1 118 | 119 | print("Saving into chunks...") 120 | for j in tqdm(range(num_batches)): 121 | batch_metadata = metadata.iloc[j*batch_size : (j+1)*batch_size] 122 | imgs = [] 123 | 124 | for i in tqdm(range(len(batch_metadata))): 125 | curr_metadata = batch_metadata.iloc[i].to_dict() 126 | 127 | img_path = root / curr_metadata['img_path'] 128 | img = Image.open(img_path) 129 | img = img.convert('RGB') 130 | 131 | img = np.asarray(default_transform(img), dtype=np.uint8) 132 | 133 | imgs.append(img) 134 | imgs = np.asarray(imgs, dtype=np.uint8) 135 | np.save(dstroot / f'rgb_all_imgs_{j}.npy', imgs) 136 | -------------------------------------------------------------------------------- /dataset_preprocessing/molpcba_unlabeled/process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from wilds import get_dataset 3 | from rdkit.Chem import AllChem 4 | from rdkit import Chem 5 | from tqdm import tqdm 6 | import pandas as pd 7 | import os 8 | import torch 9 | 10 | def compute_pcba_fingerprint(): 11 | ''' 12 | Compute the fingerprint features for molpcba molecules. 13 | ''' 14 | os.makedirs('processed_fp', exist_ok = True) 15 | 16 | pcba_dataset = get_dataset(dataset = 'ogb-molpcba') 17 | smiles_list = pd.read_csv('data/ogbg_molpcba/mapping/mol.csv.gz')['smiles'].tolist() 18 | x_list = [] 19 | for smiles in tqdm(smiles_list): 20 | mol = Chem.MolFromSmiles(smiles) 21 | x = np.array(list(AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)), dtype=np.int8) 22 | x_list.append(x) 23 | 24 | x = np.stack(x_list) 25 | 26 | np.save('processed_fp/molpcba.npy', x) 27 | 28 | 29 | def jaccard_similarity(vec, mat): 30 | AND = vec * mat 31 | OR = (vec + mat) > 0 32 | denom = np.sum(OR, axis = 1) 33 | nom = np.sum(AND, axis = 1) 34 | 35 | denom[denom==0] = 1 36 | return nom / denom 37 | 38 | 39 | def assign_to_group(): 40 | ''' 41 | Assign unlabeled pubchem molecules to scaffold groups of molpcba. 42 | ''' 43 | smiles_list = pd.read_csv('molpcba_unlabeled/mapping/unlabeled_smiles.csv', header = None)[0].tolist() 44 | 45 | x_pcba = np.load('processed_fp/molpcba.npy') 46 | print(x_pcba.shape) 47 | print((x_pcba > 1).sum()) 48 | scaffold_group = np.load('data/ogbg_molpcba/raw/scaffold_group.npy') 49 | 50 | # ground-truth assignment 51 | group_assignment = np.load('molpcba_unlabeled/processed/group_assignment.npy') 52 | 53 | for i, smiles in tqdm(enumerate(smiles_list), total = len(smiles_list)): 54 | mol = Chem.MolFromSmiles(smiles) 55 | x = np.array(list(AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)), dtype=np.int8) 56 | sim = jaccard_similarity(x, x_pcba) 57 | 58 | max_idx = np.argmax(sim) 59 | a = scaffold_group[max_idx] 60 | b = group_assignment[i] 61 | 62 | print(a, b) 63 | assert a == b # make sure they coincide each other 64 | 65 | 66 | def test_jaccard(): 67 | vec = np.random.randn(1024) > 0 68 | mat = np.random.randn(1000, 1024) 69 | mat[0] = vec 70 | 71 | sim = jaccard_similarity(vec, mat) 72 | print(sim) 73 | 74 | 75 | if __name__ == '__main__': 76 | compute_pcba_fingerprint() 77 | assign_to_group() 78 | 79 | -------------------------------------------------------------------------------- /dataset_preprocessing/poverty/convert_poverty_to_npy.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from github.com/sustainlab-group/africa_poverty/data_analysis/dhs.ipynb 3 | ''' 4 | import tensorflow as tf 5 | import numpy as np 6 | import batcher 7 | import dataset_constants 8 | from tqdm import tqdm 9 | 10 | FOLDS = ['A', 'B', 'C', 'D', 'E'] 11 | SPLITS = ['train', 'val', 'test'] 12 | BAND_ORDER = ['BLUE', 'GREEN', 'RED', 'SWIR1', 'SWIR2', 'TEMP1', 'NIR', 'NIGHTLIGHTS'] 13 | DATASET = '2009-17' 14 | 15 | COUNTRIES = np.asarray(dataset_constants.DHS_COUNTRIES) 16 | 17 | 18 | def get_images(tfrecord_paths, label_name='wealthpooled', return_meta=False): 19 | ''' 20 | Args 21 | - tfrecord_paths: list of str, length N <= 32, paths of TFRecord files 22 | 23 | Returns: np.array, shape [N, 224, 224, 8], type float32 24 | ''' 25 | init_iter, batch_op = batcher.Batcher( 26 | tfrecord_files=tfrecord_paths, 27 | dataset=DATASET, 28 | batch_size=32, 29 | ls_bands='ms', 30 | nl_band='merge', 31 | label_name=label_name, 32 | shuffle=False, 33 | augment=False, 34 | negatives='zero', 35 | normalize=True).get_batch() 36 | with tf.Session() as sess: 37 | sess.run(init_iter) 38 | if return_meta: 39 | ret = sess.run(batch_op) 40 | else: 41 | ret = sess.run(batch_op['images']) 42 | return ret 43 | 44 | 45 | if __name__ == '__main__': 46 | tfrecord_paths = np.asarray(batcher.get_tfrecord_paths(dataset=DATASET, split='all')) 47 | 48 | num_batches = len(tfrecord_paths) // 32 49 | if len(tfrecord_paths) % 32 != 0: 50 | num_batches += 1 51 | 52 | imgs = [] 53 | 54 | for i in tqdm(range(num_batches)): 55 | imgs.append(get_images(tfrecord_paths[i*32: (i+1)*32])) 56 | 57 | imgs = np.concatenate(imgs, axis=0) 58 | np.save('/scr/landsat_poverty_imgs.npy', imgs) 59 | 60 | 61 | ######### process unlabeled data 62 | 63 | tfrecord_paths = [] 64 | root = Path('/atlas/u/chrisyeh/poverty_data/lxv3_transfer') 65 | for country_year in root.iterdir(): 66 | if not country_year.is_dir(): 67 | continue 68 | for tfrecord_file in country_year.iterdir(): 69 | tfrecord_paths.append(str(tfrecord_file)) 70 | 71 | batch_size = 32 72 | num_batches = len(tfrecord_paths) // batch_size 73 | if len(tfrecord_paths) % batch_size != 0: 74 | num_batches += 1 75 | 76 | metadata = [] 77 | imgs = [] 78 | 79 | counter = 0 80 | for i in tqdm(range(num_batches)): 81 | batch_paths = tfrecord_paths[i*batch_size: (i+1)*batch_size] 82 | img_batch = get_images(batch_paths, label_name=None, return_meta=True) 83 | nl_means = img_batch['images'][:, :, :, -1].mean((1,2)) 84 | nl_centers = img_batch['images'][:, 112, 112, -1] 85 | 86 | for path, loc, year, nl_mean, nl_center in zip(batch_paths, img_batch['locs'], img_batch['years'], nl_means, nl_centers): 87 | country = "_".join(str(Path(path).parent.stem).split('_')[:-1]) 88 | 89 | metadata.append({'country': country, 'lat': loc[0], 'lon': loc[1], 'year': year, 'nl_mean': float(nl_mean), 'nl_center': float(nl_center)}) 90 | 91 | imgs.append(img_batch['images']) 92 | 93 | if len(imgs) > (10000 // 32): 94 | imgs = np.concatenate(imgs, axis=0) 95 | np.save(f'/u/scr/nlp/dro/poverty/unlabeled_landsat_poverty_imgs_{counter}.npy', imgs) 96 | counter += 1 97 | imgs = [] 98 | if len(imgs) > 0: 99 | imgs = np.concatenate(imgs, axis=0) 100 | np.save(f'/u/scr/nlp/dro/poverty/unlabeled_landsat_poverty_imgs_{counter}.npy', imgs) 101 | 102 | df = pd.DataFrame(metadata) 103 | df.to_csv('/u/scr/nlp/dro/poverty/unlabeled_metadata.csv', index=False) 104 | -------------------------------------------------------------------------------- /dataset_preprocessing/poverty/process_metadata_poverty.py: -------------------------------------------------------------------------------- 1 | ######## 2 | # ADAPTED from github.com/sustainlab-group/africa_poverty 3 | ######## 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import batcher 8 | import dataset_constants 9 | from tqdm import tqdm 10 | from utils.general import load_npz 11 | import pickle 12 | import pandas as pd 13 | from pathlib import Path 14 | 15 | 16 | FOLDS = ['A', 'B', 'C', 'D', 'E'] 17 | SPLITS = ['train', 'val', 'test'] 18 | BAND_ORDER = ['BLUE', 'GREEN', 'RED', 'SWIR1', 'SWIR2', 'TEMP1', 'NIR', 'NIGHTLIGHTS'] 19 | DATASET = '2009-17' 20 | ROOT = Path('../data') # Path to files from sustainlab-group/africa_poverty 21 | DSTROOT = Path('/u/scr/nlp/dro/poverty/data') 22 | 23 | COUNTRIES = np.asarray(dataset_constants.DHS_COUNTRIES) 24 | 25 | file_path = ROOT / 'dhs_image_hists.npz' 26 | npz = load_npz(file_path) 27 | 28 | labels = npz['labels'] 29 | locs = npz['locs'] 30 | years = npz['years'] 31 | nls_center = npz['nls_center'] 32 | nls_mean = npz['nls_mean'] 33 | 34 | num_examples = len(labels) 35 | assert np.all(np.asarray([len(labels), len(locs), len(years)]) == num_examples) 36 | 37 | dmsp_mask = years < 2012 38 | viirs_mask = ~dmsp_mask 39 | 40 | with open(ROOT / 'dhs_loc_dict.pkl', 'rb') as f: 41 | loc_dict = pickle.load(f) 42 | 43 | df_data = [] 44 | for label, loc, nl_mean, nl_center in zip(labels, locs, nls_mean, nls_center): 45 | lat, lon = loc 46 | loc_info = loc_dict[(lat, lon)] 47 | country = loc_info['country'] 48 | year = int(loc_info['country_year'][-4:]) # use the year matching the surveyID 49 | urban = loc_info['urban'] 50 | household = loc_info['households'] 51 | row = [lat, lon, label, country, year, urban, nl_mean, nl_center, household] 52 | df_data.append(row) 53 | df = pd.DataFrame.from_records( 54 | df_data, 55 | columns=['lat', 'lon', 'wealthpooled', 'country', 'year', 'urban', 'nl_mean', 'nl_center', 'households']) 56 | 57 | df.to_csv(DSTROOT / 'dhs_metadata.csv', index=False) 58 | -------------------------------------------------------------------------------- /dataset_preprocessing/poverty/split_npys.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import numpy as np 4 | from PIL import Image 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | 8 | def main(): 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--root_dir', required=True, 12 | help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).') 13 | config = parser.parse_args() 14 | data_dir = Path(config.root_dir) / 'poverty_v1.0' 15 | indiv_dir = Path(config.root_dir) / 'poverty_v1.0_indiv_npz' 16 | os.makedirs(indiv_dir, exist_ok=True) 17 | 18 | f = np.load(data_dir / 'landsat_poverty_imgs.npy', mmap_mode='r') 19 | f = f.transpose((0, 3, 1, 2)) 20 | for i in tqdm(range(len(f))): 21 | x = f[i] 22 | np.savez_compressed(indiv_dir / f'landsat_poverty_img_{i}.npz', x=x) 23 | 24 | if __name__=='__main__': 25 | main() 26 | -------------------------------------------------------------------------------- /dataset_preprocessing/poverty/split_npys_unlabeled.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import numpy as np 4 | from PIL import Image 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--root_dir', required=True, 11 | help='The poverty data directory.') 12 | parser.add_argument('--out_dir_root', required=True, 13 | help='The directory where output dir should be made.') 14 | args = parser.parse_args() 15 | 16 | data_dir = Path(args.root_dir) 17 | indiv_dir = Path(args.out_dir_root) / 'poverty_unlabeled_v1.0_indiv_npz' / 'images' 18 | indiv_dir.mkdir(exist_ok=True, parents=True) 19 | 20 | counter = 0 21 | for i in range(27): 22 | path = data_dir / f'unlabeled_landsat_poverty_imgs_{i}.npy' 23 | arr = np.load(path, mmap_mode='r') 24 | arr = arr.transpose((0, 3, 1, 2)) 25 | for j in tqdm(range(len(arr))): 26 | x = arr[j] 27 | np.savez_compressed(indiv_dir / f'landsat_poverty_img_{counter}.npz', x=x) 28 | counter += 1 29 | 30 | 31 | if __name__=='__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p-lambda/wilds/472677590de351857197a9bf24958838c39c272b/examples/__init__.py -------------------------------------------------------------------------------- /examples/algorithms/AFN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from algorithms.single_model_algorithm import SingleModelAlgorithm 4 | from models.initializer import initialize_model 5 | 6 | class AFN(SingleModelAlgorithm): 7 | """ 8 | Adaptive Feature Norm (AFN) 9 | 10 | Original paper: 11 | @InProceedings{Xu_2019_ICCV, 12 | author = {Xu, Ruijia and Li, Guanbin and Yang, Jihan and Lin, Liang}, 13 | title = {Larger Norm More Transferable: An Adaptive Feature Norm Approach for 14 | Unsupervised Domain Adaptation}, 15 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 16 | month = {October}, 17 | year = {2019} 18 | } 19 | """ 20 | 21 | def __init__( 22 | self, 23 | config, 24 | d_out, 25 | grouper, 26 | loss, 27 | metric, 28 | n_train_steps, 29 | ): 30 | # Initialize model 31 | featurizer, classifier = initialize_model(config, d_out=d_out, is_featurizer=True) 32 | model = torch.nn.Sequential(featurizer, classifier) 33 | 34 | # Initialize module 35 | super().__init__( 36 | config=config, 37 | model=model, 38 | grouper=grouper, 39 | loss=loss, 40 | metric=metric, 41 | n_train_steps=n_train_steps, 42 | ) 43 | 44 | # Model components 45 | self.featurizer = featurizer 46 | self.classifier = classifier 47 | 48 | # Algorithm hyperparameters 49 | self.penalty_weight = config.afn_penalty_weight 50 | self.delta_r = config.safn_delta_r 51 | self.r = config.hafn_r 52 | self.afn_loss = self.hafn_loss if config.use_hafn else self.safn_loss 53 | 54 | # Additional logging 55 | self.logged_fields.append("classification_loss") 56 | self.logged_fields.append("feature_norm_penalty") 57 | 58 | def safn_loss(self, features): 59 | """ 60 | Adapted from https://github.com/jihanyang/AFN 61 | """ 62 | radius = features.norm(p=2, dim=1).detach() 63 | assert not radius.requires_grad 64 | radius = radius + self.delta_r 65 | loss = ((features.norm(p=2, dim=1) - radius) ** 2).mean() 66 | return loss 67 | 68 | def hafn_loss(self, features): 69 | """ 70 | Adapted from https://github.com/jihanyang/AFN 71 | """ 72 | loss = (features.norm(p=2, dim=1).mean() - self.r) ** 2 73 | return loss 74 | 75 | def process_batch(self, batch, unlabeled_batch=None): 76 | """ 77 | Overrides single_model_algorithm.process_batch(). 78 | Args: 79 | - batch (tuple of Tensors): a batch of data yielded by data loaders 80 | - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader 81 | Output: 82 | - results (dictionary): information about the batch 83 | - y_true (Tensor): ground truth labels for batch 84 | - g (Tensor): groups for batch 85 | - metadata (Tensor): metadata for batch 86 | - features (Tensor): featurizer output for batch 87 | - y_pred (Tensor): full model output for batch 88 | - unlabeled_features (Tensor): featurizer outputs for unlabeled_batch 89 | """ 90 | # Forward pass 91 | x, y_true, metadata = batch 92 | x = x.to(self.device) 93 | y_true = y_true.to(self.device) 94 | g = self.grouper.metadata_to_group(metadata).to(self.device) 95 | features = self.featurizer(x) 96 | y_pred = self.classifier(features) 97 | 98 | results = { 99 | "g": g, 100 | "metadata": metadata, 101 | "y_true": y_true, 102 | "y_pred": y_pred, 103 | "features": features, 104 | } 105 | 106 | if unlabeled_batch is not None: 107 | unlabeled_x, _ = unlabeled_batch 108 | unlabeled_x = unlabeled_x.to(self.device) 109 | results['unlabeled_features'] = self.featurizer(unlabeled_x) 110 | return results 111 | 112 | def objective(self, results): 113 | classification_loss = self.loss.compute( 114 | results["y_pred"], results["y_true"], return_dict=False 115 | ) 116 | 117 | if self.is_training: 118 | f_source = results.pop("features") 119 | f_target = results.pop("unlabeled_features") 120 | feature_norm_penalty = self.afn_loss(f_source) + self.afn_loss(f_target) 121 | else: 122 | feature_norm_penalty = 0.0 123 | 124 | # Add to results for additional logging 125 | self.save_metric_for_logging( 126 | results, "classification_loss", classification_loss 127 | ) 128 | self.save_metric_for_logging( 129 | results, "feature_norm_penalty", feature_norm_penalty 130 | ) 131 | return classification_loss + self.penalty_weight * feature_norm_penalty -------------------------------------------------------------------------------- /examples/algorithms/DANN.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | 5 | from algorithms.single_model_algorithm import SingleModelAlgorithm 6 | from models.domain_adversarial_network import DomainAdversarialNetwork 7 | from models.initializer import initialize_model 8 | from optimizer import initialize_optimizer_with_model_params 9 | from losses import initialize_loss 10 | from utils import concat_input 11 | 12 | class DANN(SingleModelAlgorithm): 13 | """ 14 | Domain-adversarial training of neural networks. 15 | 16 | Original paper: 17 | @inproceedings{dann, 18 | title={Domain-Adversarial Training of Neural Networks}, 19 | author={Ganin, Ustinova, Ajakan, Germain, Larochelle, Laviolette, Marchand and Lempitsky}, 20 | booktitle={Journal of Machine Learning Research 17}, 21 | year={2016} 22 | } 23 | """ 24 | 25 | def __init__( 26 | self, 27 | config, 28 | d_out, 29 | grouper, 30 | loss, 31 | metric, 32 | n_train_steps, 33 | n_domains, 34 | group_ids_to_domains, 35 | ): 36 | # Initialize model 37 | featurizer, classifier = initialize_model( 38 | config, d_out=d_out, is_featurizer=True 39 | ) 40 | model = DomainAdversarialNetwork(featurizer, classifier, n_domains) 41 | parameters_to_optimize: List[Dict] = model.get_parameters_with_lr( 42 | featurizer_lr=config.dann_featurizer_lr, 43 | classifier_lr=config.dann_classifier_lr, 44 | discriminator_lr=config.dann_discriminator_lr, 45 | ) 46 | self.optimizer = initialize_optimizer_with_model_params(config, parameters_to_optimize) 47 | self.domain_loss = initialize_loss('cross_entropy', config) 48 | 49 | # Initialize module 50 | super().__init__( 51 | config=config, 52 | model=model, 53 | grouper=grouper, 54 | loss=loss, 55 | metric=metric, 56 | n_train_steps=n_train_steps, 57 | ) 58 | self.group_ids_to_domains = group_ids_to_domains 59 | 60 | # Algorithm hyperparameters 61 | self.penalty_weight = config.dann_penalty_weight 62 | 63 | # Additional logging 64 | self.logged_fields.append("classification_loss") 65 | self.logged_fields.append("domain_classification_loss") 66 | 67 | def process_batch(self, batch, unlabeled_batch=None): 68 | """ 69 | Overrides single_model_algorithm.process_batch(). 70 | Args: 71 | - batch (tuple of Tensors): a batch of data yielded by data loaders 72 | - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader 73 | Output: 74 | - results (dictionary): information about the batch 75 | - y_true (Tensor): ground truth labels for batch 76 | - g (Tensor): groups for batch 77 | - metadata (Tensor): metadata for batch 78 | - y_pred (Tensor): model output for batch 79 | - domains_true (Tensor): true domains for batch and unlabeled batch 80 | - domains_pred (Tensor): predicted domains for batch and unlabeled batch 81 | - unlabeled_features (Tensor): featurizer outputs for unlabeled_batch 82 | """ 83 | # Forward pass 84 | x, y_true, metadata = batch 85 | g = self.grouper.metadata_to_group(metadata).to(self.device) 86 | domains_true = self.group_ids_to_domains[g] 87 | 88 | if unlabeled_batch is not None: 89 | unlabeled_x, unlabeled_metadata = unlabeled_batch 90 | unlabeled_domains_true = self.group_ids_to_domains[ 91 | self.grouper.metadata_to_group(unlabeled_metadata) 92 | ] 93 | 94 | # Concatenate examples and true domains 95 | x_cat = concat_input(x, unlabeled_x) 96 | domains_true = torch.cat([domains_true, unlabeled_domains_true]) 97 | else: 98 | x_cat = x 99 | 100 | x_cat = x_cat.to(self.device) 101 | y_true = y_true.to(self.device) 102 | domains_true = domains_true.to(self.device) 103 | y_pred, domains_pred = self.model(x_cat) 104 | 105 | # Ignore the predicted labels for the unlabeled data 106 | y_pred = y_pred[: len(y_true)] 107 | 108 | return { 109 | "g": g, 110 | "metadata": metadata, 111 | "y_true": y_true, 112 | "y_pred": y_pred, 113 | "domains_true": domains_true, 114 | "domains_pred": domains_pred, 115 | } 116 | 117 | def objective(self, results): 118 | classification_loss = self.loss.compute( 119 | results["y_pred"], results["y_true"], return_dict=False 120 | ) 121 | 122 | if self.is_training: 123 | domain_classification_loss = self.domain_loss.compute( 124 | results.pop("domains_pred"), 125 | results.pop("domains_true"), 126 | return_dict=False, 127 | ) 128 | else: 129 | domain_classification_loss = 0.0 130 | 131 | # Add to results for additional logging 132 | self.save_metric_for_logging( 133 | results, "classification_loss", classification_loss 134 | ) 135 | self.save_metric_for_logging( 136 | results, "domain_classification_loss", domain_classification_loss 137 | ) 138 | return classification_loss + domain_classification_loss * self.penalty_weight 139 | -------------------------------------------------------------------------------- /examples/algorithms/ERM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from algorithms.single_model_algorithm import SingleModelAlgorithm 3 | from models.initializer import initialize_model 4 | from utils import move_to 5 | 6 | class ERM(SingleModelAlgorithm): 7 | def __init__(self, config, d_out, grouper, loss, 8 | metric, n_train_steps): 9 | model = initialize_model(config, d_out) 10 | # initialize module 11 | super().__init__( 12 | config=config, 13 | model=model, 14 | grouper=grouper, 15 | loss=loss, 16 | metric=metric, 17 | n_train_steps=n_train_steps, 18 | ) 19 | self.use_unlabeled_y = config.use_unlabeled_y # Expect x,y,m from unlabeled loaders and train on the unlabeled y 20 | 21 | def process_batch(self, batch, unlabeled_batch=None): 22 | """ 23 | Overrides single_model_algorithm.process_batch(). 24 | ERM defines its own process_batch to handle if self.use_unlabeled_y is true. 25 | Args: 26 | - batch (tuple of Tensors): a batch of data yielded by data loaders 27 | - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader 28 | Output: 29 | - results (dictionary): information about the batch 30 | - y_true (Tensor): ground truth labels for batch 31 | - g (Tensor): groups for batch 32 | - metadata (Tensor): metadata for batch 33 | - y_pred (Tensor): model output for batch 34 | - unlabeled_g (Tensor): groups for unlabeled batch 35 | - unlabeled_metadata (Tensor): metadata for unlabeled batch 36 | - unlabeled_y_pred (Tensor): predictions for unlabeled batch for fully-supervised ERM experiments 37 | - unlabeled_y_true (Tensor): true labels for unlabeled batch for fully-supervised ERM experiments 38 | """ 39 | x, y_true, metadata = batch 40 | x = move_to(x, self.device) 41 | y_true = move_to(y_true, self.device) 42 | g = move_to(self.grouper.metadata_to_group(metadata), self.device) 43 | 44 | outputs = self.get_model_output(x, y_true) 45 | 46 | results = { 47 | 'g': g, 48 | 'y_true': y_true, 49 | 'y_pred': outputs, 50 | 'metadata': metadata, 51 | } 52 | if unlabeled_batch is not None: 53 | if self.use_unlabeled_y: # expect loaders to return x,y,m 54 | x, y, metadata = unlabeled_batch 55 | y = move_to(y, self.device) 56 | else: 57 | x, metadata = unlabeled_batch 58 | x = move_to(x, self.device) 59 | results['unlabeled_metadata'] = metadata 60 | if self.use_unlabeled_y: 61 | results['unlabeled_y_pred'] = self.get_model_output(x, y) 62 | results['unlabeled_y_true'] = y 63 | results['unlabeled_g'] = self.grouper.metadata_to_group(metadata).to(self.device) 64 | return results 65 | 66 | def objective(self, results): 67 | labeled_loss = self.loss.compute(results['y_pred'], results['y_true'], return_dict=False) 68 | if self.use_unlabeled_y and 'unlabeled_y_true' in results: 69 | unlabeled_loss = self.loss.compute( 70 | results['unlabeled_y_pred'], 71 | results['unlabeled_y_true'], 72 | return_dict=False 73 | ) 74 | lab_size = len(results['y_pred']) 75 | unl_size = len(results['unlabeled_y_pred']) 76 | return (lab_size * labeled_loss + unl_size * unlabeled_loss) / (lab_size + unl_size) 77 | else: 78 | return labeled_loss -------------------------------------------------------------------------------- /examples/algorithms/IRM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.initializer import initialize_model 3 | from algorithms.single_model_algorithm import SingleModelAlgorithm 4 | from wilds.common.utils import split_into_groups 5 | import torch.autograd as autograd 6 | from wilds.common.metrics.metric import ElementwiseMetric, MultiTaskMetric 7 | from optimizer import initialize_optimizer 8 | 9 | class IRM(SingleModelAlgorithm): 10 | """ 11 | Invariant risk minimization. 12 | 13 | Original paper: 14 | @article{arjovsky2019invariant, 15 | title={Invariant risk minimization}, 16 | author={Arjovsky, Martin and Bottou, L{\'e}on and Gulrajani, Ishaan and Lopez-Paz, David}, 17 | journal={arXiv preprint arXiv:1907.02893}, 18 | year={2019} 19 | } 20 | 21 | The IRM penalty function below is adapted from the code snippet 22 | provided in the above paper. 23 | """ 24 | def __init__(self, config, d_out, grouper, loss, metric, n_train_steps): 25 | """ 26 | Algorithm-specific arguments (in config): 27 | - irm_lambda 28 | - irm_penalty_anneal_iters 29 | """ 30 | # check config 31 | assert config.train_loader == 'group' 32 | assert config.uniform_over_groups 33 | assert config.distinct_groups 34 | # initialize model 35 | model = initialize_model(config, d_out).to(config.device) 36 | # initialize the module 37 | super().__init__( 38 | config=config, 39 | model=model, 40 | grouper=grouper, 41 | loss=loss, 42 | metric=metric, 43 | n_train_steps=n_train_steps, 44 | ) 45 | 46 | # additional logging 47 | self.logged_fields.append('penalty') 48 | # set IRM-specific variables 49 | self.irm_lambda = config.irm_lambda 50 | self.irm_penalty_anneal_iters = config.irm_penalty_anneal_iters 51 | self.scale = torch.tensor(1.).to(self.device).requires_grad_() 52 | self.update_count = 0 53 | self.config = config # Need to store config for IRM because we need to re-init optimizer 54 | 55 | assert isinstance(self.loss, ElementwiseMetric) or isinstance(self.loss, MultiTaskMetric) 56 | 57 | def irm_penalty(self, losses): 58 | grad_1 = autograd.grad(losses[0::2].mean(), [self.scale], create_graph=True)[0] 59 | grad_2 = autograd.grad(losses[1::2].mean(), [self.scale], create_graph=True)[0] 60 | result = torch.sum(grad_1 * grad_2) 61 | return result 62 | 63 | def objective(self, results): 64 | # Compute penalty on each group 65 | # To be consistent with the DomainBed implementation, 66 | # this returns the average loss and penalty across groups, regardless of group size 67 | # But the GroupLoader ensures that each group is of the same size in each minibatch 68 | unique_groups, group_indices, _ = split_into_groups(results['g']) 69 | n_groups_per_batch = unique_groups.numel() 70 | avg_loss = 0. 71 | penalty = 0. 72 | 73 | for i_group in group_indices: # Each element of group_indices is a list of indices 74 | group_losses, _ = self.loss.compute_flattened( 75 | self.scale * results['y_pred'][i_group], 76 | results['y_true'][i_group], 77 | return_dict=False) 78 | if group_losses.numel()>0: 79 | avg_loss += group_losses.mean() 80 | if self.is_training: # Penalties only make sense when training 81 | penalty += self.irm_penalty(group_losses) 82 | avg_loss /= n_groups_per_batch 83 | penalty /= n_groups_per_batch 84 | 85 | if self.update_count >= self.irm_penalty_anneal_iters: 86 | penalty_weight = self.irm_lambda 87 | else: 88 | penalty_weight = 1.0 89 | 90 | self.save_metric_for_logging(results, 'penalty', penalty) 91 | return avg_loss + penalty * penalty_weight 92 | 93 | def _update(self, results, should_step=True): 94 | if self.update_count == self.irm_penalty_anneal_iters: 95 | print('Hit IRM penalty anneal iters') 96 | # Reset optimizer to deal with the changing penalty weight 97 | self.optimizer = initialize_optimizer(self.config, self.model) 98 | super()._update(results, should_step=should_step) 99 | self.update_count += 1 100 | -------------------------------------------------------------------------------- /examples/algorithms/algorithm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from utils import move_to, detach_and_clone 3 | 4 | 5 | class Algorithm(nn.Module): 6 | def __init__(self, device): 7 | super().__init__() 8 | self.device = device 9 | self.out_device = 'cpu' 10 | self._has_log = False 11 | self.reset_log() 12 | 13 | def update(self, batch): 14 | """ 15 | Process the batch, update the log, and update the model 16 | Args: 17 | - batch (tuple of Tensors): a batch of data yielded by data loaders 18 | Output: 19 | - results (dictionary): information about the batch, such as: 20 | - g (Tensor) 21 | - y_true (Tensor) 22 | - metadata (Tensor) 23 | - loss (Tensor) 24 | - metrics (Tensor) 25 | """ 26 | raise NotImplementedError 27 | 28 | def evaluate(self, batch): 29 | """ 30 | Process the batch and update the log, without updating the model 31 | Args: 32 | - batch (tuple of Tensors): a batch of data yielded by data loaders 33 | Output: 34 | - results (dictionary): information about the batch, such as: 35 | - g (Tensor) 36 | - y_true (Tensor) 37 | - metadata (Tensor) 38 | - loss (Tensor) 39 | - metrics (Tensor) 40 | """ 41 | raise NotImplementedError 42 | 43 | def train(self, mode=True): 44 | """ 45 | Switch to train mode 46 | """ 47 | self.is_training = mode 48 | super().train(mode) 49 | self.reset_log() 50 | 51 | @property 52 | def has_log(self): 53 | return self._has_log 54 | 55 | def reset_log(self): 56 | """ 57 | Resets log by clearing out the internal log, Algorithm.log_dict 58 | """ 59 | self._has_log = False 60 | self.log_dict = {} 61 | 62 | def update_log(self, results): 63 | """ 64 | Updates the internal log, Algorithm.log_dict 65 | Args: 66 | - results (dictionary) 67 | """ 68 | raise NotImplementedError 69 | 70 | def get_log(self): 71 | """ 72 | Sanitizes the internal log (Algorithm.log_dict) and outputs it. 73 | 74 | """ 75 | raise NotImplementedError 76 | 77 | def get_pretty_log_str(self): 78 | raise NotImplementedError 79 | 80 | def step_schedulers(self, is_epoch, metrics={}, log_access=False): 81 | """ 82 | Update all relevant schedulers 83 | Args: 84 | - is_epoch (bool): epoch-wise update if set to True, batch-wise update otherwise 85 | - metrics (dict): a dictionary of metrics that can be used for scheduler updates 86 | - log_access (bool): whether metrics from self.get_log() can be used to update schedulers 87 | """ 88 | raise NotImplementedError 89 | 90 | def sanitize_dict(self, in_dict, to_out_device=True): 91 | """ 92 | Helper function that sanitizes dictionaries by: 93 | - moving to the specified output device 94 | - removing any gradient information 95 | - detaching and cloning the tensors 96 | Args: 97 | - in_dict (dictionary) 98 | Output: 99 | - out_dict (dictionary): sanitized version of in_dict 100 | """ 101 | out_dict = detach_and_clone(in_dict) 102 | if to_out_device: 103 | out_dict = move_to(out_dict, self.out_device) 104 | return out_dict -------------------------------------------------------------------------------- /examples/algorithms/deepCORAL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.initializer import initialize_model 3 | from algorithms.single_model_algorithm import SingleModelAlgorithm 4 | from wilds.common.utils import split_into_groups 5 | from utils import concat_input 6 | 7 | class DeepCORAL(SingleModelAlgorithm): 8 | """ 9 | Deep CORAL. 10 | This algorithm was originally proposed as an unsupervised domain adaptation algorithm. 11 | 12 | Original paper: 13 | @inproceedings{sun2016deep, 14 | title={Deep CORAL: Correlation alignment for deep domain adaptation}, 15 | author={Sun, Baochen and Saenko, Kate}, 16 | booktitle={European Conference on Computer Vision}, 17 | pages={443--450}, 18 | year={2016}, 19 | organization={Springer} 20 | } 21 | 22 | The original CORAL loss is the distance between second-order statistics (covariances) 23 | of the source and target features. 24 | 25 | The CORAL penalty function below is adapted from DomainBed's implementation: 26 | https://github.com/facebookresearch/DomainBed/blob/1a61f7ff44b02776619803a1dd12f952528ca531/domainbed/algorithms.py#L539 27 | """ 28 | def __init__(self, config, d_out, grouper, loss, metric, n_train_steps): 29 | # check config 30 | assert config.train_loader == 'group' 31 | assert config.uniform_over_groups 32 | assert config.distinct_groups 33 | # initialize models 34 | featurizer, classifier = initialize_model(config, d_out=d_out, is_featurizer=True) 35 | featurizer = featurizer.to(config.device) 36 | classifier = classifier.to(config.device) 37 | model = torch.nn.Sequential(featurizer, classifier) 38 | # initialize module 39 | super().__init__( 40 | config=config, 41 | model=model, 42 | grouper=grouper, 43 | loss=loss, 44 | metric=metric, 45 | n_train_steps=n_train_steps, 46 | ) 47 | # algorithm hyperparameters 48 | self.penalty_weight = config.coral_penalty_weight 49 | # additional logging 50 | self.logged_fields.append('penalty') 51 | # set model components 52 | self.featurizer = featurizer 53 | self.classifier = classifier 54 | 55 | def coral_penalty(self, x, y): 56 | if x.dim() > 2: 57 | # featurizers output Tensors of size (batch_size, ..., feature dimensionality). 58 | # we flatten to Tensors of size (*, feature dimensionality) 59 | x = x.view(-1, x.size(-1)) 60 | y = y.view(-1, y.size(-1)) 61 | 62 | mean_x = x.mean(0, keepdim=True) 63 | mean_y = y.mean(0, keepdim=True) 64 | cent_x = x - mean_x 65 | cent_y = y - mean_y 66 | cova_x = (cent_x.t() @ cent_x) / (len(x) - 1) 67 | cova_y = (cent_y.t() @ cent_y) / (len(y) - 1) 68 | 69 | mean_diff = (mean_x - mean_y).pow(2).mean() 70 | cova_diff = (cova_x - cova_y).pow(2).mean() 71 | 72 | return mean_diff + cova_diff 73 | 74 | def process_batch(self, batch, unlabeled_batch=None): 75 | """ 76 | Overrides single_model_algorithm.process_batch(). 77 | Args: 78 | - batch (tuple of Tensors): a batch of data yielded by data loaders 79 | - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader 80 | Output: 81 | - results (dictionary): information about the batch 82 | - y_true (Tensor): ground truth labels for batch 83 | - g (Tensor): groups for batch 84 | - metadata (Tensor): metadata for batch 85 | - unlabeled_g (Tensor): groups for unlabeled batch 86 | - features (Tensor): featurizer output for batch and unlabeled batch 87 | - y_pred (Tensor): full model output for batch and unlabeled batch 88 | """ 89 | # forward pass 90 | x, y_true, metadata = batch 91 | y_true = y_true.to(self.device) 92 | g = self.grouper.metadata_to_group(metadata).to(self.device) 93 | 94 | results = { 95 | 'g': g, 96 | 'y_true': y_true, 97 | 'metadata': metadata, 98 | } 99 | 100 | if unlabeled_batch is not None: 101 | unlabeled_x, unlabeled_metadata = unlabeled_batch 102 | x = concat_input(x, unlabeled_x) 103 | unlabeled_g = self.grouper.metadata_to_group(unlabeled_metadata).to(self.device) 104 | results['unlabeled_g'] = unlabeled_g 105 | 106 | x = x.to(self.device) 107 | features = self.featurizer(x) 108 | outputs = self.classifier(features) 109 | y_pred = outputs[: len(y_true)] 110 | 111 | results['features'] = features 112 | results['y_pred'] = y_pred 113 | return results 114 | 115 | def objective(self, results): 116 | if self.is_training: 117 | features = results.pop('features') 118 | 119 | # Split into groups 120 | groups = concat_input(results['g'], results['unlabeled_g']) if 'unlabeled_g' in results else results['g'] 121 | unique_groups, group_indices, _ = split_into_groups(groups) 122 | n_groups_per_batch = unique_groups.numel() 123 | 124 | # Compute penalty - perform pairwise comparisons between features of all the groups 125 | penalty = torch.zeros(1, device=self.device) 126 | for i_group in range(n_groups_per_batch): 127 | for j_group in range(i_group+1, n_groups_per_batch): 128 | penalty += self.coral_penalty(features[group_indices[i_group]], features[group_indices[j_group]]) 129 | if n_groups_per_batch > 1: 130 | penalty /= (n_groups_per_batch * (n_groups_per_batch-1) / 2) # get the mean penalty 131 | else: 132 | penalty = 0. 133 | 134 | self.save_metric_for_logging(results, 'penalty', penalty) 135 | avg_loss = self.loss.compute(results['y_pred'], results['y_true'], return_dict=False) 136 | return avg_loss + penalty * self.penalty_weight 137 | -------------------------------------------------------------------------------- /examples/algorithms/fixmatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from models.initializer import initialize_model 5 | from algorithms.single_model_algorithm import SingleModelAlgorithm 6 | from configs.supported import process_pseudolabels_functions 7 | from utils import detach_and_clone 8 | 9 | 10 | class FixMatch(SingleModelAlgorithm): 11 | """ 12 | FixMatch. 13 | This algorithm was originally proposed as a semi-supervised learning algorithm. 14 | 15 | Loss is of the form 16 | \ell_s + \lambda * \ell_u 17 | where 18 | \ell_s = cross-entropy with true labels using weakly augmented labeled examples 19 | \ell_u = cross-entropy with pseudolabel generated using weak augmentation and prediction 20 | using strong augmentation 21 | 22 | Original paper: 23 | @article{sohn2020fixmatch, 24 | title={Fixmatch: Simplifying semi-supervised learning with consistency and confidence}, 25 | author={Sohn, Kihyuk and Berthelot, David and Li, Chun-Liang and Zhang, Zizhao and Carlini, Nicholas and Cubuk, Ekin D and Kurakin, Alex and Zhang, Han and Raffel, Colin}, 26 | journal={arXiv preprint arXiv:2001.07685}, 27 | year={2020} 28 | } 29 | """ 30 | def __init__(self, config, d_out, grouper, loss, metric, n_train_steps): 31 | featurizer, classifier = initialize_model( 32 | config, d_out=d_out, is_featurizer=True 33 | ) 34 | model = torch.nn.Sequential(featurizer, classifier) 35 | 36 | # initialize module 37 | super().__init__( 38 | config=config, 39 | model=model, 40 | grouper=grouper, 41 | loss=loss, 42 | metric=metric, 43 | n_train_steps=n_train_steps, 44 | ) 45 | # algorithm hyperparameters 46 | self.fixmatch_lambda = config.self_training_lambda 47 | self.confidence_threshold = config.self_training_threshold 48 | self.process_pseudolabels_function = process_pseudolabels_functions[config.process_pseudolabels_function] 49 | 50 | # Additional logging 51 | self.logged_fields.append("pseudolabels_kept_frac") 52 | self.logged_fields.append("classification_loss") 53 | self.logged_fields.append("consistency_loss") 54 | 55 | def process_batch(self, batch, unlabeled_batch=None): 56 | """ 57 | Overrides single_model_algorithm.process_batch(). 58 | Args: 59 | - batch (x, y, m): a batch of data yielded by data loaders 60 | - unlabeled_batch: examples ((x_weak, x_strong), m) where x_weak is weakly augmented but x_strong is strongly augmented 61 | Output: 62 | - results (dictionary): information about the batch 63 | - y_true (Tensor): ground truth labels for batch 64 | - g (Tensor): groups for batch 65 | - metadata (Tensor): metadata for batch 66 | - y_pred (Tensor): model output for batch 67 | - unlabeled_g (Tensor): groups for unlabeled batch 68 | - unlabeled_metadata (Tensor): metadata for unlabeled batch 69 | - unlabeled_weak_y_pseudo (Tensor): pseudolabels on x_weak of the unlabeled batch, already thresholded 70 | - unlabeled_strong_y_pred (Tensor): model output on x_strong of the unlabeled batch, already thresholded 71 | """ 72 | # Labeled examples 73 | x, y_true, metadata = batch 74 | x = x.to(self.device) 75 | y_true = y_true.to(self.device) 76 | g = self.grouper.metadata_to_group(metadata).to(self.device) 77 | # package the results 78 | results = { 79 | 'g': g, 80 | 'y_true': y_true, 81 | 'metadata': metadata 82 | } 83 | pseudolabels_kept_frac = 0 84 | 85 | # Unlabeled examples 86 | if unlabeled_batch is not None: 87 | (x_weak, x_strong), metadata = unlabeled_batch 88 | x_weak = x_weak.to(self.device) 89 | x_strong = x_strong.to(self.device) 90 | 91 | g = self.grouper.metadata_to_group(metadata).to(self.device) 92 | results['unlabeled_metadata'] = metadata 93 | results['unlabeled_g'] = g 94 | 95 | with torch.no_grad(): 96 | outputs = self.model(x_weak) 97 | _, pseudolabels, pseudolabels_kept_frac, mask = self.process_pseudolabels_function( 98 | outputs, 99 | self.confidence_threshold, 100 | ) 101 | results['unlabeled_weak_y_pseudo'] = detach_and_clone(pseudolabels) 102 | 103 | self.save_metric_for_logging( 104 | results, "pseudolabels_kept_frac", pseudolabels_kept_frac 105 | ) 106 | 107 | # Concat and call forward 108 | n_lab = x.shape[0] 109 | if unlabeled_batch is not None: 110 | x_concat = torch.cat((x, x_strong), dim=0) 111 | else: 112 | x_concat = x 113 | 114 | outputs = self.model(x_concat) 115 | results['y_pred'] = outputs[:n_lab] 116 | if unlabeled_batch is not None: 117 | results['unlabeled_strong_y_pred'] = outputs[n_lab:] if mask is None else outputs[n_lab:][mask] 118 | return results 119 | 120 | def objective(self, results): 121 | # Labeled loss 122 | classification_loss = self.loss.compute(results['y_pred'], results['y_true'], return_dict=False) 123 | 124 | # Pseudolabeled loss 125 | if 'unlabeled_weak_y_pseudo' in results: 126 | loss_output = self.loss.compute( 127 | results['unlabeled_strong_y_pred'], 128 | results['unlabeled_weak_y_pseudo'], 129 | return_dict=False, 130 | ) 131 | consistency_loss = loss_output * results['pseudolabels_kept_frac'] 132 | else: 133 | consistency_loss = 0 134 | 135 | # Add to results for additional logging 136 | self.save_metric_for_logging( 137 | results, "classification_loss", classification_loss 138 | ) 139 | self.save_metric_for_logging( 140 | results, "consistency_loss", consistency_loss 141 | ) 142 | 143 | return classification_loss + self.fixmatch_lambda * consistency_loss 144 | -------------------------------------------------------------------------------- /examples/algorithms/groupDRO.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from algorithms.single_model_algorithm import SingleModelAlgorithm 3 | from models.initializer import initialize_model 4 | 5 | class GroupDRO(SingleModelAlgorithm): 6 | """ 7 | Group distributionally robust optimization. 8 | 9 | Original paper: 10 | @inproceedings{sagawa2019distributionally, 11 | title={Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization}, 12 | author={Sagawa, Shiori and Koh, Pang Wei and Hashimoto, Tatsunori B and Liang, Percy}, 13 | booktitle={International Conference on Learning Representations}, 14 | year={2019} 15 | } 16 | """ 17 | def __init__(self, config, d_out, grouper, loss, metric, n_train_steps, is_group_in_train): 18 | # check config 19 | assert config.uniform_over_groups 20 | # initialize model 21 | model = initialize_model(config, d_out) 22 | # initialize module 23 | super().__init__( 24 | config=config, 25 | model=model, 26 | grouper=grouper, 27 | loss=loss, 28 | metric=metric, 29 | n_train_steps=n_train_steps, 30 | ) 31 | # additional logging 32 | self.logged_fields.append('group_weight') 33 | # step size 34 | self.group_weights_step_size = config.group_dro_step_size 35 | # initialize adversarial weights 36 | self.group_weights = torch.zeros(grouper.n_groups) 37 | self.group_weights[is_group_in_train] = 1 38 | self.group_weights = self.group_weights/self.group_weights.sum() 39 | self.group_weights = self.group_weights.to(self.device) 40 | 41 | def process_batch(self, batch, unlabeled_batch=None): 42 | results = super().process_batch(batch) 43 | results['group_weight'] = self.group_weights 44 | return results 45 | 46 | def objective(self, results): 47 | """ 48 | Takes an output of SingleModelAlgorithm.process_batch() and computes the 49 | optimized objective. For group DRO, the objective is the weighted average 50 | of losses, where groups have weights groupDRO.group_weights. 51 | Args: 52 | - results (dictionary): output of SingleModelAlgorithm.process_batch() 53 | Output: 54 | - objective (Tensor): optimized objective; size (1,). 55 | """ 56 | group_losses, _, _ = self.loss.compute_group_wise( 57 | results['y_pred'], 58 | results['y_true'], 59 | results['g'], 60 | self.grouper.n_groups, 61 | return_dict=False) 62 | return group_losses @ self.group_weights 63 | 64 | def _update(self, results, should_step=True): 65 | """ 66 | Process the batch, update the log, and update the model, group weights, and scheduler. 67 | Args: 68 | - batch (tuple of Tensors): a batch of data yielded by data loaders 69 | Output: 70 | - results (dictionary): information about the batch, such as: 71 | - g (Tensor) 72 | - y_true (Tensor) 73 | - metadata (Tensor) 74 | - loss (Tensor) 75 | - metrics (Tensor) 76 | - objective (float) 77 | """ 78 | # compute group losses 79 | group_losses, _, _ = self.loss.compute_group_wise( 80 | results['y_pred'], 81 | results['y_true'], 82 | results['g'], 83 | self.grouper.n_groups, 84 | return_dict=False) 85 | # update group weights 86 | self.group_weights = self.group_weights * torch.exp(self.group_weights_step_size*group_losses.data) 87 | self.group_weights = (self.group_weights/(self.group_weights.sum())) 88 | # save updated group weights 89 | results['group_weight'] = self.group_weights 90 | # update model 91 | super()._update(results, should_step=should_step) 92 | -------------------------------------------------------------------------------- /examples/configs/algorithm.py: -------------------------------------------------------------------------------- 1 | algorithm_defaults = { 2 | 'ERM': { 3 | 'train_loader': 'standard', 4 | 'uniform_over_groups': False, 5 | 'eval_loader': 'standard', 6 | 'randaugment_n': 2, # When running ERM + data augmentation 7 | }, 8 | 'groupDRO': { 9 | 'train_loader': 'standard', 10 | 'uniform_over_groups': True, 11 | 'distinct_groups': True, 12 | 'eval_loader': 'standard', 13 | 'group_dro_step_size': 0.01, 14 | }, 15 | 'deepCORAL': { 16 | 'train_loader': 'group', 17 | 'uniform_over_groups': True, 18 | 'distinct_groups': True, 19 | 'eval_loader': 'standard', 20 | 'coral_penalty_weight': 1., 21 | 'randaugment_n': 2, 22 | 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples 23 | }, 24 | 'IRM': { 25 | 'train_loader': 'group', 26 | 'uniform_over_groups': True, 27 | 'distinct_groups': True, 28 | 'eval_loader': 'standard', 29 | 'irm_lambda': 100., 30 | 'irm_penalty_anneal_iters': 500, 31 | }, 32 | 'DANN': { 33 | 'train_loader': 'group', 34 | 'uniform_over_groups': True, 35 | 'distinct_groups': True, 36 | 'eval_loader': 'standard', 37 | 'randaugment_n': 2, 38 | 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples 39 | }, 40 | 'AFN': { 41 | 'train_loader': 'standard', 42 | 'uniform_over_groups': False, 43 | 'eval_loader': 'standard', 44 | 'use_hafn': False, 45 | 'afn_penalty_weight': 0.01, 46 | 'safn_delta_r': 1.0, 47 | 'hafn_r': 1.0, 48 | 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples 49 | 'randaugment_n': 2, 50 | }, 51 | 'FixMatch': { 52 | 'train_loader': 'standard', 53 | 'uniform_over_groups': False, 54 | 'eval_loader': 'standard', 55 | 'self_training_lambda': 1, 56 | 'self_training_threshold': 0.7, 57 | 'scheduler': 'FixMatchLR', 58 | 'randaugment_n': 2, 59 | 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled examples 60 | }, 61 | 'PseudoLabel': { 62 | 'train_loader': 'standard', 63 | 'uniform_over_groups': False, 64 | 'eval_loader': 'standard', 65 | 'self_training_lambda': 1, 66 | 'self_training_threshold': 0.7, 67 | 'pseudolabel_T2': 0.4, 68 | 'scheduler': 'FixMatchLR', 69 | 'randaugment_n': 2, 70 | 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples 71 | }, 72 | 'NoisyStudent': { 73 | 'train_loader': 'standard', 74 | 'uniform_over_groups': False, 75 | 'eval_loader': 'standard', 76 | 'noisystudent_add_dropout': True, 77 | 'noisystudent_dropout_rate': 0.5, 78 | 'scheduler': 'FixMatchLR', 79 | 'randaugment_n': 2, 80 | 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /examples/configs/data_loader.py: -------------------------------------------------------------------------------- 1 | loader_defaults = { 2 | 'loader_kwargs': { 3 | 'num_workers': 4, 4 | 'pin_memory': True, 5 | }, 6 | 'unlabeled_loader_kwargs': { 7 | 'num_workers': 8, 8 | 'pin_memory': True, 9 | }, 10 | 'n_groups_per_batch': 4, 11 | } 12 | -------------------------------------------------------------------------------- /examples/configs/model.py: -------------------------------------------------------------------------------- 1 | model_defaults = { 2 | 'bert-base-uncased': { 3 | 'optimizer': 'AdamW', 4 | 'max_grad_norm': 1.0, 5 | 'scheduler': 'linear_schedule_with_warmup', 6 | }, 7 | 'distilbert-base-uncased': { 8 | 'optimizer': 'AdamW', 9 | 'max_grad_norm': 1.0, 10 | 'scheduler': 'linear_schedule_with_warmup', 11 | }, 12 | 'code-gpt-py': { 13 | 'optimizer': 'AdamW', 14 | 'max_grad_norm': 1.0, 15 | 'scheduler': 'linear_schedule_with_warmup', 16 | }, 17 | 'densenet121': { 18 | 'model_kwargs': { 19 | 'pretrained':True, 20 | }, 21 | 'target_resolution': (224, 224), 22 | }, 23 | 'wideresnet50': { 24 | 'model_kwargs': { 25 | 'pretrained':True, 26 | }, 27 | 'target_resolution': (224, 224), 28 | }, 29 | 'resnet18': { 30 | 'model_kwargs':{ 31 | 'pretrained':True, 32 | }, 33 | 'target_resolution': (224, 224), 34 | }, 35 | 'resnet34': { 36 | 'model_kwargs':{ 37 | 'pretrained':True, 38 | }, 39 | 'target_resolution': (224, 224), 40 | }, 41 | 'resnet50': { 42 | 'model_kwargs': { 43 | 'pretrained': True, 44 | }, 45 | 'target_resolution': (224, 224), 46 | }, 47 | 'resnet101': { 48 | 'model_kwargs': { 49 | 'pretrained': True, 50 | }, 51 | 'target_resolution': (224, 224), 52 | }, 53 | 'gin-virtual': {}, 54 | 'resnet18_ms': { 55 | 'target_resolution': (224, 224), 56 | }, 57 | 'logistic_regression': {}, 58 | 'unet-seq': { 59 | 'optimizer': 'Adam' 60 | }, 61 | 'fasterrcnn': { 62 | 'model_kwargs': { 63 | 'pretrained_model': True, 64 | 'pretrained_backbone': True, 65 | 'min_size' :1024, 66 | 'max_size' :1024 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /examples/configs/scheduler.py: -------------------------------------------------------------------------------- 1 | scheduler_defaults = { 2 | 'linear_schedule_with_warmup': { 3 | 'scheduler_kwargs':{ 4 | 'num_warmup_steps': 0, 5 | }, 6 | }, 7 | 'cosine_schedule_with_warmup': { 8 | 'scheduler_kwargs':{ 9 | 'num_warmup_steps': 0, 10 | }, 11 | }, 12 | 'ReduceLROnPlateau': { 13 | 'scheduler_kwargs':{}, 14 | }, 15 | 'StepLR': { 16 | 'scheduler_kwargs':{ 17 | 'step_size': 1, 18 | } 19 | }, 20 | 'FixMatchLR': { 21 | 'scheduler_kwargs': {}, 22 | }, 23 | 'MultiStepLR': { 24 | 'scheduler_kwargs':{ 25 | 'gamma': 0.1, 26 | } 27 | }, 28 | } 29 | -------------------------------------------------------------------------------- /examples/configs/supported.py: -------------------------------------------------------------------------------- 1 | from wilds.common.metrics.all_metrics import ( 2 | Accuracy, 3 | MultiTaskAccuracy, 4 | MSE, 5 | multiclass_logits_to_pred, 6 | binary_logits_to_pred, 7 | pseudolabel_binary_logits, 8 | pseudolabel_multiclass_logits, 9 | pseudolabel_identity, 10 | pseudolabel_detection, 11 | pseudolabel_detection_discard_empty, 12 | MultiTaskAveragePrecision 13 | ) 14 | 15 | algo_log_metrics = { 16 | 'accuracy': Accuracy(prediction_fn=multiclass_logits_to_pred), 17 | 'mse': MSE(), 18 | 'multitask_accuracy': MultiTaskAccuracy(prediction_fn=multiclass_logits_to_pred), 19 | 'multitask_binary_accuracy': MultiTaskAccuracy(prediction_fn=binary_logits_to_pred), 20 | 'multitask_avgprec': MultiTaskAveragePrecision(prediction_fn=None), 21 | None: None, 22 | } 23 | 24 | process_outputs_functions = { 25 | 'binary_logits_to_pred': binary_logits_to_pred, 26 | 'multiclass_logits_to_pred': multiclass_logits_to_pred, 27 | None: None, 28 | } 29 | 30 | process_pseudolabels_functions = { 31 | 'pseudolabel_binary_logits': pseudolabel_binary_logits, 32 | 'pseudolabel_multiclass_logits': pseudolabel_multiclass_logits, 33 | 'pseudolabel_identity': pseudolabel_identity, 34 | 'pseudolabel_detection': pseudolabel_detection, 35 | 'pseudolabel_detection_discard_empty': pseudolabel_detection_discard_empty, 36 | } 37 | 38 | # see initialize_*() functions for correspondence= 39 | # See algorithms/initializer.py 40 | algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM', 'DANN', 'AFN', 'FixMatch', 'PseudoLabel', 'NoisyStudent'] 41 | 42 | # See transforms.py 43 | transforms = ['bert', 'image_base', 'image_resize', 'image_resize_and_center_crop', 'poverty', 'rxrx1'] 44 | additional_transforms = ['randaugment', 'weak'] 45 | 46 | # See models/initializer.py 47 | models = ['resnet18_ms', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'wideresnet50', 48 | 'densenet121', 'bert-base-uncased', 'distilbert-base-uncased', 49 | 'gin-virtual', 'logistic_regression', 'code-gpt-py', 50 | 'fasterrcnn', 'unet-seq'] 51 | 52 | # See optimizer.py 53 | optimizers = ['SGD', 'Adam', 'AdamW'] 54 | 55 | # See scheduler.py 56 | schedulers = ['linear_schedule_with_warmup', 'cosine_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR', 'FixMatchLR', 'MultiStepLR'] 57 | 58 | # See losses.py 59 | losses = ['cross_entropy', 'lm_cross_entropy', 'MSE', 'multitask_bce', 'fasterrcnn_criterion', 'cross_entropy_logits'] 60 | -------------------------------------------------------------------------------- /examples/data_augmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p-lambda/wilds/472677590de351857197a9bf24958838c39c272b/examples/data_augmentation/__init__.py -------------------------------------------------------------------------------- /examples/data_augmentation/randaugment.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/YBZh/Bridging_UDA_SSL 2 | 3 | import torch 4 | from PIL import Image, ImageOps, ImageEnhance, ImageDraw 5 | 6 | 7 | def AutoContrast(img, _): 8 | return ImageOps.autocontrast(img) 9 | 10 | 11 | def Brightness(img, v): 12 | assert v >= 0.0 13 | return ImageEnhance.Brightness(img).enhance(v) 14 | 15 | 16 | def Color(img, v): 17 | assert v >= 0.0 18 | return ImageEnhance.Color(img).enhance(v) 19 | 20 | 21 | def Contrast(img, v): 22 | assert v >= 0.0 23 | return ImageEnhance.Contrast(img).enhance(v) 24 | 25 | 26 | def Equalize(img, _): 27 | return ImageOps.equalize(img) 28 | 29 | 30 | def Invert(img, _): 31 | return ImageOps.invert(img) 32 | 33 | 34 | def Identity(img, v): 35 | return img 36 | 37 | 38 | def Posterize(img, v): # [4, 8] 39 | v = int(v) 40 | v = max(1, v) 41 | return ImageOps.posterize(img, v) 42 | 43 | 44 | def Rotate(img, v): # [-30, 30] 45 | return img.rotate(v) 46 | 47 | 48 | def Sharpness(img, v): # [0.1,1.9] 49 | assert v >= 0.0 50 | return ImageEnhance.Sharpness(img).enhance(v) 51 | 52 | 53 | def ShearX(img, v): # [-0.3, 0.3] 54 | return img.transform(img.size, Image.AFFINE, (1, v, 0, 0, 1, 0)) 55 | 56 | 57 | def ShearY(img, v): # [-0.3, 0.3] 58 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, v, 1, 0)) 59 | 60 | 61 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 62 | v = v * img.size[0] 63 | return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) 64 | 65 | 66 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 67 | return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) 68 | 69 | 70 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 71 | v = v * img.size[1] 72 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) 73 | 74 | 75 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 76 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) 77 | 78 | 79 | def Solarize(img, v): # [0, 256] 80 | assert 0 <= v <= 256 81 | return ImageOps.solarize(img, v) 82 | 83 | 84 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] => change to [0, 0.5] 85 | assert 0.0 <= v <= 0.5 86 | 87 | v = v * img.size[0] 88 | return CutoutAbs(img, v) 89 | 90 | 91 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 92 | if v < 0: 93 | return img 94 | w, h = img.size 95 | x_center = _sample_uniform(0, w) 96 | y_center = _sample_uniform(0, h) 97 | 98 | x0 = int(max(0, x_center - v / 2.0)) 99 | y0 = int(max(0, y_center - v / 2.0)) 100 | x1 = min(w, x0 + v) 101 | y1 = min(h, y0 + v) 102 | 103 | xy = (x0, y0, x1, y1) 104 | color = (125, 123, 114) 105 | img = img.copy() 106 | ImageDraw.Draw(img).rectangle(xy, color) 107 | return img 108 | 109 | 110 | FIX_MATCH_AUGMENTATION_POOL = [ 111 | (AutoContrast, 0, 1), 112 | (Brightness, 0.05, 0.95), 113 | (Color, 0.05, 0.95), 114 | (Contrast, 0.05, 0.95), 115 | (Equalize, 0, 1), 116 | (Identity, 0, 1), 117 | (Posterize, 4, 8), 118 | (Rotate, -30, 30), 119 | (Sharpness, 0.05, 0.95), 120 | (ShearX, -0.3, 0.3), 121 | (ShearY, -0.3, 0.3), 122 | (Solarize, 0, 256), 123 | (TranslateX, -0.3, 0.3), 124 | (TranslateY, -0.3, 0.3), 125 | ] 126 | 127 | 128 | def _sample_uniform(a, b): 129 | return torch.empty(1).uniform_(a, b).item() 130 | 131 | 132 | class RandAugment: 133 | def __init__(self, n, augmentation_pool): 134 | assert n >= 1, "RandAugment N has to be a value greater than or equal to 1." 135 | self.n = n 136 | self.augmentation_pool = augmentation_pool 137 | 138 | def __call__(self, img): 139 | ops = [ 140 | self.augmentation_pool[torch.randint(len(self.augmentation_pool), (1,))] 141 | for _ in range(self.n) 142 | ] 143 | for op, min_val, max_val in ops: 144 | val = min_val + float(max_val - min_val) * _sample_uniform(0, 1) 145 | img = op(img, val) 146 | cutout_val = _sample_uniform(0, 1) * 0.5 147 | img = Cutout(img, cutout_val) 148 | return img 149 | -------------------------------------------------------------------------------- /examples/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss 3 | from wilds.common.metrics.all_metrics import MSE 4 | from utils import cross_entropy_with_logits_loss 5 | 6 | def initialize_loss(loss, config): 7 | if loss == 'cross_entropy': 8 | return ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none', ignore_index=-100)) 9 | 10 | elif loss == 'lm_cross_entropy': 11 | return MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none', ignore_index=-100)) 12 | 13 | elif loss == 'mse': 14 | return MSE(name='loss') 15 | 16 | elif loss == 'multitask_bce': 17 | return MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')) 18 | 19 | elif loss == 'fasterrcnn_criterion': 20 | from models.detection.fasterrcnn import FasterRCNNLoss 21 | return ElementwiseLoss(loss_fn=FasterRCNNLoss(config.device)) 22 | 23 | elif loss == 'cross_entropy_logits': 24 | return ElementwiseLoss(loss_fn=cross_entropy_with_logits_loss) 25 | 26 | else: 27 | raise ValueError(f'loss {loss} not recognized') 28 | -------------------------------------------------------------------------------- /examples/models/CNN_genome.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def single_conv(in_channels, out_channels, kernel_size=7): 9 | padding_size = int((kernel_size-1)/2) 10 | return nn.Sequential( 11 | nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding_size), 12 | nn.BatchNorm1d(out_channels), 13 | nn.ReLU(inplace=True) 14 | ) 15 | 16 | def double_conv(in_channels, out_channels, kernel_size=7): 17 | padding_size = int((kernel_size-1)/2) 18 | return nn.Sequential( 19 | nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding_size), 20 | nn.BatchNorm1d(out_channels), 21 | nn.ReLU(inplace=True), 22 | nn.Conv1d(out_channels, out_channels, kernel_size, padding=padding_size), 23 | nn.BatchNorm1d(out_channels), 24 | nn.ReLU(inplace=True) 25 | ) 26 | 27 | 28 | class UNet(nn.Module): 29 | def __init__(self, num_tasks=16, n_channels_in=5): 30 | super().__init__() 31 | 32 | self.dconv_down1 = double_conv(n_channels_in, 15) 33 | self.dconv_down2 = double_conv(15, 22) 34 | self.dconv_down3 = double_conv(22, 33) 35 | self.dconv_down4 = double_conv(33, 49) 36 | self.dconv_down5 = double_conv(49, 73) 37 | self.dconv_down6 = double_conv(73, 109) 38 | 39 | self.maxpool = nn.MaxPool1d(2) 40 | # self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 41 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 42 | # self.conv_middle = single_conv(109, 109) 43 | self.upsamp_6 = nn.ConvTranspose1d(109, 109, 2, stride=2) 44 | 45 | self.dconv_up5 = double_conv(73 + 109, 73) 46 | self.upsamp_5 = nn.ConvTranspose1d(73, 73, 2, stride=2) 47 | self.dconv_up4 = double_conv(49 + 73, 49) 48 | self.upsamp_4 = nn.ConvTranspose1d(49, 49, 2, stride=2) 49 | self.dconv_up3 = double_conv(33 + 49, 33) 50 | self.upsamp_3 = nn.ConvTranspose1d(33, 33, 2, stride=2) 51 | self.dconv_up2 = double_conv(22 + 33, 22) 52 | self.upsamp_2 = nn.ConvTranspose1d(22, 22, 2, stride=2) 53 | self.dconv_up1 = double_conv(15 + 22, 15) 54 | self.upsamp_1 = nn.ConvTranspose1d(15, 15, 2, stride=2) 55 | 56 | self.conv_last = nn.Conv1d(15, 1, 200, stride=50, padding=0) 57 | self.d_out = num_tasks if num_tasks is not None else 253 58 | 59 | self.fc_last = nn.Linear(253, 128) 60 | 61 | 62 | def forward(self, x): 63 | # input_size = 12800 64 | # input_channels = 5 65 | x = x.float() 66 | conv1 = self.dconv_down1(x) # Output size: (input_size) x 15 67 | x = self.maxpool(conv1) # (input_size / 2) x 15 68 | 69 | conv2 = self.dconv_down2(x) # (input_size / 2) x 22 70 | x = self.maxpool(conv2) # (input_size / 4) x 22 71 | 72 | conv3 = self.dconv_down3(x) # (input_size / 4) x 33 73 | x = self.maxpool(conv3) # (input_size / 8) x 33 74 | 75 | conv4 = self.dconv_down4(x) # (input_size / 8) x 49 76 | x = self.maxpool(conv4) # (input_size / 16) x 49 77 | 78 | conv5 = self.dconv_down5(x) # (input_size / 16) x 73 79 | x = self.maxpool(conv5) # (input_size / 32) x 73 80 | 81 | conv6 = self.dconv_down6(x) # (input_size / 32) x 109 82 | # conv6 = self.conv_middle(conv6) # Optional: convolution here. 83 | 84 | # Encoder finished. 85 | 86 | x = self.upsamp_6(conv6) # (input_size / 16) x 109 87 | x = torch.cat([x, conv5], dim=1) # (input_size / 16) x (109 + 73) 88 | 89 | x = self.dconv_up5(x) # (input_size / 16) x 73 90 | x = self.upsamp_5(x) # (input_size / 8) x 73 91 | x = torch.cat([x, conv4], dim=1) # (input_size / 8) x (73 + 49) 92 | 93 | x = self.dconv_up4(x) # (input_size / 8) x 49 94 | x = self.upsamp_4(x) # (input_size / 4) x 49 95 | x = torch.cat([x, conv3], dim=1) # (input_size / 4) x (49 + 33) 96 | 97 | x = self.dconv_up3(x) # (input_size / 4) x 33 98 | x = self.upsamp_3(x) # (input_size / 2) x 33 99 | x = torch.cat([x, conv2], dim=1) # (input_size / 2) x (33 + 22) 100 | 101 | x = self.dconv_up2(x) # (input_size / 2) x 22 102 | x = self.upsamp_2(x) # (input_size) x 22 103 | x = torch.cat([x, conv1], dim=1) # (input_size) x (22 + 15) 104 | 105 | x = self.dconv_up1(x) # (input_size) x 15 106 | 107 | x = self.conv_last(x) # (input_size/50 - 3) x 1 108 | x = torch.squeeze(x) 109 | 110 | # Default input_size == 12800: x has size N x 1 x 253 at this point. 111 | if self.d_out == 253: 112 | out = x 113 | else: 114 | out = self.fc_last(x) 115 | # out = x[:, 64:192] # middle 128 values 116 | 117 | return out 118 | -------------------------------------------------------------------------------- /examples/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p-lambda/wilds/472677590de351857197a9bf24958838c39c272b/examples/models/__init__.py -------------------------------------------------------------------------------- /examples/models/bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p-lambda/wilds/472677590de351857197a9bf24958838c39c272b/examples/models/bert/__init__.py -------------------------------------------------------------------------------- /examples/models/bert/bert.py: -------------------------------------------------------------------------------- 1 | from transformers import BertForSequenceClassification, BertModel 2 | import torch 3 | 4 | class BertClassifier(BertForSequenceClassification): 5 | def __init__(self, config): 6 | super().__init__(config) 7 | self.d_out = config.num_labels 8 | 9 | def __call__(self, x): 10 | input_ids = x[:, :, 0] 11 | attention_mask = x[:, :, 1] 12 | token_type_ids = x[:, :, 2] 13 | outputs = super().__call__( 14 | input_ids=input_ids, 15 | attention_mask=attention_mask, 16 | token_type_ids=token_type_ids 17 | )[0] 18 | return outputs 19 | 20 | class BertFeaturizer(BertModel): 21 | def __init__(self, config): 22 | super().__init__(config) 23 | self.d_out = config.hidden_size 24 | 25 | def __call__(self, x): 26 | input_ids = x[:, :, 0] 27 | attention_mask = x[:, :, 1] 28 | token_type_ids = x[:, :, 2] 29 | outputs = super().__call__( 30 | input_ids=input_ids, 31 | attention_mask=attention_mask, 32 | token_type_ids=token_type_ids 33 | )[1] # get pooled output 34 | return outputs 35 | -------------------------------------------------------------------------------- /examples/models/bert/distilbert.py: -------------------------------------------------------------------------------- 1 | from transformers import DistilBertForSequenceClassification, DistilBertModel 2 | 3 | class DistilBertClassifier(DistilBertForSequenceClassification): 4 | def __init__(self, config): 5 | super().__init__(config) 6 | 7 | def __call__(self, x): 8 | input_ids = x[:, :, 0] 9 | attention_mask = x[:, :, 1] 10 | outputs = super().__call__( 11 | input_ids=input_ids, 12 | attention_mask=attention_mask, 13 | )[0] 14 | return outputs 15 | 16 | 17 | class DistilBertFeaturizer(DistilBertModel): 18 | def __init__(self, config): 19 | super().__init__(config) 20 | self.d_out = config.hidden_size 21 | 22 | def __call__(self, x): 23 | input_ids = x[:, :, 0] 24 | attention_mask = x[:, :, 1] 25 | hidden_state = super().__call__( 26 | input_ids=input_ids, 27 | attention_mask=attention_mask, 28 | )[0] 29 | pooled_output = hidden_state[:, 0] 30 | return pooled_output 31 | -------------------------------------------------------------------------------- /examples/models/code_gpt.py: -------------------------------------------------------------------------------- 1 | from transformers import GPT2LMHeadModel, GPT2Model 2 | import torch 3 | 4 | class GPT2LMHeadLogit(GPT2LMHeadModel): 5 | def __init__(self, config): 6 | super().__init__(config) 7 | self.d_out = config.vocab_size 8 | 9 | def __call__(self, x): 10 | outputs = super().__call__(x) 11 | logits = outputs[0] #[batch_size, seqlen, vocab_size] 12 | return logits 13 | 14 | 15 | class GPT2Featurizer(GPT2Model): 16 | def __init__(self, config): 17 | super().__init__(config) 18 | self.d_out = config.n_embd 19 | 20 | def __call__(self, x): 21 | outputs = super().__call__(x) 22 | hidden_states = outputs[0] #[batch_size, seqlen, n_embd] 23 | return hidden_states 24 | 25 | 26 | class GPT2FeaturizerLMHeadLogit(GPT2LMHeadModel): 27 | def __init__(self, config): 28 | super().__init__(config) 29 | self.d_out = config.vocab_size 30 | self.transformer = GPT2Featurizer(config) 31 | 32 | def __call__(self, x): 33 | hidden_states = self.transformer(x) #[batch_size, seqlen, n_embd] 34 | logits = self.lm_head(hidden_states) #[batch_size, seqlen, vocab_size] 35 | return logits 36 | -------------------------------------------------------------------------------- /examples/models/domain_adversarial_network.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | 7 | 8 | class DomainDiscriminator(nn.Sequential): 9 | """ 10 | Adapted from https://github.com/thuml/Transfer-Learning-Library 11 | 12 | Domain discriminator model from 13 | `"Domain-Adversarial Training of Neural Networks" `_ 14 | In the original paper and implementation, we distinguish whether the input features come 15 | from the source domain or the target domain. 16 | 17 | We extended this to work with multiple domains, which is controlled by the n_domains 18 | argument. 19 | 20 | Args: 21 | in_feature (int): dimension of the input feature 22 | n_domains (int): number of domains to discriminate 23 | hidden_size (int): dimension of the hidden features 24 | batch_norm (bool): whether use :class:`~torch.nn.BatchNorm1d`. 25 | Use :class:`~torch.nn.Dropout` if ``batch_norm`` is False. Default: True. 26 | Shape: 27 | - Inputs: (minibatch, `in_feature`) 28 | - Outputs: :math:`(minibatch, n_domains)` 29 | """ 30 | 31 | def __init__( 32 | self, in_feature: int, n_domains, hidden_size: int = 1024, batch_norm=True 33 | ): 34 | if batch_norm: 35 | super(DomainDiscriminator, self).__init__( 36 | nn.Linear(in_feature, hidden_size), 37 | nn.BatchNorm1d(hidden_size), 38 | nn.ReLU(), 39 | nn.Linear(hidden_size, hidden_size), 40 | nn.BatchNorm1d(hidden_size), 41 | nn.ReLU(), 42 | nn.Linear(hidden_size, n_domains), 43 | ) 44 | else: 45 | super(DomainDiscriminator, self).__init__( 46 | nn.Linear(in_feature, hidden_size), 47 | nn.ReLU(inplace=True), 48 | nn.Dropout(0.5), 49 | nn.Linear(hidden_size, hidden_size), 50 | nn.ReLU(inplace=True), 51 | nn.Dropout(0.5), 52 | nn.Linear(hidden_size, n_domains), 53 | ) 54 | 55 | def get_parameters_with_lr(self, lr) -> List[Dict]: 56 | return [{"params": self.parameters(), "lr": lr}] 57 | 58 | class GradientReverseFunction(Function): 59 | """ 60 | Credit: https://github.com/thuml/Transfer-Learning-Library 61 | """ 62 | @staticmethod 63 | def forward( 64 | ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.0 65 | ) -> torch.Tensor: 66 | ctx.coeff = coeff 67 | output = input * 1.0 68 | return output 69 | 70 | @staticmethod 71 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]: 72 | return grad_output.neg() * ctx.coeff, None 73 | 74 | 75 | class GradientReverseLayer(nn.Module): 76 | """ 77 | Credit: https://github.com/thuml/Transfer-Learning-Library 78 | """ 79 | def __init__(self): 80 | super(GradientReverseLayer, self).__init__() 81 | 82 | def forward(self, *input): 83 | return GradientReverseFunction.apply(*input) 84 | 85 | 86 | class DomainAdversarialNetwork(nn.Module): 87 | def __init__(self, featurizer, classifier, n_domains): 88 | super().__init__() 89 | self.featurizer = featurizer 90 | self.classifier = classifier 91 | self.domain_classifier = DomainDiscriminator(featurizer.d_out, n_domains) 92 | self.gradient_reverse_layer = GradientReverseLayer() 93 | 94 | def forward(self, input): 95 | features = self.featurizer(input) 96 | y_pred = self.classifier(features) 97 | features = self.gradient_reverse_layer(features) 98 | domains_pred = self.domain_classifier(features) 99 | return y_pred, domains_pred 100 | 101 | def get_parameters_with_lr(self, featurizer_lr, classifier_lr, discriminator_lr) -> List[Dict]: 102 | """ 103 | Adapted from https://github.com/thuml/Transfer-Learning-Library 104 | 105 | A parameter list which decides optimization hyper-parameters, 106 | such as the relative learning rate of each layer 107 | """ 108 | # In TLL's implementation, the learning rate of this classifier is set 10 times to that of the 109 | # feature extractor for better accuracy by default. For our implementation, we allow the learning 110 | # rates to be passed in separately for featurizer and classifier. 111 | params = [ 112 | {"params": self.featurizer.parameters(), "lr": featurizer_lr}, 113 | {"params": self.classifier.parameters(), "lr": classifier_lr}, 114 | ] 115 | return params + self.domain_classifier.get_parameters_with_lr(discriminator_lr) 116 | -------------------------------------------------------------------------------- /examples/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Identity(nn.Module): 6 | """An identity layer""" 7 | def __init__(self, d): 8 | super().__init__() 9 | self.in_features = d 10 | self.out_features = d 11 | 12 | def forward(self, x): 13 | return x 14 | -------------------------------------------------------------------------------- /examples/noisy_student_wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper code to run multiple iterations of Noisy Student, using the same hyperparameters between iterations. The initial teacher's weights must be provided by the command line. 3 | 4 | Normally, to run 2 warm-started iterations with some initial teacher weights, one would run a sequence of commands: 5 | python examples/run_expt.py --root_dir $HOME --log_dir ./student1 --dataset DATASET --algorithm NoisyStudent --unlabeled_split test_unlabeled --teacher_model_path teacher_weights.pth --pretrained_model_path teacher_weights.pth 6 | python examples/run_expt.py --root_dir $HOME --log_dir ./student2 --dataset DATASET --algorithm NoisyStudent --unlabeled_split test_unlabeled --teacher_model_path ./student1/model.pth --pretrained_model_path ./student1/model.pth 7 | 8 | With this script, to run 2 warm-started iterations with some initial teacher weights: 9 | python examples/noisy_student_wrapper.py 2 teacher_weights.pth --root_dir $HOME --log_dir . --dataset DATASET --unlabeled_split test_unlabeled 10 | 11 | i.e. usage: 12 | python examples/noisy_student_wrapper.py [NUM_ITERS] [INITIAL_TEACHER_WEIGHTS] [REST OF RUN_EXPT COMMAND STRING] 13 | 14 | Notes: 15 | - Students are all warm-started with the current teacher's weights. 16 | - This command will use the FIRST occurrence of --log_dir (instead of the last). 17 | """ 18 | import argparse 19 | import os 20 | import pathlib 21 | import pdb 22 | import subprocess 23 | 24 | SUCCESS_RETURN_CODE = 0 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("num_iters", type=int) 28 | parser.add_argument("initial_teacher_path", type=str) # required 29 | parser.add_argument("cmd", nargs=argparse.REMAINDER) 30 | args = parser.parse_args() 31 | 32 | assert args.initial_teacher_path.endswith(".pth") 33 | assert os.path.exists( 34 | args.initial_teacher_path 35 | ), f"Model weights did not exist at {args.initial_teacher_path}" 36 | prefix = pathlib.Path(__file__).parent.resolve() 37 | 38 | 39 | def remove_arg(args, arg_to_remove): 40 | idx = args.cmd.index(f"--{arg_to_remove}") 41 | value = args.cmd[idx + 1] 42 | args.cmd = args.cmd[:idx] + args.cmd[idx + 2 :] 43 | return value 44 | 45 | 46 | # Parse out a few args that we need 47 | try: 48 | idx = args.cmd.index("--log_dir") 49 | log_dir = args.cmd[idx + 1] 50 | args.cmd = ( 51 | args.cmd[:idx] + args.cmd[idx + 2 :] 52 | ) # will need to modify this between iters, so remove from args.cmd 53 | except: 54 | log_dir = "./logs" # default in run_expt.py 55 | 56 | idx = args.cmd.index("--dataset") 57 | dataset = args.cmd[idx + 1] 58 | 59 | try: 60 | idx = args.cmd.index("--seed") 61 | seed = args.cmd[idx + 1] 62 | except: 63 | seed = 0 # default in run_expt.py 64 | 65 | try: 66 | idx = args.cmd.index("--dataset_kwargs") 67 | fold = args.cmd[idx + 1] 68 | assert fold.startswith("fold=") 69 | fold = fold.replace("fold=", "") 70 | except: 71 | fold = "A" 72 | 73 | # Train the teacher model without unlabeled data and default values for gradient_accumulation_steps and n_epochs 74 | unlabeled_split = remove_arg(args, "unlabeled_split") 75 | gradient_accumulation_steps = remove_arg(args, "gradient_accumulation_steps") 76 | n_epochs = remove_arg(args, "n_epochs") 77 | 78 | # Run student iterations 79 | for i in range(1, args.num_iters + 1): 80 | if i == 1: 81 | teacher_weights = args.initial_teacher_path 82 | else: 83 | if dataset == "poverty": 84 | teacher_weights = ( 85 | f"{log_dir}/student{i - 1}/{dataset}_fold:{fold}_epoch:best_model.pth" 86 | ) 87 | else: 88 | teacher_weights = ( 89 | f"{log_dir}/student{i-1}/{dataset}_seed:{seed}_epoch:best_model.pth" 90 | ) 91 | cmd = ( 92 | f"python {prefix}/run_expt.py --algorithm NoisyStudent {' '.join(args.cmd)}" 93 | + f" --unlabeled_split {unlabeled_split} --gradient_accumulation_steps {gradient_accumulation_steps}" 94 | + f" --n_epochs {n_epochs} --log_dir {log_dir}/student{i}" 95 | + f" --teacher_model_path {teacher_weights}" 96 | + f" --pretrained_model_path {teacher_weights}" # warm starting 97 | ) 98 | print(f">>> Running {cmd}") 99 | return_code = subprocess.Popen(cmd, shell=True).wait() 100 | if return_code != SUCCESS_RETURN_CODE: 101 | raise RuntimeError( 102 | f"FAILED: Iteration {i} failed with return code: {return_code}" 103 | ) 104 | 105 | print(">>> Done!") 106 | -------------------------------------------------------------------------------- /examples/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import SGD, Adam 2 | from transformers import AdamW 3 | 4 | def initialize_optimizer(config, model): 5 | # initialize optimizers 6 | if config.optimizer=='SGD': 7 | params = filter(lambda p: p.requires_grad, model.parameters()) 8 | optimizer = SGD( 9 | params, 10 | lr=config.lr, 11 | weight_decay=config.weight_decay, 12 | **config.optimizer_kwargs) 13 | elif config.optimizer=='AdamW': 14 | if 'bert' in config.model or 'gpt' in config.model: 15 | no_decay = ['bias', 'LayerNorm.weight'] 16 | else: 17 | no_decay = [] 18 | 19 | params = [ 20 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': config.weight_decay}, 21 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 22 | ] 23 | optimizer = AdamW( 24 | params, 25 | lr=config.lr, 26 | **config.optimizer_kwargs) 27 | elif config.optimizer == 'Adam': 28 | params = filter(lambda p: p.requires_grad, model.parameters()) 29 | optimizer = Adam( 30 | params, 31 | lr=config.lr, 32 | weight_decay=config.weight_decay, 33 | **config.optimizer_kwargs) 34 | else: 35 | raise ValueError(f'Optimizer {config.optimizer} not recognized.') 36 | 37 | return optimizer 38 | 39 | def initialize_optimizer_with_model_params(config, params): 40 | if config.optimizer=='SGD': 41 | optimizer = SGD( 42 | params, 43 | lr=config.lr, 44 | weight_decay=config.weight_decay, 45 | **config.optimizer_kwargs 46 | ) 47 | elif config.optimizer=='AdamW': 48 | optimizer = AdamW( 49 | params, 50 | lr=config.lr, 51 | weight_decay=config.weight_decay, 52 | **config.optimizer_kwargs 53 | ) 54 | elif config.optimizer == 'Adam': 55 | optimizer = Adam( 56 | params, 57 | lr=config.lr, 58 | weight_decay=config.weight_decay, 59 | **config.optimizer_kwargs 60 | ) 61 | else: 62 | raise ValueError(f'Optimizer {config.optimizer} not supported.') 63 | 64 | return optimizer 65 | -------------------------------------------------------------------------------- /examples/pretraining/mlm/README.md: -------------------------------------------------------------------------------- 1 | # Masked LM Pre-training 2 | 3 | ## Dependencies 4 | - datasets==1.11.0 5 | - transformers==4.9.1 6 | 7 | ## Usage 8 | 1. Format the unlabeled text data in the hugging-face format 9 | ``` 10 | python3 examples/pretraining/mlm/get_data.py 11 | ``` 12 | 13 | 2. Run the commands in `examples/pretraining/mlm/run_pretrain.sh` to start masked LM pre-training 14 | 15 | 3. Use the pre-trained model in WILDS fine-tuning, e.g., 16 | ``` 17 | python3 examples/run_expt.py --dataset civilcomments --algorithm ERM --root_dir data \ 18 | --model distilbert-base-uncased \ 19 | --pretrained_model_path examples/pretraining/mlm/data/_run__distilbert-base-uncased__civilcomments__b32a256_lr1e-4/checkpoint-1500/pytorch_model.bin 20 | ``` 21 | -------------------------------------------------------------------------------- /examples/pretraining/mlm/get_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | from collections import defaultdict 7 | import csv 8 | 9 | os.system('mkdir -p examples/pretraining/mlm/data') 10 | 11 | 12 | ######################## CivilComments ######################## 13 | CCU_metadata_df = pd.read_csv('data/civilcomments_unlabeled_v1.0/unlabeled_data_with_identities.csv', index_col=0) 14 | CCU_text_array = list(CCU_metadata_df['comment_text']) #1_551_515 15 | 16 | with open('examples/pretraining/mlm/data/civilcomments_train.json', 'w') as outf: 17 | for text in tqdm(CCU_text_array): 18 | print (json.dumps({'text': text}), file=outf) 19 | 20 | 21 | CC_metadata_df = pd.read_csv('data/civilcomments_v1.0/all_data_with_identities.csv', index_col=0) 22 | CC_text_array_val = list(CC_metadata_df[CC_metadata_df['split'] == 'val']['comment_text']) #45_180 23 | 24 | with open('examples/pretraining/mlm/data/civilcomments_val.json', 'w') as outf: 25 | for text in tqdm(CC_text_array_val): 26 | print (json.dumps({'text': text}), file=outf) 27 | 28 | 29 | 30 | ######################## Amazon ######################## 31 | amazon_data_df: pd.DataFrame = pd.read_csv( 32 | 'data/amazon_v2.1/reviews.csv', 33 | dtype={ 34 | "reviewerID": str, 35 | "asin": str, 36 | "reviewTime": str, 37 | "unixReviewTime": int, 38 | "reviewText": str, 39 | "summary": str, 40 | "verified": bool, 41 | "category": str, 42 | "reviewYear": int, 43 | }, 44 | keep_default_na=False, 45 | na_values=[], 46 | quoting=csv.QUOTE_NONNUMERIC, 47 | ) #10_116_947 48 | 49 | amazon_split_df: pd.DataFrame = pd.read_csv('data/amazon_v2.1/splits/user.csv') #10_116_947 50 | is_in_dataset: bool = (amazon_split_df["split"] != -1) 51 | 52 | amazon_split_df = amazon_split_df[is_in_dataset] #4_002_170 53 | amazon_data_df = amazon_data_df[is_in_dataset] #4_002_170 54 | 55 | # "val_unlabeled": 11, "test_unlabeled": 12, "extra_unlabeled": 13, "val": 1 56 | _text_array_11 = list(amazon_data_df[amazon_split_df['split']==11]['reviewText']) #266_066 57 | _text_array_12 = list(amazon_data_df[amazon_split_df['split']==12]['reviewText']) #268_761 58 | _text_array_13 = list(amazon_data_df[amazon_split_df['split']==13]['reviewText']) #2_927_841 59 | _text_array_val = list(amazon_data_df[amazon_split_df['split']==1]['reviewText']) #100_050 60 | 61 | with open('examples/pretraining/mlm/data/amazon_train_11.json', 'w') as outf: 62 | for text in tqdm(_text_array_11): 63 | print (json.dumps({'text': text}), file=outf) 64 | 65 | with open('examples/pretraining/mlm/data/amazon_train_12.json', 'w') as outf: 66 | for text in tqdm(_text_array_12): 67 | print (json.dumps({'text': text}), file=outf) 68 | 69 | with open('examples/pretraining/mlm/data/amazon_train_13.json', 'w') as outf: 70 | for text in tqdm(_text_array_13): 71 | print (json.dumps({'text': text}), file=outf) 72 | 73 | with open('examples/pretraining/mlm/data/amazon_train_11_12_13.json', 'w') as outf: 74 | for text in tqdm(_text_array_11 + _text_array_12 + _text_array_13): 75 | print (json.dumps({'text': text}), file=outf) 76 | 77 | with open('examples/pretraining/mlm/data/amazon_val.json', 'w') as outf: 78 | for text in tqdm(_text_array_val): 79 | print (json.dumps({'text': text}), file=outf) 80 | -------------------------------------------------------------------------------- /examples/pretraining/mlm/run_pretrain.sh: -------------------------------------------------------------------------------- 1 | ######################## CivilComments ######################## 2 | dt=`date '+%Y%m%d_%H%M%S'` 3 | data_dir="mlm_pretrain/data" 4 | TRAIN_FILE="${data_dir}/civilcomments_train.json" 5 | VAL_FILE="${data_dir}/civilcomments_val.json" 6 | model="distilbert-base-uncased" 7 | outdir="${data_dir}/_run__${model}__civilcomments__b32a256_lr1e-4__${dt}" 8 | mkdir -p $outdir 9 | 10 | CUDA_VISIBLE_DEVICES=1 python3.7 -u mlm_pretrain/src/run_mlm.py \ 11 | --model_name_or_path $model \ 12 | --train_file $TRAIN_FILE --validation_file $VAL_FILE \ 13 | --do_train --do_eval --output_dir $outdir --overwrite_output_dir \ 14 | --line_by_line --max_seq_length 300 --fp16 --preprocessing_num_workers 10 --learning_rate 1e-4 \ 15 | --max_steps 1000 --logging_first_step --logging_steps 10 --save_steps 100 \ 16 | --evaluation_strategy steps --eval_steps 100 \ 17 | --per_device_train_batch_size 32 --per_device_eval_batch_size 64 --gradient_accumulation_steps 256 \ 18 | |& tee $outdir/log.txt 19 | 20 | 21 | 22 | ######################## Amazon ######################## 23 | dt=`date '+%Y%m%d_%H%M%S'` 24 | data_dir="mlm_pretrain/data" 25 | TRAIN_FILE="${data_dir}/amazon_train_12.json" 26 | VAL_FILE="${data_dir}/amazon_val.json" 27 | model="distilbert-base-uncased" 28 | outdir="${data_dir}/_run__${model}__amazon_12__b16a512_lr1e-4__${dt}" 29 | mkdir -p $outdir 30 | 31 | CUDA_VISIBLE_DEVICES=9 python3.7 -u mlm_pretrain/src/run_mlm.py \ 32 | --model_name_or_path $model \ 33 | --train_file $TRAIN_FILE --validation_file $VAL_FILE \ 34 | --do_train --do_eval --output_dir $outdir --overwrite_output_dir \ 35 | --line_by_line --max_seq_length 512 --fp16 --preprocessing_num_workers 10 --learning_rate 1e-4 \ 36 | --max_steps 1000 --logging_first_step --logging_steps 10 --save_steps 100 \ 37 | --evaluation_strategy steps --eval_steps 100 \ 38 | --per_device_train_batch_size 16 --per_device_eval_batch_size 32 --gradient_accumulation_steps 512 \ 39 | |& tee $outdir/log.txt 40 | -------------------------------------------------------------------------------- /examples/pretraining/swav/README.md: -------------------------------------------------------------------------------- 1 | # SwAV pre-training 2 | 3 | This folder is contains a lightly modified version of the SwAV code from https://github.com/facebookresearch/swav, licensed under CC BY-NC 4.0. 4 | 5 | If you use this algorithm, please cite the original source: 6 | ``` 7 | @article{caron2020unsupervised, 8 | title={Unsupervised Learning of Visual Features by Contrasting Cluster Assignments}, 9 | author={Caron, Mathilde and Misra, Ishan and Mairal, Julien and Goyal, Priya and Bojanowski, Piotr and Joulin, Armand}, 10 | booktitle={Proceedings of Advances in Neural Information Processing Systems (NeurIPS)}, 11 | year={2020} 12 | } 13 | ``` 14 | 15 | SwAV requires installation of the NVIDIA Apex library for mixed-precision training. Each Apex installation has a specific CUDA extension--more information can be found in the "requirements" section of the original SwAV repository's README: ([link](https://github.com/facebookresearch/swav)). 16 | 17 | ## Changes 18 | We made the following changes to the SwAV repository to interface with the WILDS code. 19 | 20 | ### `multicropdataset.py` 21 | - Added a new dataset class, CustomSplitMultiCropDataset, to accommodate WILDS data loaders, allowing SwAV to train on multiple datasets at once. 22 | ### Model building code 23 | - Pulled the changes from standard ResNets to SwAV-compatible ResNets into a new file (`model.py`), allowing to incorporate WILDS-Unlabeled architectures, including ResNets and DenseNets. 24 | ### `main_swav.py` 25 | - Edited data loading and model building code to be compatible with the 2 changes noted above. 26 | 27 | ## Pre-training on WILDS 28 | 29 | To run SwAV pre-training on the WILDS datasets with the default hyperparameters used in the [paper](https://arxiv.org/abs/2112.05090), 30 | simply run: 31 | 32 | ```buildoutcfg 33 | python -m torch.distributed.launch --nproc_per_node= main_swav.py --dataset --root_dir 34 | ``` 35 | 36 | We support SwAV pre-training on the following datasets: 37 | 38 | - `camelyon17` 39 | - `iwildcam` 40 | - `fmow` 41 | - `poverty` 42 | - `domainnet` -------------------------------------------------------------------------------- /examples/pretraining/swav/src/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import logging 10 | import time 11 | from datetime import timedelta 12 | import pandas as pd 13 | 14 | 15 | class LogFormatter: 16 | def __init__(self): 17 | self.start_time = time.time() 18 | 19 | def format(self, record): 20 | elapsed_seconds = round(record.created - self.start_time) 21 | 22 | prefix = "%s - %s - %s" % ( 23 | record.levelname, 24 | time.strftime("%x %X"), 25 | timedelta(seconds=elapsed_seconds), 26 | ) 27 | message = record.getMessage() 28 | message = message.replace("\n", "\n" + " " * (len(prefix) + 3)) 29 | return "%s - %s" % (prefix, message) if message else "" 30 | 31 | 32 | def create_logger(filepath, rank): 33 | """ 34 | Create a logger. 35 | Use a different log file for each process. 36 | """ 37 | # create log formatter 38 | log_formatter = LogFormatter() 39 | 40 | # create file handler and set level to debug 41 | if filepath is not None: 42 | if rank > 0: 43 | filepath = "%s-%i" % (filepath, rank) 44 | file_handler = logging.FileHandler(filepath, "a") 45 | file_handler.setLevel(logging.DEBUG) 46 | file_handler.setFormatter(log_formatter) 47 | 48 | # create console handler and set level to info 49 | console_handler = logging.StreamHandler() 50 | console_handler.setLevel(logging.INFO) 51 | console_handler.setFormatter(log_formatter) 52 | 53 | # create logger and set level to debug 54 | logger = logging.getLogger() 55 | logger.handlers = [] 56 | logger.setLevel(logging.DEBUG) 57 | logger.propagate = False 58 | if filepath is not None: 59 | logger.addHandler(file_handler) 60 | logger.addHandler(console_handler) 61 | 62 | # reset logger elapsed time 63 | def reset_time(): 64 | log_formatter.start_time = time.time() 65 | 66 | logger.reset_time = reset_time 67 | 68 | return logger 69 | 70 | 71 | class PD_Stats(object): 72 | """ 73 | Log stuff with pandas library 74 | """ 75 | 76 | def __init__(self, path, columns): 77 | self.path = path 78 | 79 | # reload path stats 80 | if os.path.isfile(self.path): 81 | self.stats = pd.read_pickle(self.path) 82 | 83 | # check that columns are the same 84 | assert list(self.stats.columns) == list(columns) 85 | 86 | else: 87 | self.stats = pd.DataFrame(columns=columns) 88 | 89 | def update(self, row, save=True): 90 | self.stats.loc[len(self.stats.index)] = row 91 | 92 | # save the statistics 93 | if save: 94 | self.stats.to_pickle(self.path) 95 | -------------------------------------------------------------------------------- /examples/pretraining/swav/src/model.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file defines the SwAVModel class, a wrapper around WILDS-Unlabeled architectures 3 | # that implements the changes necessary to make the networks compatible with SwAV 4 | # training (e.g. prototypes, projection head, etc.). Currently, the supported architectures 5 | # are ResNets and DenseNets. 6 | # 7 | 8 | import os 9 | import sys 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torchvision.models as models 15 | 16 | sys.path.insert(1, os.path.join(sys.path[0], '../../..')) 17 | import examples.models.resnet_multispectral as resnet_ms 18 | 19 | class SwAVModel(nn.Module): 20 | def __init__( 21 | self, 22 | base_model, 23 | normalize=False, 24 | output_dim=0, 25 | hidden_mlp=0, 26 | nmb_prototypes=0, 27 | ): 28 | super(SwAVModel, self).__init__() 29 | 30 | self.base_model = base_model # base CNN architecture 31 | self.l2norm = normalize # whether to normalize output features 32 | 33 | # projection head 34 | last_dim = base_model.d_out # output dimensionality of final featurizer layer 35 | if output_dim == 0: 36 | self.projection_head = None 37 | elif hidden_mlp == 0: 38 | self.projection_head = nn.Linear(last_dim, output_dim) 39 | else: 40 | self.projection_head = nn.Sequential( 41 | nn.Linear(last_dim, hidden_mlp), 42 | nn.BatchNorm1d(hidden_mlp), 43 | nn.ReLU(inplace=True), 44 | nn.Linear(hidden_mlp, output_dim), 45 | ) 46 | 47 | # prototype layer 48 | self.prototypes = None 49 | if isinstance(nmb_prototypes, list): 50 | self.prototypes = MultiPrototypes(output_dim, nmb_prototypes) 51 | elif nmb_prototypes > 0: 52 | self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False) 53 | 54 | def forward_head(self, x): 55 | if self.projection_head is not None: 56 | x = self.projection_head(x) 57 | 58 | if self.l2norm: 59 | x = F.normalize(x, dim=1, p=2) 60 | 61 | if self.prototypes is not None: 62 | return x, self.prototypes(x) 63 | return x 64 | 65 | def forward(self, inputs): 66 | if not isinstance(inputs, list): 67 | inputs = [inputs] 68 | idx_crops = torch.cumsum(torch.unique_consecutive( 69 | torch.tensor([inp.shape[-1] for inp in inputs]), 70 | return_counts=True, 71 | )[1], 0) 72 | start_idx = 0 73 | for end_idx in idx_crops: 74 | _out = self.base_model( 75 | torch.cat(inputs[start_idx: end_idx]).cuda(non_blocking=True)) 76 | if start_idx == 0: 77 | output = _out 78 | else: 79 | output = torch.cat((output, _out)) 80 | start_idx = end_idx 81 | return self.forward_head(output) 82 | 83 | class MultiPrototypes(nn.Module): 84 | def __init__(self, output_dim, nmb_prototypes): 85 | super(MultiPrototypes, self).__init__() 86 | self.nmb_heads = len(nmb_prototypes) 87 | for i, k in enumerate(nmb_prototypes): 88 | self.add_module("prototypes" + str(i), nn.Linear(output_dim, k, bias=False)) 89 | 90 | def forward(self, x): 91 | out = [] 92 | for i in range(self.nmb_heads): 93 | out.append(getattr(self, "prototypes" + str(i))(x)) 94 | return out 95 | -------------------------------------------------------------------------------- /examples/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau, StepLR, CosineAnnealingLR, MultiStepLR 2 | 3 | def initialize_scheduler(config, optimizer, n_train_steps): 4 | # construct schedulers 5 | if config.scheduler is None: 6 | return None 7 | elif config.scheduler == 'linear_schedule_with_warmup': 8 | from transformers import get_linear_schedule_with_warmup 9 | scheduler = get_linear_schedule_with_warmup( 10 | optimizer, 11 | num_training_steps=n_train_steps, 12 | **config.scheduler_kwargs) 13 | step_every_batch = True 14 | use_metric = False 15 | elif config.scheduler == 'cosine_schedule_with_warmup': 16 | from transformers import get_cosine_schedule_with_warmup 17 | scheduler = get_cosine_schedule_with_warmup( 18 | optimizer, 19 | num_training_steps=n_train_steps, 20 | **config.scheduler_kwargs) 21 | step_every_batch = True 22 | use_metric = False 23 | elif config.scheduler=='ReduceLROnPlateau': 24 | assert config.scheduler_metric_name, f'scheduler metric must be specified for {config.scheduler}' 25 | scheduler = ReduceLROnPlateau( 26 | optimizer, 27 | **config.scheduler_kwargs) 28 | step_every_batch = False 29 | use_metric = True 30 | elif config.scheduler == 'StepLR': 31 | scheduler = StepLR(optimizer, **config.scheduler_kwargs) 32 | step_every_batch = False 33 | use_metric = False 34 | elif config.scheduler == 'FixMatchLR': 35 | scheduler = LambdaLR( 36 | optimizer, 37 | lambda x: (1.0 + 10 * float(x) / n_train_steps) ** -0.75 38 | ) 39 | step_every_batch = True 40 | use_metric = False 41 | elif config.scheduler == 'MultiStepLR': 42 | scheduler = MultiStepLR(optimizer, **config.scheduler_kwargs) 43 | step_every_batch = False 44 | use_metric = False 45 | else: 46 | raise ValueError(f'Scheduler: {config.scheduler} not supported.') 47 | 48 | # add an step_every_batch field 49 | scheduler.step_every_batch = step_every_batch 50 | scheduler.use_metric = use_metric 51 | return scheduler 52 | 53 | def step_scheduler(scheduler, metric=None): 54 | if isinstance(scheduler, ReduceLROnPlateau): 55 | assert metric is not None 56 | scheduler.step(metric) 57 | else: 58 | scheduler.step() 59 | 60 | class LinearScheduleWithWarmupAndThreshold(): 61 | """ 62 | Linear scheduler with warmup and threshold for non lr parameters. 63 | Parameters is held at 0 until some T1, linearly increased until T2, and then held 64 | at some max value after T2. 65 | Designed to be called by step_scheduler() above and used within Algorithm class. 66 | Args: 67 | - last_warmup_step: aka T1. for steps [0, T1) keep param = 0 68 | - threshold_step: aka T2. step over period [T1, T2) to reach param = max value 69 | - max value: end value of the param 70 | """ 71 | def __init__(self, max_value, last_warmup_step=0, threshold_step=1, step_every_batch=False): 72 | self.max_value = max_value 73 | self.T1 = last_warmup_step 74 | self.T2 = threshold_step 75 | assert (0 <= self.T1) and (self.T1 < self.T2) 76 | 77 | # internal tracker of which step we're on 78 | self.current_step = 0 79 | self.value = 0 80 | 81 | # required fields called in Algorithm when stepping schedulers 82 | self.step_every_batch = step_every_batch 83 | self.use_metric = False 84 | 85 | def step(self): 86 | """This function is first called AFTER step 0, so increment first to set value for next step""" 87 | self.current_step += 1 88 | if self.current_step < self.T1: 89 | self.value = 0 90 | elif self.current_step < self.T2: 91 | self.value = (self.current_step - self.T1) / (self.T2 - self.T1) * self.max_value 92 | else: 93 | self.value = self.max_value 94 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import os 3 | import sys 4 | 5 | here = os.path.abspath(os.path.dirname(__file__)) 6 | sys.path.insert(0, os.path.join(here, 'wilds')) 7 | from version import __version__ 8 | 9 | print(f'Version {__version__}') 10 | 11 | with open("README.md", "r", encoding="utf-8") as fh: 12 | long_description = fh.read() 13 | 14 | setuptools.setup( 15 | name="wilds", 16 | version=__version__, 17 | author="WILDS team", 18 | author_email="wilds@cs.stanford.edu", 19 | url="https://wilds.stanford.edu", 20 | description="WILDS distribution shift benchmark", 21 | long_description=long_description, 22 | long_description_content_type="text/markdown", 23 | install_requires = [ 24 | 'numpy>=1.19.1', 25 | 'ogb>=1.2.6', 26 | 'outdated>=0.2.0', 27 | 'pandas>=1.1.0', 28 | 'pillow>=7.2.0', 29 | 'ogb>=1.2.6', 30 | 'pytz>=2020.4', 31 | 'torch>=1.7.0', 32 | 'torchvision>=0.8.2', 33 | 'tqdm>=4.53.0', 34 | 'scikit-learn>=0.20.0', 35 | 'scipy>=1.5.4' 36 | ], 37 | license='MIT', 38 | packages=setuptools.find_packages(exclude=['dataset_preprocessing', 'examples', 'examples.models', 'examples.models.bert', 'examples.data_augmentation']), 39 | classifiers=[ 40 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 41 | 'Intended Audience :: Science/Research', 42 | "Programming Language :: Python :: 3", 43 | "License :: OSI Approved :: MIT License", 44 | ], 45 | python_requires='>=3.6', 46 | ) 47 | -------------------------------------------------------------------------------- /wilds/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from .get_dataset import get_dataset 3 | 4 | benchmark_datasets = [ 5 | 'amazon', 6 | 'camelyon17', 7 | 'civilcomments', 8 | 'iwildcam', 9 | 'ogb-molpcba', 10 | 'poverty', 11 | 'fmow', 12 | 'py150', 13 | 'rxrx1', 14 | 'globalwheat', 15 | ] 16 | 17 | additional_datasets = [ 18 | 'celebA', 19 | 'domainnet', 20 | 'waterbirds', 21 | 'yelp', 22 | 'bdd100k', 23 | 'sqf', 24 | 'encode' 25 | ] 26 | 27 | supported_datasets = benchmark_datasets + additional_datasets 28 | 29 | unlabeled_datasets = [ 30 | 'amazon', 31 | 'camelyon17', 32 | 'domainnet', 33 | 'civilcomments', 34 | 'iwildcam', 35 | 'ogb-molpcba', 36 | 'poverty', 37 | 'fmow', 38 | 'globalwheat', 39 | ] 40 | 41 | unlabeled_splits = [ 42 | 'train_unlabeled', 43 | 'val_unlabeled', 44 | 'test_unlabeled', 45 | 'extra_unlabeled' 46 | ] -------------------------------------------------------------------------------- /wilds/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p-lambda/wilds/472677590de351857197a9bf24958838c39c272b/wilds/common/__init__.py -------------------------------------------------------------------------------- /wilds/common/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p-lambda/wilds/472677590de351857197a9bf24958838c39c272b/wilds/common/metrics/__init__.py -------------------------------------------------------------------------------- /wilds/common/metrics/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from wilds.common.utils import avg_over_groups, maximum 3 | from wilds.common.metrics.metric import ElementwiseMetric, Metric, MultiTaskMetric 4 | 5 | class Loss(Metric): 6 | def __init__(self, loss_fn, name=None): 7 | self.loss_fn = loss_fn 8 | if name is None: 9 | name = 'loss' 10 | super().__init__(name=name) 11 | 12 | def _compute(self, y_pred, y_true): 13 | """ 14 | Helper for computing element-wise metric, implemented for each metric 15 | Args: 16 | - y_pred (Tensor): Predicted targets or model output 17 | - y_true (Tensor): True targets 18 | Output: 19 | - element_wise_metrics (Tensor): tensor of size (batch_size, ) 20 | """ 21 | return self.loss_fn(y_pred, y_true) 22 | 23 | def worst(self, metrics): 24 | """ 25 | Given a list/numpy array/Tensor of metrics, computes the worst-case metric 26 | Args: 27 | - metrics (Tensor, numpy array, or list): Metrics 28 | Output: 29 | - worst_metric (float): Worst-case metric 30 | """ 31 | return maximum(metrics) 32 | 33 | class ElementwiseLoss(ElementwiseMetric): 34 | def __init__(self, loss_fn, name=None): 35 | self.loss_fn = loss_fn 36 | if name is None: 37 | name = 'loss' 38 | super().__init__(name=name) 39 | 40 | def _compute_element_wise(self, y_pred, y_true): 41 | """ 42 | Helper for computing element-wise metric, implemented for each metric 43 | Args: 44 | - y_pred (Tensor): Predicted targets or model output 45 | - y_true (Tensor): True targets 46 | Output: 47 | - element_wise_metrics (Tensor): tensor of size (batch_size, ) 48 | """ 49 | return self.loss_fn(y_pred, y_true) 50 | 51 | def worst(self, metrics): 52 | """ 53 | Given a list/numpy array/Tensor of metrics, computes the worst-case metric 54 | Args: 55 | - metrics (Tensor, numpy array, or list): Metrics 56 | Output: 57 | - worst_metric (float): Worst-case metric 58 | """ 59 | return maximum(metrics) 60 | 61 | class MultiTaskLoss(MultiTaskMetric): 62 | def __init__(self, loss_fn, name=None): 63 | self.loss_fn = loss_fn # should be elementwise 64 | if name is None: 65 | name = 'loss' 66 | super().__init__(name=name) 67 | 68 | def _compute_flattened(self, flattened_y_pred, flattened_y_true): 69 | if isinstance(self.loss_fn, torch.nn.BCEWithLogitsLoss): 70 | flattened_y_pred = flattened_y_pred.float() 71 | flattened_y_true = flattened_y_true.float() 72 | elif isinstance(self.loss_fn, torch.nn.CrossEntropyLoss): 73 | flattened_y_true = flattened_y_true.long() 74 | flattened_loss = self.loss_fn(flattened_y_pred, flattened_y_true) 75 | return flattened_loss 76 | 77 | def worst(self, metrics): 78 | """ 79 | Given a list/numpy array/Tensor of metrics, computes the worst-case metric 80 | Args: 81 | - metrics (Tensor, numpy array, or list): Metrics 82 | Output: 83 | - worst_metric (float): Worst-case metric 84 | """ 85 | return maximum(metrics) 86 | -------------------------------------------------------------------------------- /wilds/common/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Subset 4 | from pandas.api.types import CategoricalDtype 5 | 6 | def minimum(numbers, empty_val=0.): 7 | if isinstance(numbers, torch.Tensor): 8 | if numbers.numel()==0: 9 | return torch.tensor(empty_val, device=numbers.device) 10 | else: 11 | return numbers[~torch.isnan(numbers)].min() 12 | elif isinstance(numbers, np.ndarray): 13 | if numbers.size==0: 14 | return np.array(empty_val) 15 | else: 16 | return np.nanmin(numbers) 17 | else: 18 | if len(numbers)==0: 19 | return empty_val 20 | else: 21 | return min(numbers) 22 | 23 | def maximum(numbers, empty_val=0.): 24 | if isinstance(numbers, torch.Tensor): 25 | if numbers.numel()==0: 26 | return torch.tensor(empty_val, device=numbers.device) 27 | else: 28 | return numbers[~torch.isnan(numbers)].max() 29 | elif isinstance(numbers, np.ndarray): 30 | if numbers.size==0: 31 | return np.array(empty_val) 32 | else: 33 | return np.nanmax(numbers) 34 | else: 35 | if len(numbers)==0: 36 | return empty_val 37 | else: 38 | return max(numbers) 39 | 40 | def split_into_groups(g): 41 | """ 42 | Args: 43 | - g (Tensor): Vector of groups 44 | Returns: 45 | - groups (Tensor): Unique groups present in g 46 | - group_indices (list): List of Tensors, where the i-th tensor is the indices of the 47 | elements of g that equal groups[i]. 48 | Has the same length as len(groups). 49 | - unique_counts (Tensor): Counts of each element in groups. 50 | Has the same length as len(groups). 51 | """ 52 | unique_groups, unique_counts = torch.unique(g, sorted=False, return_counts=True) 53 | group_indices = [] 54 | for group in unique_groups: 55 | group_indices.append( 56 | torch.nonzero(g == group, as_tuple=True)[0]) 57 | return unique_groups, group_indices, unique_counts 58 | 59 | def get_counts(g, n_groups): 60 | """ 61 | This differs from split_into_groups in how it handles missing groups. 62 | get_counts always returns a count Tensor of length n_groups, 63 | whereas split_into_groups returns a unique_counts Tensor 64 | whose length is the number of unique groups present in g. 65 | Args: 66 | - g (Tensor): Vector of groups 67 | Returns: 68 | - counts (Tensor): A list of length n_groups, denoting the count of each group. 69 | """ 70 | unique_groups, unique_counts = torch.unique(g, sorted=False, return_counts=True) 71 | counts = torch.zeros(n_groups, device=g.device) 72 | counts[unique_groups] = unique_counts.float() 73 | return counts 74 | 75 | def avg_over_groups(v, g, n_groups): 76 | """ 77 | Args: 78 | v (Tensor): Vector containing the quantity to average over. 79 | g (Tensor): Vector of the same length as v, containing group information. 80 | Returns: 81 | group_avgs (Tensor): Vector of length num_groups 82 | group_counts (Tensor) 83 | """ 84 | import torch_scatter 85 | assert v.device==g.device 86 | assert v.numel()==g.numel() 87 | group_count = get_counts(g, n_groups) 88 | group_avgs = torch_scatter.scatter(src=v, index=g, dim_size=n_groups, reduce='mean') 89 | return group_avgs, group_count 90 | 91 | def map_to_id_array(df, ordered_map={}): 92 | maps = {} 93 | array = np.zeros(df.shape) 94 | for i, c in enumerate(df.columns): 95 | if c in ordered_map: 96 | category_type = CategoricalDtype(categories=ordered_map[c], ordered=True) 97 | else: 98 | category_type = 'category' 99 | series = df[c].astype(category_type) 100 | maps[c] = series.cat.categories.values 101 | array[:,i] = series.cat.codes.values 102 | return maps, array 103 | 104 | def subsample_idxs(idxs, num=5000, take_rest=False, seed=None): 105 | seed = (seed + 541433) if seed is not None else None 106 | rng = np.random.default_rng(seed) 107 | 108 | idxs = idxs.copy() 109 | rng.shuffle(idxs) 110 | if take_rest: 111 | idxs = idxs[num:] 112 | else: 113 | idxs = idxs[:num] 114 | return idxs 115 | 116 | def shuffle_arr(arr, seed=None): 117 | seed = (seed + 548207) if seed is not None else None 118 | rng = np.random.default_rng(seed) 119 | 120 | arr = arr.copy() 121 | rng.shuffle(arr) 122 | return arr 123 | 124 | def threshold_at_recall(y_pred, y_true, global_recall=60): 125 | """ Calculate the model threshold to use to achieve a desired global_recall level. Assumes that 126 | y_true is a vector of the true binary labels.""" 127 | return np.percentile(y_pred[y_true == 1], 100-global_recall) 128 | 129 | def numel(obj): 130 | if torch.is_tensor(obj): 131 | return obj.numel() 132 | elif isinstance(obj, list): 133 | return len(obj) 134 | else: 135 | raise TypeError("Invalid type for numel") 136 | -------------------------------------------------------------------------------- /wilds/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p-lambda/wilds/472677590de351857197a9bf24958838c39c272b/wilds/datasets/__init__.py -------------------------------------------------------------------------------- /wilds/datasets/archive/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p-lambda/wilds/472677590de351857197a9bf24958838c39c272b/wilds/datasets/archive/__init__.py -------------------------------------------------------------------------------- /wilds/datasets/celebA_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pandas as pd 4 | from PIL import Image 5 | import numpy as np 6 | from wilds.datasets.wilds_dataset import WILDSDataset 7 | from wilds.common.grouper import CombinatorialGrouper 8 | from wilds.common.metrics.all_metrics import Accuracy 9 | 10 | class CelebADataset(WILDSDataset): 11 | """ 12 | A variant of the CelebA dataset. 13 | This dataset is not part of the official WILDS benchmark. 14 | We provide it for convenience and to facilitate comparisons to previous work. 15 | 16 | Supported `split_scheme`: 17 | 'official' 18 | 19 | Input (x): 20 | Images of celebrity faces that have already been cropped and centered. 21 | 22 | Label (y): 23 | y is binary. It is 1 if the celebrity in the image has blond hair, and is 0 otherwise. 24 | 25 | Metadata: 26 | Each image is annotated with whether the celebrity has been labeled 'Male' or 'Female'. 27 | 28 | Website: 29 | http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html 30 | 31 | Original publication: 32 | @inproceedings{liu2015faceattributes, 33 | title = {Deep Learning Face Attributes in the Wild}, 34 | author = {Liu, Ziwei and Luo, Ping and Wang, Xiaogang and Tang, Xiaoou}, 35 | booktitle = {Proceedings of International Conference on Computer Vision (ICCV)}, 36 | month = {December}, 37 | year = {2015} 38 | } 39 | 40 | This variant of the dataset is identical to the setup in: 41 | @inproceedings{sagawa2019distributionally, 42 | title = {Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization}, 43 | author = {Sagawa, Shiori and Koh, Pang Wei and Hashimoto, Tatsunori B and Liang, Percy}, 44 | booktitle = {International Conference on Learning Representations}, 45 | year = {2019} 46 | } 47 | 48 | License: 49 | This version of the dataset was originally downloaded from Kaggle 50 | https://www.kaggle.com/jessicali9530/celeba-dataset 51 | 52 | It is available for non-commercial research purposes only. 53 | """ 54 | _dataset_name = 'celebA' 55 | _versions_dict = { 56 | '1.0': { 57 | 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xfe55077f5cd541f985ebf9ec50473293/contents/blob/', 58 | 'compressed_size': 1_308_557_312}} 59 | 60 | def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): 61 | self._version = version 62 | self._data_dir = self.initialize_data_dir(root_dir, download) 63 | target_name = 'Blond_Hair' 64 | confounder_names = ['Male'] 65 | 66 | # Read in attributes 67 | attrs_df = pd.read_csv( 68 | os.path.join(self.data_dir, 'list_attr_celeba.csv')) 69 | 70 | # Split out filenames and attribute names 71 | # Note: idx and filenames are off by one. 72 | self._input_array = attrs_df['image_id'].values 73 | self._original_resolution = (178, 218) 74 | attrs_df = attrs_df.drop(labels='image_id', axis='columns') 75 | attr_names = attrs_df.columns.copy() 76 | def attr_idx(attr_name): 77 | return attr_names.get_loc(attr_name) 78 | 79 | # Then cast attributes to numpy array and set them to 0 and 1 80 | # (originally, they're -1 and 1) 81 | attrs_df = attrs_df.values 82 | attrs_df[attrs_df == -1] = 0 83 | 84 | # Get the y values 85 | target_idx = attr_idx(target_name) 86 | self._y_array = torch.LongTensor(attrs_df[:, target_idx]) 87 | self._y_size = 1 88 | self._n_classes = 2 89 | 90 | # Get metadata 91 | confounder_idx = [attr_idx(a) for a in confounder_names] 92 | confounders = attrs_df[:, confounder_idx] 93 | 94 | self._metadata_array = torch.cat( 95 | (torch.LongTensor(confounders), self._y_array.reshape((-1, 1))), 96 | dim=1) 97 | confounder_names = [s.lower() for s in confounder_names] 98 | self._metadata_fields = confounder_names + ['y'] 99 | self._metadata_map = { 100 | 'y': ['not blond', ' blond'] # Padding for str formatting 101 | } 102 | 103 | self._eval_grouper = CombinatorialGrouper( 104 | dataset=self, 105 | groupby_fields=(confounder_names + ['y'])) 106 | 107 | # Extract splits 108 | self._split_scheme = split_scheme 109 | if self._split_scheme != 'official': 110 | raise ValueError(f'Split scheme {self._split_scheme} not recognized') 111 | split_df = pd.read_csv( 112 | os.path.join(self.data_dir, 'list_eval_partition.csv')) 113 | self._split_array = split_df['partition'].values 114 | 115 | super().__init__(root_dir, download, split_scheme) 116 | 117 | def get_input(self, idx): 118 | # Note: idx and filenames are off by one. 119 | img_filename = os.path.join( 120 | self.data_dir, 121 | 'img_align_celeba', 122 | self._input_array[idx]) 123 | x = Image.open(img_filename).convert('RGB') 124 | return x 125 | 126 | def eval(self, y_pred, y_true, metadata, prediction_fn=None): 127 | """ 128 | Computes all evaluation metrics. 129 | Args: 130 | - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). 131 | But they can also be other model outputs such that prediction_fn(y_pred) 132 | are predicted labels. 133 | - y_true (LongTensor): Ground-truth labels 134 | - metadata (Tensor): Metadata 135 | - prediction_fn (function): A function that turns y_pred into predicted labels 136 | Output: 137 | - results (dictionary): Dictionary of evaluation metrics 138 | - results_str (str): String summarizing the evaluation metrics 139 | """ 140 | metric = Accuracy(prediction_fn=prediction_fn) 141 | return self.standard_group_eval( 142 | metric, 143 | self._eval_grouper, 144 | y_pred, y_true, metadata) 145 | -------------------------------------------------------------------------------- /wilds/datasets/ogbmolpcba_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from wilds.datasets.wilds_dataset import WILDSDataset 5 | from ogb.graphproppred import PygGraphPropPredDataset, Evaluator 6 | from ogb.utils.url import download_url 7 | import torch_geometric 8 | if torch_geometric.__version__ >= '2.0.0': 9 | from torch_geometric.loader.dataloader import Collater as PyGCollater 10 | else: 11 | from torch_geometric.data.dataloader import Collater as PyGCollater 12 | 13 | class OGBPCBADataset(WILDSDataset): 14 | """ 15 | The OGB-molpcba dataset. 16 | This dataset is directly adopted from Open Graph Benchmark, and originally curated by MoleculeNet. 17 | 18 | Supported `split_scheme`: 19 | - 'official' or 'scaffold', which are equivalent 20 | 21 | Input (x): 22 | Molecular graphs represented as Pytorch Geometric data objects 23 | 24 | Label (y): 25 | y represents 128-class binary labels. 26 | 27 | Metadata: 28 | - scaffold 29 | Each molecule is annotated with the scaffold ID that the molecule is assigned to. 30 | 31 | Website: 32 | https://ogb.stanford.edu/docs/graphprop/#ogbg-mol 33 | 34 | Original publication: 35 | @article{hu2020ogb, 36 | title={Open Graph Benchmark: Datasets for Machine Learning on Graphs}, 37 | author={W. {Hu}, M. {Fey}, M. {Zitnik}, Y. {Dong}, H. {Ren}, B. {Liu}, M. {Catasta}, J. {Leskovec}}, 38 | journal={arXiv preprint arXiv:2005.00687}, 39 | year={2020} 40 | } 41 | 42 | @article{wu2018moleculenet, 43 | title={MoleculeNet: a benchmark for molecular machine learning}, 44 | author={Z. {Wu}, B. {Ramsundar}, E. V {Feinberg}, J. {Gomes}, C. {Geniesse}, A. S {Pappu}, K. {Leswing}, V. {Pande}}, 45 | journal={Chemical science}, 46 | volume={9}, 47 | number={2}, 48 | pages={513--530}, 49 | year={2018}, 50 | publisher={Royal Society of Chemistry} 51 | } 52 | 53 | License: 54 | This dataset is distributed under the MIT license. 55 | https://github.com/snap-stanford/ogb/blob/master/LICENSE 56 | """ 57 | 58 | _dataset_name = 'ogb-molpcba' 59 | _versions_dict = { 60 | '1.0': { 61 | 'download_url': None, 62 | 'compressed_size': None}} 63 | 64 | def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): 65 | self._version = version 66 | if version is not None: 67 | raise ValueError('Versioning for OGB-MolPCBA is handled through the OGB package. Please set version=none.') 68 | # internally call ogb package 69 | self.ogb_dataset = PygGraphPropPredDataset(name = 'ogbg-molpcba', root = root_dir) 70 | 71 | # set variables 72 | self._data_dir = self.ogb_dataset.root 73 | if split_scheme=='official': 74 | split_scheme = 'scaffold' 75 | self._split_scheme = split_scheme 76 | self._y_type = 'float' # although the task is binary classification, the prediction target contains nan value, thus we need float 77 | self._y_size = self.ogb_dataset.num_tasks 78 | self._n_classes = self.ogb_dataset.__num_classes__ 79 | 80 | self._split_array = torch.zeros(len(self.ogb_dataset)).long() 81 | split_idx = self.ogb_dataset.get_idx_split() 82 | self._split_array[split_idx['train']] = 0 83 | self._split_array[split_idx['valid']] = 1 84 | self._split_array[split_idx['test']] = 2 85 | 86 | self._y_array = self.ogb_dataset.data.y 87 | 88 | self._metadata_fields = ['scaffold'] 89 | 90 | metadata_file_path = os.path.join(self.ogb_dataset.root, 'raw', 'scaffold_group.npy') 91 | if not os.path.exists(metadata_file_path): 92 | download_url('https://snap.stanford.edu/ogb/data/misc/ogbg_molpcba/scaffold_group.npy', os.path.join(self.ogb_dataset.root, 'raw')) 93 | self._metadata_array = torch.from_numpy(np.load(metadata_file_path)).reshape(-1,1).long() 94 | 95 | if torch_geometric.__version__ >= '1.7.0': 96 | self._collate = PyGCollater(follow_batch=[], exclude_keys=[]) 97 | else: 98 | self._collate = PyGCollater(follow_batch=[]) 99 | 100 | self._metric = Evaluator('ogbg-molpcba') 101 | 102 | super().__init__(root_dir, download, split_scheme) 103 | 104 | def get_input(self, idx): 105 | return self.ogb_dataset[int(idx)] 106 | 107 | def eval(self, y_pred, y_true, metadata, prediction_fn=None): 108 | """ 109 | Computes all evaluation metrics. 110 | Args: 111 | - y_pred (FloatTensor): Binary logits from a model 112 | - y_true (LongTensor): Ground-truth labels 113 | - metadata (Tensor): Metadata 114 | - prediction_fn (function): A function that turns y_pred into predicted labels. 115 | Only None is supported because OGB Evaluators accept binary logits 116 | Output: 117 | - results (dictionary): Dictionary of evaluation metrics 118 | - results_str (str): String summarizing the evaluation metrics 119 | """ 120 | assert prediction_fn is None, "OGBPCBADataset.eval() does not support prediction_fn. Only binary logits accepted" 121 | input_dict = {"y_true": y_true, "y_pred": y_pred} 122 | results = self._metric.eval(input_dict) 123 | 124 | return results, f"Average precision: {results['ap']:.3f}\n" 125 | -------------------------------------------------------------------------------- /wilds/datasets/unlabeled/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p-lambda/wilds/472677590de351857197a9bf24958838c39c272b/wilds/datasets/unlabeled/__init__.py -------------------------------------------------------------------------------- /wilds/datasets/unlabeled/camelyon17_unlabeled_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from PIL import Image 7 | 8 | from wilds.datasets.camelyon17_dataset import TEST_CENTER, VAL_CENTER 9 | from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSUnlabeledDataset 10 | from wilds.common.grouper import CombinatorialGrouper 11 | 12 | 13 | class Camelyon17UnlabeledDataset(WILDSUnlabeledDataset): 14 | """ 15 | Unlabeled Camelyon17-WILDS dataset. 16 | This dataset contains patches from all of the slides in the original CAMELYON17 training data, 17 | except for the slides that were labeled with lesion annotations and therefore used in the 18 | labeled Camelyon17Dataset. 19 | 20 | Supported `split_scheme`: 21 | 'official' 22 | 23 | Input (x): 24 | 96x96 image patches extracted from histopathology slides. 25 | 26 | Metadata: 27 | Each patch is annotated with the ID of the hospital it came from (integer from 0 to 4) 28 | and the slide it came from (integer from 0 to 49). 29 | 30 | Website: 31 | https://camelyon17.grand-challenge.org/ 32 | 33 | Original publication: 34 | @article{bandi2018detection, 35 | title={From detection of individual metastases to classification of lymph node status at the patient level: the camelyon17 challenge}, 36 | author={Bandi, Peter and Geessink, Oscar and Manson, Quirine and Van Dijk, Marcory and Balkenhol, Maschenka and Hermsen, Meyke and Bejnordi, Babak Ehteshami and Lee, Byungjae and Paeng, Kyunghyun and Zhong, Aoxiao and others}, 37 | journal={IEEE transactions on medical imaging}, 38 | volume={38}, 39 | number={2}, 40 | pages={550--560}, 41 | year={2018}, 42 | publisher={IEEE} 43 | } 44 | 45 | License: 46 | This dataset is in the public domain and is distributed under CC0. 47 | https://creativecommons.org/publicdomain/zero/1.0/ 48 | """ 49 | 50 | _dataset_name = "camelyon17_unlabeled" 51 | _versions_dict = { 52 | "1.0": { 53 | "download_url": "https://worksheets.codalab.org/rest/bundles/0xa78be8a88a00487a92006936514967d2/contents/blob/", 54 | "compressed_size": 69_442_379_933, 55 | } 56 | } 57 | 58 | def __init__( 59 | self, version=None, root_dir="data", download=False, split_scheme="official" 60 | ): 61 | self._version = version 62 | self._data_dir = self.initialize_data_dir(root_dir, download) 63 | self._original_resolution = (96, 96) 64 | 65 | # Read in metadata 66 | self._metadata_df = pd.read_csv( 67 | os.path.join(self._data_dir, "metadata.csv"), 68 | index_col=0, 69 | dtype={"patient": "str"}, 70 | ) 71 | 72 | # Get filenames 73 | self._input_array = [ 74 | f"patches/patient_{patient}_node_{node}/patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png" 75 | for patient, node, x, y in self._metadata_df.loc[ 76 | :, ["patient", "node", "x_coord", "y_coord"] 77 | ].itertuples(index=False, name=None) 78 | ] 79 | 80 | self._split_scheme = split_scheme 81 | if self._split_scheme == "official": 82 | self._split_dict = { 83 | "train_unlabeled": 10, 84 | "val_unlabeled": 11, 85 | "test_unlabeled": 12, 86 | } 87 | self._split_names = { 88 | "train_unlabeled": "Unlabeled Train", 89 | "val_unlabeled": "Unlabeled Validation", 90 | "test_unlabeled": "Unlabeled Test", 91 | } 92 | else: 93 | raise ValueError(f"Split scheme {self._split_scheme} not recognized") 94 | 95 | # Extract splits 96 | centers = self._metadata_df["center"].values.astype("long") 97 | num_centers = int(np.max(centers)) + 1 98 | self._metadata_df["split"] = self.split_dict["train_unlabeled"] 99 | val_center_mask = self._metadata_df["center"] == VAL_CENTER 100 | test_center_mask = self._metadata_df["center"] == TEST_CENTER 101 | self._metadata_df.loc[val_center_mask, "split"] = self.split_dict[ 102 | "val_unlabeled" 103 | ] 104 | self._metadata_df.loc[test_center_mask, "split"] = self.split_dict[ 105 | "test_unlabeled" 106 | ] 107 | # Centers 1 and 2 have 600,030 unlabeled examples each. 108 | # The rest of the unlabeled data is used for the train_unlabeled split (1,799,247 total). 109 | assert self._metadata_df.loc[val_center_mask].shape[0] == 600_030 110 | assert self._metadata_df.loc[test_center_mask].shape[0] == 600_030 111 | train_center_mask = ~self._metadata_df["center"].isin([VAL_CENTER, TEST_CENTER]) 112 | assert self._metadata_df.loc[train_center_mask].shape[0] == 1_799_247 113 | 114 | self._split_array = self._metadata_df["split"].values 115 | 116 | self._y_array = 100 * torch.LongTensor(self._metadata_df["tumor"].values) # in metadata.csv, these are all -1 117 | self._metadata_array = torch.stack( 118 | ( 119 | torch.LongTensor(centers), 120 | torch.LongTensor(self._metadata_df["slide"].values), 121 | self._y_array, 122 | ), 123 | dim=1, 124 | ) 125 | self._metadata_fields = ["hospital", "slide", "y"] 126 | 127 | self._eval_grouper = CombinatorialGrouper( 128 | dataset=self, groupby_fields=["slide"] 129 | ) 130 | 131 | super().__init__(root_dir, download, split_scheme) 132 | 133 | def get_input(self, idx): 134 | """ 135 | Returns x for a given idx. 136 | """ 137 | img_filename = os.path.join(self.data_dir, self._input_array[idx]) 138 | x = Image.open(img_filename).convert("RGB") 139 | return x 140 | -------------------------------------------------------------------------------- /wilds/datasets/unlabeled/civilcomments_unlabeled_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from typing import Any, Dict, List, Optional, Tuple, Union 4 | 5 | import torch 6 | import pandas as pd 7 | import numpy as np 8 | 9 | from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSUnlabeledDataset 10 | from wilds.common.utils import map_to_id_array 11 | 12 | 13 | class CivilCommentsUnlabeledDataset(WILDSUnlabeledDataset): 14 | """ 15 | Unlabeled CivilComments-WILDS toxicity classification dataset. 16 | This is a modified version of the original CivilComments dataset. 17 | 18 | Supported `split_scheme`: 19 | 'official' 20 | 21 | Input (x): 22 | A comment on an online article, comprising one or more sentences of text. 23 | 24 | Website: 25 | https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification 26 | 27 | Original publication: 28 | @inproceedings{borkan2019nuanced, 29 | title={Nuanced metrics for measuring unintended bias with real data for text classification}, 30 | author={Borkan, Daniel and Dixon, Lucas and Sorensen, Jeffrey and Thain, Nithum and Vasserman, Lucy}, 31 | booktitle={Companion Proceedings of The 2019 World Wide Web Conference}, 32 | pages={491--500}, 33 | year={2019} 34 | } 35 | 36 | License: 37 | This dataset is in the public domain and is distributed under CC0. 38 | https://creativecommons.org/publicdomain/zero/1.0/ 39 | """ 40 | 41 | _NOT_IN_DATASET: int = -1 42 | 43 | _dataset_name: str = "civilcomments_unlabeled" 44 | _versions_dict: Dict[str, Dict[str, Union[str, int]]] = { 45 | "1.0": { 46 | 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x1c471f23448e4518b000fe47aa7724e0/contents/blob/', 47 | 'compressed_size': 254_142_009 48 | }, 49 | } 50 | 51 | def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): 52 | self._version = version 53 | self._data_dir = self.initialize_data_dir(root_dir, download) 54 | 55 | # Read in metadata 56 | self._metadata_df = pd.read_csv( 57 | os.path.join(self._data_dir, 'unlabeled_data_with_identities.csv'), 58 | index_col=0) 59 | 60 | # Extract text 61 | self._text_array = list(self._metadata_df['comment_text']) 62 | 63 | # Extract splits 64 | self._split_scheme = split_scheme 65 | if self._split_scheme != 'official': 66 | raise ValueError(f'Split scheme {self._split_scheme} not recognized') 67 | 68 | # metadata_df contains split names in strings, so convert them to ints 69 | self._split_dict = { "extra_unlabeled": 13 } 70 | self._split_names = { "extra_unlabeled": "Unlabeled Extra" } 71 | self._metadata_df['split'] = self.split_dict["extra_unlabeled"] 72 | self._split_array = self._metadata_df['split'].values 73 | 74 | # Metadata (Not Available) 75 | # We want grouper to assign all values to their own group, so fill 76 | # all metadata fields with '2'. The normal dataset has binary metadata, 77 | # so this will not overlap. 78 | self._identity_vars = [ 79 | 'male', 80 | 'female', 81 | 'LGBTQ', 82 | 'christian', 83 | 'muslim', 84 | 'other_religions', 85 | 'black', 86 | 'white' 87 | ] 88 | self._auxiliary_vars = [ 89 | 'identity_any', 90 | 'severe_toxicity', 91 | 'obscene', 92 | 'threat', 93 | 'insult', 94 | 'identity_attack', 95 | 'sexual_explicit' 96 | ] 97 | 98 | self._y_array = torch.LongTensor(self._metadata_df['toxicity'].values >= 0.5) 99 | self._metadata_array = torch.cat( 100 | ( 101 | torch.ones( 102 | len(self._metadata_df), 103 | len(self._identity_vars) + len(self._auxiliary_vars) 104 | ) * 2, 105 | self._y_array.unsqueeze(dim=-1) 106 | ), 107 | axis=1 108 | ) 109 | self._metadata_fields = self._identity_vars + self._auxiliary_vars + ['y'] 110 | 111 | super().__init__(root_dir, download, split_scheme) 112 | 113 | def get_input(self, idx): 114 | return self._text_array[idx] 115 | 116 | -------------------------------------------------------------------------------- /wilds/datasets/unlabeled/iwildcam_unlabeled_dataset.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pathlib import Path 3 | import os 4 | 5 | from PIL import Image 6 | import pandas as pd 7 | import numpy as np 8 | import torch 9 | import json 10 | 11 | from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSUnlabeledDataset 12 | from wilds.common.grouper import CombinatorialGrouper 13 | from wilds.common.metrics.all_metrics import Accuracy, Recall, F1 14 | 15 | 16 | class IWildCamUnlabeledDataset(WILDSUnlabeledDataset): 17 | """ 18 | The unlabeled iWildCam2020-WILDS dataset. 19 | This is a modified version of the original iWildCam2020 competition dataset. 20 | Input (x): 21 | RGB images from camera traps 22 | Metadata: 23 | Each image is annotated with the ID of the location (camera trap) it came from. 24 | Website: 25 | http://lila.science/datasets/wcscameratraps 26 | https://library.wcs.org/ScienceData/Camera-Trap-Data-Summary.aspx 27 | Original publication: 28 | @misc{wcsdataset, 29 | title = {Wildlife Conservation Society Camera Traps Dataset}, 30 | howpublished = {\\url{http://lila.science/datasets/wcscameratraps}}, 31 | } 32 | License: 33 | This dataset is distributed under Community Data License Agreement – Permissive – Version 1.0 34 | https://cdla.io/permissive-1-0/ 35 | """ 36 | 37 | _dataset_name = "iwildcam_unlabeled" 38 | _versions_dict = { 39 | "1.0": { 40 | "download_url": "https://worksheets.codalab.org/rest/bundles/0xff56ea50fbf64aabbc4d09b2e8d50e18/contents/blob/", 41 | "compressed_size": 41_016_937_676, 42 | } 43 | } 44 | 45 | def __init__( 46 | self, version=None, root_dir="data", download=False, split_scheme="official" 47 | ): 48 | 49 | self._version = version 50 | self._split_scheme = split_scheme 51 | if self._split_scheme != "official": 52 | raise ValueError(f"Split scheme {self._split_scheme} not recognized") 53 | 54 | # path 55 | self._data_dir = Path(self.initialize_data_dir(root_dir, download)) 56 | 57 | # Load splits 58 | df = pd.read_csv(self._data_dir / "metadata.csv") 59 | 60 | # Splits 61 | self._split_dict = {"extra_unlabeled": 0} 62 | self._split_names = {"extra_unlabeled": "Extra Unlabeled"} 63 | df["split_id"] = 0 64 | self._split_array = df["split_id"].values 65 | 66 | # Filenames 67 | df["filename"] = df["uid"].apply(lambda x: x + ".jpg") 68 | self._input_array = df["filename"].values 69 | 70 | # Location/group info 71 | n_groups = df["location_remapped"].nunique() 72 | self._n_groups = n_groups 73 | 74 | def get_date(x): 75 | if isinstance(x, str): 76 | return datetime.strptime(x, "%Y-%m-%d %H:%M:%S.%f") 77 | else: 78 | return -1 79 | 80 | ## Extract datetime subcomponents and include in metadata 81 | df["datetime_obj"] = df["datetime"].apply(get_date) 82 | df["year"] = df["datetime_obj"].apply( 83 | lambda x: int(x.year) if isinstance(x, datetime) else -1 84 | ) 85 | df["month"] = df["datetime_obj"].apply( 86 | lambda x: int(x.month) if isinstance(x, datetime) else -1 87 | ) 88 | df["day"] = df["datetime_obj"].apply( 89 | lambda x: int(x.day) if isinstance(x, datetime) else -1 90 | ) 91 | df["hour"] = df["datetime_obj"].apply( 92 | lambda x: int(x.hour) if isinstance(x, datetime) else -1 93 | ) 94 | df["minute"] = df["datetime_obj"].apply( 95 | lambda x: int(x.minute) if isinstance(x, datetime) else -1 96 | ) 97 | df["second"] = df["datetime_obj"].apply( 98 | lambda x: int(x.second) if isinstance(x, datetime) else -1 99 | ) 100 | 101 | df["y"] = df["y"].apply( # filter out "bad" labels (-1 means the category was not in iwildcam_v2.0; 99999 means the category was unknown). map all to -100. 102 | lambda x: x if ((x != -1) and (x != 99999)) else -100 103 | ) 104 | self._y_array = torch.LongTensor(df['y'].values) 105 | 106 | self._metadata_array = torch.tensor( 107 | np.stack( 108 | [ 109 | df["location_remapped"].values, 110 | df["sequence_remapped"].values, 111 | df["year"].values, 112 | df["month"].values, 113 | df["day"].values, 114 | df["hour"].values, 115 | df["minute"].values, 116 | df["second"].values, 117 | df["y"], 118 | ], 119 | axis=1, 120 | ) 121 | ) 122 | self._metadata_fields = [ 123 | "location", 124 | "sequence", 125 | "year", 126 | "month", 127 | "day", 128 | "hour", 129 | "minute", 130 | "second", 131 | "y", 132 | ] 133 | 134 | # eval grouper 135 | self._eval_grouper = CombinatorialGrouper( 136 | dataset=self, groupby_fields=(["location"]) 137 | ) 138 | 139 | super().__init__(root_dir, download, split_scheme) 140 | 141 | def get_input(self, idx): 142 | """ 143 | Args: 144 | - idx (int): Index of a data point 145 | Output: 146 | - x (Tensor): Input features of the idx-th data point 147 | """ 148 | 149 | # All images are in the train folder 150 | img_path = self.data_dir / "images" / self._input_array[idx] 151 | img = Image.open(img_path) 152 | 153 | return img 154 | -------------------------------------------------------------------------------- /wilds/datasets/unlabeled/ogbmolpcba_unlabeled_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSUnlabeledDataset 6 | 7 | from ogb.graphproppred import PygGraphPropPredDataset 8 | from ogb.utils.url import download_url 9 | import torch_geometric 10 | if torch_geometric.__version__ >= '2.0.0': 11 | from torch_geometric.loader.dataloader import Collater as PyGCollater 12 | else: 13 | from torch_geometric.data.dataloader import Collater as PyGCollater 14 | 15 | class OGBPCBAUnlabeledDataset(WILDSUnlabeledDataset): 16 | """ 17 | Unlabeled dataset for OGB-molpcba. There are 5 million unlabeled molecules randomly sampled from the entire PubChem database. 18 | 19 | Input (x): 20 | Molecular graphs represented as Pytorch Geometric data objects 21 | 22 | Metadata: 23 | - scaffold 24 | Each molecule is annotated with the scaffold ID that the molecule is assigned to. 25 | 26 | Website: 27 | https://ogb.stanford.edu/docs/graphprop/#ogbg-mol 28 | 29 | Original publication: 30 | @article{hu2020ogb, 31 | title={Open Graph Benchmark: Datasets for Machine Learning on Graphs}, 32 | author={W. {Hu}, M. {Fey}, M. {Zitnik}, Y. {Dong}, H. {Ren}, B. {Liu}, M. {Catasta}, J. {Leskovec}}, 33 | journal={arXiv preprint arXiv:2005.00687}, 34 | year={2020} 35 | } 36 | 37 | License: 38 | This dataset is distributed under the MIT license. 39 | https://github.com/snap-stanford/ogb/blob/master/LICENSE 40 | """ 41 | 42 | _dataset_name = 'ogb-molpcba_unlabeled' 43 | _versions_dict = { 44 | '1.0': { 45 | 'download_url': None, 46 | 'compressed_size': None}} 47 | 48 | def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): 49 | self._version = version 50 | if version is not None: 51 | raise ValueError('Versioning for Unlabeled MolPCBA is handled through the OGB package. Please set version=none.') 52 | # internally call ogb package 53 | 54 | ### Setting up meta-information for the dataset 55 | meta_dict = {} 56 | meta_dict['dir_path'] = os.path.join(root_dir, 'molpcba_unlabeled') 57 | meta_dict['url'] = 'http://snap.stanford.edu/ogb/data/wilds/molpcba_unlabeled.zip' 58 | meta_dict['num tasks'] = 0 59 | meta_dict['eval metric'] = None 60 | meta_dict['download_name'] = 'molpcba_unlabeled' 61 | meta_dict['version'] = 1 62 | meta_dict['add_inverse_edge'] = 'False' 63 | meta_dict['data type'] = 'mol' 64 | meta_dict['has_node_attr'] = 'True' 65 | meta_dict['has_edge_attr'] = 'True' 66 | meta_dict['task type'] = 'classification' 67 | meta_dict['num classes'] = -1 68 | meta_dict['split'] = 'scaffold' 69 | meta_dict['additional node files'] = 'None' 70 | meta_dict['additional edge files'] = 'None' 71 | meta_dict['binary'] = 'True' 72 | 73 | self.ogb_dataset = PygGraphPropPredDataset(name = 'molpcba_unlabeled', root = root_dir, meta_dict = meta_dict) 74 | self.ogb_dataset.data.y = None 75 | 76 | # set variables 77 | self._data_dir = self.ogb_dataset.root 78 | if split_scheme=='official': 79 | split_scheme = 'scaffold' 80 | self._split_scheme = split_scheme 81 | 82 | self._split_array = torch.zeros(len(self.ogb_dataset)).long() 83 | split_idx = self.ogb_dataset.get_idx_split() 84 | self._split_array[split_idx['train']] = 10 85 | self._split_array[split_idx['valid']] = 11 86 | self._split_array[split_idx['test']] = 12 87 | 88 | self._metadata_fields = ['scaffold'] 89 | 90 | metadata_file_path = os.path.join(self.ogb_dataset.root, 'processed', 'group_assignment.npy') 91 | self._metadata_array = torch.from_numpy(np.load(metadata_file_path)).reshape(-1,1).long() 92 | 93 | if torch_geometric.__version__ >= '1.7.0': 94 | self._collate = PyGCollater(follow_batch=[], exclude_keys=[]) 95 | else: 96 | self._collate = PyGCollater(follow_batch=[]) 97 | 98 | super().__init__(root_dir, download, split_scheme) 99 | 100 | def get_input(self, idx): 101 | return self.ogb_dataset[int(idx)] 102 | -------------------------------------------------------------------------------- /wilds/download_datasets.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import wilds 4 | 5 | def main(): 6 | """ 7 | Downloads the latest versions of all specified datasets, 8 | if they do not already exist. 9 | """ 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--root_dir', required=True, 12 | help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).') 13 | parser.add_argument('--datasets', nargs='*', default=None, 14 | help=f'Specify a space-separated list of dataset names to download. If left unspecified, the script will download all of the official benchmark datasets. Available choices are {wilds.supported_datasets}.') 15 | parser.add_argument('--unlabeled', default=False, type=bool, 16 | help=f'If this flag is set, the unlabeled dataset will be downloaded instead of the labeled.') 17 | config = parser.parse_args() 18 | 19 | if config.datasets is None: 20 | config.datasets = wilds.benchmark_datasets 21 | 22 | for dataset in config.datasets: 23 | if dataset not in wilds.supported_datasets: 24 | raise ValueError(f'{dataset} not recognized; must be one of {wilds.supported_datasets}.') 25 | 26 | print(f'Downloading the following datasets: {config.datasets}') 27 | for dataset in config.datasets: 28 | print(f'=== {dataset} ===') 29 | wilds.get_dataset( 30 | dataset=dataset, 31 | root_dir=config.root_dir, 32 | unlabeled=config.unlabeled, 33 | download=True) 34 | 35 | 36 | if __name__=='__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /wilds/version.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/snap-stanford/ogb/blob/master/ogb/version.py 2 | 3 | import os 4 | import logging 5 | from threading import Thread 6 | 7 | __version__ = '2.0.0' 8 | 9 | try: 10 | os.environ['OUTDATED_IGNORE'] = '1' 11 | from outdated import check_outdated # noqa 12 | except ImportError: 13 | check_outdated = None 14 | 15 | def check(): 16 | try: 17 | is_outdated, latest = check_outdated('wilds', __version__) 18 | if is_outdated: 19 | logging.warning( 20 | f'The WILDS package is out of date. Your version is ' 21 | f'{__version__}, while the latest version is {latest}.') 22 | except Exception: 23 | pass 24 | 25 | if check_outdated is not None: 26 | thread = Thread(target=check) 27 | thread.start() 28 | --------------------------------------------------------------------------------