Skip to content

Commit

Permalink
chore(cubesql): ProtocolDetails - support session variables (#8587)
Browse files Browse the repository at this point in the history
  • Loading branch information
ovr authored Aug 16, 2024
1 parent 9927ba2 commit a86dc09
Show file tree
Hide file tree
Showing 13 changed files with 99 additions and 84 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{any::Any, sync::Arc};

use crate::compile::engine::context::TableName;
use crate::compile::{engine::context::TableName, DatabaseVariables};
use async_trait::async_trait;
use datafusion::{
arrow::{
Expand All @@ -14,8 +14,6 @@ use datafusion::{
physical_plan::{memory::MemoryExec, ExecutionPlan},
};

use crate::sql::database_variables::DatabaseVariables;

pub struct PerfSchemaVariablesProvider {
table_name: String,
variables: DatabaseVariables,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{any::Any, sync::Arc};

use crate::compile::DatabaseVariables;
use async_trait::async_trait;
use datafusion::{
arrow::{
Expand All @@ -13,8 +14,6 @@ use datafusion::{
physical_plan::{memory::MemoryExec, ExecutionPlan},
};

use crate::sql::database_variables::DatabaseVariables;

pub struct PgCatalogSettingsProvider {
vars: DatabaseVariables,
}
Expand Down
2 changes: 2 additions & 0 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod query_engine;
pub mod rewrite;
pub mod router;
pub mod service;
pub mod session;

// Internal API
pub mod test;
Expand All @@ -22,6 +23,7 @@ pub use protocol::*;
pub use query_engine::*;
pub use rewrite::rewriter::Rewriter;
pub use router::*;
pub use session::*;

// Re-export base deps to minimise version maintenance for crate users such as cloud
pub use datafusion::{self, arrow};
Expand Down
36 changes: 34 additions & 2 deletions rust/cubesql/cubesql/src/compile/protocol.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use crate::{compile::CubeContext, CubeError};
use crate::{
compile::{CubeContext, DatabaseVariable, DatabaseVariables},
CubeError,
};
use datafusion::datasource;
use log::error;
use std::{
fmt::Debug,
hash::{Hash, Hasher},
Expand All @@ -13,11 +17,17 @@ pub trait DatabaseProtocolDetails: Debug + Send + Sync {

fn support_transactions(&self) -> bool;

/// Get default state for session variables
fn get_session_default_variables(&self) -> DatabaseVariables;

/// Get default value for specific session variable
fn get_session_variable_default(&self, name: &str) -> Option<DatabaseVariable>;

fn get_provider(
&self,
context: &CubeContext,
tr: datafusion::catalog::TableReference,
) -> Option<std::sync::Arc<dyn datasource::TableProvider>>;
) -> Option<Arc<dyn datasource::TableProvider>>;

fn table_name_by_table_provider(
&self,
Expand Down Expand Up @@ -70,6 +80,28 @@ impl DatabaseProtocolDetails for DatabaseProtocol {
}
}

fn get_session_default_variables(&self) -> DatabaseVariables {
match &self {
DatabaseProtocol::MySQL => {
// TODO(ovr): Should we move it from session?
error!("get_session_default_variables was called on MySQL protocol");

DatabaseVariables::default()
}
DatabaseProtocol::PostgreSQL => {
// TODO(ovr): Should we move it from session?
error!("get_session_default_variables was called on PostgreSQL protocol");

DatabaseVariables::default()
}
DatabaseProtocol::Extension(ext) => ext.get_session_default_variables(),
}
}

fn get_session_variable_default(&self, name: &str) -> Option<DatabaseVariable> {
self.get_session_default_variables().get(name).cloned()
}

fn get_provider(
&self,
context: &CubeContext,
Expand Down
2 changes: 1 addition & 1 deletion rust/cubesql/cubesql/src/compile/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ use crate::{
engine::{df::planner::CubeQueryPlanner, udf::*, VariablesProvider},
error::{CompilationError, CompilationResult},
parser::parse_sql_to_statement,
DatabaseVariable, DatabaseVariablesToUpdate,
},
sql::{
database_variables::{DatabaseVariable, DatabaseVariablesToUpdate},
dataframe,
statement::{
ApproximateCountDistinctVisitor, CastReplacer, RedshiftDatePartReplacer,
Expand Down
45 changes: 45 additions & 0 deletions rust/cubesql/cubesql/src/compile/session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use datafusion::{scalar::ScalarValue, variable::VarType};
use std::collections::HashMap;

pub type DatabaseVariablesToUpdate = Vec<DatabaseVariable>;
pub type DatabaseVariables = HashMap<String, DatabaseVariable>;

#[derive(Debug, Clone)]
pub struct DatabaseVariable {
pub name: String,
pub value: ScalarValue,
pub var_type: VarType,
pub readonly: bool,
// Postgres schema includes a range of additional parameters
pub additional_params: Option<HashMap<String, ScalarValue>>,
}

impl DatabaseVariable {
pub fn system(
name: String,
value: ScalarValue,
additional_params: Option<HashMap<String, ScalarValue>>,
) -> Self {
Self {
name: name,
value: value,
var_type: VarType::System,
readonly: false,
additional_params,
}
}

pub fn user_defined(
name: String,
value: ScalarValue,
additional_params: Option<HashMap<String, ScalarValue>>,
) -> Self {
Self {
name: name,
value: value,
var_type: VarType::UserDefined,
readonly: false,
additional_params,
}
}
}
47 changes: 1 addition & 46 deletions rust/cubesql/cubesql/src/sql/database_variables/mod.rs
Original file line number Diff line number Diff line change
@@ -1,53 +1,8 @@
use std::collections::HashMap;

use datafusion::{scalar::ScalarValue, variable::VarType};
use crate::compile::DatabaseVariables;

pub mod mysql;
pub mod postgres;

pub type DatabaseVariablesToUpdate = Vec<DatabaseVariable>;
pub type DatabaseVariables = HashMap<String, DatabaseVariable>;

#[derive(Debug, Clone)]
pub struct DatabaseVariable {
pub name: String,
pub value: ScalarValue,
pub var_type: VarType,
pub readonly: bool,
// Postgres schema includes a range of additional parameters
pub additional_params: Option<HashMap<String, ScalarValue>>,
}

impl DatabaseVariable {
pub fn system(
name: String,
value: ScalarValue,
additional_params: Option<HashMap<String, ScalarValue>>,
) -> Self {
Self {
name: name,
value: value,
var_type: VarType::System,
readonly: false,
additional_params,
}
}

pub fn user_defined(
name: String,
value: ScalarValue,
additional_params: Option<HashMap<String, ScalarValue>>,
) -> Self {
Self {
name: name,
value: value,
var_type: VarType::UserDefined,
readonly: false,
additional_params,
}
}
}

pub fn mysql_default_session_variables() -> DatabaseVariables {
mysql::session_vars::defaults()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::collections::HashMap;

use crate::compile::{DatabaseVariable, DatabaseVariables};
use datafusion::scalar::ScalarValue;

use crate::sql::database_variables::{DatabaseVariable, DatabaseVariables};

pub fn defaults() -> DatabaseVariables {
let mut variables: DatabaseVariables = HashMap::new();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::collections::HashMap;

use crate::compile::{DatabaseVariable, DatabaseVariables};
use datafusion::scalar::ScalarValue;

use crate::sql::database_variables::{DatabaseVariable, DatabaseVariables};

pub fn defaults() -> DatabaseVariables {
let mut variables: DatabaseVariables = HashMap::new();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::HashMap;

use datafusion::scalar::ScalarValue;

use crate::sql::database_variables::{DatabaseVariable, DatabaseVariables};
use crate::compile::{DatabaseVariable, DatabaseVariables};

pub fn defaults() -> DatabaseVariables {
let mut variables: DatabaseVariables = HashMap::new();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use datafusion::scalar::ScalarValue;
use std::collections::HashMap;

use crate::sql::database_variables::{DatabaseVariable, DatabaseVariables};
use crate::compile::{DatabaseVariable, DatabaseVariables};

pub fn defaults() -> DatabaseVariables {
let mut variables: DatabaseVariables = HashMap::new();
Expand Down
9 changes: 2 additions & 7 deletions rust/cubesql/cubesql/src/sql/server_manager.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
use crate::{
compile::DatabaseProtocol,
compile::{DatabaseProtocol, DatabaseVariables, DatabaseVariablesToUpdate},
config::ConfigObj,
sql::{
compiler_cache::CompilerCache,
database_variables::{
mysql_default_global_variables, postgres_default_global_variables,
DatabaseVariablesToUpdate,
},
database_variables::{mysql_default_global_variables, postgres_default_global_variables},
SqlAuthService,
},
transport::TransportService,
CubeError,
};
use std::sync::{Arc, RwLock as RwLockSync, RwLockReadGuard, RwLockWriteGuard};

use super::database_variables::DatabaseVariables;

#[derive(Debug)]
pub struct ServerConfiguration {
/// Max number of prepared statements which can be allocated per connection
Expand Down
25 changes: 8 additions & 17 deletions rust/cubesql/cubesql/src/sql/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,14 @@ use std::{
};
use tokio_util::sync::CancellationToken;

use super::{
database_variables::DatabaseVariables, server_manager::ServerManager,
session_manager::SessionManager, AuthContextRef,
};
use super::{server_manager::ServerManager, session_manager::SessionManager, AuthContextRef};
use crate::{
compile::{DatabaseProtocol, DatabaseProtocolDetails},
compile::{
DatabaseProtocol, DatabaseProtocolDetails, DatabaseVariable, DatabaseVariables,
DatabaseVariablesToUpdate,
},
sql::{
database_variables::{
mysql_default_session_variables, postgres_default_session_variables, DatabaseVariable,
DatabaseVariablesToUpdate,
},
database_variables::{mysql_default_session_variables, postgres_default_session_variables},
extended::PreparedStatement,
temp_tables::TempTableManager,
},
Expand Down Expand Up @@ -328,10 +325,7 @@ impl SessionState {
_ => match &self.protocol {
DatabaseProtocol::MySQL => return MYSQL_DEFAULT_VARIABLES.clone(),
DatabaseProtocol::PostgreSQL => return POSTGRES_DEFAULT_VARIABLES.clone(),
DatabaseProtocol::Extension(ext) => unimplemented!(
"Session.all_variables is not implemented for custom protocol: {:?}",
ext
),
DatabaseProtocol::Extension(ext) => ext.get_session_default_variables(),
},
}
}
Expand All @@ -349,10 +343,7 @@ impl SessionState {
DatabaseProtocol::PostgreSQL => {
POSTGRES_DEFAULT_VARIABLES.get(name).map(|v| v.clone())
}
DatabaseProtocol::Extension(ext) => unimplemented!(
"Session.get_variable is not implemented for custom protocol: {:?}",
ext
),
DatabaseProtocol::Extension(ext) => ext.get_session_variable_default(name),
},
}
}
Expand Down

0 comments on commit a86dc09

Please sign in to comment.