Skip to content

Commit

Permalink
Set default iteration in train based on data size
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuanXu1 committed Aug 22, 2022
1 parent a95a536 commit 13093d1
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions celltypist/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def train(X = None,
with_mean: bool = True,
check_expression: bool = True,
#LR param
C: float = 1.0, solver: Optional[str] = None, max_iter: int = 1000, n_jobs: Optional[int] = None,
C: float = 1.0, solver: Optional[str] = None, max_iter: Optional[int] = None, n_jobs: Optional[int] = None,
#SGD param
use_SGD: bool = False, alpha: float = 0.0001,
#mini-batch
Expand Down Expand Up @@ -218,7 +218,7 @@ def train(X = None,
Maximum number of iterations before reaching the minimum of the cost function.
Try to decrease `max_iter` if the cost function does not converge for a long time.
This argument is for both traditional and SGD logistic classifiers, and will be ignored if mini-batch SGD training is conducted (`use_SGD = True` and `mini_batch = True`).
(Default: 1000)
Default to 200, 500, and 1000 for large (>500k cells), medium (50-500k), and small (<50k) datasets, respectively.
n_jobs
Number of CPUs used. Default to one CPU. `-1` means all CPUs are used.
This argument is for both traditional and SGD logistic classifiers.
Expand Down Expand Up @@ -314,6 +314,14 @@ def train(X = None,
#sklearn (Cython) does not support very large sparse matrices for the time being
if isinstance(indata, spmatrix) and ((indata.indices.dtype == 'int64') or (indata.indptr.dtype == 'int64')):
indata = indata.toarray()
#max_iter
if max_iter is None:
if indata.shape[0] < 50000:
max_iter = 1000
elif indata.shape[0] < 500000:
max_iter = 500
else:
max_iter = 200
#classifier
if use_SGD or feature_selection:
classifier = _SGDClassifier(indata = indata, labels = labels, alpha = alpha, max_iter = max_iter, n_jobs = n_jobs, mini_batch = mini_batch, batch_number = batch_number, batch_size = batch_size, epochs = epochs, balance_cell_type = balance_cell_type, **kwargs)
Expand Down

0 comments on commit 13093d1

Please sign in to comment.