Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add TCP keepalive for MySQL and PostgresSQL. #3559

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ hashlink = "0.9.0"
indexmap = "2.0"
event-listener = "5.2.0"
hashbrown = "0.14.5"
socket2 = "0.5.7"

[dev-dependencies]
sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] }
Expand Down
3 changes: 2 additions & 1 deletion sqlx-core/src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ mod socket;
pub mod tls;

pub use socket::{
connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer,
connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, TcpKeepalive, WithSocket,
WriteBuffer,
};
40 changes: 37 additions & 3 deletions sqlx-core/src/net/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::io;
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;

use bytes::BufMut;
use futures_core::ready;
Expand Down Expand Up @@ -182,10 +183,18 @@ impl<S: Socket + ?Sized> Socket for Box<S> {
}
}

#[derive(Debug, Clone)]
pub struct TcpKeepalive {
pub time: Duration,
pub interval: Duration,
pub retries: u32,
}
Copy link
Contributor

@CommanderStorm CommanderStorm Oct 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.rs/socket2/latest/socket2/struct.TcpKeepalive.html

I think we should use the builder pattern too and copy the platform support of socket2 given that this is explicitly done in their case.

Also adding documentation via docstrings is really helpful => copying their docs is likely fine.

Likely we should make this Copy too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for giving me a review!
That's reasonable, or we might have to maintain the consistency between TcpKeepalive defined in socket2 and in sqlx_core.
Should I re-export TcpKeepalive definition in socket2?

Copy link
Contributor

@CommanderStorm CommanderStorm Oct 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current code is a build failour on the following platforms (assuming one tries to use the code):

any(
    target_os = "openbsd",
    target_os = "redox",
    target_os = "solaris",
    target_os = "nto",
    target_os = "espidf",
    target_os = "vita",
    target_os = "haiku",
)

Please use the same platform cfg scoping as the new dependency.


pub async fn connect_tcp<Ws: WithSocket>(
host: &str,
port: u16,
with_socket: Ws,
keepalive: &Option<TcpKeepalive>,
Copy link
Contributor

@CommanderStorm CommanderStorm Oct 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
keepalive: &Option<TcpKeepalive>,
keepalive: Option<&TcpKeepalive>,

Would this not be a better API?
https://users.rust-lang.org/t/api-design-option-t-vs-option-t/34139/2

Given that socket2::TcpKeepalive is Copy, I think copying that derive and dropping the reference might be more ergonomic. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't notice that TcpKeepalive in socket2 is Copy.
I will adjust it. Thx.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't notice the API design preference in Rust. Thank you for pointing this out for me!

) -> crate::Result<Ws::Output> {
// IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
let host = host.trim_matches(&['[', ']'][..]);
Expand All @@ -197,6 +206,16 @@ pub async fn connect_tcp<Ws: WithSocket>(
let stream = TcpStream::connect((host, port)).await?;
stream.set_nodelay(true)?;

// set tcp keepalive
if let Some(keepalive) = keepalive {
let keepalive = socket2::TcpKeepalive::new()
.with_interval(keepalive.interval)
.with_retries(keepalive.retries)
.with_time(keepalive.time);
let sock_ref = socket2::SockRef::from(&stream);
sock_ref.set_tcp_keepalive(&keepalive)?;
}

return Ok(with_socket.with_socket(stream));
}

Expand All @@ -216,9 +235,24 @@ pub async fn connect_tcp<Ws: WithSocket>(
s.get_ref().set_nodelay(true)?;
Ok(s)
});
match stream {
Ok(stream) => return Ok(with_socket.with_socket(stream)),
Err(e) => last_err = Some(e),
let stream = match stream {
Ok(stream) => stream,
Err(e) => {
last_err = Some(e);
continue;
}
};
// set tcp keepalive
if let Some(keepalive) = keepalive {
let keepalive = socket2::TcpKeepalive::new()
.with_interval(keepalive.interval)
.with_retries(keepalive.retries)
.with_time(keepalive.time);
let sock_ref = socket2::SockRef::from(&stream);
match sock_ref.set_tcp_keepalive(&keepalive) {
Ok(_) => return Ok(with_socket.with_socket(stream)),
Err(e) => last_err = Some(e),
}
}
}

