Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add webgpu support #92

Open
wants to merge 2 commits into
base: webgpu
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions Platform/Windows/src/WebGPUPlatformEmitter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright (c) Contributors to the Open 3D Engine Project.
* For complete copyright and license terms please see the LICENSE at the root of this distribution.
*
* SPDX-License-Identifier: Apache-2.0 OR MIT
*
*/

#include <AzslcEmitter.h>
#include "WebGPUPlatformEmitter.h"

namespace AZ::ShaderCompiler
{
static constexpr char WebGPUPlatformEmitterName[] = "wg";
static const PlatformEmitter* s_platformEmitter = WebGPUPlatformEmitter::RegisterPlatformEmitter();

const PlatformEmitter* WebGPUPlatformEmitter::RegisterPlatformEmitter() noexcept(false)
{
static WebGPUPlatformEmitter platformEmitter; // Static linkage, will be destroyed

static bool alreadyRegistered = false;
if (!alreadyRegistered)
{
PlatformEmitter::SetEmitter(WebGPUPlatformEmitterName, &platformEmitter);
alreadyRegistered = true;
}

return &platformEmitter;
}

std::string WebGPUPlatformEmitter::GetRootConstantsView(const CodeEmitter& codeEmitter, const RootSigDesc& rootSig, const Options& options, BindingPair::Set signatureQuery) const
{
std::stringstream strOut;

const auto& structUid = codeEmitter.GetIR()->m_rootConstantStructUID;
const auto& bindInfo = rootSig.Get(structUid);
assert(structUid == bindInfo.m_uid);
const auto& rootCBForEmission = codeEmitter.GetTranslatedName(RootConstantsViewName, UsageContext::DeclarationSite);
const auto& rootConstClassForEmission = codeEmitter.GetTranslatedName(structUid.GetName(), UsageContext::ReferenceSite);
const auto& spaceX = ", space" + std::to_string(bindInfo.m_registerBinding.m_pair[signatureQuery].m_logicalSpace);
strOut << "ConstantBuffer<" << rootConstClassForEmission << "> " << rootCBForEmission << " : register(b" << bindInfo.m_registerBinding.m_pair[signatureQuery].m_registerIndex << spaceX << ");\n\n";
return strOut.str();
}

uint32_t WebGPUPlatformEmitter::AlignRootConstants(uint32_t size) const
{
return Packing::AlignUp(size, Packing::s_bytesPerRegister);
}

SubpassInputSupportFlag WebGPUPlatformEmitter::GetSubpassInputSupport() const
{
return SubpassInputSupportFlag::None;
}
}
33 changes: 33 additions & 0 deletions Platform/Windows/src/WebGPUPlatformEmitter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) Contributors to the Open 3D Engine Project.
* For complete copyright and license terms please see the LICENSE at the root of this distribution.
*
* SPDX-License-Identifier: Apache-2.0 OR MIT
*
*/
#pragma once

#include <AzslcPlatformEmitter.h>
#include <CommonVulkanPlatformEmitter.h>

namespace AZ::ShaderCompiler
{
// PlatformEmitter is not a Backend by design. It's a supplement to CodeEmitter, not a replacement
struct WebGPUPlatformEmitter : CommonVulkanPlatformEmitter
{
public:
//! This method will be called once and only once when the platform emitter registers itself to the system.
//! Returns a singleton object of this class.
static const PlatformEmitter* RegisterPlatformEmitter() noexcept(false);

[[nodiscard]]
std::string GetRootConstantsView(const CodeEmitter& codeEmitter, const RootSigDesc& rootSig, const Options& options, BindingPair::Set signatureQuery) const override final;

uint32_t AlignRootConstants(uint32_t size) const override final;

SubpassInputSupportFlag GetSubpassInputSupport() const override;

private:
WebGPUPlatformEmitter() : CommonVulkanPlatformEmitter {} {};
};
}
23 changes: 9 additions & 14 deletions src/AzslcEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,20 +664,15 @@ namespace AZ::ShaderCompiler
}
}

