From 97e9b793188ee7ecb40382a511105fd1fccf63dd Mon Sep 17 00:00:00 2001 From: Akio Gaule <10719597+akioCL@users.noreply.github.com> Date: Tue, 12 Nov 2024 11:34:17 -0300 Subject: [PATCH 1/2] Add webgpu support Signed-off-by: Akio Gaule <10719597+akioCL@users.noreply.github.com> --- .../Windows/src/WebGPUPlatformEmitter.cpp | 54 ++++++++ Platform/Windows/src/WebGPUPlatformEmitter.h | 33 +++++ src/AzslcEmitter.cpp | 23 ++-- src/AzslcReflection.cpp | 120 +++++++++++++++++- src/AzslcUtils.h | 38 ++++++ src/PadToAttributeMutator.cpp | 48 +++---- 6 files changed, 276 insertions(+), 40 deletions(-) create mode 100644 Platform/Windows/src/WebGPUPlatformEmitter.cpp create mode 100644 Platform/Windows/src/WebGPUPlatformEmitter.h diff --git a/Platform/Windows/src/WebGPUPlatformEmitter.cpp b/Platform/Windows/src/WebGPUPlatformEmitter.cpp new file mode 100644 index 0000000..bec2c87 --- /dev/null +++ b/Platform/Windows/src/WebGPUPlatformEmitter.cpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) Contributors to the Open 3D Engine Project. + * For complete copyright and license terms please see the LICENSE at the root of this distribution. + * + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + */ + +#include +#include "WebGPUPlatformEmitter.h" + +namespace AZ::ShaderCompiler +{ + static constexpr char WebGPUPlatformEmitterName[] = "wg"; + static const PlatformEmitter* s_platformEmitter = WebGPUPlatformEmitter::RegisterPlatformEmitter(); + + const PlatformEmitter* WebGPUPlatformEmitter::RegisterPlatformEmitter() noexcept(false) + { + static WebGPUPlatformEmitter platformEmitter; // Static linkage, will be destroyed + + static bool alreadyRegistered = false; + if (!alreadyRegistered) + { + PlatformEmitter::SetEmitter(WebGPUPlatformEmitterName, &platformEmitter); + alreadyRegistered = true; + } + + return &platformEmitter; + } + + std::string WebGPUPlatformEmitter::GetRootConstantsView(const CodeEmitter& codeEmitter, const RootSigDesc& rootSig, const Options& options, BindingPair::Set signatureQuery) const + { + std::stringstream strOut; + + const auto& structUid = codeEmitter.GetIR()->m_rootConstantStructUID; + const auto& bindInfo = rootSig.Get(structUid); + assert(structUid == bindInfo.m_uid); + const auto& rootCBForEmission = codeEmitter.GetTranslatedName(RootConstantsViewName, UsageContext::DeclarationSite); + const auto& rootConstClassForEmission = codeEmitter.GetTranslatedName(structUid.GetName(), UsageContext::ReferenceSite); + const auto& spaceX = ", space" + std::to_string(bindInfo.m_registerBinding.m_pair[signatureQuery].m_logicalSpace); + strOut << "ConstantBuffer<" << rootConstClassForEmission << "> " << rootCBForEmission << " : register(b" << bindInfo.m_registerBinding.m_pair[signatureQuery].m_registerIndex << spaceX << ");\n\n"; + return strOut.str(); + } + + uint32_t WebGPUPlatformEmitter::AlignRootConstants(uint32_t size) const + { + return Packing::AlignUp(size, Packing::s_bytesPerRegister); + } + + SubpassInputSupportFlag WebGPUPlatformEmitter::GetSubpassInputSupport() const + { + return SubpassInputSupportFlag::None; + } +} diff --git a/Platform/Windows/src/WebGPUPlatformEmitter.h b/Platform/Windows/src/WebGPUPlatformEmitter.h new file mode 100644 index 0000000..debacf7 --- /dev/null +++ b/Platform/Windows/src/WebGPUPlatformEmitter.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Contributors to the Open 3D Engine Project. + * For complete copyright and license terms please see the LICENSE at the root of this distribution. + * + * SPDX-License-Identifier: Apache-2.0 OR MIT + * + */ +#pragma once + +#include +#include + +namespace AZ::ShaderCompiler +{ + // PlatformEmitter is not a Backend by design. It's a supplement to CodeEmitter, not a replacement + struct WebGPUPlatformEmitter : CommonVulkanPlatformEmitter + { + public: + //! This method will be called once and only once when the platform emitter registers itself to the system. + //! Returns a singleton object of this class. + static const PlatformEmitter* RegisterPlatformEmitter() noexcept(false); + + [[nodiscard]] + std::string GetRootConstantsView(const CodeEmitter& codeEmitter, const RootSigDesc& rootSig, const Options& options, BindingPair::Set signatureQuery) const override final; + + uint32_t AlignRootConstants(uint32_t size) const override final; + + SubpassInputSupportFlag GetSubpassInputSupport() const override; + + private: + WebGPUPlatformEmitter() : CommonVulkanPlatformEmitter {} {}; + }; +} diff --git a/src/AzslcEmitter.cpp b/src/AzslcEmitter.cpp index 4a7e4a3..310d7a7 100644 --- a/src/AzslcEmitter.cpp +++ b/src/AzslcEmitter.cpp @@ -664,20 +664,15 @@ namespace AZ::ShaderCompiler } } - else if (attrInfo.m_attribute == "partial") - { - // Reserved for ShaderResourceGroup use. Do not re-emit - outstream << "// original attribute: [[" << attrInfo << "]]\n "; - } - - else if (attrInfo.m_attribute == "range") - { - // Reserved for integer type option variables. Do not re-emit - outstream << "// original attribute: [[" << attrInfo << "]]\n "; - } - else if (attrInfo.m_attribute == "no_specialization") - { - // Reserved for avoiding specialization of a shader option. Do not re-emit + else if ( + attrInfo.m_attribute == "partial" || // Reserved for ShaderResourceGroup use. Do not re-emit + attrInfo.m_attribute == "range" || // Reserved for integer type option variables. Do not re-emit + attrInfo.m_attribute == "no_specialization" || // Reserved for avoiding specialization of a shader option. Do not re-emit + attrInfo.m_attribute == "unrolled" || // Reserved for unrolled resource arrays. Do not re-emit + attrInfo.m_attribute == "access" || // Reserved for storage textures. Do not re-emit + attrInfo.m_attribute == "sample_type" || // Reserved for sampled textures. Do not re-emit + attrInfo.m_attribute == "binding_type") // Reserved for samplers. Do not re-emit + { outstream << "// original attribute: [[" << attrInfo << "]]\n "; } diff --git a/src/AzslcReflection.cpp b/src/AzslcReflection.cpp index 1d0836a..0f0deed 100644 --- a/src/AzslcReflection.cpp +++ b/src/AzslcReflection.cpp @@ -703,7 +703,8 @@ namespace AZ::ShaderCompiler // SRVs and UAVs for (const auto& tId : srgInfo->m_srViews) { - const auto& bindInfo = rootSig.Get(tId); + auto mangledName = tId.m_name; + auto bindInfo = rootSig.Get(tId); uint32_t strideSize = GetViewStride(tId, options.m_packDataBuffers, options); @@ -718,12 +719,74 @@ namespace AZ::ShaderCompiler strideSize = Packing::AlignUp(strideSize, Packing::s_bytesPerRegister); } + auto unrolledAttribute = m_ir->m_symbols.GetAttribute(tId, "unrolled"); + if (unrolledAttribute) + { + size_t last_index = tId.m_name.find_last_not_of("0123456789"); + std::string result = tId.m_name.substr(last_index + 1); + int arrayId = ::atoi(result.c_str()); + if (arrayId != 0) + { + continue; + } + + mangledName = QualifiedName(tId.m_name.substr(0, last_index + 1)); + uint32_t arraySize = 0; + VarInfo* varInfo = nullptr; + do + { + arraySize++; + varInfo = m_ir->GetSymbolSubAs(QualifiedName(mangledName + std::to_string(arraySize))); + } while (varInfo); + + bindInfo.m_uid.m_name = mangledName; + bindInfo.m_registerRange = arraySize; + } + + std::string format = "Unknown"; + auto formatAttribute = m_ir->m_symbols.GetAttribute(tId, "image_format"); + if (formatAttribute) + { + if (!formatAttribute->m_argList.empty() && m_ir->IsAttributeNamespaceActivated(formatAttribute->m_namespace)) + { + if (holds_alternative(formatAttribute->m_argList[0])) + { + format = Trim(get(formatAttribute->m_argList[0]), "\""); + format = ToRHIFormat(format.c_str()); + } + } + } + + std::string sampleType = "Unknown"; + auto sampleTypeAttribute = m_ir->m_symbols.GetAttribute(tId, "sample_type"); + if (sampleTypeAttribute) + { + if (!sampleTypeAttribute->m_argList.empty()) + { + if (holds_alternative(sampleTypeAttribute->m_argList[0])) + { + sampleType = Trim(get(sampleTypeAttribute->m_argList[0]), "\""); + } + } + } + + std::string usage = (isReadWriteView) ? "ReadWrite" : "Read"; + auto accessAttribute = m_ir->m_symbols.GetAttribute(tId, "access"); + if (accessAttribute) + { + if (!accessAttribute->m_argList.empty()) + { + usage = Trim(get(accessAttribute->m_argList[0]), "\""); + } + } + Json::Value dataView(Json::objectValue); - dataView["id"] = ExtractLeaf(tId.m_name).data(); + dataView["id"] = ExtractLeaf(mangledName).data(); dataView["type"] = viewName; - dataView["usage"] = (isReadWriteView) ? "ReadWrite" : "Read"; + dataView["usage"] = usage; ReflectBinding(dataView, bindInfo); dataView["stride"] = strideSize; + dataView["format"] = format; if (isBufferView) { @@ -731,6 +794,7 @@ namespace AZ::ShaderCompiler } else { + dataView["sampleType"] = sampleType; imagesList.append(dataView); } } @@ -746,9 +810,23 @@ namespace AZ::ShaderCompiler const auto* srgMemberInfo = m_ir->GetSymbolSubAs(sId.m_name); const auto& samplerInfo = *srgMemberInfo->m_samplerState; + std::string bindingType = "Unknown"; + auto bindingTypeAttribute = m_ir->m_symbols.GetAttribute(sId, "binding_type"); + if (bindingTypeAttribute) + { + if (!bindingTypeAttribute->m_argList.empty()) + { + if (holds_alternative(bindingTypeAttribute->m_argList[0])) + { + bindingType = Trim(get(bindingTypeAttribute->m_argList[0]), "\""); + } + } + } + Json::Value samplerJson(Json::objectValue); samplerJson["id"] = sId.GetNameLeaf(); samplerJson["isDynamic"] = samplerInfo.m_isDynamic; + samplerJson["bindingType"] = bindingType; ReflectBinding(samplerJson, bindInfo); if (!samplerInfo.m_isDynamic) @@ -954,6 +1032,19 @@ namespace AZ::ShaderCompiler return resourceJsonValue; }; + auto combineJsonArrays = [](Json::Value& lhs, Json::Value& rhs) + { + if (!lhs.isArray() || !rhs.isArray()) + { + return; + } + + for (const auto entry : rhs) + { + lhs.append(entry); + } + }; + optional srgConstants; // if we have SRG Constants we treat them later for (auto& srgParam : srgDesc.m_parameters) { @@ -966,7 +1057,28 @@ namespace AZ::ShaderCompiler { set dependencyList; DiscoverTopLevelFunctionDependencies(srgParam.m_uid, dependencyList, m_functionIntervals); - srgMember[srgParam.m_uid.GetNameLeaf()] = makeJsonNodeForOneResource(dependencyList, srgParam, {}); + auto jsonNodeResource = makeJsonNodeForOneResource(dependencyList, srgParam, {}); + auto unrolledAttribute = m_ir->m_symbols.GetAttribute(srgParam.m_uid, "unrolled"); + if (unrolledAttribute) + { + auto leafName = srgParam.m_uid.GetNameLeaf(); + size_t last_index = leafName.find_last_not_of("0123456789"); + string resourceName = leafName.substr(0, last_index + 1); + if (srgMember[resourceName].empty()) + { + srgMember[resourceName] = std::move(jsonNodeResource); + } + else + { + combineJsonArrays(srgMember[resourceName]["dependentFunctions"], jsonNodeResource["dependentFunctions"]); + srgMember[resourceName]["binding"]["count"] = srgMember[resourceName]["binding"]["count"].asInt() + 1; + } + } + else + { + + srgMember[srgParam.m_uid.GetNameLeaf()] = std::move(jsonNodeResource); + } } } // SRG constants (and the variant-fallback) are in one special constant buffer diff --git a/src/AzslcUtils.h b/src/AzslcUtils.h index 1d24361..b855de0 100644 --- a/src/AzslcUtils.h +++ b/src/AzslcUtils.h @@ -1362,4 +1362,42 @@ namespace AZ::ShaderCompiler { return ctx && HasStandardInitializer(ctx->variableInitializer()); } + + inline const char* ToRHIFormat(string_view format) + { + if (EqualNoCase(format, "rgba32f")) + { + return "R32G32B32A32_FLOAT"; + } + else if (EqualNoCase(format, "rgba16f")) + { + return "R16G16B16A16_FLOAT"; + } + else if (EqualNoCase(format, "r32f")) + { + return "R32_FLOAT"; + } + else if (EqualNoCase(format, "rgba8")) + { + return "R8G8B8A8_UNORM"; + } + else if (EqualNoCase(format, "rgba8snorm")) + { + return "R8G8B8A8_SNORM"; + } + else if (EqualNoCase(format, "rg32f")) + { + return "R32G32_FLOAT"; + } + else if (EqualNoCase(format, "rg16f")) + { + return "R16G16_FLOAT"; + } + else if (EqualNoCase(format, "r16f")) + { + return "R16_FLOAT"; + } + + return "Unknown"; + } } diff --git a/src/PadToAttributeMutator.cpp b/src/PadToAttributeMutator.cpp index 4aad68a..e0e6e6d 100644 --- a/src/PadToAttributeMutator.cpp +++ b/src/PadToAttributeMutator.cpp @@ -479,19 +479,21 @@ namespace AZ::ShaderCompiler const auto deltaBytes = alignedOffset - startingOffset; if (deltaBytes < numBytesToAdd && deltaBytes != 0) { - string typeName = getFloatTypeNameOfSize(deltaBytes); - auto variableName = FormatString("__pad_at%u", startingOffset); - IdentifierUID newVarUid = createVariableInSymbolTable(scopeUid.GetName(), typeName, UnqualifiedName{variableName}); - if (insertBeforeThisUid.IsEmpty()) + for (uint32_t i = 0; i < (deltaBytes >> 2); ++i) { - classInfo->PushMember(newVarUid, Kind::Variable); - } - else - { - classInfo->InsertBefore(newVarUid, Kind::Variable, insertBeforeThisUid); - m_ir.m_symbols.m_elastic.MigrateOrder(newVarUid, insertBeforeThisUid); + auto variableName = FormatString("__pad_at%u", startingOffset + i * sizeof(float)); + IdentifierUID newVarUid = createVariableInSymbolTable(scopeUid.GetName(), "float", UnqualifiedName{ variableName }); + if (insertBeforeThisUid.IsEmpty()) + { + classInfo->PushMember(newVarUid, Kind::Variable); + } + else + { + classInfo->InsertBefore(newVarUid, Kind::Variable, insertBeforeThisUid); + m_ir.m_symbols.m_elastic.MigrateOrder(newVarUid, insertBeforeThisUid); + } + numAddedVariables++; } - numAddedVariables++; numBytesToAdd -= deltaBytes; startingOffset = alignedOffset; } @@ -522,19 +524,21 @@ namespace AZ::ShaderCompiler // 3rd variable. The remainder if (numBytesToAdd > 0) { - auto variableName = FormatString("__pad_at%u", startingOffset); - string typeName = getFloatTypeNameOfSize(numBytesToAdd); - IdentifierUID newVarUid = createVariableInSymbolTable(scopeUid.GetName(), typeName, UnqualifiedName{variableName}); - if (insertBeforeThisUid.IsEmpty()) + for (uint32_t i = 0; i < (numBytesToAdd >> 2); ++i) { - classInfo->PushMember(newVarUid, Kind::Variable); - } - else - { - classInfo->InsertBefore(newVarUid, Kind::Variable, insertBeforeThisUid); - m_ir.m_symbols.m_elastic.MigrateOrder(newVarUid, insertBeforeThisUid); + auto variableName = FormatString("__pad_at%u", startingOffset + i * sizeof(float)); + IdentifierUID newVarUid = createVariableInSymbolTable(scopeUid.GetName(), "float", UnqualifiedName{ variableName }); + if (insertBeforeThisUid.IsEmpty()) + { + classInfo->PushMember(newVarUid, Kind::Variable); + } + else + { + classInfo->InsertBefore(newVarUid, Kind::Variable, insertBeforeThisUid); + m_ir.m_symbols.m_elastic.MigrateOrder(newVarUid, insertBeforeThisUid); + } + numAddedVariables++; } - numAddedVariables++; } return numAddedVariables; From 479a613f93ed7dd096ffe38f2f86d98148396a38 Mon Sep 17 00:00:00 2001 From: Akio Gaule <10719597+akioCL@users.noreply.github.com> Date: Tue, 12 Nov 2024 12:34:55 -0300 Subject: [PATCH 2/2] Fix tests Signed-off-by: Akio Gaule <10719597+akioCL@users.noreply.github.com> --- tests/Advanced/pad-to-attribute-validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/Advanced/pad-to-attribute-validation.py b/tests/Advanced/pad-to-attribute-validation.py index 02cf6d3..5a6472f 100644 --- a/tests/Advanced/pad-to-attribute-validation.py +++ b/tests/Advanced/pad-to-attribute-validation.py @@ -27,8 +27,8 @@ def check_StructuredBuffer_Vs_ConstantBuffer_Padding(thefile, compilerPath, sile predicates.append(lambda expectedSize=expectedSize: j["ShaderResourceGroups"][0]["inputsForBufferViews"][0]["stride"] == expectedSize) predicates.append(lambda: j["ShaderResourceGroups"][0]["inputsForBufferViews"][0]["type"] == "StructuredBuffer") - predicates.append(lambda expectedSize=expectedSize: j["ShaderResourceGroups"][0]["inputsForSRGConstants"][27]["constantByteSize"] == expectedSize) - predicates.append(lambda: j["ShaderResourceGroups"][0]["inputsForSRGConstants"][27]["typeName"] == "/MyStruct") + predicates.append(lambda expectedSize=expectedSize: j["ShaderResourceGroups"][0]["inputsForSRGConstants"][-1]["constantByteSize"] == expectedSize) + predicates.append(lambda: j["ShaderResourceGroups"][0]["inputsForSRGConstants"][-1]["typeName"] == "/MyStruct") ok = testfuncs.verifyAllPredicates(predicates, j, silent) if ok and not silent: