Skip to content

Commit

Permalink
stylistic nits
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Feb 26, 2024
1 parent ac63736 commit fdd2153
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions src/refiners/training_utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, **kwargs: AttrType):

for attr_name in kwargs:
self.__setattr__(attr_name, kwargs[attr_name], check_size=False)
new_size = self.attr_length(attr_name)
new_size = self.attr_size(attr_name)
if size is not None and size != new_size:
raise ValueError(f"Attribute '{attr_name}' has size {new_size}, expected {size}")
size = new_size
Expand All @@ -79,10 +79,8 @@ def __setattr__(self, name: str, value: Any, check_size: bool = True) -> None:
raise TypeError(
f"Invalid type for attribute '{name}': Expected '{attr_type.__name__}', got '{type(value).__name__}'"
)
if isinstance(value, list):
new_size = len(value)
else:
new_size = value.shape[0]

new_size = len(value) if isinstance(value, list) else value.shape[0]

if check_size and new_size != len(self):
raise ValueError(f"Attribute '{name}' has size {new_size}, expected {len(self)}")
Expand All @@ -100,7 +98,7 @@ def collate(cls: Type[T], batch_list: list[T]) -> T:
raise ValueError(f"Cannot collate an empty list of {cls.__name__}")

for attr_name, attr_type in attr_types.items():
attr_list = [getattr(obj, attr_name) for obj in batch_list]
attr_list = [obj.__getattr__(attr_name) for obj in batch_list]

if attr_type == Tensor:
tensor_tuple = cast(tuple[Tensor, ...], tuple(attr_list))
Expand Down Expand Up @@ -134,23 +132,22 @@ def to(self: T, device: Device | None = None, dtype: DType | None = None) -> T:

return self

def attr_length(self, name: str) -> int:
def attr_size(self, name: str) -> int:
value = self.__getattr__(name)
if isinstance(value, list):
return len(value)
else:
return value.shape[0]

def __len__(self) -> int:
return self.attr_length(list(self.__class__.attr_types().keys())[0])
return self.attr_size(list(self.__class__.attr_types().keys())[0])

def to_dict(self) -> dict[str, AttrType]:
return {attr_name: getattr(self, attr_name) for attr_name in self.__class__.attr_types()}

def split(self: T) -> list[T]:
result: list[T] = []
l = len(self)
for i in range(l):
for i in range(len(self)):
args = {attr_name: getattr(self, attr_name)[i : i + 1] for attr_name in self.__class__.attr_types()}
result.append(self.__class__(**args))
return result
Expand Down

0 comments on commit fdd2153

Please sign in to comment.