diff --git a/README.md b/README.md index 1cc09658..ca56d6a0 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ pip install wilds If you have already installed it, please check that you have the latest version: ```bash python -c "import wilds; print(wilds.__version__)" -# This should print "1.0.0". If it doesn't, update by running: +# This should print "1.1.0". If it doesn't, update by running: pip install -U wilds ``` @@ -42,15 +42,15 @@ pip install -e . ### Requirements - numpy>=1.19.1 +- ogb>=1.2.6 +- outdated>=0.2.0 - pandas>=1.1.0 - pillow>=7.2.0 -- torch>=1.7.0 -- tqdm>=4.53.0 - pytz>=2020.4 -- outdated>=0.2.0 -- ogb>=1.2.3 +- torch>=1.7.0 - torch-scatter>=2.0.5 - torch-geometric>=1.6.1 +- tqdm>=4.53.0 Running `pip install wilds` or `pip install -e .` will automatically check for and install all of these requirements except for the `torch-scatter` and `torch-geometric` packages, which require a [quick manual install](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html#installation-via-binaries). @@ -70,39 +70,69 @@ To run these scripts, you will need to install these additional dependencies: All baseline experiments in the paper were run on Python 3.8.5 and CUDA 10.1. -## Usage -### Default models -In the `examples/` folder, we provide a set of scripts that we used to train models on the WILDS package. These scripts are configured with the default models and hyperparameters that we used for all of the baselines described in our paper. All baseline results in the paper can be easily replicated with commands like: + +## Using the example scripts + +In the `examples/` folder, we provide a set of scripts that can be used to download WILDS datasets and train models on them. +These scripts are configured with the default models and hyperparameters that we used for all of the baselines described in our paper. All baseline results in the paper can be easily replicated with commands like: ```bash -cd examples -python run_expt.py --dataset iwildcam --algorithm ERM --root_dir data -python run_expt.py --dataset civilcomments --algorithm groupDRO --root_dir data +python examples/run_expt.py --dataset iwildcam --algorithm ERM --root_dir data +python examples/run_expt.py --dataset civilcomments --algorithm groupDRO --root_dir data ``` The scripts are set up to facilitate general-purpose algorithm development: new algorithms can be added to `examples/algorithms` and then run on all of the WILDS datasets using the default models. The first time you run these scripts, you might need to download the datasets. You can do so with the `--download` argument, for example: ``` -python run_expt.py --dataset civilcomments --algorithm groupDRO --root_dir data --download +python examples/run_expt.py --dataset civilcomments --algorithm groupDRO --root_dir data --download ``` +Alternatively, you can use the standalone `wilds/download_datasets.py` script to download the datasets, for example: + +```bash +python wilds/download_datasets.py --root_dir data +``` + +This will download all datasets to the specified `data` folder. You can also use the `--datasets` argument to download particular datasets. + +These are the sizes of each of our datasets, as well as their approximate time taken to train and evaluate the default model for a single ERM run using a NVIDIA V100 GPU. + +| Dataset command | Modality | Download size (GB) | Size on disk (GB) | Train+eval time (Hours) | +|-----------------|----------|--------------------|-------------------|-------------------------| +| iwildcam | Image | 11 | 25 | 7 | +| camelyon17 | Image | 10 | 15 | 2 | +| ogb-molpcba | Graph | 0.04 | 2 | 15 | +| civilcomments | Text | 0.1 | 0.3 | 4.5 | +| fmow | Image | 50 | 55 | 6 | +| poverty | Image | 12 | 14 | 5 | +| amazon | Text | 6.6 | 7 | 5 | +| py150 | Text | 0.1 | 0.8 | 9.5 | + +While the `camelyon17` dataset is small and fast to train on, we advise against using it as the only dataset to prototype methods on, as the test performance of models trained on this dataset tend to exhibit a large degree of variability over random seeds. + +The image datasets (`iwildcam`, `camelyon17`, `fmow`, and `poverty`) tend to have high disk I/O usage. If training time is much slower for you than the approximate times listed above, consider checking if I/O is a bottleneck (e.g., by moving to a local disk if you are using a network drive, or by increasing the number of data loader workers). To speed up training, you could also disable evaluation at each epoch or for all splits by toggling `--evaluate_all_splits` and related arguments. + +We have an [executable version](https://wilds.stanford.edu/codalab) of our paper on CodaLab that contains the exact commands, code, and data used for the experiments reported in our paper. Trained model weights for all datasets can also be found there. + + +## Using the WILDS package ### Data loading The WILDS package provides a simple, standardized interface for all datasets in the benchmark. This short Python snippet covers all of the steps of getting started with a WILDS dataset, including dataset download and initialization, accessing various splits, and preparing a user-customizable data loader. ```py ->>> from wilds.datasets.iwildcam_dataset import IWildCamDataset +>>> from wilds import get_dataset >>> from wilds.common.data_loaders import get_train_loader >>> import torchvision.transforms as transforms # Load the full dataset, and download it if necessary ->>> dataset = IWildCamDataset(download=True) +>>> dataset = get_dataset(dataset='iwildcam', download=True) # Get the training set >>> train_data = dataset.get_subset('train', -... transform=transforms.Compose([transforms.Resize((224,224)), +... transform=transforms.Compose([transforms.Resize((448,448)), ... transforms.ToTensor()])) # Prepare the standard data loader @@ -171,11 +201,12 @@ Invoking the `eval` method of each dataset yields all metrics reported in the pa >>> dataset.eval(all_y_pred, all_y_true, all_metadata) {'recall_macro_all': 0.66, ...} ``` +Most `eval` methods take in predicted labels for `all_y_pred` by default, but the default inputs vary across datasets and are documented in the `eval` docstrings of the corresponding dataset class. ## Citing WILDS If you use WILDS datasets in your work, please cite [our paper](https://arxiv.org/abs/2012.07421) ([Bibtex](https://wilds.stanford.edu/assets/files/bibtex.md)): -- **WILDS: A Benchmark of in-the-Wild Distribution Shifts** (2020). Pang Wei Koh*, Shiori Sagawa*, Henrik Marklund, Sang Michael Xie, Marvin Zhang, Akshay Balsubramani, Weihua Hu, Michihiro Yasunaga, Richard Lanas Phillips, Sara Beery, Jure Leskovec, Anshul Kundaje, Emma Pierson, Sergey Levine, Chelsea Finn, and Percy Liang. +- **WILDS: A Benchmark of in-the-Wild Distribution Shifts** (2021). Pang Wei Koh*, Shiori Sagawa*, Henrik Marklund, Sang Michael Xie, Marvin Zhang, Akshay Balsubramani, Weihua Hu, Michihiro Yasunaga, Richard Lanas Phillips, Irena Gao, Tony Lee, Etienne David, Ian Stavness, Wei Guo, Berton A. Earnshaw, Imran S. Haque, Sara Beery, Jure Leskovec, Anshul Kundaje, Emma Pierson, Sergey Levine, Chelsea Finn, and Percy Liang. Please also cite the original papers that introduce the datasets, as listed on the [datasets page](https://wilds.stanford.edu/datasets/). diff --git a/dataset_preprocessing/amazon_yelp/subsample_amazon.py b/dataset_preprocessing/amazon_yelp/subsample_amazon.py new file mode 100644 index 00000000..7b4971bd --- /dev/null +++ b/dataset_preprocessing/amazon_yelp/subsample_amazon.py @@ -0,0 +1,157 @@ +import argparse +import csv +import os + +import pandas as pd +import numpy as np + +# Fix the seed for reproducibility +np.random.seed(0) + +""" +Subsample the Amazon dataset. + +Usage: + python dataset_preprocessing/amazon_yelp/subsample_amazon.py +""" + +NOT_IN_DATASET = -1 +# Split: {'train': 0, 'val': 1, 'id_val': 2, 'test': 3, 'id_test': 4} +TRAIN, OOD_VAL, ID_VAL, OOD_TEST, ID_TEST = range(5) + + +def main(dataset_path, frac=0.25): + def output_dataset_sizes(split_df): + print("-" * 50) + print(f'Train size: {len(split_df[split_df["split"] == TRAIN])}') + print(f'Val size: {len(split_df[split_df["split"] == OOD_VAL])}') + print(f'ID Val size: {len(split_df[split_df["split"] == ID_VAL])}') + print(f'Test size: {len(split_df[split_df["split"] == OOD_TEST])}') + print(f'ID Test size: {len(split_df[split_df["split"] == ID_TEST])}') + print( + f'Number of examples not included: {len(split_df[split_df["split"] == NOT_IN_DATASET])}' + ) + print("-" * 50) + print("\n") + + data_df = pd.read_csv( + os.path.join(dataset_path, "reviews.csv"), + dtype={ + "reviewerID": str, + "asin": str, + "reviewTime": str, + "unixReviewTime": int, + "reviewText": str, + "summary": str, + "verified": bool, + "category": str, + "reviewYear": int, + }, + keep_default_na=False, + na_values=[], + quoting=csv.QUOTE_NONNUMERIC, + ) + + user_csv_path = os.path.join(dataset_path, "splits", "user.csv") + split_df = pd.read_csv(user_csv_path) + output_dataset_sizes(split_df) + + train_data_df = data_df[split_df["split"] == 0] + train_reviewer_ids = train_data_df.reviewerID.unique() + print(f"Number of unique reviewers in train set: {len(train_reviewer_ids)}") + + # Randomly sample (1 - frac) x number of reviewers + # Blackout all the reviews belonging to the randomly sampled reviewers + subsampled_reviewers_count = int((1 - frac) * len(train_reviewer_ids)) + subsampled_reviewers = np.random.choice( + train_reviewer_ids, subsampled_reviewers_count, replace=False + ) + print(subsampled_reviewers) + + blackout_indices = train_data_df[ + train_data_df["reviewerID"].isin(subsampled_reviewers) + ].index + + # Mark all the corresponding reviews of blackout_indices as -1 + split_df.loc[blackout_indices, "split"] = NOT_IN_DATASET + output_dataset_sizes(split_df) + + # Mark duplicates + duplicated_within_user = data_df[["reviewerID", "reviewText"]].duplicated() + df_deduplicated_within_user = data_df[~duplicated_within_user] + duplicated_text = df_deduplicated_within_user[ + df_deduplicated_within_user["reviewText"] + .apply(lambda x: x.lower()) + .duplicated(keep=False) + ]["reviewText"] + duplicated_text = set(duplicated_text.values) + data_df["duplicate"] = ( + data_df["reviewText"].isin(duplicated_text) + ) | duplicated_within_user + + # Mark html candidates + data_df["contains_html"] = data_df["reviewText"].apply( + lambda x: "<" in x and ">" in x + ) + + # Mark clean ones + data_df["clean"] = ~data_df["duplicate"] & ~data_df["contains_html"] + + # Clear ID val and ID test since we're regenerating + split_df.loc[split_df["split"] == ID_VAL, "split"] = NOT_IN_DATASET + split_df.loc[split_df["split"] == ID_TEST, "split"] = NOT_IN_DATASET + + # Regenerate ID val and ID test + train_reviewer_ids = data_df[split_df["split"] == TRAIN]["reviewerID"].unique() + np.random.shuffle(train_reviewer_ids) + cutoff = int(len(train_reviewer_ids) / 2) + id_val_reviewer_ids = train_reviewer_ids[:cutoff] + id_test_reviewer_ids = train_reviewer_ids[cutoff:] + split_df.loc[ + (split_df["split"] == NOT_IN_DATASET) + & data_df["clean"] + & data_df["reviewerID"].isin(id_val_reviewer_ids), + "split", + ] = ID_VAL + split_df.loc[ + (split_df["split"] == NOT_IN_DATASET) + & data_df["clean"] + & data_df["reviewerID"].isin(id_test_reviewer_ids), + "split", + ] = ID_TEST + + # Sanity check + assert ( + data_df[(split_df["split"] == ID_VAL)]["reviewerID"].value_counts().min() == 75 + ) + assert ( + data_df[(split_df["split"] == ID_VAL)]["reviewerID"].value_counts().max() == 75 + ) + assert ( + data_df[(split_df["split"] == ID_TEST)]["reviewerID"].value_counts().min() == 75 + ) + assert ( + data_df[(split_df["split"] == ID_TEST)]["reviewerID"].value_counts().max() == 75 + ) + + # Write out the new splits to user.csv + output_dataset_sizes(split_df) + split_df.to_csv(user_csv_path, index=False) + print("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Subsample the Amazon dataset.") + parser.add_argument( + "path", + type=str, + help="Path to the Amazon dataset", + ) + parser.add_argument( + "frac", + type=float, + help="Subsample fraction", + ) + + args = parser.parse_args() + main(args.path, args.frac) diff --git a/dataset_preprocessing/fmow/convert_npy_to_jpg.py b/dataset_preprocessing/fmow/convert_npy_to_jpg.py new file mode 100644 index 00000000..c883198b --- /dev/null +++ b/dataset_preprocessing/fmow/convert_npy_to_jpg.py @@ -0,0 +1,28 @@ +import os, sys +import argparse +import numpy as np +from PIL import Image +from pathlib import Path +from tqdm import tqdm + +def main(): + + parser = argparse.ArgumentParser() + parser.add_argument('--root_dir', required=True, + help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).') + config = parser.parse_args() + data_dir = Path(config.root_dir) / 'fmow_v1.0' + image_dir = Path(config.root_dir) / 'fmow_v1.0_images_jpg' + os.makedirs(image_dir, exist_ok=True) + + img_counter = 0 + for chunk in tqdm(range(101)): + npy_chunk = np.load(data_dir / f'rgb_all_imgs_{chunk}.npy', mmap_mode='r') + for i in range(len(npy_chunk)): + npy_image = npy_chunk[i] + img = Image.fromarray(npy_image, mode='RGB') + img.save(image_dir / f'rgb_img_{img_counter}.jpg') + img_counter += 1 + +if __name__=='__main__': + main() diff --git a/dataset_preprocessing/iwildcam/create_split.py b/dataset_preprocessing/iwildcam/create_split.py index d181a7fd..249c4dcd 100644 --- a/dataset_preprocessing/iwildcam/create_split.py +++ b/dataset_preprocessing/iwildcam/create_split.py @@ -1,69 +1,52 @@ +from datetime import datetime +from pathlib import Path +import argparse +import json +from PIL import Image import pandas as pd import numpy as np -# Examples to skip due to e.g them missing, loading issues -LOCATIONS_TO_SKIP = [537] - -CANNOT_OPEN = ['99136aa6-21bc-11ea-a13a-137349068a90.jpg', - '87022118-21bc-11ea-a13a-137349068a90.jpg', - '8f17b296-21bc-11ea-a13a-137349068a90.jpg', - '883572ba-21bc-11ea-a13a-137349068a90.jpg', - '896c1198-21bc-11ea-a13a-137349068a90.jpg', - '8792549a-21bc-11ea-a13a-137349068a90.jpg', - '94529be0-21bc-11ea-a13a-137349068a90.jpg'] - -CANNOT_LOAD = ['929da9de-21bc-11ea-a13a-137349068a90.jpg', - '9631e6a0-21bc-11ea-a13a-137349068a90.jpg', - '8c3a31fc-21bc-11ea-a13a-137349068a90.jpg', - '88313344-21bc-11ea-a13a-137349068a90.jpg', - '8c53e822-21bc-11ea-a13a-137349068a90.jpg', - '911848a8-21bc-11ea-a13a-137349068a90.jpg', - '98bd006c-21bc-11ea-a13a-137349068a90.jpg', - '91ba7b50-21bc-11ea-a13a-137349068a90.jpg', - '9799f64a-21bc-11ea-a13a-137349068a90.jpg', - '88007592-21bc-11ea-a13a-137349068a90.jpg', - '94860606-21bc-11ea-a13a-137349068a90.jpg', - '9166fbd8-21bc-11ea-a13a-137349068a90.jpg'] - -OTHER = ['8e0c091a-21bc-11ea-a13a-137349068a90.jpg'] # This one got slightly different error - - -IDS_TO_SKIP = CANNOT_OPEN + CANNOT_LOAD + OTHER +def create_split(data_dir, seed): + np_rng = np.random.default_rng(seed) + # Loading json was adapted from + # https://www.kaggle.com/ateplyuk/iwildcam2020-pytorch-start + filename = f'iwildcam2021_train_annotations_final.json' + with open(data_dir / filename ) as json_file: + data = json.load(json_file) -def create_split(data_dir): - train_df, val_cis_df, val_trans_df, test_cis_df, test_trans_df = _create_split(data_dir, seed=0) + df_annotations = pd.DataFrame({ + 'category_id': [item['category_id'] for item in data['annotations']], + 'image_id': [item['image_id'] for item in data['annotations']] + }) - train_df.to_csv(data_dir / 'train.csv') - val_cis_df.to_csv(data_dir / 'val_cis.csv') - val_trans_df.to_csv(data_dir / 'val_trans.csv') - test_cis_df.to_csv(data_dir / 'test_cis.csv') - test_trans_df.to_csv(data_dir / 'test_trans.csv') + df_metadata = pd.DataFrame({ + 'image_id': [item['id'] for item in data['images']], + 'location': [item['location'] for item in data['images']], + 'filename': [item['file_name'] for item in data['images']], + 'datetime': [item['datetime'] for item in data['images']], + 'frame_num': [item['frame_num'] for item in data['images']], # this attribute is not used + 'seq_id': [item['seq_id'] for item in data['images']] # this attribute is not used + }) -def _create_split(data_dir, seed, skip=True): - data_dir = Path(data_dir) - np_rng = np.random.default_rng(seed) + df = df_metadata.merge(df_annotations, on='image_id', how='inner') - # Load Kaggle train data - with open(data_dir / r'iwildcam2020_train_annotations.json' ) as json_file: - data = json.load(json_file) + # Create category_id to name dictionary + cat_id_to_name_map = {} + for item in data['categories']: + cat_id_to_name_map[item['id']] = item['name'] + df['category_name'] = df['category_id'].apply(lambda x: cat_id_to_name_map[x]) - # This line was adapted from - # https://www.kaggle.com/ateplyuk/iwildcam2020-pytorch-start - df = pd.DataFrame( - { - 'id': [item['id'] for item in data['annotations']], - 'category_id': [item['category_id'] for item in data['annotations']], - 'image_id': [item['image_id'] for item in data['annotations']], - 'location': [item['location'] for item in data['images']], - 'filename': [item['file_name'] for item in data['images']], - 'datetime': [item['datetime'] for item in data['images']], - 'frame_num': [item['frame_num'] for item in data['images']], # this attribute is not used - 'seq_id': [item['seq_id'] for item in data['images']] # this attribute is not used - }) + # Extract the date from the datetime. + df['datetime_obj'] = df['datetime'].apply(lambda x: datetime.strptime(x, '%Y-%m-%d %H:%M:%S.%f')) + df['date'] = df['datetime_obj'].apply(lambda x: x.date()) + # Retrieve the sequences that span 2 days + grouped_by = df.groupby('seq_id') + nunique_dates = grouped_by['date'].nunique() + seq_ids_that_span_across_days = nunique_dates[nunique_dates.values > 1].reset_index()['seq_id'].values # Split by location to get the cis & trans validation set locations = np.unique(df['location']) @@ -78,13 +61,14 @@ def _create_split(data_dir, seed, skip=True): train_locations, val_trans_locations = locations[:n_train_locations], locations[n_train_locations:(n_train_locations+n_val_locations)] test_trans_locations = locations[(n_train_locations+n_val_locations):] + remaining_df, val_trans_df = df[df['location'].isin(train_locations)], df[df['location'].isin(val_trans_locations)] test_trans_df = df[df['location'].isin(test_trans_locations)] # Split remaining samples by dates to get the cis validation and test set - frac_validation = 0.05 - frac_test = 0.05 - unique_dates = np.unique(remaining_df['datetime']) + frac_validation = 0.07 + frac_test = 0.09 + unique_dates = np.unique(remaining_df['date']) n_dates = len(unique_dates) n_val_dates = int(n_dates * frac_validation) n_test_dates = int(n_dates * frac_test) @@ -94,9 +78,9 @@ def _create_split(data_dir, seed, skip=True): train_dates, val_cis_dates = unique_dates[:n_train_dates], unique_dates[n_train_dates:(n_train_dates+n_val_dates)] test_cis_dates = unique_dates[(n_train_dates+n_val_dates):] - val_cis_df = remaining_df[remaining_df['datetime'].isin(val_cis_dates)] - test_cis_df = remaining_df[remaining_df['datetime'].isin(test_cis_dates)] - train_df = remaining_df[remaining_df['datetime'].isin(train_dates)] + val_cis_df = remaining_df[remaining_df['date'].isin(val_cis_dates)] + test_cis_df = remaining_df[remaining_df['date'].isin(test_cis_dates)] + train_df = remaining_df[remaining_df['date'].isin(train_dates)] # Locations in val_cis and test_cis but not in train are all moved to train set # since we want all locations in tcis splits to be in the train set. @@ -120,35 +104,89 @@ def _create_split(data_dir, seed, skip=True): test_cis_df = test_cis_df[test_cis_df['category_id'].isin(train_classes)] test_trans_df = test_trans_df[test_trans_df['category_id'].isin(train_classes)] - # Remove examples that are corrupted in some way - if skip: - train_df, val_cis_df, val_trans_df, test_cis_df, test_trans_df = remove([train_df, val_cis_df, - val_trans_df, test_cis_df, - test_trans_df]) - + # Assert that all sequences that spanned across multiple days ended up in the same split + for seq_id in seq_ids_that_span_across_days: + n_splits = 0 + for split_df in [train_df, val_cis_df, test_cis_df]: + if seq_id in split_df['seq_id'].values: + n_splits += 1 + assert n_splits == 1, "Each sequence should only be in one split. Please move manually" # Reset index - train_df.reset_index(inplace=True), val_cis_df.reset_index(inplace=True), val_trans_df.reset_index(inplace=True) - test_cis_df.reset_index(inplace=True), test_trans_df.reset_index(inplace=True) + train_df.reset_index(inplace=True, drop=True), val_cis_df.reset_index(inplace=True, drop=True), val_trans_df.reset_index(inplace=True, drop=True) + test_cis_df.reset_index(inplace=True, drop=True), test_trans_df.reset_index(inplace=True, drop=True) + + print("n train: ", len(train_df)) + print("n val trans: ", len(val_trans_df)) + print("n test trans: ", len(test_trans_df)) + print("n val cis: ", len(val_cis_df)) + print("n test cis: ", len(test_cis_df)) + + # Merge into one df + train_df['split'] = 'train' + val_trans_df['split'] = 'val' + test_trans_df['split'] = 'test' + val_cis_df['split'] = 'id_val' + test_cis_df['split'] = 'id_test' + df = pd.concat([train_df, val_trans_df, test_trans_df, test_cis_df, val_cis_df]) + df = df.reset_index(drop=True) + + # Create y labels by remapping the category ids to be contiguous + unique_categories = np.unique(df['category_id']) + n_classes = len(unique_categories) + category_to_label = dict([(i, j) for i, j in zip(unique_categories, range(n_classes))]) + df['y'] = df['category_id'].apply(lambda x: category_to_label[x]).values + print("N classes: ", n_classes) + + # Create y to category name map and save + categories_df = pd.DataFrame({ + 'category_id': [item['id'] for item in data['categories']], + 'name': [item['name'] for item in data['categories']] + }) + + categories_df['y'] = categories_df['category_id'].apply(lambda x: category_to_label[x] if x in category_to_label else 99999) + categories_df = categories_df.sort_values('y').reset_index(drop=True) + categories_df = categories_df[['y','category_id','name']] + + # Create remapped location id such that they are contigious + location_ids = df['location'] + locations = np.unique(location_ids) + n_groups = len(locations) + location_to_group_id = {locations[i]: i for i in range(n_groups)} + df['location_remapped' ] = df['location'].apply(lambda x: location_to_group_id[x]) + + # Create remapped sequence id such that they are contigious + sequence_ids = df['seq_id'] + sequences = np.unique(sequence_ids) + n_sequences = len(sequences) + sequence_to_normalized_id = {sequences[i]: i for i in range(n_sequences)} + df['sequence_remapped' ] = df['seq_id'].apply(lambda x: sequence_to_normalized_id[x]) + # Make sure there's no overlap - for df in [val_cis_df, val_trans_df, test_cis_df, test_trans_df]: - assert not check_overlap(train_df, df) - - return train_df, val_cis_df, val_trans_df, test_cis_df, test_trans_df - -def remove(dfs): - new_dfs = [] - for df in dfs: - df = df[~df['location'].isin(LOCATIONS_TO_SKIP)] - df = df[~df['filename'].isin(IDS_TO_SKIP)] - new_dfs.append(df) - return new_dfs - -def check_overlap(df1, df2): - files1 = set(df1['filename']) - files2 = set(df2['filename']) + for split_df in [val_cis_df, val_trans_df, test_cis_df, test_trans_df]: + assert not check_overlap(train_df, split_df) + + # Save + df = df.sort_values(['split','location_remapped', 'sequence_remapped','datetime']).reset_index(drop=True) + cols = ['split', 'location_remapped', 'location', 'sequence_remapped', 'seq_id', 'y', 'category_id', 'datetime', 'filename', 'image_id'] + df[cols].to_csv(data_dir / 'metadata.csv') + categories_df.to_csv(data_dir / 'categories.csv', index=False) + + +def check_overlap(df1, df2, column='filename'): + files1 = set(df1[column]) + files2 = set(df2[column]) intersection = files1.intersection(files2) n_intersection = len(intersection) return False if n_intersection == 0 else True + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', type=str) + args = parser.parse_args() + + create_split(Path(args.data_dir), seed=0) diff --git a/dataset_preprocessing/poverty/split_npys.py b/dataset_preprocessing/poverty/split_npys.py new file mode 100644 index 00000000..4bf9f023 --- /dev/null +++ b/dataset_preprocessing/poverty/split_npys.py @@ -0,0 +1,25 @@ +import os, sys +import argparse +import numpy as np +from PIL import Image +from pathlib import Path +from tqdm import tqdm + +def main(): + + parser = argparse.ArgumentParser() + parser.add_argument('--root_dir', required=True, + help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).') + config = parser.parse_args() + data_dir = Path(config.root_dir) / 'poverty_v1.0' + indiv_dir = Path(config.root_dir) / 'poverty_v1.0_indiv_npz' + os.makedirs(indiv_dir, exist_ok=True) + + f = np.load(data_dir / 'landsat_poverty_imgs.npy', mmap_mode='r') + f = f.transpose((0, 3, 1, 2)) + for i in tqdm(range(len(f))): + x = f[i] + np.savez_compressed(indiv_dir / f'landsat_poverty_img_{i}.npz', x=x) + +if __name__=='__main__': + main() diff --git a/examples/algorithms/deepCORAL.py b/examples/algorithms/deepCORAL.py index 7069d127..e82981d4 100644 --- a/examples/algorithms/deepCORAL.py +++ b/examples/algorithms/deepCORAL.py @@ -27,8 +27,9 @@ def __init__(self, config, d_out, grouper, loss, metric, n_train_steps): assert config.uniform_over_groups assert config.distinct_groups # initialize models - featurizer = initialize_model(config, d_out=None).to(config.device) - classifier = torch.nn.Linear(featurizer.d_out, d_out).to(config.device) + featurizer, classifier = initialize_model(config, d_out=d_out, is_featurizer=True) + featurizer = featurizer.to(config.device) + classifier = classifier.to(config.device) model = torch.nn.Sequential(featurizer, classifier).to(config.device) # initialize module super().__init__( @@ -48,6 +49,12 @@ def __init__(self, config, d_out, grouper, loss, metric, n_train_steps): self.classifier = classifier def coral_penalty(self, x, y): + if x.dim() > 2: + # featurizers output Tensors of size (batch_size, ..., feature dimensionality). + # we flatten to Tensors of size (*, feature dimensionality) + x = x.view(-1, x.size(-1)) + y = y.view(-1, y.size(-1)) + mean_x = x.mean(0, keepdim=True) mean_y = y.mean(0, keepdim=True) cent_x = x - mean_x diff --git a/examples/algorithms/initializer.py b/examples/algorithms/initializer.py index 9c6a5444..00748cfc 100644 --- a/examples/algorithms/initializer.py +++ b/examples/algorithms/initializer.py @@ -14,6 +14,8 @@ def initialize_algorithm(config, datasets, train_grouper): if (train_dataset.is_classification) and (train_dataset.y_size == 1): # For single-task classification, we have one output per class d_out = train_dataset.n_classes + elif (train_dataset.is_classification) and (train_dataset.y_size is None): + d_out = train_dataset.n_classes elif (train_dataset.is_classification) and (train_dataset.y_size > 1) and (train_dataset.n_classes == 2): # For multi-task binary classification (each output is the logit for each binary class) d_out = train_dataset.y_size diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 1d15c7af..cd2d1d6f 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -1,19 +1,24 @@ dataset_defaults = { 'amazon': { 'split_scheme': 'official', - 'model': 'bert-base-uncased', + 'model': 'distilbert-base-uncased', 'train_transform': 'bert', 'eval_transform': 'bert', 'max_token_length': 512, 'loss_function': 'cross_entropy', 'algo_log_metric': 'accuracy', 'batch_size': 8, - 'lr': 2e-6, + 'lr': 1e-5, 'weight_decay': 0.01, 'n_epochs': 3, 'n_groups_per_batch': 2, 'irm_lambda': 1.0, - 'coral_penalty_weight': 10.0, + 'coral_penalty_weight': 1.0, + 'loader_kwargs': { + 'num_workers': 1, + 'pin_memory': True, + }, + 'process_outputs_function': 'multiclass_logits_to_pred', }, 'bdd100k': { 'split_scheme': 'official', @@ -28,9 +33,10 @@ 'lr': 0.001, 'weight_decay': 0.0001, 'n_epochs': 10, - 'algo_log_metric': 'multitask_accuracy', + 'algo_log_metric': 'multitask_binary_accuracy', 'train_transform': 'image_base', 'eval_transform': 'image_base', + 'process_outputs_function': 'binary_logits_to_pred', }, 'camelyon17': { 'split_scheme': 'official', @@ -38,6 +44,7 @@ 'model_kwargs': {'pretrained': False}, 'train_transform': 'image_base', 'eval_transform': 'image_base', + 'target_resolution': (96, 96), 'loss_function': 'cross_entropy', 'groupby_fields': ['hospital'], 'val_metric': 'acc_avg', @@ -53,6 +60,7 @@ 'irm_lambda': 1.0, 'coral_penalty_weight': 0.1, 'algo_log_metric': 'accuracy', + 'process_outputs_function': 'multiclass_logits_to_pred', }, 'celebA': { 'split_scheme': 'official', @@ -72,10 +80,11 @@ 'weight_decay': 0.0, 'n_epochs': 200, 'algo_log_metric': 'accuracy', + 'process_outputs_function': 'multiclass_logits_to_pred', }, 'civilcomments': { 'split_scheme': 'official', - 'model': 'bert-base-uncased', + 'model': 'distilbert-base-uncased', 'train_transform': 'bert', 'eval_transform': 'bert', 'loss_function': 'cross_entropy', @@ -88,12 +97,19 @@ 'n_epochs': 5, 'algo_log_metric': 'accuracy', 'max_token_length': 300, + 'irm_lambda': 1.0, + 'coral_penalty_weight': 10.0, + 'loader_kwargs': { + 'num_workers': 1, + 'pin_memory': True, + }, + 'process_outputs_function': 'multiclass_logits_to_pred', }, 'fmow': { 'split_scheme': 'official', 'dataset_kwargs': { 'oracle_training_set': False, - 'seed':111, + 'seed': 111, 'use_ood_val': True }, 'model': 'densenet121', @@ -102,7 +118,7 @@ 'eval_transform': 'image_base', 'loss_function': 'cross_entropy', 'groupby_fields': ['year',], - 'val_metric': 'acc_avg', + 'val_metric': 'acc_worst_region', 'val_metric_decreasing': False, 'optimizer': 'Adam', 'scheduler': 'StepLR', @@ -115,6 +131,7 @@ 'irm_lambda': 1.0, 'coral_penalty_weight': 0.1, 'algo_log_metric': 'accuracy', + 'process_outputs_function': 'multiclass_logits_to_pred', }, 'iwildcam': { 'loss_function': 'cross_entropy', @@ -122,22 +139,23 @@ 'model_kwargs': {'pretrained': True}, 'train_transform': 'image_base', 'eval_transform': 'image_base', - 'target_resolution': (224, 224), + 'target_resolution': (448, 448), 'val_metric_decreasing': False, 'algo_log_metric': 'accuracy', 'model': 'resnet50', - 'lr': 1e-5, + 'lr': 3e-5, 'weight_decay': 0.0, 'batch_size': 16, - 'n_epochs': 18, + 'n_epochs': 12, 'optimizer': 'Adam', 'split_scheme': 'official', 'scheduler': None, 'groupby_fields': ['location',], 'n_groups_per_batch': 2, 'irm_lambda': 1., - 'coral_penalty_weight': 0.1, + 'coral_penalty_weight': 10., 'no_group_logging': True, + 'process_outputs_function': 'multiclass_logits_to_pred' }, 'ogb-molpcba': { 'split_scheme': 'official', @@ -156,6 +174,28 @@ 'irm_lambda': 1., 'coral_penalty_weight': 0.1, 'no_group_logging': True, + 'process_outputs_function': None, + 'algo_log_metric': 'multitask_binary_accuracy', + }, + 'py150': { + 'split_scheme': 'official', + 'model': 'code-gpt-py', + 'loss_function': 'lm_cross_entropy', + 'val_metric': 'acc', + 'val_metric_decreasing': False, + 'optimizer': 'AdamW', + 'optimizer_kwargs': {'eps':1e-8}, + 'lr': 8e-5, + 'weight_decay': 0., + 'n_epochs': 3, + 'batch_size': 6, + 'groupby_fields': ['repo',], + 'n_groups_per_batch': 2, + 'irm_lambda': 1., + 'coral_penalty_weight': 1., + 'no_group_logging': True, + 'algo_log_metric': 'multitask_accuracy', + 'process_outputs_function': 'multiclass_logits_to_pred', }, 'poverty': { 'split_scheme': 'official', @@ -165,17 +205,13 @@ 'oracle_training_set': False, 'use_ood_val': True }, - 'loader_kwargs': { - 'num_workers': 1, - 'pin_memory': False, - }, 'model': 'resnet18_ms', 'model_kwargs': {'num_channels': 8}, 'train_transform': 'poverty_train', 'eval_transform': None, 'loss_function': 'mse', 'groupby_fields': ['country',], - 'val_metric': 'r_all', + 'val_metric': 'r_wg', 'val_metric_decreasing': False, 'algo_log_metric': 'mse', 'optimizer': 'Adam', @@ -187,7 +223,8 @@ 'n_epochs': 200, 'n_groups_per_batch': 8, 'irm_lambda': 1.0, - 'coral_penalty_weight': 10, + 'coral_penalty_weight': 0.1, + 'process_outputs_function': None, }, 'waterbirds': { 'split_scheme': 'official', @@ -208,6 +245,7 @@ 'lr': 1e-5, 'weight_decay': 1.0, 'n_epochs': 300, + 'process_outputs_function': 'multiclass_logits_to_pred', }, 'yelp': { 'split_scheme': 'official', @@ -222,6 +260,27 @@ 'weight_decay': 0.01, 'n_epochs': 3, 'n_groups_per_batch': 2, + 'process_outputs_function': 'multiclass_logits_to_pred', + }, + 'sqf': { + 'split_scheme': 'all_race', + 'model': 'logistic_regression', + 'train_transform': None, + 'eval_transform': None, + 'model_kwargs': {'in_features': 104}, + 'loss_function': 'cross_entropy', + 'groupby_fields': ['y'], + 'val_metric': 'precision_at_global_recall_all', + 'val_metric_decreasing': False, + 'algo_log_metric': 'accuracy', + 'optimizer': 'Adam', + 'optimizer_kwargs': {}, + 'scheduler': None, + 'batch_size': 4, + 'lr': 5e-5, + 'weight_decay': 0, + 'n_epochs': 4, + 'process_outputs_function': None, }, } diff --git a/examples/configs/model.py b/examples/configs/model.py index 12a429a7..46714bbe 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -4,6 +4,16 @@ 'max_grad_norm': 1.0, 'scheduler': 'linear_schedule_with_warmup', }, + 'distilbert-base-uncased': { + 'optimizer': 'AdamW', + 'max_grad_norm': 1.0, + 'scheduler': 'linear_schedule_with_warmup', + }, + 'code-gpt-py': { + 'optimizer': 'AdamW', + 'max_grad_norm': 1.0, + 'scheduler': 'linear_schedule_with_warmup', + }, 'densenet121': { 'model_kwargs':{ 'pretrained':True, diff --git a/examples/configs/supported.py b/examples/configs/supported.py index bcbe54a9..8b66b74e 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -1,53 +1,37 @@ import torch.nn as nn import torch import sys, os -# Datasets -from wilds.datasets.amazon_dataset import AmazonDataset -from wilds.datasets.bdd100k_dataset import BDD100KDataset -from wilds.datasets.camelyon17_dataset import Camelyon17Dataset -from wilds.datasets.celebA_dataset import CelebADataset -from wilds.datasets.civilcomments_dataset import CivilCommentsDataset -from wilds.datasets.fmow_dataset import FMoWDataset -from wilds.datasets.iwildcam_dataset import IWildCamDataset -from wilds.datasets.ogbmolpcba_dataset import OGBPCBADataset -from wilds.datasets.poverty_dataset import PovertyMapDataset -from wilds.datasets.waterbirds_dataset import WaterbirdsDataset -from wilds.datasets.yelp_dataset import YelpDataset + # metrics from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss -from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE - -datasets = { - 'amazon': AmazonDataset, - 'camelyon17': Camelyon17Dataset, - 'celebA': CelebADataset, - 'civilcomments': CivilCommentsDataset, - 'iwildcam': IWildCamDataset, - 'waterbirds': WaterbirdsDataset, - 'yelp': YelpDataset, - 'ogb-molpcba': OGBPCBADataset, - 'poverty': PovertyMapDataset, - 'fmow': FMoWDataset, - 'bdd100k': BDD100KDataset, -} +from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, multiclass_logits_to_pred, binary_logits_to_pred losses = { 'cross_entropy': ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')), + 'lm_cross_entropy': MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')), 'mse': MSE(name='loss'), 'multitask_bce': MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')), } algo_log_metrics = { - 'accuracy': Accuracy(), + 'accuracy': Accuracy(prediction_fn=multiclass_logits_to_pred), 'mse': MSE(), - 'multitask_accuracy': MultiTaskAccuracy(), + 'multitask_accuracy': MultiTaskAccuracy(prediction_fn=multiclass_logits_to_pred), + 'multitask_binary_accuracy': MultiTaskAccuracy(prediction_fn=binary_logits_to_pred), + None: None, +} + +process_outputs_functions = { + 'binary_logits_to_pred': binary_logits_to_pred, + 'multiclass_logits_to_pred': multiclass_logits_to_pred, None: None, } # see initialize_*() functions for correspondence transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty_train'] -models = ['resnet18_ms', 'resnet50', 'resnet34', 'wideresnet50', 'densenet121', 'bert-base-uncased', 'gin-virtual', - 'logistic_regression'] +models = ['resnet18_ms', 'resnet50', 'resnet34', 'wideresnet50', + 'densenet121', 'bert-base-uncased', 'distilbert-base-uncased', + 'gin-virtual', 'logistic_regression', 'code-gpt-py'] algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM'] optimizers = ['SGD', 'Adam', 'AdamW'] schedulers = ['linear_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR'] diff --git a/examples/models/bert/__init__.py b/examples/models/bert/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/models/bert.py b/examples/models/bert/bert.py similarity index 100% rename from examples/models/bert.py rename to examples/models/bert/bert.py diff --git a/examples/models/bert/distilbert.py b/examples/models/bert/distilbert.py new file mode 100644 index 00000000..c508fea2 --- /dev/null +++ b/examples/models/bert/distilbert.py @@ -0,0 +1,30 @@ +from transformers import DistilBertForSequenceClassification, DistilBertModel + +class DistilBertClassifier(DistilBertForSequenceClassification): + def __init__(self, config): + super().__init__(config) + + def __call__(self, x): + input_ids = x[:, :, 0] + attention_mask = x[:, :, 1] + outputs = super().__call__( + input_ids=input_ids, + attention_mask=attention_mask, + )[0] + return outputs + + +class DistilBertFeaturizer(DistilBertModel): + def __init__(self, config): + super().__init__(config) + self.d_out = config.hidden_size + + def __call__(self, x): + input_ids = x[:, :, 0] + attention_mask = x[:, :, 1] + hidden_state = super().__call__( + input_ids=input_ids, + attention_mask=attention_mask, + )[0] + pooled_output = hidden_state[:, 0] + return pooled_output diff --git a/examples/models/code_gpt.py b/examples/models/code_gpt.py new file mode 100644 index 00000000..a85ef064 --- /dev/null +++ b/examples/models/code_gpt.py @@ -0,0 +1,35 @@ +from transformers import GPT2LMHeadModel, GPT2Model +import torch + +class GPT2LMHeadLogit(GPT2LMHeadModel): + def __init__(self, config): + super().__init__(config) + self.d_out = config.vocab_size + + def __call__(self, x): + outputs = super().__call__(x) + logits = outputs[0] #[batch_size, seqlen, vocab_size] + return logits + + +class GPT2Featurizer(GPT2Model): + def __init__(self, config): + super().__init__(config) + self.d_out = config.n_embd + + def __call__(self, x): + outputs = super().__call__(x) + hidden_states = outputs[0] #[batch_size, seqlen, n_embd] + return hidden_states + + +class GPT2FeaturizerLMHeadLogit(GPT2LMHeadModel): + def __init__(self, config): + super().__init__(config) + self.d_out = config.vocab_size + self.transformer = GPT2Featurizer(config) + + def __call__(self, x): + hidden_states = self.transformer(x) #[batch_size, seqlen, n_embd] + logits = self.lm_head(hidden_states) #[batch_size, seqlen, vocab_size] + return logits diff --git a/examples/models/initializer.py b/examples/models/initializer.py index cea5ebfc..4d414763 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -1,41 +1,108 @@ import torch.nn as nn import torchvision -from models.bert import BertClassifier, BertFeaturizer +from models.bert.bert import BertClassifier, BertFeaturizer +from models.bert.distilbert import DistilBertClassifier, DistilBertFeaturizer from models.resnet_multispectral import ResNet18 from models.layers import Identity from models.gnn import GINVirtual +from models.code_gpt import GPT2LMHeadLogit, GPT2FeaturizerLMHeadLogit +from transformers import GPT2Tokenizer -def initialize_model(config, d_out): - if config.model == 'resnet18_ms': - # multispectral resnet 18 - model = ResNet18(num_classes=d_out, **config.model_kwargs) - elif config.model in ('resnet50', 'resnet34', 'wideresnet50','densenet121'): - model = initialize_torchvision_model( - name=config.model, - d_out=d_out, - **config.model_kwargs) - elif config.model.startswith('bert'): - if d_out is None: +def initialize_model(config, d_out, is_featurizer=False): + """ + Initializes models according to the config + Args: + - config (dictionary): config dictionary + - d_out (int): the dimensionality of the model output + - is_featurizer (bool): whether to return a model or a (featurizer, classifier) pair that constitutes a model. + Output: + If is_featurizer=True: + - featurizer: a model that outputs feature Tensors of shape (batch_size, ..., feature dimensionality) + - classifier: a model that takes in feature Tensors and outputs predictions. In most cases, this is a linear layer. + + If is_featurizer=False: + - model: a model that is equivalent to nn.Sequential(featurizer, classifier) + """ + if config.model in ('resnet50', 'resnet34', 'wideresnet50', 'densenet121'): + if is_featurizer: + featurizer = initialize_torchvision_model( + name=config.model, + d_out=None, + **config.model_kwargs) + classifier = nn.Linear(featurizer.d_out, d_out) + model = (featurizer, classifier) + else: + model = initialize_torchvision_model( + name=config.model, + d_out=d_out, + **config.model_kwargs) + elif 'bert' in config.model: + if is_featurizer: + featurizer = initialize_bert_based_model(config, d_out, is_featurizer) + classifier = nn.Linear(featurizer.d_out, d_out) + model = (featurizer, classifier) + else: + model = initialize_bert_based_model(config, d_out) + elif config.model == 'resnet18_ms': # multispectral resnet 18 + if is_featurizer: + featurizer = ResNet18(num_classes=None, **config.model_kwargs) + classifier = nn.Linear(featurizer.d_out, d_out) + model = (featurizer, classifier) + else: + model = ResNet18(num_classes=d_out, **config.model_kwargs) + elif config.model == 'gin-virtual': + if is_featurizer: + featurizer = GINVirtual(num_tasks=None, **config.model_kwargs) + classifier = nn.Linear(featurizer.d_out, d_out) + model = (featurizer, classifier) + else: + model = GINVirtual(num_tasks=d_out, **config.model_kwargs) + elif config.model == 'code-gpt-py': + name = 'microsoft/CodeGPT-small-py' + tokenizer = GPT2Tokenizer.from_pretrained(name) + if is_featurizer: + model = GPT2FeaturizerLMHeadLogit.from_pretrained(name) + model.resize_token_embeddings(len(tokenizer)) + featurizer = model.transformer + classifier = model.lm_head + model = (featurizer, classifier) + else: + model = GPT2LMHeadLogit.from_pretrained(name) + model.resize_token_embeddings(len(tokenizer)) + elif config.model == 'logistic_regression': + assert not is_featurizer, "Featurizer not supported for logistic regression" + model = nn.Linear(out_features=d_out, **config.model_kwargs) + else: + raise ValueError(f'Model: {config.model} not recognized.') + return model + +def initialize_bert_based_model(config, d_out, is_featurizer=False): + if config.model == 'bert-base-uncased': + if is_featurizer: model = BertFeaturizer.from_pretrained(config.model, **config.model_kwargs) else: model = BertClassifier.from_pretrained( config.model, num_labels=d_out, **config.model_kwargs) - elif config.model == 'logistic_regression': - model = nn.Linear(out_features=d_out, **config.model_kwargs) - elif config.model == 'gin-virtual': - model = GINVirtual(num_tasks=d_out, **config.model_kwargs) + elif config.model == 'distilbert-base-uncased': + if is_featurizer: + model = DistilBertFeaturizer.from_pretrained(config.model, **config.model_kwargs) + else: + model = DistilBertClassifier.from_pretrained( + config.model, + num_labels=d_out, + **config.model_kwargs) else: - raise ValueError('Model not recognized.') + raise ValueError(f'Model: {config.model} not recognized.') return model def initialize_torchvision_model(name, d_out, **kwargs): # get constructor and last layer names - if name=='wideresnet50': + if name == 'wideresnet50': constructor_name = 'wide_resnet50_2' last_layer_name = 'fc' - elif name=='densenet121': + elif name == 'densenet121': constructor_name = name last_layer_name = 'classifier' elif name in ('resnet50', 'resnet34'): @@ -47,13 +114,12 @@ def initialize_torchvision_model(name, d_out, **kwargs): constructor = getattr(torchvision.models, constructor_name) model = constructor(**kwargs) # adjust the last layer - d = getattr(model, last_layer_name).in_features - if d_out is None: # want to initialize a featurizer model - last_layer = Identity(d) - model.d_out = d + d_features = getattr(model, last_layer_name).in_features + if d_out is None: # want to initialize a featurizer model + last_layer = Identity(d_features) + model.d_out = d_features else: # want to initialize a classifier for a particular num_classes - last_layer = nn.Linear(d, d_out) + last_layer = nn.Linear(d_features, d_out) model.d_out = d_out setattr(model, last_layer_name, last_layer) - # set the feature dimension as an attribute for convenience return model diff --git a/examples/optimizer.py b/examples/optimizer.py index a31777ff..bc390394 100644 --- a/examples/optimizer.py +++ b/examples/optimizer.py @@ -2,8 +2,6 @@ from transformers import AdamW def initialize_optimizer(config, model): - if config.model.startswith('bert'): - assert config.optimizer=='AdamW', 'Only AdamW supported for BERT models' # initialize optimizers if config.optimizer=='SGD': params = filter(lambda p: p.requires_grad, model.parameters()) @@ -13,8 +11,11 @@ def initialize_optimizer(config, model): weight_decay=config.weight_decay, **config.optimizer_kwargs) elif config.optimizer=='AdamW': - assert config.model.startswith('bert'), "Only BERT supported for AdamW" - no_decay = ['bias', 'LayerNorm.weight'] + if 'bert' in config.model or 'gpt' in config.model: + no_decay = ['bias', 'LayerNorm.weight'] + else: + no_decay = [] + params = [ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': config.weight_decay}, {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} diff --git a/examples/run_expt.py b/examples/run_expt.py index 166df04f..acee29db 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -8,6 +8,7 @@ import sys from collections import defaultdict +import wilds from wilds.common.data_loaders import get_train_loader, get_eval_loader from wilds.common.grouper import CombinatorialGrouper @@ -23,7 +24,7 @@ def main(): parser = argparse.ArgumentParser() # Required arguments - parser.add_argument('-d', '--dataset', choices=supported.datasets, required=True) + parser.add_argument('-d', '--dataset', choices=wilds.supported_datasets, required=True) parser.add_argument('--algorithm', required=True, choices=supported.algorithms) parser.add_argument('--root_dir', required=True, help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).') @@ -34,7 +35,8 @@ def main(): parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?', help='If true, tries to downloads the dataset if it does not exist in root_dir.') parser.add_argument('--frac', type=float, default=1.0, - help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.') + help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes. Note that this also scales the test set down, so the reported numbers are not comparable with the full test set.') + parser.add_argument('--version', default=None, type=str) # Loaders parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={}) @@ -53,7 +55,7 @@ def main(): # Transforms parser.add_argument('--train_transform', choices=supported.transforms) parser.add_argument('--eval_transform', choices=supported.transforms) - parser.add_argument('--target_resolution', nargs='+', type=int, help='target resolution. for example --target_resolution 224 224 for standard resnet.') + parser.add_argument('--target_resolution', nargs='+', type=int, help='The input resolution that images will be resized to before being passed into the model. For example, use --target_resolution 224 224 for a standard ResNet.') parser.add_argument('--resize_scale', type=float) parser.add_argument('--max_token_length', type=int) @@ -87,10 +89,11 @@ def main(): parser.add_argument('--scheduler_metric_name') # Evaluation + parser.add_argument('--process_outputs_function', choices = supported.process_outputs_functions) parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True) parser.add_argument('--eval_splits', nargs='+', default=[]) parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False) - parser.add_argument('--eval_epoch', default=None, type=int) + parser.add_argument('--eval_epoch', default=None, type=int, help='If eval_only is set, then eval_epoch allows you to specify evaluating at a particular epoch. By default, it evaluates the best epoch by validation performance.') # Misc parser.add_argument('--device', type=int, default=0) @@ -133,7 +136,9 @@ def main(): set_seed(config.seed) # Data - full_dataset = supported.datasets[config.dataset]( + full_dataset = wilds.get_dataset( + dataset=config.dataset, + version=config.version, root_dir=config.root_dir, download=config.download, split_scheme=config.split_scheme, @@ -193,7 +198,7 @@ def main(): datasets[split]['split'] = split datasets[split]['name'] = full_dataset.split_names[split] datasets[split]['verbose'] = verbose - # Loggers + # Loggers datasets[split]['eval_logger'] = BatchLogger( os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose)) @@ -204,7 +209,8 @@ def main(): initialize_wandb(config) # Logging dataset info - if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1: + # Show class breakdown if feasible + if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1 and full_dataset.n_classes <= 10: log_grouper = CombinatorialGrouper( dataset=full_dataset, groupby_fields=['y']) @@ -244,7 +250,6 @@ def main(): epoch_offset=0 best_val_metric=None - train( algorithm=algorithm, datasets=datasets, diff --git a/examples/train.py b/examples/train.py index ba29be90..63deaee1 100644 --- a/examples/train.py +++ b/examples/train.py @@ -3,6 +3,7 @@ import torch from utils import save import torch.autograd.profiler as profiler +from configs.supported import process_outputs_functions def log_results(algorithm, dataset, general_logger, epoch, batch_idx): if algorithm.has_log: @@ -22,9 +23,6 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): algorithm.train() else: algorithm.eval() - # process = psutil.Process(os.getpid()) - - # process = psutil.Process(os.getpid()) # Not preallocating memory is slower # but makes it easier to handle different types of data loaders @@ -49,7 +47,10 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): # The subsequent detach is just for safety # (they should already be detached in batch_results) epoch_y_true.append(batch_results['y_true'].clone().detach()) - epoch_y_pred.append(batch_results['y_pred'].clone().detach()) + y_pred = batch_results['y_pred'].clone().detach() + if config.process_outputs_function is not None: + y_pred = process_outputs_functions[config.process_outputs_function](y_pred) + epoch_y_pred.append(y_pred) epoch_metadata.append(batch_results['metadata'].clone().detach()) if train and (batch_idx+1) % config.log_every==0: @@ -135,7 +136,10 @@ def evaluate(algorithm, datasets, epoch, general_logger, config): for batch in iterator: batch_results = algorithm.evaluate(batch) epoch_y_true.append(batch_results['y_true'].clone().detach()) - epoch_y_pred.append(batch_results['y_pred'].clone().detach()) + y_pred = batch_results['y_pred'].clone().detach() + if config.process_outputs_function is not None: + y_pred = process_outputs_functions[config.process_outputs_function](y_pred) + epoch_y_pred.append(y_pred) epoch_metadata.append(batch_results['metadata'].clone().detach()) results, results_str = dataset['dataset'].eval( diff --git a/examples/transforms.py b/examples/transforms.py index cbacb1f1..bafbd42f 100644 --- a/examples/transforms.py +++ b/examples/transforms.py @@ -1,5 +1,5 @@ import torchvision.transforms as transforms -from transformers import BertTokenizerFast +from transformers import BertTokenizerFast, DistilBertTokenizerFast import torch def initialize_transform(transform_name, config, dataset): @@ -17,9 +17,10 @@ def initialize_transform(transform_name, config, dataset): raise ValueError(f"{transform_name} not recognized") def initialize_bert_transform(config): - assert config.model.startswith('bert') + assert 'bert' in config.model assert config.max_token_length is not None - tokenizer = BertTokenizerFast.from_pretrained(config.model) + + tokenizer = getBertTokenizer(config.model) def transform(text): tokens = tokenizer( text, @@ -27,15 +28,31 @@ def transform(text): truncation=True, max_length=config.max_token_length, return_tensors='pt') - x = torch.stack( - (tokens['input_ids'], - tokens['attention_mask'], - tokens['token_type_ids']), - dim=2) + if config.model == 'bert-base-uncased': + x = torch.stack( + (tokens['input_ids'], + tokens['attention_mask'], + tokens['token_type_ids']), + dim=2) + elif config.model == 'distilbert-base-uncased': + x = torch.stack( + (tokens['input_ids'], + tokens['attention_mask']), + dim=2) x = torch.squeeze(x, dim=0) # First shape dim is always 1 return x return transform +def getBertTokenizer(model): + if model == 'bert-base-uncased': + tokenizer = BertTokenizerFast.from_pretrained(model) + elif model == 'distilbert-base-uncased': + tokenizer = DistilBertTokenizerFast.from_pretrained(model) + else: + raise ValueError(f'Model: {model} not recognized.') + + return tokenizer + def initialize_image_base_transform(config, dataset): transform_steps = [] if dataset.original_resolution is not None and min(dataset.original_resolution)!=max(dataset.original_resolution): diff --git a/examples/utils.py b/examples/utils.py index dcfcba3e..8a12f859 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -30,7 +30,7 @@ class ParseKwargs(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, dict()) for value in values: - key, value_str = value.split('=') + key, value_str = value.split('=') if value_str.replace('-','').isnumeric(): processed_val = int(value_str) elif value_str.replace('-','').replace('.','').isnumeric(): diff --git a/setup.py b/setup.py index ab7c7c98..9cd1f596 100644 --- a/setup.py +++ b/setup.py @@ -26,13 +26,13 @@ 'scikit-learn>=0.20.0', 'pillow>=7.2.0', 'torch>=1.7.0', - 'ogb>=1.2.3', + 'ogb>=1.2.6', 'tqdm>=4.53.0', 'outdated>=0.2.0', 'pytz>=2020.4', ], license='MIT', - packages=setuptools.find_packages(exclude=['dataset_preprocessing', 'examples', 'examples.models']), + packages=setuptools.find_packages(exclude=['dataset_preprocessing', 'examples', 'examples.models', 'examples.models.bert']), classifiers=[ 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Intended Audience :: Science/Research', diff --git a/wilds/__init__.py b/wilds/__init__.py index 77ac4a0d..77f0ad5a 100644 --- a/wilds/__init__.py +++ b/wilds/__init__.py @@ -1 +1,23 @@ from .version import __version__ +from .get_dataset import get_dataset + +benchmark_datasets = [ + 'amazon', + 'camelyon17', + 'civilcomments', + 'iwildcam', + 'ogb-molpcba', + 'poverty', + 'fmow', + 'py150', +] + +additional_datasets = [ + 'celebA', + 'waterbirds', + 'yelp', + 'bdd100k', + 'sqf', +] + +supported_datasets = benchmark_datasets + additional_datasets diff --git a/wilds/common/grouper.py b/wilds/common/grouper.py index 2c6f8d82..07dc92a3 100644 --- a/wilds/common/grouper.py +++ b/wilds/common/grouper.py @@ -87,11 +87,13 @@ def __init__(self, dataset, groupby_fields): # Note that this might result in some empty groups. self.groupby_field_indices = [i for (i, field) in enumerate(dataset.metadata_fields) if field in groupby_fields] if len(self.groupby_field_indices) != len(self.groupby_fields): - raise ValueError('at least one group field not found in dataset.metadata_fields') + raise ValueError('At least one group field not found in dataset.metadata_fields') grouped_metadata = dataset.metadata_array[:, self.groupby_field_indices] if not isinstance(grouped_metadata, torch.LongTensor): - warnings.warn(f'CombinatorialGrouper: converting metadata with fields [{", ".join(groupby_fields)}] into long') - grouped_metadata = grouped_metadata.long() + grouped_metadata_long = grouped_metadata.long() + if not torch.all(grouped_metadata == grouped_metadata_long): + warnings.warn(f'CombinatorialGrouper: converting metadata with fields [{", ".join(groupby_fields)}] into long') + grouped_metadata = grouped_metadata_long for idx, field in enumerate(self.groupby_fields): min_value = grouped_metadata[:,idx].min() if min_value < 0: @@ -150,4 +152,3 @@ def group_str(self, group): def group_field_str(self, group): return self.group_str(group).replace('=', ':').replace(',','_').replace(' ','') - diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 3c2af169..0f5d7eb1 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -8,7 +8,7 @@ import sklearn.metrics from scipy.stats import pearsonr -def logits_to_score(logits): +def binary_logits_to_score(logits): assert logits.dim() in (1,2) if logits.dim()==2: #multi-class logits assert logits.size(1)==2, "Only binary classification" @@ -17,22 +17,19 @@ def logits_to_score(logits): score = logits return score -def logits_to_pred(logits): - assert logits.dim() in (1,2) - if logits.dim()==2: #multi-class logits - pred = torch.argmax(logits, 1) - else: - pred = (logits>0).long() - return pred - -def logits_to_binary_pred(logits): - assert logits.dim() in (1,2) - pred = (logits>0).long() - return pred +def multiclass_logits_to_pred(logits): + """ + Takes multi-class logits of size (batch_size, ..., n_classes) and returns predictions + by taking an argmax at the last dimension + """ + assert logits.dim() > 1 + return logits.argmax(-1) +def binary_logits_to_pred(logits): + return (logits>0).long() class Accuracy(ElementwiseMetric): - def __init__(self, prediction_fn=logits_to_pred, name=None): + def __init__(self, prediction_fn=None, name=None): self.prediction_fn = prediction_fn if name is None: name = 'acc' @@ -47,7 +44,7 @@ def worst(self, metrics): return minimum(metrics) class MultiTaskAccuracy(MultiTaskMetric): - def __init__(self, prediction_fn=logits_to_binary_pred, name=None): + def __init__(self, prediction_fn=None, name=None): self.prediction_fn = prediction_fn # should work on flattened inputs if name is None: name = 'acc' @@ -62,7 +59,7 @@ def worst(self, metrics): return minimum(metrics) class Recall(Metric): - def __init__(self, prediction_fn=logits_to_pred, name=None, average='binary'): + def __init__(self, prediction_fn=None, name=None, average='binary'): self.prediction_fn = prediction_fn if name is None: name = f'recall' @@ -81,7 +78,7 @@ def worst(self, metrics): return minimum(metrics) class F1(Metric): - def __init__(self, prediction_fn=logits_to_pred, name=None, average='binary'): + def __init__(self, prediction_fn=None, name=None, average='binary'): self.prediction_fn = prediction_fn if name is None: name = f'F1' @@ -128,3 +125,20 @@ def __init__(self, name=None): if name is None: name = 'mse' super().__init__(name=name, loss_fn=mse_loss) + +class PrecisionAtRecall(Metric): + """Given a specific model threshold, determine the precision score achieved""" + def __init__(self, threshold, score_fn=None, name=None): + self.score_fn = score_fn + self.threshold = threshold + if name is None: + name = "precision_at_global_recall" + super().__init__(name=name) + + def _compute(self, y_pred, y_true): + score = self.score_fn(y_pred) + predictions = (score > self.threshold) + return torch.tensor(sklearn.metrics.precision_score(y_true, predictions)) + + def worst(self, metrics): + return minimum(metrics) diff --git a/wilds/common/metrics/loss.py b/wilds/common/metrics/loss.py index 40df9b0b..4d2aa1ad 100644 --- a/wilds/common/metrics/loss.py +++ b/wilds/common/metrics/loss.py @@ -29,7 +29,7 @@ def worst(self, metrics): - worst_metric (float): Worst-case metric """ return maximum(metrics) - + class ElementwiseLoss(ElementwiseMetric): def __init__(self, loss_fn, name=None): self.loss_fn = loss_fn @@ -69,6 +69,8 @@ def _compute_flattened(self, flattened_y_pred, flattened_y_true): if isinstance(self.loss_fn, torch.nn.BCEWithLogitsLoss): flattened_y_pred = flattened_y_pred.float() flattened_y_true = flattened_y_true.float() + elif isinstance(self.loss_fn, torch.nn.CrossEntropyLoss): + flattened_y_true = flattened_y_true.long() flattened_loss = self.loss_fn(flattened_y_pred, flattened_y_true) return flattened_loss @@ -81,4 +83,3 @@ def worst(self, metrics): - worst_metric (float): Worst-case metric """ return maximum(metrics) - diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index 4c3e8440..9c4372b0 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -135,7 +135,7 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups): y_true[g == group_idx])) group_metrics = torch.stack(group_metrics) worst_group_metric = self.worst(group_metrics[group_counts>0]) - + return group_metrics, group_counts, worst_group_metric class ElementwiseMetric(Metric): @@ -212,7 +212,7 @@ def compute_element_wise(self, y_pred, y_true, return_dict=True): def compute_flattened(self, y_pred, y_true, return_dict=True): flattened_metrics = self.compute_element_wise(y_pred, y_true, return_dict=False) - index = torch.arange(y_true.numel()) + index = torch.arange(y_true.numel()) if return_dict: return {self.name: flattened_metrics, 'index': index} else: diff --git a/wilds/common/utils.py b/wilds/common/utils.py index 9fd6426f..7854393a 100644 --- a/wilds/common/utils.py +++ b/wilds/common/utils.py @@ -122,7 +122,7 @@ def shuffle_arr(arr, seed=None): rng.shuffle(arr) return arr -def threshold_at_recall(y_pred, y_true, global_recall=0.6): +def threshold_at_recall(y_pred, y_true, global_recall=60): """ Calculate the model threshold to use to achieve a desired global_recall level. Assumes that y_true is a vector of the true binary labels.""" - return np.percentile(y_pred[y_true == 1], global_recall) + return np.percentile(y_pred[y_true == 1], 100-global_recall) diff --git a/wilds/datasets/amazon_dataset.py b/wilds/datasets/amazon_dataset.py index 518beee2..81e633b7 100644 --- a/wilds/datasets/amazon_dataset.py +++ b/wilds/datasets/amazon_dataset.py @@ -50,12 +50,20 @@ class AmazonDataset(WILDSDataset): License: None. However, the original authors request that the data be used for research purposes only. """ - def __init__(self, root_dir='data', download=False, split_scheme='official'): - # set variables - self._dataset_name = 'amazon' - self._version = '1.0' - self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x60237058e01749cda7b0701c2bd01420/contents/blob/' - self._compressed_size = 4_066_541_568 + _dataset_name = 'amazon' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x60237058e01749cda7b0701c2bd01420/contents/blob/', + 'compressed_size': 4_066_541_568 + }, + '2.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xadbf6198d3a64bdc96fb64d6966b5e79/contents/blob/', + 'compressed_size': 1_987_523_759 + }, + } + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + self._version = version # the official split is the user split if split_scheme=='official': split_scheme = 'user' @@ -85,41 +93,54 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self.initialize_split_dicts() # eval self.initialize_eval_grouper() - self._metric = Accuracy() super().__init__(root_dir, download, split_scheme) def get_input(self, idx): return self._input_array[idx] - def eval(self, y_pred, y_true, metadata): + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). + But they can also be other model outputs such that prediction_fn(y_pred) + are predicted labels. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + metric = Accuracy(prediction_fn=prediction_fn) if self.split_scheme=='user': # first compute groupwise accuracies g = self._eval_grouper.metadata_to_group(metadata) results = { - **self._metric.compute(y_pred, y_true), - **self._metric.compute_group_wise(y_pred, y_true, g, self._eval_grouper.n_groups) + **metric.compute(y_pred, y_true), + **metric.compute_group_wise(y_pred, y_true, g, self._eval_grouper.n_groups) } accs = [] for group_idx in range(self._eval_grouper.n_groups): group_str = self._eval_grouper.group_field_str(group_idx) - group_metric = results.pop(self._metric.group_metric_field(group_idx)) - group_counts = results.pop(self._metric.group_count_field(group_idx)) - results[f'{self._metric.name}_{group_str}'] = group_metric + group_metric = results.pop(metric.group_metric_field(group_idx)) + group_counts = results.pop(metric.group_count_field(group_idx)) + results[f'{metric.name}_{group_str}'] = group_metric results[f'count_{group_str}'] = group_counts if group_counts>0: accs.append(group_metric) accs = np.array(accs) results['10th_percentile_acc'] = np.percentile(accs, 10) - results[f'{self._metric.worst_group_metric_field}'] = self._metric.worst(accs) + results[f'{metric.worst_group_metric_field}'] = metric.worst(accs) results_str = ( - f"Average {self._metric.name}: {results[self._metric.agg_metric_field]:.3f}\n" - f"10th percentile {self._metric.name}: {results['10th_percentile_acc']:.3f}\n" - f"Worst-group {self._metric.name}: {results[self._metric.worst_group_metric_field]:.3f}\n" + f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n" + f"10th percentile {metric.name}: {results['10th_percentile_acc']:.3f}\n" + f"Worst-group {metric.name}: {results[metric.worst_group_metric_field]:.3f}\n" ) return results, results_str else: return self.standard_group_eval( - self._metric, + metric, self._eval_grouper, y_pred, y_true, metadata) diff --git a/wilds/datasets/archive/__init__.py b/wilds/datasets/archive/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/wilds/datasets/archive/fmow_v1_0_dataset.py b/wilds/datasets/archive/fmow_v1_0_dataset.py new file mode 100644 index 00000000..2fef7d51 --- /dev/null +++ b/wilds/datasets/archive/fmow_v1_0_dataset.py @@ -0,0 +1,230 @@ +from pathlib import Path +import shutil +import pandas as pd +import torch +from torch.utils.data import Dataset +import pickle +import numpy as np +import torchvision.transforms.functional as F +from torchvision import transforms +import tarfile +import datetime +import pytz +from PIL import Image +from tqdm import tqdm +from wilds.common.utils import subsample_idxs +from wilds.common.metrics.all_metrics import Accuracy +from wilds.common.grouper import CombinatorialGrouper +from wilds.datasets.wilds_dataset import WILDSDataset + +Image.MAX_IMAGE_PIXELS = 10000000000 + + +categories = ["airport", "airport_hangar", "airport_terminal", "amusement_park", "aquaculture", "archaeological_site", "barn", "border_checkpoint", "burial_site", "car_dealership", "construction_site", "crop_field", "dam", "debris_or_rubble", "educational_institution", "electric_substation", "factory_or_powerplant", "fire_station", "flooded_road", "fountain", "gas_station", "golf_course", "ground_transportation_station", "helipad", "hospital", "impoverished_settlement", "interchange", "lake_or_pond", "lighthouse", "military_facility", "multi-unit_residential", "nuclear_powerplant", "office_building", "oil_or_gas_facility", "park", "parking_lot_or_garage", "place_of_worship", "police_station", "port", "prison", "race_track", "railway_bridge", "recreational_facility", "road_bridge", "runway", "shipyard", "shopping_mall", "single-unit_residential", "smokestack", "solar_farm", "space_facility", "stadium", "storage_tank", "surface_mine", "swimming_pool", "toll_booth", "tower", "tunnel_opening", "waste_disposal", "water_treatment_facility", "wind_farm", "zoo"] + + +class FMoWDataset(WILDSDataset): + """ + The Functional Map of the World land use / building classification dataset. + This is a processed version of the Functional Map of the World dataset originally sourced from https://github.com/fMoW/dataset. + + Support `split_scheme` + 'official': official split, which is equivalent to 'time_after_2016' + `time_after_{YEAR}` for YEAR between 2002--2018 + + Input (x): + 224 x 224 x 3 RGB satellite image. + + Label (y): + y is one of 62 land use / building classes + + Metadata: + each image is annotated with a location coordinate, timestamp, country code. This dataset computes region as a derivative of country code. + + Website: https://github.com/fMoW/dataset + + Original publication: + @inproceedings{fmow2018, + title={Functional Map of the World}, + author={Christie, Gordon and Fendley, Neil and Wilson, James and Mukherjee, Ryan}, + booktitle={CVPR}, + year={2018} + } + + License: + Distributed under the FMoW Challenge Public License. + https://github.com/fMoW/dataset/blob/master/LICENSE + + """ + _dataset_name = 'fmow' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xc59ea8261dfe4d2baa3820866e33d781/contents/blob/', + 'compressed_size': 70_000_000_000} + } + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', oracle_training_set=False, seed=111, use_ood_val=False): + self._version = version + self._data_dir = self.initialize_data_dir(root_dir, download) + + self._split_dict = {'train': 0, 'id_val': 1, 'id_test': 2, 'val': 3, 'test': 4} + self._split_names = {'train': 'Train', 'id_val': 'ID Val', 'id_test': 'ID Test', 'val': 'OOD Val', 'test': 'OOD Test'} + if split_scheme=='official': + split_scheme='time_after_2016' + self._split_scheme = split_scheme + self.oracle_training_set = oracle_training_set + + self.root = Path(self._data_dir) + self.seed = int(seed) + self._original_resolution = (224, 224) + + self.category_to_idx = {cat: i for i, cat in enumerate(categories)} + + self.metadata = pd.read_csv(self.root / 'rgb_metadata.csv') + country_codes_df = pd.read_csv(self.root / 'country_code_mapping.csv') + countrycode_to_region = {k: v for k, v in zip(country_codes_df['alpha-3'], country_codes_df['region'])} + regions = [countrycode_to_region.get(code, 'Other') for code in self.metadata['country_code'].to_list()] + self.metadata['region'] = regions + all_countries = self.metadata['country_code'] + + self.num_chunks = 101 + self.chunk_size = len(self.metadata) // (self.num_chunks - 1) + + if self._split_scheme.startswith('time_after'): + year = int(self._split_scheme.split('_')[2]) + year_dt = datetime.datetime(year, 1, 1, tzinfo=pytz.UTC) + self.test_ood_mask = np.asarray(pd.to_datetime(self.metadata['timestamp']) >= year_dt) + # use 3 years of the training set as validation + year_minus_3_dt = datetime.datetime(year-3, 1, 1, tzinfo=pytz.UTC) + self.val_ood_mask = np.asarray(pd.to_datetime(self.metadata['timestamp']) >= year_minus_3_dt) & ~self.test_ood_mask + self.ood_mask = self.test_ood_mask | self.val_ood_mask + else: + raise ValueError(f"Not supported: self._split_scheme = {self._split_scheme}") + + self._split_array = -1 * np.ones(len(self.metadata)) + for split in self._split_dict.keys(): + idxs = np.arange(len(self.metadata)) + if split == 'test': + test_mask = np.asarray(self.metadata['split'] == 'test') + idxs = idxs[self.test_ood_mask & test_mask] + elif split == 'val': + val_mask = np.asarray(self.metadata['split'] == 'val') + idxs = idxs[self.val_ood_mask & val_mask] + elif split == 'id_test': + test_mask = np.asarray(self.metadata['split'] == 'test') + idxs = idxs[~self.ood_mask & test_mask] + elif split == 'id_val': + val_mask = np.asarray(self.metadata['split'] == 'val') + idxs = idxs[~self.ood_mask & val_mask] + else: + split_mask = np.asarray(self.metadata['split'] == split) + idxs = idxs[~self.ood_mask & split_mask] + + if self.oracle_training_set and split == 'train': + test_mask = np.asarray(self.metadata['split'] == 'test') + unused_ood_idxs = np.arange(len(self.metadata))[self.ood_mask & ~test_mask] + subsample_unused_ood_idxs = subsample_idxs(unused_ood_idxs, num=len(idxs)//2, seed=self.seed+2) + subsample_train_idxs = subsample_idxs(idxs.copy(), num=len(idxs) // 2, seed=self.seed+3) + idxs = np.concatenate([subsample_unused_ood_idxs, subsample_train_idxs]) + self._split_array[idxs] = self._split_dict[split] + + if not use_ood_val: + self._split_dict = {'train': 0, 'val': 1, 'id_test': 2, 'ood_val': 3, 'test': 4} + self._split_names = {'train': 'Train', 'val': 'ID Val', 'id_test': 'ID Test', 'ood_val': 'OOD Val', 'test': 'OOD Test'} + + # filter out sequestered images from full dataset + seq_mask = np.asarray(self.metadata['split'] == 'seq') + # take out the sequestered images + self._split_array = self._split_array[~seq_mask] + self.full_idxs = np.arange(len(self.metadata))[~seq_mask] + + self._y_array = np.asarray([self.category_to_idx[y] for y in list(self.metadata['category'])]) + self.metadata['y'] = self._y_array + self._y_array = torch.from_numpy(self._y_array).long()[~seq_mask] + self._y_size = 1 + self._n_classes = 62 + + # convert region to idxs + all_regions = list(self.metadata['region'].unique()) + region_to_region_idx = {region: i for i, region in enumerate(all_regions)} + self._metadata_map = {'region': all_regions} + region_idxs = [region_to_region_idx[region] for region in self.metadata['region'].tolist()] + self.metadata['region'] = region_idxs + + # make a year column in metadata + year_array = -1 * np.ones(len(self.metadata)) + ts = pd.to_datetime(self.metadata['timestamp']) + for year in range(2002, 2018): + year_mask = np.asarray(ts >= datetime.datetime(year, 1, 1, tzinfo=pytz.UTC)) \ + & np.asarray(ts < datetime.datetime(year+1, 1, 1, tzinfo=pytz.UTC)) + year_array[year_mask] = year - 2002 + self.metadata['year'] = year_array + self._metadata_map['year'] = list(range(2002, 2018)) + + self._metadata_fields = ['region', 'year', 'y'] + self._metadata_array = torch.from_numpy(self.metadata[self._metadata_fields].astype(int).to_numpy()).long()[~seq_mask] + + self._eval_groupers = { + 'year': CombinatorialGrouper(dataset=self, groupby_fields=['year']), + 'region': CombinatorialGrouper(dataset=self, groupby_fields=['region']), + } + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + """ + Returns x for a given idx. + """ + idx = self.full_idxs[idx] + batch_idx = idx // self.chunk_size + within_batch_idx = idx % self.chunk_size + img_batch = np.load(self.root / f'rgb_all_imgs_{batch_idx}.npy', mmap_mode='r') + img = img_batch[within_batch_idx].copy() + return img + + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). + But they can also be other model outputs such that prediction_fn(y_pred) + are predicted labels. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + metric = Accuracy(prediction_fn=prediction_fn) + # Overall evaluation + evaluate by year + all_results, all_results_str = self.standard_group_eval( + metric, + self._eval_groupers['year'], + y_pred, y_true, metadata) + # Evaluate by region and ignore the "Other" region + region_grouper = self._eval_groupers['region'] + region_results = metric.compute_group_wise( + y_pred, + y_true, + region_grouper.metadata_to_group(metadata), + region_grouper.n_groups) + all_results[f'{metric.name}_worst_year'] = all_results.pop(metric.worst_group_metric_field) + region_metric_list = [] + for group_idx in range(region_grouper.n_groups): + group_str = region_grouper.group_field_str(group_idx) + group_metric = region_results[metric.group_metric_field(group_idx)] + group_counts = region_results[metric.group_count_field(group_idx)] + all_results[f'{metric.name}_{group_str}'] = group_metric + all_results[f'count_{group_str}'] = group_counts + if region_results[metric.group_count_field(group_idx)] == 0 or "Other" in group_str: + continue + all_results_str += ( + f' {region_grouper.group_str(group_idx)} ' + f"[n = {region_results[metric.group_count_field(group_idx)]:6.0f}]:\t" + f"{metric.name} = {region_results[metric.group_metric_field(group_idx)]:5.3f}\n") + region_metric_list.append(region_results[metric.group_metric_field(group_idx)]) + all_results[f'{metric.name}_worst_region'] = metric.worst(region_metric_list) + all_results_str += f"Worst-group {metric.name}: {all_results[f'{metric.name}_worst_region']:.3f}\n" + + return all_results, all_results_str diff --git a/wilds/datasets/archive/iwildcam_v1_0_dataset.py b/wilds/datasets/archive/iwildcam_v1_0_dataset.py new file mode 100644 index 00000000..49c53d1e --- /dev/null +++ b/wilds/datasets/archive/iwildcam_v1_0_dataset.py @@ -0,0 +1,168 @@ +from datetime import datetime +from pathlib import Path +import os + +from PIL import Image +import pandas as pd +import numpy as np +import torch +import json + +from wilds.datasets.wilds_dataset import WILDSDataset +from wilds.common.grouper import CombinatorialGrouper +from wilds.common.metrics.all_metrics import Accuracy, Recall, F1 + + +class IWildCamDataset(WILDSDataset): + """ + The iWildCam2020 dataset. + This is a modified version of the original iWildCam2020 competition dataset. + Input (x): + RGB images from camera traps + Label (y): + y is one of 186 classes corresponding to animal species + Metadata: + Each image is annotated with the ID of the location (camera trap) it came from. + Website: + https://www.kaggle.com/c/iwildcam-2020-fgvc7 + Original publication: + @article{beery2020iwildcam, + title={The iWildCam 2020 Competition Dataset}, + author={Beery, Sara and Cole, Elijah and Gjoka, Arvi}, + journal={arXiv preprint arXiv:2004.10340}, + year={2020} + } + License: + This dataset is distributed under Community Data License Agreement – Permissive – Version 1.0 + https://cdla.io/permissive-1-0/ + """ + _dataset_name = 'iwildcam' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x3f1b346ff2d74b5daf1a08685d68c6ec/contents/blob/', + 'compressed_size': 90_094_666_806}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + + self._version = version + self._split_scheme = split_scheme + if self._split_scheme != 'official': + raise ValueError(f'Split scheme {self._split_scheme} not recognized') + + # path + self._data_dir = Path(self.initialize_data_dir(root_dir, download)) + + # Load splits + train_df = pd.read_csv(self._data_dir / 'train.csv') + val_trans_df = pd.read_csv(self._data_dir / 'val_trans.csv') + test_trans_df = pd.read_csv(self._data_dir / 'test_trans.csv') + val_cis_df = pd.read_csv(self._data_dir / 'val_cis.csv') + test_cis_df = pd.read_csv(self._data_dir / 'test_cis.csv') + + # Merge all dfs + train_df['split'] = 'train' + val_trans_df['split'] = 'val' + test_trans_df['split'] = 'test' + val_cis_df['split'] = 'id_val' + test_cis_df['split'] = 'id_test' + df = pd.concat([train_df, val_trans_df, test_trans_df, test_cis_df, val_cis_df]) + + # Splits + data = {} + self._split_dict = {'train': 0, 'val': 1, 'test': 2, 'id_val': 3, 'id_test': 4} + self._split_names = {'train': 'Train', 'val': 'Validation (OOD/Trans)', + 'test': 'Test (OOD/Trans)', 'id_val': 'Validation (ID/Cis)', + 'id_test': 'Test (ID/Cis)'} + + df['split_id'] = df['split'].apply(lambda x: self._split_dict[x]) + self._split_array = df['split_id'].values + + # Filenames + self._input_array = df['filename'].values + + # Labels + unique_categories = np.unique(df['category_id']) + self._n_classes = len(unique_categories) + category_to_label = dict([(i, j) for i, j in zip(unique_categories, range(self._n_classes))]) + label_to_category = dict([(v, k) for k, v in category_to_label.items()]) + self._y_array = torch.tensor(df['category_id'].apply(lambda x: category_to_label[x]).values) + self._y_size = 1 + + # Location/group info + location_ids = df['location'] + locations = np.unique(location_ids) + n_groups = len(locations) + location_to_group_id = {locations[i]: i for i in range(n_groups)} + df['group_id' ] = df['location'].apply(lambda x: location_to_group_id[x]) + + self._n_groups = n_groups + + # Extract datetime subcomponents and include in metadata + df['datetime_obj'] = df['datetime'].apply(lambda x: datetime.strptime(x, '%Y-%m-%d %H:%M:%S.%f')) + df['year'] = df['datetime_obj'].apply(lambda x: int(x.year)) + df['month'] = df['datetime_obj'].apply(lambda x: int(x.month)) + df['day'] = df['datetime_obj'].apply(lambda x: int(x.day)) + df['hour'] = df['datetime_obj'].apply(lambda x: int(x.hour)) + df['minute'] = df['datetime_obj'].apply(lambda x: int(x.minute)) + df['second'] = df['datetime_obj'].apply(lambda x: int(x.second)) + + self._metadata_array = torch.tensor(np.stack([df['group_id'].values, + df['year'].values, df['month'].values, df['day'].values, + df['hour'].values, df['minute'].values, df['second'].values, + self.y_array], axis=1)) + self._metadata_fields = ['location', 'year', 'month', 'day', 'hour', 'minute', 'second', 'y'] + # eval grouper + self._eval_grouper = CombinatorialGrouper( + dataset=self, + groupby_fields=(['location'])) + + super().__init__(root_dir, download, split_scheme) + + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). + But they can also be other model outputs such that prediction_fn(y_pred) + are predicted labels. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + metrics = [ + Accuracy(prediction_fn=prediction_fn), + Recall(prediction_fn=prediction_fn, average='macro'), + F1(prediction_fn=prediction_fn, average='macro'), + ] + + results = {} + + for i in range(len(metrics)): + results.update({ + **metrics[i].compute(y_pred, y_true), + }) + + results_str = ( + f"Average acc: {results[metrics[0].agg_metric_field]:.3f}\n" + f"Recall macro: {results[metrics[1].agg_metric_field]:.3f}\n" + f"F1 macro: {results[metrics[2].agg_metric_field]:.3f}\n" + ) + + return results, results_str + + def get_input(self, idx): + """ + Args: + - idx (int): Index of a data point + Output: + - x (Tensor): Input features of the idx-th data point + """ + + # All images are in the train folder + img_path = self.data_dir / 'train' / self._input_array[idx] + img = Image.open(img_path) + + return img diff --git a/wilds/datasets/archive/poverty_v1_0_dataset.py b/wilds/datasets/archive/poverty_v1_0_dataset.py new file mode 100644 index 00000000..438e7beb --- /dev/null +++ b/wilds/datasets/archive/poverty_v1_0_dataset.py @@ -0,0 +1,280 @@ +from pathlib import Path +import pandas as pd +import torch +from torch.utils.data import Dataset +import pickle +import numpy as np +from wilds.datasets.wilds_dataset import WILDSDataset +from wilds.common.metrics.all_metrics import MSE, PearsonCorrelation +from wilds.common.grouper import CombinatorialGrouper +from wilds.common.utils import subsample_idxs, shuffle_arr + +DATASET = '2009-17' +BAND_ORDER = ['BLUE', 'GREEN', 'RED', 'SWIR1', 'SWIR2', 'TEMP1', 'NIR', 'NIGHTLIGHTS'] + + +DHS_COUNTRIES = [ + 'angola', 'benin', 'burkina_faso', 'cameroon', 'cote_d_ivoire', + 'democratic_republic_of_congo', 'ethiopia', 'ghana', 'guinea', 'kenya', + 'lesotho', 'malawi', 'mali', 'mozambique', 'nigeria', 'rwanda', 'senegal', + 'sierra_leone', 'tanzania', 'togo', 'uganda', 'zambia', 'zimbabwe'] + +_SURVEY_NAMES_2009_17A = { + 'train': ['cameroon', 'democratic_republic_of_congo', 'ghana', 'kenya', + 'lesotho', 'malawi', 'mozambique', 'nigeria', 'senegal', + 'togo', 'uganda', 'zambia', 'zimbabwe'], + 'val': ['benin', 'burkina_faso', 'guinea', 'sierra_leone', 'tanzania'], + 'test': ['angola', 'cote_d_ivoire', 'ethiopia', 'mali', 'rwanda'], +} +_SURVEY_NAMES_2009_17B = { + 'train': ['angola', 'cote_d_ivoire', 'democratic_republic_of_congo', + 'ethiopia', 'kenya', 'lesotho', 'mali', 'mozambique', + 'nigeria', 'rwanda', 'senegal', 'togo', 'uganda', 'zambia'], + 'val': ['cameroon', 'ghana', 'malawi', 'zimbabwe'], + 'test': ['benin', 'burkina_faso', 'guinea', 'sierra_leone', 'tanzania'], +} +_SURVEY_NAMES_2009_17C = { + 'train': ['angola', 'benin', 'burkina_faso', 'cote_d_ivoire', 'ethiopia', + 'guinea', 'kenya', 'lesotho', 'mali', 'rwanda', 'senegal', + 'sierra_leone', 'tanzania', 'zambia'], + 'val': ['democratic_republic_of_congo', 'mozambique', 'nigeria', 'togo', 'uganda'], + 'test': ['cameroon', 'ghana', 'malawi', 'zimbabwe'], +} +_SURVEY_NAMES_2009_17D = { + 'train': ['angola', 'benin', 'burkina_faso', 'cameroon', 'cote_d_ivoire', + 'ethiopia', 'ghana', 'guinea', 'malawi', 'mali', 'rwanda', + 'sierra_leone', 'tanzania', 'zimbabwe'], + 'val': ['kenya', 'lesotho', 'senegal', 'zambia'], + 'test': ['democratic_republic_of_congo', 'mozambique', 'nigeria', 'togo', 'uganda'], +} +_SURVEY_NAMES_2009_17E = { + 'train': ['benin', 'burkina_faso', 'cameroon', 'democratic_republic_of_congo', + 'ghana', 'guinea', 'malawi', 'mozambique', 'nigeria', 'sierra_leone', + 'tanzania', 'togo', 'uganda', 'zimbabwe'], + 'val': ['angola', 'cote_d_ivoire', 'ethiopia', 'mali', 'rwanda'], + 'test': ['kenya', 'lesotho', 'senegal', 'zambia'], +} + +SURVEY_NAMES = { + '2009-17A': _SURVEY_NAMES_2009_17A, + '2009-17B': _SURVEY_NAMES_2009_17B, + '2009-17C': _SURVEY_NAMES_2009_17C, + '2009-17D': _SURVEY_NAMES_2009_17D, + '2009-17E': _SURVEY_NAMES_2009_17E, +} + + +# means and standard deviations calculated over the entire dataset (train + val + test), +# with negative values set to 0, and ignoring any pixel that is 0 across all bands +# all images have already been mean subtracted and normalized (x - mean) / std + +_MEANS_2009_17 = { + 'BLUE': 0.059183, + 'GREEN': 0.088619, + 'RED': 0.104145, + 'SWIR1': 0.246874, + 'SWIR2': 0.168728, + 'TEMP1': 299.078023, + 'NIR': 0.253074, + 'DMSP': 4.005496, + 'VIIRS': 1.096089, + # 'NIGHTLIGHTS': 5.101585, # nightlights overall +} + +_STD_DEVS_2009_17 = { + 'BLUE': 0.022926, + 'GREEN': 0.031880, + 'RED': 0.051458, + 'SWIR1': 0.088857, + 'SWIR2': 0.083240, + 'TEMP1': 4.300303, + 'NIR': 0.058973, + 'DMSP': 23.038301, + 'VIIRS': 4.786354, + # 'NIGHTLIGHTS': 23.342916, # nightlights overall +} + + +def split_by_countries(idxs, ood_countries, metadata): + countries = np.asarray(metadata['country'].iloc[idxs]) + is_ood = np.any([(countries == country) for country in ood_countries], axis=0) + return idxs[~is_ood], idxs[is_ood] + + +class PovertyMapDataset(WILDSDataset): + """ + The PovertyMap poverty measure prediction dataset. + This is a processed version of LandSat 5/7/8 satellite imagery originally from Google Earth Engine under the names `LANDSAT/LC08/C01/T1_SR`,`LANDSAT/LE07/C01/T1_SR`,`LANDSAT/LT05/C01/T1_SR`, + nighttime light imagery from the DMSP and VIIRS satellites (Google Earth Engine names `NOAA/DMSP-OLS/CALIBRATED_LIGHTS_V4` and `NOAA/VIIRS/DNB/MONTHLY_V1/VCMSLCFG`) + and processed DHS survey metadata obtained from https://github.com/sustainlab-group/africa_poverty and originally from `https://dhsprogram.com/data/available-datasets.cfm`. + + Supported `split_scheme`: + 'official' and `countries`, which are equivalent + + Input (x): + 224 x 224 x 8 satellite image, with 7 channels from LandSat and 1 nighttime light channel from DMSP/VIIRS. Already mean/std normalized. + + Output (y): + y is a real-valued asset wealth index. Higher index corresponds to more asset wealth. + + Metadata: + each image is annotated with location coordinates (noised for anonymity), survey year, urban/rural classification, country, nighttime light mean, nighttime light median. + + Website: https://github.com/sustainlab-group/africa_poverty + + Original publication: + @article{yeh2020using, + author = {Yeh, Christopher and Perez, Anthony and Driscoll, Anne and Azzari, George and Tang, Zhongyi and Lobell, David and Ermon, Stefano and Burke, Marshall}, + day = {22}, + doi = {10.1038/s41467-020-16185-w}, + issn = {2041-1723}, + journal = {Nature Communications}, + month = {5}, + number = {1}, + title = {{Using publicly available satellite imagery and deep learning to understand economic well-being in Africa}}, + url = {https://www.nature.com/articles/s41467-020-16185-w}, + volume = {11}, + year = {2020} + } + + License: + LandSat/DMSP/VIIRS data is U.S. Public Domain. + + """ + _dataset_name = 'poverty' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x9a2add5219db4ebc89965d7f42719750/contents/blob/', + 'compressed_size': 18_630_656_000}} + + def __init__(self, version=None, root_dir='data', download=False, + split_scheme='official', + no_nl=False, fold='A', oracle_training_set=False, + use_ood_val=True, + cache_size=100): + self._version = version + self._data_dir = self.initialize_data_dir(root_dir, download) + + self._split_dict = {'train': 0, 'id_val': 1, 'id_test': 2, 'val': 3, 'test': 4} + self._split_names = {'train': 'Train', 'id_val': 'ID Val', 'id_test': 'ID Test', 'val': 'OOD Val', 'test': 'OOD Test'} + + if split_scheme=='official': + split_scheme = 'countries' + self._split_scheme = split_scheme + if self._split_scheme != 'countries': + raise ValueError("Split scheme not recognized") + + self.oracle_training_set = oracle_training_set + + self.no_nl = no_nl + if fold not in {'A', 'B', 'C', 'D', 'E'}: + raise ValueError("Fold must be A, B, C, D, or E") + + self.root = Path(self._data_dir) + self.metadata = pd.read_csv(self.root / 'dhs_metadata.csv') + # country folds, split off OOD + country_folds = SURVEY_NAMES[f'2009-17{fold}'] + + self._split_array = -1 * np.ones(len(self.metadata)) + + incountry_folds_split = np.arange(len(self.metadata)) + # take the test countries to be ood + idxs_id, idxs_ood_test = split_by_countries(incountry_folds_split, country_folds['test'], self.metadata) + # also create a validation OOD set + idxs_id, idxs_ood_val = split_by_countries(idxs_id, country_folds['val'], self.metadata) + for split in ['test', 'val', 'id_test', 'id_val', 'train']: + # keep ood for test, otherwise throw away ood data + if split == 'test': + idxs = idxs_ood_test + elif split == 'val': + idxs = idxs_ood_val + else: + idxs = idxs_id + num_eval = 2000 + # if oracle, do 50-50 split between OOD and ID + if split == 'train' and self.oracle_training_set: + idxs = subsample_idxs(incountry_folds_split, num=len(idxs_id), seed=ord(fold))[num_eval:] + elif split != 'train' and self.oracle_training_set: + eval_idxs = subsample_idxs(incountry_folds_split, num=len(idxs_id), seed=ord(fold))[:num_eval] + elif split == 'train': + idxs = subsample_idxs(idxs, take_rest=True, num=num_eval, seed=ord(fold)) + else: + eval_idxs = subsample_idxs(idxs, take_rest=False, num=num_eval, seed=ord(fold)) + + if split != 'train': + if split == 'id_val': + idxs = eval_idxs[:num_eval//2] + else: + idxs = eval_idxs[num_eval//2:] + self._split_array[idxs] = self._split_dict[split] + + if not use_ood_val: + self._split_dict = {'train': 0, 'val': 1, 'id_test': 2, 'ood_val': 3, 'test': 4} + self._split_names = {'train': 'Train', 'val': 'ID Val', 'id_test': 'ID Test', 'ood_val': 'OOD Val', 'test': 'OOD Test'} + + self.cache_size = cache_size + self.cache_counter = 0 + self.imgs = np.load(self.root / 'landsat_poverty_imgs.npy', mmap_mode='r') + self.imgs = self.imgs.transpose((0, 3, 1, 2)) + + self._y_array = torch.from_numpy(np.asarray(self.metadata['wealthpooled'])[:, np.newaxis]).float() + self._y_size = 1 + + # add country group field + country_to_idx = {country: i for i, country in enumerate(DHS_COUNTRIES)} + self.metadata['country'] = [country_to_idx[country] for country in self.metadata['country'].tolist()] + self._metadata_map = {'country': DHS_COUNTRIES} + self._metadata_array = torch.from_numpy(self.metadata[['urban', 'wealthpooled', 'country']].astype(float).to_numpy()) + # rename wealthpooled to y + self._metadata_fields = ['urban', 'y', 'country'] + + self._eval_grouper = CombinatorialGrouper( + dataset=self, + groupby_fields=['urban']) + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + """ + Returns x for a given idx. + """ + img = self.imgs[idx].copy() + if self.no_nl: + img[-1] = 0 + img = torch.from_numpy(img).float() + # consider refreshing cache if cache_size is limited + if self.cache_size < self.imgs.shape[0]: + self.cache_counter += 1 + if self.cache_counter > self.cache_size: + self.imgs = np.load(self.root / 'landsat_poverty_imgs.npy', mmap_mode='r') + self.imgs = self.imgs.transpose((0, 3, 1, 2)) + self.cache_counter = 0 + + return img + + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model + - y_true (LongTensor): Ground-truth values + - metadata (Tensor): Metadata + - prediction_fn (function): Only None supported + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + assert prediction_fn is None, "PovertyMapDataset.eval() does not support prediction_fn" + + metrics = [MSE(), PearsonCorrelation()] + + all_results = {} + all_results_str = '' + for metric in metrics: + results, results_str = self.standard_group_eval( + metric, + self._eval_grouper, + y_pred, y_true, metadata) + all_results.update(results) + all_results_str += results_str + return all_results, all_results_str diff --git a/wilds/datasets/bdd100k_dataset.py b/wilds/datasets/bdd100k_dataset.py index 0b97df31..29f4f16a 100644 --- a/wilds/datasets/bdd100k_dataset.py +++ b/wilds/datasets/bdd100k_dataset.py @@ -45,18 +45,18 @@ class BDD100KDataset(WILDSDataset): License (original text): Copyright ©2018. The Regents of the University of California (Regents). All Rights Reserved. - Permission to use, copy, modify, and distribute this software and its documentation for educational, research, and - not-for-profit purposes, without fee and without a signed licensing agreement; and permission use, copy, modify and - distribute this software for commercial purposes (such rights not subject to transfer) to BDD member and its affiliates, - is hereby granted, provided that the above copyright notice, this paragraph and the following two paragraphs appear in - all copies, modifications, and distributions. Contact The Office of Technology Licensing, UC Berkeley, 2150 Shattuck + Permission to use, copy, modify, and distribute this software and its documentation for educational, research, and + not-for-profit purposes, without fee and without a signed licensing agreement; and permission use, copy, modify and + distribute this software for commercial purposes (such rights not subject to transfer) to BDD member and its affiliates, + is hereby granted, provided that the above copyright notice, this paragraph and the following two paragraphs appear in + all copies, modifications, and distributions. Contact The Office of Technology Licensing, UC Berkeley, 2150 Shattuck Avenue, Suite 510, Berkeley, CA 94720-1620, (510) 643-7201, otl@berkeley.edu, http://ipira.berkeley.edu/industry-info for commercial licensing opportunities. - IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, - INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF REGENTS HAS BEEN ADVISED + IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, + INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY - AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, PROVIDED HEREUNDER IS PROVIDED + REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY + AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. """ @@ -65,11 +65,15 @@ class BDD100KDataset(WILDSDataset): TIMEOFDAY_SPLITS = ['daytime', 'night', 'dawn/dusk', 'undefined'] LOCATION_SPLITS = ['New York', 'California'] - def __init__(self, root_dir='data', download=False, split_scheme='official'): - self._dataset_name = 'bdd100k' - self._version = '1.0' + _dataset_name = 'bdd100k' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x0ac62ae89a644676a57fa61d6aa2f87d/contents/blob/', + 'compressed_size': None}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + self._version = version self._original_resolution = (1280, 720) - self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x0ac62ae89a644676a57fa61d6aa2f87d/contents/blob/' self._data_dir = self.initialize_data_dir(root_dir, download) self.root = Path(self.data_dir) @@ -103,14 +107,27 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): split_names = (self.TIMEOFDAY_SPLITS if split_to_load == 'timeofday' else self.LOCATION_SPLITS) self._metadata_map = {split_to_load: split_names} - self._metric = MultiTaskAccuracy() def get_input(self, idx): img = Image.open(self.root / 'images' / self._image_array[idx]) return img - def eval(self, y_pred, y_true, metadata): - results = self._metric.compute(y_pred, y_true) - results_str = (f'{self._metric.name}: ' - f'{results[self._metric.agg_metric_field]:.3f}\n') + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). + But they can also be other model outputs such that prediction_fn(y_pred) + are predicted labels. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + metric = MultiTaskAccuracy(prediction_fn=prediction_fn) + results = metric.compute(y_pred, y_true) + results_str = (f'{metric.name}: ' + f'{results[metric.agg_metric_field]:.3f}\n') return results, results_str diff --git a/wilds/datasets/camelyon17_dataset.py b/wilds/datasets/camelyon17_dataset.py index 0a76f615..2efeaa41 100644 --- a/wilds/datasets/camelyon17_dataset.py +++ b/wilds/datasets/camelyon17_dataset.py @@ -45,11 +45,14 @@ class Camelyon17Dataset(WILDSDataset): https://creativecommons.org/publicdomain/zero/1.0/ """ - def __init__(self, root_dir='data', download=False, split_scheme='official'): - self._dataset_name = 'camelyon17' - self._version = '1.0' - self._download_url = 'https://worksheets.codalab.org/rest/bundles/0xe45e15f39fb54e9d9e919556af67aabe/contents/blob/' - self._compressed_size = 10_658_709_504 + _dataset_name = 'camelyon17' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xe45e15f39fb54e9d9e919556af67aabe/contents/blob/', + 'compressed_size': 10_658_709_504}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + self._version = version self._data_dir = self.initialize_data_dir(root_dir, download) self._original_resolution = (96,96) @@ -120,8 +123,6 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): dataset=self, groupby_fields=['slide']) - self._metric = Accuracy() - super().__init__(root_dir, download, split_scheme) def get_input(self, idx): @@ -134,8 +135,22 @@ def get_input(self, idx): x = Image.open(img_filename).convert('RGB') return x - def eval(self, y_pred, y_true, metadata): + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). + But they can also be other model outputs such that prediction_fn(y_pred) + are predicted labels. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + metric = Accuracy(prediction_fn=prediction_fn) return self.standard_group_eval( - self._metric, + metric, self._eval_grouper, y_pred, y_true, metadata) diff --git a/wilds/datasets/celebA_dataset.py b/wilds/datasets/celebA_dataset.py index 37b9ffd2..06fcde93 100644 --- a/wilds/datasets/celebA_dataset.py +++ b/wilds/datasets/celebA_dataset.py @@ -51,11 +51,14 @@ class CelebADataset(WILDSDataset): It is available for non-commercial research purposes only. """ - - def __init__(self, root_dir='data', download=False, split_scheme='official'): - self._dataset_name = 'celebA' - self._version = '1.0' - self._download_url = 'https://worksheets.codalab.org/rest/bundles/0xa174edc9c11041869d11f98d1dc19935/contents/blob/' + _dataset_name = 'celebA' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xfe55077f5cd541f985ebf9ec50473293/contents/blob/', + 'compressed_size': 1_308_557_312}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + self._version = version self._data_dir = self.initialize_data_dir(root_dir, download) target_name = 'Blond_Hair' confounder_names = ['Male'] @@ -100,7 +103,6 @@ def attr_idx(attr_name): self._eval_grouper = CombinatorialGrouper( dataset=self, groupby_fields=(confounder_names + ['y'])) - self._metric = Accuracy() # Extract splits self._split_scheme = split_scheme @@ -121,8 +123,22 @@ def get_input(self, idx): x = Image.open(img_filename).convert('RGB') return x - def eval(self, y_pred, y_true, metadata): + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). + But they can also be other model outputs such that prediction_fn(y_pred) + are predicted labels. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + metric = Accuracy(prediction_fn=prediction_fn) return self.standard_group_eval( - self._metric, + metric, self._eval_grouper, y_pred, y_true, metadata) diff --git a/wilds/datasets/civilcomments_dataset.py b/wilds/datasets/civilcomments_dataset.py index 78fe9310..c4d6bb8b 100644 --- a/wilds/datasets/civilcomments_dataset.py +++ b/wilds/datasets/civilcomments_dataset.py @@ -55,11 +55,14 @@ class CivilCommentsDataset(WILDSDataset): https://creativecommons.org/publicdomain/zero/1.0/ """ - def __init__(self, root_dir='data', download=False, split_scheme='official'): - self._dataset_name = 'civilcomments' - self._version = '1.0' - self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8cd3de0634154aeaad2ee6eb96723c6e/contents/blob/' - self._compressed_size = 90_644_480 + _dataset_name = 'civilcomments' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8cd3de0634154aeaad2ee6eb96723c6e/contents/blob/', + 'compressed_size': 90_644_480}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + self._version = version self._data_dir = self.initialize_data_dir(root_dir, download) # Read in metadata @@ -121,18 +124,31 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): dataset=self, groupby_fields=[identity_var, 'y']) for identity_var in self._identity_vars] - self._metric = Accuracy() super().__init__(root_dir, download, split_scheme) def get_input(self, idx): return self._text_array[idx] - def eval(self, y_pred, y_true, metadata): + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). + But they can also be other model outputs such that prediction_fn(y_pred) + are predicted labels. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + metric = Accuracy(prediction_fn=prediction_fn) results = { - **self._metric.compute(y_pred, y_true), + **metric.compute(y_pred, y_true), } - results_str = f"Average {self._metric.name}: {results[self._metric.agg_metric_field]:.3f}\n" + results_str = f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n" # Each eval_grouper is over label + a single identity # We only want to keep the groups where the identity is positive # The groups are: @@ -145,31 +161,31 @@ def eval(self, y_pred, y_true, metadata): for identity_var, eval_grouper in zip(self._identity_vars, self._eval_groupers): g = eval_grouper.metadata_to_group(metadata) group_results = { - **self._metric.compute_group_wise(y_pred, y_true, g, eval_grouper.n_groups) + **metric.compute_group_wise(y_pred, y_true, g, eval_grouper.n_groups) } results_str += f" {identity_var:20s}" for group_idx in range(eval_grouper.n_groups): group_str = eval_grouper.group_field_str(group_idx) if f'{identity_var}:1' in group_str: - group_metric = group_results[self._metric.group_metric_field(group_idx)] - group_counts = group_results[self._metric.group_count_field(group_idx)] - results[f'{self._metric.name}_{group_str}'] = group_metric + group_metric = group_results[metric.group_metric_field(group_idx)] + group_counts = group_results[metric.group_count_field(group_idx)] + results[f'{metric.name}_{group_str}'] = group_metric results[f'count_{group_str}'] = group_counts if f'y:0' in group_str: label_str = 'non_toxic' else: label_str = 'toxic' results_str += ( - f" {self._metric.name} on {label_str}: {group_metric:.3f}" + f" {metric.name} on {label_str}: {group_metric:.3f}" f" (n = {results[f'count_{group_str}']:6.0f}) " ) if worst_group_metric is None: worst_group_metric = group_metric else: - worst_group_metric = self._metric.worst( + worst_group_metric = metric.worst( [worst_group_metric, group_metric]) results_str += f"\n" - results[f'{self._metric.worst_group_metric_field}'] = worst_group_metric - results_str += f"Worst-group {self._metric.name}: {worst_group_metric:.3f}\n" + results[f'{metric.worst_group_metric_field}'] = worst_group_metric + results_str += f"Worst-group {metric.name}: {worst_group_metric:.3f}\n" return results, results_str diff --git a/wilds/datasets/fmow_dataset.py b/wilds/datasets/fmow_dataset.py index 7c6e1814..4a310b40 100644 --- a/wilds/datasets/fmow_dataset.py +++ b/wilds/datasets/fmow_dataset.py @@ -57,12 +57,14 @@ class FMoWDataset(WILDSDataset): """ _dataset_name = 'fmow' - _download_url = 'https://worksheets.codalab.org/rest/bundles/0xc59ea8261dfe4d2baa3820866e33d781/contents/blob/' - _version = '1.0' + _versions_dict = { + '1.1': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xaec91eb7c9d548ebb15e1b5e60f966ab/contents/blob/', + 'compressed_size': 53_893_324_800} + } - def __init__(self, root_dir='data', download=False, split_scheme='official', - oracle_training_set=False, seed=111, use_ood_val=False): - self._compressed_size = 70_000_000_000 + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', oracle_training_set=False, seed=111, use_ood_val=False): + self._version = version self._data_dir = self.initialize_data_dir(root_dir, download) self._split_dict = {'train': 0, 'id_val': 1, 'id_test': 2, 'val': 3, 'test': 4} @@ -167,48 +169,59 @@ def __init__(self, root_dir='data', download=False, split_scheme='official', 'region': CombinatorialGrouper(dataset=self, groupby_fields=['region']), } - self._metric = Accuracy() super().__init__(root_dir, download, split_scheme) def get_input(self, idx): - """ - Returns x for a given idx. - """ - idx = self.full_idxs[idx] - batch_idx = idx // self.chunk_size - within_batch_idx = idx % self.chunk_size - img_batch = np.load(self.root / f'rgb_all_imgs_{batch_idx}.npy', mmap_mode='r') - return img_batch[within_batch_idx] - - def eval(self, y_pred, y_true, metadata): + """ + Returns x for a given idx. + """ + idx = self.full_idxs[idx] + img = Image.open(self.root / 'images' / f'rgb_img_{idx}.png').convert('RGB') + return img + + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). + But they can also be other model outputs such that prediction_fn(y_pred) + are predicted labels. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + metric = Accuracy(prediction_fn=prediction_fn) # Overall evaluation + evaluate by year all_results, all_results_str = self.standard_group_eval( - self._metric, + metric, self._eval_groupers['year'], y_pred, y_true, metadata) # Evaluate by region and ignore the "Other" region region_grouper = self._eval_groupers['region'] - region_results = self._metric.compute_group_wise( + region_results = metric.compute_group_wise( y_pred, y_true, region_grouper.metadata_to_group(metadata), region_grouper.n_groups) - all_results[f'{self._metric.name}_worst_year'] = all_results.pop(self._metric.worst_group_metric_field) + all_results[f'{metric.name}_worst_year'] = all_results.pop(metric.worst_group_metric_field) region_metric_list = [] for group_idx in range(region_grouper.n_groups): group_str = region_grouper.group_field_str(group_idx) - group_metric = region_results[self._metric.group_metric_field(group_idx)] - group_counts = region_results[self._metric.group_count_field(group_idx)] - all_results[f'{self._metric.name}_{group_str}'] = group_metric + group_metric = region_results[metric.group_metric_field(group_idx)] + group_counts = region_results[metric.group_count_field(group_idx)] + all_results[f'{metric.name}_{group_str}'] = group_metric all_results[f'count_{group_str}'] = group_counts - if region_results[self._metric.group_count_field(group_idx)] == 0 or "Other" in group_str: + if region_results[metric.group_count_field(group_idx)] == 0 or "Other" in group_str: continue all_results_str += ( f' {region_grouper.group_str(group_idx)} ' - f"[n = {region_results[self._metric.group_count_field(group_idx)]:6.0f}]:\t" - f"{self._metric.name} = {region_results[self._metric.group_metric_field(group_idx)]:5.3f}\n") - region_metric_list.append(region_results[self._metric.group_metric_field(group_idx)]) - all_results[f'{self._metric.name}_worst_region'] = self._metric.worst(region_metric_list) - all_results_str += f"Worst-group {self._metric.name}: {all_results[f'{self._metric.name}_worst_region']:.3f}\n" + f"[n = {region_results[metric.group_count_field(group_idx)]:6.0f}]:\t" + f"{metric.name} = {region_results[metric.group_metric_field(group_idx)]:5.3f}\n") + region_metric_list.append(region_results[metric.group_metric_field(group_idx)]) + all_results[f'{metric.name}_worst_region'] = metric.worst(region_metric_list) + all_results_str += f"Worst-group {metric.name}: {all_results[f'{metric.name}_worst_region']:.3f}\n" return all_results, all_results_str diff --git a/wilds/datasets/iwildcam_dataset.py b/wilds/datasets/iwildcam_dataset.py index 011c1f1c..533f7fbb 100644 --- a/wilds/datasets/iwildcam_dataset.py +++ b/wilds/datasets/iwildcam_dataset.py @@ -1,3 +1,4 @@ +from datetime import datetime from pathlib import Path import os @@ -35,37 +36,26 @@ class IWildCamDataset(WILDSDataset): This dataset is distributed under Community Data License Agreement – Permissive – Version 1.0 https://cdla.io/permissive-1-0/ """ + _dataset_name = 'iwildcam' + _versions_dict = { + '2.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x6313da2b204647e79a14b468131fcd64/contents/blob/', + 'compressed_size': 12_000_000_000}} - def __init__(self, root_dir='data', download=False, split_scheme='official'): + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): - self._dataset_name = 'iwildcam' - self._version = '1.0' + self._version = version self._split_scheme = split_scheme if self._split_scheme != 'official': raise ValueError(f'Split scheme {self._split_scheme} not recognized') # path - self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x3f1b346ff2d74b5daf1a08685d68c6ec/contents/blob/' - self._compressed_size = 90_094_666_806 self._data_dir = Path(self.initialize_data_dir(root_dir, download)) # Load splits - train_df = pd.read_csv(self._data_dir / 'train.csv') - val_trans_df = pd.read_csv(self._data_dir / 'val_trans.csv') - test_trans_df = pd.read_csv(self._data_dir / 'test_trans.csv') - val_cis_df = pd.read_csv(self._data_dir / 'val_cis.csv') - test_cis_df = pd.read_csv(self._data_dir / 'test_cis.csv') - - # Merge all dfs - train_df['split'] = 'train' - val_trans_df['split'] = 'val' - test_trans_df['split'] = 'test' - val_cis_df['split'] = 'id_val' - test_cis_df['split'] = 'id_test' - df = pd.concat([train_df, val_trans_df, test_trans_df, test_cis_df, val_cis_df]) + df = pd.read_csv(self._data_dir / 'metadata.csv') # Splits - data = {} self._split_dict = {'train': 0, 'val': 1, 'test': 2, 'id_val': 3, 'id_test': 4} self._split_names = {'train': 'Train', 'val': 'Validation (OOD/Trans)', 'test': 'Test (OOD/Trans)', 'id_val': 'Validation (ID/Cis)', @@ -78,46 +68,75 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._input_array = df['filename'].values # Labels - unique_categories = np.unique(df['category_id']) - self._n_classes = len(unique_categories) - category_to_label = dict([(i, j) for i, j in zip(unique_categories, range(self._n_classes))]) - label_to_category = dict([(v, k) for k, v in category_to_label.items()]) - self._y_array = torch.tensor(df['category_id'].apply(lambda x: category_to_label[x]).values) + self._y_array = torch.tensor(df['y'].values) + self._n_classes = max(df['y']) + 1 self._y_size = 1 + assert len(np.unique(df['y'])) == self._n_classes # Location/group info - location_ids = df['location'] - locations = np.unique(location_ids) - n_groups = len(locations) - location_to_group_id = {locations[i]: i for i in range(n_groups)} - df['group_id' ] = df['location'].apply(lambda x: location_to_group_id[x]) - + n_groups = max(df['location_remapped']) + 1 self._n_groups = n_groups - self._metadata_array = torch.tensor(np.stack([df['group_id'].values, self.y_array], axis=1)) - self._metadata_fields = ['location', 'y'] + assert len(np.unique(df['location_remapped'])) == self._n_groups + + # Sequence info + n_sequences = max(df['sequence_remapped']) + 1 + self._n_sequences = n_sequences + assert len(np.unique(df['sequence_remapped'])) == self._n_sequences + + # Extract datetime subcomponents and include in metadata + df['datetime_obj'] = df['datetime'].apply(lambda x: datetime.strptime(x, '%Y-%m-%d %H:%M:%S.%f')) + df['year'] = df['datetime_obj'].apply(lambda x: int(x.year)) + df['month'] = df['datetime_obj'].apply(lambda x: int(x.month)) + df['day'] = df['datetime_obj'].apply(lambda x: int(x.day)) + df['hour'] = df['datetime_obj'].apply(lambda x: int(x.hour)) + df['minute'] = df['datetime_obj'].apply(lambda x: int(x.minute)) + df['second'] = df['datetime_obj'].apply(lambda x: int(x.second)) + + self._metadata_array = torch.tensor(np.stack([df['location_remapped'].values, + df['sequence_remapped'].values, + df['year'].values, df['month'].values, df['day'].values, + df['hour'].values, df['minute'].values, df['second'].values, + self.y_array], axis=1)) + self._metadata_fields = ['location', 'sequence', 'year', 'month', 'day', 'hour', 'minute', 'second', 'y'] + # eval grouper self._eval_grouper = CombinatorialGrouper( dataset=self, groupby_fields=(['location'])) - self._metrics = [Accuracy(), Recall(average='macro'), Recall(average='weighted'), - F1(average='macro'), F1(average='weighted')] super().__init__(root_dir, download, split_scheme) - def eval(self, y_pred, y_true, metadata): + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). + But they can also be other model outputs such that prediction_fn(y_pred) + are predicted labels. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + metrics = [ + Accuracy(prediction_fn=prediction_fn), + Recall(prediction_fn=prediction_fn, average='macro'), + F1(prediction_fn=prediction_fn, average='macro'), + ] + results = {} - for i in range(len(self._metrics)): + for i in range(len(metrics)): results.update({ - **self._metrics[i].compute(y_pred, y_true), + **metrics[i].compute(y_pred, y_true), }) results_str = ( - f"Average acc: {results[self._metrics[0].agg_metric_field]:.3f}\n" - f"Recall macro: {results[self._metrics[1].agg_metric_field]:.3f}\n" - f"Recall weighted: {results[self._metrics[2].agg_metric_field]:.3f}\n" - f"F1 macro: {results[self._metrics[3].agg_metric_field]:.3f}\n" - f"F1 weighted: {results[self._metrics[4].agg_metric_field]:.3f}\n" + f"Average acc: {results[metrics[0].agg_metric_field]:.3f}\n" + f"Recall macro: {results[metrics[1].agg_metric_field]:.3f}\n" + f"F1 macro: {results[metrics[2].agg_metric_field]:.3f}\n" ) return results, results_str @@ -134,5 +153,4 @@ def get_input(self, idx): img_path = self.data_dir / 'train' / self._input_array[idx] img = Image.open(img_path) - return img diff --git a/wilds/datasets/ogbmolpcba_dataset.py b/wilds/datasets/ogbmolpcba_dataset.py index 38ddd4ab..413fd330 100644 --- a/wilds/datasets/ogbmolpcba_dataset.py +++ b/wilds/datasets/ogbmolpcba_dataset.py @@ -51,12 +51,20 @@ class OGBPCBADataset(WILDSDataset): https://github.com/snap-stanford/ogb/blob/master/LICENSE """ - def __init__(self, root_dir='data', download=False, split_scheme='official'): + _dataset_name = 'ogbg-molpcba' + _versions_dict = { + '1.0': { + 'download_url': None, + 'compressed_size': None}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + self._version = version + if version is not None: + raise ValueError('Versioning for OGB-MolPCBA is handled through the OGB package. Please set version=none.') # internally call ogb package self.ogb_dataset = PygGraphPropPredDataset(name = 'ogbg-molpcba', root = root_dir) # set variables - self._dataset_name = 'ogbg-molpcba' self._data_dir = self.ogb_dataset.root if split_scheme=='official': split_scheme = 'scaffold' @@ -88,7 +96,20 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): def get_input(self, idx): return self.ogb_dataset[int(idx)] - def eval(self, y_pred, y_true, metadata): + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (FloatTensor): Binary logits from a model + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels. + Only None is supported because OGB Evaluators accept binary logits + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + assert prediction_fn is None, "OGBPCBADataset.eval() does not support prediction_fn. Only binary logits accepted" input_dict = {"y_true": y_true, "y_pred": y_pred} results = self._metric.eval(input_dict) diff --git a/wilds/datasets/poverty_dataset.py b/wilds/datasets/poverty_dataset.py index 889881c7..7b062002 100644 --- a/wilds/datasets/poverty_dataset.py +++ b/wilds/datasets/poverty_dataset.py @@ -142,13 +142,17 @@ class PovertyMapDataset(WILDSDataset): """ _dataset_name = 'poverty' - _download_url = 'https://worksheets.codalab.org/rest/bundles/0x9a2add5219db4ebc89965d7f42719750/contents/blob/' - _version = '1.0' - - def __init__(self, root_dir='data', download=False, split_scheme='official', - no_nl=True, fold='A', oracle_training_set=False, use_ood_val=False): - - self._compressed_size = 18_630_656_000 + _versions_dict = { + '1.1': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xfc0aa86ad9af4eb08c42dfc40eacf094/contents/blob/', + 'compressed_size': 13_091_823_616}} + + def __init__(self, version=None, root_dir='data', download=False, + split_scheme='official', + no_nl=False, fold='A', oracle_training_set=False, + use_ood_val=True, + cache_size=100): + self._version = version self._data_dir = self.initialize_data_dir(root_dir, download) self._split_dict = {'train': 0, 'id_val': 1, 'id_test': 2, 'val': 3, 'test': 4} @@ -208,10 +212,6 @@ def __init__(self, root_dir='data', download=False, split_scheme='official', self._split_dict = {'train': 0, 'val': 1, 'id_test': 2, 'ood_val': 3, 'test': 4} self._split_names = {'train': 'Train', 'val': 'ID Val', 'id_test': 'ID Test', 'ood_val': 'OOD Val', 'test': 'OOD Test'} - - self.imgs = np.load(self.root / 'landsat_poverty_imgs.npy', mmap_mode='r') - - self.imgs = self.imgs.transpose((0, 3, 1, 2)) self._y_array = torch.from_numpy(np.asarray(self.metadata['wealthpooled'])[:, np.newaxis]).float() self._y_size = 1 @@ -227,32 +227,38 @@ def __init__(self, root_dir='data', download=False, split_scheme='official', dataset=self, groupby_fields=['urban']) - self._metrics = [MSE(), PearsonCorrelation()] - self.cache_counter = 0 - super().__init__(root_dir, download, split_scheme) def get_input(self, idx): - """ - Returns x for a given idx. - """ - img = self.imgs[idx].copy() - if self.no_nl: - img[-1] = 0 - img = torch.from_numpy(img).float() - - self.cache_counter += 1 - if self.cache_counter > 1000: - self.imgs = np.load(self.root / 'landsat_poverty_imgs.npy', mmap_mode='r') - self.imgs = self.imgs.transpose((0, 3, 1, 2)) - self.cache_counter = 0 - - return img - - def eval(self, y_pred, y_true, metadata): + """ + Returns x for a given idx. + """ + img = np.load(self.root / 'images' / f'landsat_poverty_img_{idx}.npz')['x'] + if self.no_nl: + img[-1] = 0 + img = torch.from_numpy(img).float() + + return img + + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model + - y_true (LongTensor): Ground-truth values + - metadata (Tensor): Metadata + - prediction_fn (function): Only None supported + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + assert prediction_fn is None, "PovertyMapDataset.eval() does not support prediction_fn" + + metrics = [MSE(), PearsonCorrelation()] + all_results = {} all_results_str = '' - for metric in self._metrics: + for metric in metrics: results, results_str = self.standard_group_eval( metric, self._eval_grouper, diff --git a/wilds/datasets/py150_dataset.py b/wilds/datasets/py150_dataset.py new file mode 100644 index 00000000..e821c632 --- /dev/null +++ b/wilds/datasets/py150_dataset.py @@ -0,0 +1,188 @@ +from pathlib import Path +import os + +import pandas as pd +import numpy as np +import torch +import json +import gc +from wilds.common.metrics.all_metrics import Accuracy +from wilds.datasets.wilds_dataset import WILDSDataset +from transformers import GPT2Tokenizer + +class Py150Dataset(WILDSDataset): + """ + The Py150 dataset. + This is a modified version of the original Py150 dataset. + Input (x): + A Python code snippet (a sequence of tokens) + Label (y): + A sequence of next tokens (shifted x) + Metadata: + Each example is annotated with the original GitHub repo id. + This repo id can be matched with the name of the repo in natural language by + matching it with the contents of the metadata/ folder in the downloaded dataset. + Similarly, each example can also associated with the name of the file in natural language. + Website: + https://www.sri.inf.ethz.ch/py150 + https://github.com/microsoft/CodeXGLUE + Original publication: + @article{raychev2016probabilistic, + title={Probabilistic model for code with decision trees}, + author={Raychev, Veselin and Bielik, Pavol and Vechev, Martin}, + journal={ACM SIGPLAN Notices}, + year={2016}, + } + @article{CodeXGLUE, + title={CodeXGLUE: A Benchmark Dataset and Open Challenge for Code Intelligence}, + year={2020}, + } + License: + This dataset is distributed under the MIT license. + """ + + _dataset_name = 'py150' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x442a0661a84649e69c0a946cc5f84237/contents/blob/', + 'compressed_size': 162_811_706}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + + self._version = version + self._split_scheme = split_scheme + if self._split_scheme != 'official': + raise ValueError(f'Split scheme {self._split_scheme} not recognized') + + # path + self._data_dir = Path(self.initialize_data_dir(root_dir, download)) + + # Load data + df = self._load_all_data() + self._TYPE2ID = {'class':0, 'method':1, 'punctuation':2, 'keyword':3, 'builtin':4, 'literal':5, 'other_identifier':6, 'masked':-100} + self._ID2TYPE = {v: k for k, v in self._TYPE2ID.items()} + + # Splits + data = {} + self._split_dict = {'train': 0, 'val': 1, 'test': 2, 'id_val': 3, 'id_test': 4} + self._split_names = {'train': 'Train', 'val': 'Validation (OOD)', + 'test': 'Test (OOD)', 'id_val': 'Validation (ID)', + 'id_test': 'Test (ID)'} + + df['split_id'] = df['split'].apply(lambda x: self._split_dict[x]) + self._split_array = df['split_id'].values + + # Input + self._input_array = torch.tensor(list(df['input'].apply(lambda x: x[:-1]).values)) #[n_samples, seqlen-1] + + # Labels + name = 'microsoft/CodeGPT-small-py' + tokenizer = GPT2Tokenizer.from_pretrained(name) + self._n_classes = len(tokenizer) + self._y_array = torch.tensor(list(df['input'].apply(lambda x: x[1:]).values)) + self._y_size = None + + _repo = torch.tensor(df['repo'].values).reshape(-1,1) #[n_samples, 1] + _tok_type = torch.tensor(list(df['tok_type'].apply(lambda x: x[1:]).values)) #[n_samples, seqlen-1] + length = _tok_type.size(1) + self._metadata_fields = ['repo'] + [f'tok_{i}_type' for i in range(length)] + self._metadata_array = torch.cat([_repo, _tok_type], dim=1) + + self._y_array = self._y_array.float() + self._y_array[(_tok_type==self._TYPE2ID['masked']).bool()] = float('nan') + + super().__init__(root_dir, download, split_scheme) + + def _compute_acc(self, y_pred, y_true, eval_pos): + flattened_y_pred = y_pred[eval_pos] + flattened_y_true = y_true[eval_pos] + assert flattened_y_pred.size()==flattened_y_true.size() and flattened_y_pred.dim()==1 + if len(flattened_y_pred) == 0: + acc = 0 + else: + acc = (flattened_y_pred==flattened_y_true).float().mean().item() + return acc + + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). + But they can also be other model outputs such that prediction_fn(y_pred) + are predicted labels. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + if prediction_fn is not None: + y_pred = prediction_fn(y_pred) + + #y_pred: [n_samples, seqlen-1] + #y_true: [n_samples, seqlen-1] + tok_type = metadata[:, 1:] #[n_samples, seqlen-1] + results = {} + results_str = "" + + #Acc for class & method combined + eval_pos = (tok_type == self._TYPE2ID['class']) | (tok_type == self._TYPE2ID['method']) + acc = self._compute_acc(y_pred, y_true, eval_pos) + results['acc'] = acc + results['Acc (Class-Method)'] = acc + results_str += f"Acc (Class-Method): {acc:.3f}\n" + + #Overall acc + eval_pos = ~torch.isnan(y_true) + acc = self._compute_acc(y_pred, y_true, eval_pos) + results['Acc (Overall)'] = acc + results_str += f"Acc (Overall): {acc:.3f}\n" + + #Acc for each token type + for TYPE, TYPEID in self._TYPE2ID.items(): + if TYPE == 'masked': + continue + eval_pos = (tok_type == TYPEID) + acc = self._compute_acc(y_pred, y_true, eval_pos) + results[f'Acc ({TYPE})'] = acc + results_str += f"Acc ({TYPE}): {acc:.3f}\n" + + return results, results_str + + def get_input(self, idx): + """ + Args: + - idx (int): Index of a data point + Output: + - x (Tensor): Input features of the idx-th data point + """ + return self._input_array[idx] + + + def _load_all_data(self): + def fname2repo_id(fname, repo_name2id): + return repo_name2id['/'.join(fname.split('/')[:2])] + + def get_split_name(name): + if name.startswith('OOD'): return name.replace('OOD','') + if name.startswith('ID'): return name.replace('ID','id_') + return name + + _df = pd.read_csv(self._data_dir/'metadata/repo_file_names/repo_ids.csv') + repo_name2id = {repo_name: id for id, repo_name in zip(_df.id, _df.repo_name)} + + dfs = [] + pad_token_id = 1 + for type in ['train', 'IDval', 'OODval', 'IDtest', 'OODtest']: + inputs = json.load(open(self._data_dir/f'processed/{type}_input.json')) + fnames = open(self._data_dir/f'metadata/repo_file_names/{type}.txt').readlines() + repo_ids = [fname2repo_id(fname, repo_name2id) for fname in fnames] + splits = [get_split_name(type)] * len(inputs) + tok_types = json.load(open(self._data_dir/f'processed/{type}_input_tok_type.json')) + assert len(repo_ids) == len(inputs) == len(tok_types) + + _df = pd.DataFrame({'input': inputs, 'tok_type': tok_types, 'repo': repo_ids, 'split': splits}) + dfs.append(_df) + + return pd.concat(dfs) diff --git a/wilds/datasets/sqf_dataset.py b/wilds/datasets/sqf_dataset.py new file mode 100644 index 00000000..d7f233c5 --- /dev/null +++ b/wilds/datasets/sqf_dataset.py @@ -0,0 +1,304 @@ +import os +import torch +import pandas as pd +import numpy as np +from wilds.datasets.wilds_dataset import WILDSDataset +from wilds.common.metrics.all_metrics import Accuracy, PrecisionAtRecall, binary_logits_to_score, multiclass_logits_to_pred +from wilds.common.grouper import CombinatorialGrouper +from wilds.common.utils import subsample_idxs, threshold_at_recall +import torch.nn.functional as F + +class SQFDataset(WILDSDataset): + """ + New York City stop-question-and-frisk data. + The dataset covers data from 2009 - 2012, as orginally provided by the New York Police Department (NYPD) and later cleaned by Goel, Rao, and Shroff, 2016. + + Supported `split_scheme`: + 'black', 'all_race', 'bronx', or 'all_borough' + + Input (x): + For the 'black' and 'all_race' split schemes: + 29 pre-stop observable features + + 75 one-hot district indicators = 104 features + + For the 'bronx' and 'all_borough' split schemes: + 29 pre-stop observable features. + As these split schemes study location shifts, we remove the district + indicators here as they prevent generalizing to new locations. + In order to run the example code with these split_schemes, + pass in the command-line parameter `--model_kwargs in_features=29` + to `examples/run_expt.py`. + + Label (y): + Binary. It is 1 if the stop is listed as finding a weapon, and 0 otherwise. + + Metadata: + Each stop is annotated with the borough the stop took place, + the race of the stopped person, and whether the stop took + place in 2009-2010 or in 2011-2012 + + Website: + NYPD - https://www1.nyc.gov/site/nypd/stats/reports-analysis/stopfrisk.page + Cleaned data - https://5harad.com/data/sqf.RData + + Cleaning and analysis citation: + @article{goel_precinct_2016, + title = {Precinct or prejudice? {Understanding} racial disparities in {New} {York} {City}’s stop-and-frisk policy}, + volume = {10}, + issn = {1932-6157}, + shorttitle = {Precinct or prejudice?}, + url = {http://projecteuclid.org/euclid.aoas/1458909920}, + doi = {10.1214/15-AOAS897}, + language = {en}, + number = {1}, + journal = {The Annals of Applied Statistics}, + author = {Goel, Sharad and Rao, Justin M. and Shroff, Ravi}, + month = mar, + year = {2016}, + pages = {365--394}, + } + + License: + The original data frmo the NYPD is in the public domain. + The cleaned data from Goel, Rao, and Shroff is shared with permission. + """ + _dataset_name = 'sqf' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xea27fd7daef642d2aa95b02f1e3ac404/contents/blob/', + 'compressed_size': 36_708_352}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='all_race'): + # set variables + self._version = version + self._split_scheme = split_scheme + self._y_size = 1 + self._n_classes = 2 + # path + self._data_dir = self.initialize_data_dir(root_dir, download) + + # Load data + data_df = pd.read_csv(os.path.join(self.data_dir, 'sqf.csv') , index_col=0) + data_df = data_df[data_df['suspected.crime'] == 'cpw'] + categories = ['black', 'white hispanic', 'black hispanic', 'hispanic', 'white'] + data_df = data_df.loc[data_df['suspect.race'].map(lambda x: x in categories)] + data_df['suspect.race'] = data_df['suspect.race'].map(lambda x: 'Hispanic' if 'hispanic' in x else x.title()) + + # Only track weapons stops + data_df = data_df[data_df['suspected.crime']=='cpw'] + + # Get district features if measuring race, don't if measuring boroughs + self.feats_to_use = self.get_split_features(data_df.columns) + + # Drop rows that don't have all of the predictive features. + # This preserves almost all rows. + data_df = data_df.dropna(subset=self.feats_to_use) + + # Get indices based on new index / after dropping rows with missing data + train_idxs, test_idxs, val_idxs = self.get_split_indices(data_df) + + # Drop rows with unused metadata categories + data_df = data_df.loc[train_idxs + test_idxs + val_idxs] + + # Reindex for simplicity + data_df.index = range(data_df.shape[0]) + train_idxs = range(0, len(train_idxs)) + test_idxs = range(len(train_idxs), len(train_idxs)+ len(test_idxs)) + val_idxs = range(test_idxs[-1], data_df.shape[0]) + + # Normalize continuous features + data_df = self.normalize_data(data_df, train_idxs) + self._input_array = data_df + + # Create split dictionaries + self._split_dict, self._split_names = self.initialize_split_dicts() + + # Get whether a weapon was found for various groups + self._y_array = torch.from_numpy(data_df['found.weapon'].values).long() + + # Metadata will be int dicts + explicit_identity_label_df, self._metadata_map = self.load_metadata(data_df, ['suspect.race', 'borough', 'train.period']) + self._metadata_array = torch.cat( + ( + torch.LongTensor(explicit_identity_label_df.values), + self._y_array.reshape((-1, 1)) + ), + dim=1 + ) + self._metadata_fields = ['suspect race', 'borough', '2010 or earlier?'] + ['y'] + + self._split_array = self.get_split_maps(data_df, train_idxs, test_idxs, val_idxs) + data_df = data_df[self.feats_to_use] + self._input_array = pd.get_dummies( + data_df, + columns=[i for i in self.feats_to_use + if 'suspect.' not in i and 'observation.period' not in i], + drop_first=True) + + # Recover relevant features after taking dummies + new_feats = [] + for i in self.feats_to_use: + for j in self._input_array: + if i in j: + new_feats.append(j) + else: + pass + self._input_array = self._input_array[new_feats] + self._eval_grouper = self.initialize_eval_grouper() + + def load_metadata(self, data_df, identity_vars): + metadata_df = data_df[identity_vars].copy() + metadata_names = ['suspect race', 'borough', '2010 or earlier?'] + metadata_ordered_maps = {} + for col_name, meta_name in zip(metadata_df.columns, metadata_names): + col_order = sorted(set(metadata_df[col_name])) + col_dict = dict(zip(col_order, range(len(col_order)))) + metadata_ordered_maps[col_name] = col_order + metadata_df[meta_name] = metadata_df[col_name].map(col_dict) + return metadata_df[metadata_names], metadata_ordered_maps + + def get_split_indices(self, data_df): + """Finds splits based on the split type """ + test_idxs = data_df[data_df.year > 2010].index.tolist() + train_df = data_df[data_df.year <= 2010] + validation_id_idxs = subsample_idxs( + train_df.index.tolist(), + num=int(train_df.shape[0] * 0.2), + seed=2851, + take_rest=False) + + train_df = train_df[~train_df.index.isin(validation_id_idxs)] + + if 'black' == self._split_scheme: + train_idxs = train_df[train_df['suspect.race'] == 'Black'].index.tolist() + + elif 'all_race' in self._split_scheme: + black_train_size = train_df[train_df['suspect.race'] == 'Black'].shape[0] + train_idxs = subsample_idxs(train_df.index.tolist(), num=black_train_size, take_rest=False, seed=4999) + + elif 'all_borough' == self._split_scheme: + bronx_train_size = train_df[train_df['borough'] == 'Bronx'].shape[0] + train_idxs = subsample_idxs(train_df.index.tolist(), num=bronx_train_size, take_rest=False, seed=8614) + + elif 'bronx' == self._split_scheme: + train_idxs = train_df[train_df['borough'] == 'Bronx'].index.tolist() + + else: + raise ValueError(f'Split scheme {self.split_scheme} not recognized') + + return train_idxs, test_idxs, validation_id_idxs + + def get_split_maps(self, data_df, train_idxs, test_idxs, val_idxs): + """Using the existing split indices, create a map to put entries to training and validation sets. """ + split_array = np.zeros(data_df.shape[0]) + split_array[train_idxs] = 0 + split_array[test_idxs] = 1 + split_array[val_idxs] = 2 + return split_array + + def get_split_features(self, columns): + """Get features that include precinct if we're splitting on race or don't include if we're using borough splits.""" + feats_to_use = [] + if 'bronx' not in self._split_scheme and 'borough' not in self._split_scheme: + feats_to_use.append('precinct') + + feats_to_use += ['suspect.height', 'suspect.weight', 'suspect.age', 'observation.period', + 'inside.outside', 'location.housing', 'radio.run', 'officer.uniform'] + # Primary stop reasoning features + feats_to_use += [i for i in columns if 'stopped.bc' in i] + # Secondary stop reasoning features, if any + feats_to_use += [i for i in columns if 'additional' in i] + + return feats_to_use + + def normalize_data(self, df, train_idxs): + """"Normalizes the data as Goel et al do - continuous features only""" + columns_to_norm = ['suspect.height', 'suspect.weight', 'suspect.age', 'observation.period'] + df_unnormed_train = df.loc[train_idxs].copy() + for feature_name in columns_to_norm: + df[feature_name] = df[feature_name] - np.mean(df_unnormed_train[feature_name]) + df[feature_name] = df[feature_name] / np.std(df_unnormed_train[feature_name]) + return df + + def initialize_split_dicts(self): + """Identify split indices and name splits""" + split_dict = {'train': 0, 'test': 1, 'val':2} + if 'all_borough' == self.split_scheme : + split_names = { + 'train': 'Stops in 2009 & 2010, subsampled to match Bronx train set size', + 'test': 'All stops in 2011 & 2012', + 'val': '20% sample of all stops 2009 & 2010' + } + elif 'bronx' == self.split_scheme: + split_names = { + 'train': 'Bronx stops in 2009 & 2010', + 'test': 'All stops in 2011 & 2012', + 'val': '20% sample of all stops 2009 & 2010' + } + elif 'black' == self.split_scheme: + split_names = { + 'train': '80% Black Stops 2009 and 2010', + 'test': 'All stops in 2011 & 2012', + 'val': '20% sample of all stops 2009 & 2010' + } + elif 'all_race' == self.split_scheme: + split_names = { + 'train': 'Stops in 2009 & 2010, subsampled to match Black people train set size', + 'test': 'All stops in 2011 & 2012', + 'val': '20% sample of all stops 2009 & 2010' + } + else: + raise ValueError(f'Split scheme {self.split_scheme} not recognized') + return split_dict, split_names + + def get_input(self, idx): + return torch.FloatTensor(self._input_array.loc[idx].values) + + def eval(self, y_pred, y_true, metadata, prediction_fn=multiclass_logits_to_pred, score_fn=binary_logits_to_score): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are multi-class logits (FloatTensor). + But they can also be other model outputs such that prediction_fn(y_pred) + are predicted labels and score_fn(y_pred) are confidence scores. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + """Evaluate the precision achieved overall and across groups for a given global recall""" + g = self._eval_grouper.metadata_to_group(metadata) + + y_scores = score_fn(y_pred) + threshold_60 = threshold_at_recall(y_scores, y_true, global_recall=60) + + accuracy_metric = Accuracy(prediction_fn=prediction_fn) + PAR_metric = PrecisionAtRecall(threshold_60, score_fn=score_fn) + + results = accuracy_metric.compute(y_pred, y_true) + results.update(PAR_metric.compute(y_pred, y_true)) + results.update(accuracy_metric.compute_group_wise(y_pred, y_true, g, self._eval_grouper.n_groups)) + results.update(PAR_metric.compute_group_wise(y_pred, y_true, g, self._eval_grouper.n_groups)) + + results_str = ( + f"Average {PAR_metric.name}: {results[PAR_metric.agg_metric_field]:.3f}\n" + f"Average {accuracy_metric.name}: {results[accuracy_metric.agg_metric_field]:.3f}\n" + ) + + return results, results_str + + def initialize_eval_grouper(self): + if 'black' in self.split_scheme or 'race' in self.split_scheme : + eval_grouper = CombinatorialGrouper( + dataset=self, + groupby_fields = ['suspect race'] + ) + elif 'bronx' in self.split_scheme or 'all_borough' == self.split_scheme: + eval_grouper = CombinatorialGrouper( + dataset=self, + groupby_fields = ['borough']) + else: + raise ValueError(f'Split scheme {self.split_scheme} not recognized') + return eval_grouper diff --git a/wilds/datasets/waterbirds_dataset.py b/wilds/datasets/waterbirds_dataset.py index d9e69349..9caeb4cb 100644 --- a/wilds/datasets/waterbirds_dataset.py +++ b/wilds/datasets/waterbirds_dataset.py @@ -53,10 +53,14 @@ class WaterbirdsDataset(WILDSDataset): The use of this dataset is restricted to non-commercial research and educational purposes. """ - def __init__(self, root_dir='data', download=False, split_scheme='official'): - self._dataset_name = 'waterbirds' - self._version = '1.0' - self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x505056d5cdea4e4eaa0e242cbfe2daa4/contents/blob/' + _dataset_name = 'waterbirds' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x505056d5cdea4e4eaa0e242cbfe2daa4/contents/blob/', + 'compressed_size': None}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + self._version = version self._data_dir = self.initialize_data_dir(root_dir, download) if not os.path.exists(self.data_dir): @@ -96,7 +100,6 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._eval_grouper = CombinatorialGrouper( dataset=self, groupby_fields=(['background', 'y'])) - self._metric = Accuracy() super().__init__(root_dir, download, split_scheme) @@ -110,8 +113,22 @@ def get_input(self, idx): x = Image.open(img_filename).convert('RGB') return x - def eval(self, y_pred, y_true, metadata): + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). + But they can also be other model outputs such that prediction_fn(y_pred) + are predicted labels. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + metric = Accuracy(prediction_fn=prediction_fn) return self.standard_group_eval( - self._metric, + metric, self._eval_grouper, y_pred, y_true, metadata) diff --git a/wilds/datasets/wilds_dataset.py b/wilds/datasets/wilds_dataset.py index ab149bac..1f8bf21a 100644 --- a/wilds/datasets/wilds_dataset.py +++ b/wilds/datasets/wilds_dataset.py @@ -1,5 +1,6 @@ import os -import shutil +import time + import torch import numpy as np @@ -109,6 +110,24 @@ def check_init(self): if self.y_size == 1: assert 'y' in self.metadata_fields + @property + def latest_version(cls): + def is_later(u, v): + """Returns true if u is a later version than v.""" + u_major, u_minor = tuple(map(int, u.split('.'))) + v_major, v_minor = tuple(map(int, v.split('.'))) + if (u_major > v_major) or ( + (u_major == v_major) and (u_minor > v_minor)): + return True + else: + return False + + latest_version = '0.0' + for key in cls.versions_dict.keys(): + if is_later(key, latest_version): + latest_version = key + return latest_version + @property def dataset_name(self): """ @@ -121,16 +140,25 @@ def version(self): """ A string that identifies the dataset version, e.g., '1.0'. """ - return self._version + if self._version is None: + return self.latest_version + else: + return self._version @property - def download_url(self): + def versions_dict(self): """ - URL for downloading the dataset archive. + A dictionary where each key is a version string (e.g., '1.0') + and each value is a dictionary containing the 'download_url' and + 'compressed_size' keys. + + 'download_url' is the URL for downloading the dataset archive. If None, the dataset cannot be downloaded automatically (e.g., because it first requires accepting a usage agreement). + + 'compressed_size' is the approximate size of the compressed dataset in bytes. """ - return getattr(self, '_download_url', None) + return self._versions_dict @property def data_dir(self): @@ -256,13 +284,6 @@ def original_resolution(self): """ return getattr(self, '_original_resolution', None) - @property - def compressed_size(self): - """ - Size of the compressed bundle - """ - return getattr(self, '_compressed_size', None) - def initialize_data_dir(self, root_dir, download): """ Helper function for downloading/updating the dataset if required. @@ -271,102 +292,78 @@ def initialize_data_dir(self, root_dir, download): Datasets for which we don't control the download, like Yelp, might not handle versions similarly. """ + if self.version not in self.versions_dict: + raise ValueError(f'Version {self.version} not supported. Must be in {self.versions_dict.keys()}.') + + download_url = self.versions_dict[self.version]['download_url'] + compressed_size = self.versions_dict[self.version]['compressed_size'] + os.makedirs(root_dir, exist_ok=True) data_dir = os.path.join(root_dir, f'{self.dataset_name}_v{self.version}') version_file = os.path.join(data_dir, f'RELEASE_v{self.version}.txt') current_major_version, current_minor_version = tuple(map(int, self.version.split('.'))) + # Check if we specified the latest version. Otherwise, print a warning. + latest_major_version, latest_minor_version = tuple(map(int, self.latest_version.split('.'))) + if latest_major_version > current_major_version: + print( + f'*****************************\n' + f'{self.dataset_name} has been updated to version {self.latest_version}.\n' + f'You are currently using version {self.version}.\n' + f'We highly recommend updating the dataset by not specifying the older version in the command-line argument or dataset constructor.\n' + f'See https://wilds.stanford.edu/changelog for changes.\n' + f'*****************************\n') + elif latest_minor_version > current_minor_version: + print( + f'*****************************\n' + f'{self.dataset_name} has been updated to version {self.latest_version}.\n' + f'You are currently using version {self.version}.\n' + f'Please consider updating the dataset.\n' + f'See https://wilds.stanford.edu/changelog for changes.\n' + f'*****************************\n') + # If the data_dir exists and contains the right RELEASE file, # we assume the dataset is correctly set up if os.path.exists(data_dir) and os.path.exists(version_file): return data_dir - # If the data_dir exists and is not empty, and the download_url is set, + # If the data_dir exists and does not contain the right RELEASE file, but it is not empty and the download_url is not set, # we assume the dataset is correctly set up if ((os.path.exists(data_dir)) and (len(os.listdir(data_dir)) > 0) and - (self.download_url is None)): + (download_url is None)): return data_dir - # Otherwise, check if there's an older version of the dataset around - old_major_version, old_minor_version = -1, -1 - old_folders = [ - f for f in os.listdir(root_dir) if ( - os.path.isdir(os.path.join(root_dir, f)) and - f.startswith(self.dataset_name))] - for old_folder in old_folders: - prefix = f'{self.dataset_name}_v' - try: - version = old_folder.split(prefix)[1] - if os.path.exists( - os.path.join(root_dir, old_folder, f'RELEASE_v{version}.txt')): - major_version, minor_version = tuple(map(int, version.split('.'))) - if ((old_major_version < major_version) or - ((old_major_version == major_version) and - (old_minor_version < minor_version))): - old_major_version, old_minor_version = major_version, minor_version - latest_existing_data_dir = os.path.join(root_dir, old_folder) - except: - continue - - do_download = False - - # No existing dataset - if (old_major_version == -1): - if download == False: - if self.download_url is None: - raise FileNotFoundError(f'The {self.dataset_name} dataset could not be found in {data_dir}. {self.dataset_name} cannot be automatically downloaded. Please download it manually.') - else: - raise FileNotFoundError(f'The {self.dataset_name} dataset could not be found in {data_dir}. Initialize the dataset with download=True to download the dataset. If you are using the example script, run with --download. This might take some time for large datasets.') + # Otherwise, we assume the dataset needs to be downloaded. + # If download == False, then return an error. + if download == False: + if download_url is None: + raise FileNotFoundError(f'The {self.dataset_name} dataset could not be found in {data_dir}. {self.dataset_name} cannot be automatically downloaded. Please download it manually.') else: - do_download = True - - # Older major version: - # Prompt for update, ignore `download` flag - elif (old_major_version < current_major_version): - print( - '***********\n' - f'{self.dataset_name} has been updated to a new major version.\n' - f'We recommend updating the dataset.\n') - confirm = input(f'Will you update the dataset now? This might take some time for large datasets. (y/n)\n').lower() - if confirm == 'y': - do_download = True - - # Same major version, older minor version: - # Notify user but do not prompt unless `download` is set - elif ((old_major_version == current_major_version) and - (old_minor_version < current_minor_version)): - print( - '***********\n' - f'{self.dataset_name} has been updated to a new minor version.\n') - if download == False: - print( - 'Initialize the dataset with download=True to download the dataset. If you are using the example script, run with --download. This might take some time for large datasets.\n' - '***********\n') - else: - do_download = True - - # Download if necessary - if do_download == False: - data_dir = latest_existing_data_dir - else: - if self.download_url is None: - raise ValueError(f'Sorry, {self.dataset_name} cannot be automatically downloaded. Please download it manually.') - - from wilds.datasets.download_utils import download_and_extract_archive - print(f'Downloading dataset to {data_dir}...') - print(f'You can also download the dataset manually at https://wilds.stanford.edu/downloads.') - try: - download_and_extract_archive( - url=self.download_url, - download_root=data_dir, - filename='archive.tar.gz', - remove_finished=True, - size=self.compressed_size) - except Exception as e: - print(f"\n{os.path.join(data_dir, 'archive.tar.gz')} may be corrupted. Please try deleting it and rerunning this command.\n") - print(f"Exception: ", e) + raise FileNotFoundError(f'The {self.dataset_name} dataset could not be found in {data_dir}. Initialize the dataset with download=True to download the dataset. If you are using the example script, run with --download. This might take some time for large datasets.') + + # Otherwise, proceed with downloading. + if download_url is None: + raise ValueError(f'Sorry, {self.dataset_name} cannot be automatically downloaded. Please download it manually.') + + from wilds.datasets.download_utils import download_and_extract_archive + print(f'Downloading dataset to {data_dir}...') + print(f'You can also download the dataset manually at https://wilds.stanford.edu/downloads.') + try: + start_time = time.time() + download_and_extract_archive( + url=download_url, + download_root=data_dir, + filename='archive.tar.gz', + remove_finished=True, + size=compressed_size) + + download_time_in_minutes = (time.time() - start_time) / 60 + print(f"It took {round(download_time_in_minutes, 2)} minutes to download and uncompress the dataset.") + except Exception as e: + print(f"\n{os.path.join(data_dir, 'archive.tar.gz')} may be corrupted. Please try deleting it and rerunning this command.\n") + print(f"Exception: ", e) return data_dir diff --git a/wilds/datasets/yelp_dataset.py b/wilds/datasets/yelp_dataset.py index 39923e8f..36e9ea10 100644 --- a/wilds/datasets/yelp_dataset.py +++ b/wilds/datasets/yelp_dataset.py @@ -41,12 +41,17 @@ class YelpDataset(WILDSDataset): License: Because of the Dataset License provided by Yelp, we are unable to redistribute the data. Please download the data through the website (https://www.yelp.com/dataset/download) by - agreeing to the Dataset License. + agreeing to the Dataset License. """ - def __init__(self, root_dir='data', download=False, split_scheme='official'): - # set variables - self._dataset_name = 'yelp' - self._version = '1.0' + _dataset_name = 'yelp' + _versions_dict = { + '1.0': { + 'download_url': None, + 'compressed_size': None}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + # set variables + self._version = version if split_scheme=='official': split_scheme = 'time' self._split_scheme = split_scheme @@ -75,41 +80,54 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self.initialize_split_dicts() # eval self.initialize_eval_grouper() - self._metric = Accuracy() super().__init__(root_dir, download, split_scheme) def get_input(self, idx): return self._input_array[idx] - def eval(self, y_pred, y_true, metadata): + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). + But they can also be other model outputs such that prediction_fn(y_pred) + are predicted labels. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + metric = Accuracy(prediction_fn=prediction_fn) if self.split_scheme=='user': # first compute groupwise accuracies g = self._eval_grouper.metadata_to_group(metadata) results = { - **self._metric.compute(y_pred, y_true), - **self._metric.compute_group_wise(y_pred, y_true, g, self._eval_grouper.n_groups) + **metric.compute(y_pred, y_true), + **metric.compute_group_wise(y_pred, y_true, g, self._eval_grouper.n_groups) } accs = [] for group_idx in range(self._eval_grouper.n_groups): group_str = self._eval_grouper.group_field_str(group_idx) - group_metric = results.pop(self._metric.group_metric_field(group_idx)) - group_counts = results.pop(self._metric.group_count_field(group_idx)) - results[f'{self._metric.name}_{group_str}'] = group_metric + group_metric = results.pop(metric.group_metric_field(group_idx)) + group_counts = results.pop(metric.group_count_field(group_idx)) + results[f'{metric.name}_{group_str}'] = group_metric results[f'count_{group_str}'] = group_counts if group_counts>0: accs.append(group_metric) accs = np.array(accs) results['10th_percentile_acc'] = np.percentile(accs, 10) - results[f'{self._metric.worst_group_metric_field}'] = self._metric.worst(accs) + results[f'{metric.worst_group_metric_field}'] = metric.worst(accs) results_str = ( - f"Average {self._metric.name}: {results[self._metric.agg_metric_field]:.3f}\n" - f"10th percentile {self._metric.name}: {results['10th_percentile_acc']:.3f}\n" - f"Worst-group {self._metric.name}: {results[self._metric.worst_group_metric_field]:.3f}\n" + f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n" + f"10th percentile {metric.name}: {results['10th_percentile_acc']:.3f}\n" + f"Worst-group {metric.name}: {results[metric.worst_group_metric_field]:.3f}\n" ) return results, results_str else: return self.standard_group_eval( - self._metric, + metric, self._eval_grouper, y_pred, y_true, metadata) diff --git a/wilds/download_datasets.py b/wilds/download_datasets.py new file mode 100644 index 00000000..bf085739 --- /dev/null +++ b/wilds/download_datasets.py @@ -0,0 +1,34 @@ +import os, sys +import argparse +import wilds + +def main(): + """ + Downloads the latest versions of all specified datasets, + if they do not already exist. + """ + parser = argparse.ArgumentParser() + parser.add_argument('--root_dir', required=True, + help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).') + parser.add_argument('--datasets', nargs='*', default=None, + help=f'Specify a space-separated list of dataset names to download. If left unspecified, the script will download all of the official benchmark datasets. Available choices are {wilds.supported_datasets}.') + config = parser.parse_args() + + if config.datasets is None: + config.datasets = wilds.benchmark_datasets + + for dataset in config.datasets: + if dataset not in wilds.supported_datasets: + raise ValueError(f'{dataset} not recognized; must be one of {wilds.supported_datasets}.') + + print(f'Downloading the following datasets: {config.datasets}') + for dataset in config.datasets: + print(f'=== {dataset} ===') + wilds.get_dataset( + dataset=dataset, + root_dir=config.root_dir, + download=True) + + +if __name__=='__main__': + main() diff --git a/wilds/get_dataset.py b/wilds/get_dataset.py new file mode 100644 index 00000000..1073100f --- /dev/null +++ b/wilds/get_dataset.py @@ -0,0 +1,79 @@ +import wilds + +def get_dataset(dataset, version=None, **dataset_kwargs): + """ + Returns the appropriate WILDS dataset class. + Input: + dataset (str): Name of the dataset + version (str): Dataset version number, e.g., '1.0'. + Defaults to the latest version. + dataset_kwargs: Other keyword arguments to pass to the dataset constructors. + Output: + The specified WILDSDataset class. + """ + if version is not None: + version = str(version) + + if dataset not in wilds.supported_datasets: + raise ValueError(f'The dataset {dataset} is not recognized. Must be one of {wilds.supported_datasets}.') + + if dataset == 'amazon': + from wilds.datasets.amazon_dataset import AmazonDataset + return AmazonDataset(version=version, **dataset_kwargs) + + elif dataset == 'camelyon17': + from wilds.datasets.camelyon17_dataset import Camelyon17Dataset + return Camelyon17Dataset(version=version, **dataset_kwargs) + + elif dataset == 'celebA': + from wilds.datasets.celebA_dataset import CelebADataset + return CelebADataset(version=version, **dataset_kwargs) + + elif dataset == 'civilcomments': + from wilds.datasets.civilcomments_dataset import CivilCommentsDataset + return CivilCommentsDataset(version=version, **dataset_kwargs) + + elif dataset == 'iwildcam': + if version == '1.0': + from wilds.datasets.archive.iwildcam_v1_0_dataset import IWildCamDataset + else: + from wilds.datasets.iwildcam_dataset import IWildCamDataset + return IWildCamDataset(version=version, **dataset_kwargs) + + elif dataset == 'waterbirds': + from wilds.datasets.waterbirds_dataset import WaterbirdsDataset + return WaterbirdsDataset(version=version, **dataset_kwargs) + + elif dataset == 'yelp': + from wilds.datasets.yelp_dataset import YelpDataset + return YelpDataset(version=version, **dataset_kwargs) + + elif dataset == 'ogb-molpcba': + from wilds.datasets.ogbmolpcba_dataset import OGBPCBADataset + return OGBPCBADataset(version=version, **dataset_kwargs) + + elif dataset == 'poverty': + if version == '1.0': + from wilds.datasets.archive.poverty_v1_0_dataset import PovertyMapDataset + else: + from wilds.datasets.poverty_dataset import PovertyMapDataset + return PovertyMapDataset(version=version, **dataset_kwargs) + + elif dataset == 'fmow': + if version == '1.0': + from wilds.datasets.archive.fmow_v1_0_dataset import FMoWDataset + else: + from wilds.datasets.fmow_dataset import FMoWDataset + return FMoWDataset(version=version, **dataset_kwargs) + + elif dataset == 'bdd100k': + from wilds.datasets.bdd100k_dataset import BDD100KDataset + return BDD100KDataset(version=version, **dataset_kwargs) + + elif dataset == 'py150': + from wilds.datasets.py150_dataset import Py150Dataset + return Py150Dataset(version=version, **dataset_kwargs) + + elif dataset == 'sqf': + from wilds.datasets.sqf_dataset import SQFDataset + return SQFDataset(version=version, **dataset_kwargs) diff --git a/wilds/version.py b/wilds/version.py index 3f7bf4a6..6d19cfa3 100644 --- a/wilds/version.py +++ b/wilds/version.py @@ -4,7 +4,7 @@ import logging from threading import Thread -__version__ = '1.0.0' +__version__ = '1.1.0' try: os.environ['OUTDATED_IGNORE'] = '1'