1use std::future::Future;
2use std::sync::Arc;
3
4use tokio::sync::oneshot;
5
6pub struct BatchResponse<Resp> {
8 send: oneshot::Sender<Resp>,
9}
10
11impl<Resp> BatchResponse<Resp> {
12 #[must_use]
14 pub fn new(send: oneshot::Sender<Resp>) -> Self {
15 Self { send }
16 }
17
18 #[inline(always)]
20 pub fn send(self, response: Resp) {
21 let _ = self.send.send(response);
22 }
23
24 #[inline(always)]
26 pub fn send_ok<O, E>(self, response: O)
27 where
28 Resp: From<Result<O, E>>,
29 {
30 self.send(Ok(response).into())
31 }
32
33 #[inline(always)]
35 pub fn send_err<O, E>(self, error: E)
36 where
37 Resp: From<Result<O, E>>,
38 {
39 self.send(Err(error).into())
40 }
41
42 #[inline(always)]
44 pub fn send_none<T>(self)
45 where
46 Resp: From<Option<T>>,
47 {
48 self.send(None.into())
49 }
50
51 #[inline(always)]
53 pub fn send_some<T>(self, value: T)
54 where
55 Resp: From<Option<T>>,
56 {
57 self.send(Some(value).into())
58 }
59}
60
61pub trait BatchExecutor {
63 type Request: Send + 'static;
65 type Response: Send + Sync + 'static;
67
68 fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) -> impl Future<Output = ()> + Send;
72}
73
74#[derive(Clone, Copy, Debug)]
76#[must_use = "builders must be used to create a batcher"]
77pub struct BatcherBuilder<E> {
78 batch_size: usize,
79 concurrency: usize,
80 delay: std::time::Duration,
81 _marker: std::marker::PhantomData<E>,
82}
83
84impl<E> Default for BatcherBuilder<E> {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl<E> BatcherBuilder<E> {
91 pub const fn new() -> Self {
93 Self {
94 batch_size: 1000,
95 concurrency: 50,
96 delay: std::time::Duration::from_millis(5),
97 _marker: std::marker::PhantomData,
98 }
99 }
100
101 #[inline]
103 pub const fn batch_size(mut self, batch_size: usize) -> Self {
104 self.with_batch_size(batch_size);
105 self
106 }
107
108 #[inline]
110 pub const fn delay(mut self, delay: std::time::Duration) -> Self {
111 self.with_delay(delay);
112 self
113 }
114
115 #[inline]
117 pub const fn concurrency(mut self, concurrency: usize) -> Self {
118 self.with_concurrency(concurrency);
119 self
120 }
121
122 #[inline]
124 pub const fn with_concurrency(&mut self, concurrency: usize) -> &mut Self {
125 self.concurrency = concurrency;
126 self
127 }
128
129 #[inline]
131 pub const fn with_batch_size(&mut self, batch_size: usize) -> &mut Self {
132 self.batch_size = batch_size;
133 self
134 }
135
136 #[inline]
138 pub const fn with_delay(&mut self, delay: std::time::Duration) -> &mut Self {
139 self.delay = delay;
140 self
141 }
142
143 #[inline]
145 pub fn build(self, executor: E) -> Batcher<E>
146 where
147 E: BatchExecutor + Send + Sync + 'static,
148 {
149 Batcher::new(executor, self.batch_size, self.concurrency, self.delay)
150 }
151}
152
153#[must_use = "batchers must be used to execute batches"]
155pub struct Batcher<E>
156where
157 E: BatchExecutor + Send + Sync + 'static,
158{
159 _auto_spawn: tokio::task::JoinHandle<()>,
160 executor: Arc<E>,
161 semaphore: Arc<tokio::sync::Semaphore>,
162 current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
163 batch_size: usize,
164}
165
166struct Batch<E>
167where
168 E: BatchExecutor + Send + Sync + 'static,
169{
170 items: Vec<(E::Request, BatchResponse<E::Response>)>,
171 semaphore: Arc<tokio::sync::Semaphore>,
172 created_at: std::time::Instant,
173}
174
175impl<E> Batcher<E>
176where
177 E: BatchExecutor + Send + Sync + 'static,
178{
179 pub fn new(executor: E, batch_size: usize, concurrency: usize, delay: std::time::Duration) -> Self {
181 let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency.max(1)));
182 let current_batch = Arc::new(tokio::sync::Mutex::new(None));
183 let executor = Arc::new(executor);
184
185 let join_handle = tokio::spawn(batch_loop(executor.clone(), current_batch.clone(), delay));
186
187 Self {
188 executor,
189 _auto_spawn: join_handle,
190 semaphore,
191 current_batch,
192 batch_size: batch_size.max(1),
193 }
194 }
195
196 pub const fn builder() -> BatcherBuilder<E> {
198 BatcherBuilder::new()
199 }
200
201 pub async fn execute(&self, items: E::Request) -> Option<E::Response> {
203 self.execute_many(std::iter::once(items)).await.pop()?
204 }
205
206 pub async fn execute_many<I>(&self, items: I) -> Vec<Option<E::Response>>
208 where
209 I: IntoIterator<Item = E::Request>,
210 {
211 let mut responses = Vec::new();
212
213 {
214 let mut batch = self.current_batch.lock().await;
215
216 for item in items {
217 if batch.is_none() {
218 batch.replace(Batch::new(self.semaphore.clone()));
219 }
220
221 let batch_mut = batch.as_mut().unwrap();
222 let (tx, rx) = oneshot::channel();
223 batch_mut.items.push((item, BatchResponse::new(tx)));
224 responses.push(rx);
225
226 if batch_mut.items.len() >= self.batch_size {
227 tokio::spawn(batch.take().unwrap().spawn(self.executor.clone()));
228 }
229 }
230 }
231
232 let mut results = Vec::with_capacity(responses.len());
233 for response in responses {
234 results.push(response.await.ok());
235 }
236
237 results
238 }
239}
240
241async fn batch_loop<E>(
242 executor: Arc<E>,
243 current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
244 delay: std::time::Duration,
245) where
246 E: BatchExecutor + Send + Sync + 'static,
247{
248 let mut delay_delta = delay;
249 loop {
250 tokio::time::sleep(delay_delta).await;
251
252 let mut batch = current_batch.lock().await;
253 let Some(created_at) = batch.as_ref().map(|b| b.created_at) else {
254 delay_delta = delay;
255 continue;
256 };
257
258 let remaining = delay.saturating_sub(created_at.elapsed());
259 if remaining == std::time::Duration::ZERO {
260 tokio::spawn(batch.take().unwrap().spawn(executor.clone()));
261 delay_delta = delay;
262 } else {
263 delay_delta = remaining;
264 }
265 }
266}
267
268impl<E> Batch<E>
269where
270 E: BatchExecutor + Send + Sync + 'static,
271{
272 fn new(semaphore: Arc<tokio::sync::Semaphore>) -> Self {
273 Self {
274 created_at: std::time::Instant::now(),
275 items: Vec::new(),
276 semaphore,
277 }
278 }
279
280 async fn spawn(self, executor: Arc<E>) {
281 let _ticket = self.semaphore.acquire_owned().await;
282 executor.execute(self.items).await;
283 }
284}
285
286#[cfg_attr(all(coverage_nightly, test), coverage(off))]
287#[cfg(test)]
288mod tests {
289 use std::collections::HashMap;
290 use std::sync::atomic::AtomicUsize;
291
292 use super::*;
293
294 struct TestExecutor<K, V> {
295 values: HashMap<K, V>,
296 delay: std::time::Duration,
297 requests: Arc<AtomicUsize>,
298 capacity: usize,
299 }
300
301 impl<K, V> BatchExecutor for TestExecutor<K, V>
302 where
303 K: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
304 V: Clone + Send + Sync + 'static,
305 {
306 type Request = K;
307 type Response = V;
308
309 async fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) {
310 tokio::time::sleep(self.delay).await;
311
312 assert!(requests.len() <= self.capacity);
313
314 self.requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
315 for (request, response) in requests {
316 if let Some(value) = self.values.get(&request) {
317 response.send(value.clone());
318 }
319 }
320 }
321 }
322
323 #[cfg(not(valgrind))] #[tokio::test]
325 async fn basic() {
326 let requests = Arc::new(AtomicUsize::new(0));
327
328 let fetcher = TestExecutor {
329 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
330 delay: std::time::Duration::from_millis(5),
331 requests: requests.clone(),
332 capacity: 2,
333 };
334
335 let loader = Batcher::builder().batch_size(2).concurrency(1).build(fetcher);
336
337 let start = std::time::Instant::now();
338 let a = loader.execute("a").await;
339 assert_eq!(a, Some(1));
340 assert!(start.elapsed() < std::time::Duration::from_millis(100));
341 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
342
343 let start = std::time::Instant::now();
344 let b = loader.execute("b").await;
345 assert_eq!(b, Some(2));
346 assert!(start.elapsed() < std::time::Duration::from_millis(100));
347 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
348 let start = std::time::Instant::now();
349 let c = loader.execute("c").await;
350 assert_eq!(c, Some(3));
351 assert!(start.elapsed() < std::time::Duration::from_millis(100));
352 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 3);
353
354 let start = std::time::Instant::now();
355 let ab = loader.execute_many(vec!["a", "b"]).await;
356 assert_eq!(ab, vec![Some(1), Some(2)]);
357 assert!(start.elapsed() < std::time::Duration::from_millis(100));
358 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 4);
359
360 let start = std::time::Instant::now();
361 let unknown = loader.execute("unknown").await;
362 assert_eq!(unknown, None);
363 assert!(start.elapsed() < std::time::Duration::from_millis(100));
364 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
365 }
366
367 #[cfg(not(valgrind))] #[tokio::test]
369 async fn concurrency_high() {
370 let requests = Arc::new(AtomicUsize::new(0));
371
372 let fetcher = TestExecutor {
373 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
374 delay: std::time::Duration::from_millis(5),
375 requests: requests.clone(),
376 capacity: 2,
377 };
378
379 let loader = Batcher::builder().batch_size(2).concurrency(10).build(fetcher);
380
381 let start = std::time::Instant::now();
382 let ab = loader
383 .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
384 .await;
385 assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
386 assert!(start.elapsed() < std::time::Duration::from_millis(100));
387 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
388 }
389
390 #[cfg(not(valgrind))] #[tokio::test]
392 async fn delay_low() {
393 let requests = Arc::new(AtomicUsize::new(0));
394
395 let fetcher = TestExecutor {
396 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
397 delay: std::time::Duration::from_millis(5),
398 requests: requests.clone(),
399 capacity: 2,
400 };
401
402 let loader = Batcher::builder()
403 .batch_size(2)
404 .concurrency(1)
405 .delay(std::time::Duration::from_millis(10))
406 .build(fetcher);
407
408 let start = std::time::Instant::now();
409 let ab = loader
410 .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
411 .await;
412 assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
413 assert!(start.elapsed() < std::time::Duration::from_millis(35));
414 assert!(start.elapsed() >= std::time::Duration::from_millis(25));
415 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
416 }
417
418 #[cfg(not(valgrind))] #[tokio::test]
420 async fn batch_size() {
421 let requests = Arc::new(AtomicUsize::new(0));
422
423 let fetcher = TestExecutor {
424 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
425 delay: std::time::Duration::from_millis(5),
426 requests: requests.clone(),
427 capacity: 100,
428 };
429
430 let loader = BatcherBuilder::default()
431 .batch_size(100)
432 .concurrency(1)
433 .delay(std::time::Duration::from_millis(10))
434 .build(fetcher);
435
436 let start = std::time::Instant::now();
437 let ab = loader
438 .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
439 .await;
440 assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
441 assert!(start.elapsed() >= std::time::Duration::from_millis(10));
442 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
443 }
444
445 #[cfg(not(valgrind))] #[tokio::test]
447 async fn high_concurrency() {
448 let requests = Arc::new(AtomicUsize::new(0));
449
450 let fetcher = TestExecutor {
451 values: HashMap::from_iter((0..1134).map(|i| (i, i * 2 + 5))),
452 delay: std::time::Duration::from_millis(5),
453 requests: requests.clone(),
454 capacity: 100,
455 };
456
457 let loader = BatcherBuilder::default()
458 .batch_size(100)
459 .concurrency(10)
460 .delay(std::time::Duration::from_millis(10))
461 .build(fetcher);
462
463 let start = std::time::Instant::now();
464 let ab = loader.execute_many(0..1134).await;
465 assert_eq!(ab, (0..1134).map(|i| Some(i * 2 + 5)).collect::<Vec<_>>());
466 assert!(start.elapsed() >= std::time::Duration::from_millis(15));
467 assert!(start.elapsed() < std::time::Duration::from_millis(25));
468 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1134 / 100 + 1);
469 }
470
471 #[cfg(not(valgrind))] #[tokio::test]
473 async fn delayed_start() {
474 let requests = Arc::new(AtomicUsize::new(0));
475
476 let fetcher = TestExecutor {
477 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
478 delay: std::time::Duration::from_millis(5),
479 requests: requests.clone(),
480 capacity: 2,
481 };
482
483 let loader = BatcherBuilder::default()
484 .batch_size(2)
485 .concurrency(100)
486 .delay(std::time::Duration::from_millis(10))
487 .build(fetcher);
488
489 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
490
491 let start = std::time::Instant::now();
492 let ab = loader
493 .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
494 .await;
495 assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
496 assert!(start.elapsed() >= std::time::Duration::from_millis(5));
497 assert!(start.elapsed() < std::time::Duration::from_millis(25));
498 }
499
500 #[cfg(not(valgrind))] #[tokio::test]
502 async fn delayed_start_single() {
503 let requests = Arc::new(AtomicUsize::new(0));
504
505 let fetcher = TestExecutor {
506 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
507 delay: std::time::Duration::from_millis(5),
508 requests: requests.clone(),
509 capacity: 2,
510 };
511
512 let loader = BatcherBuilder::default()
513 .batch_size(2)
514 .concurrency(100)
515 .delay(std::time::Duration::from_millis(10))
516 .build(fetcher);
517
518 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
519
520 let start = std::time::Instant::now();
521 let ab = loader.execute_many(vec!["a"]).await;
522 assert_eq!(ab, vec![Some(1)]);
523 assert!(start.elapsed() >= std::time::Duration::from_millis(15));
524 assert!(start.elapsed() < std::time::Duration::from_millis(20));
525 }
526
527 #[cfg(not(valgrind))] #[tokio::test]
529 async fn no_deduplication() {
530 let requests = Arc::new(AtomicUsize::new(0));
531
532 let fetcher = TestExecutor {
533 values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
534 delay: std::time::Duration::from_millis(5),
535 requests: requests.clone(),
536 capacity: 4,
537 };
538
539 let loader = BatcherBuilder::default()
540 .batch_size(4)
541 .concurrency(1)
542 .delay(std::time::Duration::from_millis(10))
543 .build(fetcher);
544
545 let start = std::time::Instant::now();
546 let ab = loader.execute_many(vec!["a", "a", "b", "b", "c", "c"]).await;
547 assert_eq!(ab, vec![Some(1), Some(1), Some(2), Some(2), Some(3), Some(3)]);
548 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
549 assert!(start.elapsed() >= std::time::Duration::from_millis(5));
550 assert!(start.elapsed() < std::time::Duration::from_millis(20));
551 }
552
553 #[cfg(not(valgrind))] #[tokio::test]
555 async fn result() {
556 let requests = Arc::new(AtomicUsize::new(0));
557
558 struct TestExecutor(Arc<AtomicUsize>);
559
560 impl BatchExecutor for TestExecutor {
561 type Request = &'static str;
562 type Response = Result<usize, ()>;
563
564 async fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) {
565 self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
566 for (request, response) in requests {
567 match request.parse() {
568 Ok(value) => response.send_ok(value),
569 Err(_) => response.send_err(()),
570 }
571 }
572 }
573 }
574
575 let loader = BatcherBuilder::default()
576 .batch_size(4)
577 .concurrency(1)
578 .delay(std::time::Duration::from_millis(10))
579 .build(TestExecutor(requests.clone()));
580
581 let start = std::time::Instant::now();
582 let ab = loader.execute_many(vec!["1", "1", "2", "2", "3", "3", "hello"]).await;
583 assert_eq!(
584 ab,
585 vec![
586 Some(Ok(1)),
587 Some(Ok(1)),
588 Some(Ok(2)),
589 Some(Ok(2)),
590 Some(Ok(3)),
591 Some(Ok(3)),
592 Some(Err(()))
593 ]
594 );
595 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
596 assert!(start.elapsed() >= std::time::Duration::from_millis(5));
597 assert!(start.elapsed() < std::time::Duration::from_millis(20));
598 }
599
600 #[cfg(not(valgrind))] #[tokio::test]
602 async fn option() {
603 let requests = Arc::new(AtomicUsize::new(0));
604
605 struct TestExecutor(Arc<AtomicUsize>);
606
607 impl BatchExecutor for TestExecutor {
608 type Request = &'static str;
609 type Response = Option<usize>;
610
611 async fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) {
612 self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
613 for (request, response) in requests {
614 match request.parse() {
615 Ok(value) => response.send_some(value),
616 Err(_) => response.send_none(),
617 }
618 }
619 }
620 }
621
622 let loader = BatcherBuilder::default()
623 .batch_size(4)
624 .concurrency(1)
625 .delay(std::time::Duration::from_millis(10))
626 .build(TestExecutor(requests.clone()));
627
628 let start = std::time::Instant::now();
629 let ab = loader.execute_many(vec!["1", "1", "2", "2", "3", "3", "hello"]).await;
630 assert_eq!(
631 ab,
632 vec![
633 Some(Some(1)),
634 Some(Some(1)),
635 Some(Some(2)),
636 Some(Some(2)),
637 Some(Some(3)),
638 Some(Some(3)),
639 Some(None)
640 ]
641 );
642 assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
643 assert!(start.elapsed() >= std::time::Duration::from_millis(5));
644 assert!(start.elapsed() < std::time::Duration::from_millis(20));
645 }
646}