use std::fmt::{self, Debug, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use http::header::{HOST, LOCATION, UPGRADE};
use http::uri::{Authority, Scheme};
use http::{HeaderValue, Request, Response, StatusCode, Uri};
use http_body_util::{Either, Empty};
use pin_project::pin_project;
use tower_layer::Layer;
use tower_service::Service as TowerService;
use crate::Protocol;
#[derive(Clone, Copy, Debug)]
pub struct UpgradeHttpLayer;
impl<Service> Layer<Service> for UpgradeHttpLayer {
type Service = UpgradeHttp<Service>;
fn layer(&self, inner: Service) -> Self::Service {
UpgradeHttp::new(inner)
}
}
#[derive(Clone, Debug)]
pub struct UpgradeHttp<Service> {
service: Service,
}
impl<Service> UpgradeHttp<Service> {
pub const fn new(service: Service) -> Self {
Self { service }
}
pub fn into_inner(self) -> Service {
self.service
}
pub const fn get_ref(&self) -> &Service {
&self.service
}
pub fn get_mut(&mut self) -> &mut Service {
&mut self.service
}
}
impl<Service, RequestBody, ResponseBody> TowerService<Request<RequestBody>> for UpgradeHttp<Service>
where
Service: TowerService<Request<RequestBody>, Response = Response<ResponseBody>>,
{
type Response = Response<Either<ResponseBody, Empty<Bytes>>>;
type Error = Service::Error;
type Future = UpgradeHttpFuture<Service, Request<RequestBody>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: Request<RequestBody>) -> Self::Future {
match req
.extensions()
.get::<Protocol>()
.expect("`Protocol` should always be set by `DualProtocolService`")
{
Protocol::Tls => UpgradeHttpFuture::new_service(self.service.call(req)),
Protocol::Plain => {
let response = Response::builder();
let response = if let Some((authority, scheme)) =
extract_authority(&req).and_then(|authority| {
let uri = req.uri();
if uri.scheme_str() == Some("ws")
|| req.headers().get(UPGRADE)
== Some(&HeaderValue::from_static("websocket"))
{
Some((
authority,
Scheme::try_from("wss").expect("ASCII string is valid"),
))
}
else if uri.scheme() == Some(&Scheme::HTTP) || uri.scheme_str().is_none()
{
Some((authority, Scheme::HTTPS))
}
else {
None
}
}) {
let mut uri = Uri::builder().scheme(scheme).authority(authority);
if let Some(path_and_query) = req.uri().path_and_query() {
uri = uri.path_and_query(path_and_query.clone());
}
let uri = uri.build().expect("invalid path and query");
response
.status(StatusCode::MOVED_PERMANENTLY)
.header(LOCATION, uri.to_string())
} else {
response.status(StatusCode::BAD_REQUEST)
}
.body(Empty::new())
.expect("invalid header or body");
UpgradeHttpFuture::new_upgrade(response)
}
}
}
}
#[pin_project]
pub struct UpgradeHttpFuture<Service, Request>(#[pin] FutureServe<Service, Request>)
where
Service: TowerService<Request>;
#[derive(Debug)]
#[pin_project(project = UpgradeHttpFutureProj)]
enum FutureServe<Service, Request>
where
Service: TowerService<Request>,
{
Service(#[pin] Service::Future),
Upgrade(Option<Response<Empty<Bytes>>>),
}
impl<Service, Request> Debug for UpgradeHttpFuture<Service, Request>
where
Service: TowerService<Request>,
FutureServe<Service, Request>: Debug,
{
fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
formatter
.debug_tuple("UpgradeHttpFuture")
.field(&self.0)
.finish()
}
}
impl<Service, Request> UpgradeHttpFuture<Service, Request>
where
Service: TowerService<Request>,
{
const fn new_service(future: Service::Future) -> Self {
Self(FutureServe::Service(future))
}
const fn new_upgrade(response: Response<Empty<Bytes>>) -> Self {
Self(FutureServe::Upgrade(Some(response)))
}
}
impl<Service, Request, ResponseBody> Future for UpgradeHttpFuture<Service, Request>
where
Service: TowerService<Request, Response = Response<ResponseBody>>,
{
type Output = Result<Response<Either<ResponseBody, Empty<Bytes>>>, Service::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().0.project() {
UpgradeHttpFutureProj::Service(future) => {
future.poll(cx).map_ok(|result| result.map(Either::Left))
}
UpgradeHttpFutureProj::Upgrade(response) => Poll::Ready(Ok(response
.take()
.expect("polled again after `Poll::Ready`")
.map(Either::Right))),
}
}
}
fn extract_authority<Body>(request: &Request<Body>) -> Option<Authority> {
const X_FORWARDED_HOST: &str = "x-forwarded-host";
let headers = request.headers();
headers
.get(X_FORWARDED_HOST)
.or_else(|| headers.get(HOST))
.and_then(|header| header.to_str().ok())
.or_else(|| request.uri().host())
.and_then(|host| Authority::try_from(host).ok())
}