checker: comptime match only eval true branch

This commit is contained in:
kbkpbot 2025-09-03 08:06:27 +08:00
parent f6b60e4d9f
commit b442fc6349
2 changed files with 178 additions and 158 deletions

View file

@ -221,191 +221,194 @@ fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type {
} }
} }
} }
if node.is_expr {
c.stmts_ending_with_expression(mut branch.stmts, c.expected_or_type) if !node.is_comptime || (node.is_comptime && comptime_match_branch_result) {
} else { if node.is_expr {
c.stmts(mut branch.stmts) c.stmts_ending_with_expression(mut branch.stmts, c.expected_or_type)
} } else {
c.smartcast_mut_pos = token.Pos{} c.stmts(mut branch.stmts)
c.smartcast_cond_pos = token.Pos{}
if node.is_expr {
if branch.stmts.len == 0 && ret_type != ast.void_type {
c.error('`match` expression requires an expression as the last statement of every branch',
branch.branch_pos)
} }
} c.smartcast_mut_pos = token.Pos{}
if !branch.is_else && cond_is_option && branch.exprs.any(it !is ast.None) { c.smartcast_cond_pos = token.Pos{}
c.error('`match` expression with Option type only checks against `none`, to match its value you must unwrap it first `var?`', if node.is_expr {
branch.pos) if branch.stmts.len == 0 && ret_type != ast.void_type {
} c.error('`match` expression requires an expression as the last statement of every branch',
if cond_type_sym.kind == .none { branch.branch_pos)
c.error('`none` cannot be a match condition', node.pos)
}
// If the last statement is an expression, return its type
if branch.stmts.len > 0 && node.is_expr {
mut stmt := branch.stmts.last()
if mut stmt is ast.ExprStmt {
c.expected_type = if c.expected_expr_type != ast.void_type {
c.expected_expr_type
} else {
node.expected_type
} }
expr_type := c.unwrap_generic(if stmt.expr is ast.CallExpr { }
stmt.typ if !branch.is_else && cond_is_option && branch.exprs.any(it !is ast.None) {
} else { c.error('`match` expression with Option type only checks against `none`, to match its value you must unwrap it first `var?`',
c.expr(mut stmt.expr) branch.pos)
}) }
unwrapped_expected_type := c.unwrap_generic(node.expected_type) if cond_type_sym.kind == .none {
must_be_option = must_be_option || expr_type == ast.none_type c.error('`none` cannot be a match condition', node.pos)
stmt.typ = expr_type }
if first_iteration { // If the last statement is an expression, return its type
if unwrapped_expected_type.has_option_or_result() if branch.stmts.len > 0 && node.is_expr {
|| c.table.type_kind(unwrapped_expected_type) in [.sum_type, .multi_return] { mut stmt := branch.stmts.last()
c.check_match_branch_last_stmt(stmt, unwrapped_expected_type, if mut stmt is ast.ExprStmt {
expr_type) c.expected_type = if c.expected_expr_type != ast.void_type {
ret_type = node.expected_type c.expected_expr_type
} else { } else {
ret_type = expr_type node.expected_type
if expr_type.is_ptr() { }
if stmt.expr is ast.Ident && stmt.expr.obj is ast.Var expr_type := c.unwrap_generic(if stmt.expr is ast.CallExpr {
&& c.table.is_interface_var(stmt.expr.obj) { stmt.typ
ret_type = expr_type.deref() } else {
} else if mut stmt.expr is ast.PrefixExpr c.expr(mut stmt.expr)
&& stmt.expr.right is ast.Ident { })
ident := stmt.expr.right as ast.Ident unwrapped_expected_type := c.unwrap_generic(node.expected_type)
if ident.obj is ast.Var && c.table.is_interface_var(ident.obj) { must_be_option = must_be_option || expr_type == ast.none_type
stmt.typ = expr_type
if first_iteration {
if unwrapped_expected_type.has_option_or_result()
|| c.table.type_kind(unwrapped_expected_type) in [.sum_type, .multi_return] {
c.check_match_branch_last_stmt(stmt, unwrapped_expected_type,
expr_type)
ret_type = node.expected_type
} else {
ret_type = expr_type
if expr_type.is_ptr() {
if stmt.expr is ast.Ident && stmt.expr.obj is ast.Var
&& c.table.is_interface_var(stmt.expr.obj) {
ret_type = expr_type.deref() ret_type = expr_type.deref()
} else if mut stmt.expr is ast.PrefixExpr
&& stmt.expr.right is ast.Ident {
ident := stmt.expr.right as ast.Ident
if ident.obj is ast.Var && c.table.is_interface_var(ident.obj) {
ret_type = expr_type.deref()
}
}
}
c.expected_expr_type = expr_type
}
infer_cast_type = stmt.typ
if mut stmt.expr is ast.CastExpr {
need_explicit_cast = true
infer_cast_type = stmt.expr.typ
}
} else {
if ret_type.idx() != expr_type.idx() {
if unwrapped_expected_type.has_option_or_result()
&& c.table.sym(stmt.typ).kind == .struct
&& !c.check_types(expr_type, c.unwrap_generic(ret_type))
&& c.type_implements(stmt.typ, ast.error_type, node.pos) {
stmt.expr = ast.CastExpr{
expr: stmt.expr
typname: 'IError'
typ: ast.error_type
expr_type: stmt.typ
pos: node.pos
}
stmt.typ = ast.error_type
} else {
c.check_match_branch_last_stmt(stmt, c.unwrap_generic(ret_type),
expr_type)
if ret_type.is_number() && expr_type.is_number() && !c.inside_return {
ret_type = c.promote_num(ret_type, expr_type)
} }
} }
} }
c.expected_expr_type = expr_type if must_be_option && ret_type == ast.none_type && expr_type != ret_type {
} ret_type = expr_type.set_flag(.option)
infer_cast_type = stmt.typ }
if mut stmt.expr is ast.CastExpr { if stmt.typ != ast.error_type && !is_noreturn_callexpr(stmt.expr) {
need_explicit_cast = true ret_sym := c.table.sym(ret_type)
infer_cast_type = stmt.expr.typ stmt_sym := c.table.sym(stmt.typ)
} if ret_sym.kind !in [.sum_type, .interface]
} else { && stmt_sym.kind in [.sum_type, .interface] {
if ret_type.idx() != expr_type.idx() { c.error('return type mismatch, it should be `${ret_sym.name}`, but it is instead `${c.table.type_to_str(expr_type)}`',
if unwrapped_expected_type.has_option_or_result() stmt.pos)
&& c.table.sym(stmt.typ).kind == .struct
&& !c.check_types(expr_type, c.unwrap_generic(ret_type))
&& c.type_implements(stmt.typ, ast.error_type, node.pos) {
stmt.expr = ast.CastExpr{
expr: stmt.expr
typname: 'IError'
typ: ast.error_type
expr_type: stmt.typ
pos: node.pos
} }
stmt.typ = ast.error_type if ret_type.nr_muls() != stmt.typ.nr_muls()
} else { && stmt.typ.idx() !in [ast.voidptr_type_idx, ast.nil_type_idx] {
c.check_match_branch_last_stmt(stmt, c.unwrap_generic(ret_type), type_name := '&'.repeat(ret_type.nr_muls()) + ret_sym.name
expr_type) c.error('return type mismatch, it should be `${type_name}`, but it is instead `${c.table.type_to_str(expr_type)}`',
if ret_type.is_number() && expr_type.is_number() && !c.inside_return { stmt.pos)
ret_type = c.promote_num(ret_type, expr_type)
} }
} }
} if !node.is_sum_type {
if must_be_option && ret_type == ast.none_type && expr_type != ret_type { if mut stmt.expr is ast.CastExpr {
ret_type = expr_type.set_flag(.option) expr_typ_sym := c.table.sym(stmt.expr.typ)
} if need_explicit_cast {
if stmt.typ != ast.error_type && !is_noreturn_callexpr(stmt.expr) { if infer_cast_type != stmt.expr.typ
ret_sym := c.table.sym(ret_type) && expr_typ_sym.kind !in [.interface, .sum_type] {
stmt_sym := c.table.sym(stmt.typ) c.error('the type of the last expression in the first match branch was an explicit `${c.table.type_to_str(infer_cast_type)}`, not `${c.table.type_to_str(stmt.expr.typ)}`',
if ret_sym.kind !in [.sum_type, .interface] stmt.pos)
&& stmt_sym.kind in [.sum_type, .interface] { }
c.error('return type mismatch, it should be `${ret_sym.name}`, but it is instead `${c.table.type_to_str(expr_type)}`', } else {
stmt.pos) if infer_cast_type != stmt.expr.typ
} && expr_typ_sym.kind !in [.interface, .sum_type]
if ret_type.nr_muls() != stmt.typ.nr_muls() && c.promote_num(stmt.expr.typ, ast.int_type) != ast.int_type {
&& stmt.typ.idx() !in [ast.voidptr_type_idx, ast.nil_type_idx] { c.error('the type of the last expression of the first match branch was `${c.table.type_to_str(infer_cast_type)}`, which is not compatible with `${c.table.type_to_str(stmt.expr.typ)}`',
type_name := '&'.repeat(ret_type.nr_muls()) + ret_sym.name stmt.pos)
c.error('return type mismatch, it should be `${type_name}`, but it is instead `${c.table.type_to_str(expr_type)}`', }
stmt.pos)
}
}
if !node.is_sum_type {
if mut stmt.expr is ast.CastExpr {
expr_typ_sym := c.table.sym(stmt.expr.typ)
if need_explicit_cast {
if infer_cast_type != stmt.expr.typ
&& expr_typ_sym.kind !in [.interface, .sum_type] {
c.error('the type of the last expression in the first match branch was an explicit `${c.table.type_to_str(infer_cast_type)}`, not `${c.table.type_to_str(stmt.expr.typ)}`',
stmt.pos)
} }
} else { } else {
if infer_cast_type != stmt.expr.typ if mut stmt.expr is ast.IntegerLiteral {
&& expr_typ_sym.kind !in [.interface, .sum_type] cast_type_sym := c.table.sym(infer_cast_type)
&& c.promote_num(stmt.expr.typ, ast.int_type) != ast.int_type { num := stmt.expr.val.i64()
c.error('the type of the last expression of the first match branch was `${c.table.type_to_str(infer_cast_type)}`, which is not compatible with `${c.table.type_to_str(stmt.expr.typ)}`', mut needs_explicit_cast := false
stmt.pos)
}
}
} else {
if mut stmt.expr is ast.IntegerLiteral {
cast_type_sym := c.table.sym(infer_cast_type)
num := stmt.expr.val.i64()
mut needs_explicit_cast := false
match cast_type_sym.kind { match cast_type_sym.kind {
.u8 { .u8 {
if !(num >= min_u8 && num <= max_u8) { if !(num >= min_u8 && num <= max_u8) {
needs_explicit_cast = true needs_explicit_cast = true
}
} }
} .u16 {
.u16 { if !(num >= min_u16 && num <= max_u16) {
if !(num >= min_u16 && num <= max_u16) { needs_explicit_cast = true
needs_explicit_cast = true }
} }
} .u32 {
.u32 { if !(num >= min_u32 && num <= max_u32) {
if !(num >= min_u32 && num <= max_u32) { needs_explicit_cast = true
needs_explicit_cast = true }
} }
} .u64 {
.u64 { if !(num >= min_u64 && num <= max_u64) {
if !(num >= min_u64 && num <= max_u64) { needs_explicit_cast = true
needs_explicit_cast = true }
} }
} .i8 {
.i8 { if !(num >= min_i32 && num <= max_i32) {
if !(num >= min_i32 && num <= max_i32) { needs_explicit_cast = true
needs_explicit_cast = true }
} }
} .i16 {
.i16 { if !(num >= min_i16 && num <= max_i16) {
if !(num >= min_i16 && num <= max_i16) { needs_explicit_cast = true
needs_explicit_cast = true }
} }
} .i32, .int {
.i32, .int { if !(num >= min_i32 && num <= max_i32) {
if !(num >= min_i32 && num <= max_i32) { needs_explicit_cast = true
needs_explicit_cast = true }
} }
} .i64 {
.i64 { if !(num >= min_i64 && num <= max_i64) {
if !(num >= min_i64 && num <= max_i64) { needs_explicit_cast = true
needs_explicit_cast = true }
} }
.int_literal {
needs_explicit_cast = false
}
else {}
} }
.int_literal { if needs_explicit_cast {
needs_explicit_cast = false c.error('${num} does not fit the range of `${c.table.type_to_str(infer_cast_type)}`',
stmt.pos)
} }
else {}
}
if needs_explicit_cast {
c.error('${num} does not fit the range of `${c.table.type_to_str(infer_cast_type)}`',
stmt.pos)
} }
} }
} }
} }
} } else if stmt !in [ast.Return, ast.BranchStmt] {
} else if stmt !in [ast.Return, ast.BranchStmt] { if ret_type != ast.void_type {
if ret_type != ast.void_type { c.error('`match` expression requires an expression as the last statement of every branch',
c.error('`match` expression requires an expression as the last statement of every branch', stmt.pos)
stmt.pos) }
} }
} }
} }

View file

@ -0,0 +1,17 @@
module main
fn func[T]() bool {
$match T {
u8, u16 {
return true
}
$else {
// return false
$compile_error('fail')
}
}
}
fn test_comptime_match_eval_only_true_branch() {
assert func[u8]()
}