Expand Down
10 changes: 9 additions & 1 deletion sqlx-mysql/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@ impl MySqlConnection {

let handshake = match &options.socket {
Some(path) => crate::net::connect_uds(path, do_handshake).await?,
None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?,
None => {
crate::net::connect_tcp(
&options.host,
options.port,
do_handshake,
&options.tcp_keep_alive,
)
.await?
}
};

let stream = handshake.await?;
Expand Down
10 changes: 9 additions & 1 deletion sqlx-mysql/src/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod connect;
mod parse;
mod ssl_mode;

use crate::{connection::LogSettings, net::tls::CertificateInput};
use crate::{connection::LogSettings, net::tls::CertificateInput, net::TcpKeepalive};
pub use ssl_mode::MySqlSslMode;

/// Options and flags which can be used to configure a MySQL connection.
Expand Down Expand Up @@ -80,6 +80,7 @@ pub struct MySqlConnectOptions {
pub(crate) no_engine_substitution: bool,
pub(crate) timezone: Option<String>,
pub(crate) set_names: bool,
pub(crate) tcp_keep_alive: Option<TcpKeepalive>,
}

impl Default for MySqlConnectOptions {
Expand Down Expand Up @@ -111,6 +112,7 @@ impl MySqlConnectOptions {
no_engine_substitution: true,
timezone: Some(String::from("+00:00")),
set_names: true,
tcp_keep_alive: None,
}
}

Expand Down Expand Up @@ -403,6 +405,12 @@ impl MySqlConnectOptions {
self.set_names = flag_val;
self
}

/// Sets the TCP keepalive configuration for the connection.
pub fn tcp_keep_alive(mut self, tcp_keep_alive: TcpKeepalive) -> Self {
self.tcp_keep_alive = Some(tcp_keep_alive);
self
}
}

impl MySqlConnectOptions {
Expand Down
10 changes: 9 additions & 1 deletion sqlx-postgres/src/connection/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,15 @@ impl PgStream {
pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
let socket_future = match options.fetch_socket() {
Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,
None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?,
None => {
net::connect_tcp(
&options.host,
options.port,
MaybeUpgradeTls(options),
&options.tcp_keep_alive,
)
.await?
}
};

let socket = socket_future.await?;
Expand Down
10 changes: 9 additions & 1 deletion sqlx-postgres/src/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::path::{Path, PathBuf};

pub use ssl_mode::PgSslMode;

use crate::{connection::LogSettings, net::tls::CertificateInput};
use crate::{connection::LogSettings, net::tls::CertificateInput, net::TcpKeepalive};

mod connect;
mod parse;
Expand Down Expand Up @@ -102,6 +102,7 @@ pub struct PgConnectOptions {
pub(crate) application_name: Option<String>,
pub(crate) log_settings: LogSettings,
pub(crate) extra_float_digits: Option<Cow<'static, str>>,
pub(crate) tcp_keep_alive: Option<TcpKeepalive>,
pub(crate) options: Option<String>,
}

Expand Down Expand Up @@ -168,6 +169,7 @@ impl PgConnectOptions {
application_name: var("PGAPPNAME").ok(),
extra_float_digits: Some("2".into()),
log_settings: Default::default(),
tcp_keep_alive: None,
options: var("PGOPTIONS").ok(),
}
}
Expand Down Expand Up @@ -493,6 +495,12 @@ impl PgConnectOptions {
self
}

/// Sets the TCP keepalive configuration for the connection.
pub fn tcp_keep_alive(mut self, tcp_keep_alive: TcpKeepalive) -> Self {
self.tcp_keep_alive = Some(tcp_keep_alive);
self
}

/// Set additional startup options for the connection as a list of key-value pairs.
///
/// # Example
Expand Down