warp/
server.rs

1use std::future::Future;
2use std::net::SocketAddr;
3#[cfg(feature = "tls")]
4use std::path::Path;
5
6use futures_util::TryFuture;
7
8use crate::filter::Filter;
9use crate::reject::IsReject;
10use crate::reply::Reply;
11#[cfg(feature = "tls")]
12use crate::tls::TlsConfigBuilder;
13
14/// Create a `Server` with the provided `Filter`.
15pub fn serve<F>(filter: F) -> Server<F, accept::LazyTcp, run::Standard>
16where
17    F: Filter + Clone + Send + Sync + 'static,
18    F::Extract: Reply,
19    F::Error: IsReject,
20{
21    Server {
22        acceptor: accept::LazyTcp,
23        pipeline: false,
24        filter,
25        runner: run::Standard,
26    }
27}
28
29/// A warp Server ready to filter requests.
30///
31/// Construct this type using [`serve()`].
32///
33/// # Unnameable
34///
35/// This type is publicly available in the docs only.
36///
37/// It is not otherwise nameable, since it is a builder type using typestate
38/// to allow for ergonomic configuration.
39#[derive(Debug)]
40pub struct Server<F, A, R> {
41    acceptor: A,
42    filter: F,
43    pipeline: bool,
44    runner: R,
45}
46
47// ===== impl Server =====
48
49impl<F, R> Server<F, accept::LazyTcp, R>
50where
51    F: Filter + Clone + Send + Sync + 'static,
52    <F::Future as TryFuture>::Ok: Reply,
53    <F::Future as TryFuture>::Error: IsReject,
54    R: run::Run,
55{
56    /// Binds and runs this server.
57    ///
58    /// # Panics
59    ///
60    /// Panics if we are unable to bind to the provided address.
61    ///
62    /// To handle bind failures, bind a listener and call `incoming()`.
63    pub async fn run(self, addr: impl Into<SocketAddr>) {
64        self.bind(addr).await.run().await;
65    }
66
67    /// Binds this server.
68    ///
69    /// # Panics
70    ///
71    /// Panics if we are unable to bind to the provided address.
72    ///
73    /// To handle bind failures, bind a listener and call `incoming()`.
74    pub async fn bind(self, addr: impl Into<SocketAddr>) -> Server<F, tokio::net::TcpListener, R> {
75        let addr = addr.into();
76        let acceptor = tokio::net::TcpListener::bind(addr)
77            .await
78            .expect("failed to bind to address");
79
80        self.incoming(acceptor)
81    }
82
83    /// Configure the server with an acceptor of incoming connections.
84    pub fn incoming<A>(self, acceptor: A) -> Server<F, A, R> {
85        Server {
86            acceptor,
87            filter: self.filter,
88            pipeline: self.pipeline,
89            runner: self.runner,
90        }
91    }
92
93    // pub fn conn
94}
95
96impl<F, A, R> Server<F, A, R>
97where
98    F: Filter + Clone + Send + Sync + 'static,
99    <F::Future as TryFuture>::Ok: Reply,
100    <F::Future as TryFuture>::Error: IsReject,
101    A: accept::Accept,
102    R: run::Run,
103{
104    #[cfg(feature = "tls")]
105    pub fn tls(self) -> Server<F, accept::Tls<A>, R> {}
106
107    /// Add graceful shutdown support to this server.
108    ///
109    /// # Example
110    ///
111    /// ```
112    /// # async fn ex(addr: std::net::SocketAddr) {
113    /// # use warp::Filter;
114    /// # let filter = warp::any().map(|| "ok");
115    /// warp::serve(filter)
116    ///     .bind(addr).await
117    ///     .graceful(async {
118    ///         // some signal in here, such as ctrl_c
119    ///     })
120    ///     .run().await;
121    /// # }
122    /// ```
123    pub fn graceful<Fut>(self, shutdown_signal: Fut) -> Server<F, A, run::Graceful<Fut>>
124    where
125        Fut: Future<Output = ()> + Send + 'static,
126    {
127        Server {
128            acceptor: self.acceptor,
129            filter: self.filter,
130            pipeline: self.pipeline,
131            runner: run::Graceful(shutdown_signal),
132        }
133    }
134
135    /// Run this server.
136    pub async fn run(self) {
137        R::run(self).await;
138    }
139}
140
141// // ===== impl Tls =====
142
143#[cfg(feature = "tls")]
144impl<F, A, R> Server<F, accept::Tls<A>, R>
145where
146    F: Filter + Clone + Send + Sync + 'static,
147    <F::Future as TryFuture>::Ok: Reply,
148    <F::Future as TryFuture>::Error: IsReject,
149    A: accept::Accept,
150    R: run::Run,
151{
152    // TLS config methods
153
154    /// Specify the file path to read the private key.
155    ///
156    /// *This function requires the `"tls"` feature.*
157    pub fn key_path(self, path: impl AsRef<Path>) -> Self {
158        self.with_tls(|tls| tls.key_path(path))
159    }
160
161    /// Specify the file path to read the certificate.
162    ///
163    /// *This function requires the `"tls"` feature.*
164    pub fn cert_path(self, path: impl AsRef<Path>) -> Self {
165        self.with_tls(|tls| tls.cert_path(path))
166    }
167
168    /// Specify the file path to read the trust anchor for optional client authentication.
169    ///
170    /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any
171    /// of the `client_auth_` methods, then client authentication is disabled by default.
172    ///
173    /// *This function requires the `"tls"` feature.*
174    pub fn client_auth_optional_path(self, path: impl AsRef<Path>) -> Self {
175        self.with_tls(|tls| tls.client_auth_optional_path(path))
176    }
177
178    /// Specify the file path to read the trust anchor for required client authentication.
179    ///
180    /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the
181    /// `client_auth_` methods, then client authentication is disabled by default.
182    ///
183    /// *This function requires the `"tls"` feature.*
184    pub fn client_auth_required_path(self, path: impl AsRef<Path>) -> Self {
185        self.with_tls(|tls| tls.client_auth_required_path(path))
186    }
187
188    /// Specify the in-memory contents of the private key.
189    ///
190    /// *This function requires the `"tls"` feature.*
191    pub fn key(self, key: impl AsRef<[u8]>) -> Self {
192        self.with_tls(|tls| tls.key(key.as_ref()))
193    }
194
195    /// Specify the in-memory contents of the certificate.
196    ///
197    /// *This function requires the `"tls"` feature.*
198    pub fn cert(self, cert: impl AsRef<[u8]>) -> Self {
199        self.with_tls(|tls| tls.cert(cert.as_ref()))
200    }
201
202    /// Specify the in-memory contents of the trust anchor for optional client authentication.
203    ///
204    /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any
205    /// of the `client_auth_` methods, then client authentication is disabled by default.
206    ///
207    /// *This function requires the `"tls"` feature.*
208    pub fn client_auth_optional(self, trust_anchor: impl AsRef<[u8]>) -> Self {
209        self.with_tls(|tls| tls.client_auth_optional(trust_anchor.as_ref()))
210    }
211
212    /// Specify the in-memory contents of the trust anchor for required client authentication.
213    ///
214    /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the
215    /// `client_auth_` methods, then client authentication is disabled by default.
216    ///
217    /// *This function requires the `"tls"` feature.*
218    pub fn client_auth_required(self, trust_anchor: impl AsRef<[u8]>) -> Self {
219        self.with_tls(|tls| tls.client_auth_required(trust_anchor.as_ref()))
220    }
221
222    /// Specify the DER-encoded OCSP response.
223    ///
224    /// *This function requires the `"tls"` feature.*
225    pub fn ocsp_resp(self, resp: impl AsRef<[u8]>) -> Self {
226        self.with_tls(|tls| tls.ocsp_resp(resp.as_ref()))
227    }
228
229    fn with_tls<Func>(self, func: Func) -> Self
230    where
231        Func: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder,
232    {
233        let tls = func(tls);
234    }
235}
236
237mod accept {
238    pub trait Accept {
239        type IO: hyper::rt::Read + hyper::rt::Write + Send + Unpin + 'static;
240        type AcceptError: std::fmt::Debug;
241        type Accepting: super::Future<Output = Result<Self::IO, Self::AcceptError>> + Send + 'static;
242        #[allow(async_fn_in_trait)]
243        async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error>;
244    }
245
246    #[derive(Debug)]
247    pub struct LazyTcp;
248
249    impl Accept for tokio::net::TcpListener {
250        type IO = hyper_util::rt::TokioIo<tokio::net::TcpStream>;
251        type AcceptError = std::convert::Infallible;
252        type Accepting = std::future::Ready<Result<Self::IO, Self::AcceptError>>;
253        async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error> {
254            let (io, _addr) = <tokio::net::TcpListener>::accept(self).await?;
255            Ok(std::future::ready(Ok(hyper_util::rt::TokioIo::new(io))))
256        }
257    }
258
259    #[cfg(unix)]
260    impl Accept for tokio::net::UnixListener {
261        type IO = hyper_util::rt::TokioIo<tokio::net::UnixStream>;
262        type AcceptError = std::convert::Infallible;
263        type Accepting = std::future::Ready<Result<Self::IO, Self::AcceptError>>;
264        async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error> {
265            let (io, _addr) = <tokio::net::UnixListener>::accept(self).await?;
266            Ok(std::future::ready(Ok(hyper_util::rt::TokioIo::new(io))))
267        }
268    }
269
270    #[cfg(feature = "tls")]
271    #[derive(Debug)]
272    pub struct Tls<A>(pub(super) A);
273
274    #[cfg(feature = "tls")]
275    impl<A: Accept> Accept for Tls<A> {
276        type IO = hyper_util::rt::TokioIo<tokio::net::TcpStream>;
277        type AcceptError = std::convert::Infallible;
278        type Accepting = std::future::Ready<Result<Self::IO, Self::AcceptError>>;
279        async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error> {
280            let (io, _addr) = self.0.accept().await?;
281            Ok(std::future::ready(Ok(hyper_util::rt::TokioIo::new(io))))
282        }
283    }
284}
285
286mod run {
287    pub trait Run {
288        #[allow(async_fn_in_trait)]
289        async fn run<F, A>(server: super::Server<F, A, Self>)
290        where
291            F: super::Filter + Clone + Send + Sync + 'static,
292            <F::Future as super::TryFuture>::Ok: super::Reply,
293            <F::Future as super::TryFuture>::Error: super::IsReject,
294            A: super::accept::Accept,
295            Self: Sized;
296    }
297
298    #[derive(Debug)]
299    pub struct Standard;
300
301    impl Run for Standard {
302        async fn run<F, A>(mut server: super::Server<F, A, Self>)
303        where
304            F: super::Filter + Clone + Send + Sync + 'static,
305            <F::Future as super::TryFuture>::Ok: super::Reply,
306            <F::Future as super::TryFuture>::Error: super::IsReject,
307            A: super::accept::Accept,
308            Self: Sized,
309        {
310            let pipeline = server.pipeline;
311            loop {
312                let accepting = match server.acceptor.accept().await {
313                    Ok(fut) => fut,
314                    Err(err) => {
315                        handle_accept_error(err).await;
316                        continue;
317                    }
318                };
319                let svc = crate::service(server.filter.clone());
320                let svc = hyper_util::service::TowerToHyperService::new(svc);
321                tokio::spawn(async move {
322                    let io = match accepting.await {
323                        Ok(io) => io,
324                        Err(err) => {
325                            tracing::debug!("server accept error: {:?}", err);
326                            return;
327                        }
328                    };
329                    if let Err(err) = hyper_util::server::conn::auto::Builder::new(
330                        hyper_util::rt::TokioExecutor::new(),
331                    )
332                    .http1()
333                    .pipeline_flush(pipeline)
334                    .serve_connection_with_upgrades(io, svc)
335                    .await
336                    {
337                        tracing::error!("server connection error: {:?}", err)
338                    }
339                });
340            }
341        }
342    }
343
344    #[derive(Debug)]
345    pub struct Graceful<Fut>(pub(super) Fut);
346
347    impl<Fut> Run for Graceful<Fut>
348    where
349        Fut: super::Future<Output = ()> + Send + 'static,
350    {
351        async fn run<F, A>(mut server: super::Server<F, A, Self>)
352        where
353            F: super::Filter + Clone + Send + Sync + 'static,
354            <F::Future as super::TryFuture>::Ok: super::Reply,
355            <F::Future as super::TryFuture>::Error: super::IsReject,
356            A: super::accept::Accept,
357            Self: Sized,
358        {
359            use futures_util::future;
360
361            let pipeline = server.pipeline;
362            let graceful_util = hyper_util::server::graceful::GracefulShutdown::new();
363            let mut shutdown_signal = std::pin::pin!(server.runner.0);
364            loop {
365                let accept = std::pin::pin!(server.acceptor.accept());
366                let accepting = match future::select(accept, &mut shutdown_signal).await {
367                    future::Either::Left((Ok(fut), _)) => fut,
368                    future::Either::Left((Err(err), _)) => {
369                        handle_accept_error(err).await;
370                        continue;
371                    }
372                    future::Either::Right(((), _)) => {
373                        tracing::debug!("shutdown signal received, starting graceful shutdown");
374                        break;
375                    }
376                };
377                let svc = crate::service(server.filter.clone());
378                let svc = hyper_util::service::TowerToHyperService::new(svc);
379                let watcher = graceful_util.watcher();
380                tokio::spawn(async move {
381                    let io = match accepting.await {
382                        Ok(io) => io,
383                        Err(err) => {
384                            tracing::debug!("server accepting error: {:?}", err);
385                            return;
386                        }
387                    };
388                    let mut hyper = hyper_util::server::conn::auto::Builder::new(
389                        hyper_util::rt::TokioExecutor::new(),
390                    );
391                    hyper.http1().pipeline_flush(pipeline);
392                    let conn = hyper.serve_connection_with_upgrades(io, svc);
393                    let conn = watcher.watch(conn);
394                    if let Err(err) = conn.await {
395                        tracing::error!("server connection error: {:?}", err)
396                    }
397                });
398            }
399
400            drop(server.acceptor); // close listener
401            graceful_util.shutdown().await;
402        }
403    }
404
405    // TODO: allow providing your own handler
406    async fn handle_accept_error(e: std::io::Error) {
407        if is_connection_error(&e) {
408            return;
409        }
410        // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
411        //
412        // > A possible scenario is that the process has hit the max open files
413        // > allowed, and so trying to accept a new connection will fail with
414        // > `EMFILE`. In some cases, it's preferable to just wait for some time, if
415        // > the application will likely close some files (or connections), and try
416        // > to accept the connection again. If this option is `true`, the error
417        // > will be logged at the `error` level, since it is still a big deal,
418        // > and then the listener will sleep for 1 second.
419        tracing::error!("accept error: {:?}", e);
420        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
421    }
422
423    fn is_connection_error(e: &std::io::Error) -> bool {
424        // some errors that occur on the TCP stream are emitted when
425        // accepting, they can be ignored.
426        matches!(
427            e.kind(),
428            std::io::ErrorKind::ConnectionRefused
429                | std::io::ErrorKind::ConnectionAborted
430                | std::io::ErrorKind::ConnectionReset
431        )
432    }
433}