tinc_build/codegen/cel/functions/
contains.rs

1use quote::quote;
2use syn::parse_quote;
3use tinc_cel::CelValue;
4
5use super::Function;
6use crate::codegen::cel::compiler::{CompileError, CompiledExpr, CompilerCtx, ConstantCompiledExpr, RuntimeCompiledExpr};
7use crate::codegen::cel::types::CelType;
8use crate::types::{ProtoModifiedValueType, ProtoType, ProtoValueType};
9
10#[derive(Debug, Clone, Default)]
11pub(crate) struct Contains;
12
13// this.contains(arg)
14// arg in this
15impl Function for Contains {
16    fn name(&self) -> &'static str {
17        "contains"
18    }
19
20    fn syntax(&self) -> &'static str {
21        "<this>.contains(<arg>)"
22    }
23
24    fn compile(&self, mut ctx: CompilerCtx) -> Result<CompiledExpr, CompileError> {
25        let Some(this) = ctx.this.take() else {
26            return Err(CompileError::syntax("missing this", self));
27        };
28
29        if ctx.args.len() != 1 {
30            return Err(CompileError::syntax("takes exactly one argument", self));
31        }
32
33        let arg = ctx.resolve(&ctx.args[0])?.into_cel()?;
34
35        if let CompiledExpr::Runtime(RuntimeCompiledExpr {
36            expr,
37            ty:
38                ty @ CelType::Proto(ProtoType::Modified(
39                    ProtoModifiedValueType::Repeated(item) | ProtoModifiedValueType::Map(item, _),
40                )),
41        }) = &this
42            && !matches!(item, ProtoValueType::Message { .. } | ProtoValueType::Enum(_))
43        {
44            let op = match &ty {
45                CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(_))) => {
46                    quote! { array_contains }
47                }
48                CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(_, _))) => {
49                    quote! { map_contains }
50                }
51                _ => unreachable!(),
52            };
53
54            return Ok(CompiledExpr::runtime(
55                CelType::Proto(ProtoType::Value(ProtoValueType::Bool)),
56                parse_quote! {
57                    ::tinc::__private::cel::#op(
58                        #expr,
59                        #arg,
60                    )
61                },
62            ));
63        }
64
65        let this = this.clone().into_cel()?;
66
67        match (this, arg) {
68            (
69                CompiledExpr::Constant(ConstantCompiledExpr { value: this }),
70                CompiledExpr::Constant(ConstantCompiledExpr { value: arg }),
71            ) => Ok(CompiledExpr::constant(CelValue::cel_contains(this, arg)?)),
72            (this, arg) => Ok(CompiledExpr::runtime(
73                CelType::Proto(ProtoType::Value(ProtoValueType::Bool)),
74                parse_quote! {
75                    ::tinc::__private::cel::CelValue::cel_contains(
76                        #this,
77                        #arg,
78                    )?
79                },
80            )),
81        }
82    }
83}
84
85#[cfg(test)]
86#[cfg(feature = "prost")]
87#[cfg_attr(coverage_nightly, coverage(off))]
88mod tests {
89    use quote::quote;
90    use syn::parse_quote;
91    use tinc_cel::CelValue;
92
93    use crate::codegen::cel::compiler::{CompiledExpr, Compiler, CompilerCtx};
94    use crate::codegen::cel::functions::{Contains, Function};
95    use crate::codegen::cel::types::CelType;
96    use crate::extern_paths::ExternPaths;
97    use crate::path_set::PathSet;
98    use crate::types::{ProtoModifiedValueType, ProtoType, ProtoTypeRegistry, ProtoValueType};
99
100    #[test]
101    fn test_contains_syntax() {
102        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
103        let compiler = Compiler::new(&registry);
104        insta::assert_debug_snapshot!(Contains.compile(CompilerCtx::new(compiler.child(), None, &[])), @r#"
105        Err(
106            InvalidSyntax {
107                message: "missing this",
108                syntax: "<this>.contains(<arg>)",
109            },
110        )
111        "#);
112
113        insta::assert_debug_snapshot!(Contains.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::String("hi".into()))), &[])), @r#"
114        Err(
115            InvalidSyntax {
116                message: "takes exactly one argument",
117                syntax: "<this>.contains(<arg>)",
118            },
119        )
120        "#);
121
122        insta::assert_debug_snapshot!(Contains.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::List(Default::default()))), &[
123            cel_parser::parse("1 + 1").unwrap(),
124        ])), @r"
125        Ok(
126            Constant(
127                ConstantCompiledExpr {
128                    value: Bool(
129                        false,
130                    ),
131                },
132            ),
133        )
134        ");
135    }
136
137    #[test]
138    #[cfg(not(valgrind))]
139    fn test_contains_runtime_string() {
140        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
141        let compiler = Compiler::new(&registry);
142
143        let string_value =
144            CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ProtoValueType::String)), parse_quote!(input));
145
146        let output = Contains
147            .compile(CompilerCtx::new(
148                compiler.child(),
149                Some(string_value),
150                &[cel_parser::parse("(1 + 1).string()").unwrap()],
151            ))
152            .unwrap();
153
154        insta::assert_snapshot!(postcompile::compile_str!(
155            postcompile::config! {
156                test: true,
157                dependencies: vec![
158                    postcompile::Dependency::version("tinc", "*"),
159                ],
160            },
161            quote! {
162                fn contains(input: &String) -> Result<bool, ::tinc::__private::cel::CelError<'_>> {
163                    Ok(#output)
164                }
165
166                #[test]
167                fn test_contains() {
168                    assert_eq!(contains(&"in2dastring".into()).unwrap(), true);
169                    assert_eq!(contains(&"in3dastring".into()).unwrap(), false);
170                }
171            },
172        ));
173    }
174
175    #[test]
176    #[cfg(not(valgrind))]
177    fn test_contains_runtime_map() {
178        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
179        let compiler = Compiler::new(&registry);
180
181        let string_value = CompiledExpr::runtime(
182            CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(
183                ProtoValueType::String,
184                ProtoValueType::Bool,
185            ))),
186            parse_quote!(input),
187        );
188
189        let output = Contains
190            .compile(CompilerCtx::new(
191                compiler.child(),
192                Some(string_value),
193                &[cel_parser::parse("'value'").unwrap()],
194            ))
195            .unwrap();
196
197        insta::assert_snapshot!(postcompile::compile_str!(
198            postcompile::config! {
199                test: true,
200                dependencies: vec![
201                    postcompile::Dependency::version("tinc", "*"),
202                ],
203            },
204            quote! {
205                fn contains(input: &std::collections::HashMap<String, bool>) -> Result<bool, ::tinc::__private::cel::CelError<'_>> {
206                    Ok(#output)
207                }
208
209                #[test]
210                fn test_contains() {
211                    assert_eq!(contains(&{
212                        let mut map = std::collections::HashMap::new();
213                        map.insert("value".to_string(), true);
214                        map
215                    }).unwrap(), true);
216                    assert_eq!(contains(&{
217                        let mut map = std::collections::HashMap::new();
218                        map.insert("not_value".to_string(), true);
219                        map
220                    }).unwrap(), false);
221                    assert_eq!(contains(&{
222                        let mut map = std::collections::HashMap::new();
223                        map.insert("xd".to_string(), true);
224                        map.insert("value".to_string(), true);
225                        map
226                    }).unwrap(), true);
227                }
228            },
229        ));
230    }
231
232    #[test]
233    #[cfg(not(valgrind))]
234    fn test_contains_runtime_repeated() {
235        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, ExternPaths::new(crate::Mode::Prost), PathSet::default());
236        let compiler = Compiler::new(&registry);
237
238        let string_value = CompiledExpr::runtime(
239            CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(ProtoValueType::String))),
240            parse_quote!(input),
241        );
242
243        let output = Contains
244            .compile(CompilerCtx::new(
245                compiler.child(),
246                Some(string_value),
247                &[cel_parser::parse("'value'").unwrap()],
248            ))
249            .unwrap();
250
251        insta::assert_snapshot!(postcompile::compile_str!(
252            postcompile::config! {
253                test: true,
254                dependencies: vec![
255                    postcompile::Dependency::version("tinc", "*"),
256                ],
257            },
258            quote! {
259                fn contains(input: &Vec<String>) -> Result<bool, ::tinc::__private::cel::CelError<'_>> {
260                    Ok(#output)
261                }
262
263                #[test]
264                fn test_contains() {
265                    assert_eq!(contains(&vec!["value".into()]).unwrap(), true);
266                    assert_eq!(contains(&vec!["not_value".into()]).unwrap(), false);
267                    assert_eq!(contains(&vec!["xd".into(), "value".into()]).unwrap(), true);
268                }
269            },
270        ));
271    }
272}