diff --git a/common/src/main/java/dev/cel/common/internal/BUILD.bazel b/common/src/main/java/dev/cel/common/internal/BUILD.bazel index 8219aa73..88d7dd51 100644 --- a/common/src/main/java/dev/cel/common/internal/BUILD.bazel +++ b/common/src/main/java/dev/cel/common/internal/BUILD.bazel @@ -181,6 +181,7 @@ java_library( ], deps = [ "//common/annotations", + "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", ], ) diff --git a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java index f66fe26a..91b205f6 100644 --- a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java +++ b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java @@ -16,7 +16,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static java.util.Arrays.stream; import dev.cel.expr.ExprValue; import com.google.common.collect.ImmutableList; @@ -60,7 +59,6 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.function.Function; import org.jspecify.nullness.Nullable; /** @@ -139,10 +137,6 @@ public final class ProtoAdapter { public static final BidiConverter DOUBLE_CONVERTER = BidiConverter.of(Number::doubleValue, Number::floatValue); - private static final ImmutableMap WELL_KNOWN_PROTOS = - stream(WellKnownProto.values()) - .collect(toImmutableMap(WellKnownProto::typeName, Function.identity())); - private final DynamicProto dynamicProto; private final boolean enableUnsignedLongs; @@ -163,7 +157,8 @@ public Object adaptProtoToValue(MessageOrBuilder proto) { } // If the proto is not a well-known type, then the input Message is what's expected as the // output return value. - WellKnownProto wellKnownProto = WELL_KNOWN_PROTOS.get(typeName(proto.getDescriptorForType())); + WellKnownProto wellKnownProto = + WellKnownProto.getByDescriptorName(typeName(proto.getDescriptorForType())); if (wellKnownProto == null) { return proto; } @@ -328,7 +323,7 @@ private BidiConverter fieldToValueConverter(FieldDescriptor fieldDescriptor) { * considered, such as a packing an {@code google.protobuf.StringValue} into a {@code Any} value. */ public Optional adaptValueToProto(Object value, String protoTypeName) { - WellKnownProto wellKnownProto = WELL_KNOWN_PROTOS.get(protoTypeName); + WellKnownProto wellKnownProto = WellKnownProto.getByDescriptorName(protoTypeName); if (wellKnownProto == null) { if (value instanceof Message) { return Optional.of((Message) value); @@ -644,7 +639,7 @@ private static boolean isWrapperType(FieldDescriptor fieldDescriptor) { return false; } String fieldTypeName = fieldDescriptor.getMessageType().getFullName(); - WellKnownProto wellKnownProto = WELL_KNOWN_PROTOS.get(fieldTypeName); + WellKnownProto wellKnownProto = WellKnownProto.getByDescriptorName(fieldTypeName); return wellKnownProto != null && wellKnownProto.isWrapperType(); } diff --git a/common/src/main/java/dev/cel/common/internal/WellKnownProto.java b/common/src/main/java/dev/cel/common/internal/WellKnownProto.java index f3520f44..14da4396 100644 --- a/common/src/main/java/dev/cel/common/internal/WellKnownProto.java +++ b/common/src/main/java/dev/cel/common/internal/WellKnownProto.java @@ -14,6 +14,10 @@ package dev.cel.common.internal; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Arrays.stream; + +import com.google.common.collect.ImmutableMap; import com.google.protobuf.Any; import com.google.protobuf.BoolValue; import com.google.protobuf.BytesValue; @@ -31,6 +35,7 @@ import com.google.protobuf.UInt64Value; import com.google.protobuf.Value; import dev.cel.common.annotations.Internal; +import java.util.function.Function; /** * WellKnownProto types used throughout CEL. These types are specially handled to ensure that @@ -58,6 +63,14 @@ public enum WellKnownProto { private final Descriptor descriptor; private final boolean isWrapperType; + private static final ImmutableMap WELL_KNOWN_PROTO_MAP; + + static { + WELL_KNOWN_PROTO_MAP = + stream(WellKnownProto.values()) + .collect(toImmutableMap(WellKnownProto::typeName, Function.identity())); + } + WellKnownProto(Descriptor descriptor) { this(descriptor, /* isWrapperType= */ false); } @@ -75,7 +88,11 @@ public String typeName() { return descriptor.getFullName(); } - boolean isWrapperType() { + public boolean isWrapperType() { return isWrapperType; } + + public static WellKnownProto getByDescriptorName(String name) { + return WELL_KNOWN_PROTO_MAP.get(name); + } } diff --git a/common/src/main/java/dev/cel/common/values/BUILD.bazel b/common/src/main/java/dev/cel/common/values/BUILD.bazel index 39d98005..5c4d857c 100644 --- a/common/src/main/java/dev/cel/common/values/BUILD.bazel +++ b/common/src/main/java/dev/cel/common/values/BUILD.bazel @@ -35,6 +35,7 @@ CEL_VALUES_SOURCES = [ PROTO_MESSAGE_VALUE_SOURCES = [ "ProtoCelValueConverter.java", "ProtoMessageValue.java", + "ProtoWrapperValue.java", ] java_library( diff --git a/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java b/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java index 57415e22..d0a12d14 100644 --- a/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java +++ b/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java @@ -18,9 +18,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.math.LongMath.checkedAdd; import static com.google.common.math.LongMath.checkedSubtract; -import static java.util.Arrays.stream; -import com.google.common.collect.ImmutableMap; import com.google.common.primitives.UnsignedLong; import com.google.errorprone.annotations.Immutable; import com.google.protobuf.Any; @@ -56,7 +54,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Function; /** * {@code CelValueConverter} handles bidirectional conversion between native Java and protobuf @@ -70,9 +67,6 @@ @Immutable @Internal public final class ProtoCelValueConverter extends CelValueConverter { - private static final ImmutableMap WELL_KNOWN_PROTOS = - stream(WellKnownProto.values()) - .collect(toImmutableMap(WellKnownProto::typeName, Function.identity())); private final CelDescriptorPool celDescriptorPool; private final DynamicProto dynamicProto; @@ -90,7 +84,7 @@ public CelValue fromProtoMessageToCelValue(MessageOrBuilder message) { } WellKnownProto wellKnownProto = - WELL_KNOWN_PROTOS.get(message.getDescriptorForType().getFullName()); + WellKnownProto.getByDescriptorName(message.getDescriptorForType().getFullName()); if (wellKnownProto == null) { return ProtoMessageValue.create((Message) message, celDescriptorPool, this); } diff --git a/common/src/main/java/dev/cel/common/values/ProtoWrapperValue.java b/common/src/main/java/dev/cel/common/values/ProtoWrapperValue.java new file mode 100644 index 00000000..fbc62854 --- /dev/null +++ b/common/src/main/java/dev/cel/common/values/ProtoWrapperValue.java @@ -0,0 +1,115 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.values; + +import com.google.auto.value.AutoValue; +import com.google.errorprone.annotations.Immutable; +import com.google.protobuf.ByteString; +import com.google.protobuf.FloatValue; +import com.google.protobuf.Int32Value; +import com.google.protobuf.Int64Value; +import com.google.protobuf.MessageOrBuilder; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import dev.cel.common.internal.WellKnownProto; +import dev.cel.common.types.NullableType; + +/** ProtoWrapperValue represents a */ +@Immutable +@AutoValue +public abstract class ProtoWrapperValue extends StructValue { + + @Override + public abstract CelValue value(); + + abstract WellKnownProto wellKnownProto(); + + /** + * Retrieves the underlying value being held in the wrapper. For example, if this is a + * `google.protobuf.IntValue', a Java long is returned. + */ + public Object nativeValue() { + if (wellKnownProto().equals(WellKnownProto.BYTES_VALUE)) { + // Return the proto ByteString as the underlying primitive value rather than a mutable byte + // array. + return ByteString.copyFrom(((BytesValue) value()).value().toByteArray()); + } + return value().value(); + } + + @Override + public abstract NullableType celType(); + + @Override + public boolean isZeroValue() { + return value().isZeroValue(); + } + + @Override + public CelValue select(String fieldName) { + throw new UnsupportedOperationException("Wrappers do not support field selection"); + } + + @Override + public boolean hasField(String fieldName) { + throw new UnsupportedOperationException("Wrappers do not support presence tests"); + } + + public static ProtoWrapperValue create( + MessageOrBuilder wrapperMessage, boolean enableUnsignedLongs) { + WellKnownProto wellKnownProto = getWellKnownProtoFromWrapperName(wrapperMessage); + CelValue celValue = newCelValueFromWrapper(wrapperMessage, wellKnownProto, enableUnsignedLongs); + NullableType nullableType = NullableType.create(celValue.celType()); + return new AutoValue_ProtoWrapperValue(celValue, wellKnownProto, nullableType); + } + + private static CelValue newCelValueFromWrapper( + MessageOrBuilder message, WellKnownProto wellKnownProto, boolean enableUnsignedLongs) { + switch (wellKnownProto) { + case BOOL_VALUE: + return BoolValue.create(((com.google.protobuf.BoolValue) message).getValue()); + case BYTES_VALUE: + return BytesValue.create( + CelByteString.of(((com.google.protobuf.BytesValue) message).getValue().toByteArray())); + case DOUBLE_VALUE: + return DoubleValue.create(((com.google.protobuf.DoubleValue) message).getValue()); + case FLOAT_VALUE: + return DoubleValue.create(((FloatValue) message).getValue()); + case INT32_VALUE: + return IntValue.create(((Int32Value) message).getValue()); + case INT64_VALUE: + return IntValue.create(((Int64Value) message).getValue()); + case STRING_VALUE: + return StringValue.create(((com.google.protobuf.StringValue) message).getValue()); + case UINT32_VALUE: + return UintValue.create(((UInt32Value) message).getValue(), enableUnsignedLongs); + case UINT64_VALUE: + return UintValue.create(((UInt64Value) message).getValue(), enableUnsignedLongs); + default: + throw new IllegalArgumentException( + "Should only be called for wrapper types. Got: " + wellKnownProto.name()); + } + } + + private static WellKnownProto getWellKnownProtoFromWrapperName(MessageOrBuilder message) { + WellKnownProto wellKnownProto = + WellKnownProto.getByDescriptorName(message.getDescriptorForType().getFullName()); + if (!wellKnownProto.isWrapperType()) { + throw new IllegalArgumentException("Expected a wrapper type. Got: " + wellKnownProto.name()); + } + + return wellKnownProto; + } +} diff --git a/common/src/main/java/dev/cel/common/values/UintValue.java b/common/src/main/java/dev/cel/common/values/UintValue.java index c376c76e..aa76dce8 100644 --- a/common/src/main/java/dev/cel/common/values/UintValue.java +++ b/common/src/main/java/dev/cel/common/values/UintValue.java @@ -21,22 +21,25 @@ import dev.cel.common.types.SimpleType; /** - * UintValue represents CelValue for unsigned longs, leveraging Guava's implementation of {@link - * UnsignedLong}. - * - *

TODO: Look into potentially accepting a primitive `long` to avoid boxing/unboxing - * when the interpreter is augmented to work directly on CelValue. + * UintValue represents CelValue for unsigned longs. This either leverages Guava's implementation of + * {@link UnsignedLong}, or just holds a primitive long. */ -@AutoValue @Immutable +@AutoValue +@AutoValue.CopyAnnotations +@SuppressWarnings("Immutable") // value is either a boxed long or an immutable UnsignedLong. public abstract class UintValue extends CelValue { @Override - public abstract UnsignedLong value(); + public abstract Number value(); @Override public boolean isZeroValue() { - return UnsignedLong.ZERO.equals(value()); + if (value() instanceof UnsignedLong) { + return UnsignedLong.ZERO.equals(value()); + } else { + return value().longValue() == 0; + } } @Override @@ -47,4 +50,9 @@ public CelType celType() { public static UintValue create(UnsignedLong value) { return new AutoValue_UintValue(value); } + + public static UintValue create(long value, boolean enableUnsignedLongs) { + Number unsignedLong = enableUnsignedLongs ? UnsignedLong.fromLongBits(value) : value; + return new AutoValue_UintValue(unsignedLong); + } } diff --git a/common/src/test/java/dev/cel/common/values/ProtoMessageValueTest.java b/common/src/test/java/dev/cel/common/values/ProtoMessageValueTest.java index c1546061..10740bf2 100644 --- a/common/src/test/java/dev/cel/common/values/ProtoMessageValueTest.java +++ b/common/src/test/java/dev/cel/common/values/ProtoMessageValueTest.java @@ -172,7 +172,7 @@ public void hasField_extensionField_throwsWhenDescriptorMissing() { } private enum SelectFieldTestCase { - // // Primitives + // Primitives BOOL("single_bool", BoolValue.create(true)), INT32("single_int32", IntValue.create(4L)), INT64("single_int64", IntValue.create(5L)), diff --git a/common/src/test/java/dev/cel/common/values/ProtoWrapperValueTest.java b/common/src/test/java/dev/cel/common/values/ProtoWrapperValueTest.java new file mode 100644 index 00000000..07b35fa5 --- /dev/null +++ b/common/src/test/java/dev/cel/common/values/ProtoWrapperValueTest.java @@ -0,0 +1,121 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.values; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.primitives.UnsignedLong; +import com.google.protobuf.ByteString; +import com.google.protobuf.FloatValue; +import com.google.protobuf.Int32Value; +import com.google.protobuf.Int64Value; +import com.google.protobuf.Message; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import com.google.protobuf.Value; +import com.google.testing.junit.testparameterinjector.TestParameter; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import dev.cel.common.types.NullableType; +import dev.cel.common.types.SimpleType; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public class ProtoWrapperValueTest { + + private enum WrapperTestCase { + BOOL_WRAPPER(com.google.protobuf.BoolValue.of(true), BoolValue.create(true)), + INT32_WRAPPER(Int32Value.of(5), IntValue.create(5L)), + INT64_WRAPPER(Int64Value.of(10), IntValue.create(10L)), + UINT32_WRAPPER(UInt32Value.of(5), UintValue.create(5L, false)), + UINT64_WRAPPER( + UInt64Value.of(UnsignedLong.MAX_VALUE.longValue()), + UintValue.create(UnsignedLong.MAX_VALUE.longValue(), false)), + FLOAT_WRAPPER(FloatValue.of(5.4f), DoubleValue.create(5.4f)), + DOUBLE_WRAPPER(com.google.protobuf.DoubleValue.of(6.5d), DoubleValue.create(6.5d)), + STRING_WRAPPER(com.google.protobuf.StringValue.of("hello"), StringValue.create("hello")); + + private final Message protoMessage; + private final CelValue wrappedCelValue; + + WrapperTestCase(Message wrapperMessage, CelValue wrappedCelValue) { + this.protoMessage = wrapperMessage; + this.wrappedCelValue = wrappedCelValue; + } + } + + @Test + public void emptyValue() { + ProtoWrapperValue wrapperValue = ProtoWrapperValue.create(Int64Value.of(0L), false); + + assertThat(wrapperValue.value()).isEqualTo(IntValue.create(0L)); + assertThat(wrapperValue.isZeroValue()).isTrue(); + } + + @Test + public void constructWrapperValue(@TestParameter WrapperTestCase testCase) { + ProtoWrapperValue wrapperValue = ProtoWrapperValue.create(testCase.protoMessage, false); + + assertThat(wrapperValue.value()).isEqualTo(testCase.wrappedCelValue); + assertThat(wrapperValue.nativeValue()).isEqualTo(testCase.wrappedCelValue.value()); + assertThat(wrapperValue.isZeroValue()).isFalse(); + } + + @Test + public void constructWrapperValue_unsignedLong() { + ProtoWrapperValue wrapperValue = + ProtoWrapperValue.create(UInt64Value.of(UnsignedLong.MAX_VALUE.longValue()), true); + + assertThat(wrapperValue.value()) + .isEqualTo(UintValue.create(UnsignedLong.MAX_VALUE.longValue(), true)); + assertThat(wrapperValue.nativeValue()).isEqualTo(UnsignedLong.MAX_VALUE); + } + + @Test + public void constructBytesWrapperValue() { + ProtoWrapperValue wrapperValue = + ProtoWrapperValue.create( + com.google.protobuf.BytesValue.of(ByteString.copyFrom(new byte[] {0x02})), false); + + assertThat(wrapperValue.value()) + .isEqualTo(BytesValue.create(CelByteString.of(new byte[] {0x02}))); + assertThat(wrapperValue.nativeValue()).isEqualTo(ByteString.copyFrom(new byte[] {0x02})); + assertThat(wrapperValue.isZeroValue()).isFalse(); + } + + @Test + public void celTypeTest() { + ProtoWrapperValue value = + ProtoWrapperValue.create(com.google.protobuf.StringValue.of(""), false); + + assertThat(value.celType()).isEqualTo(NullableType.create(SimpleType.STRING)); + } + + @Test + public void fieldSelection_throws() { + ProtoWrapperValue value = ProtoWrapperValue.create(Int64Value.of(1), false); + + assertThrows(UnsupportedOperationException.class, () -> value.hasField("bogus")); + assertThrows(UnsupportedOperationException.class, () -> value.select("bogus")); + } + + @Test + public void nonWrapperType_throws() { + assertThrows( + IllegalArgumentException.class, + () -> ProtoWrapperValue.create(Value.getDefaultInstance(), false)); + } +}