vector/api/schema/
filter.rs1use std::collections::BTreeSet;
2
3use async_graphql::{InputObject, InputType};
4
5use super::components::{source, ComponentKind};
6
7#[macro_export]
9macro_rules! filter_check {
10 ($($match:expr_2021),+) => {
11 $(
12 if matches!($match, Some(t) if !t) {
13 return false;
14 }
15 )+
16 }
17}
18
19#[derive(Default, InputObject)]
20pub struct StringFilter {
22 pub equals: Option<String>,
23 pub not_equals: Option<String>,
24 pub contains: Option<String>,
25 pub not_contains: Option<String>,
26 pub starts_with: Option<String>,
27 pub ends_with: Option<String>,
28}
29
30impl StringFilter {
31 pub fn filter_value(&self, value: &str) -> bool {
32 filter_check!(
33 self.equals.as_ref().map(|s| value.eq(s)),
35 self.not_equals.as_ref().map(|s| !value.eq(s)),
37 self.contains.as_ref().map(|s| value.contains(s)),
39 self.not_contains.as_ref().map(|s| !value.contains(s)),
41 self.starts_with.as_ref().map(|s| value.starts_with(s)),
43 self.ends_with.as_ref().map(|s| value.ends_with(s))
45 );
46 true
47 }
48}
49
50#[derive(InputObject)]
51#[graphql(concrete(name = "SourceOutputTypeFilter", params(source::SourceOutputType)))]
52pub struct ListFilter<T: InputType + PartialEq + Eq + Ord> {
54 pub equals: Option<Vec<T>>,
55 pub not_equals: Option<Vec<T>>,
56 pub contains: Option<T>,
57 pub not_contains: Option<T>,
58}
59
60impl<T: InputType + PartialEq + Eq + Ord> ListFilter<T> {
61 pub fn filter_value(&self, value: Vec<T>) -> bool {
62 let val = BTreeSet::from_iter(value.iter());
63 filter_check!(
64 self.equals
66 .as_ref()
67 .map(|s| BTreeSet::from_iter(s.iter()).eq(&val)),
68 self.not_equals
70 .as_ref()
71 .map(|s| !BTreeSet::from_iter(s.iter()).eq(&val)),
72 self.contains.as_ref().map(|s| val.contains(s)),
74 self.not_contains.as_ref().map(|s| !val.contains(s))
76 );
77 true
78 }
79}
80
81#[derive(InputObject)]
82#[graphql(concrete(name = "ComponentKindFilter", params(ComponentKind)))]
83pub struct EqualityFilter<T: InputType + PartialEq + Eq> {
84 pub equals: Option<T>,
85 pub not_equals: Option<T>,
86}
87
88impl<T: InputType + PartialEq + Eq> EqualityFilter<T> {
89 pub fn filter_value(&self, value: T) -> bool {
90 filter_check!(
91 self.equals.as_ref().map(|s| value.eq(s)),
93 self.not_equals.as_ref().map(|s| !value.eq(s))
95 );
96 true
97 }
98}
99
100pub trait CustomFilter<T> {
102 fn matches(&self, item: &T) -> bool;
103 fn or(&self) -> Option<&Vec<Self>>
104 where
105 Self: Sized;
106}
107
108fn filter_item<Item, Filter>(item: &Item, f: &Filter) -> bool
110where
111 Filter: CustomFilter<Item>,
112{
113 f.matches(item)
114 || f.or()
115 .map_or_else(|| false, |f| f.iter().any(|f| filter_item(item, f)))
116}
117
118pub fn filter_items<Item, Iter, Filter>(items: Iter, f: &Filter) -> Vec<Item>
120where
121 Iter: Iterator<Item = Item>,
122 Filter: CustomFilter<Item>,
123{
124 items.filter(|c| filter_item(c, f)).collect()
125}
126
127#[cfg(test)]
128mod test {
129 use super::StringFilter;
130
131 #[test]
132 fn string_equals() {
133 let value = "test";
134
135 let sf = StringFilter {
136 equals: value.to_string().into(),
137 ..Default::default()
138 };
139
140 assert!(sf.filter_value(value));
141 assert!(!sf.filter_value("not found"));
142 }
143
144 #[test]
145 fn string_not_equals() {
146 let value = "value";
147 let diff_value = "different value";
148
149 let sf = StringFilter {
150 not_equals: diff_value.to_string().into(),
151 ..Default::default()
152 };
153
154 assert!(sf.filter_value(value));
155 assert!(!sf.filter_value(diff_value));
156 }
157
158 #[test]
159 fn string_contains() {
160 let sf = StringFilter {
161 contains: "234".to_string().into(),
162 ..Default::default()
163 };
164
165 assert!(sf.filter_value("12345"));
166 assert!(!sf.filter_value("xxx"));
167 }
168
169 #[test]
170 fn string_not_contains() {
171 let contains = "xyz";
172
173 let sf = StringFilter {
174 not_contains: contains.to_string().into(),
175 ..Default::default()
176 };
177
178 assert!(sf.filter_value("abc"));
179 assert!(!sf.filter_value(contains));
180 }
181
182 #[test]
183 fn string_starts_with() {
184 let sf = StringFilter {
185 starts_with: "abc".to_string().into(),
186 ..Default::default()
187 };
188
189 assert!(sf.filter_value("abcdef"));
190 assert!(!sf.filter_value("xyz"));
191 }
192
193 #[test]
194 fn string_ends_with() {
195 let sf = StringFilter {
196 ends_with: "456".to_string().into(),
197 ..Default::default()
198 };
199
200 assert!(sf.filter_value("123456"));
201 assert!(!sf.filter_value("123"));
202 }
203
204 #[test]
205 fn string_multiple_all_match() {
206 let value = "123456";
207 let sf = StringFilter {
208 equals: value.to_string().into(),
209 not_equals: "xyz".to_string().into(),
210 contains: "234".to_string().into(),
211 not_contains: "678".to_string().into(),
212 starts_with: "123".to_string().into(),
213 ends_with: "456".to_string().into(),
214 };
215
216 assert!(sf.filter_value(value));
217 assert!(!sf.filter_value("should fail"));
218 }
219}