warp/filters/
ws.rs

1//! Websockets Filters
2
3use std::borrow::Cow;
4use std::fmt;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use bytes::Bytes;
10use futures_util::{future, ready, FutureExt, Sink, Stream, TryFutureExt};
11use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade};
12use hyper::upgrade::OnUpgrade;
13use tokio_tungstenite::{
14    tungstenite::protocol::{self, frame::Utf8Bytes, WebSocketConfig},
15    WebSocketStream,
16};
17
18use super::header;
19use crate::filter::{filter_fn_one, Filter, One};
20use crate::reject::Rejection;
21use crate::reply::{Reply, Response};
22
23/// Creates a Websocket Filter.
24///
25/// The yielded `Ws` is used to finish the websocket upgrade.
26///
27/// # Note
28///
29/// This filter combines multiple filters internally, so you don't need them:
30///
31/// - Method must be `GET`
32/// - Header `connection` must be `upgrade`
33/// - Header `upgrade` must be `websocket`
34/// - Header `sec-websocket-version` must be `13`
35/// - Header `sec-websocket-key` must be set.
36///
37/// If the filters are met, yields a `Ws`. Calling `Ws::on_upgrade` will
38/// return a reply with:
39///
40/// - Status of `101 Switching Protocols`
41/// - Header `connection: upgrade`
42/// - Header `upgrade: websocket`
43/// - Header `sec-websocket-accept` with the hash value of the received key.
44pub fn ws() -> impl Filter<Extract = One<Ws>, Error = Rejection> + Copy {
45    let connection_has_upgrade = header::header2()
46        .and_then(|conn: ::headers::Connection| {
47            if conn.contains("upgrade") {
48                future::ok(())
49            } else {
50                future::err(crate::reject::known(MissingConnectionUpgrade))
51            }
52        })
53        .untuple_one();
54
55    crate::get()
56        .and(connection_has_upgrade)
57        .and(header::exact_ignore_case("upgrade", "websocket"))
58        .and(header::exact("sec-websocket-version", "13"))
59        //.and(header::exact2(Upgrade::websocket()))
60        //.and(header::exact2(SecWebsocketVersion::V13))
61        .and(header::header2::<SecWebsocketKey>())
62        .and(on_upgrade())
63        .map(
64            move |key: SecWebsocketKey, on_upgrade: Option<OnUpgrade>| Ws {
65                config: None,
66                key,
67                on_upgrade,
68            },
69        )
70}
71
72/// Extracted by the [`ws`] filter, and used to finish an upgrade.
73pub struct Ws {
74    config: Option<WebSocketConfig>,
75    key: SecWebsocketKey,
76    on_upgrade: Option<OnUpgrade>,
77}
78
79impl Ws {
80    /// Finish the upgrade, passing a function to handle the `WebSocket`.
81    ///
82    /// The passed function must return a `Future`.
83    pub fn on_upgrade<F, U>(self, func: F) -> impl Reply
84    where
85        F: FnOnce(WebSocket) -> U + Send + 'static,
86        U: Future<Output = ()> + Send + 'static,
87    {
88        WsReply {
89            ws: self,
90            on_upgrade: func,
91        }
92    }
93
94    // config
95
96    /// Does nothing.
97    ///
98    /// # Deprecated
99    ///
100    /// Use `max_write_buffer_size()` instead.
101    #[deprecated = "use max_write_buffer_size instead"]
102    pub fn max_send_queue(self, _max: usize) -> Self {
103        self
104    }
105
106    /// The max size of the write buffer, in bytes.
107    pub fn max_write_buffer_size(mut self, max: usize) -> Self {
108        self.config
109            .get_or_insert_with(WebSocketConfig::default)
110            .max_write_buffer_size = max;
111        self
112    }
113
114    /// Set the maximum message size (defaults to 64 megabytes)
115    pub fn max_message_size(mut self, max: usize) -> Self {
116        self.config
117            .get_or_insert_with(WebSocketConfig::default)
118            .max_message_size = Some(max);
119        self
120    }
121
122    /// Set the maximum frame size (defaults to 16 megabytes)
123    pub fn max_frame_size(mut self, max: usize) -> Self {
124        self.config
125            .get_or_insert_with(WebSocketConfig::default)
126            .max_frame_size = Some(max);
127        self
128    }
129}
130
131impl fmt::Debug for Ws {
132    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133        f.debug_struct("Ws").finish()
134    }
135}
136
137#[allow(missing_debug_implementations)]
138struct WsReply<F> {
139    ws: Ws,
140    on_upgrade: F,
141}
142
143impl<F, U> Reply for WsReply<F>
144where
145    F: FnOnce(WebSocket) -> U + Send + 'static,
146    U: Future<Output = ()> + Send + 'static,
147{
148    fn into_response(self) -> Response {
149        if let Some(on_upgrade) = self.ws.on_upgrade {
150            let on_upgrade_cb = self.on_upgrade;
151            let config = self.ws.config;
152            let fut = on_upgrade
153                .and_then(move |upgraded| {
154                    tracing::trace!("websocket upgrade complete");
155                    WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok)
156                })
157                .and_then(move |socket| on_upgrade_cb(socket).map(Ok))
158                .map(|result| {
159                    if let Err(err) = result {
160                        tracing::debug!("ws upgrade error: {}", err);
161                    }
162                });
163            ::tokio::task::spawn(fut);
164        } else {
165            tracing::debug!("ws couldn't be upgraded since no upgrade state was present");
166        }
167
168        let mut res = http::Response::default();
169
170        *res.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS;
171
172        res.headers_mut().typed_insert(Connection::upgrade());
173        res.headers_mut().typed_insert(Upgrade::websocket());
174        res.headers_mut()
175            .typed_insert(SecWebsocketAccept::from(self.ws.key));
176
177        res
178    }
179}
180
181// Extracts OnUpgrade state from the route.
182fn on_upgrade() -> impl Filter<Extract = (Option<OnUpgrade>,), Error = Rejection> + Copy {
183    filter_fn_one(|route| future::ready(Ok(route.extensions_mut().remove::<OnUpgrade>())))
184}
185
186/// A websocket `Stream` and `Sink`, provided to `ws` filters.
187///
188/// Ping messages sent from the client will be handled internally by replying with a Pong message.
189/// Close messages need to be handled explicitly: usually by closing the `Sink` end of the
190/// `WebSocket`.
191///
192/// **Note!**
193/// Due to rust futures nature, pings won't be handled until read part of `WebSocket` is polled
194pub struct WebSocket {
195    inner: WebSocketStream<hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>>,
196}
197
198impl WebSocket {
199    pub(crate) async fn from_raw_socket(
200        upgraded: hyper::upgrade::Upgraded,
201        role: protocol::Role,
202        config: Option<protocol::WebSocketConfig>,
203    ) -> Self {
204        let upgraded = hyper_util::rt::TokioIo::new(upgraded);
205        WebSocketStream::from_raw_socket(upgraded, role, config)
206            .map(|inner| WebSocket { inner })
207            .await
208    }
209
210    /// Gracefully close this websocket.
211    pub async fn close(mut self) -> Result<(), crate::Error> {
212        future::poll_fn(|cx| Pin::new(&mut self).poll_close(cx)).await
213    }
214}
215
216impl Stream for WebSocket {
217    type Item = Result<Message, crate::Error>;
218
219    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
220        match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
221            Some(Ok(item)) => Poll::Ready(Some(Ok(Message { inner: item }))),
222            Some(Err(e)) => {
223                tracing::debug!("websocket poll error: {}", e);
224                Poll::Ready(Some(Err(crate::Error::new(e))))
225            }
226            None => {
227                tracing::trace!("websocket closed");
228                Poll::Ready(None)
229            }
230        }
231    }
232}
233
234impl Sink<Message> for WebSocket {
235    type Error = crate::Error;
236
237    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
238        match ready!(Pin::new(&mut self.inner).poll_ready(cx)) {
239            Ok(()) => Poll::Ready(Ok(())),
240            Err(e) => Poll::Ready(Err(crate::Error::new(e))),
241        }
242    }
243
244    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
245        match Pin::new(&mut self.inner).start_send(item.inner) {
246            Ok(()) => Ok(()),
247            Err(e) => {
248                tracing::debug!("websocket start_send error: {}", e);
249                Err(crate::Error::new(e))
250            }
251        }
252    }
253
254    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
255        match ready!(Pin::new(&mut self.inner).poll_flush(cx)) {
256            Ok(()) => Poll::Ready(Ok(())),
257            Err(e) => Poll::Ready(Err(crate::Error::new(e))),
258        }
259    }
260
261    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
262        match ready!(Pin::new(&mut self.inner).poll_close(cx)) {
263            Ok(()) => Poll::Ready(Ok(())),
264            Err(err) => {
265                tracing::debug!("websocket close error: {}", err);
266                Poll::Ready(Err(crate::Error::new(err)))
267            }
268        }
269    }
270}
271
272impl fmt::Debug for WebSocket {
273    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274        f.debug_struct("WebSocket").finish()
275    }
276}
277
278/// A WebSocket message.
279///
280/// This will likely become a `non-exhaustive` enum in the future, once that
281/// language feature has stabilized.
282#[derive(Eq, PartialEq, Clone)]
283pub struct Message {
284    inner: protocol::Message,
285}
286
287impl Message {
288    /// Construct a new Text `Message`.
289    pub fn text<B: Into<String>>(bytes: B) -> Message {
290        Message {
291            inner: protocol::Message::text(bytes.into()),
292        }
293    }
294
295    /// Construct a new Binary `Message`.
296    pub fn binary<B: Into<Bytes>>(bytes: B) -> Message {
297        Message {
298            inner: protocol::Message::binary(bytes),
299        }
300    }
301
302    /// Construct a new Ping `Message`.
303    pub fn ping<B: Into<Bytes>>(bytes: B) -> Message {
304        Message {
305            inner: protocol::Message::Ping(bytes.into()),
306        }
307    }
308
309    /// Construct a new Pong `Message`.
310    ///
311    /// Note that one rarely needs to manually construct a Pong message because the underlying tungstenite socket
312    /// automatically responds to the Ping messages it receives. Manual construction might still be useful in some cases
313    /// like in tests or to send unidirectional heartbeats.
314    pub fn pong<B: Into<Bytes>>(bytes: B) -> Message {
315        Message {
316            inner: protocol::Message::Pong(bytes.into()),
317        }
318    }
319
320    /// Construct the default Close `Message`.
321    pub fn close() -> Message {
322        Message {
323            inner: protocol::Message::Close(None),
324        }
325    }
326
327    /// Construct a Close `Message` with a code and reason.
328    pub fn close_with(code: impl Into<u16>, reason: impl Into<Cow<'static, str>>) -> Message {
329        Message {
330            inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
331                code: protocol::frame::coding::CloseCode::from(code.into()),
332                reason: match reason.into() {
333                    Cow::Borrowed(s) => Utf8Bytes::from_static(s),
334                    Cow::Owned(s) => s.into(),
335                },
336            })),
337        }
338    }
339
340    /// Returns true if this message is a Text message.
341    pub fn is_text(&self) -> bool {
342        self.inner.is_text()
343    }
344
345    /// Returns true if this message is a Binary message.
346    pub fn is_binary(&self) -> bool {
347        self.inner.is_binary()
348    }
349
350    /// Returns true if this message a is a Close message.
351    pub fn is_close(&self) -> bool {
352        self.inner.is_close()
353    }
354
355    /// Returns true if this message is a Ping message.
356    pub fn is_ping(&self) -> bool {
357        self.inner.is_ping()
358    }
359
360    /// Returns true if this message is a Pong message.
361    pub fn is_pong(&self) -> bool {
362        self.inner.is_pong()
363    }
364
365    /// Try to get the close frame (close code and reason)
366    pub fn close_frame(&self) -> Option<(u16, &str)> {
367        if let protocol::Message::Close(Some(ref close_frame)) = self.inner {
368            Some((close_frame.code.into(), close_frame.reason.as_ref()))
369        } else {
370            None
371        }
372    }
373
374    /// Try to get a reference to the string text, if this is a Text message.
375    pub fn to_str(&self) -> Result<&str, ()> {
376        match self.inner {
377            protocol::Message::Text(ref s) => Ok(s),
378            _ => Err(()),
379        }
380    }
381
382    /// Return the bytes of this message, if the message can contain data.
383    pub fn as_bytes(&self) -> &[u8] {
384        match self.inner {
385            protocol::Message::Text(ref s) => s.as_bytes(),
386            protocol::Message::Binary(ref v) => v,
387            protocol::Message::Ping(ref v) => v,
388            protocol::Message::Pong(ref v) => v,
389            protocol::Message::Close(_) => &[],
390            protocol::Message::Frame(ref frame) => frame.payload(),
391        }
392    }
393
394    /// Destructure this message into binary data.
395    pub fn into_bytes(self) -> Bytes {
396        self.inner.into_data()
397    }
398}
399
400impl fmt::Debug for Message {
401    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
402        fmt::Debug::fmt(&self.inner, f)
403    }
404}
405
406impl From<Message> for Bytes {
407    fn from(m: Message) -> Self {
408        m.into_bytes()
409    }
410}
411
412// ===== Rejections =====
413
414/// Connection header did not include 'upgrade'
415#[derive(Debug)]
416pub struct MissingConnectionUpgrade;
417
418impl fmt::Display for MissingConnectionUpgrade {
419    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
420        write!(f, "Connection header did not include 'upgrade'")
421    }
422}
423
424impl ::std::error::Error for MissingConnectionUpgrade {}