1use std::{
3 any::{Any, TypeId},
4 collections::HashMap,
5 marker::{Send, Sync},
6 sync::Arc,
7};
8
9#[derive(Clone, Debug, Default)]
12pub struct ExtraContext(Arc<HashMap<TypeId, ContextItem>>);
13
14type ContextItem = Box<dyn Any + Send + Sync>;
15
16impl ExtraContext {
17 pub fn single_value<T: Any + Send + Sync>(value: T) -> Self {
19 [Box::new(value) as _].into_iter().collect()
20 }
21
22 #[cfg(test)]
23 fn insert<T: Any + Send + Sync>(&mut self, value: T) {
26 Arc::get_mut(&mut self.0)
27 .expect("only insert into extra context when there is a single reference to it")
28 .insert(value.type_id(), Box::new(value));
29 }
30
31 pub fn get<T: 'static>(&self) -> Option<&T> {
33 self.0
34 .get(&TypeId::of::<T>())
35 .and_then(|t| t.downcast_ref())
36 }
37
38 pub fn get_or_default<T: Clone + Default + 'static>(&self) -> T {
40 self.get().cloned().unwrap_or_default()
41 }
42}
43
44impl FromIterator<ContextItem> for ExtraContext {
45 fn from_iter<T: IntoIterator<Item = ContextItem>>(iter: T) -> Self {
46 Self(Arc::new(
47 iter.into_iter()
48 .map(|item| ((*item).type_id(), item))
49 .collect(),
50 ))
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57
58 #[derive(Clone, Eq, PartialEq, Debug, Default)]
59 struct Peas {
60 beans: usize,
61 }
62
63 #[derive(Clone, Eq, PartialEq, Debug, Default)]
64 struct Potatoes(usize);
65
66 #[test]
67 fn get_fetches_item() {
68 let peas = Peas { beans: 42 };
69 let potatoes = Potatoes(8);
70
71 let mut context = ExtraContext::default();
72 context.insert(peas);
73 context.insert(potatoes);
74
75 assert_eq!(&Peas { beans: 42 }, context.get::<Peas>().unwrap());
76 assert_eq!(&Potatoes(8), context.get::<Potatoes>().unwrap());
77 }
78
79 #[test]
80 fn get_or_default_fetches_item() {
81 let potatoes = Potatoes(8);
82
83 let mut context = ExtraContext::default();
84 context.insert(potatoes);
85
86 assert_eq!(Potatoes(8), context.get_or_default::<Potatoes>());
87 assert_eq!(Peas::default(), context.get_or_default::<Peas>());
88 }
89
90 #[test]
91 fn duplicate_types() {
92 let potatoes = Potatoes(8);
93 let potatoes99 = Potatoes(99);
94
95 let mut context = ExtraContext::default();
96 context.insert(potatoes);
97 context.insert(potatoes99);
98
99 assert_eq!(&Potatoes(99), context.get::<Potatoes>().unwrap());
100 }
101}