diff --git a/vlib/v/ast/scope.v b/vlib/v/ast/scope.v index 0667aa72e6..f2855ce59b 100644 --- a/vlib/v/ast/scope.v +++ b/vlib/v/ast/scope.v @@ -129,6 +129,13 @@ pub fn (mut s Scope) update_ct_var_kind(name string, kind ComptimeVarKind) { } } +pub fn (mut s Scope) update_smartcasts(name string, typ Type) { + mut obj := unsafe { s.objects[name] } + if mut obj is Var { + obj.smartcasts = [typ] + } +} + // selector_expr: name.field_name pub fn (mut s Scope) register_struct_field(name string, field ScopeStructField) { if f := s.struct_fields[name] { diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index 188730e8b2..6e9f9ae232 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -3794,7 +3794,7 @@ fn (mut c Checker) concat_expr(mut node ast.ConcatExpr) ast.Type { } // smartcast takes the expression with the current type which should be smartcasted to the target type in the given scope -fn (mut c Checker) smartcast(mut expr ast.Expr, cur_type ast.Type, to_type_ ast.Type, mut scope ast.Scope) { +fn (mut c Checker) smartcast(mut expr ast.Expr, cur_type ast.Type, to_type_ ast.Type, mut scope ast.Scope, is_comptime bool) { sym := c.table.sym(cur_type) to_type := if sym.kind == .interface_ && c.table.sym(to_type_).kind != .interface_ { to_type_.ref() @@ -3852,7 +3852,7 @@ fn (mut c Checker) smartcast(mut expr ast.Expr, cur_type ast.Type, to_type_ ast. orig_type = expr.obj.typ } is_inherited = expr.obj.is_inherited - ct_type_var = if expr.obj.ct_type_var == .field_var { + ct_type_var = if is_comptime && expr.obj.ct_type_var != .no_comptime { .smartcast } else { .no_comptime @@ -3861,9 +3861,15 @@ fn (mut c Checker) smartcast(mut expr ast.Expr, cur_type ast.Type, to_type_ ast. // smartcast either if the value is immutable or if the mut argument is explicitly given if (!is_mut || expr.is_mut) && !is_already_casted { smartcasts << to_type + if var := scope.find_var(expr.name) { + if is_comptime && var.ct_type_var == .smartcast { + scope.update_smartcasts(expr.name, to_type) + return + } + } scope.register(ast.Var{ name: expr.name - typ: if ct_type_var == .smartcast { to_type } else { cur_type } + typ: cur_type pos: expr.pos is_used: true is_mut: expr.is_mut diff --git a/vlib/v/checker/comptime.v b/vlib/v/checker/comptime.v index 5d207f3b37..cf7fb0ca66 100644 --- a/vlib/v/checker/comptime.v +++ b/vlib/v/checker/comptime.v @@ -217,7 +217,7 @@ fn (mut c Checker) comptime_for(mut node ast.ComptimeFor) { c.unwrap_generic(node.typ) } else { node.typ = c.expr(mut node.expr) - node.typ + c.unwrap_generic(node.typ) } sym := c.table.final_sym(typ) if sym.kind == .placeholder || typ.has_flag(.generic) { diff --git a/vlib/v/checker/for.v b/vlib/v/checker/for.v index 11eb72b2a3..ddea370e09 100644 --- a/vlib/v/checker/for.v +++ b/vlib/v/checker/for.v @@ -285,7 +285,7 @@ fn (mut c Checker) for_stmt(mut node ast.ForStmt) { if node.cond.right is ast.TypeNode && node.cond.left in [ast.Ident, ast.SelectorExpr] { if c.table.type_kind(node.cond.left_type) in [.sum_type, .interface_] { c.smartcast(mut node.cond.left, node.cond.left_type, node.cond.right_type, mut - node.scope) + node.scope, false) } } } diff --git a/vlib/v/checker/if.v b/vlib/v/checker/if.v index 4c48337723..cf9fd581f1 100644 --- a/vlib/v/checker/if.v +++ b/vlib/v/checker/if.v @@ -142,7 +142,6 @@ fn (mut c Checker) if_expr(mut node ast.IfExpr) ast.Type { } if left is ast.SelectorExpr { comptime_field_name = left.expr.str() - c.comptime.type_map[comptime_field_name] = got_type is_comptime_type_is_expr = true if comptime_field_name == c.comptime.comptime_for_field_var { left_type := c.unwrap_generic(c.comptime.comptime_for_field_type) @@ -177,10 +176,13 @@ fn (mut c Checker) if_expr(mut node ast.IfExpr) ast.Type { left_type := c.unwrap_generic(left.typ) skip_state = c.check_compatible_types(left_type, right as ast.TypeNode) } else if left is ast.Ident { - is_comptime_type_is_expr = true mut checked_type := ast.void_type + is_comptime_type_is_expr = true if var := left.scope.find_var(left.name) { checked_type = c.unwrap_generic(var.typ) + if var.smartcasts.len > 0 { + checked_type = c.unwrap_generic(var.smartcasts.last()) + } } skip_state = c.check_compatible_types(checked_type, right as ast.TypeNode) } @@ -344,7 +346,7 @@ fn (mut c Checker) if_expr(mut node ast.IfExpr) ast.Type { if comptime_field_name.len > 0 { if comptime_field_name == c.comptime.comptime_for_method_var { c.comptime.type_map[comptime_field_name] = c.comptime.comptime_for_method_ret_type - } else { + } else if comptime_field_name == c.comptime.comptime_for_field_var { c.comptime.type_map[comptime_field_name] = c.comptime.comptime_for_field_type } } @@ -516,11 +518,14 @@ fn (mut c Checker) smartcast_if_conds(mut node ast.Expr, mut scope ast.Scope) { c.smartcast_if_conds(mut node.right, mut scope) } else if node.left is ast.Ident && node.op == .ne && node.right is ast.None { c.smartcast(mut node.left, node.left_type, node.left_type.clear_flag(.option), mut - scope) + scope, false) } else if node.op == .key_is { - if node.left_type == ast.Type(0) { + if node.left is ast.Ident && c.comptime.is_comptime_var(node.left) { + node.left_type = c.comptime.get_comptime_var_type(node.left) + } else { node.left_type = c.expr(mut node.left) } + mut is_comptime := false right_expr := node.right right_type := match right_expr { ast.TypeNode { @@ -531,6 +536,7 @@ fn (mut c Checker) smartcast_if_conds(mut node ast.Expr, mut scope ast.Scope) { } ast.Ident { if right_expr.name == c.comptime.comptime_for_variant_var { + is_comptime = true c.comptime.type_map['${c.comptime.comptime_for_variant_var}.typ'] } else { c.error('invalid type `${right_expr}`', right_expr.pos) @@ -544,7 +550,7 @@ fn (mut c Checker) smartcast_if_conds(mut node ast.Expr, mut scope ast.Scope) { } if right_type != ast.Type(0) { right_sym := c.table.sym(right_type) - mut expr_type := c.expr(mut node.left) + mut expr_type := c.unwrap_generic(node.left_type) left_sym := c.table.sym(expr_type) if left_sym.kind == .aggregate { expr_type = (left_sym.info as ast.Aggregate).sum_type @@ -581,7 +587,7 @@ fn (mut c Checker) smartcast_if_conds(mut node ast.Expr, mut scope ast.Scope) { } if left_sym.kind in [.interface_, .sum_type] { c.smartcast(mut node.left, node.left_type, right_type, mut - scope) + scope, is_comptime) } } } diff --git a/vlib/v/checker/infix.v b/vlib/v/checker/infix.v index 24d0b17988..a82cc76a37 100644 --- a/vlib/v/checker/infix.v +++ b/vlib/v/checker/infix.v @@ -675,7 +675,8 @@ fn (mut c Checker) infix_expr(mut node ast.InfixExpr) ast.Type { if typ != ast.none_type_idx { c.error('`${op}` can only be used to test for none in sql', node.pos) } - } else if left_sym.kind !in [.interface_, .sum_type] { + } else if left_sym.kind !in [.interface_, .sum_type] + && !c.comptime.is_comptime_var(node.left) { c.error('`${op}` can only be used with interfaces and sum types', node.pos) // can be used in sql too, but keep err simple } else if mut left_sym.info is ast.SumType { diff --git a/vlib/v/checker/match.v b/vlib/v/checker/match.v index 4cfc7e6513..db4934ffd1 100644 --- a/vlib/v/checker/match.v +++ b/vlib/v/checker/match.v @@ -476,7 +476,8 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, cond_type_sym ast.TypeSym expr_type = expr_types[0].typ } - c.smartcast(mut node.cond, node.cond_type, expr_type, mut branch.scope) + c.smartcast(mut node.cond, node.cond_type, expr_type, mut branch.scope, + false) } } } diff --git a/vlib/v/comptime/comptimeinfo.v b/vlib/v/comptime/comptimeinfo.v index c734b9b5d7..27eb0a1110 100644 --- a/vlib/v/comptime/comptimeinfo.v +++ b/vlib/v/comptime/comptimeinfo.v @@ -54,7 +54,7 @@ pub fn (mut ct ComptimeInfo) get_comptime_var_type(node ast.Expr) ast.Type { node.obj.typ } .smartcast { - ct.type_map['${ct.comptime_for_variant_var}.typ'] or { ast.void_type } + ct.type_map['${ct.comptime_for_variant_var}.typ'] or { node.obj.typ } } .key_var, .value_var { // key and value variables from normal for stmt @@ -77,9 +77,6 @@ pub fn (mut ct ComptimeInfo) get_comptime_var_type(node ast.Expr) ast.Type { ct.comptime_for_variant_var { return ct.type_map['${ct.comptime_for_variant_var}.typ'] } - ct.comptime_for_enum_var { - return ct.type_map['${ct.comptime_for_enum_var}.typ'] - } else { // field_var.typ from $for field return ct.comptime_for_field_type diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index 3c1e1f862c..35143d10f6 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -4546,7 +4546,7 @@ fn (mut g Gen) ident(node ast.Ident) { } } if node.obj.ct_type_var == .smartcast { - cur_variant_sym := g.table.sym(g.comptime.type_map['${g.comptime.comptime_for_variant_var}.typ']) + cur_variant_sym := g.table.sym(g.unwrap_generic(g.comptime.get_comptime_var_type(node))) g.write('${dot}_${cur_variant_sym.cname}') } else if !is_option_unwrap && obj_sym.kind in [.sum_type, .interface_] { diff --git a/vlib/v/gen/c/comptime.v b/vlib/v/gen/c/comptime.v index c176a6c0ec..bb065ff859 100644 --- a/vlib/v/gen/c/comptime.v +++ b/vlib/v/gen/c/comptime.v @@ -896,7 +896,11 @@ fn (mut g Gen) comptime_for(node ast.ComptimeFor) { if sym.info.vals.len > 0 { g.writeln('\tEnumData ${node.val_var} = {0};') } + g.push_new_comptime_info() for val in sym.info.vals { + g.comptime.comptime_for_enum_var = node.val_var + g.comptime.type_map['${node.val_var}.typ'] = node.typ + g.writeln('/* enum vals ${i} */ {') g.writeln('\t${node.val_var}.name = _SLIT("${val}");') g.write('\t${node.val_var}.value = ') @@ -918,6 +922,7 @@ fn (mut g Gen) comptime_for(node ast.ComptimeFor) { g.writeln('}') i++ } + g.pop_comptime_info() } } } else if node.kind == .attributes { diff --git a/vlib/v/gen/c/infix.v b/vlib/v/gen/c/infix.v index dd37f4c5a0..abf8674907 100644 --- a/vlib/v/gen/c/infix.v +++ b/vlib/v/gen/c/infix.v @@ -673,7 +673,11 @@ fn (mut g Gen) infix_expr_in_optimization(left ast.Expr, right ast.ArrayInit) { // infix_expr_is_op generates code for `is` and `!is` fn (mut g Gen) infix_expr_is_op(node ast.InfixExpr) { - mut left_sym := g.table.sym(node.left_type) + mut left_sym := if g.comptime.is_comptime_var(node.left) { + g.table.sym(g.unwrap_generic(g.comptime.get_comptime_var_type(node.left))) + } else { + g.table.sym(node.left_type) + } is_aggregate := left_sym.kind == .aggregate if is_aggregate { parent_left_type := (left_sym.info as ast.Aggregate).sum_type diff --git a/vlib/v/tests/comptime_var_is_check_test.v b/vlib/v/tests/comptime_var_is_check_test.v new file mode 100644 index 0000000000..91fe32e46f --- /dev/null +++ b/vlib/v/tests/comptime_var_is_check_test.v @@ -0,0 +1,19 @@ +type TestSum = int | string + +fn gen[T, R](sum T) R { + $if T is $sumtype { + $for v in sum.variants { + if sum is v { + $if sum is R { + return sum + } + } + } + } + return R{} +} + +fn test_main() { + assert dump(gen[TestSum, string](TestSum('foo'))) == 'foo' + assert dump(gen[TestSum, int](TestSum(123))) == 123 +}