else if (attrInfo.m_attribute == "partial")
{
// Reserved for ShaderResourceGroup use. Do not re-emit
outstream << "// original attribute: [[" << attrInfo << "]]\n ";
}

else if (attrInfo.m_attribute == "range")
{
// Reserved for integer type option variables. Do not re-emit
outstream << "// original attribute: [[" << attrInfo << "]]\n ";
}
else if (attrInfo.m_attribute == "no_specialization")
{
// Reserved for avoiding specialization of a shader option. Do not re-emit
else if (
attrInfo.m_attribute == "partial" || // Reserved for ShaderResourceGroup use. Do not re-emit
attrInfo.m_attribute == "range" || // Reserved for integer type option variables. Do not re-emit
attrInfo.m_attribute == "no_specialization" || // Reserved for avoiding specialization of a shader option. Do not re-emit
attrInfo.m_attribute == "unrolled" || // Reserved for unrolled resource arrays. Do not re-emit
attrInfo.m_attribute == "access" || // Reserved for storage textures. Do not re-emit
attrInfo.m_attribute == "sample_type" || // Reserved for sampled textures. Do not re-emit
attrInfo.m_attribute == "binding_type") // Reserved for samplers. Do not re-emit
{
outstream << "// original attribute: [[" << attrInfo << "]]\n ";
}

Expand Down
120 changes: 116 additions & 4 deletions src/AzslcReflection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,8 @@ namespace AZ::ShaderCompiler
// SRVs and UAVs
for (const auto& tId : srgInfo->m_srViews)
{
const auto& bindInfo = rootSig.Get(tId);
auto mangledName = tId.m_name;
auto bindInfo = rootSig.Get(tId);

uint32_t strideSize = GetViewStride(tId, options.m_packDataBuffers, options);

Expand All @@ -718,19 +719,82 @@ namespace AZ::ShaderCompiler
strideSize = Packing::AlignUp(strideSize, Packing::s_bytesPerRegister);
}

auto unrolledAttribute = m_ir->m_symbols.GetAttribute(tId, "unrolled");
if (unrolledAttribute)
{
size_t last_index = tId.m_name.find_last_not_of("0123456789");
std::string result = tId.m_name.substr(last_index + 1);
int arrayId = ::atoi(result.c_str());
if (arrayId != 0)
{
continue;
}

mangledName = QualifiedName(tId.m_name.substr(0, last_index + 1));
uint32_t arraySize = 0;
VarInfo* varInfo = nullptr;
do
{
arraySize++;
varInfo = m_ir->GetSymbolSubAs<VarInfo>(QualifiedName(mangledName + std::to_string(arraySize)));
} while (varInfo);

bindInfo.m_uid.m_name = mangledName;
bindInfo.m_registerRange = arraySize;
}

std::string format = "Unknown";
auto formatAttribute = m_ir->m_symbols.GetAttribute(tId, "image_format");
if (formatAttribute)
{
if (!formatAttribute->m_argList.empty() && m_ir->IsAttributeNamespaceActivated(formatAttribute->m_namespace))
{
if (holds_alternative<string>(formatAttribute->m_argList[0]))
{
format = Trim(get<string>(formatAttribute->m_argList[0]), "\"");
format = ToRHIFormat(format.c_str());
}
}
}

std::string sampleType = "Unknown";
auto sampleTypeAttribute = m_ir->m_symbols.GetAttribute(tId, "sample_type");
if (sampleTypeAttribute)
{
if (!sampleTypeAttribute->m_argList.empty())
{
if (holds_alternative<string>(sampleTypeAttribute->m_argList[0]))
{
sampleType = Trim(get<string>(sampleTypeAttribute->m_argList[0]), "\"");
}
}
}

std::string usage = (isReadWriteView) ? "ReadWrite" : "Read";
auto accessAttribute = m_ir->m_symbols.GetAttribute(tId, "access");
if (accessAttribute)
{
if (!accessAttribute->m_argList.empty())
{
usage = Trim(get<string>(accessAttribute->m_argList[0]), "\"");
}
}

Json::Value dataView(Json::objectValue);
dataView["id"] = ExtractLeaf(tId.m_name).data();
dataView["id"] = ExtractLeaf(mangledName).data();
dataView["type"] = viewName;
dataView["usage"] = (isReadWriteView) ? "ReadWrite" : "Read";
dataView["usage"] = usage;
ReflectBinding(dataView, bindInfo);
dataView["stride"] = strideSize;
dataView["format"] = format;

if (isBufferView)
{
buffersList.append(dataView);
}
else
{
dataView["sampleType"] = sampleType;
imagesList.append(dataView);
}
}
Expand All @@ -746,9 +810,23 @@ namespace AZ::ShaderCompiler
const auto* srgMemberInfo = m_ir->GetSymbolSubAs<VarInfo>(sId.m_name);
const auto& samplerInfo = *srgMemberInfo->m_samplerState;

