Skip to content

Commit

Permalink
Add EXT_mesh_shader validation support
Browse files Browse the repository at this point in the history
1. Each OpEntryPoint with the MeshEXT Execution Model can have at most one global OpVariable of storage class TaskPayloadWorkgroupEXT.
2. PerPrimitiveEXT only be used on a memory object declaration or a member of a structure type
3. PerPrimitiveEXT only Input in Fragment and Output in MeshEXT
4. Added Mesh vulkan validation support for following rules:
   VUID-Layer-Layer-07039 VUID-PrimitiveId-PrimitiveId-07040
   VUID-PrimitivePointIndicesEXT-PrimitivePointIndicesEXT-07042
   VUID-PrimitivePointIndicesEXT-PrimitivePointIndicesEXT-07046
   VUID-PrimitiveLineIndicesEXT-PrimitiveLineIndicesEXT-07048
   VUID-PrimitiveLineIndicesEXT-PrimitiveLineIndicesEXT-07052
   VUID-PrimitiveTriangleIndicesEXT-PrimitiveTriangleIndicesEXT-07054
   VUID-PrimitiveTriangleIndicesEXT-PrimitiveTriangleIndicesEXT-07058
   VUID-ViewportIndex-ViewportIndex-07060
   VUID-StandaloneSpirv-ExecutionModel-07330
   VUID-StandaloneSpirv-ExecutionModel-07331
   VUID-PrimitiveId-PrimitiveId-04336
   VUID-Layer-Layer-07039
   VUID-ViewportIndex-ViewportIndex-07060

   VUID-CullPrimitiveEXT-CullPrimitiveEXT-07034
   VUID-CullPrimitiveEXT-CullPrimitiveEXT-07035
   VUID-CullPrimitiveEXT-CullPrimitiveEXT-07036
  • Loading branch information
