Skip to content

Commit

Permalink
Merge pull request #8 from ToyotaResearchInstitute/att_experiments_br…
Browse files Browse the repository at this point in the history
…anch

Att experiments branch
  • Loading branch information
deepakgopinath authored May 15, 2021
2 parents 61a502f + 03b006f commit 7c9412a
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 0 deletions.
Empty file added src/chm/experiments/__init__.py
Empty file.
139 changes: 139 additions & 0 deletions src/chm/experiments/chm_experiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2020 Toyota Research Institute. All rights reserved.
import torch
import random
import numpy as np
import json
import os

from abc import ABC, abstractmethod
from chm.model.model_wrapper import ModelWrapper
from chm.trainers.cognitive_heatmap_trainer import CHMTrainer
from chm.experiments.chm_inference_engine import CHMInferenceEngine
from collections import OrderedDict


class ChmExperiment(ABC):
def __init__(self, args, session_hash, training_experiment=False):
"""
Abstract base class for different CHM experiments for denoising, calibration and awareness estimation
Parameters:
-----------
args: argparse.Namespace
Contains all args specified in the args_file and any additional arg_setter (specified in the derived classes)
session_hash: str
Unique string indicating the sessions id.
training_experiment: bool
Bool indicating whether the experiment require training or just inference.
"""

# create params dict from the args
self.params_dict = vars(args)

# select training device
if self.params_dict["no_cuda"] or not torch.cuda.is_available():
device = torch.device("cpu")
else:
device = torch.device("cuda")

# set random seed
# if random seed is provided use it, if not create a random seed
random_seed = self.params_dict["random_seed"] or random.randint(1, 10000)
print("Random seed used for experiment is ", random_seed)
self.params_dict["random_seed"] = random_seed # update the params dict

# set random seed for numpy and torch
np.random.seed(self.params_dict["random_seed"])
torch.manual_seed(self.params_dict["random_seed"])

# init variables
self.results_aggregator = {} # this is a dictionary that saves results on the inference dataset
self.training_experiment = training_experiment
# folder to save the results
self.results_save_folder = os.path.join(os.path.expanduser("~"), "cognitive_heatmap", "results")

# create model wrapper instance
self.model_wrapper = ModelWrapper(self.params_dict, session_hash)

@abstractmethod
def initialize_functors(self):
"""
Abstract method to be implemented by the derived class.
Defines the input_process_functors (used to process the data_input before inference) and output_process_functors (to process and compute metrics on the the inference output)
"""
pass

def _perform_experiment(self):
"""
Depending on the type of experiment, either training or inference is performed.
"""
if self.training_experiment:
print("Launching training experiment")
trainer = CHMTrainer(self.params_dict)
trainer.fit(self.model_wrapper, ds_type=self.params_dict["inference_ds_type"])
else:
print("Launching inference experiment")
inference_engine = CHMInferenceEngine(self.params_dict)
inference_engine.infer(self.model_wrapper)

self.results_aggregator = self.model_wrapper.get_results_aggregator()

@abstractmethod
def perform_experiment(self):
"""
Abstract method implemented by the derived class. Has to call _perform_experiment() at the end.
"""
pass

def shape_results(self, results_aggregator):
"""
Function to futher shape the results dictionary before saving to disk
Parameters:
---------
result_aggregator: OrderedDict
Dict containing the results of a particular experiment. Each key in the dict corresponds to a different metric computed on the output
Returns:
-------
result_aggregator: OrderedDict
Reshaped results dictionary
"""
return results_aggregator

def save_experiment(self, name_values=OrderedDict()):
"""
Function that saves the results dictionary onto disk.
Parameters:
----------
name_values: OrderedDict()
Dictionary containing various strings that are combined to form the filename.
Returns:
-------
None
"""
assert type(name_values) is OrderedDict
# create the results folder
os.makedirs(self.results_save_folder, exist_ok=True)

# create filename
results_filename = "experiment"
for k in name_values:
results_filename += "_" + k + "_" + str(name_values[k])

