1use std::borrow::Cow;
2use std::time::Duration;
3
4use bytes::BytesMut;
5use scuffle_amf0::Amf0Value;
6use scuffle_bytes_util::BytesCursorExt;
7use scuffle_future_ext::FutureExt;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::sync::oneshot;
10
11use super::define::RtmpCommand;
12use super::errors::SessionError;
13use crate::channels::{ChannelData, DataProducer, PublishRequest, UniqueID};
14use crate::chunk::{ChunkDecoder, ChunkEncoder, CHUNK_SIZE};
15use crate::handshake::{HandshakeServer, ServerHandshakeState};
16use crate::messages::{MessageParser, RtmpMessageData};
17use crate::netconnection::NetConnection;
18use crate::netstream::NetStreamWriter;
19use crate::protocol_control_messages::ProtocolControlMessagesWriter;
20use crate::user_control_messages::EventMessagesWriter;
21use crate::{handshake, PublishProducer};
22
23pub struct Session<S> {
24 app_name: Option<String>,
34
35 uid: Option<UniqueID>,
38
39 io: S,
41
42 read_buf: BytesMut,
44 write_buf: Vec<u8>,
46
47 skip_read: bool,
51
52 chunk_decoder: ChunkDecoder,
55 chunk_encoder: ChunkEncoder,
57
58 stream_id: u32,
60
61 data_producer: DataProducer,
63
64 is_publishing: bool,
66
67 publish_request_producer: PublishProducer,
70}
71
72impl<S> Session<S> {
73 pub fn new(io: S, data_producer: DataProducer, publish_request_producer: PublishProducer) -> Self {
74 Self {
75 uid: None,
76 app_name: None,
77 io,
78 skip_read: false,
79 chunk_decoder: ChunkDecoder::default(),
80 chunk_encoder: ChunkEncoder::default(),
81 read_buf: BytesMut::new(),
82 write_buf: Vec::new(),
83 data_producer,
84 stream_id: 0,
85 is_publishing: false,
86 publish_request_producer,
87 }
88 }
89
90 pub fn uid(&self) -> Option<UniqueID> {
91 self.uid
92 }
93}
94
95impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> Session<S> {
96 pub async fn run(&mut self) -> Result<bool, SessionError> {
102 let mut handshaker = HandshakeServer::default();
103 while !self.do_handshake(&mut handshaker).await? {
105 self.flush().await?;
106 }
107
108 drop(handshaker);
111
112 tracing::debug!("Handshake complete");
113
114 while match self.do_ready().await {
116 Ok(v) => v,
117 Err(err) if err.is_client_closed() => {
118 tracing::debug!("Client closed the connection");
121 false
122 }
123 Err(e) => {
124 return Err(e);
125 }
126 } {
127 self.flush().await?;
128 }
129
130 Ok(!self.is_publishing)
135 }
136
137 async fn do_handshake(&mut self, handshaker: &mut HandshakeServer) -> Result<bool, SessionError> {
142 const READ_SIZE: usize = handshake::RTMP_HANDSHAKE_SIZE + 1;
144 self.read_buf.reserve(READ_SIZE);
145
146 let mut bytes_read = 0;
147 while bytes_read < READ_SIZE {
148 let n = self
149 .io
150 .read_buf(&mut self.read_buf)
151 .with_timeout(Duration::from_secs(2))
152 .await??;
153 bytes_read += n;
154 }
155
156 let mut cursor = std::io::Cursor::new(self.read_buf.split().freeze());
157
158 handshaker.handshake(&mut cursor, &mut self.write_buf)?;
159
160 if handshaker.state() == ServerHandshakeState::Finish {
161 let over_read = cursor.extract_remaining();
162
163 if !over_read.is_empty() {
164 self.skip_read = true;
165 self.read_buf.extend_from_slice(&over_read);
166 }
167
168 self.send_set_chunk_size().await?;
169
170 Ok(true)
174 } else {
175 Ok(false)
179 }
180 }
181
182 async fn do_ready(&mut self) -> Result<bool, SessionError> {
186 if self.skip_read {
188 self.skip_read = false;
189 } else {
190 self.read_buf.reserve(CHUNK_SIZE);
191
192 let n = self
193 .io
194 .read_buf(&mut self.read_buf)
195 .with_timeout(Duration::from_millis(2500))
196 .await??;
197
198 if n == 0 {
199 return Ok(false);
200 }
201 }
202
203 self.parse_chunks().await?;
204
205 Ok(true)
206 }
207
208 async fn parse_chunks(&mut self) -> Result<(), SessionError> {
210 while let Some(chunk) = self.chunk_decoder.read_chunk(&mut self.read_buf)? {
211 let timestamp = chunk.message_header.timestamp;
212 let msg_stream_id = chunk.message_header.msg_stream_id;
213
214 if let Some(msg) = MessageParser::parse(&chunk)? {
215 self.process_messages(msg, msg_stream_id, timestamp).await?;
216 }
217 }
218
219 Ok(())
220 }
221
222 async fn process_messages(
224 &mut self,
225 rtmp_msg: RtmpMessageData<'_>,
226 stream_id: u32,
227 timestamp: u32,
228 ) -> Result<(), SessionError> {
229 match rtmp_msg {
230 RtmpMessageData::Amf0Command {
231 command_name,
232 transaction_id,
233 command_object,
234 others,
235 } => {
236 self.on_amf0_command_message(stream_id, command_name, transaction_id, command_object, others)
237 .await?
238 }
239 RtmpMessageData::SetChunkSize { chunk_size } => {
240 self.on_set_chunk_size(chunk_size as usize)?;
241 }
242 RtmpMessageData::AudioData { data } => {
243 self.on_data(stream_id, ChannelData::Audio { timestamp, data }).await?;
244 }
245 RtmpMessageData::VideoData { data } => {
246 self.on_data(stream_id, ChannelData::Video { timestamp, data }).await?;
247 }
248 RtmpMessageData::AmfData { data } => {
249 self.on_data(stream_id, ChannelData::Metadata { timestamp, data }).await?;
250 }
251 }
252
253 Ok(())
254 }
255
256 async fn send_set_chunk_size(&mut self) -> Result<(), SessionError> {
258 ProtocolControlMessagesWriter::write_set_chunk_size(&self.chunk_encoder, &mut self.write_buf, CHUNK_SIZE as u32)?;
259 self.chunk_encoder.set_chunk_size(CHUNK_SIZE);
260
261 Ok(())
262 }
263
264 async fn on_data(&self, stream_id: u32, data: ChannelData) -> Result<(), SessionError> {
268 if stream_id != self.stream_id || !self.is_publishing {
269 return Err(SessionError::UnknownStreamID(stream_id));
270 };
271
272 if matches!(
273 self.data_producer.send(data).with_timeout(Duration::from_secs(2)).await,
274 Err(_) | Ok(Err(_))
275 ) {
276 tracing::debug!("Publisher dropped");
277 return Err(SessionError::PublisherDropped);
278 }
279
280 Ok(())
281 }
282
283 async fn on_amf0_command_message(
286 &mut self,
287 stream_id: u32,
288 command_name: Amf0Value<'_>,
289 transaction_id: Amf0Value<'_>,
290 command_object: Amf0Value<'_>,
291 others: Vec<Amf0Value<'_>>,
292 ) -> Result<(), SessionError> {
293 let cmd = RtmpCommand::from(match command_name {
294 Amf0Value::String(ref s) => s,
295 _ => "",
296 });
297
298 let transaction_id = match transaction_id {
299 Amf0Value::Number(number) => number,
300 _ => 0.0,
301 };
302
303 let obj = match command_object {
304 Amf0Value::Object(obj) => obj,
305 _ => Cow::Owned(Vec::new()),
306 };
307
308 match cmd {
309 RtmpCommand::Connect => {
310 self.on_command_connect(transaction_id, stream_id, &obj, others).await?;
311 }
312 RtmpCommand::CreateStream => {
313 self.on_command_create_stream(transaction_id, stream_id, &obj, others).await?;
314 }
315 RtmpCommand::DeleteStream => {
316 self.on_command_delete_stream(transaction_id, stream_id, &obj, others).await?;
317 }
318 RtmpCommand::Play => {
319 return Err(SessionError::PlayNotSupported);
320 }
321 RtmpCommand::Publish => {
322 self.on_command_publish(transaction_id, stream_id, &obj, others).await?;
323 }
324 RtmpCommand::CloseStream | RtmpCommand::ReleaseStream => {
325 }
327 RtmpCommand::Unknown(_) => {}
328 }
329
330 Ok(())
331 }
332
333 fn on_set_chunk_size(&mut self, chunk_size: usize) -> Result<(), SessionError> {
336 if self.chunk_decoder.update_max_chunk_size(chunk_size) {
337 Ok(())
338 } else {
339 Err(SessionError::InvalidChunkSize(chunk_size))
340 }
341 }
342
343 async fn on_command_connect(
347 &mut self,
348 transaction_id: f64,
349 _stream_id: u32,
350 command_obj: &[(Cow<'_, str>, Amf0Value<'_>)],
351 _others: Vec<Amf0Value<'_>>,
352 ) -> Result<(), SessionError> {
353 ProtocolControlMessagesWriter::write_window_acknowledgement_size(
354 &self.chunk_encoder,
355 &mut self.write_buf,
356 CHUNK_SIZE as u32,
357 )?;
358
359 ProtocolControlMessagesWriter::write_set_peer_bandwidth(
360 &self.chunk_encoder,
361 &mut self.write_buf,
362 CHUNK_SIZE as u32,
363 2, )?;
365
366 let app_name = command_obj.iter().find(|(key, _)| key == "app");
367 let app_name = match app_name {
368 Some((_, Amf0Value::String(app))) => app,
369 _ => {
370 return Err(SessionError::NoAppName);
371 }
372 };
373
374 self.app_name = Some(app_name.to_string());
375
376 NetConnection::write_connect_response(
386 &self.chunk_encoder,
387 &mut self.write_buf,
388 transaction_id,
389 "FMS/3,0,1,123", 31.0, "NetConnection.Connect.Success",
392 "status", "Connection Succeeded.",
394 0.0,
395 )?;
396
397 Ok(())
398 }
399
400 async fn on_command_create_stream(
405 &mut self,
406 transaction_id: f64,
407 _stream_id: u32,
408 _command_obj: &[(Cow<'_, str>, Amf0Value<'_>)],
409 _others: Vec<Amf0Value<'_>>,
410 ) -> Result<(), SessionError> {
411 NetConnection::write_create_stream_response(&self.chunk_encoder, &mut self.write_buf, transaction_id, 1.0)?;
413
414 Ok(())
415 }
416
417 async fn on_command_delete_stream(
422 &mut self,
423 transaction_id: f64,
424 _stream_id: u32,
425 _command_obj: &[(Cow<'_, str>, Amf0Value<'_>)],
426 others: Vec<Amf0Value<'_>>,
427 ) -> Result<(), SessionError> {
428 let stream_id = match others.first() {
429 Some(Amf0Value::Number(stream_id)) => *stream_id,
430 _ => 0.0,
431 } as u32;
432
433 if self.stream_id == stream_id && self.is_publishing {
434 self.stream_id = 0;
435 self.is_publishing = false;
436 }
437
438 NetStreamWriter::write_on_status(
439 &self.chunk_encoder,
440 &mut self.write_buf,
441 transaction_id,
442 "status",
443 "NetStream.DeleteStream.Suceess",
444 "",
445 )?;
446
447 Ok(())
448 }
449
450 async fn on_command_publish(
454 &mut self,
455 transaction_id: f64,
456 stream_id: u32,
457 _command_obj: &[(Cow<'_, str>, Amf0Value<'_>)],
458 others: Vec<Amf0Value<'_>>,
459 ) -> Result<(), SessionError> {
460 let stream_name = match others.first() {
461 Some(Amf0Value::String(val)) => val,
462 _ => {
463 return Err(SessionError::NoStreamName);
464 }
465 };
466
467 let Some(app_name) = &self.app_name else {
468 return Err(SessionError::NoAppName);
469 };
470
471 let (response, waiter) = oneshot::channel();
472
473 if self
474 .publish_request_producer
475 .send(PublishRequest {
476 app_name: app_name.clone(),
477 stream_name: stream_name.to_string(),
478 response,
479 })
480 .await
481 .is_err()
482 {
483 return Err(SessionError::PublishRequestDenied);
484 }
485
486 let Ok(uid) = waiter.await else {
487 return Err(SessionError::PublishRequestDenied);
488 };
489
490 self.uid = Some(uid);
491
492 self.is_publishing = true;
493 self.stream_id = stream_id;
494
495 EventMessagesWriter::write_stream_begin(&self.chunk_encoder, &mut self.write_buf, stream_id)?;
496
497 NetStreamWriter::write_on_status(
498 &self.chunk_encoder,
499 &mut self.write_buf,
500 transaction_id,
501 "status",
502 "NetStream.Publish.Start",
503 "",
504 )?;
505
506 Ok(())
507 }
508
509 async fn flush(&mut self) -> Result<(), SessionError> {
510 if !self.write_buf.is_empty() {
511 self.io
512 .write_all(self.write_buf.as_ref())
513 .with_timeout(Duration::from_secs(2))
514 .await??;
515 self.write_buf.clear();
516 }
517
518 Ok(())
519 }
520}