Skip to content

Commit

Permalink
Fix the type info for elements of arrays and tuples in type_system_v2.
Browse files Browse the repository at this point in the history
In a case like `const X: u32[3] = [1, 2, 3]`, we were previously allowing the literals in the array to just assume the unified RHS element type of u2, only factoring the LHS into the overall type of [1, 2, 3].

PiperOrigin-RevId: 714263472
  • Loading branch information
richmckeever authored and copybara-github committed Jan 11, 2025
1 parent 1019538 commit a00d2c2
Show file tree
Hide file tree
Showing 4 changed files with 576 additions and 620 deletions.
1 change: 1 addition & 0 deletions xls/dslx/type_system_v2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down
116 changes: 71 additions & 45 deletions xls/dslx/type_system_v2/inference_table_to_type_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/functional/function_ref.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
Expand Down Expand Up @@ -593,20 +594,22 @@ class InferenceTableConverter {
XLS_ASSIGN_OR_RETURN(
std::vector<const TypeAnnotation*> actual_arg_annotations,
table_.GetTypeAnnotationsForTypeVariable(*actual_arg_type_var));
XLS_RETURN_IF_ERROR(
ResolveVariableTypeAnnotations(invocation, actual_arg_annotations));
TypeInfo* actual_arg_ti = base_type_info_;
if (invocation->caller_invocation().has_value()) {
actual_arg_ti =
invocation_type_info_.at(*invocation->caller_invocation());
}

// The type variable for the actual argument should have at least one
// annotation associated with it that came from the formal argument and is
// therefore dependent on the parametric we are solving for. Let's unify
// just the independent annotations(s) for the purposes of solving for the
// variable.
RemoveAnnotationsReferringToNamesWithoutTypeInfo(actual_arg_ti,
actual_arg_annotations);
auto accept_predicate = [&](const TypeAnnotation* annotation) {
return !HasAnyReferencesWithMissingTypeInfo(actual_arg_ti, annotation);
};
XLS_RETURN_IF_ERROR(ResolveVariableTypeAnnotations(
invocation, actual_arg_annotations, accept_predicate));
if (actual_arg_annotations.empty()) {
VLOG(5) << "The actual argument type variable: "
<< (*actual_arg_type_var)->ToString()
Expand Down Expand Up @@ -693,15 +696,21 @@ class InferenceTableConverter {
// type annotations that have been associated with the given type variable. If
// the information has unreconcilable conflicts, returns an error. The given
// `parametric_invocation` argument is used as a context for the evaluation of
// any expressions inside the type annotations.
// any expressions inside the type annotations. If an `accept_predicate` is
// specified, then annotations not accepted by the predicate are ignored.
absl::StatusOr<const TypeAnnotation*> UnifyTypeAnnotations(
std::optional<const ParametricInvocation*> parametric_invocation,
const NameRef* type_variable, const Span& span) {
const NameRef* type_variable, const Span& span,
std::optional<absl::FunctionRef<bool(const TypeAnnotation*)>>
accept_predicate = std::nullopt) {
VLOG(5) << "Unifying type annotations for variable "
<< type_variable->ToString();
XLS_ASSIGN_OR_RETURN(
std::vector<const TypeAnnotation*> annotations,
table_.GetTypeAnnotationsForTypeVariable(type_variable));
if (accept_predicate.has_value()) {
FilterAnnotations(annotations, *accept_predicate);
}
XLS_ASSIGN_OR_RETURN(
const TypeAnnotation* result,
UnifyTypeAnnotations(parametric_invocation, annotations, span));
Expand Down Expand Up @@ -913,10 +922,14 @@ class InferenceTableConverter {
// Returns `annotation` with any `TypeVariableTypeAnnotation`s replaced with
// the unifications of the corresponding variables. The original `annotation`
// is returned if there is nothing to replace, preserving the ability to look
// it up in `auto_literal_annotations_`.
// it up in `auto_literal_annotations_`. If `accept_predicate` is specified,
// then it is used to filter the annotations associated with encountered type
// variables (the predicate is not applied to the input `annotation` itself).
absl::StatusOr<const TypeAnnotation*> ResolveVariableTypeAnnotations(
std::optional<const ParametricInvocation*> parametric_invocation,
const TypeAnnotation* annotation) {
const TypeAnnotation* annotation,
std::optional<absl::FunctionRef<bool(const TypeAnnotation*)>>
accept_predicate = std::nullopt) {
bool replaced_anything = false;
XLS_ASSIGN_OR_RETURN(
AstNode * clone,
Expand All @@ -931,7 +944,7 @@ class InferenceTableConverter {
UnifyTypeAnnotations(
parametric_invocation,
variable_type_annotation->type_variable(),
annotation->span()));
annotation->span(), accept_predicate));
replaced_anything = true;
return const_cast<TypeAnnotation*>(unified);
}
Expand All @@ -946,15 +959,25 @@ class InferenceTableConverter {
}

// Variant that deeply resolves all `TypeVariableTypeAnnotation`s within a
// vector of annotations.
// vector of annotations. If `accept_predicate` is specified, then any
// annotations not accepted by the predicate are filtered from both
// `annotations` and the expansions of any encountered type variables.
absl::Status ResolveVariableTypeAnnotations(
std::optional<const ParametricInvocation*> parametric_invocation,
std::vector<const TypeAnnotation*>& annotations) {
for (int i = 0; i < annotations.size(); i++) {
XLS_ASSIGN_OR_RETURN(annotations[i],
ResolveVariableTypeAnnotations(parametric_invocation,
annotations[i]));
std::vector<const TypeAnnotation*>& annotations,
std::optional<absl::FunctionRef<bool(const TypeAnnotation*)>>
accept_predicate = std::nullopt) {
std::vector<const TypeAnnotation*> result;
for (const TypeAnnotation* annotation : annotations) {
if (!accept_predicate.has_value() || (*accept_predicate)(annotation)) {
XLS_ASSIGN_OR_RETURN(
const TypeAnnotation* resolved_annotation,
ResolveVariableTypeAnnotations(parametric_invocation, annotation,
accept_predicate));
result.push_back(resolved_annotation);
}
}
annotations = std::move(result);
return absl::OkStatus();
}

Expand Down Expand Up @@ -1095,36 +1118,39 @@ class InferenceTableConverter {
return context_invocation;
}

// Removes any annotations in the given vector that contain any `NameRef`
// whose type info has not (yet) been generated. The effective `TypeInfo` for
// each annotation is either `default_ti`; or, for invocation-scoped
// annotations, the `TypeInfo` for the relevant parametric invocation.
void RemoveAnnotationsReferringToNamesWithoutTypeInfo(
TypeInfo* default_ti, std::vector<const TypeAnnotation*>& annotations) {
annotations.erase(
std::remove_if(
annotations.begin(), annotations.end(),
[&](const TypeAnnotation* annotation) {
TypeInfo* ti = default_ti;
const auto it =
invocation_scoped_type_annotations_.find(annotation);
if (it != invocation_scoped_type_annotations_.end()) {
ti = invocation_type_info_.at(it->second);
}
FreeVariables vars =
GetFreeVariablesByLambda(annotation, [&](const NameRef& ref) {
if (!std::holds_alternative<const NameDef*>(
ref.name_def())) {
return false;
}
const NameDef* name_def =
std::get<const NameDef*>(ref.name_def());
return !ti->GetItem(name_def).has_value() &&
!ti->IsKnownConstExpr(name_def);
});
return vars.GetFreeVariableCount() > 0;
}),
annotations.end());
// Removes any annotations in the given vector for which `accept_predicate`
// returns false.
void FilterAnnotations(
std::vector<const TypeAnnotation*>& annotations,
absl::FunctionRef<bool(const TypeAnnotation*)> accept_predicate) {
annotations.erase(std::remove_if(annotations.begin(), annotations.end(),
[&](const TypeAnnotation* annotation) {
return !accept_predicate(annotation);
}),
annotations.end());
}

// Returns true if `annotation` contains any `NameRef` whose type info has not
// (yet) been generated. The effective `TypeInfo` is either `default_ti`; or,
// for invocation-scoped annotations, the `TypeInfo` for the relevant
// parametric invocation.
bool HasAnyReferencesWithMissingTypeInfo(TypeInfo* default_ti,
const TypeAnnotation* annotation) {
TypeInfo* ti = default_ti;
const auto it = invocation_scoped_type_annotations_.find(annotation);
if (it != invocation_scoped_type_annotations_.end()) {
ti = invocation_type_info_.at(it->second);
}
FreeVariables vars =
GetFreeVariablesByLambda(annotation, [&](const NameRef& ref) {
if (!std::holds_alternative<const NameDef*>(ref.name_def())) {
return false;
}
const NameDef* name_def = std::get<const NameDef*>(ref.name_def());
return !ti->GetItem(name_def).has_value() &&
!ti->IsKnownConstExpr(name_def);
});
return vars.GetFreeVariableCount() > 0;
}

// Returns a `SignednessAndSize` that agrees with the two given
Expand Down
45 changes: 43 additions & 2 deletions xls/dslx/type_system_v2/typecheck_module_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,18 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault {
// -----------------------------------------------
// FOO (u32, (s8, u32)) T0
// (4, (-2, 5)) (var:M0, var:M1) T0
// 4 M0
// (-2, 5) M1
// 4 u32 M0
// (-2, 5) (s8, u32) M1
//
// Recursive descent will ultimately put auto annotations for the literals
// in the table. Upon conversion of the table to type info, unification of
// the LHS annotation with the variable-based RHS annotation will be
// attempted.

XLS_ASSIGN_OR_RETURN(
std::optional<const TupleTypeAnnotation*> tuple_annotation,
GetDeclarationTypeAnnotation<TupleTypeAnnotation>(node));

// Create the M0, M1, ... variables and apply them to the members.
std::vector<TypeAnnotation*> member_types;
member_types.reserve(node->members().size());
Expand All @@ -189,6 +193,10 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault {
InferenceVariableKind::kType, member,
GenerateInternalTypeVariableName(member)));
XLS_RETURN_IF_ERROR(table_.SetTypeVariable(member, member_variable));
if (tuple_annotation.has_value()) {
XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(
member, (*tuple_annotation)->members()[i]));
}
member_types.push_back(
module_.Make<TypeVariableTypeAnnotation>(member_variable));
}
Expand Down Expand Up @@ -272,6 +280,10 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault {
// this unification will fail if the array is inadequately annotated (e.g.
// no explicit annotation on a zero-size or elliptical array).

XLS_ASSIGN_OR_RETURN(
std::optional<const ArrayTypeAnnotation*> array_annotation,
GetDeclarationTypeAnnotation<ArrayTypeAnnotation>(node));

// An empty array can't end with an ellipsis, even if unification is
// possible.
if (node->has_ellipsis() && node->members().empty()) {
Expand All @@ -298,6 +310,10 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault {
for (Expr* member : node->members()) {
XLS_RETURN_IF_ERROR(
table_.SetTypeVariable(member, element_type_variable));
if (array_annotation.has_value()) {
XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(
member, (*array_annotation)->element_type()));
}
}
Expr* element_count = module_.Make<Number>(
node->span(), absl::StrCat(node->members().size()), NumberKind::kOther,
Expand Down Expand Up @@ -790,6 +806,31 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault {
return absl::OkStatus();
}

// Gets the explicit type annotation (of type `T`) for a node by querying the
// type variable that it shares with a declaration, if any. This must be done
// before imposing any synthetic type annotation on the value.
template <typename T>
absl::StatusOr<std::optional<const T*>> GetDeclarationTypeAnnotation(
const AstNode* node) {
std::optional<const NameRef*> type_variable = table_.GetTypeVariable(node);
if (!type_variable.has_value()) {
return std::nullopt;
}
XLS_ASSIGN_OR_RETURN(
std::vector<const TypeAnnotation*> annotations,
table_.GetTypeAnnotationsForTypeVariable(*type_variable));
if (annotations.empty()) {
return std::nullopt;
}
// If > 1, the caller is ignoring the "before imposing an annotation on the
// RHS" precondition.
CHECK_EQ(annotations.size(), 1);
if (const auto* result = dynamic_cast<const T*>(annotations[0])) {
return result;
}
return std::nullopt;
}

Module& module_;
InferenceTable& table_;
const FileTable& file_table_;
Expand Down
Loading

0 comments on commit a00d2c2

Please sign in to comment.