# full path to the results filename
results_filename = os.path.join(self.results_save_folder, results_filename)
results_filename += ".json"

# shape results before saving.
experiment_results = self.shape_results(self.results_aggregator)

# save filename
for k in name_values:
experiment_results[k] = name_values[k]
with open(results_filename, "w") as fp:
json.dump(experiment_results, fp, indent=2)
99 changes: 99 additions & 0 deletions src/chm/experiments/chm_inference_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2020 Toyota Research Institute. All rights reserved.
import tqdm

from chm.utils.chm_consts import InferenceMode


class CHMInferenceEngine(object):
"""
Class that defines the inference engine using the CHM.
"""

def __init__(self, params_dict):
"""
Parameters:
----------
params_dict: dict
Dictionary containing the args passed from the experiment script
"""
self.params_dict = params_dict

# type of dataset used for inference. Either 'train' or 'test'
self.inference_ds_type = self.params_dict.get("inference_ds_type", "test")
# Maximum number of batches to perform inference. Populated by the input_process_dict
self.max_batch_num = self.params_dict.get("max_inference_num_batches", 20)
# Inference mode. Determines whether the side-channel input gaze needs to be dropped out or not [WITH_GAZE, WITHOUT_GAZE, BOTH].
self.inference_mode = InferenceMode.BOTH
# Boolean which determines whether the loss needs to be computed during inference.
self.is_compute_loss = False
self.force_value_strs = self.set_force_value_strs()

def set_force_value_strs(self):
"""
Sets the forced_dropout strings according to the inference mode
"""
if self.inference_mode == InferenceMode.BOTH:
self.force_value_strs = ["with_gaze", "without_gaze"]
elif self.inference_mode == InferenceMode.WITH_GAZE:
self.force_value_strs = ["with_gaze"]
elif self.inference_mode == InferenceMode.WITHOUT_GAZE:
self.force_value_strs = ["without_gaze"]

def infer(self, module):
"""
Performs inference using CHM model.
Parameters:
----------
module: ModelWrapper
This ModelWrapper instance contains the model used for inference.
Returns:
--------
None
"""
module.inference_engine = self
# update max_batch_num pull from module.input_process_dict
if "max_batch_num" in module.input_process_dict and module.input_process_dict["max_batch_num"] is not None:
self.max_batch_num = module.input_process_dict["max_batch_num"]

# inference mode - from module.input_process_dict
if "inference_mode" in module.input_process_dict and module.input_process_dict["inference_mode"] is not None:
self.inference_mode = module.input_process_dict["inference_mode"]
# depending on inference mode update the force_value_str
self.set_force_value_strs()

# is_compute_loss - from module.input_process_dict
if "is_compute_loss" in module.input_process_dict:
if module.input_process_dict["is_compute_loss"]:
self.is_compute_loss = True

# set model mode.
assert "driver_facing" in self.params_dict["dropout_ratio"]
if self.params_dict["dropout_ratio"]["driver_facing"] < (1 - 5e-2):
module.train(False)
else:
module.train(True)

# get all dataloaders.
gaze_dataloaders, awareness_dataloaders, _ = module.get_dataloaders()

# create the proper dataloader based on and on the inference_ds_type. inference only happens on the gaze and awareness ds's
if self.inference_ds_type == "train":
dataloader_tqdm = tqdm.tqdm(
enumerate(zip(gaze_dataloaders["train"], awareness_dataloaders["train"])),
desc="inference_train",
)
elif self.inference_ds_type == "test":
dataloader_tqdm = tqdm.tqdm(
enumerate(zip(gaze_dataloaders["test"], awareness_dataloaders["test"])),
desc="inference_test",
)

# go through the dataloader
for i, data_batch in dataloader_tqdm:
if not self.max_batch_num is None:
if i > self.max_batch_num:
break
module.inference_step(data_batch, i, self.force_value_strs, self.is_compute_loss)

print("END OF INFERENCE")

0 comments on commit 7c9412a

Please sign in to comment.