-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from ToyotaResearchInstitute/att_experiments_br…
…anch Att experiments branch
- Loading branch information
Showing
3 changed files
with
238 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright 2020 Toyota Research Institute. All rights reserved. | ||
import torch | ||
import random | ||
import numpy as np | ||
import json | ||
import os | ||
|
||
from abc import ABC, abstractmethod | ||
from chm.model.model_wrapper import ModelWrapper | ||
from chm.trainers.cognitive_heatmap_trainer import CHMTrainer | ||
from chm.experiments.chm_inference_engine import CHMInferenceEngine | ||
from collections import OrderedDict | ||
|
||
|
||
class ChmExperiment(ABC): | ||
def __init__(self, args, session_hash, training_experiment=False): | ||
""" | ||
Abstract base class for different CHM experiments for denoising, calibration and awareness estimation | ||
Parameters: | ||
----------- | ||
args: argparse.Namespace | ||
Contains all args specified in the args_file and any additional arg_setter (specified in the derived classes) | ||
session_hash: str | ||
Unique string indicating the sessions id. | ||
training_experiment: bool | ||
Bool indicating whether the experiment require training or just inference. | ||
""" | ||
|
||
# create params dict from the args | ||
self.params_dict = vars(args) | ||
|
||
# select training device | ||
if self.params_dict["no_cuda"] or not torch.cuda.is_available(): | ||
device = torch.device("cpu") | ||
else: | ||
device = torch.device("cuda") | ||
|
||
# set random seed | ||
# if random seed is provided use it, if not create a random seed | ||
random_seed = self.params_dict["random_seed"] or random.randint(1, 10000) | ||
print("Random seed used for experiment is ", random_seed) | ||
self.params_dict["random_seed"] = random_seed # update the params dict | ||
|
||
# set random seed for numpy and torch | ||
np.random.seed(self.params_dict["random_seed"]) | ||
torch.manual_seed(self.params_dict["random_seed"]) | ||
|
||
# init variables | ||
self.results_aggregator = {} # this is a dictionary that saves results on the inference dataset | ||
self.training_experiment = training_experiment | ||
# folder to save the results | ||
self.results_save_folder = os.path.join(os.path.expanduser("~"), "cognitive_heatmap", "results") | ||
|
||
# create model wrapper instance | ||
self.model_wrapper = ModelWrapper(self.params_dict, session_hash) | ||
|
||
@abstractmethod | ||
def initialize_functors(self): | ||
""" | ||
Abstract method to be implemented by the derived class. | ||
Defines the input_process_functors (used to process the data_input before inference) and output_process_functors (to process and compute metrics on the the inference output) | ||
""" | ||
pass | ||
|
||
def _perform_experiment(self): | ||
""" | ||
Depending on the type of experiment, either training or inference is performed. | ||
""" | ||
if self.training_experiment: | ||
print("Launching training experiment") | ||
trainer = CHMTrainer(self.params_dict) | ||
trainer.fit(self.model_wrapper, ds_type=self.params_dict["inference_ds_type"]) | ||
else: | ||
print("Launching inference experiment") | ||
inference_engine = CHMInferenceEngine(self.params_dict) | ||
inference_engine.infer(self.model_wrapper) | ||
|
||
self.results_aggregator = self.model_wrapper.get_results_aggregator() | ||
|
||
@abstractmethod | ||
def perform_experiment(self): | ||
""" | ||
Abstract method implemented by the derived class. Has to call _perform_experiment() at the end. | ||
""" | ||
pass | ||
|
||
def shape_results(self, results_aggregator): | ||
""" | ||
Function to futher shape the results dictionary before saving to disk | ||
Parameters: | ||
--------- | ||
result_aggregator: OrderedDict | ||
Dict containing the results of a particular experiment. Each key in the dict corresponds to a different metric computed on the output | ||
Returns: | ||
------- | ||
result_aggregator: OrderedDict | ||
Reshaped results dictionary | ||
""" | ||
return results_aggregator | ||
|
||
def save_experiment(self, name_values=OrderedDict()): | ||
""" | ||
Function that saves the results dictionary onto disk. | ||
Parameters: | ||
---------- | ||
name_values: OrderedDict() | ||
Dictionary containing various strings that are combined to form the filename. | ||
Returns: | ||
------- | ||
None | ||
""" | ||
assert type(name_values) is OrderedDict | ||
# create the results folder | ||
os.makedirs(self.results_save_folder, exist_ok=True) | ||
|
||
# create filename | ||
results_filename = "experiment" | ||
for k in name_values: | ||
results_filename += "_" + k + "_" + str(name_values[k]) | ||
|
||
# full path to the results filename | ||
results_filename = os.path.join(self.results_save_folder, results_filename) | ||
results_filename += ".json" | ||
|
||
# shape results before saving. | ||
experiment_results = self.shape_results(self.results_aggregator) | ||
|
||
# save filename | ||
for k in name_values: | ||
experiment_results[k] = name_values[k] | ||
with open(results_filename, "w") as fp: | ||
json.dump(experiment_results, fp, indent=2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2020 Toyota Research Institute. All rights reserved. | ||
import tqdm | ||
|
||
from chm.utils.chm_consts import InferenceMode | ||
|
||
|
||
class CHMInferenceEngine(object): | ||
""" | ||
Class that defines the inference engine using the CHM. | ||
""" | ||
|
||
def __init__(self, params_dict): | ||
""" | ||
Parameters: | ||
---------- | ||
params_dict: dict | ||
Dictionary containing the args passed from the experiment script | ||
""" | ||
self.params_dict = params_dict | ||
|
||
# type of dataset used for inference. Either 'train' or 'test' | ||
self.inference_ds_type = self.params_dict.get("inference_ds_type", "test") | ||
# Maximum number of batches to perform inference. Populated by the input_process_dict | ||
self.max_batch_num = self.params_dict.get("max_inference_num_batches", 20) | ||
# Inference mode. Determines whether the side-channel input gaze needs to be dropped out or not [WITH_GAZE, WITHOUT_GAZE, BOTH]. | ||
self.inference_mode = InferenceMode.BOTH | ||
# Boolean which determines whether the loss needs to be computed during inference. | ||
self.is_compute_loss = False | ||
self.force_value_strs = self.set_force_value_strs() | ||
|
||
def set_force_value_strs(self): | ||
""" | ||
Sets the forced_dropout strings according to the inference mode | ||
""" | ||
if self.inference_mode == InferenceMode.BOTH: | ||
self.force_value_strs = ["with_gaze", "without_gaze"] | ||
elif self.inference_mode == InferenceMode.WITH_GAZE: | ||
self.force_value_strs = ["with_gaze"] | ||
elif self.inference_mode == InferenceMode.WITHOUT_GAZE: | ||
self.force_value_strs = ["without_gaze"] | ||
|
||
def infer(self, module): | ||
""" | ||
Performs inference using CHM model. | ||
Parameters: | ||
---------- | ||
module: ModelWrapper | ||
This ModelWrapper instance contains the model used for inference. | ||
Returns: | ||
-------- | ||
None | ||
""" | ||
module.inference_engine = self | ||
# update max_batch_num pull from module.input_process_dict | ||
if "max_batch_num" in module.input_process_dict and module.input_process_dict["max_batch_num"] is not None: | ||
self.max_batch_num = module.input_process_dict["max_batch_num"] | ||
|
||
# inference mode - from module.input_process_dict | ||
if "inference_mode" in module.input_process_dict and module.input_process_dict["inference_mode"] is not None: | ||
self.inference_mode = module.input_process_dict["inference_mode"] | ||
# depending on inference mode update the force_value_str | ||
self.set_force_value_strs() | ||
|
||
# is_compute_loss - from module.input_process_dict | ||
if "is_compute_loss" in module.input_process_dict: | ||
if module.input_process_dict["is_compute_loss"]: | ||
self.is_compute_loss = True | ||
|
||
# set model mode. | ||
assert "driver_facing" in self.params_dict["dropout_ratio"] | ||
if self.params_dict["dropout_ratio"]["driver_facing"] < (1 - 5e-2): | ||
module.train(False) | ||
else: | ||
module.train(True) | ||
|
||
# get all dataloaders. | ||
gaze_dataloaders, awareness_dataloaders, _ = module.get_dataloaders() | ||
|
||
# create the proper dataloader based on and on the inference_ds_type. inference only happens on the gaze and awareness ds's | ||
if self.inference_ds_type == "train": | ||
dataloader_tqdm = tqdm.tqdm( | ||
enumerate(zip(gaze_dataloaders["train"], awareness_dataloaders["train"])), | ||
desc="inference_train", | ||
) | ||
elif self.inference_ds_type == "test": | ||
dataloader_tqdm = tqdm.tqdm( | ||
enumerate(zip(gaze_dataloaders["test"], awareness_dataloaders["test"])), | ||
desc="inference_test", | ||
) | ||
|
||
# go through the dataloader | ||
for i, data_batch in dataloader_tqdm: | ||
if not self.max_batch_num is None: | ||
if i > self.max_batch_num: | ||
break | ||
module.inference_step(data_batch, i, self.force_value_strs, self.is_compute_loss) | ||
|
||
print("END OF INFERENCE") |