Skip to content

Commit

Permalink
[Validator] Check size of PSV. (#6924)
Browse files Browse the repository at this point in the history
Check size of PSV part matches the PSVVersion.
Updated DxilPipelineStateValidation::ReadOrWrite to read based on
initInfo.PSVVersion.
And return fail when size mismatch in RWMode::Read.

Fixes #6817
  • Loading branch information
python3kgae authored Oct 3, 2024
1 parent b05313c commit 9221570
Show file tree
Hide file tree
Showing 7 changed files with 816 additions and 55 deletions.
1 change: 1 addition & 0 deletions docs/ReleaseNotes.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Place release notes for the upcoming release below this line and remove this lin

- The incomplete WaveMatrix implementation has been removed.
- DXIL Validator Hash is open sourced.
- DXIL container validation for PSV0 part allows any content ordering inside string and semantic index tables.

### Version 1.8.2407

Expand Down
13 changes: 10 additions & 3 deletions include/dxc/DxilContainer/DxilPipelineStateValidation.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ struct PSVStringTable {
PSVStringTable() : Table(nullptr), Size(0) {}
PSVStringTable(const char *table, uint32_t size) : Table(table), Size(size) {}
const char *Get(uint32_t offset) const {
assert(offset < Size && Table && Table[Size - 1] == '\0');
if (!(offset < Size && Table && Table[Size - 1] == '\0'))
return nullptr;
return Table + offset;
}
};
Expand Down Expand Up @@ -344,7 +345,8 @@ struct PSVSemanticIndexTable {
PSVSemanticIndexTable(const uint32_t *table, uint32_t entries)
: Table(table), Entries(entries) {}
const uint32_t *Get(uint32_t offset) const {
assert(offset < Entries && Table);
if (!(offset < Entries && Table))
return nullptr;
return Table + offset;
}
};
Expand Down Expand Up @@ -638,7 +640,8 @@ class DxilPipelineStateValidation {
_T *GetRecord(void *pRecords, uint32_t recordSize, uint32_t numRecords,
uint32_t index) const {
if (pRecords && index < numRecords && sizeof(_T) <= recordSize) {
assert((size_t)index * (size_t)recordSize <= UINT_MAX);
if (!((size_t)index * (size_t)recordSize <= UINT_MAX))
return nullptr;
return reinterpret_cast<_T *>(reinterpret_cast<uint8_t *>(pRecords) +
(index * recordSize));
}
Expand Down Expand Up @@ -1126,6 +1129,10 @@ void InitPSVSignatureElement(PSVSignatureElement0 &E,
const DxilSignatureElement &SE,
bool i1ToUnknownCompat);

// Setup PSVInitInfo with DxilModule.
// Note that the StringTable and PSVSemanticIndexTable are not done.
void SetupPSVInitInfo(PSVInitInfo &InitInfo, const DxilModule &DM);

// Setup shader properties for PSVRuntimeInfo* with DxilModule.
void SetShaderProps(PSVRuntimeInfo0 *pInfo, const DxilModule &DM);
void SetShaderProps(PSVRuntimeInfo1 *pInfo1, const DxilModule &DM);
Expand Down
34 changes: 2 additions & 32 deletions lib/DxilContainer/DxilContainerAssembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,30 +738,14 @@ class DxilPSVWriter : public DxilPartWriter {
DxilPSVWriter(const DxilModule &mod, uint32_t PSVVersion = UINT_MAX)
: m_Module(mod), m_PSVInitInfo(PSVVersion) {
m_Module.GetValidatorVersion(m_ValMajor, m_ValMinor);
// Constraint PSVVersion based on validator version
uint32_t PSVVersionConstraint = hlsl::GetPSVVersion(m_ValMajor, m_ValMinor);
if (PSVVersion > PSVVersionConstraint)
m_PSVInitInfo.PSVVersion = PSVVersionConstraint;

const ShaderModel *SM = m_Module.GetShaderModel();
UINT uCBuffers = m_Module.GetCBuffers().size();
UINT uSamplers = m_Module.GetSamplers().size();
UINT uSRVs = m_Module.GetSRVs().size();
UINT uUAVs = m_Module.GetUAVs().size();
m_PSVInitInfo.ResourceCount = uCBuffers + uSamplers + uSRVs + uUAVs;
hlsl::SetupPSVInitInfo(m_PSVInitInfo, m_Module);

// TODO: for >= 6.2 version, create more efficient structure
if (m_PSVInitInfo.PSVVersion > 0) {
m_PSVInitInfo.ShaderStage = (PSVShaderKind)SM->GetKind();
// Copy Dxil Signatures
m_StringBuffer.push_back('\0'); // For empty semantic name (system value)
m_PSVInitInfo.SigInputElements =
m_Module.GetInputSignature().GetElements().size();
m_SigInputElements.resize(m_PSVInitInfo.SigInputElements);
m_PSVInitInfo.SigOutputElements =
m_Module.GetOutputSignature().GetElements().size();
m_SigOutputElements.resize(m_PSVInitInfo.SigOutputElements);
m_PSVInitInfo.SigPatchConstOrPrimElements =
m_Module.GetPatchConstOrPrimSignature().GetElements().size();
m_SigPatchConstOrPrimElements.resize(
m_PSVInitInfo.SigPatchConstOrPrimElements);
uint32_t i = 0;
Expand Down Expand Up @@ -791,20 +775,6 @@ class DxilPSVWriter : public DxilPartWriter {
m_PSVInitInfo.StringTable.Size = m_StringBuffer.size();
m_PSVInitInfo.SemanticIndexTable.Table = m_SemanticIndexBuffer.data();
m_PSVInitInfo.SemanticIndexTable.Entries = m_SemanticIndexBuffer.size();
// Set up ViewID and signature dependency info
m_PSVInitInfo.UsesViewID =
m_Module.m_ShaderFlags.GetViewID() ? true : false;
m_PSVInitInfo.SigInputVectors =
m_Module.GetInputSignature().NumVectorsUsed(0);
for (unsigned streamIndex = 0; streamIndex < 4; streamIndex++) {
m_PSVInitInfo.SigOutputVectors[streamIndex] =
m_Module.GetOutputSignature().NumVectorsUsed(streamIndex);
}
m_PSVInitInfo.SigPatchConstOrPrimVectors = 0;
if (SM->IsHS() || SM->IsDS() || SM->IsMS()) {
m_PSVInitInfo.SigPatchConstOrPrimVectors =
m_Module.GetPatchConstOrPrimSignature().NumVectorsUsed(0);
}
}
if (!m_PSV.InitNew(m_PSVInitInfo, nullptr, &m_PSVBufferSize)) {
DXASSERT(false, "PSV InitNew failed computing size!");
Expand Down
37 changes: 37 additions & 0 deletions lib/DxilContainer/DxilPipelineStateValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,43 @@ void hlsl::InitPSVSignatureElement(PSVSignatureElement0 &E,
E.DynamicMaskAndStream |= (SE.GetDynIdxCompMask()) & 0xF;
}

void hlsl::SetupPSVInitInfo(PSVInitInfo &InitInfo, const DxilModule &DM) {
// Constraint PSVVersion based on validator version
unsigned ValMajor, ValMinor;
DM.GetValidatorVersion(ValMajor, ValMinor);
unsigned PSVVersionConstraint = hlsl::GetPSVVersion(ValMajor, ValMinor);
if (InitInfo.PSVVersion > PSVVersionConstraint)
InitInfo.PSVVersion = PSVVersionConstraint;

const ShaderModel *SM = DM.GetShaderModel();
uint32_t uCBuffers = DM.GetCBuffers().size();
uint32_t uSamplers = DM.GetSamplers().size();
uint32_t uSRVs = DM.GetSRVs().size();
uint32_t uUAVs = DM.GetUAVs().size();
InitInfo.ResourceCount = uCBuffers + uSamplers + uSRVs + uUAVs;

if (InitInfo.PSVVersion > 0) {
InitInfo.ShaderStage = (PSVShaderKind)SM->GetKind();
InitInfo.SigInputElements = DM.GetInputSignature().GetElements().size();
InitInfo.SigPatchConstOrPrimElements =
DM.GetPatchConstOrPrimSignature().GetElements().size();
InitInfo.SigOutputElements = DM.GetOutputSignature().GetElements().size();

// Set up ViewID and signature dependency info
InitInfo.UsesViewID = DM.m_ShaderFlags.GetViewID() ? true : false;
InitInfo.SigInputVectors = DM.GetInputSignature().NumVectorsUsed(0);
for (unsigned streamIndex = 0; streamIndex < 4; streamIndex++) {
InitInfo.SigOutputVectors[streamIndex] =
DM.GetOutputSignature().NumVectorsUsed(streamIndex);
}
InitInfo.SigPatchConstOrPrimVectors = 0;
if (SM->IsHS() || SM->IsDS() || SM->IsMS()) {
InitInfo.SigPatchConstOrPrimVectors =
DM.GetPatchConstOrPrimSignature().NumVectorsUsed(0);
}
}
}

void hlsl::SetShaderProps(PSVRuntimeInfo0 *pInfo, const DxilModule &DM) {
const ShaderModel *SM = DM.GetShaderModel();
pInfo->MinimumExpectedWaveLaneCount = 0;
Expand Down
Loading

0 comments on commit 9221570

Please sign in to comment.