Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion crates/hir-def/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::{
HygieneId,
path::{GenericArgs, Path},
},
type_ref::{Mutability, Rawness},
type_ref::{ConstRef, Mutability, Rawness},
};

pub use syntax::ast::{ArithOp, BinaryOp, CmpOp, LogicOp, Ordering, RangeOp, UnaryOp};
Expand Down Expand Up @@ -76,6 +76,38 @@ impl ExprOrPatId {
}
stdx::impl_from!(ExprId, PatId for ExprOrPatId);

// FIXME: Like ExprOrPatId above, eventually encode this as a single u32?
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, salsa::Update)]
pub enum TypeRefIdOrConstRef {
TypeRefId(TypeRefId),
ConstRef(ConstRef),
}

impl TypeRefIdOrConstRef {
pub fn as_type_ref(self) -> Option<TypeRefId> {
match self {
Self::TypeRefId(v) => Some(v),
_ => None,
}
}

pub fn is_type_ref(&self) -> bool {
matches!(self, Self::TypeRefId(_))
}

pub fn as_const_ref(self) -> Option<ConstRef> {
match self {
Self::ConstRef(v) => Some(v),
_ => None,
}
}

pub fn is_expr(&self) -> bool {
matches!(self, Self::ConstRef(_))
}
}
stdx::impl_from!(TypeRefId, ConstRef for TypeRefIdOrConstRef);

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Label {
pub name: Name,
Expand Down
25 changes: 23 additions & 2 deletions crates/hir-ty/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use hir_def::{
TupleFieldId, TupleId, VariantId,
attrs::AttrFlags,
expr_store::{Body, ExpressionStore, HygieneId, path::Path},
hir::{BindingId, ExprId, ExprOrPatId, LabelId, PatId},
hir::{BindingId, ExprId, ExprOrPatId, LabelId, PatId, TypeRefIdOrConstRef},
lang_item::LangItems,
layout::Integer,
resolver::{HasResolver, ResolveValueResult, Resolver, TypeNs, ValueNs},
Expand Down Expand Up @@ -95,7 +95,7 @@ use crate::{
},
method_resolution::CandidateId,
next_solver::{
AliasTy, Const, ConstKind, DbInterner, ErrorGuaranteed, GenericArgs, Region,
AliasTy, Const, ConstKind, DbInterner, ErrorGuaranteed, GenericArgs, Region, StoredConst,
StoredGenericArg, StoredGenericArgs, StoredTy, StoredTys, Term, Ty, TyKind, Tys,
abi::Safety,
infer::{InferCtxt, ObligationInspector, traits::ObligationCause},
Expand Down Expand Up @@ -705,6 +705,7 @@ pub struct InferenceResult {
pub(crate) type_of_pat: ArenaMap<PatId, StoredTy>,
pub(crate) type_of_binding: ArenaMap<BindingId, StoredTy>,
pub(crate) type_of_type_placeholder: FxHashMap<TypeRefId, StoredTy>,
pub(crate) const_of_const_placeholder: FxHashMap<TypeRefIdOrConstRef, StoredConst>,
pub(crate) type_of_opaque: FxHashMap<InternedOpaqueTyId, StoredTy>,

/// Whether there are any type-mismatching errors in the result.
Expand Down Expand Up @@ -1007,6 +1008,7 @@ impl InferenceResult {
type_of_pat: Default::default(),
type_of_binding: Default::default(),
type_of_type_placeholder: Default::default(),
const_of_const_placeholder: Default::default(),
type_of_opaque: Default::default(),
skipped_ref_pats: Default::default(),
has_errors: Default::default(),
Expand Down Expand Up @@ -1082,6 +1084,16 @@ impl InferenceResult {
pub fn type_of_type_placeholder<'db>(&self, type_ref: TypeRefId) -> Option<Ty<'db>> {
self.type_of_type_placeholder.get(&type_ref).map(|ty| ty.as_ref())
}
pub fn placeholder_consts<'db>(
&self,
) -> impl Iterator<Item = (TypeRefIdOrConstRef, Const<'db>)> {
self.const_of_const_placeholder
.iter()
.map(|(&type_ref_or_const, const_)| (type_ref_or_const, const_.as_ref()))
}
pub fn const_of_const_placeholder<'db>(&self, expr: TypeRefIdOrConstRef) -> Option<Const<'db>> {
self.const_of_const_placeholder.get(&expr).map(|ty| ty.as_ref())
}
pub fn type_of_expr_or_pat<'db>(&self, id: ExprOrPatId) -> Option<Ty<'db>> {
match id {
ExprOrPatId::ExprId(id) => self.type_of_expr.get(id).map(|it| it.as_ref()),
Expand Down Expand Up @@ -1365,6 +1377,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
type_of_pat,
type_of_binding,
type_of_type_placeholder,
const_of_const_placeholder,
type_of_opaque,
has_errors: _,
diagnostics: _,
Expand Down Expand Up @@ -1396,6 +1409,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
merge_arena_maps(type_of_pat, &other.type_of_pat);
merge_arena_maps(type_of_binding, &other.type_of_binding);
merge_hash_maps(type_of_type_placeholder, &other.type_of_type_placeholder);
merge_hash_maps(const_of_const_placeholder, &other.const_of_const_placeholder);
merge_hash_maps(type_of_opaque, &other.type_of_opaque);
merge_hash_maps(expr_adjustments, &other.expr_adjustments);
merge_hash_maps(pat_adjustments, &other.pat_adjustments);
Expand Down Expand Up @@ -1531,6 +1545,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
type_of_pat,
type_of_binding,
type_of_type_placeholder,
const_of_const_placeholder,
type_of_opaque,
skipped_ref_pats,
closures_data,
Expand Down Expand Up @@ -1569,6 +1584,10 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
resolver.resolve_completely(ty);
}
type_of_type_placeholder.shrink_to_fit();
for const_ in const_of_const_placeholder.values_mut() {
resolver.resolve_completely(const_);
}
const_of_const_placeholder.shrink_to_fit();
type_of_opaque.shrink_to_fit();

if let Some(nodes_with_type_mismatches) = nodes_with_type_mismatches {
Expand Down Expand Up @@ -1868,6 +1887,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
InferenceTyDiagnosticSource::Body => Some(&mut InferenceTyLoweringVarsCtx {
table: &mut self.table,
type_of_type_placeholder: &mut self.result.type_of_type_placeholder,
const_of_const_placeholder: &mut self.result.const_of_const_placeholder,
} as _),
InferenceTyDiagnosticSource::Signature => None,
};
Expand Down Expand Up @@ -2203,6 +2223,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
let mut vars_ctx = InferenceTyLoweringVarsCtx {
table: &mut self.table,
type_of_type_placeholder: &mut self.result.type_of_type_placeholder,
const_of_const_placeholder: &mut self.result.const_of_const_placeholder,
};
let mut ctx = TyLoweringContext::new(
self.db,
Expand Down
17 changes: 16 additions & 1 deletion crates/hir-ty/src/infer/diagnostics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@ use std::ops::{Deref, DerefMut};

use either::Either;
use hir_def::expr_store::path::Path;
use hir_def::hir::TypeRefIdOrConstRef;
use hir_def::type_ref::ConstRef;
use hir_def::{ExpressionStoreOwnerId, GenericDefId};
use hir_def::{expr_store::ExpressionStore, type_ref::TypeRefId};
use hir_def::{hir::ExprOrPatId, resolver::Resolver};
use la_arena::{Idx, RawIdx};
use rustc_hash::FxHashMap;
use thin_vec::ThinVec;

use crate::next_solver::StoredConst;
use crate::{
InferenceDiagnostic, InferenceTyDiagnosticSource, Span, TyLoweringDiagnostic,
db::{AnonConstId, HirDatabase},
Expand Down Expand Up @@ -61,6 +64,7 @@ pub(crate) struct PathDiagnosticCallbackData<'a> {
pub(super) struct InferenceTyLoweringVarsCtx<'a, 'db> {
pub(super) table: &'a mut InferenceTable<'db>,
pub(super) type_of_type_placeholder: &'a mut FxHashMap<TypeRefId, StoredTy>,
pub(super) const_of_const_placeholder: &'a mut FxHashMap<TypeRefIdOrConstRef, StoredConst>,
}

impl<'db> TyLoweringInferVarsCtx<'db> for InferenceTyLoweringVarsCtx<'_, 'db> {
Expand All @@ -74,7 +78,18 @@ impl<'db> TyLoweringInferVarsCtx<'db> for InferenceTyLoweringVarsCtx<'_, 'db> {
ty
}
fn next_const_var(&mut self, span: Span) -> Const<'db> {
self.table.infer_ctxt.next_const_var(span)
let const_ = self.table.infer_ctxt.next_const_var(span);

let type_ref_id_or_const_ref = match span {
Span::ExprId(expr) => Some(TypeRefIdOrConstRef::ConstRef(ConstRef { expr })),
Span::TypeRefId(type_ref) => Some(type_ref.into()),
_ => None,
};
if let Some(key) = type_ref_id_or_const_ref {
self.const_of_const_placeholder.insert(key, const_.store());
}

const_
}
fn next_region_var(&mut self, span: Span) -> Region<'db> {
self.table.infer_ctxt.next_region_var(span)
Expand Down
1 change: 1 addition & 0 deletions crates/hir-ty/src/infer/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ impl<'db> InferenceContext<'_, 'db> {
let mut vars_ctx = InferenceTyLoweringVarsCtx {
table: &mut self.table,
type_of_type_placeholder: &mut self.result.type_of_type_placeholder,
const_of_const_placeholder: &mut self.result.const_of_const_placeholder,
};
let mut ctx = TyLoweringContext::new(
self.db,
Expand Down
11 changes: 10 additions & 1 deletion crates/hir-ty/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ use hir_def::{
TypeParamId,
db::DefDatabase,
expr_store::{Body, ExpressionStore},
hir::{BindingId, ExprId, ExprOrPatId, PatId},
hir::{BindingId, ExprId, ExprOrPatId, PatId, TypeRefIdOrConstRef},
resolver::{HasResolver, Resolver, TypeNs},
type_ref::{Rawness, TypeRefId},
};
Expand Down Expand Up @@ -539,6 +539,15 @@ impl From<ExprOrPatId> for Span {
}
}

impl From<TypeRefIdOrConstRef> for Span {
fn from(value: TypeRefIdOrConstRef) -> Self {
match value {
TypeRefIdOrConstRef::TypeRefId(idx) => idx.into(),
TypeRefIdOrConstRef::ConstRef(idx) => idx.expr.into(),
}
}
}

impl Span {
pub(crate) fn pick_best(a: Span, b: Span) -> Span {
// We prefer dummy spans to minimize the risk of false errors.
Expand Down
21 changes: 21 additions & 0 deletions crates/hir-ty/src/next_solver/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,15 @@ impl<'db> TypeVisitable<DbInterner<'db>> for Const<'db> {
}
}

impl<'db> TypeVisitable<DbInterner<'db>> for StoredConst {
fn visit_with<V: rustc_type_ir::TypeVisitor<DbInterner<'db>>>(
&self,
visitor: &mut V,
) -> V::Result {
self.as_ref().visit_with(visitor)
}
}

impl<'db> TypeSuperVisitable<DbInterner<'db>> for Const<'db> {
fn super_visit_with<V: rustc_type_ir::TypeVisitor<DbInterner<'db>>>(
&self,
Expand Down Expand Up @@ -248,6 +257,18 @@ impl<'db> TypeFoldable<DbInterner<'db>> for Const<'db> {
}
}

impl<'db> TypeFoldable<DbInterner<'db>> for StoredConst {
fn try_fold_with<F: rustc_type_ir::FallibleTypeFolder<DbInterner<'db>>>(
self,
folder: &mut F,
) -> Result<Self, F::Error> {
Ok(self.as_ref().try_fold_with(folder)?.store())
}
fn fold_with<F: rustc_type_ir::TypeFolder<DbInterner<'db>>>(self, folder: &mut F) -> Self {
self.as_ref().fold_with(folder).store()
}
}

impl<'db> TypeSuperFoldable<DbInterner<'db>> for Const<'db> {
fn try_super_fold_with<F: rustc_type_ir::FallibleTypeFolder<DbInterner<'db>>>(
self,
Expand Down
34 changes: 33 additions & 1 deletion crates/hir-ty/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use hir_def::{
AdtId, AssocItemId, DefWithBodyId, GenericDefId, HasModule, Lookup, ModuleDefId, ModuleId,
SyntheticSyntax, VariantId,
expr_store::{Body, BodySourceMap, ExpressionStore, ExpressionStoreSourceMap},
hir::{ExprId, Pat, PatId},
hir::{ExprId, Pat, PatId, TypeRefIdOrConstRef},
item_scope::ItemScope,
nameres::DefMap,
src::HasSource,
Expand Down Expand Up @@ -83,6 +83,7 @@ fn check_impl(
let mut had_annotations = false;
let mut mismatches = FxHashMap::default();
let mut types = FxHashMap::default();
let mut consts = FxHashMap::default();
let mut adjustments = FxHashMap::default();
for (file_id, annotations) in db.extract_annotations() {
for (range, expected) in annotations {
Expand All @@ -91,6 +92,8 @@ fn check_impl(
types.insert(file_range, expected);
} else if let Some(ty) = expected.strip_prefix("type: ") {
types.insert(file_range, ty.to_owned());
} else if let Some(const_) = expected.strip_prefix("const: ") {
consts.insert(file_range, const_.to_owned());
} else if expected.starts_with("expected") {
mismatches.insert(file_range, expected);
} else if let Some(adjs) = expected.strip_prefix("adjustments:") {
Expand Down Expand Up @@ -243,6 +246,29 @@ fn check_impl(
assert_eq!(actual, expected, "type annotation differs at {:#?}", range.range);
}
}

for (type_ref_id_or_const_ref, const_) in inference_result.placeholder_consts() {
let node = match type_ref_id_or_const_ref {
TypeRefIdOrConstRef::TypeRefId(type_ref) => {
type_node(body_source_map, type_ref, &db)
}
TypeRefIdOrConstRef::ConstRef(const_ref) => {
expr_node(body_source_map, const_ref.expr, &db)
}
};
let Some(node) = node else { continue };
let range = node.as_ref().original_file_range_rooted(&db);
if let Some(expected) = consts.remove(&range) {
let actual = salsa::attach(&db, || {
if display_source {
const_.display_source_code(&db, def.module(&db), true).unwrap()
} else {
const_.display_test(&db, display_target).to_string()
}
});
assert_eq!(actual, expected, "const annotation differs at {:#?}", range.range);
}
}
}

let mut buf = String::new();
Expand All @@ -261,6 +287,12 @@ fn check_impl(
format_to!(buf, "{:?}: type {}\n", t.0.range, t.1);
}
}
if !consts.is_empty() {
format_to!(buf, "Unchecked const annotations:\n");
for c in consts {
format_to!(buf, "{:?}: const {}\n", c.0.range, c.1);
}
}
if !adjustments.is_empty() {
format_to!(buf, "Unchecked adjustments annotations:\n");
for t in adjustments {
Expand Down
22 changes: 12 additions & 10 deletions crates/hir-ty/src/tests/display_source_code.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::check_types_source_code;
use super::{check, check_types_source_code};

#[test]
fn qualify_path_to_submodule() {
Expand Down Expand Up @@ -248,19 +248,21 @@ fn test() {
}

#[test]
fn type_placeholder_type() {
check_types_source_code(
fn type_and_const_placeholders() {
check(
r#"
struct S<T>(T);
struct S<T, const N: usize>([T; N]);
fn test() {
let f: S<_> = S(3);
//^ i32
let f: S<_, _> = S([1, 2]);
//^ type: i32
//^ const: 2
let f: [_; _] = [4_u32, 5, 6];
//^ u32
//^ type: u32
//^ const: 3
let f: (_, _, _) = (1_u32, 1_i32, false);
//^ u32
//^ i32
//^ bool
//^ type: u32
//^ type: i32
//^ type: bool
}
"#,
);
Expand Down
Loading
Loading