1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote};
3use syn::parse_quote;
4use tinc_cel::CelValue;
5
6use super::Function;
7use crate::codegen::cel::compiler::{CompileError, CompiledExpr, CompilerCtx, ConstantCompiledExpr, RuntimeCompiledExpr};
8use crate::codegen::cel::types::CelType;
9use crate::types::{ProtoModifiedValueType, ProtoType, ProtoValueType};
10
11#[derive(Debug, Clone, Default)]
12pub(crate) struct Map;
13
14fn native_impl(iter: TokenStream, item_ident: syn::Ident, map_fn: impl ToTokens) -> syn::Expr {
15 parse_quote!({
16 let mut collected = Vec::new();
17 let mut iter = (#iter).into_iter();
18 loop {
19 let Some(#item_ident) = iter.next() else {
20 break ::tinc::__private::cel::CelValue::List(collected.into());
21 };
22
23 collected.push(#map_fn);
24 }
25 })
26}
27
28impl Function for Map {
30 fn name(&self) -> &'static str {
31 "map"
32 }
33
34 fn syntax(&self) -> &'static str {
35 "<this>.map(<ident>, <expr>)"
36 }
37
38 fn compile(&self, ctx: CompilerCtx) -> Result<CompiledExpr, CompileError> {
39 let Some(this) = &ctx.this else {
40 return Err(CompileError::syntax("missing this", self));
41 };
42
43 if ctx.args.len() != 2 {
44 return Err(CompileError::syntax("invalid number of args", self));
45 }
46
47 let cel_parser::Expression::Ident(variable) = &ctx.args[0] else {
48 return Err(CompileError::syntax("first argument must be an ident", self));
49 };
50
51 match this {
52 CompiledExpr::Runtime(RuntimeCompiledExpr { expr, ty }) => {
53 let mut child_ctx = ctx.child();
54
55 match ty {
56 CelType::CelValue => {
57 child_ctx.add_variable(variable, CompiledExpr::runtime(CelType::CelValue, parse_quote!(item)));
58 }
59 CelType::Proto(ProtoType::Modified(
60 ProtoModifiedValueType::Repeated(ty) | ProtoModifiedValueType::Map(ty, _),
61 )) => {
62 child_ctx.add_variable(
63 variable,
64 CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ty.clone())), parse_quote!(item)),
65 );
66 }
67 v => {
68 return Err(CompileError::TypeConversion {
69 ty: Box::new(v.clone()),
70 message: "type cannot be iterated over".to_string(),
71 });
72 }
73 };
74
75 let arg = child_ctx.resolve(&ctx.args[1])?.into_cel()?;
76
77 Ok(CompiledExpr::runtime(
78 CelType::CelValue,
79 match &ty {
80 CelType::CelValue => parse_quote! {
81 ::tinc::__private::cel::CelValue::cel_map(#expr, |item| {
82 ::core::result::Result::Ok(
83 #arg
84 )
85 })?
86 },
87 CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(_, _))) => {
88 native_impl(quote!((#expr).keys()), parse_quote!(item), arg)
89 }
90 CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(_))) => {
91 native_impl(quote!((#expr).iter()), parse_quote!(item), arg)
92 }
93 _ => unreachable!(),
94 },
95 ))
96 }
97 CompiledExpr::Constant(ConstantCompiledExpr {
98 value: value @ (CelValue::List(_) | CelValue::Map(_)),
99 }) => {
100 let compile_val = |value: CelValue<'static>| {
101 let mut child_ctx = ctx.child();
102
103 child_ctx.add_variable(variable, CompiledExpr::constant(value));
104
105 child_ctx.resolve(&ctx.args[1])?.into_cel()
106 };
107
108 let collected: Result<Vec<_>, _> = match value {
109 CelValue::List(item) => item.iter().cloned().map(compile_val).collect(),
110 CelValue::Map(item) => item.iter().map(|(key, _)| key).cloned().map(compile_val).collect(),
111 _ => unreachable!(),
112 };
113
114 let collected = collected?;
115 if collected.iter().any(|c| matches!(c, CompiledExpr::Runtime(_))) {
116 Ok(CompiledExpr::runtime(
117 CelType::Proto(ProtoType::Value(ProtoValueType::Bool)),
118 native_impl(quote!([#(#collected),*]), parse_quote!(item), quote!(item)),
119 ))
120 } else {
121 Ok(CompiledExpr::constant(CelValue::List(
122 collected
123 .into_iter()
124 .map(|c| match c {
125 CompiledExpr::Constant(ConstantCompiledExpr { value }) => value,
126 _ => unreachable!("all values must be constant"),
127 })
128 .collect(),
129 )))
130 }
131 }
132 CompiledExpr::Constant(ConstantCompiledExpr { value }) => Err(CompileError::TypeConversion {
133 ty: Box::new(CelType::CelValue),
134 message: format!("{value:?} cannot be iterated over"),
135 }),
136 }
137 }
138}
139
140#[cfg(test)]
141#[cfg(feature = "prost")]
142#[cfg_attr(coverage_nightly, coverage(off))]
143mod tests {
144 use quote::quote;
145 use syn::parse_quote;
146 use tinc_cel::{CelValue, CelValueConv};
147
148 use crate::codegen::cel::compiler::{CompiledExpr, Compiler, CompilerCtx};
149 use crate::codegen::cel::functions::{Function, Map};
150 use crate::codegen::cel::types::CelType;
151 use crate::extern_paths::ExternPaths;
152 use crate::path_set::PathSet;
153 use crate::types::{ProtoModifiedValueType, ProtoType, ProtoTypeRegistry, ProtoValueType};
154
155 #[test]
156 fn test_map_syntax() {
157 let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
158 let compiler = Compiler::new(®istry);
159 insta::assert_debug_snapshot!(Map.compile(CompilerCtx::new(compiler.child(), None, &[])), @r#"
160 Err(
161 InvalidSyntax {
162 message: "missing this",
163 syntax: "<this>.map(<ident>, <expr>)",
164 },
165 )
166 "#);
167
168 insta::assert_debug_snapshot!(Map.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::String("hi".into()))), &[])), @r#"
169 Err(
170 InvalidSyntax {
171 message: "invalid number of args",
172 syntax: "<this>.map(<ident>, <expr>)",
173 },
174 )
175 "#);
176
177 insta::assert_debug_snapshot!(Map.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::String("hi".into()))), &[
178 cel_parser::parse("x").unwrap(),
179 cel_parser::parse("dyn(x >= 1)").unwrap(),
180 ])), @r#"
181 Err(
182 TypeConversion {
183 ty: CelValue,
184 message: "String(Borrowed(\"hi\")) cannot be iterated over",
185 },
186 )
187 "#);
188
189 insta::assert_debug_snapshot!(Map.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ProtoValueType::Bool)), parse_quote!(input))), &[
190 cel_parser::parse("x").unwrap(),
191 cel_parser::parse("dyn(x >= 1)").unwrap(),
192 ])), @r#"
193 Err(
194 TypeConversion {
195 ty: Proto(
196 Value(
197 Bool,
198 ),
199 ),
200 message: "type cannot be iterated over",
201 },
202 )
203 "#);
204
205 insta::assert_debug_snapshot!(Map.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::List([
206 CelValueConv::conv(0),
207 CelValueConv::conv(1),
208 CelValueConv::conv(-50),
209 CelValueConv::conv(50),
210 ].into_iter().collect()))), &[
211 cel_parser::parse("x").unwrap(),
212 cel_parser::parse("x + 1").unwrap(),
213 ])), @r"
214 Ok(
215 Constant(
216 ConstantCompiledExpr {
217 value: List(
218 [
219 Number(
220 I64(
221 1,
222 ),
223 ),
224 Number(
225 I64(
226 2,
227 ),
228 ),
229 Number(
230 I64(
231 -49,
232 ),
233 ),
234 Number(
235 I64(
236 51,
237 ),
238 ),
239 ],
240 ),
241 },
242 ),
243 )
244 ");
245
246 let input = CompiledExpr::constant(CelValue::Map(
247 [
248 (CelValueConv::conv("key0"), CelValueConv::conv(0)),
249 (CelValueConv::conv("key1"), CelValueConv::conv(1)),
250 (CelValueConv::conv("key2"), CelValueConv::conv(-50)),
251 (CelValueConv::conv("key3"), CelValueConv::conv(50)),
252 ]
253 .into_iter()
254 .collect(),
255 ));
256
257 let mut ctx = compiler.child();
258 ctx.add_variable("input", input.clone());
259
260 insta::assert_debug_snapshot!(Map.compile(CompilerCtx::new(ctx, Some(input), &[
261 cel_parser::parse("x").unwrap(),
262 cel_parser::parse("input[x]").unwrap(),
263 ])), @r"
264 Ok(
265 Constant(
266 ConstantCompiledExpr {
267 value: List(
268 [
269 Number(
270 I64(
271 0,
272 ),
273 ),
274 Number(
275 I64(
276 1,
277 ),
278 ),
279 Number(
280 I64(
281 -50,
282 ),
283 ),
284 Number(
285 I64(
286 50,
287 ),
288 ),
289 ],
290 ),
291 },
292 ),
293 )
294 ");
295 }
296
297 #[test]
298 #[cfg(not(valgrind))]
299 fn test_map_runtime_map() {
300 let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
301 let mut compiler = Compiler::new(®istry);
302
303 let string_value = CompiledExpr::runtime(
304 CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(
305 ProtoValueType::String,
306 ProtoValueType::Int32,
307 ))),
308 parse_quote!(input),
309 );
310
311 compiler.add_variable("input", string_value.clone());
312
313 let output = Map
314 .compile(CompilerCtx::new(
315 compiler.child(),
316 Some(string_value),
317 &[cel_parser::parse("x").unwrap(), cel_parser::parse("input[x] * 2").unwrap()],
318 ))
319 .unwrap();
320
321 insta::assert_snapshot!(postcompile::compile_str!(
322 postcompile::config! {
323 test: true,
324 dependencies: vec![
325 postcompile::Dependency::version("tinc", "*"),
326 ],
327 },
328 quote! {
329 fn filter(input: &std::collections::BTreeMap<String, i32>) -> Result<::tinc::__private::cel::CelValue<'_>, ::tinc::__private::cel::CelError<'_>> {
330 Ok(#output)
331 }
332
333 #[test]
334 fn test_filter() {
335 assert_eq!(filter(&{
336 let mut map = std::collections::BTreeMap::new();
337 map.insert("0".to_string(), 0);
338 map.insert("1".to_string(), 1);
339 map.insert("2".to_string(), -50);
340 map.insert("3".to_string(), 50);
341 map
342 }).unwrap(), ::tinc::__private::cel::CelValue::List([
343 ::tinc::__private::cel::CelValueConv::conv(0),
344 ::tinc::__private::cel::CelValueConv::conv(2),
345 ::tinc::__private::cel::CelValueConv::conv(-100),
346 ::tinc::__private::cel::CelValueConv::conv(100),
347 ].into_iter().collect()));
348 }
349 },
350 ));
351 }
352
353 #[test]
354 #[cfg(not(valgrind))]
355 fn test_map_runtime_repeated() {
356 let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
357 let compiler = Compiler::new(®istry);
358
359 let string_value = CompiledExpr::runtime(
360 CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(ProtoValueType::Int32))),
361 parse_quote!(input),
362 );
363
364 let output = Map
365 .compile(CompilerCtx::new(
366 compiler.child(),
367 Some(string_value),
368 &[cel_parser::parse("x").unwrap(), cel_parser::parse("x * 100").unwrap()],
369 ))
370 .unwrap();
371
372 insta::assert_snapshot!(postcompile::compile_str!(
373 postcompile::config! {
374 test: true,
375 dependencies: vec![
376 postcompile::Dependency::version("tinc", "*"),
377 ],
378 },
379 quote! {
380 fn filter(input: &Vec<i32>) -> Result<::tinc::__private::cel::CelValue<'_>, ::tinc::__private::cel::CelError<'_>> {
381 Ok(#output)
382 }
383
384 #[test]
385 fn test_filter() {
386 assert_eq!(filter(&vec![0, 1, -50, 50]).unwrap(), ::tinc::__private::cel::CelValue::List([
387 ::tinc::__private::cel::CelValueConv::conv(0),
388 ::tinc::__private::cel::CelValueConv::conv(100),
389 ::tinc::__private::cel::CelValueConv::conv(-5000),
390 ::tinc::__private::cel::CelValueConv::conv(5000),
391 ].into_iter().collect()));
392 }
393 },
394 ));
395 }
396
397 #[test]
398 #[cfg(not(valgrind))]
399 fn test_map_runtime_cel_value() {
400 let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
401 let compiler = Compiler::new(®istry);
402
403 let string_value = CompiledExpr::runtime(CelType::CelValue, parse_quote!(input));
404
405 let output = Map
406 .compile(CompilerCtx::new(
407 compiler.child(),
408 Some(string_value),
409 &[cel_parser::parse("x").unwrap(), cel_parser::parse("x + 1").unwrap()],
410 ))
411 .unwrap();
412
413 insta::assert_snapshot!(postcompile::compile_str!(
414 postcompile::config! {
415 test: true,
416 dependencies: vec![
417 postcompile::Dependency::version("tinc", "*"),
418 ],
419 },
420 quote! {
421 fn filter<'a>(input: &'a ::tinc::__private::cel::CelValue<'a>) -> Result<::tinc::__private::cel::CelValue<'a>, ::tinc::__private::cel::CelError<'a>> {
422 Ok(#output)
423 }
424
425 #[test]
426 fn test_filter() {
427 assert_eq!(filter(&tinc::__private::cel::CelValue::List([
428 tinc::__private::cel::CelValueConv::conv(5),
429 tinc::__private::cel::CelValueConv::conv(1),
430 tinc::__private::cel::CelValueConv::conv(50),
431 tinc::__private::cel::CelValueConv::conv(-50),
432 ].into_iter().collect())).unwrap(), tinc::__private::cel::CelValue::List([
433 tinc::__private::cel::CelValueConv::conv(6),
434 tinc::__private::cel::CelValueConv::conv(2),
435 tinc::__private::cel::CelValueConv::conv(51),
436 tinc::__private::cel::CelValueConv::conv(-49),
437 ].into_iter().collect()));
438 }
439 },
440 ));
441 }
442
443 #[test]
444 #[cfg(not(valgrind))]
445 fn test_map_const_requires_runtime() {
446 let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
447 let compiler = Compiler::new(®istry);
448
449 let list_value = CompiledExpr::constant(CelValue::List(
450 [CelValueConv::conv(5), CelValueConv::conv(0), CelValueConv::conv(1)]
451 .into_iter()
452 .collect(),
453 ));
454
455 let output = Map
456 .compile(CompilerCtx::new(
457 compiler.child(),
458 Some(list_value),
459 &[cel_parser::parse("x").unwrap(), cel_parser::parse("dyn(x / 2)").unwrap()],
460 ))
461 .unwrap();
462
463 insta::assert_snapshot!(postcompile::compile_str!(
464 postcompile::config! {
465 test: true,
466 dependencies: vec![
467 postcompile::Dependency::version("tinc", "*"),
468 ],
469 },
470 quote! {
471 fn filter() -> Result<::tinc::__private::cel::CelValue<'static>, ::tinc::__private::cel::CelError<'static>> {
472 Ok(#output)
473 }
474
475 #[test]
476 fn test_filter() {
477 assert_eq!(filter().unwrap(), ::tinc::__private::cel::CelValue::List([
478 ::tinc::__private::cel::CelValueConv::conv(2),
479 ::tinc::__private::cel::CelValueConv::conv(0),
480 ::tinc::__private::cel::CelValueConv::conv(0),
481 ].into_iter().collect()));
482 }
483 },
484 ));
485 }
486}