checker: comptime match only eval true branch

This commit is contained in:
kbkpbot 2025-09-03 08:06:27 +08:00
parent f6b60e4d9f
commit b442fc6349
2 changed files with 178 additions and 158 deletions

View file

@ -221,191 +221,194 @@ fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type {
}
}
}
if node.is_expr {
c.stmts_ending_with_expression(mut branch.stmts, c.expected_or_type)
} else {
c.stmts(mut branch.stmts)
}
c.smartcast_mut_pos = token.Pos{}
c.smartcast_cond_pos = token.Pos{}
if node.is_expr {
if branch.stmts.len == 0 && ret_type != ast.void_type {
c.error('`match` expression requires an expression as the last statement of every branch',
branch.branch_pos)
if !node.is_comptime || (node.is_comptime && comptime_match_branch_result) {
if node.is_expr {
c.stmts_ending_with_expression(mut branch.stmts, c.expected_or_type)
} else {
c.stmts(mut branch.stmts)
}
}
if !branch.is_else && cond_is_option && branch.exprs.any(it !is ast.None) {
c.error('`match` expression with Option type only checks against `none`, to match its value you must unwrap it first `var?`',
branch.pos)
}
if cond_type_sym.kind == .none {
c.error('`none` cannot be a match condition', node.pos)
}
// If the last statement is an expression, return its type
if branch.stmts.len > 0 && node.is_expr {
mut stmt := branch.stmts.last()
if mut stmt is ast.ExprStmt {
c.expected_type = if c.expected_expr_type != ast.void_type {
c.expected_expr_type
} else {
node.expected_type
c.smartcast_mut_pos = token.Pos{}
c.smartcast_cond_pos = token.Pos{}
if node.is_expr {
if branch.stmts.len == 0 && ret_type != ast.void_type {
c.error('`match` expression requires an expression as the last statement of every branch',
branch.branch_pos)
}
expr_type := c.unwrap_generic(if stmt.expr is ast.CallExpr {
stmt.typ
} else {
c.expr(mut stmt.expr)
})
unwrapped_expected_type := c.unwrap_generic(node.expected_type)
must_be_option = must_be_option || expr_type == ast.none_type
stmt.typ = expr_type
if first_iteration {
if unwrapped_expected_type.has_option_or_result()
|| c.table.type_kind(unwrapped_expected_type) in [.sum_type, .multi_return] {
c.check_match_branch_last_stmt(stmt, unwrapped_expected_type,
expr_type)
ret_type = node.expected_type
}
if !branch.is_else && cond_is_option && branch.exprs.any(it !is ast.None) {
c.error('`match` expression with Option type only checks against `none`, to match its value you must unwrap it first `var?`',
branch.pos)
}
if cond_type_sym.kind == .none {
c.error('`none` cannot be a match condition', node.pos)
}
// If the last statement is an expression, return its type
if branch.stmts.len > 0 && node.is_expr {
mut stmt := branch.stmts.last()
if mut stmt is ast.ExprStmt {
c.expected_type = if c.expected_expr_type != ast.void_type {
c.expected_expr_type
} else {
ret_type = expr_type
if expr_type.is_ptr() {
if stmt.expr is ast.Ident && stmt.expr.obj is ast.Var
&& c.table.is_interface_var(stmt.expr.obj) {
ret_type = expr_type.deref()
} else if mut stmt.expr is ast.PrefixExpr
&& stmt.expr.right is ast.Ident {
ident := stmt.expr.right as ast.Ident
if ident.obj is ast.Var && c.table.is_interface_var(ident.obj) {
node.expected_type
}
expr_type := c.unwrap_generic(if stmt.expr is ast.CallExpr {
stmt.typ
} else {
c.expr(mut stmt.expr)
})
unwrapped_expected_type := c.unwrap_generic(node.expected_type)
must_be_option = must_be_option || expr_type == ast.none_type
stmt.typ = expr_type
if first_iteration {
if unwrapped_expected_type.has_option_or_result()
|| c.table.type_kind(unwrapped_expected_type) in [.sum_type, .multi_return] {
c.check_match_branch_last_stmt(stmt, unwrapped_expected_type,
expr_type)
ret_type = node.expected_type
} else {
ret_type = expr_type
if expr_type.is_ptr() {
if stmt.expr is ast.Ident && stmt.expr.obj is ast.Var
&& c.table.is_interface_var(stmt.expr.obj) {
ret_type = expr_type.deref()
} else if mut stmt.expr is ast.PrefixExpr
&& stmt.expr.right is ast.Ident {
ident := stmt.expr.right as ast.Ident
if ident.obj is ast.Var && c.table.is_interface_var(ident.obj) {
ret_type = expr_type.deref()
}
}
}
c.expected_expr_type = expr_type
}
infer_cast_type = stmt.typ
if mut stmt.expr is ast.CastExpr {
need_explicit_cast = true
infer_cast_type = stmt.expr.typ
}
} else {
if ret_type.idx() != expr_type.idx() {
if unwrapped_expected_type.has_option_or_result()
&& c.table.sym(stmt.typ).kind == .struct
&& !c.check_types(expr_type, c.unwrap_generic(ret_type))
&& c.type_implements(stmt.typ, ast.error_type, node.pos) {
stmt.expr = ast.CastExpr{
expr: stmt.expr
typname: 'IError'
typ: ast.error_type
expr_type: stmt.typ
pos: node.pos
}
stmt.typ = ast.error_type
} else {
c.check_match_branch_last_stmt(stmt, c.unwrap_generic(ret_type),
expr_type)
if ret_type.is_number() && expr_type.is_number() && !c.inside_return {
ret_type = c.promote_num(ret_type, expr_type)
}
}
}
c.expected_expr_type = expr_type
}
infer_cast_type = stmt.typ
if mut stmt.expr is ast.CastExpr {
need_explicit_cast = true
infer_cast_type = stmt.expr.typ
}
} else {
if ret_type.idx() != expr_type.idx() {
if unwrapped_expected_type.has_option_or_result()
&& c.table.sym(stmt.typ).kind == .struct
&& !c.check_types(expr_type, c.unwrap_generic(ret_type))
&& c.type_implements(stmt.typ, ast.error_type, node.pos) {
stmt.expr = ast.CastExpr{
expr: stmt.expr
typname: 'IError'
typ: ast.error_type
expr_type: stmt.typ
pos: node.pos
if must_be_option && ret_type == ast.none_type && expr_type != ret_type {
ret_type = expr_type.set_flag(.option)
}
if stmt.typ != ast.error_type && !is_noreturn_callexpr(stmt.expr) {
ret_sym := c.table.sym(ret_type)
stmt_sym := c.table.sym(stmt.typ)
if ret_sym.kind !in [.sum_type, .interface]
&& stmt_sym.kind in [.sum_type, .interface] {
c.error('return type mismatch, it should be `${ret_sym.name}`, but it is instead `${c.table.type_to_str(expr_type)}`',
stmt.pos)
}
stmt.typ = ast.error_type
} else {
c.check_match_branch_last_stmt(stmt, c.unwrap_generic(ret_type),
expr_type)
if ret_type.is_number() && expr_type.is_number() && !c.inside_return {
ret_type = c.promote_num(ret_type, expr_type)
if ret_type.nr_muls() != stmt.typ.nr_muls()
&& stmt.typ.idx() !in [ast.voidptr_type_idx, ast.nil_type_idx] {
type_name := '&'.repeat(ret_type.nr_muls()) + ret_sym.name
c.error('return type mismatch, it should be `${type_name}`, but it is instead `${c.table.type_to_str(expr_type)}`',
stmt.pos)
}
}
}
if must_be_option && ret_type == ast.none_type && expr_type != ret_type {
ret_type = expr_type.set_flag(.option)
}
if stmt.typ != ast.error_type && !is_noreturn_callexpr(stmt.expr) {
ret_sym := c.table.sym(ret_type)
stmt_sym := c.table.sym(stmt.typ)
if ret_sym.kind !in [.sum_type, .interface]
&& stmt_sym.kind in [.sum_type, .interface] {
c.error('return type mismatch, it should be `${ret_sym.name}`, but it is instead `${c.table.type_to_str(expr_type)}`',
stmt.pos)
}
if ret_type.nr_muls() != stmt.typ.nr_muls()
&& stmt.typ.idx() !in [ast.voidptr_type_idx, ast.nil_type_idx] {
type_name := '&'.repeat(ret_type.nr_muls()) + ret_sym.name
c.error('return type mismatch, it should be `${type_name}`, but it is instead `${c.table.type_to_str(expr_type)}`',
stmt.pos)
}
}
if !node.is_sum_type {
if mut stmt.expr is ast.CastExpr {
expr_typ_sym := c.table.sym(stmt.expr.typ)
if need_explicit_cast {
if infer_cast_type != stmt.expr.typ
&& expr_typ_sym.kind !in [.interface, .sum_type] {
c.error('the type of the last expression in the first match branch was an explicit `${c.table.type_to_str(infer_cast_type)}`, not `${c.table.type_to_str(stmt.expr.typ)}`',
stmt.pos)
if !node.is_sum_type {
if mut stmt.expr is ast.CastExpr {
expr_typ_sym := c.table.sym(stmt.expr.typ)
if need_explicit_cast {
if infer_cast_type != stmt.expr.typ
&& expr_typ_sym.kind !in [.interface, .sum_type] {
c.error('the type of the last expression in the first match branch was an explicit `${c.table.type_to_str(infer_cast_type)}`, not `${c.table.type_to_str(stmt.expr.typ)}`',
stmt.pos)
}
} else {
if infer_cast_type != stmt.expr.typ
&& expr_typ_sym.kind !in [.interface, .sum_type]
&& c.promote_num(stmt.expr.typ, ast.int_type) != ast.int_type {
c.error('the type of the last expression of the first match branch was `${c.table.type_to_str(infer_cast_type)}`, which is not compatible with `${c.table.type_to_str(stmt.expr.typ)}`',
stmt.pos)
}
}
} else {
if infer_cast_type != stmt.expr.typ
&& expr_typ_sym.kind !in [.interface, .sum_type]
&& c.promote_num(stmt.expr.typ, ast.int_type) != ast.int_type {
c.error('the type of the last expression of the first match branch was `${c.table.type_to_str(infer_cast_type)}`, which is not compatible with `${c.table.type_to_str(stmt.expr.typ)}`',
stmt.pos)
}
}
} else {
if mut stmt.expr is ast.IntegerLiteral {
cast_type_sym := c.table.sym(infer_cast_type)
num := stmt.expr.val.i64()
mut needs_explicit_cast := false
if mut stmt.expr is ast.IntegerLiteral {
cast_type_sym := c.table.sym(infer_cast_type)
num := stmt.expr.val.i64()
mut needs_explicit_cast := false
match cast_type_sym.kind {
.u8 {
if !(num >= min_u8 && num <= max_u8) {
needs_explicit_cast = true
match cast_type_sym.kind {
.u8 {
if !(num >= min_u8 && num <= max_u8) {
needs_explicit_cast = true
}
}
}
.u16 {
if !(num >= min_u16 && num <= max_u16) {
needs_explicit_cast = true
.u16 {
if !(num >= min_u16 && num <= max_u16) {
needs_explicit_cast = true
}
}
}
.u32 {
if !(num >= min_u32 && num <= max_u32) {
needs_explicit_cast = true
.u32 {
if !(num >= min_u32 && num <= max_u32) {
needs_explicit_cast = true
}
}
}
.u64 {
if !(num >= min_u64 && num <= max_u64) {
needs_explicit_cast = true
.u64 {
if !(num >= min_u64 && num <= max_u64) {
needs_explicit_cast = true
}
}
}
.i8 {
if !(num >= min_i32 && num <= max_i32) {
needs_explicit_cast = true
.i8 {
if !(num >= min_i32 && num <= max_i32) {
needs_explicit_cast = true
}
}
}
.i16 {
if !(num >= min_i16 && num <= max_i16) {
needs_explicit_cast = true
.i16 {
if !(num >= min_i16 && num <= max_i16) {
needs_explicit_cast = true
}
}
}
.i32, .int {
if !(num >= min_i32 && num <= max_i32) {
needs_explicit_cast = true
.i32, .int {
if !(num >= min_i32 && num <= max_i32) {
needs_explicit_cast = true
}
}
}
.i64 {
if !(num >= min_i64 && num <= max_i64) {
needs_explicit_cast = true
.i64 {
if !(num >= min_i64 && num <= max_i64) {
needs_explicit_cast = true
}
}
.int_literal {
needs_explicit_cast = false
}
else {}
}
.int_literal {
needs_explicit_cast = false
if needs_explicit_cast {
c.error('${num} does not fit the range of `${c.table.type_to_str(infer_cast_type)}`',
stmt.pos)
}
else {}
}
if needs_explicit_cast {
c.error('${num} does not fit the range of `${c.table.type_to_str(infer_cast_type)}`',
stmt.pos)
}
}
}
}
}
} else if stmt !in [ast.Return, ast.BranchStmt] {
if ret_type != ast.void_type {
c.error('`match` expression requires an expression as the last statement of every branch',
stmt.pos)
} else if stmt !in [ast.Return, ast.BranchStmt] {
if ret_type != ast.void_type {
c.error('`match` expression requires an expression as the last statement of every branch',
stmt.pos)
}
}
}
}

View file

@ -0,0 +1,17 @@
module main
fn func[T]() bool {
$match T {
u8, u16 {
return true
}
$else {
// return false
$compile_error('fail')
}
}
}
fn test_comptime_match_eval_only_true_branch() {
assert func[u8]()
}