Skip to content

Commit

Permalink
Improve Cose Signed types (#209)
Browse files Browse the repository at this point in the history
* Make CoseSigned fully payload-type aware

* Add serializer as parameter for several methods in CoseSigned

* WIP: Do not encode byte string wrapper again

* Always serialize payload of CoseSigned as bytes

* Remove println from tests

* draft commit of shame

* Extract descriptor building

* Remove println from tests

* Add documentation

* Reject CoseSigned ByteStringWrapper

* Prevent double wrapping payloads in CoseSigned

---------

Co-authored-by: Bernd Prünster <[email protected]>
Co-authored-by: Simon Mueller <[email protected]>
  • Loading branch information
3 people authored Dec 4, 2024
1 parent 1b92455 commit 67298ae
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 98 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
* Additional conversion methods for Java BigInteger and iospin BigInteger
* Refactor `CryptoPublicKey.Rsa` to use `Asn1Integer`
* Fixes JWS/COSE encoding for non-standard exponents (with MSBit 1)
* Add type parameter to `CoseSigned` for its payload
* Add companion method `CoseSigned.fromObject` to create a `CoseSigned` with a typed payload (outside of the usual `ByteArray`)
* Add type parameter to `CoseSigned` for its payload (tagging with tag 24 when necessary)
* Changes primary constructor visibility to `internal` to check for `ByteStringWrapper` as payload type, which shall be rejected
* Do not use DID key identifiers as keyId for `CoseKey`
* Fix BitSet iterator

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ class Asn1EncodingTest : FreeSpec({
}
val parsed = Asn1Element.parse(seq.derEncoded)
parsed.shouldNotBeNull()
println(parsed.prettyPrint())

}

"Old and new encoder produce the same bytes" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,40 @@ import kotlinx.serialization.cbor.CborArray
/**
* Representation of a signed COSE_Sign1 object, i.e. consisting of protected header, unprotected header and payload.
*
* If the payload is a generic [ByteArray], then it will be serialized as-is. Should the payload be any other type,
* the [CoseSignedSerializer] will tag it with 24 (see [RFC8949 3.4.5.1](https://www.rfc-editor.org/rfc/rfc8949.html#name-encoded-cbor-data-item)) during serialization.
* In order to prevent nested wrapping of the payload and the resulting type erasure
* payloads of type [ByteStringWrapper] will be rejected.
* In this case the payload could be handed over as the wrapped class itself or manually serialized to [ByteArray]
*
* See [RFC 9052](https://www.rfc-editor.org/rfc/rfc9052.html).
*/
@OptIn(ExperimentalSerializationApi::class)
@Serializable(with = CoseSignedSerializer::class)
@CborArray
data class CoseSigned<P : Any?>(
@ConsistentCopyVisibility
data class CoseSigned<P : Any?> internal constructor(
@ByteString
val protectedHeader: ByteStringWrapper<CoseHeader>,
val unprotectedHeader: CoseHeader?,
@ByteString
val payload: ByteArray?,
val payload: P?,
@ByteString
val rawSignature: ByteArray,
) {

@Throws(IllegalArgumentException::class)
constructor(
protectedHeader: CoseHeader,
unprotectedHeader: CoseHeader?,
payload: ByteArray?,
signature: CryptoSignature.RawByteEncodable
payload: P?,
signature: CryptoSignature.RawByteEncodable,
) : this(
protectedHeader = ByteStringWrapper(value = protectedHeader),
unprotectedHeader = unprotectedHeader,
payload = payload,
payload = when (payload) {
is ByteStringWrapper<*> -> throw IllegalArgumentException("payload shall not be ByteStringWrapper")
else -> payload
},
rawSignature = signature.rawByteArray
)

Expand All @@ -49,18 +59,8 @@ data class CoseSigned<P : Any?>(
else CryptoSignature.RSAorHMAC(rawSignature)
}

fun serialize(): ByteArray = coseCompliantSerializer.encodeToByteArray(CoseSignedSerializer(), this)

/**
* Decodes the payload of this object into a [ByteStringWrapper] containing an object of type [P].
*
* Note that this does not work if the payload is directly a [ByteArray].
*/
fun getTypedPayload(deserializer: KSerializer<P>): KmmResult<ByteStringWrapper<P>?> = catching {
payload?.let {
coseCompliantSerializer.decodeFromByteArray(ByteStringWrapperSerializer(deserializer), it)
}
}
fun serialize(parameterSerializer: KSerializer<P>): ByteArray = coseCompliantSerializer
.encodeToByteArray(CoseSignedSerializer(parameterSerializer), this)

override fun equals(other: Any?): Boolean {
if (this === other) return true
Expand Down Expand Up @@ -88,58 +88,38 @@ data class CoseSigned<P : Any?>(
override fun toString(): String {
return "CoseSigned(protectedHeader=${protectedHeader.value}," +
" unprotectedHeader=$unprotectedHeader," +
" payload=${payload?.encodeToString(Base16Strict)}," +
" payload=${if (payload is ByteArray) payload.encodeToString(Base16Strict) else payload}," +
" signature=${rawSignature.encodeToString(Base16Strict)})"
}

companion object {
fun deserialize(it: ByteArray): KmmResult<CoseSigned<ByteArray>> = catching {
coseCompliantSerializer.decodeFromByteArray<CoseSigned<ByteArray>>(it)
}

/**
* Creates a [CoseSigned] object from the given parameters,
* encapsulating the [payload] into a [ByteStringWrapper].
*
* This has to be an inline function with a reified type parameter,
* so it can't be a constructor (leads to a runtime error).
*/
inline fun <reified P : Any> fromObject(
protectedHeader: CoseHeader,
unprotectedHeader: CoseHeader?,
payload: P,
signature: CryptoSignature.RawByteEncodable
) = CoseSigned<P>(
protectedHeader = ByteStringWrapper(value = protectedHeader),
unprotectedHeader = unprotectedHeader,
payload = when (payload) {
is ByteArray -> payload
is ByteStringWrapper<*> -> coseCompliantSerializer.encodeToByteArray(payload)
else -> coseCompliantSerializer.encodeToByteArray(ByteStringWrapper(payload))
},
rawSignature = signature.rawByteArray
)
fun <P : Any> deserialize(parameterSerializer: KSerializer<P>, it: ByteArray): KmmResult<CoseSigned<P>> =
catching {
coseCompliantSerializer.decodeFromByteArray(CoseSignedSerializer(parameterSerializer), it)
}

/**
* Called by COSE signing implementations to get the bytes that will be
* used as the input for signature calculation of a `COSE_Sign1` object
*/
inline fun <reified P : Any> prepareCoseSignatureInput(
fun <P : Any> prepareCoseSignatureInput(
protectedHeader: CoseHeader,
payload: P?,
serializer: KSerializer<P>,
externalAad: ByteArray = byteArrayOf(),
): ByteArray = CoseSignatureInput(
contextString = "Signature1",
protectedHeader = ByteStringWrapper(protectedHeader),
externalAad = externalAad,
payload = when (payload) {
null -> null
is ByteArray -> payload
is ByteStringWrapper<*> -> coseCompliantSerializer.encodeToByteArray(payload)
else -> coseCompliantSerializer.encodeToByteArray(ByteStringWrapper(payload))
else -> coseCompliantSerializer.encodeToByteArray(
ByteStringWrapperSerializer(serializer),
ByteStringWrapper(payload)
)
},
).serialize()


}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,44 +1,102 @@
package at.asitplus.signum.indispensable.cosef

import at.asitplus.signum.indispensable.cosef.io.ByteStringWrapper
import at.asitplus.signum.indispensable.cosef.io.ByteStringWrapperSerializer
import at.asitplus.signum.indispensable.cosef.io.coseCompliantSerializer
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.builtins.ByteArraySerializer
import kotlinx.serialization.cbor.ValueTags
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.SerialKind
import kotlinx.serialization.descriptors.StructureKind
import kotlinx.serialization.descriptors.buildSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.encoding.decodeStructure
import kotlinx.serialization.encoding.encodeStructure

class CoseSignedSerializer<P : Any?> : KSerializer<CoseSigned<P>> {
/**
* Serializes [CoseSigned] with a typed payload,
* also adding Tag 24 to the payload, if it is a typed object, i.e. not a byte array.
*/
class CoseSignedSerializer<P : Any?>(
private val parameterSerializer: KSerializer<P>,
) : KSerializer<CoseSigned<P>> {

@OptIn(InternalSerializationApi::class)
override val descriptor: SerialDescriptor = buildSerialDescriptor("CoseSigned", StructureKind.LIST) {
element("protectedHeader", ByteStringWrapperSerializer(CoseHeader.serializer()).descriptor)
element("unprotectedHeader", CoseHeader.serializer().descriptor)
element("payload", ByteArraySerializer().descriptor)
element("payload", ByteStringWrapperSerializer(parameterSerializer).descriptor)
element("signature", ByteArraySerializer().descriptor)
}

override fun deserialize(decoder: Decoder): CoseSigned<P> {
return decoder.decodeStructure(descriptor) {
val protectedHeader = decodeSerializableElement(descriptor, 0, ByteStringWrapperSerializer(CoseHeader.serializer()))
val unprotectedHeader = decodeNullableSerializableElement(descriptor, 1, CoseHeader.serializer())
val payload = decodeNullableSerializableElement(descriptor, 2, ByteArraySerializer())
val signature = decodeSerializableElement(descriptor, 3, ByteArraySerializer())
CoseSigned(protectedHeader, unprotectedHeader, payload, signature)
override fun deserialize(decoder: Decoder): CoseSigned<P> = decoder.decodeStructure(descriptor) {
val protectedHeader =
decodeSerializableElement(descriptor, 0, ByteStringWrapperSerializer(CoseHeader.serializer()))
val unprotectedHeader = decodeNullableSerializableElement(descriptor, 1, CoseHeader.serializer())
val payload: ByteArray? = decodeNullableSerializableElement(descriptor, 2, ByteArraySerializer())
val signature: ByteArray = decodeSerializableElement(descriptor, 3, ByteArraySerializer())
runCatching {
val typedPayload: P? = payload?.let {
coseCompliantSerializer.decodeFromByteArray(parameterSerializer, it)
}
CoseSigned(protectedHeader, unprotectedHeader, typedPayload, signature)
}.getOrElse {
@Suppress("UNCHECKED_CAST")
CoseSigned(protectedHeader, unprotectedHeader, payload as P, signature)
}
}

override fun serialize(encoder: Encoder, value: CoseSigned<P>) {
encoder.encodeStructure(descriptor) {
encodeSerializableElement(descriptor, 0, ByteStringWrapperSerializer(CoseHeader.serializer()), value.protectedHeader)
encodeSerializableElement(
descriptor,
0,
ByteStringWrapperSerializer(CoseHeader.serializer()),
value.protectedHeader
)
encodeNullableSerializableElement(descriptor, 1, CoseHeader.serializer(), value.unprotectedHeader)
encodeNullableSerializableElement(descriptor, 2, ByteArraySerializer(), value.payload)
if (value.payload != null && value.payload::class != ByteArray::class) {
encodeNullableSerializableElement(
buildTag24SerialDescriptor(),
2,
ByteStringWrapperSerializer(parameterSerializer),
ByteStringWrapper(value.payload)
)
} else {
encodeNullableSerializableElement(descriptor, 2, parameterSerializer, value.payload)
}
encodeSerializableElement(descriptor, 3, ByteArraySerializer(), value.rawSignature)
}
}

private fun buildTag24SerialDescriptor(): SerialDescriptor = object : SerialDescriptor {
@ExperimentalSerializationApi
override val serialName: String = descriptor.serialName

@ExperimentalSerializationApi
override val kind: SerialKind = descriptor.kind

@ExperimentalSerializationApi
override val elementsCount: Int = descriptor.elementsCount

@ExperimentalSerializationApi
override fun getElementName(index: Int): String = descriptor.getElementName(index)

@ExperimentalSerializationApi
override fun getElementIndex(name: String): Int = descriptor.getElementIndex(name)

@ExperimentalSerializationApi
override fun getElementAnnotations(index: Int): List<Annotation> =
if (index != 2) descriptor.getElementAnnotations(index) else listOf(ValueTags(24u))

@ExperimentalSerializationApi
override fun getElementDescriptor(index: Int): SerialDescriptor = descriptor.getElementDescriptor(index)

@ExperimentalSerializationApi
override fun isElementOptional(index: Int): Boolean = descriptor.isElementOptional(index)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,8 @@ class ByteStringWrapper<T>(
override fun toString(): String {
return "ByteStringWrapper(value=$value, serialized=${serialized.contentToString()})"
}

}


@OptIn(ExperimentalSerializationApi::class)
class ByteStringWrapperSerializer<T>(private val dataSerializer: KSerializer<T>) :
KSerializer<ByteStringWrapper<T>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package at.asitplus.signum.indispensable.cosef

import at.asitplus.signum.indispensable.CryptoSignature
import at.asitplus.signum.indispensable.cosef.io.Base16Strict
import at.asitplus.signum.indispensable.cosef.io.ByteStringWrapper
import io.kotest.core.spec.style.FreeSpec
import io.kotest.matchers.shouldBe
import io.kotest.matchers.shouldNotBe
Expand All @@ -18,16 +17,16 @@ class CoseEqualsTest : FreeSpec({
"equals with byte array" {
checkAll(Arb.byteArray(length = Arb.int(0, 10), content = Arb.byte())) { bytes ->
val bytesSigned1 = CoseSigned<ByteArray>(
protectedHeader = ByteStringWrapper(CoseHeader()),
protectedHeader = CoseHeader(),
unprotectedHeader = null,
payload = bytes,
rawSignature = bytes
signature = CryptoSignature.RSAorHMAC(bytes)
)
val bytesSigned2 = CoseSigned<ByteArray>(
protectedHeader = ByteStringWrapper(CoseHeader()),
protectedHeader = CoseHeader(),
unprotectedHeader = null,
payload = bytes,
rawSignature = bytes
signature = CryptoSignature.RSAorHMAC(bytes)
)

bytesSigned1 shouldBe bytesSigned1
Expand All @@ -37,16 +36,16 @@ class CoseEqualsTest : FreeSpec({

val reversed = bytes.reversedArray().let { it + it + 1 + 3 + 5 }
val reversedSigned1 = CoseSigned<ByteArray>(
protectedHeader = ByteStringWrapper(CoseHeader()),
protectedHeader = CoseHeader(),
unprotectedHeader = null,
payload = reversed,
rawSignature = reversed
signature = CryptoSignature.RSAorHMAC(reversed)
)
val reversedSigned2 = CoseSigned<ByteArray>(
protectedHeader = ByteStringWrapper(CoseHeader()),
protectedHeader = CoseHeader(),
unprotectedHeader = null,
payload = reversed,
rawSignature = reversed
signature = CryptoSignature.RSAorHMAC(reversed)
)

reversedSigned2 shouldBe reversedSigned2
Expand All @@ -72,13 +71,13 @@ class CoseEqualsTest : FreeSpec({
"equals with data class" {
checkAll(Arb.byteArray(length = Arb.int(0, 10), content = Arb.byte())) { bytes ->
val payload = DataClass(content = bytes.encodeToString(Base16Strict))
val bytesSigned1 = CoseSigned.fromObject<DataClass>(
val bytesSigned1 = CoseSigned(
protectedHeader = CoseHeader(),
unprotectedHeader = null,
payload = payload,
signature = CryptoSignature.RSAorHMAC(bytes)
)
val bytesSigned2 = CoseSigned.fromObject<DataClass>(
val bytesSigned2 = CoseSigned(
protectedHeader = CoseHeader(),
unprotectedHeader = null,
payload = payload,
Expand All @@ -90,19 +89,20 @@ class CoseEqualsTest : FreeSpec({
bytesSigned1.hashCode() shouldBe bytesSigned1.hashCode()
bytesSigned1.hashCode() shouldBe bytesSigned2.hashCode()

val reversed = DataClass(content = bytes.reversedArray().let { it + it + 1 + 3 + 5 }.encodeToString(Base16Strict))
val reversedSigned1 = CoseSigned.fromObject<DataClass>(
val reversed =
DataClass(content = bytes.reversedArray().let { it + it + 1 + 3 + 5 }.encodeToString(Base16Strict))
val reversedSigned1 = CoseSigned(
protectedHeader = CoseHeader(),
unprotectedHeader = null,
payload = reversed,
signature = CryptoSignature.RSAorHMAC(bytes)
)
val reversedSigned2 = CoseSigned.fromObject<DataClass>(
val reversedSigned2 = CoseSigned(
protectedHeader = CoseHeader(),
unprotectedHeader = null,
payload = reversed,
signature = CryptoSignature.RSAorHMAC(bytes)
).also { println(it.serialize().encodeToString(Base16Strict))}
)

reversedSigned2 shouldBe reversedSigned2
reversedSigned2 shouldBe reversedSigned1
Expand Down
Loading

0 comments on commit 67298ae

Please sign in to comment.