From fdd2153f3feaf52707137d88a97068435185c357 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Mon, 26 Feb 2024 15:05:38 +0000 Subject: [PATCH] stylistic nits --- src/refiners/training_utils/batch.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/refiners/training_utils/batch.py b/src/refiners/training_utils/batch.py index 97fb12e04..092ca049a 100644 --- a/src/refiners/training_utils/batch.py +++ b/src/refiners/training_utils/batch.py @@ -57,7 +57,7 @@ def __init__(self, **kwargs: AttrType): for attr_name in kwargs: self.__setattr__(attr_name, kwargs[attr_name], check_size=False) - new_size = self.attr_length(attr_name) + new_size = self.attr_size(attr_name) if size is not None and size != new_size: raise ValueError(f"Attribute '{attr_name}' has size {new_size}, expected {size}") size = new_size @@ -79,10 +79,8 @@ def __setattr__(self, name: str, value: Any, check_size: bool = True) -> None: raise TypeError( f"Invalid type for attribute '{name}': Expected '{attr_type.__name__}', got '{type(value).__name__}'" ) - if isinstance(value, list): - new_size = len(value) - else: - new_size = value.shape[0] + + new_size = len(value) if isinstance(value, list) else value.shape[0] if check_size and new_size != len(self): raise ValueError(f"Attribute '{name}' has size {new_size}, expected {len(self)}") @@ -100,7 +98,7 @@ def collate(cls: Type[T], batch_list: list[T]) -> T: raise ValueError(f"Cannot collate an empty list of {cls.__name__}") for attr_name, attr_type in attr_types.items(): - attr_list = [getattr(obj, attr_name) for obj in batch_list] + attr_list = [obj.__getattr__(attr_name) for obj in batch_list] if attr_type == Tensor: tensor_tuple = cast(tuple[Tensor, ...], tuple(attr_list)) @@ -134,7 +132,7 @@ def to(self: T, device: Device | None = None, dtype: DType | None = None) -> T: return self - def attr_length(self, name: str) -> int: + def attr_size(self, name: str) -> int: value = self.__getattr__(name) if isinstance(value, list): return len(value) @@ -142,15 +140,14 @@ def attr_length(self, name: str) -> int: return value.shape[0] def __len__(self) -> int: - return self.attr_length(list(self.__class__.attr_types().keys())[0]) + return self.attr_size(list(self.__class__.attr_types().keys())[0]) def to_dict(self) -> dict[str, AttrType]: return {attr_name: getattr(self, attr_name) for attr_name in self.__class__.attr_types()} def split(self: T) -> list[T]: result: list[T] = [] - l = len(self) - for i in range(l): + for i in range(len(self)): args = {attr_name: getattr(self, attr_name)[i : i + 1] for attr_name in self.__class__.attr_types()} result.append(self.__class__(**args)) return result