scuffle_batching/
batch.rs

1use std::future::Future;
2use std::sync::Arc;
3
4use tokio::sync::oneshot;
5
6/// A response to a batch request
7pub struct BatchResponse<Resp> {
8    send: oneshot::Sender<Resp>,
9}
10
11impl<Resp> BatchResponse<Resp> {
12    /// Create a new batch response
13    #[must_use]
14    pub fn new(send: oneshot::Sender<Resp>) -> Self {
15        Self { send }
16    }
17
18    /// Send a response back to the requester
19    #[inline(always)]
20    pub fn send(self, response: Resp) {
21        let _ = self.send.send(response);
22    }
23
24    /// Send a successful response back to the requester
25    #[inline(always)]
26    pub fn send_ok<O, E>(self, response: O)
27    where
28        Resp: From<Result<O, E>>,
29    {
30        self.send(Ok(response).into())
31    }
32
33    /// Send an error response back to the requestor
34    #[inline(always)]
35    pub fn send_err<O, E>(self, error: E)
36    where
37        Resp: From<Result<O, E>>,
38    {
39        self.send(Err(error).into())
40    }
41
42    /// Send a `None` response back to the requestor
43    #[inline(always)]
44    pub fn send_none<T>(self)
45    where
46        Resp: From<Option<T>>,
47    {
48        self.send(None.into())
49    }
50
51    /// Send a value response back to the requestor
52    #[inline(always)]
53    pub fn send_some<T>(self, value: T)
54    where
55        Resp: From<Option<T>>,
56    {
57        self.send(Some(value).into())
58    }
59}
60
61/// A trait for executing batches
62pub trait BatchExecutor {
63    /// The incoming request type
64    type Request: Send + 'static;
65    /// The outgoing response type
66    type Response: Send + Sync + 'static;
67
68    /// Execute a batch of requests
69    /// You must call `send` on the `BatchResponse` to send the response back to
70    /// the client
71    fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) -> impl Future<Output = ()> + Send;
72}
73
74/// A builder for a [`Batcher`]
75#[derive(Clone, Copy, Debug)]
76#[must_use = "builders must be used to create a batcher"]
77pub struct BatcherBuilder<E> {
78    batch_size: usize,
79    concurrency: usize,
80    delay: std::time::Duration,
81    _marker: std::marker::PhantomData<E>,
82}
83
84impl<E> Default for BatcherBuilder<E> {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl<E> BatcherBuilder<E> {
91    /// Create a new builder
92    pub const fn new() -> Self {
93        Self {
94            batch_size: 1000,
95            concurrency: 50,
96            delay: std::time::Duration::from_millis(5),
97            _marker: std::marker::PhantomData,
98        }
99    }
100
101    /// Set the batch size
102    #[inline]
103    pub const fn batch_size(mut self, batch_size: usize) -> Self {
104        self.with_batch_size(batch_size);
105        self
106    }
107
108    /// Set the delay
109    #[inline]
110    pub const fn delay(mut self, delay: std::time::Duration) -> Self {
111        self.with_delay(delay);
112        self
113    }
114
115    /// Set the concurrency to 1
116    #[inline]
117    pub const fn concurrency(mut self, concurrency: usize) -> Self {
118        self.with_concurrency(concurrency);
119        self
120    }
121
122    /// Set the concurrency
123    #[inline]
124    pub const fn with_concurrency(&mut self, concurrency: usize) -> &mut Self {
125        self.concurrency = concurrency;
126        self
127    }
128
129    /// Set the batch size
130    #[inline]
131    pub const fn with_batch_size(&mut self, batch_size: usize) -> &mut Self {
132        self.batch_size = batch_size;
133        self
134    }
135
136    /// Set the delay
137    #[inline]
138    pub const fn with_delay(&mut self, delay: std::time::Duration) -> &mut Self {
139        self.delay = delay;
140        self
141    }
142
143    /// Build the batcher
144    #[inline]
145    pub fn build(self, executor: E) -> Batcher<E>
146    where
147        E: BatchExecutor + Send + Sync + 'static,
148    {
149        Batcher::new(executor, self.batch_size, self.concurrency, self.delay)
150    }
151}
152
153/// A batcher used to batch requests to a [`BatchExecutor`]
154#[must_use = "batchers must be used to execute batches"]
155pub struct Batcher<E>
156where
157    E: BatchExecutor + Send + Sync + 'static,
158{
159    _auto_spawn: tokio::task::JoinHandle<()>,
160    executor: Arc<E>,
161    semaphore: Arc<tokio::sync::Semaphore>,
162    current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
163    batch_size: usize,
164}
165
166struct Batch<E>
167where
168    E: BatchExecutor + Send + Sync + 'static,
169{
170    items: Vec<(E::Request, BatchResponse<E::Response>)>,
171    semaphore: Arc<tokio::sync::Semaphore>,
172    created_at: std::time::Instant,
173}
174
175impl<E> Batcher<E>
176where
177    E: BatchExecutor + Send + Sync + 'static,
178{
179    /// Create a new batcher
180    pub fn new(executor: E, batch_size: usize, concurrency: usize, delay: std::time::Duration) -> Self {
181        let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency.max(1)));
182        let current_batch = Arc::new(tokio::sync::Mutex::new(None));
183        let executor = Arc::new(executor);
184
185        let join_handle = tokio::spawn(batch_loop(executor.clone(), current_batch.clone(), delay));
186
187        Self {
188            executor,
189            _auto_spawn: join_handle,
190            semaphore,
191            current_batch,
192            batch_size: batch_size.max(1),
193        }
194    }
195
196    /// Create a builder for a [`Batcher`]
197    pub const fn builder() -> BatcherBuilder<E> {
198        BatcherBuilder::new()
199    }
200
201    /// Execute a single request
202    pub async fn execute(&self, items: E::Request) -> Option<E::Response> {
203        self.execute_many(std::iter::once(items)).await.pop()?
204    }
205
206    /// Execute many requests
207    pub async fn execute_many<I>(&self, items: I) -> Vec<Option<E::Response>>
208    where
209        I: IntoIterator<Item = E::Request>,
210    {
211        let mut responses = Vec::new();
212
213        {
214            let mut batch = self.current_batch.lock().await;
215
216            for item in items {
217                if batch.is_none() {
218                    batch.replace(Batch::new(self.semaphore.clone()));
219                }
220
221                let batch_mut = batch.as_mut().unwrap();
222                let (tx, rx) = oneshot::channel();
223                batch_mut.items.push((item, BatchResponse::new(tx)));
224                responses.push(rx);
225
226                if batch_mut.items.len() >= self.batch_size {
227                    tokio::spawn(batch.take().unwrap().spawn(self.executor.clone()));
228                }
229            }
230        }
231
232        let mut results = Vec::with_capacity(responses.len());
233        for response in responses {
234            results.push(response.await.ok());
235        }
236
237        results
238    }
239}
240
241async fn batch_loop<E>(
242    executor: Arc<E>,
243    current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
244    delay: std::time::Duration,
245) where
246    E: BatchExecutor + Send + Sync + 'static,
247{
248    let mut delay_delta = delay;
249    loop {
250        tokio::time::sleep(delay_delta).await;
251
252        let mut batch = current_batch.lock().await;
253        let Some(created_at) = batch.as_ref().map(|b| b.created_at) else {
254            delay_delta = delay;
255            continue;
256        };
257
258        let remaining = delay.saturating_sub(created_at.elapsed());
259        if remaining == std::time::Duration::ZERO {
260            tokio::spawn(batch.take().unwrap().spawn(executor.clone()));
261            delay_delta = delay;
262        } else {
263            delay_delta = remaining;
264        }
265    }
266}
267
268impl<E> Batch<E>
269where
270    E: BatchExecutor + Send + Sync + 'static,
271{
272    fn new(semaphore: Arc<tokio::sync::Semaphore>) -> Self {
273        Self {
274            created_at: std::time::Instant::now(),
275            items: Vec::new(),
276            semaphore,
277        }
278    }
279
280    async fn spawn(self, executor: Arc<E>) {
281        let _ticket = self.semaphore.acquire_owned().await;
282        executor.execute(self.items).await;
283    }
284}
285
286#[cfg_attr(all(coverage_nightly, test), coverage(off))]
287#[cfg(test)]
288mod tests {
289    use std::collections::HashMap;
290    use std::sync::atomic::AtomicUsize;
291
292    use super::*;
293
294    struct TestExecutor<K, V> {
295        values: HashMap<K, V>,
296        delay: std::time::Duration,
297        requests: Arc<AtomicUsize>,
298        capacity: usize,
299    }
300
301    impl<K, V> BatchExecutor for TestExecutor<K, V>
302    where
303        K: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
304        V: Clone + Send + Sync + 'static,
305    {
306        type Request = K;
307        type Response = V;
308
309        async fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) {
310            tokio::time::sleep(self.delay).await;
311
312            assert!(requests.len() <= self.capacity);
313
314            self.requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
315            for (request, response) in requests {
316                if let Some(value) = self.values.get(&request) {
317                    response.send(value.clone());
318                }
319            }
320        }
321    }
322
323    #[cfg(not(valgrind))] // test is time-sensitive
324    #[tokio::test]
325    async fn basic() {
326        let requests = Arc::new(AtomicUsize::new(0));
327
328        let fetcher = TestExecutor {
329            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
330            delay: std::time::Duration::from_millis(5),
331            requests: requests.clone(),
332            capacity: 2,
333        };
334
335        let loader = Batcher::builder().batch_size(2).concurrency(1).build(fetcher);
336
337        let start = std::time::Instant::now();
338        let a = loader.execute("a").await;
339        assert_eq!(a, Some(1));
340        assert!(start.elapsed() < std::time::Duration::from_millis(100));
341        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
342
343        let start = std::time::Instant::now();
344        let b = loader.execute("b").await;
345        assert_eq!(b, Some(2));
346        assert!(start.elapsed() < std::time::Duration::from_millis(100));
347        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
348        let start = std::time::Instant::now();
349        let c = loader.execute("c").await;
350        assert_eq!(c, Some(3));
351        assert!(start.elapsed() < std::time::Duration::from_millis(100));
352        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 3);
353
354        let start = std::time::Instant::now();
355        let ab = loader.execute_many(vec!["a", "b"]).await;
356        assert_eq!(ab, vec![Some(1), Some(2)]);
357        assert!(start.elapsed() < std::time::Duration::from_millis(100));
358        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 4);
359
360        let start = std::time::Instant::now();
361        let unknown = loader.execute("unknown").await;
362        assert_eq!(unknown, None);
363        assert!(start.elapsed() < std::time::Duration::from_millis(100));
364        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
365    }
366
367    #[cfg(not(valgrind))] // test is time-sensitive
368    #[tokio::test]
369    async fn concurrency_high() {
370        let requests = Arc::new(AtomicUsize::new(0));
371
372        let fetcher = TestExecutor {
373            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
374            delay: std::time::Duration::from_millis(5),
375            requests: requests.clone(),
376            capacity: 2,
377        };
378
379        let loader = Batcher::builder().batch_size(2).concurrency(10).build(fetcher);
380
381        let start = std::time::Instant::now();
382        let ab = loader
383            .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
384            .await;
385        assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
386        assert!(start.elapsed() < std::time::Duration::from_millis(100));
387        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
388    }
389
390    #[cfg(not(valgrind))] // test is time-sensitive
391    #[tokio::test]
392    async fn delay_low() {
393        let requests = Arc::new(AtomicUsize::new(0));
394
395        let fetcher = TestExecutor {
396            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
397            delay: std::time::Duration::from_millis(5),
398            requests: requests.clone(),
399            capacity: 2,
400        };
401
402        let loader = Batcher::builder()
403            .batch_size(2)
404            .concurrency(1)
405            .delay(std::time::Duration::from_millis(10))
406            .build(fetcher);
407
408        let start = std::time::Instant::now();
409        let ab = loader
410            .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
411            .await;
412        assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
413        assert!(start.elapsed() < std::time::Duration::from_millis(35));
414        assert!(start.elapsed() >= std::time::Duration::from_millis(25));
415        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
416    }
417
418    #[cfg(not(valgrind))] // test is time-sensitive
419    #[tokio::test]
420    async fn batch_size() {
421        let requests = Arc::new(AtomicUsize::new(0));
422
423        let fetcher = TestExecutor {
424            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
425            delay: std::time::Duration::from_millis(5),
426            requests: requests.clone(),
427            capacity: 100,
428        };
429
430        let loader = BatcherBuilder::default()
431            .batch_size(100)
432            .concurrency(1)
433            .delay(std::time::Duration::from_millis(10))
434            .build(fetcher);
435
436        let start = std::time::Instant::now();
437        let ab = loader
438            .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
439            .await;
440        assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
441        assert!(start.elapsed() >= std::time::Duration::from_millis(10));
442        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
443    }
444
445    #[cfg(not(valgrind))] // test is time-sensitive
446    #[tokio::test]
447    async fn high_concurrency() {
448        let requests = Arc::new(AtomicUsize::new(0));
449
450        let fetcher = TestExecutor {
451            values: HashMap::from_iter((0..1134).map(|i| (i, i * 2 + 5))),
452            delay: std::time::Duration::from_millis(5),
453            requests: requests.clone(),
454            capacity: 100,
455        };
456
457        let loader = BatcherBuilder::default()
458            .batch_size(100)
459            .concurrency(10)
460            .delay(std::time::Duration::from_millis(10))
461            .build(fetcher);
462
463        let start = std::time::Instant::now();
464        let ab = loader.execute_many(0..1134).await;
465        assert_eq!(ab, (0..1134).map(|i| Some(i * 2 + 5)).collect::<Vec<_>>());
466        assert!(start.elapsed() >= std::time::Duration::from_millis(15));
467        assert!(start.elapsed() < std::time::Duration::from_millis(25));
468        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1134 / 100 + 1);
469    }
470
471    #[cfg(not(valgrind))] // test is time-sensitive
472    #[tokio::test]
473    async fn delayed_start() {
474        let requests = Arc::new(AtomicUsize::new(0));
475
476        let fetcher = TestExecutor {
477            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
478            delay: std::time::Duration::from_millis(5),
479            requests: requests.clone(),
480            capacity: 2,
481        };
482
483        let loader = BatcherBuilder::default()
484            .batch_size(2)
485            .concurrency(100)
486            .delay(std::time::Duration::from_millis(10))
487            .build(fetcher);
488
489        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
490
491        let start = std::time::Instant::now();
492        let ab = loader
493            .execute_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
494            .await;
495        assert_eq!(ab, vec![Some(1), Some(2), Some(3), None, None, None, None, None, None, None]);
496        assert!(start.elapsed() >= std::time::Duration::from_millis(5));
497        assert!(start.elapsed() < std::time::Duration::from_millis(25));
498    }
499
500    #[cfg(not(valgrind))] // test is time-sensitive
501    #[tokio::test]
502    async fn delayed_start_single() {
503        let requests = Arc::new(AtomicUsize::new(0));
504
505        let fetcher = TestExecutor {
506            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
507            delay: std::time::Duration::from_millis(5),
508            requests: requests.clone(),
509            capacity: 2,
510        };
511
512        let loader = BatcherBuilder::default()
513            .batch_size(2)
514            .concurrency(100)
515            .delay(std::time::Duration::from_millis(10))
516            .build(fetcher);
517
518        tokio::time::sleep(std::time::Duration::from_millis(5)).await;
519
520        let start = std::time::Instant::now();
521        let ab = loader.execute_many(vec!["a"]).await;
522        assert_eq!(ab, vec![Some(1)]);
523        assert!(start.elapsed() >= std::time::Duration::from_millis(15));
524        assert!(start.elapsed() < std::time::Duration::from_millis(20));
525    }
526
527    #[cfg(not(valgrind))] // test is time-sensitive
528    #[tokio::test]
529    async fn no_deduplication() {
530        let requests = Arc::new(AtomicUsize::new(0));
531
532        let fetcher = TestExecutor {
533            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
534            delay: std::time::Duration::from_millis(5),
535            requests: requests.clone(),
536            capacity: 4,
537        };
538
539        let loader = BatcherBuilder::default()
540            .batch_size(4)
541            .concurrency(1)
542            .delay(std::time::Duration::from_millis(10))
543            .build(fetcher);
544
545        let start = std::time::Instant::now();
546        let ab = loader.execute_many(vec!["a", "a", "b", "b", "c", "c"]).await;
547        assert_eq!(ab, vec![Some(1), Some(1), Some(2), Some(2), Some(3), Some(3)]);
548        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
549        assert!(start.elapsed() >= std::time::Duration::from_millis(5));
550        assert!(start.elapsed() < std::time::Duration::from_millis(20));
551    }
552
553    #[cfg(not(valgrind))] // test is time-sensitive
554    #[tokio::test]
555    async fn result() {
556        let requests = Arc::new(AtomicUsize::new(0));
557
558        struct TestExecutor(Arc<AtomicUsize>);
559
560        impl BatchExecutor for TestExecutor {
561            type Request = &'static str;
562            type Response = Result<usize, ()>;
563
564            async fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) {
565                self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
566                for (request, response) in requests {
567                    match request.parse() {
568                        Ok(value) => response.send_ok(value),
569                        Err(_) => response.send_err(()),
570                    }
571                }
572            }
573        }
574
575        let loader = BatcherBuilder::default()
576            .batch_size(4)
577            .concurrency(1)
578            .delay(std::time::Duration::from_millis(10))
579            .build(TestExecutor(requests.clone()));
580
581        let start = std::time::Instant::now();
582        let ab = loader.execute_many(vec!["1", "1", "2", "2", "3", "3", "hello"]).await;
583        assert_eq!(
584            ab,
585            vec![
586                Some(Ok(1)),
587                Some(Ok(1)),
588                Some(Ok(2)),
589                Some(Ok(2)),
590                Some(Ok(3)),
591                Some(Ok(3)),
592                Some(Err(()))
593            ]
594        );
595        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
596        assert!(start.elapsed() >= std::time::Duration::from_millis(5));
597        assert!(start.elapsed() < std::time::Duration::from_millis(20));
598    }
599
600    #[cfg(not(valgrind))] // test is time-sensitive
601    #[tokio::test]
602    async fn option() {
603        let requests = Arc::new(AtomicUsize::new(0));
604
605        struct TestExecutor(Arc<AtomicUsize>);
606
607        impl BatchExecutor for TestExecutor {
608            type Request = &'static str;
609            type Response = Option<usize>;
610
611            async fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) {
612                self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
613                for (request, response) in requests {
614                    match request.parse() {
615                        Ok(value) => response.send_some(value),
616                        Err(_) => response.send_none(),
617                    }
618                }
619            }
620        }
621
622        let loader = BatcherBuilder::default()
623            .batch_size(4)
624            .concurrency(1)
625            .delay(std::time::Duration::from_millis(10))
626            .build(TestExecutor(requests.clone()));
627
628        let start = std::time::Instant::now();
629        let ab = loader.execute_many(vec!["1", "1", "2", "2", "3", "3", "hello"]).await;
630        assert_eq!(
631            ab,
632            vec![
633                Some(Some(1)),
634                Some(Some(1)),
635                Some(Some(2)),
636                Some(Some(2)),
637                Some(Some(3)),
638                Some(Some(3)),
639                Some(None)
640            ]
641        );
642        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
643        assert!(start.elapsed() >= std::time::Duration::from_millis(5));
644        assert!(start.elapsed() < std::time::Duration::from_millis(20));
645    }
646}