all: support short lambda expressions like a.sorted(|x,y| x > y), in all callsites that accept a fn callback (#19390)

This commit is contained in:
Delyan Angelov 2023-09-20 17:22:16 +03:00 committed by GitHub
parent 175a3b2684
commit f93d257d29
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 385 additions and 25 deletions

View file

@ -0,0 +1,37 @@
fn test_sort_with_lambda_expr() {
a := [5, 2, 1, 9, 8]
dump(a)
sorted01 := a.sorted(a < b)
sorted02 := a.sorted(a > b)
dump(sorted01)
dump(sorted02)
sorted01_with_compare_fn := a.sorted_with_compare(fn (a &int, b &int) int {
return *a - *b
})
sorted02_with_compare_fn := a.sorted_with_compare(fn (a &int, b &int) int {
return *b - *a
})
dump(sorted01_with_compare_fn)
dump(sorted02_with_compare_fn)
///////////////////////////////////////////
sorted01_lambda_expr := a.sorted(|ix, iy| ix < iy)
sorted02_lambda_expr := a.sorted(|ii, jj| ii > jj)
dump(sorted01_lambda_expr)
dump(sorted02_lambda_expr)
sorted01_with_compare_lambda_expr := a.sorted_with_compare(|x, y| *x - *y)
sorted02_with_compare_lambda_expr := a.sorted_with_compare(|e1, e2| *e2 - *e1)
dump(sorted01_with_compare_lambda_expr)
dump(sorted02_with_compare_lambda_expr)
assert sorted01 == sorted01_with_compare_fn
assert sorted02 == sorted02_with_compare_fn
assert sorted01 == sorted01_lambda_expr
assert sorted02 == sorted02_lambda_expr
assert sorted01 == sorted01_with_compare_lambda_expr
assert sorted02 == sorted02_with_compare_lambda_expr
}

View file

@ -38,6 +38,7 @@ pub type Expr = AnonFn
| InfixExpr
| IntegerLiteral
| IsRefType
| LambdaExpr
| Likely
| LockExpr
| MapInit
@ -506,7 +507,7 @@ pub mut:
decl FnDecl
inherited_vars []Param
typ Type // the type of anonymous fn. Both .typ and .decl.name are auto generated
has_gen map[string]bool // has been generated
has_gen map[string]bool // a map of the names of all generic anon functions, generated from it
}
// function or method declaration
@ -782,11 +783,11 @@ pub:
share ShareType
is_mut bool
is_autofree_tmp bool
is_arg bool // fn args should not be autofreed
is_auto_deref bool
is_inherited bool
has_inherited bool
pub mut:
is_arg bool // fn args should not be autofreed
is_auto_deref bool
expr Expr
typ Type
orig_type Type // original sumtype type; 0 if it's not a sumtype
@ -891,6 +892,7 @@ pub mut:
generic_fns []&FnDecl
global_labels []string // from `asm { .globl labelname }`
template_paths []string // all the .html/.md files that were processed with $tmpl
unique_prefix string // a hash of the `.path` field, used for making anon fn generation unique
}
[unsafe]
@ -1253,14 +1255,6 @@ pub mut:
// ct_conds is filled by the checker, based on the current nesting of `$if cond1 {}` blocks
}
/*
// filter(), map(), sort()
pub struct Lambda {
pub:
name string
}
*/
// variable assign statement
[minify]
pub struct AssignStmt {
@ -1790,6 +1784,20 @@ pub:
pos token.Pos
}
pub struct LambdaExpr {
pub:
pos token.Pos
params []Ident
pub mut:
pos_expr token.Pos
expr Expr
pos_end token.Pos
scope &Scope = unsafe { nil }
func &AnonFn = unsafe { nil }
is_checked bool
typ Type
}
pub struct Likely {
pub:
pos token.Pos
@ -1977,7 +1985,7 @@ pub fn (expr Expr) pos() token.Pos {
IsRefType, Likely, LockExpr, MapInit, MatchExpr, None, OffsetOf, OrExpr, ParExpr,
PostfixExpr, PrefixExpr, RangeExpr, SelectExpr, SelectorExpr, SizeOf, SqlExpr,
StringInterLiteral, StringLiteral, StructInit, TypeNode, TypeOf, UnsafeExpr, ComptimeType,
Nil {
LambdaExpr, Nil {
return expr.pos
}
IndexExpr {
@ -2169,6 +2177,12 @@ pub fn (node Node) children() []Node {
TypeOf, ArrayDecompose {
children << node.expr
}
LambdaExpr {
for p in node.params {
children << Node(Expr(p))
}
children << node.expr
}
LockExpr, OrExpr {
return node.stmts.map(Node(it))
}

View file

@ -15,6 +15,11 @@ pub fn (f &FnDecl) get_name() string {
}
}
// get_anon_fn_name returns the unique anonymous function name, based on the prefix, the func signature and its position in the source code
pub fn (table &Table) get_anon_fn_name(prefix string, func &Fn, pos int) string {
return 'anon_fn_${prefix}_${table.fn_type_signature(func)}_${pos}'
}
// get_name returns the real name for the function calling
pub fn (f &CallExpr) get_name() string {
if f.name != '' && f.name.all_after_last('.')[0].is_capital() && f.name.contains('__static__') {
@ -609,6 +614,10 @@ pub fn (x Expr) str() string {
}
return 'typeof(${x.expr.str()})'
}
LambdaExpr {
ilist := x.params.map(it.name).join(', ')
return '|${ilist}| ${x.expr.str()}'
}
Likely {
return '_likely_(${x.expr.str()})'
}

View file

@ -2783,6 +2783,9 @@ pub fn (mut c Checker) expr(mut node ast.Expr) ast.Type {
ast.IntegerLiteral {
return c.int_lit(mut node)
}
ast.LambdaExpr {
return c.lambda_expr(mut node, c.expected_type)
}
ast.LockExpr {
return c.lock_expr(mut node)
}

View file

@ -2597,6 +2597,10 @@ fn (mut c Checker) array_builtin_method_call(mut node ast.CallExpr, left_type as
if method_name in ['filter', 'map', 'any', 'all'] {
// position of `it` doesn't matter
scope_register_it(mut node.scope, node.pos, elem_typ)
} else if method_name == 'sorted_with_compare' && node.args.len == 1 {
if mut node.args[0].expr is ast.LambdaExpr {
c.support_lambda_expr_in_sort(elem_typ.ref(), ast.int_type, mut node.args[0].expr)
}
} else if method_name == 'sort' || method_name == 'sorted' {
if method_name == 'sort' {
if node.left is ast.CallExpr {
@ -2611,7 +2615,9 @@ fn (mut c Checker) array_builtin_method_call(mut node ast.CallExpr, left_type as
if node.args.len > 1 {
c.error('expected 0 or 1 argument, but got ${node.args.len}', node.pos)
} else if node.args.len == 1 {
if node.args[0].expr is ast.InfixExpr {
if mut node.args[0].expr is ast.LambdaExpr {
c.support_lambda_expr_in_sort(elem_typ.ref(), ast.bool_type, mut node.args[0].expr)
} else if node.args[0].expr is ast.InfixExpr {
if node.args[0].expr.op !in [.gt, .lt] {
c.error('`.${method_name}()` can only use `<` or `>` comparison',
node.pos)

View file

@ -0,0 +1,123 @@
module checker
import v.ast
pub fn (mut c Checker) lambda_expr(mut node ast.LambdaExpr, exp_typ ast.Type) ast.Type {
// defer { eprintln('> line: ${@LINE} | exp_typ: $exp_typ | node: ${voidptr(node)} | node.typ: ${node.typ}') }
if node.is_checked {
return node.typ
}
if !c.inside_fn_arg {
c.error('lambda expressions are allowed only inside function or method callsites',
node.pos)
return ast.void_type
}
if exp_typ == 0 {
c.error('lambda expressions are allowed only in places expecting function callbacks',
node.pos)
return ast.void_type
}
exp_sym := c.table.sym(exp_typ)
if exp_sym.kind != .function {
c.error('a lambda expression was used, but `${exp_sym.kind}` was expected', node.pos)
return ast.void_type
}
if exp_sym.info is ast.FnType {
if node.params.len != exp_sym.info.func.params.len {
c.error('lambda expression has ${node.params.len} params, but the expected fn callback needs ${exp_sym.info.func.params.len} params',
node.pos)
return ast.void_type
}
mut params := []ast.Param{}
for idx, mut x in node.params {
eparam := exp_sym.info.func.params[idx]
eparam_type := eparam.typ
eparam_auto_deref := eparam.typ.is_ptr()
if mut v := node.scope.find(x.name) {
if mut v is ast.Var {
v.is_arg = true
v.typ = eparam_type
v.expr = ast.empty_expr
v.is_auto_deref = eparam_auto_deref
}
}
c.ident(mut x)
x.obj.typ = eparam_type
params << ast.Param{
pos: x.pos
name: x.name
typ: eparam_type
type_pos: x.pos
is_auto_rec: eparam_auto_deref
}
}
/////
is_variadic := false
return_type := exp_sym.info.func.return_type
return_type_pos := node.pos
mut stmts := []ast.Stmt{}
mut return_stmt := ast.Return{
pos: node.pos
exprs: [node.expr]
}
stmts << return_stmt
mut func := ast.Fn{
params: params
is_variadic: is_variadic
return_type: return_type
is_method: false
}
name := c.table.get_anon_fn_name(c.file.unique_prefix, func, node.pos.pos)
func.name = name
idx := c.table.find_or_register_fn_type(func, true, false)
typ := ast.new_type(idx)
node.func = &ast.AnonFn{
decl: ast.FnDecl{
name: name
short_name: ''
mod: c.file.mod.name
stmts: stmts
return_type: return_type
return_type_pos: return_type_pos
params: params
is_variadic: is_variadic
is_method: false
is_anon: true
no_body: false
pos: node.pos.extend(node.pos_end)
file: c.file.path
scope: node.scope.parent
}
typ: typ
}
c.anon_fn(mut node.func)
}
node.is_checked = true
node.typ = exp_typ
return exp_typ
}
pub fn (mut c Checker) support_lambda_expr_in_sort(param_type ast.Type, return_type ast.Type, mut expr ast.LambdaExpr) {
is_auto_rec := param_type.is_ptr()
mut expected_fn := ast.Fn{
params: [
ast.Param{
name: 'zza'
typ: param_type
is_auto_rec: is_auto_rec
},
ast.Param{
name: 'zzb'
typ: param_type
is_auto_rec: is_auto_rec
},
]
return_type: return_type
}
expected_fn_type := ast.new_type(c.table.find_or_register_fn_type(expected_fn, true,
false))
c.lambda_expr(mut expr, expected_fn_type)
}

View file

@ -574,7 +574,7 @@ pub fn (mut e Eval) expr(expr ast.Expr, expecting ast.Type) Object {
ast.ConcatExpr, ast.DumpExpr, ast.EmptyExpr, ast.EnumVal, ast.GoExpr, ast.SpawnExpr,
ast.IfGuardExpr, ast.IsRefType, ast.Likely, ast.LockExpr, ast.MapInit, ast.MatchExpr,
ast.Nil, ast.NodeError, ast.None, ast.OffsetOf, ast.OrExpr, ast.RangeExpr, ast.SelectExpr,
ast.SqlExpr, ast.TypeNode, ast.TypeOf {
ast.SqlExpr, ast.TypeNode, ast.TypeOf, ast.LambdaExpr {
e.error('unhandled expression ${typeof(expr).name}')
}
}

View file

@ -671,6 +671,17 @@ pub fn (mut f Fmt) expr(node_ ast.Expr) {
ast.IntegerLiteral {
f.write(node.val)
}
ast.LambdaExpr {
f.write('|')
for i, x in node.params {
f.expr(x)
if i < node.params.len - 1 {
f.write(', ')
}
}
f.write('| ')
f.expr(node.expr)
}
ast.Likely {
f.likely(node)
}

View file

@ -80,9 +80,12 @@ fn (mut g Gen) array_init(node ast.ArrayInit, var_name string) {
}
fn (mut g Gen) fixed_array_init(node ast.ArrayInit, array_type Type, var_name string) {
if node.has_index {
prev_inside_lambda := g.inside_lambda
g.inside_lambda = true
defer {
g.inside_lambda = prev_inside_lambda
}
if node.has_index {
past := g.past_tmp_var_from_var_name(var_name)
defer {
g.past_tmp_var_done(past)
@ -130,7 +133,6 @@ fn (mut g Gen) fixed_array_init(node ast.ArrayInit, array_type Type, var_name st
g.writeln('}')
g.indent--
g.writeln('}')
g.inside_lambda = false
return
}
need_tmp_var := g.inside_call && !g.inside_struct_init && node.exprs.len == 0
@ -246,6 +248,11 @@ fn (mut g Gen) struct_has_array_or_map_field(elem_typ ast.Type) bool {
// `[]int{len: 6, cap: 10, init: index * index}`
fn (mut g Gen) array_init_with_fields(node ast.ArrayInit, elem_type Type, is_amp bool, shared_styp string, var_name string) {
prev_inside_lambda := g.inside_lambda
g.inside_lambda = true
defer {
g.inside_lambda = prev_inside_lambda
}
elem_styp := g.typ(elem_type.typ)
noscan := g.check_noscan(elem_type.typ)
is_default_array := elem_type.unaliased_sym.kind == .array && node.has_default
@ -253,7 +260,6 @@ fn (mut g Gen) array_init_with_fields(node ast.ArrayInit, elem_type Type, is_amp
needs_more_defaults := node.has_len && (g.struct_has_array_or_map_field(elem_type.typ)
|| elem_type.unaliased_sym.kind in [.array, .map])
if node.has_index { // []int{len: 6, init: index * index} when variable it is used in init expression
g.inside_lambda = true
past := g.past_tmp_var_from_var_name(var_name)
defer {
@ -331,7 +337,6 @@ fn (mut g Gen) array_init_with_fields(node ast.ArrayInit, elem_type Type, is_amp
g.indent--
g.writeln('}')
g.set_current_pos_as_last_stmt_pos()
g.inside_lambda = false
return
}
if is_default_array {
@ -447,11 +452,15 @@ fn (mut g Gen) write_closure_fn(mut expr ast.AnonFn) {
// `nums.map(it % 2 == 0)`
fn (mut g Gen) gen_array_map(node ast.CallExpr) {
prev_inside_lambda := g.inside_lambda
g.inside_lambda = true
defer {
g.inside_lambda = prev_inside_lambda
}
past := g.past_tmp_var_new()
defer {
g.past_tmp_var_done(past)
g.inside_lambda = false
}
ret_typ := g.typ(node.return_type)
@ -599,6 +608,8 @@ fn (mut g Gen) gen_array_sort(node ast.CallExpr) {
mut compare_fn := 'compare_${g.unique_file_path_hash}_${elem_stype.replace('*', '_ptr')}'
mut comparison_type := g.unwrap(ast.void_type)
mut left_expr, mut right_expr := '', ''
mut use_lambda := false
mut lambda_fn_name := ''
// the only argument can only be an infix expression like `a < b` or `b.field > a.field`
if node.args.len == 0 {
comparison_type = g.unwrap(info.elem_type.set_nr_muls(0))
@ -610,6 +621,12 @@ fn (mut g Gen) gen_array_sort(node ast.CallExpr) {
}
left_expr = '*a'
right_expr = '*b'
} else if node.args[0].expr is ast.LambdaExpr {
lambda_fn_name = node.args[0].expr.func.decl.name
compare_fn = '${lambda_fn_name}_lambda_wrapper'
use_lambda = true
mut lambda_node := unsafe { node.args[0].expr }
g.gen_anon_fn_decl(mut lambda_node.func)
} else {
infix_expr := node.args[0].expr as ast.InfixExpr
comparison_type = g.unwrap(infix_expr.left_type.set_nr_muls(0))
@ -663,6 +680,8 @@ fn (mut g Gen) gen_array_sort(node ast.CallExpr) {
'${g.typ(comparison_type.typ)}__lt(${left_expr}, ${right_expr})'
} else if comparison_type.unaliased_sym.has_method('<') {
'${g.typ(comparison_type.unaliased)}__lt(${left_expr}, ${right_expr})'
} else if use_lambda {
'${lambda_fn_name}(a, b)'
} else {
'${left_expr} < ${right_expr}'
}

View file

@ -3329,6 +3329,10 @@ fn (mut g Gen) expr(node_ ast.Expr) {
is_ref_type := g.contains_ptr(node_typ)
g.write('/*IsRefType*/ ${is_ref_type}')
}
ast.LambdaExpr {
g.gen_anon_fn(mut node.func)
// g.write('/* lambda expr: ${node_.str()} */')
}
ast.Likely {
if node.is_likely {
g.write('_likely_')

View file

@ -587,6 +587,9 @@ pub fn (mut f Gen) expr(node_ ast.Expr) {
ast.IntegerLiteral {
f.write(node.val)
}
ast.LambdaExpr {
eprintln('> TODO: implement ast.LambdaExpr in the Go backend')
}
ast.Likely {
f.likely(node)
}

View file

@ -962,6 +962,9 @@ fn (mut g JsGen) expr(node_ ast.Expr) {
ast.IntegerLiteral {
g.gen_integer_literal_expr(node)
}
ast.LambdaExpr {
eprintln('> TODO: implement short lambda expressions in the JS backend')
}
ast.Likely {
g.write('(')
g.expr(node.expr)

View file

@ -349,6 +349,9 @@ fn (mut w Walker) expr(node_ ast.Expr) {
}
}
}
ast.LambdaExpr {
w.expr(node.func)
}
ast.Likely {
w.expr(node.expr)
}

View file

@ -178,6 +178,13 @@ fn (mut p Parser) check_expr(precedence int) !ast.Expr {
}
p.inside_unsafe = false
}
.pipe, .logical_or {
if nnn := p.lambda_expr() {
node = nnn
} else {
return error('unexpected lambda expression')
}
}
.key_lock, .key_rlock {
node = p.lock_expr()
}
@ -461,6 +468,7 @@ fn (mut p Parser) check_expr(precedence int) !ast.Expr {
}
}
}
if inside_array_lit {
if p.tok.kind in [.minus, .mul, .amp, .arrow] && p.tok.pos + 1 == p.peek_tok.pos
&& p.prev_tok.pos + p.prev_tok.len + 1 != p.peek_tok.pos {
@ -827,3 +835,66 @@ fn (mut p Parser) process_custom_orm_operators() {
}
}
}
fn (mut p Parser) lambda_expr() ?ast.LambdaExpr {
if !p.inside_call_args {
return none
}
// a) `f(||expr)` for a callback lambda expression with 0 arguments
// b) `f(|a_1,...,a_n| expr_with_a_1_etc_till_a_n)` for a callback with several arguments
if !(p.tok.kind == .logical_or
|| (p.peek_token(1).kind == .name && p.peek_token(2).kind == .pipe)
|| (p.peek_token(1).kind == .name && p.peek_token(2).kind == .comma)) {
return none
}
p.open_scope()
defer {
p.close_scope()
}
mut pos := p.tok.pos()
mut params := []ast.Ident{}
if p.tok.kind == .logical_or {
p.check(.logical_or)
} else {
p.check(.pipe)
for {
if p.tok.kind == .eof {
break
}
ident := p.ident(ast.Language.v)
if p.scope.known_var(ident.name) {
p.error_with_pos('redefinition of parameter `${ident.name}`', ident.pos)
}
params << ident
p.scope.register(ast.Var{
name: ident.name
is_mut: ident.is_mut
is_stack_obj: true
pos: ident.pos
is_used: true
is_arg: true
})
if p.tok.kind == .pipe {
p.next()
break
}
p.check(.comma)
}
}
pos_expr := p.tok.pos()
e := p.expr(0)
pos_end := p.tok.pos()
return ast.LambdaExpr{
pos: pos
pos_expr: pos_expr
pos_end: pos_end
params: params
expr: e
scope: p.scope
}
}

View file

@ -97,6 +97,11 @@ fn (mut p Parser) call_expr(language ast.Language, mod string) ast.CallExpr {
}
fn (mut p Parser) call_args() []ast.CallArg {
prev_inside_call_args := true
p.inside_call_args = true
defer {
p.inside_call_args = prev_inside_call_args
}
mut args := []ast.CallArg{}
start_pos := p.tok.pos()
for p.tok.kind != .rpar {
@ -806,7 +811,7 @@ fn (mut p Parser) anon_fn() ast.AnonFn {
return_type: return_type
is_method: false
}
name := 'anon_fn_${p.unique_prefix}_${p.table.fn_type_signature(func)}_${p.tok.pos}'
name := p.table.get_anon_fn_name(p.unique_prefix, func, p.tok.pos)
keep_fn_name := p.cur_fn_name
p.cur_fn_name = name
if p.tok.kind == .lcbr {

View file

@ -44,6 +44,7 @@ mut:
inside_for bool
inside_fn bool // true even with implicit main
inside_fn_return bool
inside_call_args bool // true inside f( .... )
inside_unsafe_fn bool
inside_str_interp bool
inside_array_lit bool
@ -379,6 +380,7 @@ pub fn (mut p Parser) parse() &ast.File {
notices: notices
global_labels: p.global_labels
template_paths: p.template_paths
unique_prefix: p.unique_prefix
}
}

View file

@ -0,0 +1,47 @@
fn f0(cb fn () int) int {
return cb() * 10
}
fn f1(cb fn (a int) int) int {
return cb(10)
}
fn f2(cb fn (a int, b int) int) int {
return cb(10, 10)
}
fn f3(cb fn (a int, b int, c int) int) int {
return cb(10, 10, 10)
}
enum MyEnum {
no
xyz = 4
other = 10
}
fn f3_different(cb fn (a int, b string, c MyEnum) string) string {
return cb(10, 'abc', .xyz)
}
fn test_lambda_expr() {
assert f0(|| 4) == 40
assert f1(|x| x + 4) == 14
assert f2(|xx, yy| xx + yy + 4) == 24
assert f3(|xxx, yyy, zzz| xxx + yyy + zzz + 4) == 34
assert f3_different(|xxx, yyy, zzz| yyy + ',${xxx}, ${yyy}, ${zzz}') == 'abc,10, abc, xyz'
}
fn doit(x int, y int, cb fn (a int, b int) string) string {
dump(cb)
dump(x)
dump(y)
return cb(x, y)
}
fn test_fn_with_callback_called_with_lambda_expression() {
assert doit(10, 20, fn (aaa int, bbb int) string {
return 'a: ${aaa}, b: ${bbb}'
}) == 'a: 10, b: 20'
assert doit(100, 200, |a, b| 'a: ${a}, b: ${b}') == 'a: 100, b: 200'
}