tinc_build/codegen/cel/functions/
map.rs

1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote};
3use syn::parse_quote;
4use tinc_cel::CelValue;
5
6use super::Function;
7use crate::codegen::cel::compiler::{CompileError, CompiledExpr, CompilerCtx, ConstantCompiledExpr, RuntimeCompiledExpr};
8use crate::codegen::cel::types::CelType;
9use crate::types::{ProtoModifiedValueType, ProtoType, ProtoValueType};
10
11#[derive(Debug, Clone, Default)]
12pub(crate) struct Map;
13
14fn native_impl(iter: TokenStream, item_ident: syn::Ident, map_fn: impl ToTokens) -> syn::Expr {
15    parse_quote!({
16        let mut collected = Vec::new();
17        let mut iter = (#iter).into_iter();
18        loop {
19            let Some(#item_ident) = iter.next() else {
20                break ::tinc::__private::cel::CelValue::List(collected.into());
21            };
22
23            collected.push(#map_fn);
24        }
25    })
26}
27
28// this.map(<ident>, <expr>)
29impl Function for Map {
30    fn name(&self) -> &'static str {
31        "map"
32    }
33
34    fn syntax(&self) -> &'static str {
35        "<this>.map(<ident>, <expr>)"
36    }
37
38    fn compile(&self, ctx: CompilerCtx) -> Result<CompiledExpr, CompileError> {
39        let Some(this) = &ctx.this else {
40            return Err(CompileError::syntax("missing this", self));
41        };
42
43        if ctx.args.len() != 2 {
44            return Err(CompileError::syntax("invalid number of args", self));
45        }
46
47        let cel_parser::Expression::Ident(variable) = &ctx.args[0] else {
48            return Err(CompileError::syntax("first argument must be an ident", self));
49        };
50
51        match this {
52            CompiledExpr::Runtime(RuntimeCompiledExpr { expr, ty }) => {
53                let mut child_ctx = ctx.child();
54
55                match ty {
56                    CelType::CelValue => {
57                        child_ctx.add_variable(variable, CompiledExpr::runtime(CelType::CelValue, parse_quote!(item)));
58                    }
59                    CelType::Proto(ProtoType::Modified(
60                        ProtoModifiedValueType::Repeated(ty) | ProtoModifiedValueType::Map(ty, _),
61                    )) => {
62                        child_ctx.add_variable(
63                            variable,
64                            CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ty.clone())), parse_quote!(item)),
65                        );
66                    }
67                    v => {
68                        return Err(CompileError::TypeConversion {
69                            ty: Box::new(v.clone()),
70                            message: "type cannot be iterated over".to_string(),
71                        });
72                    }
73                };
74
75                let arg = child_ctx.resolve(&ctx.args[1])?.into_cel()?;
76
77                Ok(CompiledExpr::runtime(
78                    CelType::CelValue,
79                    match &ty {
80                        CelType::CelValue => parse_quote! {
81                            ::tinc::__private::cel::CelValue::cel_map(#expr, |item| {
82                                ::core::result::Result::Ok(
83                                    #arg
84                                )
85                            })?
86                        },
87                        CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(_, _))) => {
88                            native_impl(quote!((#expr).keys()), parse_quote!(item), arg)
89                        }
90                        CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(_))) => {
91                            native_impl(quote!((#expr).iter()), parse_quote!(item), arg)
92                        }
93                        _ => unreachable!(),
94                    },
95                ))
96            }
97            CompiledExpr::Constant(ConstantCompiledExpr {
98                value: value @ (CelValue::List(_) | CelValue::Map(_)),
99            }) => {
100                let compile_val = |value: CelValue<'static>| {
101                    let mut child_ctx = ctx.child();
102
103                    child_ctx.add_variable(variable, CompiledExpr::constant(value));
104
105                    child_ctx.resolve(&ctx.args[1])?.into_cel()
106                };
107
108                let collected: Result<Vec<_>, _> = match value {
109                    CelValue::List(item) => item.iter().cloned().map(compile_val).collect(),
110                    CelValue::Map(item) => item.iter().map(|(key, _)| key).cloned().map(compile_val).collect(),
111                    _ => unreachable!(),
112                };
113
114                let collected = collected?;
115                if collected.iter().any(|c| matches!(c, CompiledExpr::Runtime(_))) {
116                    Ok(CompiledExpr::runtime(
117                        CelType::Proto(ProtoType::Value(ProtoValueType::Bool)),
118                        native_impl(quote!([#(#collected),*]), parse_quote!(item), quote!(item)),
119                    ))
120                } else {
121                    Ok(CompiledExpr::constant(CelValue::List(
122                        collected
123                            .into_iter()
124                            .map(|c| match c {
125                                CompiledExpr::Constant(ConstantCompiledExpr { value }) => value,
126                                _ => unreachable!("all values must be constant"),
127                            })
128                            .collect(),
129                    )))
130                }
131            }
132            CompiledExpr::Constant(ConstantCompiledExpr { value }) => Err(CompileError::TypeConversion {
133                ty: Box::new(CelType::CelValue),
134                message: format!("{value:?} cannot be iterated over"),
135            }),
136        }
137    }
138}
139
140#[cfg(test)]
141#[cfg(feature = "prost")]
142#[cfg_attr(coverage_nightly, coverage(off))]
143mod tests {
144    use quote::quote;
145    use syn::parse_quote;
146    use tinc_cel::{CelValue, CelValueConv};
147
148    use crate::codegen::cel::compiler::{CompiledExpr, Compiler, CompilerCtx};
149    use crate::codegen::cel::functions::{Function, Map};
150    use crate::codegen::cel::types::CelType;
151    use crate::extern_paths::ExternPaths;
152    use crate::path_set::PathSet;
153    use crate::types::{ProtoModifiedValueType, ProtoType, ProtoTypeRegistry, ProtoValueType};
154
155    #[test]
156    fn test_map_syntax() {
157        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
158        let compiler = Compiler::new(&registry);
159        insta::assert_debug_snapshot!(Map.compile(CompilerCtx::new(compiler.child(), None, &[])), @r#"
160        Err(
161            InvalidSyntax {
162                message: "missing this",
163                syntax: "<this>.map(<ident>, <expr>)",
164            },
165        )
166        "#);
167
168        insta::assert_debug_snapshot!(Map.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::String("hi".into()))), &[])), @r#"
169        Err(
170            InvalidSyntax {
171                message: "invalid number of args",
172                syntax: "<this>.map(<ident>, <expr>)",
173            },
174        )
175        "#);
176
177        insta::assert_debug_snapshot!(Map.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::String("hi".into()))), &[
178            cel_parser::parse("x").unwrap(),
179            cel_parser::parse("dyn(x >= 1)").unwrap(),
180        ])), @r#"
181        Err(
182            TypeConversion {
183                ty: CelValue,
184                message: "String(Borrowed(\"hi\")) cannot be iterated over",
185            },
186        )
187        "#);
188
189        insta::assert_debug_snapshot!(Map.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ProtoValueType::Bool)), parse_quote!(input))), &[
190            cel_parser::parse("x").unwrap(),
191            cel_parser::parse("dyn(x >= 1)").unwrap(),
192        ])), @r#"
193        Err(
194            TypeConversion {
195                ty: Proto(
196                    Value(
197                        Bool,
198                    ),
199                ),
200                message: "type cannot be iterated over",
201            },
202        )
203        "#);
204
205        insta::assert_debug_snapshot!(Map.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::List([
206            CelValueConv::conv(0),
207            CelValueConv::conv(1),
208            CelValueConv::conv(-50),
209            CelValueConv::conv(50),
210        ].into_iter().collect()))), &[
211            cel_parser::parse("x").unwrap(),
212            cel_parser::parse("x + 1").unwrap(),
213        ])), @r"
214        Ok(
215            Constant(
216                ConstantCompiledExpr {
217                    value: List(
218                        [
219                            Number(
220                                I64(
221                                    1,
222                                ),
223                            ),
224                            Number(
225                                I64(
226                                    2,
227                                ),
228                            ),
229                            Number(
230                                I64(
231                                    -49,
232                                ),
233                            ),
234                            Number(
235                                I64(
236                                    51,
237                                ),
238                            ),
239                        ],
240                    ),
241                },
242            ),
243        )
244        ");
245
246        let input = CompiledExpr::constant(CelValue::Map(
247            [
248                (CelValueConv::conv("key0"), CelValueConv::conv(0)),
249                (CelValueConv::conv("key1"), CelValueConv::conv(1)),
250                (CelValueConv::conv("key2"), CelValueConv::conv(-50)),
251                (CelValueConv::conv("key3"), CelValueConv::conv(50)),
252            ]
253            .into_iter()
254            .collect(),
255        ));
256
257        let mut ctx = compiler.child();
258        ctx.add_variable("input", input.clone());
259
260        insta::assert_debug_snapshot!(Map.compile(CompilerCtx::new(ctx, Some(input), &[
261            cel_parser::parse("x").unwrap(),
262            cel_parser::parse("input[x]").unwrap(),
263        ])), @r"
264        Ok(
265            Constant(
266                ConstantCompiledExpr {
267                    value: List(
268                        [
269                            Number(
270                                I64(
271                                    0,
272                                ),
273                            ),
274                            Number(
275                                I64(
276                                    1,
277                                ),
278                            ),
279                            Number(
280                                I64(
281                                    -50,
282                                ),
283                            ),
284                            Number(
285                                I64(
286                                    50,
287                                ),
288                            ),
289                        ],
290                    ),
291                },
292            ),
293        )
294        ");
295    }
296
297    #[test]
298    #[cfg(not(valgrind))]
299    fn test_map_runtime_map() {
300        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
301        let mut compiler = Compiler::new(&registry);
302
303        let string_value = CompiledExpr::runtime(
304            CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(
305                ProtoValueType::String,
306                ProtoValueType::Int32,
307            ))),
308            parse_quote!(input),
309        );
310
311        compiler.add_variable("input", string_value.clone());
312
313        let output = Map
314            .compile(CompilerCtx::new(
315                compiler.child(),
316                Some(string_value),
317                &[cel_parser::parse("x").unwrap(), cel_parser::parse("input[x] * 2").unwrap()],
318            ))
319            .unwrap();
320
321        insta::assert_snapshot!(postcompile::compile_str!(
322            postcompile::config! {
323                test: true,
324                dependencies: vec![
325                    postcompile::Dependency::version("tinc", "*"),
326                ],
327            },
328            quote! {
329                fn filter(input: &std::collections::BTreeMap<String, i32>) -> Result<::tinc::__private::cel::CelValue<'_>, ::tinc::__private::cel::CelError<'_>> {
330                    Ok(#output)
331                }
332
333                #[test]
334                fn test_filter() {
335                    assert_eq!(filter(&{
336                        let mut map = std::collections::BTreeMap::new();
337                        map.insert("0".to_string(), 0);
338                        map.insert("1".to_string(), 1);
339                        map.insert("2".to_string(), -50);
340                        map.insert("3".to_string(), 50);
341                        map
342                    }).unwrap(), ::tinc::__private::cel::CelValue::List([
343                        ::tinc::__private::cel::CelValueConv::conv(0),
344                        ::tinc::__private::cel::CelValueConv::conv(2),
345                        ::tinc::__private::cel::CelValueConv::conv(-100),
346                        ::tinc::__private::cel::CelValueConv::conv(100),
347                    ].into_iter().collect()));
348                }
349            },
350        ));
351    }
352
353    #[test]
354    #[cfg(not(valgrind))]
355    fn test_map_runtime_repeated() {
356        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
357        let compiler = Compiler::new(&registry);
358
359        let string_value = CompiledExpr::runtime(
360            CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(ProtoValueType::Int32))),
361            parse_quote!(input),
362        );
363
364        let output = Map
365            .compile(CompilerCtx::new(
366                compiler.child(),
367                Some(string_value),
368                &[cel_parser::parse("x").unwrap(), cel_parser::parse("x * 100").unwrap()],
369            ))
370            .unwrap();
371
372        insta::assert_snapshot!(postcompile::compile_str!(
373            postcompile::config! {
374                test: true,
375                dependencies: vec![
376                    postcompile::Dependency::version("tinc", "*"),
377                ],
378            },
379            quote! {
380                fn filter(input: &Vec<i32>) -> Result<::tinc::__private::cel::CelValue<'_>, ::tinc::__private::cel::CelError<'_>> {
381                    Ok(#output)
382                }
383
384                #[test]
385                fn test_filter() {
386                    assert_eq!(filter(&vec![0, 1, -50, 50]).unwrap(), ::tinc::__private::cel::CelValue::List([
387                        ::tinc::__private::cel::CelValueConv::conv(0),
388                        ::tinc::__private::cel::CelValueConv::conv(100),
389                        ::tinc::__private::cel::CelValueConv::conv(-5000),
390                        ::tinc::__private::cel::CelValueConv::conv(5000),
391                    ].into_iter().collect()));
392                }
393            },
394        ));
395    }
396
397    #[test]
398    #[cfg(not(valgrind))]
399    fn test_map_runtime_cel_value() {
400        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
401        let compiler = Compiler::new(&registry);
402
403        let string_value = CompiledExpr::runtime(CelType::CelValue, parse_quote!(input));
404
405        let output = Map
406            .compile(CompilerCtx::new(
407                compiler.child(),
408                Some(string_value),
409                &[cel_parser::parse("x").unwrap(), cel_parser::parse("x + 1").unwrap()],
410            ))
411            .unwrap();
412
413        insta::assert_snapshot!(postcompile::compile_str!(
414            postcompile::config! {
415                test: true,
416                dependencies: vec![
417                    postcompile::Dependency::version("tinc", "*"),
418                ],
419            },
420            quote! {
421                fn filter<'a>(input: &'a ::tinc::__private::cel::CelValue<'a>) -> Result<::tinc::__private::cel::CelValue<'a>, ::tinc::__private::cel::CelError<'a>> {
422                    Ok(#output)
423                }
424
425                #[test]
426                fn test_filter() {
427                    assert_eq!(filter(&tinc::__private::cel::CelValue::List([
428                        tinc::__private::cel::CelValueConv::conv(5),
429                        tinc::__private::cel::CelValueConv::conv(1),
430                        tinc::__private::cel::CelValueConv::conv(50),
431                         tinc::__private::cel::CelValueConv::conv(-50),
432                    ].into_iter().collect())).unwrap(), tinc::__private::cel::CelValue::List([
433                        tinc::__private::cel::CelValueConv::conv(6),
434                        tinc::__private::cel::CelValueConv::conv(2),
435                        tinc::__private::cel::CelValueConv::conv(51),
436                         tinc::__private::cel::CelValueConv::conv(-49),
437                    ].into_iter().collect()));
438                }
439            },
440        ));
441    }
442
443    #[test]
444    #[cfg(not(valgrind))]
445    fn test_map_const_requires_runtime() {
446        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
447        let compiler = Compiler::new(&registry);
448
449        let list_value = CompiledExpr::constant(CelValue::List(
450            [CelValueConv::conv(5), CelValueConv::conv(0), CelValueConv::conv(1)]
451                .into_iter()
452                .collect(),
453        ));
454
455        let output = Map
456            .compile(CompilerCtx::new(
457                compiler.child(),
458                Some(list_value),
459                &[cel_parser::parse("x").unwrap(), cel_parser::parse("dyn(x / 2)").unwrap()],
460            ))
461            .unwrap();
462
463        insta::assert_snapshot!(postcompile::compile_str!(
464            postcompile::config! {
465                test: true,
466                dependencies: vec![
467                    postcompile::Dependency::version("tinc", "*"),
468                ],
469            },
470            quote! {
471                fn filter() -> Result<::tinc::__private::cel::CelValue<'static>, ::tinc::__private::cel::CelError<'static>> {
472                    Ok(#output)
473                }
474
475                #[test]
476                fn test_filter() {
477                    assert_eq!(filter().unwrap(), ::tinc::__private::cel::CelValue::List([
478                        ::tinc::__private::cel::CelValueConv::conv(2),
479                        ::tinc::__private::cel::CelValueConv::conv(0),
480                        ::tinc::__private::cel::CelValueConv::conv(0),
481                    ].into_iter().collect()));
482                }
483            },
484        ));
485    }
486}