std::string bindingType = "Unknown";
auto bindingTypeAttribute = m_ir->m_symbols.GetAttribute(sId, "binding_type");
if (bindingTypeAttribute)
{
if (!bindingTypeAttribute->m_argList.empty())
{
if (holds_alternative<string>(bindingTypeAttribute->m_argList[0]))
{
bindingType = Trim(get<string>(bindingTypeAttribute->m_argList[0]), "\"");
}
}
}

Json::Value samplerJson(Json::objectValue);
samplerJson["id"] = sId.GetNameLeaf();
samplerJson["isDynamic"] = samplerInfo.m_isDynamic;
samplerJson["bindingType"] = bindingType;
ReflectBinding(samplerJson, bindInfo);

if (!samplerInfo.m_isDynamic)
Expand Down Expand Up @@ -954,6 +1032,19 @@ namespace AZ::ShaderCompiler
return resourceJsonValue;
};

auto combineJsonArrays = [](Json::Value& lhs, Json::Value& rhs)
{
if (!lhs.isArray() || !rhs.isArray())
{
return;
}

for (const auto entry : rhs)
{
lhs.append(entry);
}
};

optional<RootSigDesc::SrgParamDesc> srgConstants; // if we have SRG Constants we treat them later
for (auto& srgParam : srgDesc.m_parameters)
{
Expand All @@ -966,7 +1057,28 @@ namespace AZ::ShaderCompiler
{
set<IdentifierUID> dependencyList;
DiscoverTopLevelFunctionDependencies(srgParam.m_uid, dependencyList, m_functionIntervals);
srgMember[srgParam.m_uid.GetNameLeaf()] = makeJsonNodeForOneResource(dependencyList, srgParam, {});
auto jsonNodeResource = makeJsonNodeForOneResource(dependencyList, srgParam, {});
auto unrolledAttribute = m_ir->m_symbols.GetAttribute(srgParam.m_uid, "unrolled");
if (unrolledAttribute)
{
auto leafName = srgParam.m_uid.GetNameLeaf();
size_t last_index = leafName.find_last_not_of("0123456789");
string resourceName = leafName.substr(0, last_index + 1);
if (srgMember[resourceName].empty())
{
srgMember[resourceName] = std::move(jsonNodeResource);
}
else
{
combineJsonArrays(srgMember[resourceName]["dependentFunctions"], jsonNodeResource["dependentFunctions"]);
srgMember[resourceName]["binding"]["count"] = srgMember[resourceName]["binding"]["count"].asInt() + 1;
}
}
else
{

srgMember[srgParam.m_uid.GetNameLeaf()] = std::move(jsonNodeResource);
}
}
}
// SRG constants (and the variant-fallback) are in one special constant buffer
Expand Down
38 changes: 38 additions & 0 deletions src/AzslcUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1362,4 +1362,42 @@ namespace AZ::ShaderCompiler
{
return ctx && HasStandardInitializer(ctx->variableInitializer());
}

