Skip to content

Commit

Permalink
Update docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter9192 committed Jan 30, 2024
1 parent cabb7ad commit 5f3ed0b
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions src/dumme/dumme.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,26 @@
# BaseEstimator has boilerplate for things like get_params, set_params, _validate_data
class MixedEffectsModel(RegressorMixin, BaseEstimator):
"""
This is the core class to instantiate, train, and predict using a mixed effects random forest model.
It roughly adheres to the sklearn estimator API.
Note that the user must pass in an already instantiated fixed_effects_model that adheres to the
sklearn regression estimator API, i.e. must have a fit() and predict() method defined.
It assumes a data model of the form:
Scikit-learn compatbile implementation of a mixed-effects model of the form
.. math::
y = f(X) + b_i Z + e
* y is the target variable. The current code only supports regression for now, e.g. continuously varying scalar value
* y is the target variable. The current code only supports regression for
now, e.g. continuously varying scalar value
* X is the fixed effect features. Assume p dimensional
* f(.) is the nonlinear fixed effects mode, e.g. random forest
* Z is the random effect features. Assume q dimensional.
* e is iid noise ~N(0, sigma_e²)
* i is the cluster index. Assume k clusters in the training.
* bi is the random effect coefficients. They are different per cluster i but are assumed to be drawn from the same distribution ~N(0, Sigma_b) where Sigma_b is learned from the data.
* bi is the random effect coefficients. They are different per cluster i but
are assumed to be drawn from the same distribution ~N(0, Sigma_b) where
Sigma_b is learned from the data.
Args:
gll_early_stop_threshold (float): early stopping threshold on GLL improvement
max_iterations (int): maximum number of EM iterations to train
gll_early_stop_threshold (float): early stopping threshold on GLL
improvement max_iterations (int): maximum number of EM iterations
"""

def __init__(
Expand Down

0 comments on commit 5f3ed0b

Please sign in to comment.