with last commit, we can simplify the logic + remove old 2025_ prefix for search of the tables
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
use steel::steel_vm::engine::Engine;
|
||||
use steel::steel_vm::register_fn::RegisterFn;
|
||||
use steel::rvals::SteelVal;
|
||||
use super::functions::{SteelContext, convert_row_data_for_steel};
|
||||
use super::functions::SteelContext;
|
||||
use steel_decimal::registry::FunctionRegistry;
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
@@ -62,6 +62,48 @@ fn auto_promote_with_index(
|
||||
.into_owned()
|
||||
}
|
||||
|
||||
use common::proto::komp_ac::table_definition::ColumnDefinition;
|
||||
|
||||
// Converts row data boolean values to Steel script format during context initialization.
|
||||
pub async fn convert_row_data_for_steel(
|
||||
db_pool: &PgPool,
|
||||
schema_id: i64,
|
||||
table_name: &str,
|
||||
row_data: &mut HashMap<String, String>,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
let table_def = sqlx::query!(
|
||||
r#"
|
||||
SELECT columns FROM table_definitions
|
||||
WHERE schema_id = $1 AND table_name = $2
|
||||
"#,
|
||||
schema_id,
|
||||
table_name
|
||||
)
|
||||
.fetch_optional(db_pool)
|
||||
.await?
|
||||
.ok_or_else(|| sqlx::Error::RowNotFound)?;
|
||||
|
||||
// Parse column definitions to identify boolean columns
|
||||
if let Ok(columns) = serde_json::from_value::<Vec<ColumnDefinition>>(table_def.columns) {
|
||||
for col_def in columns {
|
||||
let normalized_type =
|
||||
col_def.field_type.to_uppercase().split('(').next().unwrap().to_string();
|
||||
|
||||
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
|
||||
if let Some(value) = row_data.get_mut(&col_def.name) {
|
||||
*value = match value.to_lowercase().as_str() {
|
||||
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
|
||||
"false" | "f" | "0" | "no" | "off" => "#false".to_string(),
|
||||
_ => value.clone(),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create_steel_context_with_boolean_conversion(
|
||||
current_table: String,
|
||||
schema_id: i64,
|
||||
|
||||
@@ -22,10 +22,8 @@ pub enum FunctionError {
|
||||
ProhibitedTypeAccess(String),
|
||||
}
|
||||
|
||||
/// Data types that Steel scripts are prohibited from accessing for security reasons
|
||||
const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
|
||||
|
||||
/// Execution context for Steel scripts with database access capabilities.
|
||||
#[derive(Clone)]
|
||||
pub struct SteelContext {
|
||||
pub current_table: String,
|
||||
@@ -36,26 +34,11 @@ pub struct SteelContext {
|
||||
}
|
||||
|
||||
impl SteelContext {
|
||||
/// Resolves a base table name to its full qualified name in the current schema.
|
||||
/// Used for foreign key relationship traversal in Steel scripts.
|
||||
pub async fn get_related_table_name(&self, base_name: &str) -> Result<String, FunctionError> {
|
||||
let table_def = sqlx::query!(
|
||||
r#"SELECT table_name FROM table_definitions
|
||||
WHERE schema_id = $1 AND table_name LIKE $2"#,
|
||||
self.schema_id,
|
||||
format!("%_{}", base_name)
|
||||
)
|
||||
.fetch_optional(&*self.db_pool)
|
||||
.await
|
||||
.map_err(|e| FunctionError::DatabaseError(e.to_string()))?
|
||||
.ok_or_else(|| FunctionError::TableNotFound(base_name.to_string()))?;
|
||||
|
||||
Ok(table_def.table_name)
|
||||
}
|
||||
|
||||
/// Retrieves the SQL data type for a specific column in a table.
|
||||
/// Parses the JSON column definitions to find type information.
|
||||
async fn get_column_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
|
||||
async fn get_column_type(
|
||||
&self,
|
||||
table_name: &str,
|
||||
column_name: &str,
|
||||
) -> Result<String, FunctionError> {
|
||||
let table_def = sqlx::query!(
|
||||
r#"SELECT columns FROM table_definitions
|
||||
WHERE schema_id = $1 AND table_name = $2"#,
|
||||
@@ -68,7 +51,10 @@ impl SteelContext {
|
||||
.ok_or_else(|| FunctionError::TableNotFound(table_name.to_string()))?;
|
||||
|
||||
let columns: Vec<ColumnDefinition> = serde_json::from_value(table_def.columns)
|
||||
.map_err(|e| FunctionError::DatabaseError(format!("Invalid column data: {}", e)))?;
|
||||
.map_err(|e| FunctionError::DatabaseError(format!(
|
||||
"Invalid column data: {}",
|
||||
e
|
||||
)))?;
|
||||
|
||||
for col_def in columns {
|
||||
if col_def.name == column_name {
|
||||
@@ -78,33 +64,29 @@ impl SteelContext {
|
||||
|
||||
Err(FunctionError::ColumnNotFound(format!(
|
||||
"Column '{}' not found in table '{}'",
|
||||
column_name,
|
||||
table_name
|
||||
column_name, table_name
|
||||
)))
|
||||
}
|
||||
|
||||
/// Converts database values to Steel script format based on column type.
|
||||
/// Currently handles boolean conversion to Steel's #true/#false syntax.
|
||||
fn convert_value_to_steel_format(&self, value: &str, column_type: &str) -> String {
|
||||
let normalized_type = normalize_data_type(column_type);
|
||||
|
||||
match normalized_type.as_str() {
|
||||
"BOOLEAN" | "BOOL" => {
|
||||
// Convert database boolean representations to Steel boolean syntax
|
||||
match value.to_lowercase().as_str() {
|
||||
"BOOLEAN" | "BOOL" => match value.to_lowercase().as_str() {
|
||||
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
|
||||
"false" | "f" | "0" | "no" | "off" => "#false".to_string(),
|
||||
_ => value.to_string(), // Return as-is if not a recognized boolean
|
||||
}
|
||||
}
|
||||
_ => value.to_string(),
|
||||
},
|
||||
"INTEGER" => value.to_string(),
|
||||
_ => value.to_string(), // Return as-is for other types
|
||||
_ => value.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validates that a column type is allowed for Steel script access.
|
||||
/// Returns the column type if validation passes, error if prohibited.
|
||||
async fn validate_column_type_and_get_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
|
||||
async fn validate_column_type_and_get_type(
|
||||
&self,
|
||||
table_name: &str,
|
||||
column_name: &str,
|
||||
) -> Result<String, FunctionError> {
|
||||
let column_type = self.get_column_type(table_name, column_name).await?;
|
||||
|
||||
if is_prohibited_type(&column_type) {
|
||||
@@ -120,15 +102,13 @@ impl SteelContext {
|
||||
Ok(column_type)
|
||||
}
|
||||
|
||||
/// Retrieves column value from current table or related tables via foreign keys.
|
||||
///
|
||||
/// # Behavior
|
||||
/// - Current table: Returns value directly from row_data with type conversion
|
||||
/// - Related table: Follows foreign key relationship and queries database
|
||||
/// - All accesses are subject to prohibited type validation
|
||||
pub fn steel_get_column(&self, table: &str, column: &str) -> Result<SteelVal, SteelVal> {
|
||||
pub fn steel_get_column(
|
||||
&self,
|
||||
table: &str,
|
||||
column: &str,
|
||||
) -> Result<SteelVal, SteelVal> {
|
||||
if table == self.current_table {
|
||||
// Access current table data with type validation
|
||||
// current table
|
||||
let column_type = tokio::task::block_in_place(|| {
|
||||
let handle = tokio::runtime::Handle::current();
|
||||
handle.block_on(async {
|
||||
@@ -141,71 +121,112 @@ impl SteelContext {
|
||||
Err(e) => return Err(SteelVal::StringV(e.to_string().into())),
|
||||
};
|
||||
|
||||
return self.row_data.get(column)
|
||||
return self
|
||||
.row_data
|
||||
.get(column)
|
||||
.map(|v| {
|
||||
let converted_value = self.convert_value_to_steel_format(v, &column_type);
|
||||
SteelVal::StringV(converted_value.into())
|
||||
let converted =
|
||||
self.convert_value_to_steel_format(v, &column_type);
|
||||
SteelVal::StringV(converted.into())
|
||||
})
|
||||
.ok_or_else(|| SteelVal::StringV(format!("Column {} not found", column).into()));
|
||||
.ok_or_else(|| {
|
||||
SteelVal::StringV(
|
||||
format!("Column {} not found", column).into(),
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
// Access related table via foreign key relationship
|
||||
// TODO REDO REMOVE YEAR PREFIX DEPRECATION
|
||||
let base_name = table.split_once('_')
|
||||
.map(|(_, rest)| rest)
|
||||
.unwrap_or(table);
|
||||
|
||||
let fk_column = format!("{}_id", base_name);
|
||||
let fk_value = self.row_data.get(&fk_column)
|
||||
.ok_or_else(|| SteelVal::StringV(format!("Foreign key {} not found", fk_column).into()))?;
|
||||
|
||||
// Cross-table via FK: use exact table name FK convention: "<table>_id"
|
||||
let result = tokio::task::block_in_place(|| {
|
||||
let handle = tokio::runtime::Handle::current();
|
||||
handle.block_on(async {
|
||||
let actual_table = self.get_related_table_name(base_name).await
|
||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
||||
let fk_column = format!("{}_id", table);
|
||||
let fk_value = self
|
||||
.row_data
|
||||
.get(&fk_column)
|
||||
.ok_or_else(|| {
|
||||
FunctionError::ForeignKeyNotFound(format!(
|
||||
"Foreign key column '{}' not found on '{}'",
|
||||
fk_column, self.current_table
|
||||
))
|
||||
})?;
|
||||
|
||||
// Validate column type and get type information
|
||||
let column_type = self.validate_column_type_and_get_type(&actual_table, column).await
|
||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
||||
let column_type =
|
||||
self.validate_column_type_and_get_type(table, column)
|
||||
.await?;
|
||||
|
||||
// Query the related table for the column value
|
||||
let raw_value = sqlx::query_scalar::<_, String>(
|
||||
&format!("SELECT {} FROM \"{}\".\"{}\" WHERE id = $1", column, self.schema_name, actual_table)
|
||||
let raw_value = sqlx::query_scalar::<_, String>(&format!(
|
||||
"SELECT \"{}\" FROM \"{}\".\"{}\" WHERE id = $1",
|
||||
column, self.schema_name, table
|
||||
))
|
||||
.bind(
|
||||
fk_value
|
||||
.parse::<i64>()
|
||||
.map_err(|_| {
|
||||
FunctionError::DatabaseError(
|
||||
"Invalid foreign key format".into(),
|
||||
)
|
||||
})?,
|
||||
)
|
||||
.bind(fk_value.parse::<i64>().map_err(|_|
|
||||
SteelVal::StringV("Invalid foreign key format".into()))?)
|
||||
.fetch_one(&*self.db_pool)
|
||||
.await
|
||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
||||
.map_err(|e| FunctionError::DatabaseError(e.to_string()))?;
|
||||
|
||||
// Convert to appropriate Steel format
|
||||
let converted_value = self.convert_value_to_steel_format(&raw_value, &column_type);
|
||||
Ok(converted_value)
|
||||
let converted =
|
||||
self.convert_value_to_steel_format(&raw_value, &column_type);
|
||||
Ok::<String, FunctionError>(converted)
|
||||
})
|
||||
});
|
||||
|
||||
result.map(|v| SteelVal::StringV(v.into()))
|
||||
match result {
|
||||
Ok(v) => Ok(SteelVal::StringV(v.into())),
|
||||
Err(e) => Err(SteelVal::StringV(e.to_string().into())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves a specific indexed element from a comma-separated column value.
|
||||
/// Useful for accessing elements from array-like string representations.
|
||||
pub fn steel_get_column_with_index(
|
||||
&self,
|
||||
table: &str,
|
||||
index: i64,
|
||||
column: &str
|
||||
column: &str,
|
||||
) -> Result<SteelVal, SteelVal> {
|
||||
// Get the full value with proper type conversion
|
||||
let value = self.steel_get_column(table, column)?;
|
||||
// Cross-table: interpret 'index' as the row id to fetch directly
|
||||
if table != self.current_table {
|
||||
let result = tokio::task::block_in_place(|| {
|
||||
let handle = tokio::runtime::Handle::current();
|
||||
handle.block_on(async {
|
||||
let column_type =
|
||||
self.validate_column_type_and_get_type(table, column)
|
||||
.await?;
|
||||
|
||||
let raw_value = sqlx::query_scalar::<_, String>(&format!(
|
||||
"SELECT \"{}\" FROM \"{}\".\"{}\" WHERE id = $1",
|
||||
column, self.schema_name, table
|
||||
))
|
||||
.bind(index)
|
||||
.fetch_one(&*self.db_pool)
|
||||
.await
|
||||
.map_err(|e| FunctionError::DatabaseError(e.to_string()))?;
|
||||
|
||||
let converted = self
|
||||
.convert_value_to_steel_format(&raw_value, &column_type);
|
||||
Ok::<String, FunctionError>(converted)
|
||||
})
|
||||
});
|
||||
|
||||
return match result {
|
||||
Ok(v) => Ok(SteelVal::StringV(v.into())),
|
||||
Err(e) => Err(SteelVal::StringV(e.to_string().into())),
|
||||
};
|
||||
}
|
||||
|
||||
// Current table: existing behavior (index in comma-separated string)
|
||||
let value = self.steel_get_column(table, column)?;
|
||||
if let SteelVal::StringV(s) = value {
|
||||
let parts: Vec<_> = s.split(',').collect();
|
||||
|
||||
if let Some(part) = parts.get(index as usize) {
|
||||
let trimmed_part = part.trim();
|
||||
let trimmed = part.trim();
|
||||
|
||||
// Apply type conversion to the indexed part based on original column type
|
||||
let column_type = tokio::task::block_in_place(|| {
|
||||
let handle = tokio::runtime::Handle::current();
|
||||
handle.block_on(async {
|
||||
@@ -215,40 +236,35 @@ impl SteelContext {
|
||||
|
||||
match column_type {
|
||||
Ok(ct) => {
|
||||
let converted_part = self.convert_value_to_steel_format(trimmed_part, &ct);
|
||||
Ok(SteelVal::StringV(converted_part.into()))
|
||||
}
|
||||
Err(_) => {
|
||||
// If type cannot be determined, return value as-is
|
||||
Ok(SteelVal::StringV(trimmed_part.into()))
|
||||
let converted =
|
||||
self.convert_value_to_steel_format(trimmed, &ct);
|
||||
Ok(SteelVal::StringV(converted.into()))
|
||||
}
|
||||
Err(_) => Ok(SteelVal::StringV(trimmed.into())),
|
||||
}
|
||||
} else {
|
||||
Err(SteelVal::StringV("Index out of bounds".into()))
|
||||
}
|
||||
} else {
|
||||
Err(SteelVal::StringV("Expected comma-separated string".into()))
|
||||
Err(SteelVal::StringV(
|
||||
"Expected comma-separated string".into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes read-only SQL queries from Steel scripts with safety restrictions.
|
||||
///
|
||||
/// # Security Features
|
||||
/// - Only SELECT, SHOW, and EXPLAIN queries allowed
|
||||
/// - Prohibited column type access validation
|
||||
/// - Returns first column of all rows as comma-separated string
|
||||
pub fn steel_query_sql(&self, query: &str) -> Result<SteelVal, SteelVal> {
|
||||
if !is_read_only_query(query) {
|
||||
return Err(SteelVal::StringV(
|
||||
"Only SELECT queries are allowed".into()
|
||||
));
|
||||
return Err(SteelVal::StringV("Only SELECT queries are allowed".into()));
|
||||
}
|
||||
|
||||
if contains_prohibited_column_access(query) {
|
||||
return Err(SteelVal::StringV(format!(
|
||||
return Err(SteelVal::StringV(
|
||||
format!(
|
||||
"SQL query may access prohibited column types. Steel scripts cannot access columns of type: {}",
|
||||
PROHIBITED_TYPES.join(", ")
|
||||
).into()));
|
||||
)
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool = self.db_pool.clone();
|
||||
@@ -263,7 +279,8 @@ impl SteelContext {
|
||||
|
||||
let mut results = Vec::new();
|
||||
for row in rows {
|
||||
let val: String = row.try_get(0)
|
||||
let val: String = row
|
||||
.try_get(0)
|
||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
||||
results.push(val);
|
||||
}
|
||||
@@ -276,80 +293,30 @@ impl SteelContext {
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if a data type is prohibited for Steel script access.
|
||||
fn is_prohibited_type(data_type: &str) -> bool {
|
||||
let normalized_type = normalize_data_type(data_type);
|
||||
PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited))
|
||||
PROHIBITED_TYPES
|
||||
.iter()
|
||||
.any(|&prohibited| normalized_type.starts_with(prohibited))
|
||||
}
|
||||
|
||||
/// Normalizes data type strings for consistent comparison.
|
||||
/// Handles variations like NUMERIC(10,2) by extracting base type.
|
||||
fn normalize_data_type(data_type: &str) -> String {
|
||||
data_type.to_uppercase()
|
||||
.split('(') // Remove precision/scale from NUMERIC(x,y)
|
||||
data_type
|
||||
.to_uppercase()
|
||||
.split('(')
|
||||
.next()
|
||||
.unwrap_or(data_type)
|
||||
.trim()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Performs basic heuristic check for prohibited column type access in SQL queries.
|
||||
/// Looks for common patterns that might indicate access to restricted types.
|
||||
fn contains_prohibited_column_access(query: &str) -> bool {
|
||||
let query_upper = query.to_uppercase();
|
||||
|
||||
let patterns = [
|
||||
"EXTRACT(", // Common with DATE/TIMESTAMPTZ
|
||||
"DATE_PART(", // Common with DATE/TIMESTAMPTZ
|
||||
"::DATE",
|
||||
"::TIMESTAMPTZ",
|
||||
"::BIGINT",
|
||||
];
|
||||
|
||||
patterns.iter().any(|pattern| query_upper.contains(pattern))
|
||||
let patterns = ["EXTRACT(", "DATE_PART(", "::DATE", "::TIMESTAMPTZ", "::BIGINT"];
|
||||
patterns.iter().any(|p| query_upper.contains(p))
|
||||
}
|
||||
|
||||
/// Validates that a query is read-only and safe for Steel script execution.
|
||||
fn is_read_only_query(query: &str) -> bool {
|
||||
let query = query.trim_start().to_uppercase();
|
||||
query.starts_with("SELECT") ||
|
||||
query.starts_with("SHOW") ||
|
||||
query.starts_with("EXPLAIN")
|
||||
}
|
||||
|
||||
/// Converts row data boolean values to Steel script format during context initialization.
|
||||
pub async fn convert_row_data_for_steel(
|
||||
db_pool: &PgPool,
|
||||
schema_id: i64,
|
||||
table_name: &str,
|
||||
row_data: &mut HashMap<String, String>,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
let table_def = sqlx::query!(
|
||||
r#"SELECT columns FROM table_definitions
|
||||
WHERE schema_id = $1 AND table_name = $2"#,
|
||||
schema_id,
|
||||
table_name
|
||||
)
|
||||
.fetch_optional(db_pool)
|
||||
.await?
|
||||
.ok_or_else(|| sqlx::Error::RowNotFound)?;
|
||||
|
||||
// Parse column definitions to identify boolean columns for conversion
|
||||
if let Ok(columns) = serde_json::from_value::<Vec<ColumnDefinition>>(table_def.columns) {
|
||||
for col_def in columns {
|
||||
let normalized_type = normalize_data_type(&col_def.field_type);
|
||||
|
||||
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
|
||||
if let Some(value) = row_data.get_mut(&col_def.name) {
|
||||
*value = match value.to_lowercase().as_str() {
|
||||
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
|
||||
"false" | "f" | "0" | "no" | "off" => "#false".to_string(),
|
||||
_ => value.clone(),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
query.starts_with("SELECT") || query.starts_with("SHOW") || query.starts_with("EXPLAIN")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user