diff --git a/celltypist/train.py b/celltypist/train.py index f0bdcda..fc6d128 100644 --- a/celltypist/train.py +++ b/celltypist/train.py @@ -165,7 +165,7 @@ def train(X = None, with_mean: bool = True, check_expression: bool = True, #LR param - C: float = 1.0, solver: Optional[str] = None, max_iter: int = 1000, n_jobs: Optional[int] = None, + C: float = 1.0, solver: Optional[str] = None, max_iter: Optional[int] = None, n_jobs: Optional[int] = None, #SGD param use_SGD: bool = False, alpha: float = 0.0001, #mini-batch @@ -218,7 +218,7 @@ def train(X = None, Maximum number of iterations before reaching the minimum of the cost function. Try to decrease `max_iter` if the cost function does not converge for a long time. This argument is for both traditional and SGD logistic classifiers, and will be ignored if mini-batch SGD training is conducted (`use_SGD = True` and `mini_batch = True`). - (Default: 1000) + Default to 200, 500, and 1000 for large (>500k cells), medium (50-500k), and small (<50k) datasets, respectively. n_jobs Number of CPUs used. Default to one CPU. `-1` means all CPUs are used. This argument is for both traditional and SGD logistic classifiers. @@ -314,6 +314,14 @@ def train(X = None, #sklearn (Cython) does not support very large sparse matrices for the time being if isinstance(indata, spmatrix) and ((indata.indices.dtype == 'int64') or (indata.indptr.dtype == 'int64')): indata = indata.toarray() + #max_iter + if max_iter is None: + if indata.shape[0] < 50000: + max_iter = 1000 + elif indata.shape[0] < 500000: + max_iter = 500 + else: + max_iter = 200 #classifier if use_SGD or feature_selection: classifier = _SGDClassifier(indata = indata, labels = labels, alpha = alpha, max_iter = max_iter, n_jobs = n_jobs, mini_batch = mini_batch, batch_number = batch_number, batch_size = batch_size, epochs = epochs, balance_cell_type = balance_cell_type, **kwargs)