scuffle_batching/
dataloader.rs

1use std::collections::{HashMap, HashSet};
2use std::future::Future;
3use std::sync::Arc;
4
5/// A trait for fetching data in batches
6pub trait DataLoaderFetcher {
7    /// The incoming key type
8    type Key: Clone + Eq + std::hash::Hash + Send + Sync;
9    /// The outgoing value type
10    type Value: Clone + Send + Sync;
11
12    /// Load a batch of keys
13    fn load(&self, keys: HashSet<Self::Key>) -> impl Future<Output = Option<HashMap<Self::Key, Self::Value>>> + Send;
14}
15
16/// A builder for a [`DataLoader`]
17#[derive(Clone, Copy, Debug)]
18#[must_use = "builders must be used to create a dataloader"]
19pub struct DataLoaderBuilder<E> {
20    batch_size: usize,
21    concurrency: usize,
22    delay: std::time::Duration,
23    _phantom: std::marker::PhantomData<E>,
24}
25
26impl<E> Default for DataLoaderBuilder<E> {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl<E> DataLoaderBuilder<E> {
33    /// Create a new builder
34    pub const fn new() -> Self {
35        Self {
36            batch_size: 1000,
37            concurrency: 50,
38            delay: std::time::Duration::from_millis(5),
39            _phantom: std::marker::PhantomData,
40        }
41    }
42
43    /// Set the batch size
44    #[inline]
45    pub const fn batch_size(mut self, batch_size: usize) -> Self {
46        self.with_batch_size(batch_size);
47        self
48    }
49
50    /// Set the delay
51    #[inline]
52    pub const fn delay(mut self, delay: std::time::Duration) -> Self {
53        self.with_delay(delay);
54        self
55    }
56
57    /// Set the concurrency
58    #[inline]
59    pub const fn concurrency(mut self, concurrency: usize) -> Self {
60        self.with_concurrency(concurrency);
61        self
62    }
63
64    /// Set the batch size
65    #[inline]
66    pub const fn with_batch_size(&mut self, batch_size: usize) -> &mut Self {
67        self.batch_size = batch_size;
68        self
69    }
70
71    /// Set the delay
72    #[inline]
73    pub const fn with_delay(&mut self, delay: std::time::Duration) -> &mut Self {
74        self.delay = delay;
75        self
76    }
77
78    /// Set the concurrency
79    #[inline]
80    pub const fn with_concurrency(&mut self, concurrency: usize) -> &mut Self {
81        self.concurrency = concurrency;
82        self
83    }
84
85    /// Build the dataloader
86    #[inline]
87    pub fn build(self, executor: E) -> DataLoader<E>
88    where
89        E: DataLoaderFetcher + Send + Sync + 'static,
90    {
91        DataLoader::new(executor, self.batch_size, self.concurrency, self.delay)
92    }
93}
94
95/// A dataloader used to batch requests to a [`DataLoaderFetcher`]
96#[must_use = "dataloaders must be used to load data"]
97pub struct DataLoader<E>
98where
99    E: DataLoaderFetcher + Send + Sync + 'static,
100{
101    _auto_spawn: tokio::task::JoinHandle<()>,
102    executor: Arc<E>,
103    semaphore: Arc<tokio::sync::Semaphore>,
104    current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
105    batch_size: usize,
106}
107
108impl<E> DataLoader<E>
109where
110    E: DataLoaderFetcher + Send + Sync + 'static,
111{
112    /// Create a new dataloader
113    pub fn new(executor: E, batch_size: usize, concurrency: usize, delay: std::time::Duration) -> Self {
114        let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency.max(1)));
115        let current_batch = Arc::new(tokio::sync::Mutex::new(None));
116        let executor = Arc::new(executor);
117
118        let join_handle = tokio::spawn(batch_loop(executor.clone(), current_batch.clone(), delay));
119
120        Self {
121            executor,
122            _auto_spawn: join_handle,
123            semaphore,
124            current_batch,
125            batch_size: batch_size.max(1),
126        }
127    }
128
129    /// Create a builder for a [`DataLoader`]
130    #[inline]
131    pub const fn builder() -> DataLoaderBuilder<E> {
132        DataLoaderBuilder::new()
133    }
134
135    /// Load a single key
136    /// Can return an error if the underlying [`DataLoaderFetcher`] returns an
137    /// error
138    ///
139    /// Returns `None` if the key is not found
140    pub async fn load(&self, items: E::Key) -> Result<Option<E::Value>, ()> {
141        Ok(self.load_many(std::iter::once(items)).await?.into_values().next())
142    }
143
144    /// Load many keys
145    /// Can return an error if the underlying [`DataLoaderFetcher`] returns an
146    /// error
147    ///
148    /// Returns a map of keys to values which may be incomplete if any of the
149    /// keys were not found
150    pub async fn load_many<I>(&self, items: I) -> Result<HashMap<E::Key, E::Value>, ()>
151    where
152        I: IntoIterator<Item = E::Key> + Send,
153    {
154        struct BatchWaiting<K, V> {
155            keys: HashSet<K>,
156            result: Arc<BatchResult<K, V>>,
157        }
158
159        let mut waiters = Vec::<BatchWaiting<E::Key, E::Value>>::new();
160
161        let mut count = 0;
162
163        {
164            let mut new_batch = true;
165            let mut batch = self.current_batch.lock().await;
166
167            for item in items {
168                if batch.is_none() {
169                    batch.replace(Batch::new(self.semaphore.clone()));
170                    new_batch = true;
171                }
172
173                let batch_mut = batch.as_mut().unwrap();
174                batch_mut.items.insert(item.clone());
175
176                if new_batch {
177                    new_batch = false;
178                    waiters.push(BatchWaiting {
179                        keys: HashSet::new(),
180                        result: batch_mut.result.clone(),
181                    });
182                }
183
184                let waiting = waiters.last_mut().unwrap();
185                waiting.keys.insert(item);
186
187                count += 1;
188
189                if batch_mut.items.len() >= self.batch_size {
190                    tokio::spawn(batch.take().unwrap().spawn(self.executor.clone()));
191                }
192            }
193        }
194
195        let mut results = HashMap::with_capacity(count);
196        for waiting in waiters {
197            let result = waiting.result.wait().await?;
198            results.extend(waiting.keys.into_iter().filter_map(|key| {
199                let value = result.get(&key)?.clone();
200                Some((key, value))
201            }));
202        }
203
204        Ok(results)
205    }
206}
207
208async fn batch_loop<E>(
209    executor: Arc<E>,
210    current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
211    delay: std::time::Duration,
212) where
213    E: DataLoaderFetcher + Send + Sync + 'static,
214{
215    let mut delay_delta = delay;
216    loop {
217        tokio::time::sleep(delay_delta).await;
218
219        let mut batch = current_batch.lock().await;
220        let Some(created_at) = batch.as_ref().map(|b| b.created_at) else {
221            delay_delta = delay;
222            continue;
223        };
224
225        let remaining = delay.saturating_sub(created_at.elapsed());
226        if remaining == std::time::Duration::ZERO {
227            tokio::spawn(batch.take().unwrap().spawn(executor.clone()));
228            delay_delta = delay;
229        } else {
230            delay_delta = remaining;
231        }
232    }
233}
234
235struct BatchResult<K, V> {
236    values: tokio::sync::OnceCell<Option<HashMap<K, V>>>,
237    token: tokio_util::sync::CancellationToken,
238}
239
240impl<K, V> BatchResult<K, V> {
241    fn new() -> Self {
242        Self {
243            values: tokio::sync::OnceCell::new(),
244            token: tokio_util::sync::CancellationToken::new(),
245        }
246    }
247
248    async fn wait(&self) -> Result<&HashMap<K, V>, ()> {
249        if !self.token.is_cancelled() {
250            self.token.cancelled().await;
251        }
252
253        self.values.get().ok_or(())?.as_ref().ok_or(())
254    }
255}
256
257struct Batch<E>
258where
259    E: DataLoaderFetcher + Send + Sync + 'static,
260{
261    items: HashSet<E::Key>,
262    result: Arc<BatchResult<E::Key, E::Value>>,
263    semaphore: Arc<tokio::sync::Semaphore>,
264    created_at: std::time::Instant,
265}
266
267impl<E> Batch<E>
268where
269    E: DataLoaderFetcher + Send + Sync + 'static,
270{
271    fn new(semaphore: Arc<tokio::sync::Semaphore>) -> Self {
272        Self {
273            items: HashSet::new(),
274            result: Arc::new(BatchResult::new()),
275            semaphore,
276            created_at: std::time::Instant::now(),
277        }
278    }
279
280    async fn spawn(self, executor: Arc<E>) {
281        let _drop_guard = self.result.token.clone().drop_guard();
282        let _ticket = self.semaphore.acquire_owned().await.unwrap();
283        let result = executor.load(self.items).await;
284
285        #[cfg_attr(all(coverage_nightly, test), coverage(off))]
286        fn unknwown_error<E>(_: E) -> ! {
287            unreachable!(
288                "batch result already set, this is a bug please report it https://github.com/scufflecloud/scuffle/issues"
289            )
290        }
291
292        self.result.values.set(result).map_err(unknwown_error).unwrap();
293    }
294}
295
296#[cfg_attr(all(coverage_nightly, test), coverage(off))]
297#[cfg(test)]
298mod tests {
299    use std::sync::atomic::AtomicUsize;
300
301    use super::*;
302
303    struct TestFetcher<K, V> {
304        values: HashMap<K, V>,
305        delay: std::time::Duration,
306        requests: Arc<AtomicUsize>,
307        capacity: usize,
308    }
309
310    impl<K, V> DataLoaderFetcher for TestFetcher<K, V>
311    where
312        K: Clone + Eq + std::hash::Hash + Send + Sync,
313        V: Clone + Send + Sync,
314    {
315        type Key = K;
316        type Value = V;
317
318        async fn load(&self, keys: HashSet<Self::Key>) -> Option<HashMap<Self::Key, Self::Value>> {
319            assert!(keys.len() <= self.capacity);
320            tokio::time::sleep(self.delay).await;
321            self.requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
322            Some(
323                keys.into_iter()
324                    .filter_map(|k| {
325                        let value = self.values.get(&k)?.clone();
326                        Some((k, value))
327                    })
328                    .collect(),
329            )
330        }
331    }
332
333    #[cfg(not(valgrind))] // test is time-sensitive
334    #[tokio::test]
335    async fn basic() {
336        let requests = Arc::new(AtomicUsize::new(0));
337
338        let fetcher = TestFetcher {
339            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
340            delay: std::time::Duration::from_millis(5),
341            requests: requests.clone(),
342            capacity: 2,
343        };
344
345        let loader = DataLoader::builder().batch_size(2).concurrency(1).build(fetcher);
346
347        let start = std::time::Instant::now();
348        let a = loader.load("a").await.unwrap();
349        assert_eq!(a, Some(1));
350        assert!(start.elapsed() < std::time::Duration::from_millis(15));
351        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
352
353        let start = std::time::Instant::now();
354        let b = loader.load("b").await.unwrap();
355        assert_eq!(b, Some(2));
356        assert!(start.elapsed() < std::time::Duration::from_millis(15));
357        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
358        let start = std::time::Instant::now();
359        let c = loader.load("c").await.unwrap();
360        assert_eq!(c, Some(3));
361        assert!(start.elapsed() < std::time::Duration::from_millis(15));
362        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 3);
363
364        let start = std::time::Instant::now();
365        let ab = loader.load_many(vec!["a", "b"]).await.unwrap();
366        assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2)]));
367        assert!(start.elapsed() < std::time::Duration::from_millis(15));
368        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 4);
369
370        let start = std::time::Instant::now();
371        let unknown = loader.load("unknown").await.unwrap();
372        assert_eq!(unknown, None);
373        assert!(start.elapsed() < std::time::Duration::from_millis(15));
374        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
375    }
376
377    #[cfg(not(valgrind))] // test is time-sensitive
378    #[tokio::test]
379    async fn concurrency_high() {
380        let requests = Arc::new(AtomicUsize::new(0));
381
382        let fetcher = TestFetcher {
383            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
384            delay: std::time::Duration::from_millis(5),
385            requests: requests.clone(),
386            capacity: 2,
387        };
388
389        let loader = DataLoader::builder().batch_size(2).concurrency(10).build(fetcher);
390
391        let start = std::time::Instant::now();
392        let ab = loader
393            .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
394            .await
395            .unwrap();
396        assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
397        assert!(start.elapsed() < std::time::Duration::from_millis(15));
398        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
399    }
400
401    #[cfg(not(valgrind))] // test is time-sensitive
402    #[tokio::test]
403    async fn delay_low() {
404        let requests = Arc::new(AtomicUsize::new(0));
405
406        let fetcher = TestFetcher {
407            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
408            delay: std::time::Duration::from_millis(5),
409            requests: requests.clone(),
410            capacity: 2,
411        };
412
413        let loader = DataLoader::builder()
414            .batch_size(2)
415            .concurrency(1)
416            .delay(std::time::Duration::from_millis(10))
417            .build(fetcher);
418
419        let start = std::time::Instant::now();
420        let ab = loader
421            .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
422            .await
423            .unwrap();
424        assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
425        assert!(start.elapsed() < std::time::Duration::from_millis(35));
426        assert!(start.elapsed() >= std::time::Duration::from_millis(25));
427        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
428    }
429
430    #[cfg(not(valgrind))] // test is time-sensitive
431    #[tokio::test]
432    async fn batch_size() {
433        let requests = Arc::new(AtomicUsize::new(0));
434
435        let fetcher = TestFetcher {
436            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
437            delay: std::time::Duration::from_millis(5),
438            requests: requests.clone(),
439            capacity: 100,
440        };
441
442        let loader = DataLoaderBuilder::default()
443            .batch_size(100)
444            .concurrency(1)
445            .delay(std::time::Duration::from_millis(10))
446            .build(fetcher);
447
448        let start = std::time::Instant::now();
449        let ab = loader
450            .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
451            .await
452            .unwrap();
453        assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
454        assert!(start.elapsed() >= std::time::Duration::from_millis(10));
455        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
456    }
457
458    #[cfg(not(valgrind))] // test is time-sensitive
459    #[tokio::test]
460    async fn high_concurrency() {
461        let requests = Arc::new(AtomicUsize::new(0));
462
463        let fetcher = TestFetcher {
464            values: HashMap::from_iter((0..1134).map(|i| (i, i * 2 + 5))),
465            delay: std::time::Duration::from_millis(5),
466            requests: requests.clone(),
467            capacity: 100,
468        };
469
470        let loader = DataLoaderBuilder::default()
471            .batch_size(100)
472            .concurrency(10)
473            .delay(std::time::Duration::from_millis(10))
474            .build(fetcher);
475
476        let start = std::time::Instant::now();
477        let ab = loader.load_many(0..1134).await.unwrap();
478        assert_eq!(ab, HashMap::from_iter((0..1134).map(|i| (i, i * 2 + 5))));
479        assert!(start.elapsed() >= std::time::Duration::from_millis(15));
480        assert!(start.elapsed() < std::time::Duration::from_millis(25));
481        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1134 / 100 + 1);
482    }
483
484    #[cfg(not(valgrind))] // test is time-sensitive
485    #[tokio::test]
486    async fn delayed_start() {
487        let requests = Arc::new(AtomicUsize::new(0));
488
489        let fetcher = TestFetcher {
490            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
491            delay: std::time::Duration::from_millis(5),
492            requests: requests.clone(),
493            capacity: 2,
494        };
495
496        let loader = DataLoader::builder()
497            .batch_size(2)
498            .concurrency(100)
499            .delay(std::time::Duration::from_millis(10))
500            .build(fetcher);
501
502        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
503
504        let start = std::time::Instant::now();
505        let ab = loader
506            .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
507            .await
508            .unwrap();
509        assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
510        assert!(start.elapsed() >= std::time::Duration::from_millis(5));
511        assert!(start.elapsed() < std::time::Duration::from_millis(25));
512    }
513
514    #[cfg(not(valgrind))] // test is time-sensitive
515    #[tokio::test]
516    async fn delayed_start_single() {
517        let requests = Arc::new(AtomicUsize::new(0));
518
519        let fetcher = TestFetcher {
520            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
521            delay: std::time::Duration::from_millis(5),
522            requests: requests.clone(),
523            capacity: 2,
524        };
525
526        let loader = DataLoader::builder()
527            .batch_size(2)
528            .concurrency(100)
529            .delay(std::time::Duration::from_millis(10))
530            .build(fetcher);
531
532        tokio::time::sleep(std::time::Duration::from_millis(5)).await;
533
534        let start = std::time::Instant::now();
535        let ab = loader.load_many(vec!["a"]).await.unwrap();
536        assert_eq!(ab, HashMap::from_iter(vec![("a", 1)]));
537        assert!(start.elapsed() >= std::time::Duration::from_millis(15));
538        assert!(start.elapsed() < std::time::Duration::from_millis(20));
539    }
540
541    #[cfg(not(valgrind))] // test is time-sensitive
542    #[tokio::test]
543    async fn deduplication() {
544        let requests = Arc::new(AtomicUsize::new(0));
545
546        let fetcher = TestFetcher {
547            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
548            delay: std::time::Duration::from_millis(5),
549            requests: requests.clone(),
550            capacity: 4,
551        };
552
553        let loader = DataLoader::builder()
554            .batch_size(4)
555            .concurrency(1)
556            .delay(std::time::Duration::from_millis(10))
557            .build(fetcher);
558
559        let start = std::time::Instant::now();
560        let ab = loader.load_many(vec!["a", "a", "b", "b", "c", "c"]).await.unwrap();
561        assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
562        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
563        assert!(start.elapsed() >= std::time::Duration::from_millis(5));
564        assert!(start.elapsed() < std::time::Duration::from_millis(20));
565    }
566
567    #[cfg(not(valgrind))] // test is time-sensitive
568    #[tokio::test]
569    async fn already_batch() {
570        let requests = Arc::new(AtomicUsize::new(0));
571
572        let fetcher = TestFetcher {
573            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
574            delay: std::time::Duration::from_millis(5),
575            requests: requests.clone(),
576            capacity: 2,
577        };
578
579        let loader = DataLoader::builder().batch_size(10).concurrency(1).build(fetcher);
580
581        let start = std::time::Instant::now();
582        let (a, b) = tokio::join!(loader.load("a"), loader.load("b"));
583        assert_eq!(a, Ok(Some(1)));
584        assert_eq!(b, Ok(Some(2)));
585        assert!(start.elapsed() < std::time::Duration::from_millis(15));
586        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
587    }
588}