tinc_build/codegen/cel/functions/
filter.rs

1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote};
3use syn::parse_quote;
4use tinc_cel::CelValue;
5
6use super::Function;
7use crate::codegen::cel::compiler::{CompileError, CompiledExpr, CompilerCtx, ConstantCompiledExpr, RuntimeCompiledExpr};
8use crate::codegen::cel::types::CelType;
9use crate::types::{ProtoModifiedValueType, ProtoType, ProtoValueType};
10
11#[derive(Debug, Clone, Default)]
12pub(crate) struct Filter;
13
14fn native_impl(iter: TokenStream, item_ident: syn::Ident, compare: impl ToTokens) -> syn::Expr {
15    parse_quote!({
16        let mut collected = Vec::new();
17        let mut iter = (#iter).into_iter();
18        loop {
19            let Some(#item_ident) = iter.next() else {
20                break ::tinc::__private::cel::CelValue::List(collected.into());
21            };
22
23            if {
24                let #item_ident = #item_ident.clone();
25                #compare
26            } {
27                collected.push(#item_ident);
28            }
29        }
30    })
31}
32
33// this.filter(<ident>, <expr>)
34impl Function for Filter {
35    fn name(&self) -> &'static str {
36        "filter"
37    }
38
39    fn syntax(&self) -> &'static str {
40        "<this>.filter(<ident>, <expr>)"
41    }
42
43    fn compile(&self, ctx: CompilerCtx) -> Result<CompiledExpr, CompileError> {
44        let Some(this) = &ctx.this else {
45            return Err(CompileError::syntax("missing this", self));
46        };
47
48        if ctx.args.len() != 2 {
49            return Err(CompileError::syntax("invalid number of args", self));
50        }
51
52        let cel_parser::Expression::Ident(variable) = &ctx.args[0] else {
53            return Err(CompileError::syntax("first argument must be an ident", self));
54        };
55
56        match this {
57            CompiledExpr::Runtime(RuntimeCompiledExpr { expr, ty }) => {
58                let mut child_ctx = ctx.child();
59
60                match ty {
61                    CelType::CelValue => {
62                        child_ctx.add_variable(variable, CompiledExpr::runtime(CelType::CelValue, parse_quote!(item)));
63                    }
64                    CelType::Proto(ProtoType::Modified(
65                        ProtoModifiedValueType::Repeated(ty) | ProtoModifiedValueType::Map(ty, _),
66                    )) => {
67                        child_ctx.add_variable(
68                            variable,
69                            CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ty.clone())), parse_quote!(item)),
70                        );
71                    }
72                    v => {
73                        return Err(CompileError::TypeConversion {
74                            ty: Box::new(v.clone()),
75                            message: "type cannot be iterated over".to_string(),
76                        });
77                    }
78                };
79
80                let arg = child_ctx.resolve(&ctx.args[1])?.into_bool(&child_ctx);
81
82                Ok(CompiledExpr::runtime(
83                    CelType::CelValue,
84                    match ty {
85                        CelType::CelValue => parse_quote! {
86                            ::tinc::__private::cel::CelValue::cel_filter(#expr, |item| {
87                                ::core::result::Result::Ok(
88                                    #arg
89                                )
90                            })?
91                        },
92                        CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(ty, _))) => {
93                            let cel_ty =
94                                CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ty.clone())), parse_quote!(item))
95                                    .into_cel()?;
96
97                            native_impl(
98                                quote!(
99                                    (#expr).keys().map(|item| #cel_ty)
100                                ),
101                                parse_quote!(item),
102                                arg,
103                            )
104                        }
105                        CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(ty))) => {
106                            let cel_ty =
107                                CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ty.clone())), parse_quote!(item))
108                                    .into_cel()?;
109
110                            native_impl(
111                                quote!(
112                                    (#expr).iter().map(|item| #cel_ty)
113                                ),
114                                parse_quote!(item),
115                                arg,
116                            )
117                        }
118                        _ => unreachable!(),
119                    },
120                ))
121            }
122            CompiledExpr::Constant(ConstantCompiledExpr {
123                value: value @ (CelValue::List(_) | CelValue::Map(_)),
124            }) => {
125                let compile_val = |value: CelValue<'static>| {
126                    let mut child_ctx = ctx.child();
127
128                    child_ctx.add_variable(variable, CompiledExpr::constant(value.clone()));
129
130                    child_ctx.resolve(&ctx.args[1]).map(|v| (value, v.into_bool(&child_ctx)))
131                };
132
133                let collected: Result<Vec<_>, _> = match value {
134                    CelValue::List(item) => item.iter().cloned().map(compile_val).collect(),
135                    CelValue::Map(item) => item.iter().map(|(key, _)| key).cloned().map(compile_val).collect(),
136                    _ => unreachable!(),
137                };
138
139                let collected = collected?;
140                if collected.iter().any(|(_, c)| matches!(c, CompiledExpr::Runtime(_))) {
141                    let collected = collected.into_iter().map(|(item, expr)| {
142                        let item = CompiledExpr::constant(item);
143                        quote! {
144                            if #expr {
145                                collected.push(#item);
146                            }
147                        }
148                    });
149
150                    Ok(CompiledExpr::runtime(
151                        CelType::Proto(ProtoType::Value(ProtoValueType::Bool)),
152                        parse_quote!({
153                            let mut collected = Vec::new();
154                            #(#collected)*
155                            ::tinc::__private::cel::CelValue::List(collected.into())
156                        }),
157                    ))
158                } else {
159                    Ok(CompiledExpr::constant(CelValue::List(
160                        collected
161                            .into_iter()
162                            .filter_map(|(item, c)| match c {
163                                CompiledExpr::Constant(ConstantCompiledExpr { value }) => {
164                                    if value.to_bool() {
165                                        Some(item)
166                                    } else {
167                                        None
168                                    }
169                                }
170                                _ => unreachable!("all values must be constant"),
171                            })
172                            .collect(),
173                    )))
174                }
175            }
176            CompiledExpr::Constant(ConstantCompiledExpr { value }) => Err(CompileError::TypeConversion {
177                ty: Box::new(CelType::CelValue),
178                message: format!("{value:?} cannot be iterated over"),
179            }),
180        }
181    }
182}
183
184#[cfg(test)]
185#[cfg(feature = "prost")]
186#[cfg_attr(coverage_nightly, coverage(off))]
187mod tests {
188    use quote::quote;
189    use syn::parse_quote;
190    use tinc_cel::{CelValue, CelValueConv};
191
192    use crate::codegen::cel::compiler::{CompiledExpr, Compiler, CompilerCtx};
193    use crate::codegen::cel::functions::{Filter, Function};
194    use crate::codegen::cel::types::CelType;
195    use crate::extern_paths::ExternPaths;
196    use crate::path_set::PathSet;
197    use crate::types::{ProtoModifiedValueType, ProtoType, ProtoTypeRegistry, ProtoValueType};
198
199    #[test]
200    fn test_filter_syntax() {
201        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
202        let compiler = Compiler::new(&registry);
203        insta::assert_debug_snapshot!(Filter.compile(CompilerCtx::new(compiler.child(), None, &[])), @r#"
204        Err(
205            InvalidSyntax {
206                message: "missing this",
207                syntax: "<this>.filter(<ident>, <expr>)",
208            },
209        )
210        "#);
211
212        insta::assert_debug_snapshot!(Filter.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::String("hi".into()))), &[])), @r#"
213        Err(
214            InvalidSyntax {
215                message: "invalid number of args",
216                syntax: "<this>.filter(<ident>, <expr>)",
217            },
218        )
219        "#);
220
221        insta::assert_debug_snapshot!(Filter.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::String("hi".into()))), &[
222            cel_parser::parse("x").unwrap(),
223            cel_parser::parse("dyn(x >= 1)").unwrap(),
224        ])), @r#"
225        Err(
226            TypeConversion {
227                ty: CelValue,
228                message: "String(Borrowed(\"hi\")) cannot be iterated over",
229            },
230        )
231        "#);
232
233        insta::assert_debug_snapshot!(Filter.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ProtoValueType::Bool)), parse_quote!(input))), &[
234            cel_parser::parse("x").unwrap(),
235            cel_parser::parse("dyn(x >= 1)").unwrap(),
236        ])), @r#"
237        Err(
238            TypeConversion {
239                ty: Proto(
240                    Value(
241                        Bool,
242                    ),
243                ),
244                message: "type cannot be iterated over",
245            },
246        )
247        "#);
248
249        insta::assert_debug_snapshot!(Filter.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::List([
250            CelValueConv::conv(0),
251            CelValueConv::conv(1),
252            CelValueConv::conv(-50),
253            CelValueConv::conv(50),
254        ].into_iter().collect()))), &[
255            cel_parser::parse("x").unwrap(),
256            cel_parser::parse("x >= 1").unwrap(),
257        ])), @r"
258        Ok(
259            Constant(
260                ConstantCompiledExpr {
261                    value: List(
262                        [
263                            Number(
264                                I64(
265                                    1,
266                                ),
267                            ),
268                            Number(
269                                I64(
270                                    50,
271                                ),
272                            ),
273                        ],
274                    ),
275                },
276            ),
277        )
278        ");
279
280        let input = CompiledExpr::constant(CelValue::Map(
281            [
282                (CelValueConv::conv("key0"), CelValueConv::conv(0)),
283                (CelValueConv::conv("key1"), CelValueConv::conv(1)),
284                (CelValueConv::conv("key2"), CelValueConv::conv(-50)),
285                (CelValueConv::conv("key3"), CelValueConv::conv(50)),
286            ]
287            .into_iter()
288            .collect(),
289        ));
290
291        let mut ctx = compiler.child();
292        ctx.add_variable("input", input.clone());
293
294        insta::assert_debug_snapshot!(Filter.compile(CompilerCtx::new(ctx, Some(input), &[
295            cel_parser::parse("x").unwrap(),
296            cel_parser::parse("input[x] >= 1").unwrap(),
297        ])), @r#"
298        Ok(
299            Constant(
300                ConstantCompiledExpr {
301                    value: List(
302                        [
303                            String(
304                                Borrowed(
305                                    "key1",
306                                ),
307                            ),
308                            String(
309                                Borrowed(
310                                    "key3",
311                                ),
312                            ),
313                        ],
314                    ),
315                },
316            ),
317        )
318        "#);
319    }
320
321    #[test]
322    #[cfg(not(valgrind))]
323    fn test_filter_runtime_map() {
324        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
325        let mut compiler = Compiler::new(&registry);
326
327        let string_value = CompiledExpr::runtime(
328            CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(
329                ProtoValueType::String,
330                ProtoValueType::Int32,
331            ))),
332            parse_quote!(input),
333        );
334
335        compiler.add_variable("input", string_value.clone());
336
337        let output = Filter
338            .compile(CompilerCtx::new(
339                compiler.child(),
340                Some(string_value),
341                &[cel_parser::parse("x").unwrap(), cel_parser::parse("input[x] >= 1").unwrap()],
342            ))
343            .unwrap();
344
345        insta::assert_snapshot!(postcompile::compile_str!(
346            postcompile::config! {
347                test: true,
348                dependencies: vec![
349                    postcompile::Dependency::version("tinc", "*"),
350                ],
351            },
352            quote! {
353                fn filter(input: &std::collections::BTreeMap<String, i32>) -> Result<::tinc::__private::cel::CelValue<'_>, ::tinc::__private::cel::CelError<'_>> {
354                    Ok(#output)
355                }
356
357                #[test]
358                fn test_filter() {
359                    assert_eq!(filter(&{
360                        let mut map = std::collections::BTreeMap::new();
361                        map.insert("0".to_string(), 0);
362                        map.insert("1".to_string(), 1);
363                        map.insert("-50".to_string(), -50);
364                        map.insert("50".to_string(), 50);
365                        map
366                    }).unwrap(), ::tinc::__private::cel::CelValue::List([
367                        ::tinc::__private::cel::CelValueConv::conv("1"),
368                        ::tinc::__private::cel::CelValueConv::conv("50"),
369                    ].into_iter().collect()));
370                }
371            },
372        ));
373    }
374
375    #[test]
376    #[cfg(not(valgrind))]
377    fn test_filter_runtime_repeated() {
378        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
379        let compiler = Compiler::new(&registry);
380
381        let string_value = CompiledExpr::runtime(
382            CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(ProtoValueType::Int32))),
383            parse_quote!(input),
384        );
385
386        let output = Filter
387            .compile(CompilerCtx::new(
388                compiler.child(),
389                Some(string_value),
390                &[cel_parser::parse("x").unwrap(), cel_parser::parse("x >= 1").unwrap()],
391            ))
392            .unwrap();
393
394        insta::assert_snapshot!(postcompile::compile_str!(
395            postcompile::config! {
396                test: true,
397                dependencies: vec![
398                    postcompile::Dependency::version("tinc", "*"),
399                ],
400            },
401            quote! {
402                fn filter(input: &Vec<i32>) -> Result<::tinc::__private::cel::CelValue<'_>, ::tinc::__private::cel::CelError<'_>> {
403                    Ok(#output)
404                }
405
406                #[test]
407                fn test_filter() {
408                    assert_eq!(filter(&vec![0, 1, -50, 50]).unwrap(), ::tinc::__private::cel::CelValue::List([
409                        ::tinc::__private::cel::CelValueConv::conv(1),
410                        ::tinc::__private::cel::CelValueConv::conv(50),
411                    ].into_iter().collect()));
412                }
413            },
414        ));
415    }
416
417    #[test]
418    #[cfg(not(valgrind))]
419    fn test_filter_runtime_cel_value() {
420        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
421        let compiler = Compiler::new(&registry);
422
423        let string_value = CompiledExpr::runtime(CelType::CelValue, parse_quote!(input));
424
425        let output = Filter
426            .compile(CompilerCtx::new(
427                compiler.child(),
428                Some(string_value),
429                &[cel_parser::parse("x").unwrap(), cel_parser::parse("x > 5").unwrap()],
430            ))
431            .unwrap();
432
433        insta::assert_snapshot!(postcompile::compile_str!(
434            postcompile::config! {
435                test: true,
436                dependencies: vec![
437                    postcompile::Dependency::version("tinc", "*"),
438                ],
439            },
440            quote! {
441                fn filter<'a>(input: &'a ::tinc::__private::cel::CelValue<'a>) -> Result<::tinc::__private::cel::CelValue<'a>, ::tinc::__private::cel::CelError<'a>> {
442                    Ok(#output)
443                }
444
445                #[test]
446                fn test_filter() {
447                    assert_eq!(filter(&tinc::__private::cel::CelValue::List([
448                        tinc::__private::cel::CelValueConv::conv(5),
449                        tinc::__private::cel::CelValueConv::conv(1),
450                        tinc::__private::cel::CelValueConv::conv(50),
451                         tinc::__private::cel::CelValueConv::conv(-50),
452                    ].into_iter().collect())).unwrap(), tinc::__private::cel::CelValue::List([
453                        tinc::__private::cel::CelValueConv::conv(50),
454                    ].into_iter().collect()));
455                }
456            },
457        ));
458    }
459
460    #[test]
461    #[cfg(not(valgrind))]
462    fn test_filter_const_requires_runtime() {
463        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
464        let compiler = Compiler::new(&registry);
465
466        let list_value = CompiledExpr::constant(CelValue::List(
467            [CelValueConv::conv(5), CelValueConv::conv(0), CelValueConv::conv(1)]
468                .into_iter()
469                .collect(),
470        ));
471
472        let output = Filter
473            .compile(CompilerCtx::new(
474                compiler.child(),
475                Some(list_value),
476                &[cel_parser::parse("x").unwrap(), cel_parser::parse("dyn(x >= 1)").unwrap()],
477            ))
478            .unwrap();
479
480        insta::assert_snapshot!(postcompile::compile_str!(
481            postcompile::config! {
482                test: true,
483                dependencies: vec![
484                    postcompile::Dependency::version("tinc", "*"),
485                ],
486            },
487            quote! {
488                fn filter() -> Result<::tinc::__private::cel::CelValue<'static>, ::tinc::__private::cel::CelError<'static>> {
489                    Ok(#output)
490                }
491
492                #[test]
493                fn test_filter() {
494                    assert_eq!(filter().unwrap(), ::tinc::__private::cel::CelValue::List([
495                        ::tinc::__private::cel::CelValueConv::conv(5),
496                        ::tinc::__private::cel::CelValueConv::conv(1),
497                    ].into_iter().collect()));
498                }
499            },
500        ));
501    }
502}