use crate::{Node, WasmNotSend}; use dyn_any::DynFuture; use std::future::Future; use std::hash::DefaultHasher; use std::ops::Deref; use std::sync::Arc; use std::sync::Mutex; /// Caches the output of a given Node and acts as a proxy #[derive(Default)] pub struct MemoNode { cache: Arc>>, node: CachedNode, } impl<'i, I: Hash + 'i, T: 'i + Clone + WasmNotSend, CachedNode: 'i> Node<'i, I> for MemoNode where CachedNode: for<'any_input> Node<'any_input, I>, for<'a> >::Output: Future + WasmNotSend, { // TODO: This should return a reference to the cached cached_value // but that requires a lot of lifetime magic <- This was suggested by copilot but is pretty accurate xD type Output = DynFuture<'i, T>; fn eval(&'i self, input: I) -> Self::Output { let mut hasher = DefaultHasher::new(); input.hash(&mut hasher); let hash = hasher.finish(); if let Some(data) = self.cache.lock().as_ref().unwrap().as_ref().and_then(|data| (data.0 == hash).then_some(data.1.clone())) { Box::pin(async move { data }) } else { let fut = self.node.eval(input); let cache = self.cache.clone(); Box::pin(async move { let value = fut.await; *cache.lock().unwrap() = Some((hash, value.clone())); value }) } } fn reset(&self) { self.cache.lock().unwrap().take(); } } impl MemoNode { pub fn new(node: CachedNode) -> MemoNode { MemoNode { cache: Default::default(), node } } } /// Caches the output of a given Node and acts as a proxy. /// In contrast to the regular `MemoNode`. This node ignores all input. /// Using this node might result in the document not updating properly, /// use with caution. #[derive(Default)] pub struct ImpureMemoNode { cache: Arc>>, node: CachedNode, _phantom: std::marker::PhantomData, } impl<'i, I: 'i, T: 'i + Clone + WasmNotSend, CachedNode: 'i> Node<'i, I> for ImpureMemoNode where CachedNode: for<'any_input> Node<'any_input, I>, for<'a> >::Output: Future + WasmNotSend, { // TODO: This should return a reference to the cached cached_value // but that requires a lot of lifetime magic <- This was suggested by copilot but is pretty accurate xD type Output = DynFuture<'i, T>; fn eval(&'i self, input: I) -> Self::Output { if let Some(cached_value) = self.cache.lock().as_ref().unwrap().deref() { let data = cached_value.clone(); Box::pin(async move { data }) } else { let fut = self.node.eval(input); let cache = self.cache.clone(); Box::pin(async move { let value = fut.await; *cache.lock().unwrap() = Some(value.clone()); value }) } } fn reset(&self) { self.cache.lock().unwrap().take(); } } impl ImpureMemoNode { pub fn new(node: CachedNode) -> ImpureMemoNode { ImpureMemoNode { cache: Default::default(), node, _phantom: std::marker::PhantomData, } } } /// Stores both what a node was called with and what it returned. #[derive(Clone, Debug)] pub struct IORecord { pub input: I, pub output: O, } /// Caches the output of the last graph evaluation for introspection #[derive(Default)] pub struct MonitorNode { #[allow(clippy::type_complexity)] io: Arc>>>>, node: N, } impl<'i, T, I, N> Node<'i, I> for MonitorNode where I: Clone + 'static + Send + Sync, T: Clone + 'static + Send + Sync, for<'a> N: Node<'a, I, Output: Future + WasmNotSend> + 'i, { type Output = DynFuture<'i, T>; fn eval(&'i self, input: I) -> Self::Output { let io = self.io.clone(); let output_fut = self.node.eval(input.clone()); Box::pin(async move { let output = output_fut.await; *io.lock().unwrap() = Some(Arc::new(IORecord { input, output: output.clone() })); output }) } fn serialize(&self) -> Option> { let io = self.io.lock().unwrap(); (io).as_ref().map(|output| output.clone() as Arc) } } impl MonitorNode { pub fn new(node: N) -> MonitorNode { MonitorNode { io: Arc::new(Mutex::new(None)), node } } } use std::hash::{Hash, Hasher}; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)] pub struct MemoHash { hash: u64, value: T, } impl<'de, T: serde::Deserialize<'de> + Hash> serde::Deserialize<'de> for MemoHash { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { T::deserialize(deserializer).map(|value| Self::new(value)) } } impl serde::Serialize for MemoHash { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { self.value.serialize(serializer) } } impl MemoHash { pub fn new(value: T) -> Self { let hash = Self::calc_hash(&value); Self { hash, value } } pub fn new_with_hash(value: T, hash: u64) -> Self { Self { hash, value } } fn calc_hash(data: &T) -> u64 { let mut hasher = DefaultHasher::new(); data.hash(&mut hasher); hasher.finish() } pub fn inner_mut(&mut self) -> MemoHashGuard<'_, T> { MemoHashGuard { inner: self } } pub fn into_inner(self) -> T { self.value } pub fn hash_code(&self) -> u64 { self.hash } } impl From for MemoHash { fn from(value: T) -> Self { Self::new(value) } } impl Hash for MemoHash { fn hash(&self, state: &mut H) { self.hash.hash(state) } } impl Deref for MemoHash { type Target = T; fn deref(&self) -> &Self::Target { &self.value } } pub struct MemoHashGuard<'a, T: Hash> { inner: &'a mut MemoHash, } impl Drop for MemoHashGuard<'_, T> { fn drop(&mut self) { let hash = MemoHash::::calc_hash(&self.inner.value); self.inner.hash = hash; } } impl Deref for MemoHashGuard<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { &self.inner.value } } impl std::ops::DerefMut for MemoHashGuard<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.inner.value } }