Skip to content
This repository has been archived by the owner on May 1, 2024. It is now read-only.

Fixing MRO and circular imports #19

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from platforms.platform import get_platform


class Dataset(abc.ABC, torch.utils.data.Dataset):
class Dataset(torch.utils.data.Dataset, abc.ABC):
"""The base class for all datasets in this framework."""

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from platforms.platform import get_platform


class Model(abc.ABC, torch.nn.Module):
class Model(torch.nn.Module, abc.ABC):
"""The base class used by all models in this codebase."""

@staticmethod
Expand Down
9 changes: 1 addition & 8 deletions platforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,7 @@ class Platform(Hparams):

@property
def device_str(self):
# GPU device.
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
device_ids = ','.join([str(x) for x in range(torch.cuda.device_count())])
return f'cuda:{device_ids}'

# CPU device.
else:
return 'cpu'
return 'cuda' if torch.cuda.is_available() else 'cpu'

@property
def torch_device(self):
Expand Down
9 changes: 6 additions & 3 deletions pruning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import abc

from foundations.hparams import PruningHparams
from models import base
from pruning.mask import Mask

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pruning.mask import Mask
from models import base


class Strategy(abc.ABC):
Expand All @@ -18,5 +21,5 @@ def get_pruning_hparams() -> type:

@staticmethod
@abc.abstractmethod
def prune(pruning_hparams: PruningHparams, trained_model: base.Model, current_mask: Mask = None) -> Mask:
def prune(pruning_hparams: PruningHparams, trained_model: 'base.Model', current_mask: 'Mask' = None) -> 'Mask':
pass
8 changes: 6 additions & 2 deletions pruning/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
import torch

from foundations import paths
from models import base
from platforms.platform import get_platform

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from models import base



class Mask(dict):
def __init__(self, other_dict=None):
Expand All @@ -30,7 +34,7 @@ def __setitem__(self, key, value):
super(Mask, self).__setitem__(key, value)

@staticmethod
def ones_like(model: base.Model) -> 'Mask':
def ones_like(model: 'base.Model') -> 'Mask':
mask = Mask()
for name in model.prunable_layer_names:
mask[name] = torch.ones(list(model.state_dict()[name].shape))
Expand Down
7 changes: 5 additions & 2 deletions pruning/sparse_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
import numpy as np

from foundations import hparams
import models.base
from pruning import base
from pruning.mask import Mask

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from models import base as models_base


@dataclasses.dataclass
class PruningHparams(hparams.PruningHparams):
Expand All @@ -29,7 +32,7 @@ def get_pruning_hparams() -> type:
return PruningHparams

@staticmethod
def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None):
def prune(pruning_hparams: PruningHparams, trained_model: 'models_base.Model', current_mask: Mask = None):
current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy()

# Determine the number of weights that need to be pruned.
Expand Down