vector/api/schema/
filter.rs

1use std::collections::BTreeSet;
2
3use async_graphql::{InputObject, InputType};
4
5use super::components::{source, ComponentKind};
6
7/// Takes an `&Option<bool>` and returns early if false
8#[macro_export]
9macro_rules! filter_check {
10    ($($match:expr_2021),+) => {
11        $(
12            if matches!($match, Some(t) if !t) {
13                return false;
14            }
15        )+
16    }
17}
18
19#[derive(Default, InputObject)]
20/// Filter for String values
21pub struct StringFilter {
22    pub equals: Option<String>,
23    pub not_equals: Option<String>,
24    pub contains: Option<String>,
25    pub not_contains: Option<String>,
26    pub starts_with: Option<String>,
27    pub ends_with: Option<String>,
28}
29
30impl StringFilter {
31    pub fn filter_value(&self, value: &str) -> bool {
32        filter_check!(
33            // Equals
34            self.equals.as_ref().map(|s| value.eq(s)),
35            // Not equals
36            self.not_equals.as_ref().map(|s| !value.eq(s)),
37            // Contains
38            self.contains.as_ref().map(|s| value.contains(s)),
39            // Does not contain
40            self.not_contains.as_ref().map(|s| !value.contains(s)),
41            // Starts with
42            self.starts_with.as_ref().map(|s| value.starts_with(s)),
43            // Ends with
44            self.ends_with.as_ref().map(|s| value.ends_with(s))
45        );
46        true
47    }
48}
49
50#[derive(InputObject)]
51#[graphql(concrete(name = "SourceOutputTypeFilter", params(source::SourceOutputType)))]
52// Filter for GraphQL lists
53pub struct ListFilter<T: InputType + PartialEq + Eq + Ord> {
54    pub equals: Option<Vec<T>>,
55    pub not_equals: Option<Vec<T>>,
56    pub contains: Option<T>,
57    pub not_contains: Option<T>,
58}
59
60impl<T: InputType + PartialEq + Eq + Ord> ListFilter<T> {
61    pub fn filter_value(&self, value: Vec<T>) -> bool {
62        let val = BTreeSet::from_iter(value.iter());
63        filter_check!(
64            // Equals
65            self.equals
66                .as_ref()
67                .map(|s| BTreeSet::from_iter(s.iter()).eq(&val)),
68            // Not Equals
69            self.not_equals
70                .as_ref()
71                .map(|s| !BTreeSet::from_iter(s.iter()).eq(&val)),
72            // Contains
73            self.contains.as_ref().map(|s| val.contains(s)),
74            // Not Contains
75            self.not_contains.as_ref().map(|s| !val.contains(s))
76        );
77        true
78    }
79}
80
81#[derive(InputObject)]
82#[graphql(concrete(name = "ComponentKindFilter", params(ComponentKind)))]
83pub struct EqualityFilter<T: InputType + PartialEq + Eq> {
84    pub equals: Option<T>,
85    pub not_equals: Option<T>,
86}
87
88impl<T: InputType + PartialEq + Eq> EqualityFilter<T> {
89    pub fn filter_value(&self, value: T) -> bool {
90        filter_check!(
91            // Equals
92            self.equals.as_ref().map(|s| value.eq(s)),
93            // Not equals
94            self.not_equals.as_ref().map(|s| !value.eq(s))
95        );
96        true
97    }
98}
99
100/// CustomFilter trait to determine whether to include/exclude fields based on matches.
101pub trait CustomFilter<T> {
102    fn matches(&self, item: &T) -> bool;
103    fn or(&self) -> Option<&Vec<Self>>
104    where
105        Self: Sized;
106}
107
108/// Returns true if a provided `Item` passes all 'AND' or 'OR' filter rules, recursively.
109fn filter_item<Item, Filter>(item: &Item, f: &Filter) -> bool
110where
111    Filter: CustomFilter<Item>,
112{
113    f.matches(item)
114        || f.or()
115            .map_or_else(|| false, |f| f.iter().any(|f| filter_item(item, f)))
116}
117
118/// Filters items based on an implementation of `CustomFilter<T>`.
119pub fn filter_items<Item, Iter, Filter>(items: Iter, f: &Filter) -> Vec<Item>
120where
121    Iter: Iterator<Item = Item>,
122    Filter: CustomFilter<Item>,
123{
124    items.filter(|c| filter_item(c, f)).collect()
125}
126
127#[cfg(test)]
128mod test {
129    use super::StringFilter;
130
131    #[test]
132    fn string_equals() {
133        let value = "test";
134
135        let sf = StringFilter {
136            equals: value.to_string().into(),
137            ..Default::default()
138        };
139
140        assert!(sf.filter_value(value));
141        assert!(!sf.filter_value("not found"));
142    }
143
144    #[test]
145    fn string_not_equals() {
146        let value = "value";
147        let diff_value = "different value";
148
149        let sf = StringFilter {
150            not_equals: diff_value.to_string().into(),
151            ..Default::default()
152        };
153
154        assert!(sf.filter_value(value));
155        assert!(!sf.filter_value(diff_value));
156    }
157
158    #[test]
159    fn string_contains() {
160        let sf = StringFilter {
161            contains: "234".to_string().into(),
162            ..Default::default()
163        };
164
165        assert!(sf.filter_value("12345"));
166        assert!(!sf.filter_value("xxx"));
167    }
168
169    #[test]
170    fn string_not_contains() {
171        let contains = "xyz";
172
173        let sf = StringFilter {
174            not_contains: contains.to_string().into(),
175            ..Default::default()
176        };
177
178        assert!(sf.filter_value("abc"));
179        assert!(!sf.filter_value(contains));
180    }
181
182    #[test]
183    fn string_starts_with() {
184        let sf = StringFilter {
185            starts_with: "abc".to_string().into(),
186            ..Default::default()
187        };
188
189        assert!(sf.filter_value("abcdef"));
190        assert!(!sf.filter_value("xyz"));
191    }
192
193    #[test]
194    fn string_ends_with() {
195        let sf = StringFilter {
196            ends_with: "456".to_string().into(),
197            ..Default::default()
198        };
199
200        assert!(sf.filter_value("123456"));
201        assert!(!sf.filter_value("123"));
202    }
203
204    #[test]
205    fn string_multiple_all_match() {
206        let value = "123456";
207        let sf = StringFilter {
208            equals: value.to_string().into(),
209            not_equals: "xyz".to_string().into(),
210            contains: "234".to_string().into(),
211            not_contains: "678".to_string().into(),
212            starts_with: "123".to_string().into(),
213            ends_with: "456".to_string().into(),
214        };
215
216        assert!(sf.filter_value(value));
217        assert!(!sf.filter_value("should fail"));
218    }
219}