From d2af0dc96a099cc4f9b73eee6e08dbe455dfd1b3 Mon Sep 17 00:00:00 2001 From: Felipe Pena Date: Fri, 16 Feb 2024 17:04:34 -0300 Subject: [PATCH] cgen, ast, checker: fix auto deref arg when fn expects ref (#20846) --- cmd/tools/vast/vast.v | 1 + vlib/v/ast/ast.v | 1 + vlib/v/checker/fn.v | 4 ++++ vlib/v/gen/c/cgen.v | 3 ++- vlib/v/gen/c/fn.v | 14 ++++++++---- vlib/v/tests/sumtype_ptr_arg_test.v | 33 +++++++++++++++++++++++++++++ 6 files changed, 51 insertions(+), 5 deletions(-) create mode 100644 vlib/v/tests/sumtype_ptr_arg_test.v diff --git a/cmd/tools/vast/vast.v b/cmd/tools/vast/vast.v index a311ee9428..9ea22164f7 100644 --- a/cmd/tools/vast/vast.v +++ b/cmd/tools/vast/vast.v @@ -1560,6 +1560,7 @@ fn (t Tree) call_arg(node ast.CallArg) &Node { obj.add_terse('is_mut', t.bool_node(node.is_mut)) obj.add_terse('share', t.enum_node(node.share)) obj.add_terse('expr', t.expr(node.expr)) + obj.add_terse('should_be_ptr', t.bool_node(node.should_be_ptr)) obj.add('is_tmp_autofree', t.bool_node(node.is_tmp_autofree)) obj.add('pos', t.pos(node.pos)) obj.add('comments', t.array_node_comment(node.comments)) diff --git a/vlib/v/ast/ast.v b/vlib/v/ast/ast.v index 084b1a22b9..1976a88f4a 100644 --- a/vlib/v/ast/ast.v +++ b/vlib/v/ast/ast.v @@ -803,6 +803,7 @@ pub mut: typ Type is_tmp_autofree bool // this tells cgen that a tmp variable has to be used for the arg expression in order to free it after the call pos token.Pos + should_be_ptr bool // fn expects a ptr for this arg // tmp_name string // for autofree } diff --git a/vlib/v/checker/fn.v b/vlib/v/checker/fn.v index b062bb2329..1a9db9f62b 100644 --- a/vlib/v/checker/fn.v +++ b/vlib/v/checker/fn.v @@ -1162,6 +1162,8 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast. } else { func.params[i] } + // registers if the arg must be passed by ref to disable auto deref args + call_arg.should_be_ptr = param.typ.is_ptr() && !param.is_mut if func.is_variadic && call_arg.expr is ast.ArrayDecompose { if i > func.params.len - 1 { c.error('too many arguments in call to `${func.name}`', node.pos) @@ -1969,6 +1971,8 @@ fn (mut c Checker) method_call(mut node ast.CallExpr) ast.Type { } else { info.func.params[i] } + // registers if the arg must be passed by ref to disable auto deref args + arg.should_be_ptr = param.typ.is_ptr() && !param.is_mut if c.table.sym(param.typ).kind == .interface_ { // cannot hide interface expected type to make possible to pass its interface type automatically earg_types << if targ.idx() != param.typ.idx() { param.typ } else { targ } diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index 9a975e62b9..011793a4f1 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -212,6 +212,7 @@ mut: // where an aggregate (at least two types) is generated // sum type deref needs to know which index to deref because unions take care of the correct field aggregate_type_idx int + arg_no_auto_deref bool // smartcast must not be dereferenced branch_parent_pos int // used in BranchStmt (continue/break) for autofree stop position returned_var_name string // to detect that a var doesn't need to be freed since it's being returned infix_left_var_name string // a && if expr @@ -4695,7 +4696,7 @@ fn (mut g Gen) ident(node ast.Ident) { } styp := g.base_type(node.obj.typ) g.write('*(${styp}*)') - } else { + } else if !g.arg_no_auto_deref { g.write('*') } } else if (g.inside_interface_deref && g.table.is_interface_var(node.obj)) diff --git a/vlib/v/gen/c/fn.v b/vlib/v/gen/c/fn.v index 75843fcbce..debf1bb21c 100644 --- a/vlib/v/gen/c/fn.v +++ b/vlib/v/gen/c/fn.v @@ -2248,6 +2248,7 @@ fn (mut g Gen) call_args(node ast.CallExpr) { if is_variadic && i == expected_types.len - 1 { break } + mut is_smartcast := false if arg.expr is ast.Ident { if arg.expr.obj is ast.Var { if arg.expr.obj.smartcasts.len > 0 { @@ -2260,6 +2261,7 @@ fn (mut g Gen) call_args(node ast.CallExpr) { if cast_sym.info is ast.Aggregate { expected_types[i] = cast_sym.info.types[g.aggregate_type_idx] } + is_smartcast = true } } } @@ -2295,7 +2297,7 @@ fn (mut g Gen) call_args(node ast.CallExpr) { g.write('/*autofree arg*/' + name) } } else { - g.ref_or_deref_arg(arg, expected_types[i], node.language) + g.ref_or_deref_arg(arg, expected_types[i], node.language, is_smartcast) } } else { if use_tmp_var_autofree { @@ -2366,7 +2368,8 @@ fn (mut g Gen) call_args(node ast.CallExpr) { noscan := g.check_noscan(arr_info.elem_type) g.write('new_array_from_c_array${noscan}(${variadic_count}, ${variadic_count}, sizeof(${elem_type}), _MOV((${elem_type}[${variadic_count}]){') for j in arg_nr .. args.len { - g.ref_or_deref_arg(args[j], arr_info.elem_type, node.language) + g.ref_or_deref_arg(args[j], arr_info.elem_type, node.language, + false) if j < args.len - 1 { g.write(', ') } @@ -2393,7 +2396,7 @@ fn (mut g Gen) keep_alive_call_pregen(node ast.CallExpr) int { expected_type := node.expected_arg_types[i] typ := g.table.sym(expected_type).cname g.write('${typ} __tmp_arg_${tmp_cnt_save + i} = ') - g.ref_or_deref_arg(arg, expected_type, node.language) + g.ref_or_deref_arg(arg, expected_type, node.language, false) g.writeln(';') } g.empty_line = false @@ -2410,7 +2413,7 @@ fn (mut g Gen) keep_alive_call_postgen(node ast.CallExpr, tmp_cnt_save int) { } @[inline] -fn (mut g Gen) ref_or_deref_arg(arg ast.CallArg, expected_type ast.Type, lang ast.Language) { +fn (mut g Gen) ref_or_deref_arg(arg ast.CallArg, expected_type ast.Type, lang ast.Language, is_smartcast bool) { arg_typ := if arg.expr is ast.ComptimeSelector { g.unwrap_generic(g.comptime.get_comptime_var_type(arg.expr)) } else { @@ -2525,7 +2528,10 @@ fn (mut g Gen) ref_or_deref_arg(arg ast.CallArg, expected_type ast.Type, lang as } } } + // check if the argument must be dereferenced or not + g.arg_no_auto_deref = is_smartcast && !arg_is_ptr && !exp_is_ptr && arg.should_be_ptr g.expr_with_cast(arg.expr, arg_typ, expected_type) + g.arg_no_auto_deref = false if needs_closing { g.write(')') } diff --git a/vlib/v/tests/sumtype_ptr_arg_test.v b/vlib/v/tests/sumtype_ptr_arg_test.v new file mode 100644 index 0000000000..4d573abf0e --- /dev/null +++ b/vlib/v/tests/sumtype_ptr_arg_test.v @@ -0,0 +1,33 @@ +struct Foo { + foo i32 +} + +struct Bar { + bar f32 +} + +type Foobar = Bar | Foo + +fn match_sum(sum Foobar) { + match sum { + Foo { + sum_foo(sum) + } + Bar { + sum_bar(sum) + } + } +} + +fn sum_foo(i &Foo) { + assert true +} + +fn sum_bar(i Bar) { + assert true +} + +fn test_main() { + match_sum(Foo{ foo: 5 }) + match_sum(Bar{ bar: 5 }) +}