vector/
extra_context.rs

1//! ExtraContext is used for passing extra data to Vector's components when Vector is used as a library.
2use std::{
3    any::{Any, TypeId},
4    collections::HashMap,
5    marker::{Send, Sync},
6    sync::Arc,
7};
8
9/// Structure containing any extra data.
10/// The data is held in an [`Arc`] so is cheap to clone.
11#[derive(Clone, Debug, Default)]
12pub struct ExtraContext(Arc<HashMap<TypeId, ContextItem>>);
13
14type ContextItem = Box<dyn Any + Send + Sync>;
15
16impl ExtraContext {
17    /// Create a new `ExtraContext` that contains the single passed in value.
18    pub fn single_value<T: Any + Send + Sync>(value: T) -> Self {
19        [Box::new(value) as _].into_iter().collect()
20    }
21
22    #[cfg(test)]
23    /// This is only available for tests due to panic implications of making an Arc
24    /// mutable when there may be multiple references to it.
25    fn insert<T: Any + Send + Sync>(&mut self, value: T) {
26        Arc::get_mut(&mut self.0)
27            .expect("only insert into extra context when there is a single reference to it")
28            .insert(value.type_id(), Box::new(value));
29    }
30
31    /// Get an object from the context.
32    pub fn get<T: 'static>(&self) -> Option<&T> {
33        self.0
34            .get(&TypeId::of::<T>())
35            .and_then(|t| t.downcast_ref())
36    }
37
38    /// Get an object from the context, if it doesn't exist return the default.
39    pub fn get_or_default<T: Clone + Default + 'static>(&self) -> T {
40        self.get().cloned().unwrap_or_default()
41    }
42}
43
44impl FromIterator<ContextItem> for ExtraContext {
45    fn from_iter<T: IntoIterator<Item = ContextItem>>(iter: T) -> Self {
46        Self(Arc::new(
47            iter.into_iter()
48                .map(|item| ((*item).type_id(), item))
49                .collect(),
50        ))
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57
58    #[derive(Clone, Eq, PartialEq, Debug, Default)]
59    struct Peas {
60        beans: usize,
61    }
62
63    #[derive(Clone, Eq, PartialEq, Debug, Default)]
64    struct Potatoes(usize);
65
66    #[test]
67    fn get_fetches_item() {
68        let peas = Peas { beans: 42 };
69        let potatoes = Potatoes(8);
70
71        let mut context = ExtraContext::default();
72        context.insert(peas);
73        context.insert(potatoes);
74
75        assert_eq!(&Peas { beans: 42 }, context.get::<Peas>().unwrap());
76        assert_eq!(&Potatoes(8), context.get::<Potatoes>().unwrap());
77    }
78
79    #[test]
80    fn get_or_default_fetches_item() {
81        let potatoes = Potatoes(8);
82
83        let mut context = ExtraContext::default();
84        context.insert(potatoes);
85
86        assert_eq!(Potatoes(8), context.get_or_default::<Potatoes>());
87        assert_eq!(Peas::default(), context.get_or_default::<Peas>());
88    }
89
90    #[test]
91    fn duplicate_types() {
92        let potatoes = Potatoes(8);
93        let potatoes99 = Potatoes(99);
94
95        let mut context = ExtraContext::default();
96        context.insert(potatoes);
97        context.insert(potatoes99);
98
99        assert_eq!(&Potatoes(99), context.get::<Potatoes>().unwrap());
100    }
101}