rustc_utils/cache.rs
1//! Data structures for memoizing computations.
2//!
3//! Contruct new caches using [`Default::default`], then construct/retrieve
4//! elements with [`get`](Cache::get). `get` should only ever be used with one,
5//! `compute` function[^inconsistent].
6//!
7//! In terms of choice,
8//! - [`CopyCache`] should be used for expensive computations that create cheap
9//! (i.e. small) values.
10//! - [`Cache`] should be used for expensive computations that create expensive
11//! (i.e. large) values.
12//!
13//! Both types of caches implement **recursion breaking**. In general because
14//! caches are supposed to be used as simple `&` (no `mut`) the reference may be
15//! freely copied, including into the `compute` closure. What this means is that
16//! a `compute` may call [`get`](Cache::get) on the cache again. This is usually
17//! safe and can be used to compute data structures that recursively depend on
18//! one another, dynamic-programming style. However if a `get` on a key `k`
19//! itself calls `get` again on the same `k` this will either create an infinite
20//! recursion or an inconsistent cache[^inconsistent].
21//!
22//! Consider a simple example where we compute the Fibonacci Series with a
23//! [`CopyCache`]:
24//!
25//! ```rs
26//! struct Fib(CopyCache<u32, u32>);
27//!
28//! impl Fib {
29//! fn get(&self, i: u32) -> u32 {
30//! self.0.get(i, |_| {
31//! if this <= 1 {
32//! return this;
33//! }
34//! let fib_1 = self.get(this - 1);
35//! let fib_2 = self.get(this - 2);
36//! fib_1 + fib_2
37//! })
38//! }
39//! }
40//!
41//! let cache = Fib(Default::default());
42//! let fib_5 = cache.get(5);
43//! ```
44//!
45//! This use of recursive [`get`](CopyCache::get) calls is perfectly legal.
46//! However if we made an error and called `chache.get(this, ...)` (forgetting
47//! the decrement) we would have created an inadvertend infinite recursion.
48//!
49//! To avoid this scenario both caches are implemented to detect when a
50//! recursive call as described is performed and `get` will panic. If your code
51//! uses recursive construction and would like to handle this case gracefully
52//! use [`get_maybe_recursive`](Cache::get_maybe_recursive) instead wich returns
53//! `None` from `get(k)` *iff* `k` this call (potentially transitively)
54//! originates from another `get(k)` call.
55//!
56//! [^inconsistent]: For any given cache value `get` should only ever be used
57//! with one, referentially transparent `compute` function. Essentially this
58//! means running `compute(k)` should always return the same value
59//! *independent of the state of it's environment*. Violation of this rule
60//! can introduces non-determinism in your program.
61use std::{cell::RefCell, hash::Hash, pin::Pin};
62
63use rustc_data_structures::fx::FxHashMap as HashMap;
64
65/// Cache for non-copyable types.
66pub struct Cache<In, Out>(RefCell<HashMap<In, Option<Pin<Box<Out>>>>>);
67
68impl<In, Out> Cache<In, Out>
69where
70 In: Hash + Eq + Clone,
71{
72 /// Size of the cache
73 pub fn len(&self) -> usize {
74 self.0.borrow().len()
75 }
76
77 /// Returns true if the cache contains the key.
78 pub fn contains_key(&self, key: &In) -> bool {
79 self.0.borrow().contains_key(key)
80 }
81
82 /// Returns the cached value for the given key, or runs `compute` if
83 /// the value is not in cache.
84 ///
85 /// # Panics
86 ///
87 /// If this is a recursive invocation for this key.
88 pub fn get(&self, key: &In, compute: impl FnOnce(In) -> Out) -> &Out {
89 self
90 .get_maybe_recursive(key, compute)
91 .unwrap_or_else(recursion_panic)
92 }
93
94 /// Returns the cached value for the given key, or runs `compute` if
95 /// the value is not in cache.
96 ///
97 /// Returns `None` if this is a recursive invocation of `get` for key `key`.
98 pub fn get_maybe_recursive<'a>(
99 &'a self,
100 key: &In,
101 compute: impl FnOnce(In) -> Out,
102 ) -> Option<&'a Out> {
103 if !self.0.borrow().contains_key(key) {
104 self.0.borrow_mut().insert(key.clone(), None);
105 let out = Box::pin(compute(key.clone()));
106 self.0.borrow_mut().insert(key.clone(), Some(out));
107 }
108
109 let cache = self.0.borrow();
110 // Important here to first `unwrap` the `Option` created by `get`, then
111 // propagate the potential option stored in the map.
112 let entry = cache.get(key).expect("invariant broken").as_ref()?;
113
114 // SAFETY: because the entry is pinned, it cannot move and this pointer will
115 // only be invalidated if Cache is dropped. The returned reference has a lifetime
116 // equal to Cache, so Cache cannot be dropped before this reference goes out of scope.
117 Some(unsafe { std::mem::transmute::<&'_ Out, &'a Out>(&**entry) })
118 }
119}
120
121fn recursion_panic<A>() -> A {
122 panic!(
123 "Recursion detected! The computation of a value tried to retrieve the same from the cache. Using `get_maybe_recursive` to handle this case gracefully."
124 )
125}
126
127impl<In, Out> Default for Cache<In, Out> {
128 fn default() -> Self {
129 Cache(RefCell::new(HashMap::default()))
130 }
131}
132
133/// Cache for copyable types.
134pub struct CopyCache<In, Out>(RefCell<HashMap<In, Option<Out>>>);
135
136impl<In, Out> CopyCache<In, Out>
137where
138 In: Hash + Eq + Clone,
139 Out: Copy,
140{
141 /// Size of the cache
142 pub fn len(&self) -> usize {
143 self.0.borrow().len()
144 }
145 /// Returns the cached value for the given key, or runs `compute` if
146 /// the value is not in cache.
147 ///
148 /// # Panics
149 ///
150 /// If this is a recursive invocation for this key.
151 pub fn get(&self, key: &In, compute: impl FnOnce(In) -> Out) -> Out {
152 self
153 .get_maybe_recursive(key, compute)
154 .unwrap_or_else(recursion_panic)
155 }
156
157 /// Returns the cached value for the given key, or runs `compute` if
158 /// the value is not in cache.
159 ///
160 /// Returns `None` if this is a recursive invocation of `get` for key `key`.
161 pub fn get_maybe_recursive(
162 &self,
163 key: &In,
164 compute: impl FnOnce(In) -> Out,
165 ) -> Option<Out> {
166 if !self.0.borrow().contains_key(key) {
167 self.0.borrow_mut().insert(key.clone(), None);
168 let out = compute(key.clone());
169 self.0.borrow_mut().insert(key.clone(), Some(out));
170 }
171
172 *self.0.borrow_mut().get(key).expect("invariant broken")
173 }
174}
175
176impl<In, Out> Default for CopyCache<In, Out> {
177 fn default() -> Self {
178 CopyCache(RefCell::new(HashMap::default()))
179 }
180}
181
182#[cfg(test)]
183mod test {
184 use super::*;
185
186 #[test]
187 fn test_cached() {
188 let cache: Cache<usize, usize> = Cache::default();
189 let x = cache.get(&0, |_| 0);
190 let y = cache.get(&1, |_| 1);
191 let z = cache.get(&0, |_| 2);
192 assert_eq!(*x, 0);
193 assert_eq!(*y, 1);
194 assert_eq!(*z, 0);
195 assert!(std::ptr::eq(x, z));
196 }
197
198 #[test]
199 fn test_recursion_breaking() {
200 struct RecursiveUse(Cache<i32, i32>);
201 impl RecursiveUse {
202 fn get_infinite_recursion(&self, i: i32) -> i32 {
203 self
204 .0
205 .get_maybe_recursive(&i, |_| i + self.get_infinite_recursion(i))
206 .copied()
207 .unwrap_or(-18)
208 }
209 fn get_safe_recursion(&self, i: i32) -> i32 {
210 *self.0.get(&i, |_| {
211 if i == 0 {
212 0
213 } else {
214 self.get_safe_recursion(i - 1) + i
215 }
216 })
217 }
218 }
219
220 let cache = RecursiveUse(Cache::default());
221
222 assert_eq!(cache.get_infinite_recursion(60), 42);
223 assert_eq!(cache.get_safe_recursion(5), 15);
224 }
225}