1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote};
3use syn::parse_quote;
4use tinc_cel::CelValue;
5
6use super::Function;
7use crate::codegen::cel::compiler::{CompileError, CompiledExpr, CompilerCtx, ConstantCompiledExpr, RuntimeCompiledExpr};
8use crate::codegen::cel::types::CelType;
9use crate::types::{ProtoModifiedValueType, ProtoType, ProtoValueType};
10
11#[derive(Debug, Clone, Default)]
12pub(crate) struct Filter;
13
14fn native_impl(iter: TokenStream, item_ident: syn::Ident, compare: impl ToTokens) -> syn::Expr {
15 parse_quote!({
16 let mut collected = Vec::new();
17 let mut iter = (#iter).into_iter();
18 loop {
19 let Some(#item_ident) = iter.next() else {
20 break ::tinc::__private::cel::CelValue::List(collected.into());
21 };
22
23 if {
24 let #item_ident = #item_ident.clone();
25 #compare
26 } {
27 collected.push(#item_ident);
28 }
29 }
30 })
31}
32
33impl Function for Filter {
35 fn name(&self) -> &'static str {
36 "filter"
37 }
38
39 fn syntax(&self) -> &'static str {
40 "<this>.filter(<ident>, <expr>)"
41 }
42
43 fn compile(&self, ctx: CompilerCtx) -> Result<CompiledExpr, CompileError> {
44 let Some(this) = &ctx.this else {
45 return Err(CompileError::syntax("missing this", self));
46 };
47
48 if ctx.args.len() != 2 {
49 return Err(CompileError::syntax("invalid number of args", self));
50 }
51
52 let cel_parser::Expression::Ident(variable) = &ctx.args[0] else {
53 return Err(CompileError::syntax("first argument must be an ident", self));
54 };
55
56 match this {
57 CompiledExpr::Runtime(RuntimeCompiledExpr { expr, ty }) => {
58 let mut child_ctx = ctx.child();
59
60 match ty {
61 CelType::CelValue => {
62 child_ctx.add_variable(variable, CompiledExpr::runtime(CelType::CelValue, parse_quote!(item)));
63 }
64 CelType::Proto(ProtoType::Modified(
65 ProtoModifiedValueType::Repeated(ty) | ProtoModifiedValueType::Map(ty, _),
66 )) => {
67 child_ctx.add_variable(
68 variable,
69 CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ty.clone())), parse_quote!(item)),
70 );
71 }
72 v => {
73 return Err(CompileError::TypeConversion {
74 ty: Box::new(v.clone()),
75 message: "type cannot be iterated over".to_string(),
76 });
77 }
78 };
79
80 let arg = child_ctx.resolve(&ctx.args[1])?.into_bool(&child_ctx);
81
82 Ok(CompiledExpr::runtime(
83 CelType::CelValue,
84 match ty {
85 CelType::CelValue => parse_quote! {
86 ::tinc::__private::cel::CelValue::cel_filter(#expr, |item| {
87 ::core::result::Result::Ok(
88 #arg
89 )
90 })?
91 },
92 CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(ty, _))) => {
93 let cel_ty =
94 CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ty.clone())), parse_quote!(item))
95 .into_cel()?;
96
97 native_impl(
98 quote!(
99 (#expr).keys().map(|item| #cel_ty)
100 ),
101 parse_quote!(item),
102 arg,
103 )
104 }
105 CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(ty))) => {
106 let cel_ty =
107 CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ty.clone())), parse_quote!(item))
108 .into_cel()?;
109
110 native_impl(
111 quote!(
112 (#expr).iter().map(|item| #cel_ty)
113 ),
114 parse_quote!(item),
115 arg,
116 )
117 }
118 _ => unreachable!(),
119 },
120 ))
121 }
122 CompiledExpr::Constant(ConstantCompiledExpr {
123 value: value @ (CelValue::List(_) | CelValue::Map(_)),
124 }) => {
125 let compile_val = |value: CelValue<'static>| {
126 let mut child_ctx = ctx.child();
127
128 child_ctx.add_variable(variable, CompiledExpr::constant(value.clone()));
129
130 child_ctx.resolve(&ctx.args[1]).map(|v| (value, v.into_bool(&child_ctx)))
131 };
132
133 let collected: Result<Vec<_>, _> = match value {
134 CelValue::List(item) => item.iter().cloned().map(compile_val).collect(),
135 CelValue::Map(item) => item.iter().map(|(key, _)| key).cloned().map(compile_val).collect(),
136 _ => unreachable!(),
137 };
138
139 let collected = collected?;
140 if collected.iter().any(|(_, c)| matches!(c, CompiledExpr::Runtime(_))) {
141 let collected = collected.into_iter().map(|(item, expr)| {
142 let item = CompiledExpr::constant(item);
143 quote! {
144 if #expr {
145 collected.push(#item);
146 }
147 }
148 });
149
150 Ok(CompiledExpr::runtime(
151 CelType::Proto(ProtoType::Value(ProtoValueType::Bool)),
152 parse_quote!({
153 let mut collected = Vec::new();
154 #(#collected)*
155 ::tinc::__private::cel::CelValue::List(collected.into())
156 }),
157 ))
158 } else {
159 Ok(CompiledExpr::constant(CelValue::List(
160 collected
161 .into_iter()
162 .filter_map(|(item, c)| match c {
163 CompiledExpr::Constant(ConstantCompiledExpr { value }) => {
164 if value.to_bool() {
165 Some(item)
166 } else {
167 None
168 }
169 }
170 _ => unreachable!("all values must be constant"),
171 })
172 .collect(),
173 )))
174 }
175 }
176 CompiledExpr::Constant(ConstantCompiledExpr { value }) => Err(CompileError::TypeConversion {
177 ty: Box::new(CelType::CelValue),
178 message: format!("{value:?} cannot be iterated over"),
179 }),
180 }
181 }
182}
183
184#[cfg(test)]
185#[cfg(feature = "prost")]
186#[cfg_attr(coverage_nightly, coverage(off))]
187mod tests {
188 use quote::quote;
189 use syn::parse_quote;
190 use tinc_cel::{CelValue, CelValueConv};
191
192 use crate::codegen::cel::compiler::{CompiledExpr, Compiler, CompilerCtx};
193 use crate::codegen::cel::functions::{Filter, Function};
194 use crate::codegen::cel::types::CelType;
195 use crate::extern_paths::ExternPaths;
196 use crate::path_set::PathSet;
197 use crate::types::{ProtoModifiedValueType, ProtoType, ProtoTypeRegistry, ProtoValueType};
198
199 #[test]
200 fn test_filter_syntax() {
201 let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
202 let compiler = Compiler::new(®istry);
203 insta::assert_debug_snapshot!(Filter.compile(CompilerCtx::new(compiler.child(), None, &[])), @r#"
204 Err(
205 InvalidSyntax {
206 message: "missing this",
207 syntax: "<this>.filter(<ident>, <expr>)",
208 },
209 )
210 "#);
211
212 insta::assert_debug_snapshot!(Filter.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::String("hi".into()))), &[])), @r#"
213 Err(
214 InvalidSyntax {
215 message: "invalid number of args",
216 syntax: "<this>.filter(<ident>, <expr>)",
217 },
218 )
219 "#);
220
221 insta::assert_debug_snapshot!(Filter.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::String("hi".into()))), &[
222 cel_parser::parse("x").unwrap(),
223 cel_parser::parse("dyn(x >= 1)").unwrap(),
224 ])), @r#"
225 Err(
226 TypeConversion {
227 ty: CelValue,
228 message: "String(Borrowed(\"hi\")) cannot be iterated over",
229 },
230 )
231 "#);
232
233 insta::assert_debug_snapshot!(Filter.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ProtoValueType::Bool)), parse_quote!(input))), &[
234 cel_parser::parse("x").unwrap(),
235 cel_parser::parse("dyn(x >= 1)").unwrap(),
236 ])), @r#"
237 Err(
238 TypeConversion {
239 ty: Proto(
240 Value(
241 Bool,
242 ),
243 ),
244 message: "type cannot be iterated over",
245 },
246 )
247 "#);
248
249 insta::assert_debug_snapshot!(Filter.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::List([
250 CelValueConv::conv(0),
251 CelValueConv::conv(1),
252 CelValueConv::conv(-50),
253 CelValueConv::conv(50),
254 ].into_iter().collect()))), &[
255 cel_parser::parse("x").unwrap(),
256 cel_parser::parse("x >= 1").unwrap(),
257 ])), @r"
258 Ok(
259 Constant(
260 ConstantCompiledExpr {
261 value: List(
262 [
263 Number(
264 I64(
265 1,
266 ),
267 ),
268 Number(
269 I64(
270 50,
271 ),
272 ),
273 ],
274 ),
275 },
276 ),
277 )
278 ");
279
280 let input = CompiledExpr::constant(CelValue::Map(
281 [
282 (CelValueConv::conv("key0"), CelValueConv::conv(0)),
283 (CelValueConv::conv("key1"), CelValueConv::conv(1)),
284 (CelValueConv::conv("key2"), CelValueConv::conv(-50)),
285 (CelValueConv::conv("key3"), CelValueConv::conv(50)),
286 ]
287 .into_iter()
288 .collect(),
289 ));
290
291 let mut ctx = compiler.child();
292 ctx.add_variable("input", input.clone());
293
294 insta::assert_debug_snapshot!(Filter.compile(CompilerCtx::new(ctx, Some(input), &[
295 cel_parser::parse("x").unwrap(),
296 cel_parser::parse("input[x] >= 1").unwrap(),
297 ])), @r#"
298 Ok(
299 Constant(
300 ConstantCompiledExpr {
301 value: List(
302 [
303 String(
304 Borrowed(
305 "key1",
306 ),
307 ),
308 String(
309 Borrowed(
310 "key3",
311 ),
312 ),
313 ],
314 ),
315 },
316 ),
317 )
318 "#);
319 }
320
321 #[test]
322 #[cfg(not(valgrind))]
323 fn test_filter_runtime_map() {
324 let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
325 let mut compiler = Compiler::new(®istry);
326
327 let string_value = CompiledExpr::runtime(
328 CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(
329 ProtoValueType::String,
330 ProtoValueType::Int32,
331 ))),
332 parse_quote!(input),
333 );
334
335 compiler.add_variable("input", string_value.clone());
336
337 let output = Filter
338 .compile(CompilerCtx::new(
339 compiler.child(),
340 Some(string_value),
341 &[cel_parser::parse("x").unwrap(), cel_parser::parse("input[x] >= 1").unwrap()],
342 ))
343 .unwrap();
344
345 insta::assert_snapshot!(postcompile::compile_str!(
346 postcompile::config! {
347 test: true,
348 dependencies: vec![
349 postcompile::Dependency::version("tinc", "*"),
350 ],
351 },
352 quote! {
353 fn filter(input: &std::collections::BTreeMap<String, i32>) -> Result<::tinc::__private::cel::CelValue<'_>, ::tinc::__private::cel::CelError<'_>> {
354 Ok(#output)
355 }
356
357 #[test]
358 fn test_filter() {
359 assert_eq!(filter(&{
360 let mut map = std::collections::BTreeMap::new();
361 map.insert("0".to_string(), 0);
362 map.insert("1".to_string(), 1);
363 map.insert("-50".to_string(), -50);
364 map.insert("50".to_string(), 50);
365 map
366 }).unwrap(), ::tinc::__private::cel::CelValue::List([
367 ::tinc::__private::cel::CelValueConv::conv("1"),
368 ::tinc::__private::cel::CelValueConv::conv("50"),
369 ].into_iter().collect()));
370 }
371 },
372 ));
373 }
374
375 #[test]
376 #[cfg(not(valgrind))]
377 fn test_filter_runtime_repeated() {
378 let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
379 let compiler = Compiler::new(®istry);
380
381 let string_value = CompiledExpr::runtime(
382 CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(ProtoValueType::Int32))),
383 parse_quote!(input),
384 );
385
386 let output = Filter
387 .compile(CompilerCtx::new(
388 compiler.child(),
389 Some(string_value),
390 &[cel_parser::parse("x").unwrap(), cel_parser::parse("x >= 1").unwrap()],
391 ))
392 .unwrap();
393
394 insta::assert_snapshot!(postcompile::compile_str!(
395 postcompile::config! {
396 test: true,
397 dependencies: vec![
398 postcompile::Dependency::version("tinc", "*"),
399 ],
400 },
401 quote! {
402 fn filter(input: &Vec<i32>) -> Result<::tinc::__private::cel::CelValue<'_>, ::tinc::__private::cel::CelError<'_>> {
403 Ok(#output)
404 }
405
406 #[test]
407 fn test_filter() {
408 assert_eq!(filter(&vec![0, 1, -50, 50]).unwrap(), ::tinc::__private::cel::CelValue::List([
409 ::tinc::__private::cel::CelValueConv::conv(1),
410 ::tinc::__private::cel::CelValueConv::conv(50),
411 ].into_iter().collect()));
412 }
413 },
414 ));
415 }
416
417 #[test]
418 #[cfg(not(valgrind))]
419 fn test_filter_runtime_cel_value() {
420 let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
421 let compiler = Compiler::new(®istry);
422
423 let string_value = CompiledExpr::runtime(CelType::CelValue, parse_quote!(input));
424
425 let output = Filter
426 .compile(CompilerCtx::new(
427 compiler.child(),
428 Some(string_value),
429 &[cel_parser::parse("x").unwrap(), cel_parser::parse("x > 5").unwrap()],
430 ))
431 .unwrap();
432
433 insta::assert_snapshot!(postcompile::compile_str!(
434 postcompile::config! {
435 test: true,
436 dependencies: vec![
437 postcompile::Dependency::version("tinc", "*"),
438 ],
439 },
440 quote! {
441 fn filter<'a>(input: &'a ::tinc::__private::cel::CelValue<'a>) -> Result<::tinc::__private::cel::CelValue<'a>, ::tinc::__private::cel::CelError<'a>> {
442 Ok(#output)
443 }
444
445 #[test]
446 fn test_filter() {
447 assert_eq!(filter(&tinc::__private::cel::CelValue::List([
448 tinc::__private::cel::CelValueConv::conv(5),
449 tinc::__private::cel::CelValueConv::conv(1),
450 tinc::__private::cel::CelValueConv::conv(50),
451 tinc::__private::cel::CelValueConv::conv(-50),
452 ].into_iter().collect())).unwrap(), tinc::__private::cel::CelValue::List([
453 tinc::__private::cel::CelValueConv::conv(50),
454 ].into_iter().collect()));
455 }
456 },
457 ));
458 }
459
460 #[test]
461 #[cfg(not(valgrind))]
462 fn test_filter_const_requires_runtime() {
463 let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
464 let compiler = Compiler::new(®istry);
465
466 let list_value = CompiledExpr::constant(CelValue::List(
467 [CelValueConv::conv(5), CelValueConv::conv(0), CelValueConv::conv(1)]
468 .into_iter()
469 .collect(),
470 ));
471
472 let output = Filter
473 .compile(CompilerCtx::new(
474 compiler.child(),
475 Some(list_value),
476 &[cel_parser::parse("x").unwrap(), cel_parser::parse("dyn(x >= 1)").unwrap()],
477 ))
478 .unwrap();
479
480 insta::assert_snapshot!(postcompile::compile_str!(
481 postcompile::config! {
482 test: true,
483 dependencies: vec![
484 postcompile::Dependency::version("tinc", "*"),
485 ],
486 },
487 quote! {
488 fn filter() -> Result<::tinc::__private::cel::CelValue<'static>, ::tinc::__private::cel::CelError<'static>> {
489 Ok(#output)
490 }
491
492 #[test]
493 fn test_filter() {
494 assert_eq!(filter().unwrap(), ::tinc::__private::cel::CelValue::List([
495 ::tinc::__private::cel::CelValueConv::conv(5),
496 ::tinc::__private::cel::CelValueConv::conv(1),
497 ].into_iter().collect()));
498 }
499 },
500 ));
501 }
502}