inline const char* ToRHIFormat(string_view format)
{
if (EqualNoCase(format, "rgba32f"))
{
return "R32G32B32A32_FLOAT";
}
else if (EqualNoCase(format, "rgba16f"))
{
return "R16G16B16A16_FLOAT";
}
else if (EqualNoCase(format, "r32f"))
{
return "R32_FLOAT";
}
else if (EqualNoCase(format, "rgba8"))
{
return "R8G8B8A8_UNORM";
}
else if (EqualNoCase(format, "rgba8snorm"))
{
return "R8G8B8A8_SNORM";
}
else if (EqualNoCase(format, "rg32f"))
{
return "R32G32_FLOAT";
}
else if (EqualNoCase(format, "rg16f"))
{
return "R16G16_FLOAT";
}
else if (EqualNoCase(format, "r16f"))
{
return "R16_FLOAT";
}

return "Unknown";
}
}
48 changes: 26 additions & 22 deletions src/PadToAttributeMutator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,19 +479,21 @@ namespace AZ::ShaderCompiler
const auto deltaBytes = alignedOffset - startingOffset;
if (deltaBytes < numBytesToAdd && deltaBytes != 0)
{
string typeName = getFloatTypeNameOfSize(deltaBytes);
auto variableName = FormatString("__pad_at%u", startingOffset);
IdentifierUID newVarUid = createVariableInSymbolTable(scopeUid.GetName(), typeName, UnqualifiedName{variableName});
if (insertBeforeThisUid.IsEmpty())
for (uint32_t i = 0; i < (deltaBytes >> 2); ++i)
{
classInfo->PushMember(newVarUid, Kind::Variable);
}
else
{
classInfo->InsertBefore(newVarUid, Kind::Variable, insertBeforeThisUid);
m_ir.m_symbols.m_elastic.MigrateOrder(newVarUid, insertBeforeThisUid);
auto variableName = FormatString("__pad_at%u", startingOffset + i * sizeof(float));
IdentifierUID newVarUid = createVariableInSymbolTable(scopeUid.GetName(), "float", UnqualifiedName{ variableName });
if (insertBeforeThisUid.IsEmpty())
{
classInfo->PushMember(newVarUid, Kind::Variable);
}
else
{
classInfo->InsertBefore(newVarUid, Kind::Variable, insertBeforeThisUid);
m_ir.m_symbols.m_elastic.MigrateOrder(newVarUid, insertBeforeThisUid);
}
numAddedVariables++;
}
numAddedVariables++;
numBytesToAdd -= deltaBytes;
startingOffset = alignedOffset;
}
Expand Down Expand Up @@ -522,19 +524,21 @@ namespace AZ::ShaderCompiler
// 3rd variable. The remainder
if (numBytesToAdd > 0)
{
auto variableName = FormatString("__pad_at%u", startingOffset);
string typeName = getFloatTypeNameOfSize(numBytesToAdd);
IdentifierUID newVarUid = createVariableInSymbolTable(scopeUid.GetName(), typeName, UnqualifiedName{variableName});
if (insertBeforeThisUid.IsEmpty())
for (uint32_t i = 0; i < (numBytesToAdd >> 2); ++i)
{
classInfo->PushMember(newVarUid, Kind::Variable);
}
else
{
classInfo->InsertBefore(newVarUid, Kind::Variable, insertBeforeThisUid);
m_ir.m_symbols.m_elastic.MigrateOrder(newVarUid, insertBeforeThisUid);
auto variableName = FormatString("__pad_at%u", startingOffset + i * sizeof(float));
IdentifierUID newVarUid = createVariableInSymbolTable(scopeUid.GetName(), "float", UnqualifiedName{ variableName });
if (insertBeforeThisUid.IsEmpty())
{
classInfo->PushMember(newVarUid, Kind::Variable);
}
else
{
classInfo->InsertBefore(newVarUid, Kind::Variable, insertBeforeThisUid);
m_ir.m_symbols.m_elastic.MigrateOrder(newVarUid, insertBeforeThisUid);
}
numAddedVariables++;
}
numAddedVariables++;
}

return numAddedVariables;
Expand Down
Loading
Loading