rustc_utils/
cache.rs

1//! Data structures for memoizing computations.
2//!
3//! Contruct new caches using [`Default::default`], then construct/retrieve
4//! elements with [`get`](Cache::get). `get` should only ever be used with one,
5//! `compute` function[^inconsistent].
6//!
7//! In terms of choice,
8//! - [`CopyCache`] should be used for expensive computations that create cheap
9//!   (i.e. small) values.
10//! - [`Cache`] should be used for expensive computations that create expensive
11//!   (i.e. large) values.
12//!
13//! Both types of caches implement **recursion breaking**. In general because
14//! caches are supposed to be used as simple `&` (no `mut`) the reference may be
15//! freely copied, including into the `compute` closure. What this means is that
16//! a `compute` may call [`get`](Cache::get) on the cache again. This is usually
17//! safe and can be used to compute data structures that recursively depend on
18//! one another, dynamic-programming style. However if a `get` on a key `k`
19//! itself calls `get` again on the same `k` this will either create an infinite
20//! recursion or an inconsistent cache[^inconsistent].
21//!
22//! Consider a simple example where we compute the Fibonacci Series with a
23//! [`CopyCache`]:
24//!
25//! ```rs
26//! struct Fib(CopyCache<u32, u32>);
27//!
28//! impl Fib {
29//!   fn get(&self, i: u32) -> u32 {
30//!     self.0.get(i, |_| {
31//!       if this <= 1 {
32//!         return this;
33//!       }
34//!       let fib_1 = self.get(this - 1);
35//!       let fib_2 = self.get(this - 2);
36//!       fib_1 + fib_2
37//!     })
38//!   }
39//! }
40//!
41//! let cache = Fib(Default::default());
42//! let fib_5 = cache.get(5);
43//! ```
44//!
45//! This use of recursive [`get`](CopyCache::get) calls is perfectly legal.
46//! However if we made an error and called `chache.get(this, ...)` (forgetting
47//! the decrement) we would have created an inadvertend infinite recursion.
48//!
49//! To avoid this scenario both caches are implemented to detect when a
50//! recursive call as described is performed and `get` will panic. If your code
51//! uses recursive construction and would like to handle this case gracefully
52//! use [`get_maybe_recursive`](Cache::get_maybe_recursive) instead wich returns
53//! `None` from `get(k)` *iff* `k` this call (potentially transitively)
54//! originates from another `get(k)` call.
55//!
56//! [^inconsistent]: For any given cache value `get` should only ever be used
57//!     with one, referentially transparent `compute` function. Essentially this
58//!     means running `compute(k)` should always return the same value
59//!     *independent of the state of it's environment*. Violation of this rule
60//!     can introduces non-determinism in your program.
61use std::{cell::RefCell, hash::Hash, pin::Pin};
62
63use rustc_data_structures::fx::FxHashMap as HashMap;
64
65/// Cache for non-copyable types.
66pub struct Cache<In, Out>(RefCell<HashMap<In, Option<Pin<Box<Out>>>>>);
67
68impl<In, Out> Cache<In, Out>
69where
70  In: Hash + Eq + Clone,
71{
72  /// Size of the cache
73  pub fn len(&self) -> usize {
74    self.0.borrow().len()
75  }
76
77  /// Returns true if the cache contains the key.
78  pub fn contains_key(&self, key: &In) -> bool {
79    self.0.borrow().contains_key(key)
80  }
81
82  /// Returns the cached value for the given key, or runs `compute` if
83  /// the value is not in cache.
84  ///
85  /// # Panics
86  ///
87  /// If this is a recursive invocation for this key.
88  pub fn get(&self, key: &In, compute: impl FnOnce(In) -> Out) -> &Out {
89    self
90      .get_maybe_recursive(key, compute)
91      .unwrap_or_else(recursion_panic)
92  }
93
94  /// Returns the cached value for the given key, or runs `compute` if
95  /// the value is not in cache.
96  ///
97  /// Returns `None` if this is a recursive invocation of `get` for key `key`.
98  pub fn get_maybe_recursive<'a>(
99    &'a self,
100    key: &In,
101    compute: impl FnOnce(In) -> Out,
102  ) -> Option<&'a Out> {
103    if !self.0.borrow().contains_key(key) {
104      self.0.borrow_mut().insert(key.clone(), None);
105      let out = Box::pin(compute(key.clone()));
106      self.0.borrow_mut().insert(key.clone(), Some(out));
107    }
108
109    let cache = self.0.borrow();
110    // Important here to first `unwrap` the `Option` created by `get`, then
111    // propagate the potential option stored in the map.
112    let entry = cache.get(key).expect("invariant broken").as_ref()?;
113
114    // SAFETY: because the entry is pinned, it cannot move and this pointer will
115    // only be invalidated if Cache is dropped. The returned reference has a lifetime
116    // equal to Cache, so Cache cannot be dropped before this reference goes out of scope.
117    Some(unsafe { std::mem::transmute::<&'_ Out, &'a Out>(&**entry) })
118  }
119}
120
121fn recursion_panic<A>() -> A {
122  panic!(
123    "Recursion detected! The computation of a value tried to retrieve the same from the cache. Using `get_maybe_recursive` to handle this case gracefully."
124  )
125}
126
127impl<In, Out> Default for Cache<In, Out> {
128  fn default() -> Self {
129    Cache(RefCell::new(HashMap::default()))
130  }
131}
132
133/// Cache for copyable types.
134pub struct CopyCache<In, Out>(RefCell<HashMap<In, Option<Out>>>);
135
136impl<In, Out> CopyCache<In, Out>
137where
138  In: Hash + Eq + Clone,
139  Out: Copy,
140{
141  /// Size of the cache
142  pub fn len(&self) -> usize {
143    self.0.borrow().len()
144  }
145  /// Returns the cached value for the given key, or runs `compute` if
146  /// the value is not in cache.
147  ///
148  /// # Panics
149  ///
150  /// If this is a recursive invocation for this key.
151  pub fn get(&self, key: &In, compute: impl FnOnce(In) -> Out) -> Out {
152    self
153      .get_maybe_recursive(key, compute)
154      .unwrap_or_else(recursion_panic)
155  }
156
157  /// Returns the cached value for the given key, or runs `compute` if
158  /// the value is not in cache.
159  ///
160  /// Returns `None` if this is a recursive invocation of `get` for key `key`.
161  pub fn get_maybe_recursive(
162    &self,
163    key: &In,
164    compute: impl FnOnce(In) -> Out,
165  ) -> Option<Out> {
166    if !self.0.borrow().contains_key(key) {
167      self.0.borrow_mut().insert(key.clone(), None);
168      let out = compute(key.clone());
169      self.0.borrow_mut().insert(key.clone(), Some(out));
170    }
171
172    *self.0.borrow_mut().get(key).expect("invariant broken")
173  }
174}
175
176impl<In, Out> Default for CopyCache<In, Out> {
177  fn default() -> Self {
178    CopyCache(RefCell::new(HashMap::default()))
179  }
180}
181
182#[cfg(test)]
183mod test {
184  use super::*;
185
186  #[test]
187  fn test_cached() {
188    let cache: Cache<usize, usize> = Cache::default();
189    let x = cache.get(&0, |_| 0);
190    let y = cache.get(&1, |_| 1);
191    let z = cache.get(&0, |_| 2);
192    assert_eq!(*x, 0);
193    assert_eq!(*y, 1);
194    assert_eq!(*z, 0);
195    assert!(std::ptr::eq(x, z));
196  }
197
198  #[test]
199  fn test_recursion_breaking() {
200    struct RecursiveUse(Cache<i32, i32>);
201    impl RecursiveUse {
202      fn get_infinite_recursion(&self, i: i32) -> i32 {
203        self
204          .0
205          .get_maybe_recursive(&i, |_| i + self.get_infinite_recursion(i))
206          .copied()
207          .unwrap_or(-18)
208      }
209      fn get_safe_recursion(&self, i: i32) -> i32 {
210        *self.0.get(&i, |_| {
211          if i == 0 {
212            0
213          } else {
214            self.get_safe_recursion(i - 1) + i
215          }
216        })
217      }
218    }
219
220    let cache = RecursiveUse(Cache::default());
221
222    assert_eq!(cache.get_infinite_recursion(60), 42);
223    assert_eq!(cache.get_safe_recursion(5), 15);
224  }
225}