Skip to content

Commit

Permalink
Multichain provider (#1265)
Browse files Browse the repository at this point in the history
* Fixes validation enforcement from new Aquarius.
* Adapt compute tests.
* Adds compatibility with older providers.
  • Loading branch information
calina-c authored Mar 21, 2023
1 parent dd51e32 commit dc45be3
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 45 deletions.
4 changes: 2 additions & 2 deletions READMEs/c2d-flow.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ ALGO_ddo = ocean.assets.resolve(ALGO_did)

compute_service = DATA_ddo.services[1]
algo_service = ALGO_ddo.services[0]
free_c2d_env = ocean.compute.get_free_c2d_environment(compute_service.service_endpoint)
free_c2d_env = ocean.compute.get_free_c2d_environment(compute_service.service_endpoint, DATA_ddo.chain_id)

from datetime import datetime, timedelta
from ocean_lib.models.compute_input import ComputeInput
Expand Down Expand Up @@ -246,5 +246,5 @@ It modifies the contents of the given ComputeInput as follows:
This means you can reuse the same ComputeInput and you don't need to regenerate it everytime it is sent to `pay_for_compute_service`. This step makes sure you are not paying unnecessary or duplicated fees.

If you wish to upgrade the compute resources, you can use any (paid) C2D environment.
Inspect the results of `ocean.ocean_compute.get_c2d_environments(service.service_endpoint)` and `ocean.retrieve_provider_fees_for_compute(datasets, algorithm_data, consumer_address, compute_environment, duration)` for a preview of what you will pay.
Inspect the results of `ocean.ocean_compute.get_c2d_environments(service.service_endpoint, DATA_ddo.chain_id)` and `ocean.retrieve_provider_fees_for_compute(datasets, algorithm_data, consumer_address, compute_environment, duration)` for a preview of what you will pay.
Don't forget to handle any minting, allowance or approvals on the desired token to ensure transactions pass.
2 changes: 1 addition & 1 deletion ocean_lib/assets/ddo.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def add_service(self, service: Service) -> None:
:param service: To add service, Service
"""
service.encrypt_files(self.nft_address)
service.encrypt_files(self.nft_address, self.chain_id)

logger.debug(
f"Adding service with service type {service.type} with did {self.did}"
Expand Down
62 changes: 47 additions & 15 deletions ocean_lib/data_provider/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from json import JSONDecodeError
from typing import Dict, List, Optional, Tuple, Union
from unittest.mock import Mock
from requests.models import PreparedRequest

import requests
from enforce_typing import enforce_types
Expand Down Expand Up @@ -90,27 +91,34 @@ def get_service_endpoints(provider_uri: str) -> Dict[str, List[str]]:

@staticmethod
@enforce_types
def get_c2d_environments(provider_uri: str) -> Optional[str]:
def get_c2d_environments(provider_uri: str, chain_id: int) -> Optional[str]:
"""
Return the provider address
"""
try:
_, envs_endpoint = DataServiceProviderBase.build_c2d_environments_endpoint(
provider_uri
provider_uri, chain_id
)
environments = DataServiceProviderBase._http_method(
"get", envs_endpoint
"get",
envs_endpoint,
).json()

return environments
except requests.exceptions.RequestException:
if str(chain_id) not in environments:
logger.warning(
"You might be using an older provider. ocean.py can not verify the chain id."
)
return environments

return environments[str(chain_id)]
except (requests.exceptions.RequestException, KeyError):
pass

return []

@staticmethod
@enforce_types
def get_provider_address(provider_uri: str) -> Optional[str]:
def get_provider_address(provider_uri: str, chain_id: int) -> Optional[str]:
"""
Return the provider address
"""
Expand All @@ -119,7 +127,13 @@ def get_provider_address(provider_uri: str) -> Optional[str]:
"get", provider_uri
).json()

return provider_info["providerAddress"]
if "providerAddress" in provider_info:
logger.warning(
"You might be using an older provider. ocean.py can not verify the chain id."
)
return provider_info["providerAddress"]

return provider_info["providerAddresses"][str(chain_id)]
except requests.exceptions.RequestException:
pass

Expand Down Expand Up @@ -153,9 +167,14 @@ def get_root_uri(service_endpoint: str) -> str:
except (requests.exceptions.RequestException, JSONDecodeError):
raise InvalidURL(f"InvalidURL {service_endpoint}.")

if "providerAddress" not in response:
if "providerAddresses" not in response and "providerAddress" not in response:
raise InvalidURL(
f"Invalid Provider URL {service_endpoint}, no providerAddress."
f"Invalid Provider URL {service_endpoint}, no providerAddresses."
)

if "providerAddress" in response:
logger.warning(
"You might be using an older provider. ocean.py can not verify the chain id."
)

return result
Expand All @@ -172,17 +191,28 @@ def is_valid_provider(provider_uri: str) -> bool:

@staticmethod
@enforce_types
def build_endpoint(service_name: str, provider_uri: str) -> Tuple[str, str]:
def build_endpoint(
service_name: str, provider_uri: str, params: Optional[dict] = None
) -> Tuple[str, str]:
provider_uri = DataServiceProviderBase.get_root_uri(provider_uri)
service_endpoints = DataServiceProviderBase.get_service_endpoints(provider_uri)

method, url = service_endpoints[service_name]
return method, urljoin(provider_uri, url)
url = urljoin(provider_uri, url)

if params:
req = PreparedRequest()
req.prepare_url(url, params)
url = req.url

return method, url

@staticmethod
@enforce_types
def build_encrypt_endpoint(provider_uri: str) -> Tuple[str, str]:
return DataServiceProviderBase.build_endpoint("encrypt", provider_uri)
def build_encrypt_endpoint(provider_uri: str, chain_id: int) -> Tuple[str, str]:
return DataServiceProviderBase.build_endpoint(
"encrypt", provider_uri, {"chainId": chain_id}
)

@staticmethod
@enforce_types
Expand Down Expand Up @@ -216,9 +246,11 @@ def build_fileinfo(provider_uri: str) -> Tuple[str, str]:

@staticmethod
@enforce_types
def build_c2d_environments_endpoint(provider_uri: str) -> Tuple[str, str]:
def build_c2d_environments_endpoint(
provider_uri: str, chain_id: int
) -> Tuple[str, str]:
return DataServiceProviderBase.build_endpoint(
"computeEnvironments", provider_uri
"computeEnvironments", provider_uri, {"chainId": chain_id}
)

@staticmethod
Expand Down
6 changes: 4 additions & 2 deletions ocean_lib/data_provider/data_encryptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ class DataEncryptor(DataServiceProviderBase):
@staticmethod
@enforce_types
def encrypt(
objects_to_encrypt: Union[list, str, bytes, dict], provider_uri: str
objects_to_encrypt: Union[list, str, bytes, dict],
provider_uri: str,
chain_id: int,
) -> Response:
if isinstance(objects_to_encrypt, dict):
data = json.dumps(objects_to_encrypt, separators=(",", ":"))
Expand All @@ -34,7 +36,7 @@ def encrypt(
payload = objects_to_encrypt

_, encrypt_endpoint = DataServiceProviderBase.build_encrypt_endpoint(
provider_uri
provider_uri, chain_id
)

response = DataServiceProviderBase._http_method(
Expand Down
21 changes: 11 additions & 10 deletions ocean_lib/data_provider/test/test_data_service_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_encrypt(provider_wallet, file1, file2):
key = provider_wallet.private_key
# Encrypt file objects
res = {"files": [file1.to_dict(), file2.to_dict()]}
result = DataEncryptor.encrypt(res, DEFAULT_PROVIDER_URL)
result = DataEncryptor.encrypt(res, DEFAULT_PROVIDER_URL, 8996)
encrypted_files = result.content.decode("utf-8")
assert result.status_code == 201
assert result.headers["Content-type"] == "text/plain"
Expand All @@ -214,7 +214,7 @@ def test_encrypt(provider_wallet, file1, file2):

# Encrypt a simple string
test_string = "hello_world"
encrypt_result = DataEncryptor.encrypt(test_string, DEFAULT_PROVIDER_URL)
encrypt_result = DataEncryptor.encrypt(test_string, DEFAULT_PROVIDER_URL, 8996)
encrypted_document = encrypt_result.content.decode("utf-8")
assert result.status_code == 201
assert result.headers["Content-type"] == "text/plain"
Expand Down Expand Up @@ -285,7 +285,7 @@ def test_expose_endpoints():
def test_c2d_environments():
"""Tests that the test ocean-compute env exists on the DataServiceProvider."""
provider_uri = DEFAULT_PROVIDER_URL
c2d_envs = DataSP.get_c2d_environments(provider_uri)
c2d_envs = DataSP.get_c2d_environments(provider_uri, 8996)
c2d_env_ids = [elem["id"] for elem in c2d_envs]
assert "ocean-compute" in c2d_env_ids, "ocean-compute env not found."

Expand All @@ -294,7 +294,7 @@ def test_c2d_environments():
def test_provider_address():
"""Tests that a provider address exists on the DataServiceProvider."""
provider_uri = DEFAULT_PROVIDER_URL
provider_address = DataSP.get_provider_address(provider_uri)
provider_address = DataSP.get_provider_address(provider_uri, 8996)
assert provider_address, "Failed to get provider address."


Expand All @@ -303,10 +303,10 @@ def test_provider_address_with_url():
"""Tests that a URL version of provider address exists on the DataServiceProvider."""
p_ocean_instance = get_publisher_ocean_instance()
provider_address = DataSP.get_provider_address(
DataSP.get_url(p_ocean_instance.config_dict)
DataSP.get_url(p_ocean_instance.config_dict), 8996
)
assert provider_address, "Failed to get provider address."
assert DataSP.get_provider_address("not a url") is None
assert DataSP.get_provider_address("not a url", 8996) is None


@pytest.mark.integration
Expand Down Expand Up @@ -401,8 +401,9 @@ def get_service_endpoints(_provider_uri=None):
assert DataSP.build_initialize_compute_endpoint(provider_uri)[1] == urljoin(
base_uri, endpoints["initializeCompute"][1]
)
assert DataSP.build_encrypt_endpoint(provider_uri)[1] == urljoin(
base_uri, endpoints["encrypt"][1]
assert (
DataSP.build_encrypt_endpoint(provider_uri, 8996)[1]
== urljoin(base_uri, endpoints["encrypt"][1]) + "?chainId=8996"
)
assert DataSP.build_fileinfo(provider_uri)[1] == urljoin(
base_uri, endpoints["fileinfo"][1]
Expand Down Expand Up @@ -442,13 +443,13 @@ def test_encrypt_failure():
DataEncryptor.set_http_client(http_client)

with pytest.raises(OceanEncryptAssetUrlsError):
DataEncryptor.encrypt({}, DEFAULT_PROVIDER_URL)
DataEncryptor.encrypt({}, DEFAULT_PROVIDER_URL, 8996)

http_client = HttpClientEmptyMock()
DataSP.set_http_client(http_client)

with pytest.raises(DataProviderException):
DataEncryptor.encrypt({}, DEFAULT_PROVIDER_URL)
DataEncryptor.encrypt({}, DEFAULT_PROVIDER_URL, 8996)

DataSP.set_http_client(get_requests_session())

Expand Down
10 changes: 7 additions & 3 deletions ocean_lib/ocean/ocean_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def _encrypt_ddo(
flags = bytes([2])
# Encrypt DDO
encrypt_response = DataEncryptor.encrypt(
objects_to_encrypt=ddo_string, provider_uri=provider_uri
objects_to_encrypt=ddo_string,
provider_uri=provider_uri,
chain_id=ddo.chain_id,
)
document = encrypt_response.text
return document, flags, ddo_hash
Expand All @@ -137,7 +139,9 @@ def _encrypt_ddo(

# Encrypt DDO
encrypt_response = DataEncryptor.encrypt(
objects_to_encrypt=compressed_document, provider_uri=provider_uri
objects_to_encrypt=compressed_document,
provider_uri=provider_uri,
chain_id=ddo.chain_id,
)

document = encrypt_response.text
Expand Down Expand Up @@ -584,7 +588,7 @@ def update(
assert ddo.chain_id == self._chain_id

for service in ddo.services:
service.encrypt_files(ddo.nft_address)
service.encrypt_files(ddo.nft_address, ddo.chain_id)

# Validation by Aquarius
validation_result, errors_or_proof = self.validate(ddo)
Expand Down
8 changes: 4 additions & 4 deletions ocean_lib/ocean/ocean_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ def stop(self, ddo: DDO, service: Service, job_id: str, wallet) -> Dict[str, Any
return job_info

@enforce_types
def get_c2d_environments(self, service_endpoint: str) -> str:
return DataServiceProvider.get_c2d_environments(service_endpoint)
def get_c2d_environments(self, service_endpoint: str, chain_id: int) -> str:
return DataServiceProvider.get_c2d_environments(service_endpoint, chain_id)

@enforce_types
def get_free_c2d_environment(self, service_endpoint: str) -> str:
environments = self.get_c2d_environments(service_endpoint)
def get_free_c2d_environment(self, service_endpoint: str, chain_id) -> str:
environments = self.get_c2d_environments(service_endpoint, chain_id)
return next(env for env in environments if float(env["priceMin"]) == float(0))
3 changes: 2 additions & 1 deletion ocean_lib/services/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def update_compute_values(
self.compute_values["allowRawAlgorithm"] = allow_raw_algorithm

@enforce_types
def encrypt_files(self, nft_address: str) -> Response:
def encrypt_files(self, nft_address: str, chain_id: int) -> Response:
if self.files and isinstance(self.files, str):
return

Expand All @@ -319,6 +319,7 @@ def encrypt_files(self, nft_address: str) -> Response:
"files": files,
},
self.service_endpoint,
chain_id,
)

self.files = encrypt_response.content.decode("utf-8")
21 changes: 15 additions & 6 deletions tests/integration/ganache/test_compute_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import List, Optional

import pytest
import requests
from attr import dataclass

import ocean_lib
Expand Down Expand Up @@ -221,9 +222,12 @@ def run_compute_test(
dataset_and_userdata.ddo, ServiceTypes.CLOUD_COMPUTE
)

free_c2d_env = ocean_instance.compute.get_free_c2d_environment(
service.service_endpoint
)
try:
free_c2d_env = ocean_instance.compute.get_free_c2d_environment(
service.service_endpoint, 8996
)
except StopIteration:
assert False, "No free c2d environment found."

time_difference = (
timedelta(hours=1) if "reuse_order" not in scenarios else timedelta(seconds=30)
Expand Down Expand Up @@ -485,7 +489,12 @@ def test_compute_trusted_algorithm(

@pytest.mark.integration
@skip_on(
(ocean_lib.exceptions.DataProviderException, TypeError),
(
ocean_lib.exceptions.DataProviderException,
requests.exceptions.ConnectionError,
TypeError,
AssertionError,
),
reason="Fix provider issue #606",
)
def test_compute_update_trusted_algorithm(
Expand Down Expand Up @@ -540,7 +549,7 @@ def test_compute_update_trusted_algorithm(


@pytest.mark.integration
@skip_on(TypeError, reason="Fix provider issue #606")
@skip_on((TypeError, AssertionError), reason="Fix provider issue #606")
def test_compute_trusted_publisher(
publisher_wallet,
publisher_ocean,
Expand Down Expand Up @@ -576,7 +585,7 @@ def test_compute_trusted_publisher(


@pytest.mark.integration
@skip_on(TypeError, reason="Fix provider issue #606")
@skip_on((TypeError, AssertionError), reason="Fix provider issue #606")
def test_compute_just_provider_fees(
publisher_wallet,
publisher_ocean,
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/ganache/test_disconnecting_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _iterative_encrypt(mock):
global exception_flag
for _ in range(5):
try:
DataEncryptor.encrypt({}, mock["PROVIDER_URL"])
DataEncryptor.encrypt({}, mock["PROVIDER_URL"], 8996)
except requests.exceptions.InvalidURL as err:
exception_flag = 1
assert err.args[0] == "InvalidURL http://foourl.com."
Expand Down

0 comments on commit dc45be3

Please sign in to comment.