diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 4b4f26983b..8918fb0c6c 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -4904,7 +4904,7 @@ class HLSLExternalSource : public ExternalSemaSource { } bool IsValidTemplateArgumentType(SourceLocation argLoc, const QualType &type, - bool requireScalar) { + bool requireScalar, bool allowObject) { if (type.isNull()) { return false; } @@ -4929,7 +4929,7 @@ class HLSLExternalSource : public ExternalSemaSource { if (qt->isArrayType()) { const ArrayType *arrayType = qt->getAsArrayTypeUnsafe(); return IsValidTemplateArgumentType(argLoc, arrayType->getElementType(), - false); + false, allowObject); } else if (objectKind == AR_TOBJ_VECTOR) { bool valid = true; if (!IsValidVectorSize(GetHLSLVecSize(type))) { @@ -4964,9 +4964,12 @@ class HLSLExternalSource : public ExternalSemaSource { objectKind = ClassifyRecordType(recordType); switch (objectKind) { case AR_TOBJ_OBJECT: - m_sema->Diag(argLoc, diag::err_hlsl_objectintemplateargument) << type; - return false; case AR_TOBJ_COMPOUND: { + if (objectKind == AR_TOBJ_OBJECT && !allowObject) { + m_sema->Diag(argLoc, diag::err_hlsl_objectintemplateargument) + << type; + return false; + } const RecordDecl *recordDecl = recordType->getDecl(); if (recordDecl->isInvalidDecl()) return false; @@ -4975,8 +4978,9 @@ class HLSLExternalSource : public ExternalSemaSource { bool result = true; while (begin != end) { const FieldDecl *fieldDecl = *begin; - if (!IsValidTemplateArgumentType(argLoc, fieldDecl->getType(), - false)) { + if (!IsValidTemplateArgumentType( + argLoc, fieldDecl->getType(), false, + allowObject && objectKind != AR_TOBJ_OBJECT)) { m_sema->Diag(argLoc, diag::note_field_type_usage) << fieldDecl->getType() << fieldDecl->getIdentifier() << type; result = false; @@ -5193,7 +5197,10 @@ class HLSLExternalSource : public ExternalSemaSource { QualType argType = arg.getAsType(); // Skip dependent types. Types will be checked later, when concrete. if (!argType->isDependentType()) { - if (!IsValidTemplateArgumentType(argSrcLoc, argType, requireScalar)) { + bool allowObject = + templateName == "SpirvType" || templateName == "SpirvOpaqueType"; + if (!IsValidTemplateArgumentType(argSrcLoc, argType, requireScalar, + allowObject)) { // NOTE: IsValidTemplateArgumentType emits its own diagnostics return true; } diff --git a/tools/clang/test/CodeGenSPIRV/spv.inline.type.hlsl b/tools/clang/test/CodeGenSPIRV/spv.inline.type.hlsl index 9db81a4483..8735e21a05 100644 --- a/tools/clang/test/CodeGenSPIRV/spv.inline.type.hlsl +++ b/tools/clang/test/CodeGenSPIRV/spv.inline.type.hlsl @@ -1,12 +1,8 @@ // RUN: %dxc -T ps_6_0 -E main -fcgl %s -spirv | FileCheck %s -// TODO(6498): enable Array test when using `Texture2D` with an alias template of `SpirvType` is fixed -// CHECK-TODO: %type_Array_type_2d_image = OpTypeArray %type_2d_image -// template -// using Array = vk::SpirvOpaqueType; - // CHECK: %spirvIntrinsicType = OpTypeArray %type_2d_image %uint_4 -typedef vk::SpirvOpaqueType > ArrayTex2D; +template +using Array = vk::SpirvOpaqueType >; // CHECK: %spirvIntrinsicType_0 = OpTypeInt 8 0 using uint8_t [[vk::ext_capability(/* Int8 */ 39)]] = vk::SpirvType >, vk::Literal > >; @@ -15,8 +11,7 @@ using uint8_t [[vk::ext_capability(/* Int8 */ 39)]] = vk::SpirvType image; - ArrayTex2D image; + Array image; // CHECK: %byte = OpVariable %_ptr_Function_spirvIntrinsicType_0 uint8_t byte;