sal/src/postgresclient/postgresclient.rs
Mahmoud Emad 114d63e590
Some checks failed
Rhai Tests / Run Rhai Tests (pull_request) Has been cancelled
feat: Add PostgreSQL connection pooling support
- Implement connection pooling using `r2d2` and `r2d2_postgres`
- Add connection pool configuration options to `PostgresConfigBuilder`
- Introduce transaction functions with automatic commit/rollback
- Add functions for executing queries using the connection pool
- Add `QueryParams` struct for building parameterized queries
- Add tests for connection pooling and transaction functions
2025-05-09 10:45:53 +03:00

826 lines
26 KiB
Rust

use lazy_static::lazy_static;
use postgres::types::ToSql;
use postgres::{Client, Error as PostgresError, NoTls, Row};
use r2d2::Pool;
use r2d2_postgres::PostgresConnectionManager;
use std::env;
use std::sync::{Arc, Mutex, Once};
use std::time::Duration;
// Helper function to create a PostgreSQL error
fn create_postgres_error(_message: &str) -> PostgresError {
// Since we can't directly create a PostgresError, we'll create one by
// attempting to connect to an invalid connection string and capturing the error
let result = Client::connect("invalid-connection-string", NoTls);
match result {
Ok(_) => unreachable!(), // This should never happen
Err(e) => {
// We have a valid PostgresError now, but we want to customize the message
// Unfortunately, PostgresError doesn't provide a way to modify the message
// So we'll just return the error we got
e
}
}
}
// Global PostgreSQL client instance using lazy_static
lazy_static! {
static ref POSTGRES_CLIENT: Mutex<Option<Arc<PostgresClientWrapper>>> = Mutex::new(None);
static ref POSTGRES_POOL: Mutex<Option<Arc<Pool<PostgresConnectionManager<NoTls>>>>> =
Mutex::new(None);
static ref INIT: Once = Once::new();
}
/// PostgreSQL connection configuration builder
///
/// This struct is used to build a PostgreSQL connection configuration.
/// It follows the builder pattern to allow for flexible configuration.
#[derive(Debug)]
pub struct PostgresConfigBuilder {
pub host: String,
pub port: u16,
pub user: String,
pub password: Option<String>,
pub database: String,
pub application_name: Option<String>,
pub connect_timeout: Option<u64>,
pub ssl_mode: Option<String>,
// Connection pool settings
pub pool_max_size: Option<u32>,
pub pool_min_idle: Option<u32>,
pub pool_idle_timeout: Option<Duration>,
pub pool_connection_timeout: Option<Duration>,
pub pool_max_lifetime: Option<Duration>,
pub use_pool: bool,
}
impl Default for PostgresConfigBuilder {
fn default() -> Self {
Self {
host: "localhost".to_string(),
port: 5432,
user: "postgres".to_string(),
password: None,
database: "postgres".to_string(),
application_name: None,
connect_timeout: None,
ssl_mode: None,
// Default pool settings
pool_max_size: Some(10),
pool_min_idle: Some(1),
pool_idle_timeout: Some(Duration::from_secs(300)),
pool_connection_timeout: Some(Duration::from_secs(30)),
pool_max_lifetime: Some(Duration::from_secs(1800)),
use_pool: false,
}
}
}
impl PostgresConfigBuilder {
/// Create a new PostgreSQL connection configuration builder with default values
pub fn new() -> Self {
Self::default()
}
/// Set the host for the PostgreSQL connection
pub fn host(mut self, host: &str) -> Self {
self.host = host.to_string();
self
}
/// Set the port for the PostgreSQL connection
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
/// Set the user for the PostgreSQL connection
pub fn user(mut self, user: &str) -> Self {
self.user = user.to_string();
self
}
/// Set the password for the PostgreSQL connection
pub fn password(mut self, password: &str) -> Self {
self.password = Some(password.to_string());
self
}
/// Set the database for the PostgreSQL connection
pub fn database(mut self, database: &str) -> Self {
self.database = database.to_string();
self
}
/// Set the application name for the PostgreSQL connection
pub fn application_name(mut self, application_name: &str) -> Self {
self.application_name = Some(application_name.to_string());
self
}
/// Set the connection timeout in seconds
pub fn connect_timeout(mut self, seconds: u64) -> Self {
self.connect_timeout = Some(seconds);
self
}
/// Set the SSL mode for the PostgreSQL connection
pub fn ssl_mode(mut self, ssl_mode: &str) -> Self {
self.ssl_mode = Some(ssl_mode.to_string());
self
}
/// Enable connection pooling
pub fn use_pool(mut self, use_pool: bool) -> Self {
self.use_pool = use_pool;
self
}
/// Set the maximum size of the connection pool
pub fn pool_max_size(mut self, size: u32) -> Self {
self.pool_max_size = Some(size);
self
}
/// Set the minimum number of idle connections in the pool
pub fn pool_min_idle(mut self, size: u32) -> Self {
self.pool_min_idle = Some(size);
self
}
/// Set the idle timeout for connections in the pool
pub fn pool_idle_timeout(mut self, timeout: Duration) -> Self {
self.pool_idle_timeout = Some(timeout);
self
}
/// Set the connection timeout for the pool
pub fn pool_connection_timeout(mut self, timeout: Duration) -> Self {
self.pool_connection_timeout = Some(timeout);
self
}
/// Set the maximum lifetime of connections in the pool
pub fn pool_max_lifetime(mut self, lifetime: Duration) -> Self {
self.pool_max_lifetime = Some(lifetime);
self
}
/// Build the connection string from the configuration
pub fn build_connection_string(&self) -> String {
let mut conn_string = format!(
"host={} port={} user={} dbname={}",
self.host, self.port, self.user, self.database
);
if let Some(password) = &self.password {
conn_string.push_str(&format!(" password={}", password));
}
if let Some(app_name) = &self.application_name {
conn_string.push_str(&format!(" application_name={}", app_name));
}
if let Some(timeout) = self.connect_timeout {
conn_string.push_str(&format!(" connect_timeout={}", timeout));
}
if let Some(ssl_mode) = &self.ssl_mode {
conn_string.push_str(&format!(" sslmode={}", ssl_mode));
}
conn_string
}
/// Build a PostgreSQL client from the configuration
pub fn build(&self) -> Result<Client, PostgresError> {
let conn_string = self.build_connection_string();
Client::connect(&conn_string, NoTls)
}
/// Build a PostgreSQL connection pool from the configuration
pub fn build_pool(&self) -> Result<Pool<PostgresConnectionManager<NoTls>>, r2d2::Error> {
let conn_string = self.build_connection_string();
let manager = PostgresConnectionManager::new(conn_string.parse().unwrap(), NoTls);
let mut pool_builder = r2d2::Pool::builder();
if let Some(max_size) = self.pool_max_size {
pool_builder = pool_builder.max_size(max_size);
}
if let Some(min_idle) = self.pool_min_idle {
pool_builder = pool_builder.min_idle(Some(min_idle));
}
if let Some(idle_timeout) = self.pool_idle_timeout {
pool_builder = pool_builder.idle_timeout(Some(idle_timeout));
}
if let Some(connection_timeout) = self.pool_connection_timeout {
pool_builder = pool_builder.connection_timeout(connection_timeout);
}
if let Some(max_lifetime) = self.pool_max_lifetime {
pool_builder = pool_builder.max_lifetime(Some(max_lifetime));
}
pool_builder.build(manager)
}
}
/// Wrapper for PostgreSQL client to handle connection
pub struct PostgresClientWrapper {
connection_string: String,
client: Mutex<Option<Client>>,
}
/// Transaction functions for PostgreSQL
///
/// These functions provide a way to execute queries within a transaction.
/// The transaction is automatically committed when the function returns successfully,
/// or rolled back if an error occurs.
///
/// Example:
/// ```
/// use sal::postgresclient::{transaction, QueryParams};
///
/// let result = transaction(|client| {
/// // Execute queries within the transaction
/// client.execute("INSERT INTO users (name) VALUES ($1)", &[&"John"])?;
/// client.execute("UPDATE users SET active = true WHERE name = $1", &[&"John"])?;
///
/// // Return a result from the transaction
/// Ok(())
/// });
/// ```
pub fn transaction<F, T>(operations: F) -> Result<T, PostgresError>
where
F: FnOnce(&mut Client) -> Result<T, PostgresError>,
{
let client = get_postgres_client()?;
let client_mutex = client.get_client()?;
let mut client_guard = client_mutex.lock().unwrap();
if let Some(client) = client_guard.as_mut() {
// Begin transaction
client.execute("BEGIN", &[])?;
// Execute operations
match operations(client) {
Ok(result) => {
// Commit transaction
client.execute("COMMIT", &[])?;
Ok(result)
}
Err(e) => {
// Rollback transaction
let _ = client.execute("ROLLBACK", &[]);
Err(e)
}
}
} else {
Err(create_postgres_error("Failed to get PostgreSQL client"))
}
}
/// Transaction functions for PostgreSQL using the connection pool
///
/// These functions provide a way to execute queries within a transaction using the connection pool.
/// The transaction is automatically committed when the function returns successfully,
/// or rolled back if an error occurs.
///
/// Example:
/// ```
/// use sal::postgresclient::{transaction_with_pool, QueryParams};
///
/// let result = transaction_with_pool(|client| {
/// // Execute queries within the transaction
/// client.execute("INSERT INTO users (name) VALUES ($1)", &[&"John"])?;
/// client.execute("UPDATE users SET active = true WHERE name = $1", &[&"John"])?;
///
/// // Return a result from the transaction
/// Ok(())
/// });
/// ```
pub fn transaction_with_pool<F, T>(operations: F) -> Result<T, PostgresError>
where
F: FnOnce(&mut Client) -> Result<T, PostgresError>,
{
let pool = get_postgres_pool()?;
let mut client = pool.get().map_err(|e| {
create_postgres_error(&format!("Failed to get connection from pool: {}", e))
})?;
// Begin transaction
client.execute("BEGIN", &[])?;
// Execute operations
match operations(&mut client) {
Ok(result) => {
// Commit transaction
client.execute("COMMIT", &[])?;
Ok(result)
}
Err(e) => {
// Rollback transaction
let _ = client.execute("ROLLBACK", &[]);
Err(e)
}
}
}
impl PostgresClientWrapper {
/// Create a new PostgreSQL client wrapper
fn new(connection_string: String) -> Self {
PostgresClientWrapper {
connection_string,
client: Mutex::new(None),
}
}
/// Get a reference to the PostgreSQL client, creating it if it doesn't exist
fn get_client(&self) -> Result<&Mutex<Option<Client>>, PostgresError> {
let mut client_guard = self.client.lock().unwrap();
// If we don't have a client or it's not working, create a new one
if client_guard.is_none() {
*client_guard = Some(Client::connect(&self.connection_string, NoTls)?);
}
Ok(&self.client)
}
/// Execute a query on the PostgreSQL connection
pub fn execute(
&self,
query: &str,
params: &[&(dyn postgres::types::ToSql + Sync)],
) -> Result<u64, PostgresError> {
let client_mutex = self.get_client()?;
let mut client_guard = client_mutex.lock().unwrap();
if let Some(client) = client_guard.as_mut() {
client.execute(query, params)
} else {
Err(create_postgres_error("Failed to get PostgreSQL client"))
}
}
/// Execute a query on the PostgreSQL connection and return the rows
pub fn query(
&self,
query: &str,
params: &[&(dyn postgres::types::ToSql + Sync)],
) -> Result<Vec<Row>, PostgresError> {
let client_mutex = self.get_client()?;
let mut client_guard = client_mutex.lock().unwrap();
if let Some(client) = client_guard.as_mut() {
client.query(query, params)
} else {
Err(create_postgres_error("Failed to get PostgreSQL client"))
}
}
/// Execute a query on the PostgreSQL connection and return a single row
pub fn query_one(
&self,
query: &str,
params: &[&(dyn postgres::types::ToSql + Sync)],
) -> Result<Row, PostgresError> {
let client_mutex = self.get_client()?;
let mut client_guard = client_mutex.lock().unwrap();
if let Some(client) = client_guard.as_mut() {
client.query_one(query, params)
} else {
Err(create_postgres_error("Failed to get PostgreSQL client"))
}
}
/// Execute a query on the PostgreSQL connection and return an optional row
pub fn query_opt(
&self,
query: &str,
params: &[&(dyn postgres::types::ToSql + Sync)],
) -> Result<Option<Row>, PostgresError> {
let client_mutex = self.get_client()?;
let mut client_guard = client_mutex.lock().unwrap();
if let Some(client) = client_guard.as_mut() {
client.query_opt(query, params)
} else {
Err(create_postgres_error("Failed to get PostgreSQL client"))
}
}
/// Ping the PostgreSQL server to check if the connection is alive
pub fn ping(&self) -> Result<bool, PostgresError> {
let result = self.query("SELECT 1", &[]);
match result {
Ok(_) => Ok(true),
Err(e) => Err(e),
}
}
}
/// Get the PostgreSQL client instance
pub fn get_postgres_client() -> Result<Arc<PostgresClientWrapper>, PostgresError> {
// Check if we already have a client
{
let guard = POSTGRES_CLIENT.lock().unwrap();
if let Some(ref client) = &*guard {
return Ok(Arc::clone(client));
}
}
// Create a new client
let client = create_postgres_client()?;
// Store the client globally
{
let mut guard = POSTGRES_CLIENT.lock().unwrap();
*guard = Some(Arc::clone(&client));
}
Ok(client)
}
/// Create a new PostgreSQL client
fn create_postgres_client() -> Result<Arc<PostgresClientWrapper>, PostgresError> {
// Try to get connection details from environment variables
let host = env::var("POSTGRES_HOST").unwrap_or_else(|_| String::from("localhost"));
let port = env::var("POSTGRES_PORT")
.ok()
.and_then(|p| p.parse::<u16>().ok())
.unwrap_or(5432);
let user = env::var("POSTGRES_USER").unwrap_or_else(|_| String::from("postgres"));
let password = env::var("POSTGRES_PASSWORD").ok();
let database = env::var("POSTGRES_DB").unwrap_or_else(|_| String::from("postgres"));
// Build the connection string
let mut builder = PostgresConfigBuilder::new()
.host(&host)
.port(port)
.user(&user)
.database(&database);
if let Some(pass) = password {
builder = builder.password(&pass);
}
let connection_string = builder.build_connection_string();
// Create the client wrapper
let wrapper = Arc::new(PostgresClientWrapper::new(connection_string));
// Test the connection
match wrapper.ping() {
Ok(_) => Ok(wrapper),
Err(e) => Err(e),
}
}
/// Reset the PostgreSQL client
pub fn reset() -> Result<(), PostgresError> {
// Clear the existing client
{
let mut client_guard = POSTGRES_CLIENT.lock().unwrap();
*client_guard = None;
}
// Create a new client, only return error if it fails
get_postgres_client()?;
Ok(())
}
/// Execute a query on the PostgreSQL connection
pub fn execute(
query: &str,
params: &[&(dyn postgres::types::ToSql + Sync)],
) -> Result<u64, PostgresError> {
let client = get_postgres_client()?;
client.execute(query, params)
}
/// Execute a query on the PostgreSQL connection and return the rows
pub fn query(
query: &str,
params: &[&(dyn postgres::types::ToSql + Sync)],
) -> Result<Vec<Row>, PostgresError> {
let client = get_postgres_client()?;
client.query(query, params)
}
/// Execute a query on the PostgreSQL connection and return a single row
pub fn query_one(
query: &str,
params: &[&(dyn postgres::types::ToSql + Sync)],
) -> Result<Row, PostgresError> {
let client = get_postgres_client()?;
client.query_one(query, params)
}
/// Execute a query on the PostgreSQL connection and return an optional row
pub fn query_opt(
query: &str,
params: &[&(dyn postgres::types::ToSql + Sync)],
) -> Result<Option<Row>, PostgresError> {
let client = get_postgres_client()?;
client.query_opt(query, params)
}
/// Create a new PostgreSQL client with custom configuration
pub fn with_config(config: PostgresConfigBuilder) -> Result<Client, PostgresError> {
config.build()
}
/// Create a new PostgreSQL connection pool with custom configuration
pub fn with_pool_config(
config: PostgresConfigBuilder,
) -> Result<Pool<PostgresConnectionManager<NoTls>>, r2d2::Error> {
config.build_pool()
}
/// Get the PostgreSQL connection pool instance
pub fn get_postgres_pool() -> Result<Arc<Pool<PostgresConnectionManager<NoTls>>>, PostgresError> {
// Check if we already have a pool
{
let guard = POSTGRES_POOL.lock().unwrap();
if let Some(ref pool) = &*guard {
return Ok(Arc::clone(pool));
}
}
// Create a new pool
let pool = create_postgres_pool()?;
// Store the pool globally
{
let mut guard = POSTGRES_POOL.lock().unwrap();
*guard = Some(Arc::clone(&pool));
}
Ok(pool)
}
/// Create a new PostgreSQL connection pool
fn create_postgres_pool() -> Result<Arc<Pool<PostgresConnectionManager<NoTls>>>, PostgresError> {
// Try to get connection details from environment variables
let host = env::var("POSTGRES_HOST").unwrap_or_else(|_| String::from("localhost"));
let port = env::var("POSTGRES_PORT")
.ok()
.and_then(|p| p.parse::<u16>().ok())
.unwrap_or(5432);
let user = env::var("POSTGRES_USER").unwrap_or_else(|_| String::from("postgres"));
let password = env::var("POSTGRES_PASSWORD").ok();
let database = env::var("POSTGRES_DB").unwrap_or_else(|_| String::from("postgres"));
// Build the configuration
let mut builder = PostgresConfigBuilder::new()
.host(&host)
.port(port)
.user(&user)
.database(&database)
.use_pool(true);
if let Some(pass) = password {
builder = builder.password(&pass);
}
// Create the pool
match builder.build_pool() {
Ok(pool) => {
// Test the connection
match pool.get() {
Ok(_) => Ok(Arc::new(pool)),
Err(e) => Err(create_postgres_error(&format!(
"Failed to connect to PostgreSQL: {}",
e
))),
}
}
Err(e) => Err(create_postgres_error(&format!(
"Failed to create PostgreSQL connection pool: {}",
e
))),
}
}
/// Reset the PostgreSQL connection pool
pub fn reset_pool() -> Result<(), PostgresError> {
// Clear the existing pool
{
let mut pool_guard = POSTGRES_POOL.lock().unwrap();
*pool_guard = None;
}
// Create a new pool, only return error if it fails
get_postgres_pool()?;
Ok(())
}
/// Execute a query using the connection pool
pub fn execute_with_pool(
query: &str,
params: &[&(dyn postgres::types::ToSql + Sync)],
) -> Result<u64, PostgresError> {
let pool = get_postgres_pool()?;
let mut client = pool.get().map_err(|e| {
create_postgres_error(&format!("Failed to get connection from pool: {}", e))
})?;
client.execute(query, params)
}
/// Execute a query using the connection pool and return the rows
pub fn query_with_pool(
query: &str,
params: &[&(dyn postgres::types::ToSql + Sync)],
) -> Result<Vec<Row>, PostgresError> {
let pool = get_postgres_pool()?;
let mut client = pool.get().map_err(|e| {
create_postgres_error(&format!("Failed to get connection from pool: {}", e))
})?;
client.query(query, params)
}
/// Execute a query using the connection pool and return a single row
pub fn query_one_with_pool(
query: &str,
params: &[&(dyn postgres::types::ToSql + Sync)],
) -> Result<Row, PostgresError> {
let pool = get_postgres_pool()?;
let mut client = pool.get().map_err(|e| {
create_postgres_error(&format!("Failed to get connection from pool: {}", e))
})?;
client.query_one(query, params)
}
/// Execute a query using the connection pool and return an optional row
pub fn query_opt_with_pool(
query: &str,
params: &[&(dyn postgres::types::ToSql + Sync)],
) -> Result<Option<Row>, PostgresError> {
let pool = get_postgres_pool()?;
let mut client = pool.get().map_err(|e| {
create_postgres_error(&format!("Failed to get connection from pool: {}", e))
})?;
client.query_opt(query, params)
}
/// Parameter builder for PostgreSQL queries
///
/// This struct helps build parameterized queries for PostgreSQL.
/// It provides a type-safe way to build query parameters.
#[derive(Default)]
pub struct QueryParams {
params: Vec<Box<dyn ToSql + Sync>>,
}
impl QueryParams {
/// Create a new empty parameter builder
pub fn new() -> Self {
Self { params: Vec::new() }
}
/// Add a parameter to the builder
pub fn add<T: 'static + ToSql + Sync>(&mut self, value: T) -> &mut Self {
self.params.push(Box::new(value));
self
}
/// Add a string parameter to the builder
pub fn add_str(&mut self, value: &str) -> &mut Self {
self.add(value.to_string())
}
/// Add an integer parameter to the builder
pub fn add_int(&mut self, value: i32) -> &mut Self {
self.add(value)
}
/// Add a float parameter to the builder
pub fn add_float(&mut self, value: f64) -> &mut Self {
self.add(value)
}
/// Add a boolean parameter to the builder
pub fn add_bool(&mut self, value: bool) -> &mut Self {
self.add(value)
}
/// Add an optional parameter to the builder
pub fn add_opt<T: 'static + ToSql + Sync>(&mut self, value: Option<T>) -> &mut Self {
if let Some(v) = value {
self.add(v);
} else {
// Add NULL value
self.params.push(Box::new(None::<String>));
}
self
}
/// Get the parameters as a slice of references
pub fn as_slice(&self) -> Vec<&(dyn ToSql + Sync)> {
self.params
.iter()
.map(|p| p.as_ref() as &(dyn ToSql + Sync))
.collect()
}
}
/// Execute a query with the parameter builder
pub fn execute_with_params(query_str: &str, params: &QueryParams) -> Result<u64, PostgresError> {
let client = get_postgres_client()?;
client.execute(query_str, &params.as_slice())
}
/// Execute a query with the parameter builder and return the rows
pub fn query_with_params(query_str: &str, params: &QueryParams) -> Result<Vec<Row>, PostgresError> {
let client = get_postgres_client()?;
client.query(query_str, &params.as_slice())
}
/// Execute a query with the parameter builder and return a single row
pub fn query_one_with_params(query_str: &str, params: &QueryParams) -> Result<Row, PostgresError> {
let client = get_postgres_client()?;
client.query_one(query_str, &params.as_slice())
}
/// Execute a query with the parameter builder and return an optional row
pub fn query_opt_with_params(
query_str: &str,
params: &QueryParams,
) -> Result<Option<Row>, PostgresError> {
let client = get_postgres_client()?;
client.query_opt(query_str, &params.as_slice())
}
/// Execute a query with the parameter builder using the connection pool
pub fn execute_with_pool_params(
query_str: &str,
params: &QueryParams,
) -> Result<u64, PostgresError> {
execute_with_pool(query_str, &params.as_slice())
}
/// Execute a query with the parameter builder using the connection pool and return the rows
pub fn query_with_pool_params(
query_str: &str,
params: &QueryParams,
) -> Result<Vec<Row>, PostgresError> {
query_with_pool(query_str, &params.as_slice())
}
/// Execute a query with the parameter builder using the connection pool and return a single row
pub fn query_one_with_pool_params(
query_str: &str,
params: &QueryParams,
) -> Result<Row, PostgresError> {
query_one_with_pool(query_str, &params.as_slice())
}
/// Execute a query with the parameter builder using the connection pool and return an optional row
pub fn query_opt_with_pool_params(
query_str: &str,
params: &QueryParams,
) -> Result<Option<Row>, PostgresError> {
query_opt_with_pool(query_str, &params.as_slice())
}
/// Send a notification on a channel
///
/// This function sends a notification on the specified channel with the specified payload.
///
/// Example:
/// ```
/// use sal::postgresclient::notify;
///
/// notify("my_channel", "Hello, world!").expect("Failed to send notification");
/// ```
pub fn notify(channel: &str, payload: &str) -> Result<(), PostgresError> {
let client = get_postgres_client()?;
client.execute(&format!("NOTIFY {}, '{}'", channel, payload), &[])?;
Ok(())
}
/// Send a notification on a channel using the connection pool
///
/// This function sends a notification on the specified channel with the specified payload using the connection pool.
///
/// Example:
/// ```
/// use sal::postgresclient::notify_with_pool;
///
/// notify_with_pool("my_channel", "Hello, world!").expect("Failed to send notification");
/// ```
pub fn notify_with_pool(channel: &str, payload: &str) -> Result<(), PostgresError> {
let pool = get_postgres_pool()?;
let mut client = pool.get().map_err(|e| {
create_postgres_error(&format!("Failed to get connection from pool: {}", e))
})?;
client.execute(&format!("NOTIFY {}, '{}'", channel, payload), &[])?;
Ok(())
}