1use std::collections::{HashMap, HashSet};
2use std::future::Future;
3use std::sync::Arc;
4
5pub trait DataLoaderFetcher {
7 type Key: Clone + Eq + std::hash::Hash + Send + Sync;
9 type Value: Clone + Send + Sync;
11
12 fn load(&self, keys: HashSet<Self::Key>) -> impl Future<Output = Option<HashMap<Self::Key, Self::Value>>> + Send;
14}
15
16#[derive(Clone, Copy, Debug)]
18#[must_use = "builders must be used to create a dataloader"]
19pub struct DataLoaderBuilder<E> {
20 batch_size: usize,
21 concurrency: usize,
22 delay: std::time::Duration,
23 _phantom: std::marker::PhantomData<E>,
24}
25
26impl<E> Default for DataLoaderBuilder<E> {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl<E> DataLoaderBuilder<E> {
33 pub const fn new() -> Self {
35 Self {
36 batch_size: 1000,
37 concurrency: 50,
38 delay: std::time::Duration::from_millis(5),
39 _phantom: std::marker::PhantomData,
40 }
41 }
42
43 #[inline]
45 pub const fn batch_size(mut self, batch_size: usize) -> Self {
46 self.with_batch_size(batch_size);
47 self
48 }
49
50 #[inline]
52 pub const fn delay(mut self, delay: std::time::Duration) -> Self {
53 self.with_delay(delay);
54 self
55 }
56
57 #[inline]
59 pub const fn concurrency(mut self, concurrency: usize) -> Self {
60 self.with_concurrency(concurrency);
61 self
62 }
63
64 #[inline]
66 pub const fn with_batch_size(&mut self, batch_size: usize) -> &mut Self {
67 self.batch_size = batch_size;
68 self
69 }
70
71 #[inline]
73 pub const fn with_delay(&mut self, delay: std::time::Duration) -> &mut Self {
74 self.delay = delay;
75 self
76 }
77
78 #[inline]
80 pub const fn with_concurrency(&mut self, concurrency: usize) -> &mut Self {
81 self.concurrency = concurrency;
82 self
83 }
84
85 #[inline]
87 pub fn build(self, executor: E) -> DataLoader<E>
88 where
89 E: DataLoaderFetcher + Send + Sync + 'static,
90 {
91 DataLoader::new(executor, self.batch_size, self.concurrency, self.delay)
92 }
93}
94
95#[must_use = "dataloaders must be used to load data"]
97pub struct DataLoader<E>
98where
99 E: DataLoaderFetcher + Send + Sync + 'static,
100{
101 _auto_spawn: tokio::task::JoinHandle<()>,
102 executor: Arc<E>,
103 semaphore: Arc<tokio::sync::Semaphore>,
104 current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
105 batch_size: usize,
106}
107
108impl<E> DataLoader<E>
109where
110 E: DataLoaderFetcher + Send + Sync + 'static,
111{
112 pub fn new(executor: E, batch_size: usize, concurrency: usize, delay: std::time::Duration) -> Self {
114 let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency.max(1)));
115 let current_batch = Arc::new(tokio::sync::Mutex::new(None));
116 let executor = Arc::new(executor);
117
118 let join_handle = tokio::spawn(batch_loop(executor.clone(), current_batch.clone(), delay));
119
120 Self {
121 executor,
122 _auto_spawn: join_handle,
123 semaphore,
124 current_batch,
125 batch_size: batch_size.max(1),
126 }
127 }
128
129 #[inline]
131 pub const fn builder() -> DataLoaderBuilder<E> {
132 DataLoaderBuilder::new()
133 }
134
135 pub async fn load(&self, items: E::Key) -> Result<Option<E::Value>, ()> {
141 Ok(self.load_many(std::iter::once(items)).await?.into_values().next())
142 }
143
144 pub async fn load_many<I>(&self, items: I) -> Result<HashMap<E::Key, E::Value>, ()>
151 where
152 I: IntoIterator<Item = E::Key> + Send,
153 {
154 struct BatchWaiting<K, V> {
155 keys: HashSet<K>,
156 result: Arc<BatchResult<K, V>>,
157 }
158
159 let mut waiters = Vec::<BatchWaiting<E::Key, E::Value>>::new();
160
161 let mut count = 0;
162
163 {
164 let mut new_batch = true;
165 let mut batch = self.current_batch.lock().await;
166
167 for item in items {
168 if batch.is_none() {
169 batch.replace(Batch::new(self.semaphore.clone()));
170 new_batch = true;
171 }
172
173 let batch_mut = batch.as_mut().unwrap();
174 batch_mut.items.insert(item.clone());
175
176 if new_batch {
177 new_batch = false;
178 waiters.push(BatchWaiting {
179 keys: HashSet::new(),
180 result: batch_mut.result.clone(),
181 });
182 }
183
184 let waiting = waiters.last_mut().unwrap();
185 waiting.keys.insert(item);
186
187 count += 1;
188
189 if batch_mut.items.len() >= self.batch_size {
190 tokio::spawn(batch.take().unwrap().spawn(self.executor.clone()));
191 }
192 }
193 }
194
195 let mut results = HashMap::with_capacity(count);
196 for waiting in waiters {
197 let result = waiting.result.wait().await?;
198 results.extend(waiting.keys.into_iter().filter_map(|key| {
199 let value = result.get(&key)?.clone();
200 Some((key, value))
201 }));
202 }
203
204 Ok(results)
205 }
206}
207
208async fn batch_loop<E>(
209 executor: Arc<E>,
210 current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
211 delay: std::time::Duration,
212) where
213 E: DataLoaderFetcher + Send + Sync + 'static,
214{
215 let mut delay_delta = delay;
216 loop {
217 tokio::time::sleep(delay_delta).await;
218
219 let mut batch = current_batch.lock().await;
220 let Some(created_at) = batch.as_ref().map(|b| b.created_at) else {
221 delay_delta = delay;
222 continue;
223 };
224
225 let remaining = delay.saturating_sub(created_at.elapsed());
226 if remaining == std::time::Duration::ZERO {
227 tokio::spawn(batch.take().unwrap().spawn(executor.clone()));
228 delay_delta = delay;
229 } else {
230 delay_delta = remaining;
231 }
232 }
233}
234
235struct BatchResult<K, V> {
236 values: tokio::sync::OnceCell<Option<HashMap<K, V>>>,
237 token: tokio_util::sync::CancellationToken,
238}
239
240impl<K, V> BatchResult<K, V> {
241 fn new() -> Self {
242 Self {
243 values: tokio::sync::OnceCell::new(),
244 token: tokio_util::sync::CancellationToken::new(),
245 }
246 }
247
248 async fn wait(&self) -> Result<&HashMap<K, V>, ()> {
249 if !self.token.is_cancelled() {
250 self.token.cancelled().await;
251 }
252
253 self.values.get().ok_or(())?.as_ref().ok_or(())
254 }
255}
256
257struct Batch<E>
258where
259 E: DataLoaderFetcher + Send + Sync + 'static,
260{
261 items: HashSet<E::Key>,
262 result: Arc<BatchResult<E::Key, E::Value>>,
263 semaphore: Arc<tokio::sync::Semaphore>,
264 created_at: std::time::Instant,
265}
266
267impl<E> Batch<E>
268where
269 E: DataLoaderFetcher + Send + Sync + 'static,
270{
271 fn new(semaphore: Arc<tokio::sync::Semaphore>) -> Self {
272 Self {
273 items: HashSet::new(),
274 result: Arc::new(BatchResult::new()),
275 semaphore,
276 created_at: std::time::Instant::now(),
277 }
278 }
279
280 async fn spawn(self, executor: Arc<E>) {
281 let _drop_guard = self.result.token.clone().drop_guard();
282 let _ticket = self.semaphore.acquire_owned().await.unwrap();
283 let result = executor.load(self.items).await;
284
285 #[cfg_attr(all(coverage_nightly, test), coverage(off))]
286 fn unknwown_error<E>(_: E) -> ! {
287 unreachable!(
288 "batch result already set, this is a bug please report it https://github.com/scufflecloud/scuffle/issues"
289 )
290 }
291
292 self.result.values.set(result).map_err(unknwown_error).unwrap();
293 }
294}
295
296#[cfg_attr(all(coverage_nightly, test), coverage(off))]
297#[cfg(test)]
298mod tests {
299 use std::sync::atomic::AtomicUsize;
300
301 use super::*;
302
303 struct TestFetcher<K, V> {
304 values: HashMap<K, V>,
305 delay: std::time::Duration,
306 requests: Arc<AtomicUsize>,
307 capacity: usize,
308 }
309
310 impl<K, V> DataLoaderFetcher for TestFetcher<K, V>
311 where
312 K: Clone + Eq + std::hash::Hash + Send + Sync,
313 V: Clone + Send + Sync,
314 {
315 type Key = K;
316 type Value = V;
317
318 async fn load(&self, keys: HashSet<Self::Key>) -> Option<HashMap<Self::Key, Self::Value>> {
319 assert!(keys.len() <= self.capacity);
320 tokio::time::sleep(self.delay).await;
321 self.requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
322 Some(
323 keys.into_iter()
324 .filter_map(|k| {
325 let value = self.values.get(&k)?.clone();
326 Some((k, value))
327 })
328 .collect(),
329 )
330 }
331 }
332
333 #[cfg(not(valgrind))] #[tokio::test]
335 async fn basic() {
336 let requests = Arc::new(AtomicUsize::new(0));
337
338 let fetcher = TestFetcher {
339 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
340 delay: std::time::Duration::from_millis(5),
341 requests: requests.clone(),
342 capacity: 2,
343 };
344
345 let loader = DataLoader::builder().batch_size(2).concurrency(1).build(fetcher);
346
347 let start = std::time::Instant::now();
348 let a = loader.load("a").await.unwrap();
349 assert_eq!(a, Some(1));
350 assert!(start.elapsed() < std::time::Duration::from_millis(15));
351 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
352
353 let start = std::time::Instant::now();
354 let b = loader.load("b").await.unwrap();
355 assert_eq!(b, Some(2));
356 assert!(start.elapsed() < std::time::Duration::from_millis(15));
357 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
358 let start = std::time::Instant::now();
359 let c = loader.load("c").await.unwrap();
360 assert_eq!(c, Some(3));
361 assert!(start.elapsed() < std::time::Duration::from_millis(15));
362 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 3);
363
364 let start = std::time::Instant::now();
365 let ab = loader.load_many(vec!["a", "b"]).await.unwrap();
366 assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2)]));
367 assert!(start.elapsed() < std::time::Duration::from_millis(15));
368 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 4);
369
370 let start = std::time::Instant::now();
371 let unknown = loader.load("unknown").await.unwrap();
372 assert_eq!(unknown, None);
373 assert!(start.elapsed() < std::time::Duration::from_millis(15));
374 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
375 }
376
377 #[cfg(not(valgrind))] #[tokio::test]
379 async fn concurrency_high() {
380 let requests = Arc::new(AtomicUsize::new(0));
381
382 let fetcher = TestFetcher {
383 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
384 delay: std::time::Duration::from_millis(5),
385 requests: requests.clone(),
386 capacity: 2,
387 };
388
389 let loader = DataLoader::builder().batch_size(2).concurrency(10).build(fetcher);
390
391 let start = std::time::Instant::now();
392 let ab = loader
393 .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
394 .await
395 .unwrap();
396 assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
397 assert!(start.elapsed() < std::time::Duration::from_millis(15));
398 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
399 }
400
401 #[cfg(not(valgrind))] #[tokio::test]
403 async fn delay_low() {
404 let requests = Arc::new(AtomicUsize::new(0));
405
406 let fetcher = TestFetcher {
407 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
408 delay: std::time::Duration::from_millis(5),
409 requests: requests.clone(),
410 capacity: 2,
411 };
412
413 let loader = DataLoader::builder()
414 .batch_size(2)
415 .concurrency(1)
416 .delay(std::time::Duration::from_millis(10))
417 .build(fetcher);
418
419 let start = std::time::Instant::now();
420 let ab = loader
421 .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
422 .await
423 .unwrap();
424 assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
425 assert!(start.elapsed() < std::time::Duration::from_millis(35));
426 assert!(start.elapsed() >= std::time::Duration::from_millis(25));
427 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
428 }
429
430 #[cfg(not(valgrind))] #[tokio::test]
432 async fn batch_size() {
433 let requests = Arc::new(AtomicUsize::new(0));
434
435 let fetcher = TestFetcher {
436 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
437 delay: std::time::Duration::from_millis(5),
438 requests: requests.clone(),
439 capacity: 100,
440 };
441
442 let loader = DataLoaderBuilder::default()
443 .batch_size(100)
444 .concurrency(1)
445 .delay(std::time::Duration::from_millis(10))
446 .build(fetcher);
447
448 let start = std::time::Instant::now();
449 let ab = loader
450 .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
451 .await
452 .unwrap();
453 assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
454 assert!(start.elapsed() >= std::time::Duration::from_millis(10));
455 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
456 }
457
458 #[cfg(not(valgrind))] #[tokio::test]
460 async fn high_concurrency() {
461 let requests = Arc::new(AtomicUsize::new(0));
462
463 let fetcher = TestFetcher {
464 values: HashMap::from_iter((0..1134).map(|i| (i, i * 2 + 5))),
465 delay: std::time::Duration::from_millis(5),
466 requests: requests.clone(),
467 capacity: 100,
468 };
469
470 let loader = DataLoaderBuilder::default()
471 .batch_size(100)
472 .concurrency(10)
473 .delay(std::time::Duration::from_millis(10))
474 .build(fetcher);
475
476 let start = std::time::Instant::now();
477 let ab = loader.load_many(0..1134).await.unwrap();
478 assert_eq!(ab, HashMap::from_iter((0..1134).map(|i| (i, i * 2 + 5))));
479 assert!(start.elapsed() >= std::time::Duration::from_millis(15));
480 assert!(start.elapsed() < std::time::Duration::from_millis(25));
481 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1134 / 100 + 1);
482 }
483
484 #[cfg(not(valgrind))] #[tokio::test]
486 async fn delayed_start() {
487 let requests = Arc::new(AtomicUsize::new(0));
488
489 let fetcher = TestFetcher {
490 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
491 delay: std::time::Duration::from_millis(5),
492 requests: requests.clone(),
493 capacity: 2,
494 };
495
496 let loader = DataLoader::builder()
497 .batch_size(2)
498 .concurrency(100)
499 .delay(std::time::Duration::from_millis(10))
500 .build(fetcher);
501
502 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
503
504 let start = std::time::Instant::now();
505 let ab = loader
506 .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
507 .await
508 .unwrap();
509 assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
510 assert!(start.elapsed() >= std::time::Duration::from_millis(5));
511 assert!(start.elapsed() < std::time::Duration::from_millis(25));
512 }
513
514 #[cfg(not(valgrind))] #[tokio::test]
516 async fn delayed_start_single() {
517 let requests = Arc::new(AtomicUsize::new(0));
518
519 let fetcher = TestFetcher {
520 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
521 delay: std::time::Duration::from_millis(5),
522 requests: requests.clone(),
523 capacity: 2,
524 };
525
526 let loader = DataLoader::builder()
527 .batch_size(2)
528 .concurrency(100)
529 .delay(std::time::Duration::from_millis(10))
530 .build(fetcher);
531
532 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
533
534 let start = std::time::Instant::now();
535 let ab = loader.load_many(vec!["a"]).await.unwrap();
536 assert_eq!(ab, HashMap::from_iter(vec![("a", 1)]));
537 assert!(start.elapsed() >= std::time::Duration::from_millis(15));
538 assert!(start.elapsed() < std::time::Duration::from_millis(20));
539 }
540
541 #[cfg(not(valgrind))] #[tokio::test]
543 async fn deduplication() {
544 let requests = Arc::new(AtomicUsize::new(0));
545
546 let fetcher = TestFetcher {
547 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
548 delay: std::time::Duration::from_millis(5),
549 requests: requests.clone(),
550 capacity: 4,
551 };
552
553 let loader = DataLoader::builder()
554 .batch_size(4)
555 .concurrency(1)
556 .delay(std::time::Duration::from_millis(10))
557 .build(fetcher);
558
559 let start = std::time::Instant::now();
560 let ab = loader.load_many(vec!["a", "a", "b", "b", "c", "c"]).await.unwrap();
561 assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
562 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
563 assert!(start.elapsed() >= std::time::Duration::from_millis(5));
564 assert!(start.elapsed() < std::time::Duration::from_millis(20));
565 }
566
567 #[cfg(not(valgrind))] #[tokio::test]
569 async fn already_batch() {
570 let requests = Arc::new(AtomicUsize::new(0));
571
572 let fetcher = TestFetcher {
573 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
574 delay: std::time::Duration::from_millis(5),
575 requests: requests.clone(),
576 capacity: 2,
577 };
578
579 let loader = DataLoader::builder().batch_size(10).concurrency(1).build(fetcher);
580
581 let start = std::time::Instant::now();
582 let (a, b) = tokio::join!(loader.load("a"), loader.load("b"));
583 assert_eq!(a, Ok(Some(1)));
584 assert_eq!(b, Ok(Some(2)));
585 assert!(start.elapsed() < std::time::Duration::from_millis(15));
586 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
587 }
588}