Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed join to DHS datasets to include children 6 to 18 years old #3

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion imagery_scraping/download_imagery.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def download_imagery(filepath, drive, year, sensor, range_km, rgb_only, parallel


export_params = {
'description': target_df[name_colname][i],
'description': str(target_df[name_colname][i]),
'folder': drive,
'scale': resolution_m, # This is the resolution in meters
'region': region,
Expand Down
94 changes: 90 additions & 4 deletions modelling/dino/finetune_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,55 @@
from sklearn.model_selection import train_test_split
from torch.optim import Adam
from torch.nn import L1Loss


"""
Finetuning the DinoV2 model using spatial data
This spatial data is stored in survey_processing/processed_data and is split into folds
View survey_processing/main.py for more information on this

Dinov2's input is a RGB satellite image converted to a tensor
The target for each satellite image is several dhs variables (predict_target variable)
These variables once one-hot encoded form a larger dimension vector
The DinoV2 model outputs a 768 dimension vector, so we add an additional linear layer with sigmoid activation function
in order to get an output the size of our target vector

After each epoch we save the models weights to our 'last model' file and if the error is sufficiently low -
we save it to 'best model' file also

Satellite imagery is saved in the following file structure
Sub directories should be of the form country code + year + satellite
Filenames are the CENTROID_ID

- imagery parent directory
--- ET2018S2
------ ET2000000090.tif
------ ET2000000213.tif
------ ...
--- RW2018S2
------ ...
--- ...
"""


def main(fold, model_name, imagery_path, imagery_source, emb_size, batch_size, num_epochs):

"""
Finetunes and validates Dinov2 model using one fold of data
Saves the last and the best model weights to file

Parameters:
fold (integer): fold number
model_name (string): model name (i.e dinov2_vitb14)
imagery_path (string): parent directory of imagery
imagery_source (string): Landsat (L) or Sentinel (S)
emb_size (int): size of model output, default is 768
batch_size (int): batch size
num_epochs (int): number of epochs

Returns:
None
"""

if imagery_source == 'L':
normalization = 30000.
imagery_size = 336
Expand All @@ -28,24 +75,30 @@ def main(fold, model_name, imagery_path, imagery_source, emb_size, batch_size, n
raise Exception("Unsupported imagery source")
data_folder = r'survey_processing/processed_data'

# load preprocessed dhs data of the fold we are considering, we will take target columns from this
train_df = pd.read_csv(f'{data_folder}/train_fold_{fold}.csv', index_col=0)
test_df = pd.read_csv(f'{data_folder}/test_fold_{fold}.csv', index_col=0)

# store file paths of all available imagery in following list
available_imagery = []
for d in os.listdir(imagery_path):
# d[-2] will either be S or L, refer to top comment to understand file structure of saved images
if d[-2] == imagery_source:
for f in os.listdir(os.path.join(imagery_path, d)):
available_imagery.append(os.path.join(imagery_path, d, f))

# gets filename of each image without the .fileformat
available_centroids = [f.split('/')[-1][:-4] for f in available_imagery]
# filter df to remove rows with no corresponding satellite image
train_df = train_df[train_df['CENTROID_ID'].isin(available_centroids)]
test_df = test_df[test_df['CENTROID_ID'].isin(available_centroids)]


def filter_contains(query):
"""
Returns a list of items that contain the given query substring.

Parameters:
items (list of str): The list of strings to search within.
query (str): The substring to search for in each item of the list.

Returns:
Expand All @@ -55,11 +108,17 @@ def filter_contains(query):
for item in available_imagery:
if query in item:
return item


# add file path of satellite imagery corresponding to each row
train_df['imagery_path'] = train_df['CENTROID_ID'].apply(filter_contains)
test_df['imagery_path'] = test_df['CENTROID_ID'].apply(filter_contains)

# dhs variables to use as target data
# vaccination status, wealth index, height for age s.d, level of education, water access, sleeping arrangements etc
predict_target = ['h10', 'h3', 'h31', 'h5', 'h7', 'h9', 'hc70', 'hv109', 'hv121', 'hv106', 'hv201', 'hv204', 'hv205', 'hv216', 'hv225', 'hv271', 'v312']

# find one hot encoded columns associated with each of the categorical targets using regex
filtered_predict_target = []
for col in predict_target:
filtered_predict_target.extend(
Expand All @@ -69,6 +128,7 @@ def filter_contains(query):
train_df = train_df.dropna(subset=filtered_predict_target)
predict_target = sorted(filtered_predict_target)


def load_and_preprocess_image(path):
with rasterio.open(path) as src:
# Read the specific bands (4, 3, 2 for RGB)
Expand All @@ -85,6 +145,7 @@ def load_and_preprocess_image(path):

return img.astype(np.uint8) # Convert to uint8


def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
Expand All @@ -94,13 +155,20 @@ def set_seed(seed):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set your desired seed

# set your desired seed
seed = 42
set_seed(seed)

train, validation = train_test_split(train_df, test_size=0.2, random_state=42)


class CustomDataset(Dataset):
"""
Stores dataframe and transform (collection of image transforms)
When object is indexed, returns image_tensor, target
"""

def __init__(self, dataframe, transform):
self.dataframe = dataframe
self.transform = transform
Expand All @@ -118,6 +186,8 @@ def __getitem__(self, idx):
target = torch.tensor(item[predict_target], dtype=torch.float32)
return image_tensor, target # Adjust based on actual output of feature_extractor


# convert image to tensor of the correct size
transform = transforms.Compose([
transforms.Resize((imagery_size, imagery_size)), # Resize the image to the input size expected by the model
transforms.ToTensor(), # Convert the image to a PyTorch tensor
Expand All @@ -132,6 +202,7 @@ def __getitem__(self, idx):

base_model = torch.hub.load('facebookresearch/dinov2', model_name)


def save_checkpoint(model, optimizer, epoch, loss, filename="checkpoint.pth"):
torch.save({
'epoch': epoch,
Expand All @@ -140,8 +211,16 @@ def save_checkpoint(model, optimizer, epoch, loss, filename="checkpoint.pth"):
'loss': loss
}, filename)


torch.cuda.empty_cache()


class ViTForRegression(nn.Module):
"""
Parent class is nn.Module (i.e DinoV2 model)
Adds additional linear layer with sigmoid activation function in order to get output of length len(predict_target)
"""

def __init__(self, base_model):
super().__init__()
self.base_model = base_model
Expand All @@ -153,6 +232,8 @@ def forward(self, pixel_values):
# We use the last hidden state
return torch.sigmoid(self.regression_head(outputs))


# load last and best model for comparison of loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViTForRegression(base_model).to(device)
best_model = f'modelling/dino/model/{model_name}_{fold}_all_cluster_best_{imagery_source}.pth'
Expand All @@ -177,10 +258,12 @@ def forward(self, pixel_values):
optimizer = torch.optim.Adam([base_model_params, head_params])
loss_fn = L1Loss()

# training and validation
for epoch in range(epochs_ran+1, num_epochs):
torch.cuda.empty_cache()
model.train()
print('Training...')

for batch in tqdm(train_loader):
images, targets = batch
images, targets = images.to(device), targets.to(device)
Expand All @@ -193,7 +276,9 @@ def forward(self, pixel_values):
optimizer.zero_grad()
loss.backward()
optimizer.step()

torch.cuda.empty_cache()

# Validation phase
model.eval()
val_loss = []
Expand All @@ -214,14 +299,15 @@ def forward(self, pixel_values):
mean_val_loss = np.mean(val_loss)
mean_indiv_loss = torch.stack(indiv_loss).mean(dim=0)

# save best and last model if appropriate
if mean_val_loss< best_error:
save_checkpoint(model, optimizer, epoch, mean_val_loss, filename=best_model)
best_error = mean_val_loss
print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {mean_val_loss}, Individual Loss: {mean_indiv_loss}')
save_checkpoint(model, optimizer, epoch, mean_val_loss, filename=last_model)



# handle command line inputs, note we have to run a seperate command to train on each fold
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run satellite image processing model training.')
parser.add_argument('--fold', type=int, help='CV fold')
Expand Down
Loading