Skip to content

Commit

Permalink
RSDK-9590 - upgrade numpy (#805)
Browse files Browse the repository at this point in the history
  • Loading branch information
stuqdog authored Dec 20, 2024
1 parent 36a1530 commit f66b367
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/update_protos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- uses: arduino/setup-protoc@v3
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
version: "28.2"
version: "29.1"

- name: Install uv
uses: astral-sh/setup-uv@v3
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ buf: clean
rm -rf src/viam/gen
chmod +x plugin/main.py
uv pip install protoletariat
uv pip install protobuf --upgrade
uv pip install protobuf==5.29.1
$(eval API_VERSION := $(shell grep 'API_VERSION' src/viam/version_metadata.py | awk -F '"' '{print $$2}'))
buf generate buf.build/viamrobotics/api:${API_VERSION}
buf generate buf.build/viamrobotics/goutils
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dynamic = [
dependencies = [
"googleapis-common-protos>=1.65.0",
"grpclib>=0.4.7",
"protobuf==5.28.2",
"protobuf==5.29.1",
"typing-extensions>=4.12.2",
"pymongo>=4.10.1"
]
Expand Down Expand Up @@ -47,7 +47,7 @@ dev-dependencies = [
"myst-nb>=1.0.0; python_version>='3.9'",
"nbmake>=1.5.4",
"numpy<1.25.0; python_version<'3.9'",
"numpy>=1.26.2,<2; python_version>='3.9'",
"numpy>=1.26.2; python_version>='3.9'",
"pillow>=10.4.0",
"pyright>=1.1.382.post1",
"pytest-asyncio>=0.24.0",
Expand Down
12 changes: 11 additions & 1 deletion src/viam/services/mlmodel/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
from packaging.version import Version

import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -37,7 +38,16 @@ def make_ndarray(flat_data, dtype, shape):
"""Takes flat data (protobuf RepeatedScalarFieldContainer | bytes) to output an ndarray
of appropriate dtype and shape"""
make_array = np.frombuffer if dtype == np.int8 or dtype == np.uint8 else np.array
return make_array(flat_data, dtype).reshape(shape)
# As per proto, int16 and uint16 are stored as uint32. As of numpy v2, this creates
# some strange interactions with negative values for int16. Specifically, we end up
# trying to create an np.Int16 value with an out of bounds int due to rollover.
# Creating our array as a uint32 array initially and then casting to int16 solves this.
if Version(np.__version__) >= Version("2") and dtype == np.int16:
arr = np.astype(make_array(flat_data, np.uint32), np.int16) # pyright: ignore [reportAttributeAccessIssue]

else:
arr = make_array(flat_data, dtype)
return arr.reshape(shape)

ndarrays: Dict[str, NDArray] = dict()
for name, flat_tensor in flat_tensors.tensors.items():
Expand Down
Loading

0 comments on commit f66b367

Please sign in to comment.