Skip to content

Commit

Permalink
feat(rust): added the possibility to overwrite http headers in inlets
Browse files Browse the repository at this point in the history
  • Loading branch information
davide-baldo committed Jan 8, 2025
1 parent bbcd779 commit babc7e1
Show file tree
Hide file tree
Showing 21 changed files with 717 additions and 194 deletions.
304 changes: 304 additions & 0 deletions implementations/rust/ockam/ockam_api/src/http/interceptor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
use crate::http::state::{ClientRequestWriter, RequestState};
use crate::nodes::models::services::{DeleteServiceRequest, StartServiceRequest};
use crate::nodes::registry::HttpHeaderInterceptorInfo;
use crate::nodes::{NodeManager, NodeManagerWorker};
use crate::DefaultAddress;
use httparse::Request;
use minicbor::{CborLen, Decode, Encode};
use ockam_abac::{Action, PolicyAccessControl, Resource, ResourceType};
use ockam_core::api::Response;
use ockam_core::errcode::{Kind, Origin};
use ockam_core::{async_trait, Address, AllowAll, IncomingAccessControl, OutgoingAccessControl};
use ockam_node::Context;
use ockam_transport_tcp::{
read_portal_payload_length, Direction, PortalInletInterceptor, PortalInterceptor,
PortalInterceptorFactory,
};
use std::io::Write;
use std::sync::{Arc, Mutex as SyncMutex};

struct HttpHeadersInterceptorFactory {
headers: Arc<Vec<(String, String)>>,
}

impl HttpHeadersInterceptorFactory {
pub async fn create(
context: &Context,
listener_address: Address,
headers: Vec<(String, String)>,
policy_access_control: Option<PolicyAccessControl>,
) -> ockam_core::Result<()> {
let flow_control_id = context
.flow_controls()
.get_flow_control_with_spawner(&Address::from_string(
DefaultAddress::SECURE_CHANNEL_LISTENER,
))
.ok_or_else(|| {
ockam_core::Error::new(
Origin::Channel,
Kind::NotFound,
"Secure channel listener not found",
)
})?;

context
.flow_controls()
.add_consumer(listener_address.clone(), &flow_control_id);

let incoming_access_control: Arc<dyn IncomingAccessControl>;
let outgoing_access_control: Arc<dyn OutgoingAccessControl>;
if let Some(policy_access_control) = policy_access_control {
incoming_access_control = Arc::new(policy_access_control.create_incoming());
outgoing_access_control =
Arc::new(policy_access_control.create_outgoing(context).await?);
} else {
incoming_access_control = Arc::new(AllowAll);
outgoing_access_control = Arc::new(AllowAll);
}

PortalInletInterceptor::create(
context,
listener_address,
Arc::new(HttpHeadersInterceptorFactory {
headers: Arc::new(headers),
}),
incoming_access_control,
outgoing_access_control,
read_portal_payload_length(),
)
.await
}
}

impl PortalInterceptorFactory for HttpHeadersInterceptorFactory {
fn create(&self) -> Arc<dyn PortalInterceptor> {
Arc::new(HttpHeadersInterceptor {
headers: self.headers.clone(),
state: SyncMutex::new(RequestState::ParsingHeader(None)),
})
}
}

struct HttpHeadersInterceptor {
headers: Arc<Vec<(String, String)>>,
state: SyncMutex<RequestState>,
}

#[async_trait]
impl PortalInterceptor for HttpHeadersInterceptor {
async fn intercept(
&self,
_context: &mut Context,
direction: Direction,
buffer: &[u8],
) -> ockam_core::Result<Option<Vec<u8>>> {
match direction {
Direction::FromOutletToInlet => Ok(Some(buffer.to_vec())),
Direction::FromInletToOutlet => {
let mut guard = self.state.lock().unwrap();
Ok(Some(guard.process_http_buffer(buffer, self)?))
}
}
}
}

impl ClientRequestWriter for &HttpHeadersInterceptor {
fn write_headers(&self, request: &Request, buffer: &mut Vec<u8>) -> ockam_core::Result<()> {
write!(
buffer,
"{} {} HTTP/1.{}\r\n",
request.method.unwrap(),
request.path.unwrap(),
request.version.unwrap()
)
.unwrap();

for (name, value) in self.headers.iter() {
write!(buffer, "{}: {}\r\n", name, value).unwrap();
}

for h in &*request.headers {
if !self
.headers
.iter()
.any(|(name, _)| name.eq_ignore_ascii_case(h.name))
{
write!(buffer, "{}: ", h.name).unwrap();
buffer.extend_from_slice(h.value);
buffer.extend_from_slice(b"\r\n");
}
}

buffer.extend_from_slice(b"\r\n");
Ok(())
}
}