pmistryNV committed Jan 9, 2025
1 parent 1a0658f commit da790d9
Show file tree
Hide file tree
Showing 11 changed files with 1,837 additions and 59 deletions.
1 change: 1 addition & 0 deletions source/val/validate_annotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ spv_result_t ValidateDecorationTarget(ValidationState_t& _, spv::Decoration dec,
case spv::Decoration::Stream:
case spv::Decoration::RestrictPointer:
case spv::Decoration::AliasedPointer:
case spv::Decoration::PerPrimitiveEXT:
if (target->opcode() != spv::Op::OpVariable &&
target->opcode() != spv::Op::OpUntypedVariableKHR &&
target->opcode() != spv::Op::OpFunctionParameter &&
Expand Down
279 changes: 225 additions & 54 deletions source/val/validate_builtins.cpp

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions source/val/validate_decorations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,7 @@ spv_result_t CheckDecorationsOfEntryPoints(ValidationState_t& vstate) {
int num_workgroup_variables = 0;
int num_workgroup_variables_with_block = 0;
int num_workgroup_variables_with_aliased = 0;
bool has_task_payload = false;
for (const auto& desc : descs) {
std::unordered_set<Instruction*> seen_vars;
std::unordered_set<spv::BuiltIn> input_var_builtin;
Expand All @@ -786,6 +787,19 @@ spv_result_t CheckDecorationsOfEntryPoints(ValidationState_t& vstate) {
const auto sc_index = 2u;
const spv::StorageClass storage_class =
var_instr->GetOperandAs<spv::StorageClass>(sc_index);
if (vstate.version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
// SPV_EXT_mesh_shader, at most one task payload is permitted
// per entry point
if (storage_class == spv::StorageClass::TaskPayloadWorkgroupEXT) {
if (has_task_payload) {
return vstate.diag(SPV_ERROR_INVALID_ID, var_instr)
<< "There can be at most one OpVariable with storage "
"class TaskPayloadWorkgroupEXT associated with "
"an OpEntryPoint";
}
has_task_payload = true;
}
}
if (vstate.version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
// Starting in 1.4, OpEntryPoint must list all global variables
// it statically uses and those interfaces must be unique.
Expand Down
4 changes: 4 additions & 0 deletions source/val/validate_instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,10 @@ spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst) {
spv::ExecutionMode::LocalSizeId) {
_.RegisterEntryPointLocalSize(entry_point, inst);
}
if (inst->GetOperandAs<spv::ExecutionMode>(1) ==
spv::ExecutionMode::OutputPrimitivesEXT) {
_.RegisterEntryPointOutputPrimitivesEXT(entry_point, inst);
}
} else if (opcode == spv::Op::OpVariable) {
const auto storage_class = inst->GetOperandAs<spv::StorageClass>(2);
if (auto error = LimitCheckNumVars(_, inst->id(), storage_class)) {
Expand Down
53 changes: 51 additions & 2 deletions source/val/validate_mesh_shading.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,32 @@
// Validates ray query instructions from SPV_KHR_ray_query

#include "source/opcode.h"
#include "source/spirv_target_env.h"
#include "source/val/instruction.h"
#include "source/val/validate.h"
#include "source/val/validation_state.h"

namespace spvtools {
namespace val {

bool IsInterfaceVariable(ValidationState_t& _, const Instruction* inst,
spv::ExecutionModel model) {
bool foundInterface = false;
for (auto entry_point : _.entry_points()) {
const auto* models = _.GetExecutionModels(entry_point);
if (models->find(model) == models->end()) return false;
for (const auto& desc : _.entry_point_descriptions(entry_point)) {
for (auto interface : desc.interfaces) {
if (inst->id() == interface) {
foundInterface = true;
break;
}
}
}
}
return foundInterface;
}

spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst) {
const spv::Op opcode = inst->opcode();
switch (opcode) {
Expand Down Expand Up @@ -103,15 +122,45 @@ spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Primitive Count must be a 32-bit unsigned int scalar";
}

break;
}

case spv::Op::OpWritePackedPrimitiveIndices4x8NV: {
// No validation rules (for the moment).
break;
}

case spv::Op::OpVariable: {
if (_.HasCapability(spv::Capability::MeshShadingEXT)) {
bool meshInterfaceVar =
IsInterfaceVariable(_, inst, spv::ExecutionModel::MeshEXT);
bool fragInterfaceVar =
IsInterfaceVariable(_, inst, spv::ExecutionModel::Fragment);

const spv::StorageClass storage_class =
inst->GetOperandAs<spv::StorageClass>(2);
bool storage_output = (storage_class == spv::StorageClass::Output);
bool storage_input = (storage_class == spv::StorageClass::Input);

if (_.HasDecoration(inst->id(), spv::Decoration::PerPrimitiveEXT)) {
if (fragInterfaceVar && !storage_input) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "PerPrimitiveEXT decoration must be applied only to "
"variables in the Input Storage Class in the Fragment "
"Execution Model.";
}

if (meshInterfaceVar && !storage_output) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< _.VkErrorID(4336)
<< "PerPrimitiveEXT decoration must be applied only to "
"variables in the Output Storage Class in the "
"Storage Class in the MeshEXT Execution Model.";
}
}
}
break;
}
default:
break;
}
Expand Down
19 changes: 19 additions & 0 deletions source/val/validate_mode_setting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,15 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
"tessellation execution model.";
}
}
if (spvIsVulkanEnv(_.context()->target_env)) {
if (_.HasCapability(spv::Capability::MeshShadingEXT) &&
inst->GetOperandAs<uint32_t>(2) == 0) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< _.VkErrorID(7330)
<< "In mesh shaders using the MeshEXT Execution Model the "
"OutputVertices Execution Mode must be greater than 0";
}
}
break;
case spv::ExecutionMode::OutputLinesEXT:
case spv::ExecutionMode::OutputTrianglesEXT:
Expand All @@ -557,6 +566,16 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
"execution "
"model.";
}
if (mode == spv::ExecutionMode::OutputPrimitivesEXT &&
spvIsVulkanEnv(_.context()->target_env)) {
if (_.HasCapability(spv::Capability::MeshShadingEXT) &&
inst->GetOperandAs<uint32_t>(2) == 0) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< _.VkErrorID(7331)
<< "In mesh shaders using the MeshEXT Execution Model the "
"OutputPrimitivesEXT Execution Mode must be greater than 0";
}
}
break;
case spv::ExecutionMode::QuadDerivativesKHR:
if (!std::all_of(models->begin(), models->end(),
Expand Down
38 changes: 36 additions & 2 deletions source/val/validation_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2054,6 +2054,8 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
return VUID_WRAP(VUID-PrimitiveId-PrimitiveId-04330);
case 4334:
return VUID_WRAP(VUID-PrimitiveId-PrimitiveId-04334);
case 4336:
return VUID_WRAP(VUID-PrimitiveId-PrimitiveId-04336);
case 4337:
return VUID_WRAP(VUID-PrimitiveId-PrimitiveId-04337);
case 4345:
Expand Down Expand Up @@ -2382,30 +2384,62 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
return VUID_WRAP(VUID-StandaloneSpirv-OpTypeImage-06924);
case 6925:
return VUID_WRAP(VUID-StandaloneSpirv-Uniform-06925);
case 7034:
return VUID_WRAP(VUID-CullPrimitiveEXT-CullPrimitiveEXT-07034);
case 7035:
return VUID_WRAP(VUID-CullPrimitiveEXT-CullPrimitiveEXT-07035);
case 7036:
return VUID_WRAP(VUID-CullPrimitiveEXT-CullPrimitiveEXT-07036);
case 7038:
return VUID_WRAP(VUID-CullPrimitiveEXT-CullPrimitiveEXT-07038);
case 7039:
return VUID_WRAP(VUID-Layer-Layer-07039);
case 7040:
return VUID_WRAP(VUID-PrimitiveId-PrimitiveId-07040);
case 7041:
return VUID_WRAP(VUID-PrimitivePointIndicesEXT-PrimitivePointIndicesEXT-07041);
case 7042:
return VUID_WRAP(VUID-PrimitivePointIndicesEXT-PrimitivePointIndicesEXT-07042);
case 7043:
return VUID_WRAP(VUID-PrimitivePointIndicesEXT-PrimitivePointIndicesEXT-07043);
case 7044:
return VUID_WRAP(VUID-PrimitivePointIndicesEXT-PrimitivePointIndicesEXT-07044);
case 7046:
return VUID_WRAP(VUID-PrimitivePointIndicesEXT-PrimitivePointIndicesEXT-07046);
case 7047:
return VUID_WRAP(VUID-PrimitiveLineIndicesEXT-PrimitiveLineIndicesEXT-07047);
case 7048:
return VUID_WRAP(VUID-PrimitiveLineIndicesEXT-PrimitiveLineIndicesEXT-07048);
case 7049:
return VUID_WRAP(VUID-PrimitiveLineIndicesEXT-PrimitiveLineIndicesEXT-07049);
case 7050:
return VUID_WRAP(VUID-PrimitiveLineIndicesEXT-PrimitiveLineIndicesEXT-07050);
case 7052:
return VUID_WRAP(VUID-PrimitiveLineIndicesEXT-PrimitiveLineIndicesEXT-07052);
case 7053:
return VUID_WRAP(VUID-PrimitiveTriangleIndicesEXT-PrimitiveTriangleIndicesEXT-07053);
case 7054:
return VUID_WRAP(VUID-PrimitiveTriangleIndicesEXT-PrimitiveTriangleIndicesEXT-07054);
case 7055:
return VUID_WRAP(VUID-PrimitiveTriangleIndicesEXT-PrimitiveTriangleIndicesEXT-07055);
case 7056:
return VUID_WRAP(VUID-PrimitiveTriangleIndicesEXT-PrimitiveTriangleIndicesEXT-07056);
case 7058:
return VUID_WRAP(VUID-PrimitiveTriangleIndicesEXT-PrimitiveTriangleIndicesEXT-07058);
case 7059:
return VUID_WRAP(VUID-PrimitiveShadingRateKHR-PrimitiveShadingRateKHR-07059);
case 7060:
return VUID_WRAP(VUID-ViewportIndex-ViewportIndex-07060);
case 7102:
return VUID_WRAP(VUID-StandaloneSpirv-MeshEXT-07102);
case 7320:
return VUID_WRAP(VUID-StandaloneSpirv-ExecutionModel-07320);
case 7290:
return VUID_WRAP(VUID-StandaloneSpirv-Input-07290);
case 7320:
return VUID_WRAP(VUID-StandaloneSpirv-ExecutionModel-07320);
case 7330:
return VUID_WRAP(VUID-StandaloneSpirv-ExecutionModel-07330);
case 7331:
return VUID_WRAP(VUID-StandaloneSpirv-ExecutionModel-07331);
case 7650:
return VUID_WRAP(VUID-StandaloneSpirv-Base-07650);
case 7651:
Expand Down
22 changes: 22 additions & 0 deletions source/val/validation_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,24 @@ class ValidationState_t {
const Instruction* inst) {
entry_point_to_local_size_or_id_[entry_point] = inst;
}

/// Registers that the entry point maximum number of primitives
/// mesh shader will ever emit
void RegisterEntryPointOutputPrimitivesEXT(uint32_t entry_point,
const Instruction* inst) {
entry_point_to_output_primitives_[entry_point] = inst;
}

/// Returns the maximum number of primitives mesh shader can emit
uint32_t GetOutputPrimitivesEXT(uint32_t entry_point) {
auto entry = entry_point_to_output_primitives_.find(entry_point);
if (entry != entry_point_to_output_primitives_.end()) {
auto inst = entry->second;
return inst->GetOperandAs<uint32_t>(2);
}
return 0;
}

/// Returns whether the entry point declares its local size
bool EntryPointHasLocalSizeOrId(uint32_t entry_point) const {
return entry_point_to_local_size_or_id_.find(entry_point) !=
Expand Down Expand Up @@ -971,6 +989,10 @@ class ValidationState_t {
std::unordered_map<uint32_t, const Instruction*>
entry_point_to_local_size_or_id_;

// Mapping entry point -> OutputPrimitivesEXT execution mode instruction
std::unordered_map<uint32_t, const Instruction*>
entry_point_to_output_primitives_;

/// Mapping function -> array of entry points inside this
/// module which can (indirectly) call the function.
std::unordered_map<uint32_t, std::vector<uint32_t>> function_to_entry_points_;
Expand Down
Loading

0 comments on commit da790d9

Please sign in to comment.