warp/filters/
sse.rs

1//! Server-Sent Events (SSE)
2//!
3//! # Example
4//!
5//! ```
6//!
7//! use std::time::Duration;
8//! use std::convert::Infallible;
9//! use warp::{Filter, sse::Event};
10//! use futures_util::{stream::iter, Stream};
11//!
12//! fn sse_events() -> impl Stream<Item = Result<Event, Infallible>> {
13//!     iter(vec![
14//!         Ok(Event::default().data("unnamed event")),
15//!         Ok(
16//!             Event::default().event("chat")
17//!             .data("chat message")
18//!         ),
19//!         Ok(
20//!             Event::default().id(13.to_string())
21//!             .event("chat")
22//!             .data("other chat message\nwith next line")
23//!             .retry(Duration::from_millis(5000))
24//!         )
25//!     ])
26//! }
27//!
28//! let app = warp::path("push-notifications")
29//!     .and(warp::get())
30//!     .map(|| {
31//!         warp::sse::reply(warp::sse::keep_alive().stream(sse_events()))
32//!     });
33//! ```
34//!
35//! Each field already is event which can be sent to client.
36//! The events with multiple fields can be created by combining fields using tuples.
37//!
38//! See also the [EventSource](https://developer.mozilla.org/en-US/docs/Web/API/EventSource) API,
39//! which specifies the expected behavior of Server Sent Events.
40//!
41
42#![allow(rustdoc::invalid_html_tags)]
43
44use serde::Serialize;
45use std::borrow::Cow;
46use std::error::Error as StdError;
47use std::fmt::{self, Write};
48use std::future::Future;
49use std::pin::Pin;
50use std::str::FromStr;
51use std::task::{Context, Poll};
52use std::time::Duration;
53
54use crate::bodyt::Body;
55use futures_util::{future, Stream, TryStream, TryStreamExt};
56use http::header::{HeaderValue, CACHE_CONTROL, CONTENT_TYPE};
57use pin_project::pin_project;
58use serde_json::Error;
59use tokio::time::{self, Sleep};
60
61use self::sealed::SseError;
62use super::header;
63use crate::filter::One;
64use crate::reply::Response;
65use crate::{Filter, Rejection, Reply};
66
67// Server-sent event data type
68#[derive(Debug)]
69enum DataType {
70    Text(String),
71    Json(String),
72}
73
74/// Server-sent event
75#[derive(Default, Debug)]
76pub struct Event {
77    id: Option<String>,
78    data: Option<DataType>,
79    event: Option<String>,
80    comment: Option<String>,
81    retry: Option<Duration>,
82}
83
84impl Event {
85    /// Set Server-sent event data
86    /// data field(s) ("data:<content>")
87    pub fn data<T: Into<String>>(mut self, data: T) -> Event {
88        self.data = Some(DataType::Text(data.into()));
89        self
90    }
91
92    /// Set Server-sent event data
93    /// data field(s) ("data:<content>")
94    pub fn json_data<T: Serialize>(mut self, data: T) -> Result<Event, Error> {
95        self.data = Some(DataType::Json(serde_json::to_string(&data)?));
96        Ok(self)
97    }
98
99    /// Set Server-sent event comment
100    /// Comment field (":<comment-text>")
101    pub fn comment<T: Into<String>>(mut self, comment: T) -> Event {
102        self.comment = Some(comment.into());
103        self
104    }
105
106    /// Set Server-sent event event
107    /// Event name field ("event:<event-name>")
108    pub fn event<T: Into<String>>(mut self, event: T) -> Event {
109        self.event = Some(event.into());
110        self
111    }
112
113    /// Set Server-sent event retry
114    /// Retry timeout field ("retry:<timeout>")
115    pub fn retry(mut self, duration: Duration) -> Event {
116        self.retry = Some(duration);
117        self
118    }
119
120    /// Set Server-sent event id
121    /// Identifier field ("id:<identifier>")
122    pub fn id<T: Into<String>>(mut self, id: T) -> Event {
123        self.id = Some(id.into());
124        self
125    }
126}
127
128impl fmt::Display for Event {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        if let Some(ref comment) = &self.comment {
131            ":".fmt(f)?;
132            comment.fmt(f)?;
133            f.write_char('\n')?;
134        }
135
136        if let Some(ref event) = &self.event {
137            "event:".fmt(f)?;
138            event.fmt(f)?;
139            f.write_char('\n')?;
140        }
141
142        match self.data {
143            Some(DataType::Text(ref data)) => {
144                for line in data.split('\n') {
145                    "data:".fmt(f)?;
146                    line.fmt(f)?;
147                    f.write_char('\n')?;
148                }
149            }
150            Some(DataType::Json(ref data)) => {
151                "data:".fmt(f)?;
152                data.fmt(f)?;
153                f.write_char('\n')?;
154            }
155            None => {}
156        }
157
158        if let Some(ref id) = &self.id {
159            "id:".fmt(f)?;
160            id.fmt(f)?;
161            f.write_char('\n')?;
162        }
163
164        if let Some(ref duration) = &self.retry {
165            "retry:".fmt(f)?;
166
167            let secs = duration.as_secs();
168            let millis = duration.subsec_millis();
169
170            if secs > 0 {
171                // format seconds
172                secs.fmt(f)?;
173
174                // pad milliseconds
175                if millis < 10 {
176                    f.write_str("00")?;
177                } else if millis < 100 {
178                    f.write_char('0')?;
179                }
180            }
181
182            // format milliseconds
183            millis.fmt(f)?;
184
185            f.write_char('\n')?;
186        }
187
188        f.write_char('\n')?;
189        Ok(())
190    }
191}
192
193/// Gets the optional last event id from request.
194/// Typically this identifier represented as number or string.
195///
196/// ```
197/// let app = warp::sse::last_event_id::<u32>();
198///
199/// // The identifier is present
200/// # #[cfg(feature = "test")]
201/// async {
202///     assert_eq!(
203///         warp::test::request()
204///            .header("Last-Event-ID", "12")
205///            .filter(&app)
206///            .await
207///            .unwrap(),
208///         Some(12)
209///     );
210///
211///     // The identifier is missing
212///     assert_eq!(
213///        warp::test::request()
214///            .filter(&app)
215///            .await
216///            .unwrap(),
217///         None
218///     );
219///
220///     // The identifier is not a valid
221///     assert!(
222///        warp::test::request()
223///            .header("Last-Event-ID", "abc")
224///            .filter(&app)
225///            .await
226///            .is_err(),
227///     );
228///};
229/// ```
230pub fn last_event_id<T>() -> impl Filter<Extract = One<Option<T>>, Error = Rejection> + Copy
231where
232    T: FromStr + Send + Sync + 'static,
233{
234    header::optional("last-event-id")
235}
236
237/// Server-sent events reply
238///
239/// This function converts stream of server events into a `Reply` with:
240///
241/// - Status of `200 OK`
242/// - Header `content-type: text/event-stream`
243/// - Header `cache-control: no-cache`.
244///
245/// # Example
246///
247/// ```
248/// use std::time::Duration;
249/// use futures_util::Stream;
250/// use futures_util::stream::iter;
251/// use std::convert::Infallible;
252/// use warp::{Filter, sse::Event};
253/// use serde_derive::Serialize;
254///
255/// #[derive(Serialize)]
256/// struct Msg {
257///     from: u32,
258///     text: String,
259/// }
260///
261/// fn event_stream() -> impl Stream<Item = Result<Event, Infallible>> {
262///         iter(vec![
263///             // Unnamed event with data only
264///             Ok(Event::default().data("payload")),
265///             // Named event with ID and retry timeout
266///             Ok(
267///                 Event::default().data("other message\nwith next line")
268///                 .event("chat")
269///                 .id(1.to_string())
270///                 .retry(Duration::from_millis(15000))
271///             ),
272///             // Event with JSON data
273///             Ok(
274///                 Event::default().id(2.to_string())
275///                 .json_data(Msg {
276///                     from: 2,
277///                     text: "hello".into(),
278///                 }).unwrap(),
279///             )
280///         ])
281/// }
282///
283/// # #[cfg(feature = "test")]
284/// async {
285///     let app = warp::path("sse").and(warp::get()).map(|| {
286///        warp::sse::reply(event_stream())
287///     });
288///
289///     let res = warp::test::request()
290///         .method("GET")
291///         .header("Connection", "Keep-Alive")
292///         .path("/sse")
293///         .reply(&app)
294///         .await
295///         .into_body();
296///
297///     assert_eq!(
298///         res,
299///         r#"data:payload
300///
301/// event:chat
302/// data:other message
303/// data:with next line
304/// id:1
305/// retry:15000
306///
307/// data:{"from":2,"text":"hello"}
308/// id:2
309///
310/// "#
311///     );
312/// };
313/// ```
314pub fn reply<S>(event_stream: S) -> impl Reply
315where
316    S: TryStream<Ok = Event> + Send + Sync + 'static,
317    S::Error: StdError + Send + Sync + 'static,
318{
319    SseReply { event_stream }
320}
321
322#[allow(missing_debug_implementations)]
323struct SseReply<S> {
324    event_stream: S,
325}
326
327impl<S> Reply for SseReply<S>
328where
329    S: TryStream<Ok = Event> + Send + Sync + 'static,
330    S::Error: StdError + Send + Sync + 'static,
331{
332    #[inline]
333    fn into_response(self) -> Response {
334        let body_stream = self
335            .event_stream
336            .map_err(|error| {
337                // FIXME: error logging
338                log::error!("sse stream error: {}", error);
339                SseError
340            })
341            .into_stream()
342            .and_then(|event| future::ready(Ok(event.to_string())));
343
344        let mut res = Response::new(Body::wrap_stream(body_stream));
345        // Set appropriate content type
346        res.headers_mut()
347            .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
348        // Disable response body caching
349        res.headers_mut()
350            .insert(CACHE_CONTROL, HeaderValue::from_static("no-cache"));
351        res
352    }
353}
354
355/// Configure the interval between keep-alive messages, the content
356/// of each message, and the associated stream.
357#[derive(Debug)]
358pub struct KeepAlive {
359    comment_text: Cow<'static, str>,
360    max_interval: Duration,
361}
362
363impl KeepAlive {
364    /// Customize the interval between keep-alive messages.
365    ///
366    /// Default is 15 seconds.
367    pub fn interval(mut self, time: Duration) -> Self {
368        self.max_interval = time;
369        self
370    }
371
372    /// Customize the text of the keep-alive message.
373    ///
374    /// Default is an empty comment.
375    pub fn text(mut self, text: impl Into<Cow<'static, str>>) -> Self {
376        self.comment_text = text.into();
377        self
378    }
379
380    /// Wrap an event stream with keep-alive functionality.
381    ///
382    /// See [`keep_alive`] for more.
383    pub fn stream<S>(
384        self,
385        event_stream: S,
386    ) -> impl TryStream<Ok = Event, Error = impl StdError + Send + Sync + 'static> + Send + 'static
387    where
388        S: TryStream<Ok = Event> + Send + 'static,
389        S::Error: StdError + Send + Sync + 'static,
390    {
391        let alive_timer = time::sleep(self.max_interval);
392        SseKeepAlive {
393            event_stream,
394            comment_text: self.comment_text,
395            max_interval: self.max_interval,
396            alive_timer,
397        }
398    }
399}
400
401#[allow(missing_debug_implementations)]
402#[pin_project]
403struct SseKeepAlive<S> {
404    #[pin]
405    event_stream: S,
406    comment_text: Cow<'static, str>,
407    max_interval: Duration,
408    #[pin]
409    alive_timer: Sleep,
410}
411
412/// Keeps event source connection alive when no events sent over a some time.
413///
414/// Some proxy servers may drop HTTP connection after a some timeout of inactivity.
415/// This function helps to prevent such behavior by sending comment events every
416/// `keep_interval` of inactivity.
417///
418/// By default the comment is `:` (an empty comment) and the time interval between
419/// events is 15 seconds. Both may be customized using the builder pattern
420/// as shown below.
421///
422/// ```
423/// use std::time::Duration;
424/// use std::convert::Infallible;
425/// use futures_util::StreamExt;
426/// use tokio::time::interval;
427/// use tokio_stream::wrappers::IntervalStream;
428/// use warp::{Filter, Stream, sse::Event};
429///
430/// // create server-sent event
431/// fn sse_counter(counter: u64) ->  Result<Event, Infallible> {
432///     Ok(Event::default().data(counter.to_string()))
433/// }
434///
435/// fn main() {
436///     let routes = warp::path("ticks")
437///         .and(warp::get())
438///         .map(|| {
439///             let mut counter: u64 = 0;
440///             let interval = interval(Duration::from_secs(15));
441///             let stream = IntervalStream::new(interval);
442///             let event_stream = stream.map(move |_| {
443///                 counter += 1;
444///                 sse_counter(counter)
445///             });
446///             // reply using server-sent events
447///             let stream = warp::sse::keep_alive()
448///                 .interval(Duration::from_secs(5))
449///                 .text("thump".to_string())
450///                 .stream(event_stream);
451///             warp::sse::reply(stream)
452///         });
453/// }
454/// ```
455///
456/// See [notes](https://www.w3.org/TR/2009/WD-eventsource-20090421/#notes).
457pub fn keep_alive() -> KeepAlive {
458    KeepAlive {
459        comment_text: Cow::Borrowed(""),
460        max_interval: Duration::from_secs(15),
461    }
462}
463
464impl<S> Stream for SseKeepAlive<S>
465where
466    S: TryStream<Ok = Event> + Send + 'static,
467    S::Error: StdError + Send + Sync + 'static,
468{
469    type Item = Result<Event, SseError>;
470
471    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
472        let mut pin = self.project();
473        match pin.event_stream.try_poll_next(cx) {
474            Poll::Pending => match Pin::new(&mut pin.alive_timer).poll(cx) {
475                Poll::Pending => Poll::Pending,
476                Poll::Ready(_) => {
477                    // restart timer
478                    pin.alive_timer
479                        .reset(tokio::time::Instant::now() + *pin.max_interval);
480                    let comment_str = pin.comment_text.clone();
481                    let event = Event::default().comment(comment_str);
482                    Poll::Ready(Some(Ok(event)))
483                }
484            },
485            Poll::Ready(Some(Ok(event))) => {
486                // restart timer
487                pin.alive_timer
488                    .reset(tokio::time::Instant::now() + *pin.max_interval);
489                Poll::Ready(Some(Ok(event)))
490            }
491            Poll::Ready(None) => Poll::Ready(None),
492            Poll::Ready(Some(Err(error))) => {
493                log::error!("sse::keep error: {}", error);
494                Poll::Ready(Some(Err(SseError)))
495            }
496        }
497    }
498}
499
500mod sealed {
501    use super::*;
502
503    /// SSE error type
504    #[derive(Debug)]
505    pub struct SseError;
506
507    impl fmt::Display for SseError {
508        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
509            write!(f, "sse error")
510        }
511    }
512
513    impl StdError for SseError {}
514}