From f8661c35389b195f72d659775bd4366f85d29ca9 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 9 May 2024 12:57:25 -0400 Subject: [PATCH] fix typecheck errors in sparse matrix & iknn --- lenskit/algorithms/item_knn.py | 8 ++++---- lenskit/data/matrix.py | 5 ++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/lenskit/algorithms/item_knn.py b/lenskit/algorithms/item_knn.py index e6b9259a3..51b3b1e2a 100644 --- a/lenskit/algorithms/item_knn.py +++ b/lenskit/algorithms/item_knn.py @@ -365,7 +365,7 @@ class ItemItem(Predictor): AGG_WA = intern("weighted-average") RATING_AGGS = [AGG_WA] # the aggregates that use rating values - nnbrs: int | None + nnbrs: int min_nbrs: int min_sim: float save_nbrs: int | None @@ -387,7 +387,7 @@ class ItemItem(Predictor): def __init__( self, - nnbrs: int | None, + nnbrs: int, min_nbrs: int = 1, min_sim: float = 1.0e-6, save_nbrs: int | None = None, @@ -560,9 +560,9 @@ def predict_for_user(self, user, items, ratings=None): _logger.debug("user %s missing, returning empty predictions", user) return pd.Series(np.nan, index=items) upos = self.user_index_.get_loc(user) - row = self.rating_matrix_[upos] + row = self.rating_matrix_[upos] # type: ignore ratings = pd.Series( - row.values(), + row.values().numpy(), index=pd.Index(self.item_index_[row.indices()[0]]), ) diff --git a/lenskit/data/matrix.py b/lenskit/data/matrix.py index ab0467546..1f5cead93 100644 --- a/lenskit/data/matrix.py +++ b/lenskit/data/matrix.py @@ -11,21 +11,20 @@ from __future__ import annotations import logging -from collections import namedtuple -from typing import Generic, Literal, NamedTuple, TypeVar, overload import numpy as np import pandas as pd import scipy.sparse as sps import torch as t from csr import CSR +from typing_extensions import Generic, Literal, NamedTuple, TypeVar, overload _log = logging.getLogger(__name__) M = TypeVar("M", CSR, sps.csr_matrix, sps.coo_matrix, t.Tensor) -class RatingMatrix(NamedTuple): +class RatingMatrix(NamedTuple, Generic[M]): """ A rating matrix with associated indices.