diff --git a/tuning/aim_loader.py b/tuning/aim_loader.py index 32a2a9f18..204c50a22 100644 --- a/tuning/aim_loader.py +++ b/tuning/aim_loader.py @@ -12,11 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Standard -import os - # Local -from tuning.utils.import_utils import is_aim_available +from tuning.utils.import_utils import get_aim_config, is_aim_available if is_aim_available(): # Third Party @@ -25,9 +22,7 @@ def get_aimstack_callback(): # Initialize a new run - aim_server = os.environ.get("AIMSTACK_SERVER") - aim_db = os.environ.get("AIMSTACK_DB") - aim_experiment = os.environ.get("AIMSTACK_EXPERIMENT") + aim_server, aim_db, aim_experiment = get_aim_config() if aim_experiment is None: aim_experiment = "" diff --git a/tuning/utils/import_utils.py b/tuning/utils/import_utils.py index bc21ee62d..1e1991e66 100644 --- a/tuning/utils/import_utils.py +++ b/tuning/utils/import_utils.py @@ -12,11 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Standard +from typing import Optional, Tuple +import os + # Third Party from transformers.utils.import_utils import _is_package_available _is_aim_available = _is_package_available("aim") +_aim_server = os.environ.get("AIMSTACK_SERVER") +_aim_db = os.environ.get("AIMSTACK_DB") +_aim_experiment = os.environ.get("AIMSTACK_EXPERIMENT") + + +def get_aim_config() -> Tuple[Optional[str], Optional[None], Optional[None]]: + """ + Returns: aim_server, aim_db, aim_experiment + """ + return _aim_server, _aim_db, _aim_experiment def is_aim_available(): - return _is_aim_available + return ( + _is_aim_available + and (_aim_server is not None) + and (_aim_db is not None) + and (_aim_experiment is not None) + )