Skip to content

Commit

Permalink
feat: add HMMClassifier.fit multiprocessing (#259)
Browse files Browse the repository at this point in the history
  • Loading branch information
eonu authored Dec 27, 2024
1 parent 4ce8f9e commit f79f512
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 7 deletions.
4 changes: 2 additions & 2 deletions docs/source/sections/models/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ The following models provided by Sequentia all support variable length sequences
| | | | +----------+------------+
| | | | | Training | Prediction |
+=========================+==============================+================+===============+==============+==========+============+
| :class:`.HMMClassifier` | :class:`.GaussianMixtureHMM` | Classification | Real || ||
| :class:`.HMMClassifier` | :class:`.GaussianMixtureHMM` | Classification | Real || ||
| +------------------------------+----------------+---------------+--------------+----------+------------+
| | :class:`.CategoricalHMM` | Classification | Categorical || ||
| | :class:`.CategoricalHMM` | Classification | Categorical || ||
+-------------------------+------------------------------+----------------+---------------+--------------+----------+------------+
| :class:`.KNNRegressor` | Regression | Real || N/A ||
+--------------------------------------------------------+----------------+---------------+--------------+----------+------------+
Expand Down
5 changes: 3 additions & 2 deletions sequentia/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ def param_grid(**kwargs: list[t.Any]) -> list[dict[str, t.Any]]:
settings for :class:`.GaussianMixtureHMM`, which is a nested model
specified in the constructor of a :class:`.HMMClassifier`. ::
from sklearn.preprocessing import Pipeline, minmax_scale
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import minmax_scale
from sequenta.enums import PriorMode, CovarianceMode, TopologyMode
from sequentia.enums import PriorMode, CovarianceMode, TopologyMode
from sequentia.models import HMMClassifier, GaussianMixtureHMM
from sequentia.preprocessing import IndependentFunctionTransformer
from sequentia.model_selection import GridSearchCV, StratifiedKFold
Expand Down
20 changes: 18 additions & 2 deletions sequentia/models/hmm/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,24 @@ def fit(
lengths=lengths,
classes=self.classes_,
)
for X_c, lengths_c, c in dataset.iter_by_class():
self.models[c].fit(X_c, lengths=lengths_c)

# get number of jobs
n_jobs = _multiprocessing.effective_n_jobs(
self.n_jobs, x=self.classes_
)

# fit models in parallel
self.models = dict(
zip(
self.classes_,
joblib.Parallel(n_jobs=n_jobs, max_nbytes=None)(
joblib.delayed(self.models[c].fit)(
X_c, lengths=lengths_c
)
for X_c, lengths_c, c in dataset.iter_by_class()
),
)
)

# Set class priors
models: t.Iterable[int, variants.BaseHMM] = self.models.items()
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/test_models/hmm/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,18 @@ def assert_fit(clf: BaseHMM):
],
)
@pytest.mark.parametrize("fit_mode", list(FitMode))
@pytest.mark.parametrize("n_jobs", [1, -1])
def test_classifier_e2e(
request: SubRequest,
helpers: t.Any,
model: BaseHMM,
dataset: SequentialDataset,
prior: enums.PriorMode | dict[int, float],
fit_mode: FitMode,
n_jobs: int,
random_state: np.random.RandomState,
) -> None:
clf = HMMClassifier(prior=prior)
clf = HMMClassifier(prior=prior, n_jobs=n_jobs)
clf.add_models({i: copy.deepcopy(model) for i in range(n_classes)})

assert clf.prior == prior
Expand All @@ -156,6 +158,7 @@ def test_classifier_e2e(
variant=type(model),
model_kwargs=model.get_params(),
prior=prior,
n_jobs=n_jobs,
)
clf.fit(**train.X_y_lengths)

Expand Down

0 comments on commit f79f512

Please sign in to comment.