use std::fmt::{self, Debug, Formatter};
use std::future::Future;
use std::io::ErrorKind;
use std::net::{SocketAddr, TcpListener};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, slice};
use axum_server::accept::Accept;
use axum_server::tls_rustls::{RustlsAcceptor, RustlsConfig};
use axum_server::Server;
use bytes::Bytes;
use http::{Request, Response};
use http_body_util::{Either as BodyEither, Empty};
use pin_project::pin_project;
use tokio::io::ReadBuf;
use tokio::net::TcpStream;
use tokio_rustls::server::TlsStream;
use tokio_util::either::Either as TokioEither;
use tower_service::Service as TowerService;
use crate::UpgradeHttp;
#[must_use]
pub fn bind_dual_protocol(
address: SocketAddr,
config: RustlsConfig,
) -> Server<DualProtocolAcceptor> {
let acceptor = DualProtocolAcceptor::new(config);
Server::bind(address).acceptor(acceptor)
}
#[must_use]
pub fn from_tcp_dual_protocol(
listener: TcpListener,
config: RustlsConfig,
) -> Server<DualProtocolAcceptor> {
let acceptor = DualProtocolAcceptor::new(config);
Server::from_tcp(listener).acceptor(acceptor)
}
pub trait ServerExt {
#[must_use]
fn set_upgrade(self, upgrade: bool) -> Self;
}
impl ServerExt for Server<DualProtocolAcceptor> {
fn set_upgrade(mut self, upgrade: bool) -> Self {
self.get_mut().set_upgrade(upgrade);
self
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum Protocol {
Tls,
Plain,
}
#[derive(Debug, Clone)]
pub struct DualProtocolAcceptor {
rustls: RustlsAcceptor,
upgrade: bool,
}
impl DualProtocolAcceptor {
#[must_use]
pub fn new(config: RustlsConfig) -> Self {
Self {
rustls: RustlsAcceptor::new(config),
upgrade: false,
}
}
pub fn set_upgrade(&mut self, upgrade: bool) {
self.upgrade = upgrade;
}
}
impl<Service: Clone> Accept<TcpStream, Service> for DualProtocolAcceptor {
type Stream = TokioEither<TlsStream<TcpStream>, TcpStream>;
type Service = DualProtocolService<Service>;
type Future = DualProtocolAcceptorFuture<Service>;
fn accept(&self, stream: TcpStream, service: Service) -> Self::Future {
let service = if self.upgrade {
DualProtocolServiceBuilder::new_upgrade(service)
} else {
DualProtocolServiceBuilder::new_service(service)
};
DualProtocolAcceptorFuture::new(stream, service, self.rustls.clone())
}
}
#[derive(Debug)]
#[pin_project(project = DualProtocolAcceptorFutureProj)]
pub struct DualProtocolAcceptorFuture<Service: Clone>(
#[pin]
FutureState<Service>,
);
#[derive(Debug)]
#[pin_project(project = FutuereStateProj)]
enum FutureState<Service: Clone> {
Peek(Option<PeekState<Service>>),
Https(#[pin] <RustlsAcceptor as Accept<TcpStream, DualProtocolService<Service>>>::Future),
}
#[derive(Debug)]
struct PeekState<Service> {
stream: TcpStream,
service: DualProtocolServiceBuilder<Service>,
rustls: RustlsAcceptor,
}
impl<Service: Clone> DualProtocolAcceptorFuture<Service> {
const fn new(
stream: TcpStream,
service: DualProtocolServiceBuilder<Service>,
rustls: RustlsAcceptor,
) -> Self {
Self(FutureState::Peek(Some(PeekState {
stream,
service,
rustls,
})))
}
}
impl<Service: Clone> DualProtocolAcceptorFutureProj<'_, Service> {
fn upgrade(
&mut self,
future: <RustlsAcceptor as Accept<TcpStream, DualProtocolService<Service>>>::Future,
) {
self.0.set(FutureState::Https(future));
}
}
impl<Service: Clone> Future for DualProtocolAcceptorFuture<Service> {
type Output = io::Result<(
TokioEither<TlsStream<TcpStream>, TcpStream>,
DualProtocolService<Service>,
)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
loop {
match this.0.as_mut().project() {
FutuereStateProj::Peek(inner) => {
let peek = inner.as_mut().expect("polled again after `Poll::Ready`");
let mut byte = 0;
let mut buffer = ReadBuf::new(slice::from_mut(&mut byte));
match peek.stream.poll_peek(cx, &mut buffer) {
Poll::Ready(Ok(0)) => {
return Poll::Ready(Err(ErrorKind::UnexpectedEof.into()))
}
Poll::Ready(Ok(_)) => {
let PeekState {
stream,
service,
rustls,
} = inner.take().expect("`inner` was already consumed");
if byte == 0x16 {
this.upgrade(rustls.accept(stream, service.build(Protocol::Tls)));
} else {
return Poll::Ready(Ok((
TokioEither::Right(stream),
service.build(Protocol::Plain),
)));
}
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
}
FutuereStateProj::Https(future) => {
return future
.poll(cx)
.map_ok(|(stream, service)| (TokioEither::Left(stream), service))
}
}
}
}
}
#[derive(Debug)]
struct DualProtocolServiceBuilder<Service>(ServiceServe<Service>);
#[derive(Clone, Debug)]
pub struct DualProtocolService<Service: Clone> {
service: ServiceServe<Service>,
protocol: Protocol,
}
#[derive(Clone, Debug)]
enum ServiceServe<Service> {
Service(Service),
Upgrade(UpgradeHttp<Service>),
}
impl<Service: Clone> DualProtocolServiceBuilder<Service> {
const fn new_service(service: Service) -> Self {
Self(ServiceServe::Service(service))
}
const fn new_upgrade(service: Service) -> Self {
Self(ServiceServe::Upgrade(UpgradeHttp::new(service)))
}
fn build(self, protocol: Protocol) -> DualProtocolService<Service> {
DualProtocolService {
service: self.0,
protocol,
}
}
}
impl<Service, RequestBody, ResponseBody> TowerService<Request<RequestBody>>
for DualProtocolService<Service>
where
Service: Clone + TowerService<Request<RequestBody>, Response = Response<ResponseBody>>,
{
type Response = Response<BodyEither<ResponseBody, BodyEither<ResponseBody, Empty<Bytes>>>>;
type Error = Service::Error;
type Future = DualProtocolServiceFuture<Service, RequestBody, ResponseBody>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match &mut self.service {
ServiceServe::Service(service) => service.poll_ready(cx),
ServiceServe::Upgrade(service) => service.poll_ready(cx),
}
}
fn call(&mut self, mut req: Request<RequestBody>) -> Self::Future {
let _ = req.extensions_mut().insert(self.protocol);
match &mut self.service {
ServiceServe::Service(service) => {
DualProtocolServiceFuture::new_service(service.call(req))
}
ServiceServe::Upgrade(service) => {
DualProtocolServiceFuture::new_upgrade(service.call(req))
}
}
}
}
#[pin_project]
pub struct DualProtocolServiceFuture<Service, RequestBody, ResponseBody>(
#[pin] FutureServe<Service, RequestBody, ResponseBody>,
)
where
Service: TowerService<Request<RequestBody>, Response = Response<ResponseBody>>;
#[derive(Debug)]
#[pin_project(project = DualProtocolServiceFutureProj)]
enum FutureServe<Service, RequestBody, ResponseBody>
where
Service: TowerService<Request<RequestBody>, Response = Response<ResponseBody>>,
{
Service(#[pin] Service::Future),
Upgrade(#[pin] <UpgradeHttp<Service> as TowerService<Request<RequestBody>>>::Future),
}
impl<Service, RequestBody, ResponseBody> Debug
for DualProtocolServiceFuture<Service, RequestBody, ResponseBody>
where
Service: TowerService<Request<RequestBody>, Response = Response<ResponseBody>>,
FutureServe<Service, RequestBody, ResponseBody>: Debug,
{
fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
formatter
.debug_tuple("DualProtocolServiceFuture")
.field(&self.0)
.finish()
}
}
impl<Service, RequestBody, ResponseBody>
DualProtocolServiceFuture<Service, RequestBody, ResponseBody>
where
Service: TowerService<Request<RequestBody>, Response = Response<ResponseBody>>,
{
const fn new_service(future: Service::Future) -> Self {
Self(FutureServe::Service(future))
}
const fn new_upgrade(
future: <UpgradeHttp<Service> as TowerService<Request<RequestBody>>>::Future,
) -> Self {
Self(FutureServe::Upgrade(future))
}
}
impl<Service, RequestBody, ResponseBody> Future
for DualProtocolServiceFuture<Service, RequestBody, ResponseBody>
where
Service: TowerService<Request<RequestBody>, Response = Response<ResponseBody>>,
{
type Output = Result<
Response<BodyEither<ResponseBody, BodyEither<ResponseBody, Empty<Bytes>>>>,
Service::Error,
>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().0.project() {
DualProtocolServiceFutureProj::Service(future) => future
.poll(cx)
.map_ok(|response| response.map(BodyEither::Left)),
DualProtocolServiceFutureProj::Upgrade(future) => future
.poll(cx)
.map_ok(|response| response.map(BodyEither::Right)),
}
}
}