From 90b869b440e92c33bdfc89c09369cedd1c711b90 Mon Sep 17 00:00:00 2001 From: Casey Jao Date: Sun, 10 Dec 2023 15:38:21 -0500 Subject: [PATCH] Reduce memory overhead for TransportableObject * Always represent TransportableObject internally as a single array of bytes. Various properties, such as `header`, or `object_string`, decode various segments of the byte array. * Store the serialized object as raw picklebytes without base64-encoding. As a result, `get_deserialized()` no longer needs to create a temporary copy of the raw picklebytes. The data segment is directly unpickled. Base64-encoding is applied to the data segment or the entire internal buffer whenever a print friendly representation of the `TransportableObject` is desired. * Since the properties of `TransportableObject` are simply views into the underlying buffer, `TransportableObject` may itself be serialized efficiently by simply writing out the byte array. ` Add backward compatibility layer for deserialization --- CHANGELOG.md | 4 + covalent/_workflow/transportable_object.py | 254 +++++++++--------- .../_service/assets_test.py | 3 +- .../covalent_tests/workflow/transport_test.py | 178 ++++++++---- 4 files changed, 258 insertions(+), 181 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f670f14b..6041d16e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] +### Changed + +- Improved memory overhead for operations involving TransportableObject + ## [0.235.1-rc.0] - 2024-06-10 ### Authors diff --git a/covalent/_workflow/transportable_object.py b/covalent/_workflow/transportable_object.py index 7dcb4e073..acea69715 100644 --- a/covalent/_workflow/transportable_object.py +++ b/covalent/_workflow/transportable_object.py @@ -19,7 +19,7 @@ import base64 import json import platform -from typing import Any, Callable, Tuple +from typing import Any, Callable, Dict, Tuple import cloudpickle @@ -29,77 +29,12 @@ DATA_OFFSET_BYTES = 8 HEADER_OFFSET = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES BYTE_ORDER = "big" - - -class _TOArchive: - """Archived transportable object.""" - - def __init__(self, header: bytes, object_string: bytes, data: bytes): - """ - Initialize TOArchive. - - Args: - header: Archived transportable object header. - object_string: Archived transportable object string. - data: Archived transportable object data. - - Returns: - None - """ - - self.header = header - self.object_string = object_string - self.data = data - - def cat(self) -> bytes: - """ - Concatenate TOArchive. - - Returns: - Concatenated TOArchive. - - """ - - header_size = len(self.header) - string_size = len(self.object_string) - data_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size + string_size - string_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size - - data_offset = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER, signed=False) - string_offset = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER, signed=False) - - return string_offset + data_offset + self.header + self.object_string + self.data - - @staticmethod - def load(serialized: bytes, header_only: bool, string_only: bool) -> "_TOArchive": - """ - Load TOArchive object from serialized bytes. - - Args: - serialized: Serialized transportable object. - header_only: Load header only. - string_only: Load string only. - - Returns: - Archived transportable object. - - """ - - string_offset = TOArchiveUtils.string_offset(serialized) - header = TOArchiveUtils.parse_header(serialized, string_offset) - object_string = b"" - data = b"" - - if not header_only: - data_offset = TOArchiveUtils.data_offset(serialized) - object_string = TOArchiveUtils.parse_string(serialized, string_offset, data_offset) - - if not string_only: - data = TOArchiveUtils.parse_data(serialized, data_offset) - return _TOArchive(header, object_string, data) +TOBJ_FMT_STR = "0.1" class TOArchiveUtils: + """Utilities for reading serialized TransportableObjects""" + @staticmethod def data_offset(serialized: bytes) -> int: size64 = serialized[STRING_OFFSET_BYTES : STRING_OFFSET_BYTES + DATA_OFFSET_BYTES] @@ -119,24 +54,38 @@ def string_byte_range(serialized: bytes) -> Tuple[int, int]: @staticmethod def data_byte_range(serialized: bytes) -> Tuple[int, int]: - """Return byte range for the b64 picklebytes""" + """Return byte range for the picklebytes""" start_byte = TOArchiveUtils.data_offset(serialized) return start_byte, -1 @staticmethod - def parse_header(serialized: bytes, string_offset: int) -> bytes: + def header(serialized: bytes) -> dict: + string_offset = TOArchiveUtils.string_offset(serialized) header = serialized[HEADER_OFFSET:string_offset] - return header + return json.loads(header.decode("utf-8")) @staticmethod - def parse_string(serialized: bytes, string_offset: int, data_offset: int) -> bytes: + def string_segment(serialized: bytes) -> bytes: + string_offset = TOArchiveUtils.string_offset(serialized) + data_offset = TOArchiveUtils.data_offset(serialized) return serialized[string_offset:data_offset] @staticmethod - def parse_data(serialized: bytes, data_offset: int) -> bytes: + def data_segment(serialized: bytes) -> bytes: + data_offset = TOArchiveUtils.data_offset(serialized) return serialized[data_offset:] +class _ByteArrayFile: + """File-like interface for appending to a bytearray.""" + + def __init__(self, buf: bytearray): + self._buf = buf + + def write(self, data: bytes): + self._buf.extend(data) + + class TransportableObject: """ A function is converted to a transportable object by serializing it using cloudpickle @@ -149,13 +98,13 @@ class TransportableObject: """ def __init__(self, obj: Any) -> None: - b64object = base64.b64encode(cloudpickle.dumps(obj)) - object_string_u8 = str(obj).encode("utf-8") + self._buffer = bytearray() - self._object = b64object.decode("utf-8") - self._object_string = object_string_u8.decode("utf-8") + # Reserve space for the byte offsets to be written at the end + self._buffer.extend(b"\0" * HEADER_OFFSET) - self._header = { + _header = { + "format": TOBJ_FMT_STR, "py_version": platform.python_version(), "cloudpickle_version": cloudpickle.__version__, "attrs": { @@ -164,23 +113,48 @@ def __init__(self, obj: Any) -> None: }, } + # Write header and object string + header_u8 = json.dumps(_header).encode("utf-8") + header_len = len(header_u8) + + object_string_u8 = str(obj).encode("utf-8") + object_string_len = len(object_string_u8) + + self._buffer.extend(header_u8) + self._buffer.extend(object_string_u8) + del object_string_u8 + + # Append picklebytes (not base64-encoded) + cloudpickle.dump(obj, _ByteArrayFile(self._buffer)) + + # Write byte offsets + string_offset = HEADER_OFFSET + header_len + data_offset = string_offset + object_string_len + + string_offset_bytes = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER) + data_offset_bytes = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER) + self._buffer[:STRING_OFFSET_BYTES] = string_offset_bytes + self._buffer[STRING_OFFSET_BYTES:HEADER_OFFSET] = data_offset_bytes + @property def python_version(self): - return self._header["py_version"] + return self.header["py_version"] @property def header(self): - return self._header + return TOArchiveUtils.header(self._buffer) @property def attrs(self): - return self._header["attrs"] + return self.header["attrs"] @property def object_string(self): # For compatibility with older Covalent try: - return self._object_string + return ( + TOArchiveUtils.string_segment(memoryview(self._buffer)).tobytes().decode("utf-8") + ) except AttributeError: return self.__dict__["object_string"] @@ -201,11 +175,15 @@ def get_deserialized(self) -> Callable: """ - return cloudpickle.loads(base64.b64decode(self._object.encode("utf-8"))) + return cloudpickle.loads(TOArchiveUtils.data_segment(memoryview(self._buffer))) def to_dict(self) -> dict: """Return a JSON-serializable dictionary representation of self""" - return {"type": "TransportableObject", "attributes": self.__dict__.copy()} + attr_dict = { + "buffer_b64": base64.b64encode(memoryview(self._buffer)).decode("utf-8"), + } + + return {"type": "TransportableObject", "attributes": attr_dict} @staticmethod def from_dict(object_dict) -> "TransportableObject": @@ -219,7 +197,7 @@ def from_dict(object_dict) -> "TransportableObject": """ sc = TransportableObject(None) - sc.__dict__ = object_dict["attributes"] + sc._buffer = base64.b64decode(object_dict["attributes"]["buffer_b64"].encode("utf-8")) return sc def get_serialized(self) -> str: @@ -233,7 +211,9 @@ def get_serialized(self) -> str: object: The serialized transportable object. """ - return self._object + # For backward compatibility + data_segment = TOArchiveUtils.data_segment(memoryview(self._buffer)) + return base64.b64encode(data_segment).decode("utf-8") def serialize(self) -> bytes: """ @@ -246,7 +226,7 @@ def serialize(self) -> bytes: pickled_object: The serialized object alongwith the python version. """ - return _to_archive(self).cat() + return self._buffer def serialize_to_json(self) -> str: """ @@ -295,9 +275,7 @@ def make_transportable(obj) -> "TransportableObject": return TransportableObject(obj) @staticmethod - def deserialize( - serialized: bytes, *, header_only: bool = False, string_only: bool = False - ) -> "TransportableObject": + def deserialize(serialized: bytes) -> "TransportableObject": """ Deserialize the transportable object. @@ -307,9 +285,58 @@ def deserialize( Returns: object: The deserialized transportable object. """ + to = TransportableObject(None) + header = TOArchiveUtils.header(serialized) + + # For backward compatibility + if header.get("format") is None: + # Re-encode TObj serialized using older versions of the SDK, + # characterized by the lack of a "format" field in the + # header. TObj was previously serialized as + # [offsets][header][string][b64-encoded picklebytes], + # whereas starting from format 0.1 we store them as + # [offsets][header][string][picklebytes]. + to._buffer = TransportableObject._upgrade_tobj_format(serialized, header) + else: + to._buffer = serialized + return to + + @staticmethod + def _upgrade_tobj_format(serialized: bytes, header: Dict) -> bytes: + """Re-encode a serialized TObj in the newer format. + + This involves adding a format version in the header and + base64-decoding the data segment. Because the header at the + beginning of the byte array, the string and data offsets need + to be recomputed. + """ + buf = bytearray() + + # Upgrade header and recompute byte offsets + header["format"] = TOBJ_FMT_STR + serialized_header = json.dumps(header).encode("utf-8") + string_offset = HEADER_OFFSET + len(serialized_header) + + # This is just a view into the bytearray and consumes + # negligible space on its own. + string_segment = TOArchiveUtils.string_segment(serialized) + + data_offset = string_offset + len(string_segment) + string_offset_bytes = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER) + data_offset_bytes = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER) + + # Write the new byte offsets + buf.extend(b"\0" * HEADER_OFFSET) + buf[:STRING_OFFSET_BYTES] = string_offset_bytes + buf[STRING_OFFSET_BYTES:HEADER_OFFSET] = data_offset_bytes - ar = _TOArchive.load(serialized, header_only, string_only) - return _from_archive(ar) + buf.extend(serialized_header) + buf.extend(string_segment) + + # base64-decode the data segment into raw picklebytes + buf.extend(base64.b64decode(TOArchiveUtils.data_segment(serialized))) + + return buf @staticmethod def deserialize_list(collection: list) -> list: @@ -356,44 +383,3 @@ def deserialize_dict(collection: dict) -> dict: else: raise TypeError("Couldn't deserialize collection") return new_dict - - -def _to_archive(to: TransportableObject) -> _TOArchive: - """ - Convert a TransportableObject to a _TOArchive. - - Args: - to: Transportable object to be converted. - - Returns: - Archived transportable object. - - """ - - header = json.dumps(to._header).encode("utf-8") - object_string = to._object_string.encode("utf-8") - data = to._object.encode("utf-8") - return _TOArchive(header=header, object_string=object_string, data=data) - - -def _from_archive(ar: _TOArchive) -> TransportableObject: - """ - Convert a _TOArchive to a TransportableObject. - - Args: - ar: Archived transportable object. - - Returns: - Transportable object. - - """ - - decoded_object_str = ar.object_string.decode("utf-8") - decoded_data = ar.data.decode("utf-8") - decoded_header = json.loads(ar.header.decode("utf-8")) - to = TransportableObject(None) - to._header = decoded_header - to._object_string = decoded_object_str or "" - to._object = decoded_data or "" - - return to diff --git a/tests/covalent_dispatcher_tests/_service/assets_test.py b/tests/covalent_dispatcher_tests/_service/assets_test.py index 5f704ca43..8b939c6c4 100644 --- a/tests/covalent_dispatcher_tests/_service/assets_test.py +++ b/tests/covalent_dispatcher_tests/_service/assets_test.py @@ -16,6 +16,7 @@ """Unit tests for the FastAPI asset endpoints""" +import base64 import tempfile from contextlib import contextmanager from typing import Generator @@ -704,7 +705,7 @@ def test_get_pickle_offsets(): start, end = _get_tobj_pickle_offsets(f"file://{write_file.name}") - assert data[start:].decode("utf-8") == tobj.get_serialized() + assert data[start:] == base64.b64decode(tobj.get_serialized().encode("utf-8")) def test_generate_partial_file_slice(): diff --git a/tests/covalent_tests/workflow/transport_test.py b/tests/covalent_tests/workflow/transport_test.py index 40de076c1..809491dc6 100644 --- a/tests/covalent_tests/workflow/transport_test.py +++ b/tests/covalent_tests/workflow/transport_test.py @@ -16,6 +16,8 @@ """Unit tests for transport graph.""" +import base64 +import json import platform from unittest.mock import call @@ -32,6 +34,11 @@ encode_metadata, pickle_modules_by_value, ) +from covalent._workflow.transportable_object import ( + BYTE_ORDER, + DATA_OFFSET_BYTES, + STRING_OFFSET_BYTES, +) from covalent.executor import LocalExecutor from covalent.triggers import BaseTrigger @@ -80,6 +87,108 @@ def workflow_transport_graph(): return tg +# For testing TObj back-compat -- copied from earlier SDK +class _TOArchive: + """Archived transportable object.""" + + def __init__(self, header: bytes, object_string: bytes, data: bytes): + """ + Initialize TOArchive. + + Args: + header: Archived transportable object header. + object_string: Archived transportable object string. + data: Archived transportable object data. + + Returns: + None + """ + + self.header = header + self.object_string = object_string + self.data = data + + def cat(self) -> bytes: + """ + Concatenate TOArchive. + + Returns: + Concatenated TOArchive. + + """ + + header_size = len(self.header) + string_size = len(self.object_string) + data_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size + string_size + string_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size + + data_offset = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER, signed=False) + string_offset = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER, signed=False) + + return string_offset + data_offset + self.header + self.object_string + self.data + + +# Copied from previous SDK version +class LegacyTransportableObject: + """ + A function is converted to a transportable object by serializing it using cloudpickle + and then whenever executing it, the transportable object is deserialized. The object + will also contain additional info like the python version used to serialize it. + + Attributes: + _object: The serialized object. + python_version: The python version used on the client's machine. + """ + + def __init__(self, obj) -> None: + b64object = base64.b64encode(cloudpickle.dumps(obj)) + object_string_u8 = str(obj).encode("utf-8") + + self._object = b64object.decode("utf-8") + self._object_string = object_string_u8.decode("utf-8") + + self._header = { + "py_version": platform.python_version(), + "cloudpickle_version": cloudpickle.__version__, + "attrs": { + "doc": getattr(obj, "__doc__", ""), + "name": getattr(obj, "__name__", ""), + }, + } + + # For testing TObj back-compat + @staticmethod + def _to_archive(to) -> _TOArchive: + """ + Convert a TransportableObject to a _TOArchive. + + Args: + to: Transportable object to be converted. + + Returns: + Archived transportable object. + + """ + + header = json.dumps(to._header).encode("utf-8") + object_string = to._object_string.encode("utf-8") + data = to._object.encode("utf-8") + return _TOArchive(header=header, object_string=object_string, data=data) + + def serialize(self) -> bytes: + """ + Serialize the transportable object. + + Args: + None + + Returns: + pickled_object: The serialized object alongwith the python version. + """ + + return LegacyTransportableObject._to_archive(self).cat() + + def test_transportable_object_python_version(transportable_object): """Test that the transportable object retrieves the correct python version.""" @@ -87,27 +196,22 @@ def test_transportable_object_python_version(transportable_object): assert to.python_version == platform.python_version() -def test_transportable_object_eq(transportable_object): +def test_transportable_object_eq(): """Test the __eq__ magic method of TransportableObject""" - import copy - - to = transportable_object - to_new = TransportableObject(None) - to_new.__dict__ = copy.deepcopy(to.__dict__) - assert to.__eq__(to_new) - - to_new._header["py_version"] = "3.5.1" - assert not to.__eq__(to_new) - - assert not to.__eq__({}) + to = TransportableObject(1) + to_new = TransportableObject(1) + to_new_2 = TransportableObject(2) + assert to == to_new + assert to != to_new_2 + assert to != 1 def test_transportable_object_get_serialized(transportable_object): """Test serialized transportable object retrieval.""" to = transportable_object - assert to.get_serialized() == to._object + assert to.get_serialized() == base64.b64encode(cloudpickle.dumps(subtask)).decode("utf-8") def test_transportable_object_get_deserialized(transportable_object): @@ -124,15 +228,8 @@ def test_transportable_object_from_dict(transportable_object): to_new = TransportableObject.from_dict(object_dict) assert to == to_new - - -def test_transportable_object_to_dict_attributes(transportable_object): - """Test attributes from `to_dict` contain correct name and docstrings""" - - tr_dict = transportable_object.to_dict() - - assert tr_dict["attributes"]["_header"]["attrs"]["doc"] == subtask.__doc__ - assert tr_dict["attributes"]["_header"]["attrs"]["name"] == subtask.__name__ + assert to_new.header == to.header + assert to_new.object_string == to.object_string def test_transportable_object_serialize_to_json(transportable_object): @@ -148,7 +245,9 @@ def test_transportable_object_deserialize_from_json(transportable_object): to = transportable_object json_string = to.serialize_to_json() deserialized_to = TransportableObject.deserialize_from_json(json_string) - assert to.__dict__ == deserialized_to.__dict__ + assert to == deserialized_to + assert deserialized_to.header == to.header + assert deserialized_to.object_string == to.object_string def test_transportable_object_make_transportable_idempotent(transportable_object): @@ -169,29 +268,6 @@ def test_transportable_object_serialize_deserialize(transportable_object): assert new_to.python_version == to.python_version -def test_transportable_object_sedeser_string_only(): - """Test extracting string only from serialized to""" - x = 123 - to = TransportableObject(x) - - ser = to.serialize() - new_to = TransportableObject.deserialize(ser, string_only=True) - assert new_to.object_string == to.object_string - assert new_to._object == "" - - -def test_transportable_object_sedeser_header_only(): - """Test extracting header only from serialized to""" - x = 123 - to = TransportableObject(x) - - ser = to.serialize() - new_to = TransportableObject.deserialize(ser, header_only=True) - - assert new_to.object_string == "" - assert new_to._header - - def test_transportable_object_deserialize_list(transportable_object): deserialized = [1, 2, {"a": 3, "b": [4, 5]}] serialized_list = [ @@ -222,6 +298,16 @@ def test_transportable_object_deserialize_dict(transportable_object): assert TransportableObject.deserialize_dict(serialized_dict) == deserialized +def test_tobj_deserialize_back_compat(): + lto = LegacyTransportableObject({"a": 5}) + serialized = lto.serialize() + to = TransportableObject.deserialize(serialized) + obj = to.get_deserialized() + assert obj == {"a": 5} + obj2 = TransportableObject.deserialize(to.serialize()).get_deserialized() + assert obj2 == {"a": 5} + + def test_transport_graph_initialization(): """Test the initialization of an empty transport graph."""