/// Request body to create a new HTTP rewrite headers interceptor
#[derive(Clone, Debug, Encode, Decode, CborLen)]
#[rustfmt::skip]
#[cbor(map)]
pub struct HttpHeadersInterceptorRequest {
#[n(0)] pub headers: Vec<(String, String)>,
}

impl NodeManagerWorker {
pub async fn start_http_header_service(
&self,
context: &Context,
request: StartServiceRequest<HttpHeadersInterceptorRequest>,
) -> ockam_core::Result<Response<()>, Response<ockam_core::api::Error>> {
let result = self
.node_manager
.start_http_header_service(
context,
Address::from_string(request.address()),
request.request().headers.clone(),
)
.await;

match result {
Ok(_) => Ok(Response::ok().body(())),
Err(e) => Err(Response::internal_error_no_request(&e.to_string())),
}
}

pub async fn delete_http_overwrite_header_service(
&self,
context: &Context,
request: DeleteServiceRequest,
) -> ockam_core::Result<Response<()>, Response<ockam_core::api::Error>> {
let result = self
.node_manager
.delete_http_overwrite_header_service(context, Address::from_string(request.address()))
.await;

match result {
Ok(_) => Ok(Response::ok().body(())),
Err(e) => Err(Response::internal_error_no_request(&e.to_string())),
}
}
}

impl NodeManager {
pub async fn start_http_header_service(
&self,
context: &Context,
listener_address: Address,
headers: Vec<(String, String)>,
) -> ockam_core::Result<()> {
let policy_access_control = self
.policy_access_control(
self.project_authority().clone(),
Resource::new(listener_address.to_string(), ResourceType::TcpInlet),
Action::HandleMessage,
None,
)
.await?;

HttpHeadersInterceptorFactory::create(
context,
listener_address.clone(),
headers,
Some(policy_access_control),
)
.await?;

self.registry
.http_headers_interceptors
.insert(listener_address, HttpHeaderInterceptorInfo {})
.await;

Ok(())
}

pub async fn delete_http_overwrite_header_service(
&self,
context: &Context,
listener_address: Address,
) -> ockam_core::Result<()> {
context.stop_worker(listener_address.clone()).await?;

self.registry
.http_headers_interceptors
.remove(&listener_address)
.await;

Ok(())
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::nodes::service::{NodeManagerCredentialRetrieverOptions, NodeManagerTrustOptions};
use crate::test_utils::start_manager_for_tests;
use ockam_core::NeutralMessage;
use ockam_transport_tcp::PortalMessage;

#[ockam::test]
async fn main(context: &mut Context) -> ockam::Result<()> {
let handler = start_manager_for_tests(
context,
None,
Some(NodeManagerTrustOptions::new(
NodeManagerCredentialRetrieverOptions::None,
NodeManagerCredentialRetrieverOptions::None,
None,
NodeManagerCredentialRetrieverOptions::None,
)),
)
.await?;

HttpHeadersInterceptorFactory::create(
context,
"http_interceptor".into(),
vec![("Host".to_string(), "ockam.io".to_string())],
None,
)
.await?;

let connection = handler
.node_manager
.make_connection(
context,
&format!(
"/service/http_interceptor/service/{}",
context.address_ref().address()
)
.parse()?,
handler.node_manager.identifier(),
None,
None,
)
.await?;

let route = connection.route()?;

context
.send(route.clone(), PortalMessage::Ping.to_neutral_message()?)
.await?;

let _ = context.receive::<NeutralMessage>().await?;

context
.send(
route.clone(),
PortalMessage::Payload(b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", None)
.to_neutral_message()?,
)
.await?;

let message = context.receive::<NeutralMessage>().await?;
let message = PortalMessage::decode(message.payload())?;

if let PortalMessage::Payload(payload, _) = message {
let message = String::from_utf8(payload.to_vec()).unwrap();
assert_eq!(message, "GET / HTTP/1.1\r\nHost: ockam.io\r\n\r\n");
} else {
panic!("Decoded message is not a Payload");
}

Ok(())
}
}
2 changes: 2 additions & 0 deletions implementations/rust/ockam/ockam_api/src/http/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub(crate) mod interceptor;
pub mod state;
Loading

0 comments on commit babc7e1

Please sign in to comment.