1use std::borrow::Cow;
4use std::fmt;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use bytes::Bytes;
10use futures_util::{future, ready, FutureExt, Sink, Stream, TryFutureExt};
11use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade};
12use hyper::upgrade::OnUpgrade;
13use tokio_tungstenite::{
14 tungstenite::protocol::{self, frame::Utf8Bytes, WebSocketConfig},
15 WebSocketStream,
16};
17
18use super::header;
19use crate::filter::{filter_fn_one, Filter, One};
20use crate::reject::Rejection;
21use crate::reply::{Reply, Response};
22
23pub fn ws() -> impl Filter<Extract = One<Ws>, Error = Rejection> + Copy {
45 let connection_has_upgrade = header::header2()
46 .and_then(|conn: ::headers::Connection| {
47 if conn.contains("upgrade") {
48 future::ok(())
49 } else {
50 future::err(crate::reject::known(MissingConnectionUpgrade))
51 }
52 })
53 .untuple_one();
54
55 crate::get()
56 .and(connection_has_upgrade)
57 .and(header::exact_ignore_case("upgrade", "websocket"))
58 .and(header::exact("sec-websocket-version", "13"))
59 .and(header::header2::<SecWebsocketKey>())
62 .and(on_upgrade())
63 .map(
64 move |key: SecWebsocketKey, on_upgrade: Option<OnUpgrade>| Ws {
65 config: None,
66 key,
67 on_upgrade,
68 },
69 )
70}
71
72pub struct Ws {
74 config: Option<WebSocketConfig>,
75 key: SecWebsocketKey,
76 on_upgrade: Option<OnUpgrade>,
77}
78
79impl Ws {
80 pub fn on_upgrade<F, U>(self, func: F) -> impl Reply
84 where
85 F: FnOnce(WebSocket) -> U + Send + 'static,
86 U: Future<Output = ()> + Send + 'static,
87 {
88 WsReply {
89 ws: self,
90 on_upgrade: func,
91 }
92 }
93
94 #[deprecated = "use max_write_buffer_size instead"]
102 pub fn max_send_queue(self, _max: usize) -> Self {
103 self
104 }
105
106 pub fn max_write_buffer_size(mut self, max: usize) -> Self {
108 self.config
109 .get_or_insert_with(WebSocketConfig::default)
110 .max_write_buffer_size = max;
111 self
112 }
113
114 pub fn max_message_size(mut self, max: usize) -> Self {
116 self.config
117 .get_or_insert_with(WebSocketConfig::default)
118 .max_message_size = Some(max);
119 self
120 }
121
122 pub fn max_frame_size(mut self, max: usize) -> Self {
124 self.config
125 .get_or_insert_with(WebSocketConfig::default)
126 .max_frame_size = Some(max);
127 self
128 }
129}
130
131impl fmt::Debug for Ws {
132 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133 f.debug_struct("Ws").finish()
134 }
135}
136
137#[allow(missing_debug_implementations)]
138struct WsReply<F> {
139 ws: Ws,
140 on_upgrade: F,
141}
142
143impl<F, U> Reply for WsReply<F>
144where
145 F: FnOnce(WebSocket) -> U + Send + 'static,
146 U: Future<Output = ()> + Send + 'static,
147{
148 fn into_response(self) -> Response {
149 if let Some(on_upgrade) = self.ws.on_upgrade {
150 let on_upgrade_cb = self.on_upgrade;
151 let config = self.ws.config;
152 let fut = on_upgrade
153 .and_then(move |upgraded| {
154 tracing::trace!("websocket upgrade complete");
155 WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok)
156 })
157 .and_then(move |socket| on_upgrade_cb(socket).map(Ok))
158 .map(|result| {
159 if let Err(err) = result {
160 tracing::debug!("ws upgrade error: {}", err);
161 }
162 });
163 ::tokio::task::spawn(fut);
164 } else {
165 tracing::debug!("ws couldn't be upgraded since no upgrade state was present");
166 }
167
168 let mut res = http::Response::default();
169
170 *res.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS;
171
172 res.headers_mut().typed_insert(Connection::upgrade());
173 res.headers_mut().typed_insert(Upgrade::websocket());
174 res.headers_mut()
175 .typed_insert(SecWebsocketAccept::from(self.ws.key));
176
177 res
178 }
179}
180
181fn on_upgrade() -> impl Filter<Extract = (Option<OnUpgrade>,), Error = Rejection> + Copy {
183 filter_fn_one(|route| future::ready(Ok(route.extensions_mut().remove::<OnUpgrade>())))
184}
185
186pub struct WebSocket {
195 inner: WebSocketStream<hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>>,
196}
197
198impl WebSocket {
199 pub(crate) async fn from_raw_socket(
200 upgraded: hyper::upgrade::Upgraded,
201 role: protocol::Role,
202 config: Option<protocol::WebSocketConfig>,
203 ) -> Self {
204 let upgraded = hyper_util::rt::TokioIo::new(upgraded);
205 WebSocketStream::from_raw_socket(upgraded, role, config)
206 .map(|inner| WebSocket { inner })
207 .await
208 }
209
210 pub async fn close(mut self) -> Result<(), crate::Error> {
212 future::poll_fn(|cx| Pin::new(&mut self).poll_close(cx)).await
213 }
214}
215
216impl Stream for WebSocket {
217 type Item = Result<Message, crate::Error>;
218
219 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
220 match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
221 Some(Ok(item)) => Poll::Ready(Some(Ok(Message { inner: item }))),
222 Some(Err(e)) => {
223 tracing::debug!("websocket poll error: {}", e);
224 Poll::Ready(Some(Err(crate::Error::new(e))))
225 }
226 None => {
227 tracing::trace!("websocket closed");
228 Poll::Ready(None)
229 }
230 }
231 }
232}
233
234impl Sink<Message> for WebSocket {
235 type Error = crate::Error;
236
237 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
238 match ready!(Pin::new(&mut self.inner).poll_ready(cx)) {
239 Ok(()) => Poll::Ready(Ok(())),
240 Err(e) => Poll::Ready(Err(crate::Error::new(e))),
241 }
242 }
243
244 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
245 match Pin::new(&mut self.inner).start_send(item.inner) {
246 Ok(()) => Ok(()),
247 Err(e) => {
248 tracing::debug!("websocket start_send error: {}", e);
249 Err(crate::Error::new(e))
250 }
251 }
252 }
253
254 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
255 match ready!(Pin::new(&mut self.inner).poll_flush(cx)) {
256 Ok(()) => Poll::Ready(Ok(())),
257 Err(e) => Poll::Ready(Err(crate::Error::new(e))),
258 }
259 }
260
261 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
262 match ready!(Pin::new(&mut self.inner).poll_close(cx)) {
263 Ok(()) => Poll::Ready(Ok(())),
264 Err(err) => {
265 tracing::debug!("websocket close error: {}", err);
266 Poll::Ready(Err(crate::Error::new(err)))
267 }
268 }
269 }
270}
271
272impl fmt::Debug for WebSocket {
273 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274 f.debug_struct("WebSocket").finish()
275 }
276}
277
278#[derive(Eq, PartialEq, Clone)]
283pub struct Message {
284 inner: protocol::Message,
285}
286
287impl Message {
288 pub fn text<B: Into<String>>(bytes: B) -> Message {
290 Message {
291 inner: protocol::Message::text(bytes.into()),
292 }
293 }
294
295 pub fn binary<B: Into<Bytes>>(bytes: B) -> Message {
297 Message {
298 inner: protocol::Message::binary(bytes),
299 }
300 }
301
302 pub fn ping<B: Into<Bytes>>(bytes: B) -> Message {
304 Message {
305 inner: protocol::Message::Ping(bytes.into()),
306 }
307 }
308
309 pub fn pong<B: Into<Bytes>>(bytes: B) -> Message {
315 Message {
316 inner: protocol::Message::Pong(bytes.into()),
317 }
318 }
319
320 pub fn close() -> Message {
322 Message {
323 inner: protocol::Message::Close(None),
324 }
325 }
326
327 pub fn close_with(code: impl Into<u16>, reason: impl Into<Cow<'static, str>>) -> Message {
329 Message {
330 inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
331 code: protocol::frame::coding::CloseCode::from(code.into()),
332 reason: match reason.into() {
333 Cow::Borrowed(s) => Utf8Bytes::from_static(s),
334 Cow::Owned(s) => s.into(),
335 },
336 })),
337 }
338 }
339
340 pub fn is_text(&self) -> bool {
342 self.inner.is_text()
343 }
344
345 pub fn is_binary(&self) -> bool {
347 self.inner.is_binary()
348 }
349
350 pub fn is_close(&self) -> bool {
352 self.inner.is_close()
353 }
354
355 pub fn is_ping(&self) -> bool {
357 self.inner.is_ping()
358 }
359
360 pub fn is_pong(&self) -> bool {
362 self.inner.is_pong()
363 }
364
365 pub fn close_frame(&self) -> Option<(u16, &str)> {
367 if let protocol::Message::Close(Some(ref close_frame)) = self.inner {
368 Some((close_frame.code.into(), close_frame.reason.as_ref()))
369 } else {
370 None
371 }
372 }
373
374 pub fn to_str(&self) -> Result<&str, ()> {
376 match self.inner {
377 protocol::Message::Text(ref s) => Ok(s),
378 _ => Err(()),
379 }
380 }
381
382 pub fn as_bytes(&self) -> &[u8] {
384 match self.inner {
385 protocol::Message::Text(ref s) => s.as_bytes(),
386 protocol::Message::Binary(ref v) => v,
387 protocol::Message::Ping(ref v) => v,
388 protocol::Message::Pong(ref v) => v,
389 protocol::Message::Close(_) => &[],
390 protocol::Message::Frame(ref frame) => frame.payload(),
391 }
392 }
393
394 pub fn into_bytes(self) -> Bytes {
396 self.inner.into_data()
397 }
398}
399
400impl fmt::Debug for Message {
401 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
402 fmt::Debug::fmt(&self.inner, f)
403 }
404}
405
406impl From<Message> for Bytes {
407 fn from(m: Message) -> Self {
408 m.into_bytes()
409 }
410}
411
412#[derive(Debug)]
416pub struct MissingConnectionUpgrade;
417
418impl fmt::Display for MissingConnectionUpgrade {
419 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
420 write!(f, "Connection header did not include 'upgrade'")
421 }
422}
423
424impl ::std::error::Error for MissingConnectionUpgrade {}