Skip to content

Commit

Permalink
fix numpy.pyi and k_means tests
Browse files Browse the repository at this point in the history
Signed-off-by: Elazar Gershuni <[email protected]>
  • Loading branch information
elazarg committed Nov 26, 2024
1 parent 6ae45a0 commit 8827903
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 37 deletions.
4 changes: 2 additions & 2 deletions experiment/k_means/instrumented.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def run(X: np.ndarray, k: int, max_iterations: int) -> np.ndarray:
for i in transaction.iterate(range(max_iterations)): # type: int
clusters = [list[int]() for _ in range(k)]
for sample_i in range(len(X)):
r = np.argmin(np.linalg.norm(X[sample_i] - centroids, axis=1))
r = np.argmin(np.linalg.norm(X[sample_i] - centroids, None, 1))
clusters[r].append(sample_i)
prev_centroids = centroids
centroids = np.array([np.mean(X[cluster], axis=0) for cluster in clusters])
centroids = np.array([np.mean(X[cluster], 0) for cluster in clusters])
diff = centroids - prev_centroids
if not diff.any():
break
Expand Down
4 changes: 4 additions & 0 deletions test_data/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ def list_append(y: int):
x.append(y)


def list_set(x: list[int]):
x[0] = ""


def list_add(k: int):
x = [1] + [k]
y = [(1,)] + [(2,)]
Expand Down
40 changes: 5 additions & 35 deletions typeshed_mini/numpy.pyi
Original file line number Diff line number Diff line change
@@ -1,47 +1,33 @@

class ndarray:
@property
def size(self: ndarray) -> int:...

def size(self: ndarray) -> int: ...
@property
def ndim(self: ndarray) -> int:...

def ndim(self: ndarray) -> int: ...
@property
@new
def T(self: ndarray) -> ndarray:...

def T(self: ndarray) -> ndarray: ...
@new
def __add__(self: ndarray, other: ndarray) -> ndarray: ...

@new
def __sub__(self: ndarray, other: ndarray) -> ndarray: ...

@new
def __mul__(self: ndarray, other: ndarray) -> ndarray: ...

@new
def __truediv__(self: ndarray, other: ndarray) -> ndarray: ...
@new
def __truediv__(self: ndarray, other: float) -> ndarray: ...

@new
def __radd__(self: ndarray, other: float) -> ndarray: ...

@new
def __rsub__(self: ndarray, other: float) -> ndarray: ...

@new
def __rmul__(self: ndarray, other: float) -> ndarray: ...

@new
def __rtruediv__(self: ndarray, other: float) -> ndarray: ...

@new
def __gt__(self: ndarray, other) -> ndarray: ...

@new
def __lt__(self: ndarray, other) -> ndarray: ...

@new
def __getitem__(self: ndarray, key: slice) -> ndarray: ...
@new
Expand All @@ -50,9 +36,7 @@ class ndarray:
def __getitem__(self: ndarray, key: ndarray) -> ndarray: ...
@new
def __getitem__(self: ndarray, key: list[int]) -> ndarray: ...

def __getitem__(self: ndarray, key: int) -> float: ...

@update(ndarray)
def __setitem__(self: ndarray, key: int, value: float) -> None: ...
@update(ndarray)
Expand All @@ -63,18 +47,14 @@ class ndarray:

@new
def astype(self: ndarray, dtype) -> list[int]: ...

@new
def mean(self: ndarray) -> ndarray: ...
def std(self: ndarray) -> float: ...

@property
@new
def shape(self: ndarray) -> list[int]: ...

def any(self: ndarray) -> bool: ...
def all(self: ndarray) -> bool: ...

@new
def reshape(self: ndarray, shape: tuple) -> ndarray: ...
@new
Expand All @@ -87,43 +67,33 @@ class ndarray:
@property
@new
def c_() -> ndarray: ...

@new
def setdiff1d(a: ndarray, b: ndarray) -> ndarray: ...

@new
def unique(arg: ndarray) -> ndarray: ...

@new
def append(arr: ndarray, value: float) -> ndarray: ...

# def append(arr: ndarray, values: ndarray) -> ndarray: ...

@new
def zeros(dims: tuple) -> ndarray: ...

@new
def zeros(dims: int) -> ndarray: ...

@new
def ones(dims: tuple | int) -> ndarray: ...

@new
def mean(x: ndarray, axis: int) -> ndarray: ...

@new
def dot(x: ndarray, y: ndarray) -> ndarray: ...

def sum(x: ndarray) -> float: ...
def argmin(x: ndarray) -> int: ...

@new
def concatenate(arrays: tuple | ndarray) -> ndarray: ...

@module
class random:
@staticmethod
def seed(seed: int) -> None: ...

@staticmethod
@new
def choice(a: ndarray | int, size: int) -> ndarray: ...
Expand All @@ -132,7 +102,7 @@ class random:
class linalg:
@staticmethod
@new
def norm(a: ndarray, axis: int) -> ndarray: ...
def norm(a: ndarray, ord: None, axis: int) -> ndarray: ...

@new
def array(object) -> ndarray: ...
Expand Down

0 comments on commit 8827903

Please sign in to comment.