tinc_build/codegen/cel/functions/
all.rs

1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote};
3use syn::parse_quote;
4use tinc_cel::CelValue;
5
6use super::Function;
7use crate::codegen::cel::compiler::{CompileError, CompiledExpr, CompilerCtx, ConstantCompiledExpr, RuntimeCompiledExpr};
8use crate::codegen::cel::types::CelType;
9use crate::types::{ProtoModifiedValueType, ProtoType, ProtoValueType};
10
11#[derive(Debug, Clone, Default)]
12pub(crate) struct All;
13
14fn native_impl(iter: TokenStream, item_ident: syn::Ident, compare: impl ToTokens) -> syn::Expr {
15    parse_quote!({
16        let mut iter = (#iter).into_iter();
17        loop {
18            let Some(#item_ident) = iter.next() else {
19                break true;
20            };
21
22            if !(#compare) {
23                break false;
24            }
25        }
26    })
27}
28
29// this.all(<ident>, <expr>)
30impl Function for All {
31    fn name(&self) -> &'static str {
32        "all"
33    }
34
35    fn syntax(&self) -> &'static str {
36        "<this>.all(<ident>, <expr>)"
37    }
38
39    fn compile(&self, ctx: CompilerCtx) -> Result<CompiledExpr, CompileError> {
40        let Some(this) = &ctx.this else {
41            return Err(CompileError::syntax("missing this", self));
42        };
43
44        if ctx.args.len() != 2 {
45            return Err(CompileError::syntax("invalid number of args, expected 2", self));
46        }
47
48        let cel_parser::Expression::Ident(variable) = &ctx.args[0] else {
49            return Err(CompileError::syntax("first argument must be an ident", self));
50        };
51
52        match this {
53            CompiledExpr::Runtime(RuntimeCompiledExpr { expr, ty }) => {
54                let mut child_ctx = ctx.child();
55
56                match ty {
57                    CelType::CelValue => {
58                        child_ctx.add_variable(variable, CompiledExpr::runtime(CelType::CelValue, parse_quote!(item)));
59                    }
60                    CelType::Proto(ProtoType::Modified(
61                        ProtoModifiedValueType::Repeated(ty) | ProtoModifiedValueType::Map(ty, _),
62                    )) => {
63                        child_ctx.add_variable(
64                            variable,
65                            CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ty.clone())), parse_quote!(item)),
66                        );
67                    }
68                    v => {
69                        return Err(CompileError::TypeConversion {
70                            ty: Box::new(v.clone()),
71                            message: "type cannot be iterated over".to_string(),
72                        });
73                    }
74                };
75
76                let arg = child_ctx.resolve(&ctx.args[1])?.into_bool(&child_ctx);
77
78                Ok(CompiledExpr::runtime(
79                    CelType::Proto(ProtoType::Value(ProtoValueType::Bool)),
80                    match &ty {
81                        CelType::CelValue => parse_quote! {
82                            ::tinc::__private::cel::CelValue::cel_all(#expr, |item| {
83                                ::core::result::Result::Ok(
84                                    #arg
85                                )
86                            })?
87                        },
88                        CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(_, _))) => {
89                            native_impl(quote!((#expr).keys()), parse_quote!(item), arg)
90                        }
91                        CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(_))) => {
92                            native_impl(quote!((#expr).iter()), parse_quote!(item), arg)
93                        }
94                        _ => unreachable!(),
95                    },
96                ))
97            }
98            CompiledExpr::Constant(ConstantCompiledExpr {
99                value: value @ (CelValue::List(_) | CelValue::Map(_)),
100            }) => {
101                let compile_val = |value: CelValue<'static>| {
102                    let mut child_ctx = ctx.child();
103
104                    child_ctx.add_variable(variable, CompiledExpr::constant(value));
105
106                    child_ctx.resolve(&ctx.args[1]).map(|v| v.into_bool(&child_ctx))
107                };
108
109                let collected: Result<Vec<_>, _> = match value {
110                    CelValue::List(item) => item.iter().cloned().map(compile_val).collect(),
111                    CelValue::Map(item) => item.iter().map(|(key, _)| key).cloned().map(compile_val).collect(),
112                    _ => unreachable!(),
113                };
114
115                let collected = collected?;
116                if collected.iter().any(|c| matches!(c, CompiledExpr::Runtime(_))) {
117                    Ok(CompiledExpr::runtime(
118                        CelType::Proto(ProtoType::Value(ProtoValueType::Bool)),
119                        native_impl(quote!([#(#collected),*]), parse_quote!(item), quote!(item)),
120                    ))
121                } else {
122                    Ok(CompiledExpr::constant(CelValue::Bool(collected.into_iter().all(
123                        |c| match c {
124                            CompiledExpr::Constant(ConstantCompiledExpr { value }) => value.to_bool(),
125                            _ => unreachable!("all values must be constant"),
126                        },
127                    ))))
128                }
129            }
130            CompiledExpr::Constant(ConstantCompiledExpr { value }) => Err(CompileError::TypeConversion {
131                ty: Box::new(CelType::CelValue),
132                message: format!("{value:?} cannot be iterated over"),
133            }),
134        }
135    }
136}
137
138#[cfg(test)]
139#[cfg(feature = "prost")]
140#[cfg_attr(coverage_nightly, coverage(off))]
141mod tests {
142    use quote::quote;
143    use syn::parse_quote;
144    use tinc_cel::{CelValue, CelValueConv};
145
146    use crate::codegen::cel::compiler::{CompiledExpr, Compiler, CompilerCtx};
147    use crate::codegen::cel::functions::{All, Function};
148    use crate::codegen::cel::types::CelType;
149    use crate::extern_paths::ExternPaths;
150    use crate::path_set::PathSet;
151    use crate::types::{ProtoModifiedValueType, ProtoType, ProtoTypeRegistry, ProtoValueType};
152
153    #[test]
154    fn test_all_syntax() {
155        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
156        let compiler = Compiler::new(&registry);
157        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), None, &[])), @r#"
158        Err(
159            InvalidSyntax {
160                message: "missing this",
161                syntax: "<this>.all(<ident>, <expr>)",
162            },
163        )
164        "#);
165
166        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::List(Default::default()))), &[])), @r#"
167        Err(
168            InvalidSyntax {
169                message: "invalid number of args, expected 2",
170                syntax: "<this>.all(<ident>, <expr>)",
171            },
172        )
173        "#);
174
175        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::String("hi".into()))), &[
176            cel_parser::parse("x").unwrap(),
177            cel_parser::parse("dyn(x >= 1)").unwrap(),
178        ])), @r#"
179        Err(
180            TypeConversion {
181                ty: CelValue,
182                message: "String(Borrowed(\"hi\")) cannot be iterated over",
183            },
184        )
185        "#);
186
187        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ProtoValueType::Bool)), parse_quote!(input))), &[
188            cel_parser::parse("x").unwrap(),
189            cel_parser::parse("dyn(x >= 1)").unwrap(),
190        ])), @r#"
191        Err(
192            TypeConversion {
193                ty: Proto(
194                    Value(
195                        Bool,
196                    ),
197                ),
198                message: "type cannot be iterated over",
199            },
200        )
201        "#);
202
203        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::List(Default::default()))), &[
204            cel_parser::parse("1 + 1").unwrap(), // not an ident
205            cel_parser::parse("x + 2").unwrap(),
206        ])), @r#"
207        Err(
208            InvalidSyntax {
209                message: "first argument must be an ident",
210                syntax: "<this>.all(<ident>, <expr>)",
211            },
212        )
213        "#);
214
215        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::List([
216            CelValueConv::conv(4),
217            CelValueConv::conv(3),
218            CelValueConv::conv(10),
219        ].into_iter().collect()))), &[
220            cel_parser::parse("x").unwrap(),
221            cel_parser::parse("x > 2").unwrap(),
222        ])), @r"
223        Ok(
224            Constant(
225                ConstantCompiledExpr {
226                    value: Bool(
227                        true,
228                    ),
229                },
230            ),
231        )
232        ");
233
234        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::List([
235            CelValueConv::conv(2),
236        ].into_iter().collect()))), &[
237            cel_parser::parse("x").unwrap(),
238            cel_parser::parse("x > 2").unwrap(),
239        ])), @r"
240        Ok(
241            Constant(
242                ConstantCompiledExpr {
243                    value: Bool(
244                        false,
245                    ),
246                },
247            ),
248        )
249        ");
250
251        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::Map([
252            (CelValueConv::conv(2), CelValue::Null),
253        ].into_iter().collect()))), &[
254            cel_parser::parse("x").unwrap(),
255            cel_parser::parse("x > 2").unwrap(),
256        ])), @r"
257        Ok(
258            Constant(
259                ConstantCompiledExpr {
260                    value: Bool(
261                        false,
262                    ),
263                },
264            ),
265        )
266        ");
267
268        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValueConv::conv(1))), &[
269            cel_parser::parse("x").unwrap(),
270            cel_parser::parse("x > 2").unwrap(),
271        ])), @r#"
272        Err(
273            TypeConversion {
274                ty: CelValue,
275                message: "Number(I64(1)) cannot be iterated over",
276            },
277        )
278        "#);
279    }
280
281    #[test]
282    #[cfg(not(valgrind))]
283    fn test_all_cel_value() {
284        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
285        let compiler = Compiler::new(&registry);
286
287        let map = CompiledExpr::runtime(CelType::CelValue, parse_quote!(input));
288
289        let result = All
290            .compile(CompilerCtx::new(
291                compiler.child(),
292                Some(map),
293                &[
294                    cel_parser::parse("x").unwrap(), // not an ident
295                    cel_parser::parse("x > 2").unwrap(),
296                ],
297            ))
298            .unwrap();
299
300        let result = postcompile::compile_str!(
301            postcompile::config! {
302                test: true,
303                dependencies: vec![
304                    postcompile::Dependency::version("tinc", "*"),
305                ],
306            },
307            quote! {
308                #[allow(dead_code)]
309                fn all<'a>(
310                    input: ::tinc::__private::cel::CelValue<'a>,
311                ) -> Result<bool, ::tinc::__private::cel::CelError<'a>> {
312                    Ok(
313                        #result
314                    )
315                }
316
317                #[test]
318                fn test_all() {
319                    assert_eq!(all(::tinc::__private::cel::CelValueConv::conv(&[0, 1, 2] as &[i32])).unwrap(), false);
320                    assert_eq!(all(::tinc::__private::cel::CelValueConv::conv(&[3, 4, 5] as &[i32])).unwrap(), true);
321                    assert_eq!(all(::tinc::__private::cel::CelValueConv::conv(&[] as &[i32])).unwrap(), true);
322                }
323            },
324        );
325
326        insta::assert_snapshot!(result);
327    }
328
329    #[test]
330    #[cfg(not(valgrind))]
331    fn test_all_proto_map() {
332        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
333        let compiler = Compiler::new(&registry);
334
335        let map = CompiledExpr::runtime(
336            CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(
337                ProtoValueType::Int32,
338                ProtoValueType::Float,
339            ))),
340            parse_quote!(input),
341        );
342
343        let result = All
344            .compile(CompilerCtx::new(
345                compiler.child(),
346                Some(map),
347                &[
348                    cel_parser::parse("x").unwrap(), // not an ident
349                    cel_parser::parse("x > 2").unwrap(),
350                ],
351            ))
352            .unwrap();
353
354        let result = postcompile::compile_str!(
355            postcompile::config! {
356                test: true,
357                dependencies: vec![
358                    postcompile::Dependency::version("tinc", "*"),
359                ],
360            },
361            quote! {
362                #[allow(dead_code)]
363                fn all(
364                    input: &std::collections::BTreeMap<i32, f32>,
365                ) -> Result<bool, ::tinc::__private::cel::CelError<'static>> {
366                    Ok(
367                        #result
368                    )
369                }
370
371                #[test]
372                fn test_all() {
373                    assert_eq!(all(&{
374                        let mut map = std::collections::BTreeMap::new();
375                        map.insert(3, 2.0);
376                        map.insert(4, 2.0);
377                        map.insert(5, 2.0);
378                        map
379                    }).unwrap(), true);
380                    assert_eq!(all(&{
381                        let mut map = std::collections::BTreeMap::new();
382                        map.insert(3, 2.0);
383                        map.insert(1, 2.0);
384                        map.insert(5, 2.0);
385                        map
386                    }).unwrap(), false);
387                    assert_eq!(all(&std::collections::BTreeMap::new()).unwrap(), true)
388                }
389            },
390        );
391
392        insta::assert_snapshot!(result);
393    }
394
395    #[test]
396    #[cfg(not(valgrind))]
397    fn test_all_proto_repeated() {
398        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
399        let compiler = Compiler::new(&registry);
400
401        let repeated = CompiledExpr::runtime(
402            CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(ProtoValueType::Int32))),
403            parse_quote!(input),
404        );
405
406        let result = All
407            .compile(CompilerCtx::new(
408                compiler.child(),
409                Some(repeated),
410                &[
411                    cel_parser::parse("x").unwrap(), // not an ident
412                    cel_parser::parse("x > 2").unwrap(),
413                ],
414            ))
415            .unwrap();
416
417        let result = postcompile::compile_str!(
418            postcompile::config! {
419                test: true,
420                dependencies: vec![
421                    postcompile::Dependency::version("tinc", "*"),
422                ],
423            },
424            quote! {
425                #[allow(dead_code)]
426                fn all(
427                    input: &Vec<i32>,
428                ) -> Result<bool, ::tinc::__private::cel::CelError<'static>> {
429                    Ok(
430                        #result
431                    )
432                }
433
434                #[test]
435                fn test_all() {
436                    assert_eq!(all(&vec![1, 2, 3]).unwrap(), false);
437                    assert_eq!(all(&vec![3, 4, 60]).unwrap(), true);
438                    assert_eq!(all(&vec![]).unwrap(), true);
439                }
440            },
441        );
442
443        insta::assert_snapshot!(result);
444    }
445
446    #[test]
447    #[cfg(not(valgrind))]
448    fn test_all_const_needs_runtime() {
449        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
450        let compiler = Compiler::new(&registry);
451
452        let list = CompiledExpr::constant(CelValue::List([CelValue::Number(0.into())].into_iter().collect()));
453
454        let result = All
455            .compile(CompilerCtx::new(
456                compiler.child(),
457                Some(list),
458                &[
459                    cel_parser::parse("x").unwrap(), // not an ident
460                    cel_parser::parse("dyn(x > 2)").unwrap(),
461                ],
462            ))
463            .unwrap();
464
465        let result = postcompile::compile_str!(
466            postcompile::config! {
467                test: true,
468                dependencies: vec![
469                    postcompile::Dependency::version("tinc", "*"),
470                ],
471            },
472            quote! {
473                #[allow(dead_code)]
474                fn all() -> Result<bool, ::tinc::__private::cel::CelError<'static>> {
475                    Ok(
476                        #result
477                    )
478                }
479
480                #[test]
481                fn test_all() {
482                    assert_eq!(all().unwrap(), false);
483                }
484            },
485        );
486
487        insta::assert_snapshot!(result);
488    }
489
490    #[test]
491    #[cfg(not(valgrind))]
492    fn test_all_runtime() {
493        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
494        let compiler = Compiler::new(&registry);
495
496        let list = CompiledExpr::runtime(
497            CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(ProtoValueType::Int32))),
498            parse_quote!(input),
499        );
500
501        let result = All
502            .compile(CompilerCtx::new(
503                compiler.child(),
504                Some(list),
505                &[
506                    cel_parser::parse("x").unwrap(), // not an ident
507                    cel_parser::parse("x > 2").unwrap(),
508                ],
509            ))
510            .unwrap();
511
512        insta::assert_snapshot!(postcompile::compile_str!(
513            postcompile::config! {
514                test: true,
515                dependencies: vec![
516                    postcompile::Dependency::version("tinc", "*"),
517                ],
518            },
519            quote! {
520                #[allow(dead_code)]
521                fn runtime_slice(
522                    input: &[i32],
523                ) -> Result<bool, ::tinc::__private::cel::CelError<'static>> {
524                    Ok(
525                        #result
526                    )
527                }
528
529                #[allow(dead_code)]
530                fn runtime_vec(
531                    input: &Vec<i32>,
532                ) -> Result<bool, ::tinc::__private::cel::CelError<'static>> {
533                    Ok(
534                        #result
535                    )
536                }
537
538                #[test]
539                fn test_empty_lists() {
540                    assert!(runtime_slice(&[]).unwrap());
541                    assert!(runtime_vec(&vec![]).unwrap());
542                    assert!(runtime_slice(&[3, 4, 5]).unwrap());
543                    assert!(runtime_vec(&vec![3, 4, 5]).unwrap());
544                    assert!(!runtime_slice(&[3, 4, 5, 2]).unwrap());
545                    assert!(!runtime_vec(&vec![3, 4, 5, 2]).unwrap());
546                }
547            },
548        ));
549    }
550}