diff --git a/functions/errors/contract/src/main/proto/errors.proto b/functions/errors/contract/src/main/proto/errors.proto index 3b7c624f..b0293513 100644 --- a/functions/errors/contract/src/main/proto/errors.proto +++ b/functions/errors/contract/src/main/proto/errors.proto @@ -14,6 +14,7 @@ service FailingService { rpc Fail (ErrorMessage) returns (google.protobuf.Empty); rpc FailAndHandle (ErrorMessage) returns (ErrorMessage); rpc InvokeExternalAndHandleFailure(FailRequest) returns (ErrorMessage); + rpc HandleNotFound (FailRequest) returns (ErrorMessage); } message FailRequest { diff --git a/functions/errors/impl/src/main/java/dev/restate/e2e/functions/errors/FailingService.java b/functions/errors/impl/src/main/java/dev/restate/e2e/functions/errors/FailingService.java index 9353a100..c02bcc8b 100644 --- a/functions/errors/impl/src/main/java/dev/restate/e2e/functions/errors/FailingService.java +++ b/functions/errors/impl/src/main/java/dev/restate/e2e/functions/errors/FailingService.java @@ -1,8 +1,10 @@ package dev.restate.e2e.functions.errors; import com.google.protobuf.Empty; +import com.google.rpc.Code; import dev.restate.e2e.functions.utils.NumberSortHttpServerUtils; import dev.restate.sdk.RestateContext; +import io.grpc.MethodDescriptor; import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.stub.StreamObserver; @@ -85,4 +87,24 @@ public void invokeExternalAndHandleFailure( responseObserver.onNext(ErrorMessage.newBuilder().setErrorMessage(finalMessage).build()); responseObserver.onCompleted(); } + + @Override + public void handleNotFound(FailRequest request, StreamObserver responseObserver) { + var methodDescriptor = + FailingServiceGrpc.getFailMethod().toBuilder() + .setFullMethodName( + MethodDescriptor.generateFullMethodName( + FailingServiceGrpc.SERVICE_NAME, "UnknownFn")) + .build(); + try { + RestateContext.current().call(methodDescriptor, ErrorMessage.getDefaultInstance()).await(); + } catch (StatusRuntimeException e) { + if (e.getStatus().getCode().value() == Code.NOT_FOUND_VALUE) { + responseObserver.onNext(ErrorMessage.newBuilder().setErrorMessage("notfound").build()); + responseObserver.onCompleted(); + } + } + + throw new IllegalStateException("This should be unreachable"); + } } diff --git a/tests/src/test/kotlin/dev/restate/e2e/ErrorsTest.kt b/tests/src/test/kotlin/dev/restate/e2e/ErrorsTest.kt index 48e1e1db..f8d84094 100644 --- a/tests/src/test/kotlin/dev/restate/e2e/ErrorsTest.kt +++ b/tests/src/test/kotlin/dev/restate/e2e/ErrorsTest.kt @@ -67,4 +67,13 @@ class ErrorsTest { .extracting(ErrorMessage::getErrorMessage) .isEqualTo("begin:external_call:internal_call") } + + @Test + fun propagate404(@InjectBlockingStub stub: FailingServiceBlockingStub) { + assertThat( + stub.handleNotFound( + FailRequest.newBuilder().setKey(UUID.randomUUID().toString()).build())) + .extracting(ErrorMessage::getErrorMessage) + .isEqualTo("notfound") + } }