with last commit, we can simplify the logic + remove old 2025_ prefix for search of the tables

This commit is contained in:
Priec
2025-10-17 22:55:59 +02:00
parent 241ab99584
commit 492f1f1e55
2 changed files with 174 additions and 165 deletions

View File

@@ -3,7 +3,7 @@
use steel::steel_vm::engine::Engine; use steel::steel_vm::engine::Engine;
use steel::steel_vm::register_fn::RegisterFn; use steel::steel_vm::register_fn::RegisterFn;
use steel::rvals::SteelVal; use steel::rvals::SteelVal;
use super::functions::{SteelContext, convert_row_data_for_steel}; use super::functions::SteelContext;
use steel_decimal::registry::FunctionRegistry; use steel_decimal::registry::FunctionRegistry;
use sqlx::PgPool; use sqlx::PgPool;
use std::sync::Arc; use std::sync::Arc;
@@ -62,6 +62,48 @@ fn auto_promote_with_index(
.into_owned() .into_owned()
} }
use common::proto::komp_ac::table_definition::ColumnDefinition;
// Converts row data boolean values to Steel script format during context initialization.
pub async fn convert_row_data_for_steel(
db_pool: &PgPool,
schema_id: i64,
table_name: &str,
row_data: &mut HashMap<String, String>,
) -> Result<(), sqlx::Error> {
let table_def = sqlx::query!(
r#"
SELECT columns FROM table_definitions
WHERE schema_id = $1 AND table_name = $2
"#,
schema_id,
table_name
)
.fetch_optional(db_pool)
.await?
.ok_or_else(|| sqlx::Error::RowNotFound)?;
// Parse column definitions to identify boolean columns
if let Ok(columns) = serde_json::from_value::<Vec<ColumnDefinition>>(table_def.columns) {
for col_def in columns {
let normalized_type =
col_def.field_type.to_uppercase().split('(').next().unwrap().to_string();
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
if let Some(value) = row_data.get_mut(&col_def.name) {
*value = match value.to_lowercase().as_str() {
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
"false" | "f" | "0" | "no" | "off" => "#false".to_string(),
_ => value.clone(),
};
}
}
}
}
Ok(())
}
pub async fn create_steel_context_with_boolean_conversion( pub async fn create_steel_context_with_boolean_conversion(
current_table: String, current_table: String,
schema_id: i64, schema_id: i64,

View File

@@ -22,10 +22,8 @@ pub enum FunctionError {
ProhibitedTypeAccess(String), ProhibitedTypeAccess(String),
} }
/// Data types that Steel scripts are prohibited from accessing for security reasons
const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"]; const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
/// Execution context for Steel scripts with database access capabilities.
#[derive(Clone)] #[derive(Clone)]
pub struct SteelContext { pub struct SteelContext {
pub current_table: String, pub current_table: String,
@@ -36,26 +34,11 @@ pub struct SteelContext {
} }
impl SteelContext { impl SteelContext {
/// Resolves a base table name to its full qualified name in the current schema. async fn get_column_type(
/// Used for foreign key relationship traversal in Steel scripts. &self,
pub async fn get_related_table_name(&self, base_name: &str) -> Result<String, FunctionError> { table_name: &str,
let table_def = sqlx::query!( column_name: &str,
r#"SELECT table_name FROM table_definitions ) -> Result<String, FunctionError> {
WHERE schema_id = $1 AND table_name LIKE $2"#,
self.schema_id,
format!("%_{}", base_name)
)
.fetch_optional(&*self.db_pool)
.await
.map_err(|e| FunctionError::DatabaseError(e.to_string()))?
.ok_or_else(|| FunctionError::TableNotFound(base_name.to_string()))?;
Ok(table_def.table_name)
}
/// Retrieves the SQL data type for a specific column in a table.
/// Parses the JSON column definitions to find type information.
async fn get_column_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
let table_def = sqlx::query!( let table_def = sqlx::query!(
r#"SELECT columns FROM table_definitions r#"SELECT columns FROM table_definitions
WHERE schema_id = $1 AND table_name = $2"#, WHERE schema_id = $1 AND table_name = $2"#,
@@ -68,7 +51,10 @@ impl SteelContext {
.ok_or_else(|| FunctionError::TableNotFound(table_name.to_string()))?; .ok_or_else(|| FunctionError::TableNotFound(table_name.to_string()))?;
let columns: Vec<ColumnDefinition> = serde_json::from_value(table_def.columns) let columns: Vec<ColumnDefinition> = serde_json::from_value(table_def.columns)
.map_err(|e| FunctionError::DatabaseError(format!("Invalid column data: {}", e)))?; .map_err(|e| FunctionError::DatabaseError(format!(
"Invalid column data: {}",
e
)))?;
for col_def in columns { for col_def in columns {
if col_def.name == column_name { if col_def.name == column_name {
@@ -78,33 +64,29 @@ impl SteelContext {
Err(FunctionError::ColumnNotFound(format!( Err(FunctionError::ColumnNotFound(format!(
"Column '{}' not found in table '{}'", "Column '{}' not found in table '{}'",
column_name, column_name, table_name
table_name
))) )))
} }
/// Converts database values to Steel script format based on column type.
/// Currently handles boolean conversion to Steel's #true/#false syntax.
fn convert_value_to_steel_format(&self, value: &str, column_type: &str) -> String { fn convert_value_to_steel_format(&self, value: &str, column_type: &str) -> String {
let normalized_type = normalize_data_type(column_type); let normalized_type = normalize_data_type(column_type);
match normalized_type.as_str() { match normalized_type.as_str() {
"BOOLEAN" | "BOOL" => { "BOOLEAN" | "BOOL" => match value.to_lowercase().as_str() {
// Convert database boolean representations to Steel boolean syntax
match value.to_lowercase().as_str() {
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(), "true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
"false" | "f" | "0" | "no" | "off" => "#false".to_string(), "false" | "f" | "0" | "no" | "off" => "#false".to_string(),
_ => value.to_string(), // Return as-is if not a recognized boolean _ => value.to_string(),
} },
}
"INTEGER" => value.to_string(), "INTEGER" => value.to_string(),
_ => value.to_string(), // Return as-is for other types _ => value.to_string(),
} }
} }
/// Validates that a column type is allowed for Steel script access. async fn validate_column_type_and_get_type(
/// Returns the column type if validation passes, error if prohibited. &self,
async fn validate_column_type_and_get_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> { table_name: &str,
column_name: &str,
) -> Result<String, FunctionError> {
let column_type = self.get_column_type(table_name, column_name).await?; let column_type = self.get_column_type(table_name, column_name).await?;
if is_prohibited_type(&column_type) { if is_prohibited_type(&column_type) {
@@ -120,15 +102,13 @@ impl SteelContext {
Ok(column_type) Ok(column_type)
} }
/// Retrieves column value from current table or related tables via foreign keys. pub fn steel_get_column(
/// &self,
/// # Behavior table: &str,
/// - Current table: Returns value directly from row_data with type conversion column: &str,
/// - Related table: Follows foreign key relationship and queries database ) -> Result<SteelVal, SteelVal> {
/// - All accesses are subject to prohibited type validation
pub fn steel_get_column(&self, table: &str, column: &str) -> Result<SteelVal, SteelVal> {
if table == self.current_table { if table == self.current_table {
// Access current table data with type validation // current table
let column_type = tokio::task::block_in_place(|| { let column_type = tokio::task::block_in_place(|| {
let handle = tokio::runtime::Handle::current(); let handle = tokio::runtime::Handle::current();
handle.block_on(async { handle.block_on(async {
@@ -141,71 +121,112 @@ impl SteelContext {
Err(e) => return Err(SteelVal::StringV(e.to_string().into())), Err(e) => return Err(SteelVal::StringV(e.to_string().into())),
}; };
return self.row_data.get(column) return self
.row_data
.get(column)
.map(|v| { .map(|v| {
let converted_value = self.convert_value_to_steel_format(v, &column_type); let converted =
SteelVal::StringV(converted_value.into()) self.convert_value_to_steel_format(v, &column_type);
SteelVal::StringV(converted.into())
}) })
.ok_or_else(|| SteelVal::StringV(format!("Column {} not found", column).into())); .ok_or_else(|| {
SteelVal::StringV(
format!("Column {} not found", column).into(),
)
});
} }
// Access related table via foreign key relationship // Cross-table via FK: use exact table name FK convention: "<table>_id"
// TODO REDO REMOVE YEAR PREFIX DEPRECATION
let base_name = table.split_once('_')
.map(|(_, rest)| rest)
.unwrap_or(table);
let fk_column = format!("{}_id", base_name);
let fk_value = self.row_data.get(&fk_column)
.ok_or_else(|| SteelVal::StringV(format!("Foreign key {} not found", fk_column).into()))?;
let result = tokio::task::block_in_place(|| { let result = tokio::task::block_in_place(|| {
let handle = tokio::runtime::Handle::current(); let handle = tokio::runtime::Handle::current();
handle.block_on(async { handle.block_on(async {
let actual_table = self.get_related_table_name(base_name).await let fk_column = format!("{}_id", table);
.map_err(|e| SteelVal::StringV(e.to_string().into()))?; let fk_value = self
.row_data
.get(&fk_column)
.ok_or_else(|| {
FunctionError::ForeignKeyNotFound(format!(
"Foreign key column '{}' not found on '{}'",
fk_column, self.current_table
))
})?;
// Validate column type and get type information let column_type =
let column_type = self.validate_column_type_and_get_type(&actual_table, column).await self.validate_column_type_and_get_type(table, column)
.map_err(|e| SteelVal::StringV(e.to_string().into()))?; .await?;
// Query the related table for the column value let raw_value = sqlx::query_scalar::<_, String>(&format!(
let raw_value = sqlx::query_scalar::<_, String>( "SELECT \"{}\" FROM \"{}\".\"{}\" WHERE id = $1",
&format!("SELECT {} FROM \"{}\".\"{}\" WHERE id = $1", column, self.schema_name, actual_table) column, self.schema_name, table
))
.bind(
fk_value
.parse::<i64>()
.map_err(|_| {
FunctionError::DatabaseError(
"Invalid foreign key format".into(),
)
})?,
) )
.bind(fk_value.parse::<i64>().map_err(|_|
SteelVal::StringV("Invalid foreign key format".into()))?)
.fetch_one(&*self.db_pool) .fetch_one(&*self.db_pool)
.await .await
.map_err(|e| SteelVal::StringV(e.to_string().into()))?; .map_err(|e| FunctionError::DatabaseError(e.to_string()))?;
// Convert to appropriate Steel format let converted =
let converted_value = self.convert_value_to_steel_format(&raw_value, &column_type); self.convert_value_to_steel_format(&raw_value, &column_type);
Ok(converted_value) Ok::<String, FunctionError>(converted)
}) })
}); });
result.map(|v| SteelVal::StringV(v.into())) match result {
Ok(v) => Ok(SteelVal::StringV(v.into())),
Err(e) => Err(SteelVal::StringV(e.to_string().into())),
}
} }
/// Retrieves a specific indexed element from a comma-separated column value.
/// Useful for accessing elements from array-like string representations.
pub fn steel_get_column_with_index( pub fn steel_get_column_with_index(
&self, &self,
table: &str, table: &str,
index: i64, index: i64,
column: &str column: &str,
) -> Result<SteelVal, SteelVal> { ) -> Result<SteelVal, SteelVal> {
// Get the full value with proper type conversion // Cross-table: interpret 'index' as the row id to fetch directly
let value = self.steel_get_column(table, column)?; if table != self.current_table {
let result = tokio::task::block_in_place(|| {
let handle = tokio::runtime::Handle::current();
handle.block_on(async {
let column_type =
self.validate_column_type_and_get_type(table, column)
.await?;
let raw_value = sqlx::query_scalar::<_, String>(&format!(
"SELECT \"{}\" FROM \"{}\".\"{}\" WHERE id = $1",
column, self.schema_name, table
))
.bind(index)
.fetch_one(&*self.db_pool)
.await
.map_err(|e| FunctionError::DatabaseError(e.to_string()))?;
let converted = self
.convert_value_to_steel_format(&raw_value, &column_type);
Ok::<String, FunctionError>(converted)
})
});
return match result {
Ok(v) => Ok(SteelVal::StringV(v.into())),
Err(e) => Err(SteelVal::StringV(e.to_string().into())),
};
}
// Current table: existing behavior (index in comma-separated string)
let value = self.steel_get_column(table, column)?;
if let SteelVal::StringV(s) = value { if let SteelVal::StringV(s) = value {
let parts: Vec<_> = s.split(',').collect(); let parts: Vec<_> = s.split(',').collect();
if let Some(part) = parts.get(index as usize) { if let Some(part) = parts.get(index as usize) {
let trimmed_part = part.trim(); let trimmed = part.trim();
// Apply type conversion to the indexed part based on original column type
let column_type = tokio::task::block_in_place(|| { let column_type = tokio::task::block_in_place(|| {
let handle = tokio::runtime::Handle::current(); let handle = tokio::runtime::Handle::current();
handle.block_on(async { handle.block_on(async {
@@ -215,40 +236,35 @@ impl SteelContext {
match column_type { match column_type {
Ok(ct) => { Ok(ct) => {
let converted_part = self.convert_value_to_steel_format(trimmed_part, &ct); let converted =
Ok(SteelVal::StringV(converted_part.into())) self.convert_value_to_steel_format(trimmed, &ct);
} Ok(SteelVal::StringV(converted.into()))
Err(_) => {
// If type cannot be determined, return value as-is
Ok(SteelVal::StringV(trimmed_part.into()))
} }
Err(_) => Ok(SteelVal::StringV(trimmed.into())),
} }
} else { } else {
Err(SteelVal::StringV("Index out of bounds".into())) Err(SteelVal::StringV("Index out of bounds".into()))
} }
} else { } else {
Err(SteelVal::StringV("Expected comma-separated string".into())) Err(SteelVal::StringV(
"Expected comma-separated string".into(),
))
} }
} }
/// Executes read-only SQL queries from Steel scripts with safety restrictions.
///
/// # Security Features
/// - Only SELECT, SHOW, and EXPLAIN queries allowed
/// - Prohibited column type access validation
/// - Returns first column of all rows as comma-separated string
pub fn steel_query_sql(&self, query: &str) -> Result<SteelVal, SteelVal> { pub fn steel_query_sql(&self, query: &str) -> Result<SteelVal, SteelVal> {
if !is_read_only_query(query) { if !is_read_only_query(query) {
return Err(SteelVal::StringV( return Err(SteelVal::StringV("Only SELECT queries are allowed".into()));
"Only SELECT queries are allowed".into()
));
} }
if contains_prohibited_column_access(query) { if contains_prohibited_column_access(query) {
return Err(SteelVal::StringV(format!( return Err(SteelVal::StringV(
format!(
"SQL query may access prohibited column types. Steel scripts cannot access columns of type: {}", "SQL query may access prohibited column types. Steel scripts cannot access columns of type: {}",
PROHIBITED_TYPES.join(", ") PROHIBITED_TYPES.join(", ")
).into())); )
.into(),
));
} }
let pool = self.db_pool.clone(); let pool = self.db_pool.clone();
@@ -263,7 +279,8 @@ impl SteelContext {
let mut results = Vec::new(); let mut results = Vec::new();
for row in rows { for row in rows {
let val: String = row.try_get(0) let val: String = row
.try_get(0)
.map_err(|e| SteelVal::StringV(e.to_string().into()))?; .map_err(|e| SteelVal::StringV(e.to_string().into()))?;
results.push(val); results.push(val);
} }
@@ -276,80 +293,30 @@ impl SteelContext {
} }
} }
/// Checks if a data type is prohibited for Steel script access.
fn is_prohibited_type(data_type: &str) -> bool { fn is_prohibited_type(data_type: &str) -> bool {
let normalized_type = normalize_data_type(data_type); let normalized_type = normalize_data_type(data_type);
PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) PROHIBITED_TYPES
.iter()
.any(|&prohibited| normalized_type.starts_with(prohibited))
} }
/// Normalizes data type strings for consistent comparison.
/// Handles variations like NUMERIC(10,2) by extracting base type.
fn normalize_data_type(data_type: &str) -> String { fn normalize_data_type(data_type: &str) -> String {
data_type.to_uppercase() data_type
.split('(') // Remove precision/scale from NUMERIC(x,y) .to_uppercase()
.split('(')
.next() .next()
.unwrap_or(data_type) .unwrap_or(data_type)
.trim() .trim()
.to_string() .to_string()
} }
/// Performs basic heuristic check for prohibited column type access in SQL queries.
/// Looks for common patterns that might indicate access to restricted types.
fn contains_prohibited_column_access(query: &str) -> bool { fn contains_prohibited_column_access(query: &str) -> bool {
let query_upper = query.to_uppercase(); let query_upper = query.to_uppercase();
let patterns = ["EXTRACT(", "DATE_PART(", "::DATE", "::TIMESTAMPTZ", "::BIGINT"];
let patterns = [ patterns.iter().any(|p| query_upper.contains(p))
"EXTRACT(", // Common with DATE/TIMESTAMPTZ
"DATE_PART(", // Common with DATE/TIMESTAMPTZ
"::DATE",
"::TIMESTAMPTZ",
"::BIGINT",
];
patterns.iter().any(|pattern| query_upper.contains(pattern))
} }
/// Validates that a query is read-only and safe for Steel script execution.
fn is_read_only_query(query: &str) -> bool { fn is_read_only_query(query: &str) -> bool {
let query = query.trim_start().to_uppercase(); let query = query.trim_start().to_uppercase();
query.starts_with("SELECT") || query.starts_with("SELECT") || query.starts_with("SHOW") || query.starts_with("EXPLAIN")
query.starts_with("SHOW") ||
query.starts_with("EXPLAIN")
}
/// Converts row data boolean values to Steel script format during context initialization.
pub async fn convert_row_data_for_steel(
db_pool: &PgPool,
schema_id: i64,
table_name: &str,
row_data: &mut HashMap<String, String>,
) -> Result<(), sqlx::Error> {
let table_def = sqlx::query!(
r#"SELECT columns FROM table_definitions
WHERE schema_id = $1 AND table_name = $2"#,
schema_id,
table_name
)
.fetch_optional(db_pool)
.await?
.ok_or_else(|| sqlx::Error::RowNotFound)?;
// Parse column definitions to identify boolean columns for conversion
if let Ok(columns) = serde_json::from_value::<Vec<ColumnDefinition>>(table_def.columns) {
for col_def in columns {
let normalized_type = normalize_data_type(&col_def.field_type);
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
if let Some(value) = row_data.get_mut(&col_def.name) {
*value = match value.to_lowercase().as_str() {
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
"false" | "f" | "0" | "no" | "off" => "#false".to_string(),
_ => value.clone(),
};
}
}
}
}
Ok(())
} }