Compare commits
11 Commits
ceb560c658
...
492f1f1e55
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
492f1f1e55 | ||
|
|
241ab99584 | ||
|
|
8bd5b5c62f | ||
|
|
7e21258d2e | ||
|
|
49277cfdd4 | ||
|
|
1f6dc3cd75 | ||
|
|
7350b0985c | ||
|
|
73bc6dc99c | ||
|
|
095645a209 | ||
|
|
532977056d | ||
|
|
2435f58256 |
@@ -29,6 +29,30 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
".komp_ac.table_validation.CountMode",
|
||||
"#[derive(serde::Serialize, serde::Deserialize)] #[serde(rename_all = \"SCREAMING_SNAKE_CASE\")]",
|
||||
)
|
||||
.type_attribute(
|
||||
".komp_ac.table_definition.ColumnDefinition",
|
||||
"#[derive(serde::Serialize, serde::Deserialize)]",
|
||||
)
|
||||
.type_attribute(
|
||||
".komp_ac.table_definition.TableLink",
|
||||
"#[derive(serde::Serialize, serde::Deserialize)]"
|
||||
)
|
||||
.type_attribute(
|
||||
".komp_ac.table_definition.PostTableDefinitionRequest",
|
||||
"#[derive(serde::Serialize, serde::Deserialize)]",
|
||||
)
|
||||
.type_attribute(
|
||||
".komp_ac.table_definition.TableDefinitionResponse",
|
||||
"#[derive(serde::Serialize, serde::Deserialize)]"
|
||||
)
|
||||
.type_attribute(
|
||||
".komp_ac.table_script.PostTableScriptRequest",
|
||||
"#[derive(serde::Serialize, serde::Deserialize)]",
|
||||
)
|
||||
.type_attribute(
|
||||
".komp_ac.table_script.TableScriptResponse",
|
||||
"#[derive(serde::Serialize, serde::Deserialize)]",
|
||||
)
|
||||
.compile_protos(
|
||||
&[
|
||||
"proto/common.proto",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// This file is @generated by prost-build.
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct TableLink {
|
||||
#[prost(string, tag = "1")]
|
||||
@@ -6,6 +7,7 @@ pub struct TableLink {
|
||||
#[prost(bool, tag = "2")]
|
||||
pub required: bool,
|
||||
}
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct PostTableDefinitionRequest {
|
||||
#[prost(string, tag = "1")]
|
||||
@@ -19,6 +21,7 @@ pub struct PostTableDefinitionRequest {
|
||||
#[prost(string, tag = "5")]
|
||||
pub profile_name: ::prost::alloc::string::String,
|
||||
}
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct ColumnDefinition {
|
||||
#[prost(string, tag = "1")]
|
||||
@@ -26,6 +29,7 @@ pub struct ColumnDefinition {
|
||||
#[prost(string, tag = "2")]
|
||||
pub field_type: ::prost::alloc::string::String,
|
||||
}
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct TableDefinitionResponse {
|
||||
#[prost(bool, tag = "1")]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// This file is @generated by prost-build.
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct PostTableScriptRequest {
|
||||
#[prost(int64, tag = "1")]
|
||||
@@ -10,6 +11,7 @@ pub struct PostTableScriptRequest {
|
||||
#[prost(string, tag = "4")]
|
||||
pub description: ::prost::alloc::string::String,
|
||||
}
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct TableScriptResponse {
|
||||
#[prost(int64, tag = "1")]
|
||||
|
||||
@@ -3,21 +3,20 @@
|
||||
use steel::steel_vm::engine::Engine;
|
||||
use steel::steel_vm::register_fn::RegisterFn;
|
||||
use steel::rvals::SteelVal;
|
||||
use super::functions::{SteelContext, convert_row_data_for_steel};
|
||||
use super::functions::SteelContext;
|
||||
use steel_decimal::registry::FunctionRegistry;
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, error};
|
||||
use regex::Regex; // NEW
|
||||
|
||||
/// Represents different types of values that can be returned from Steel script execution.
|
||||
#[derive(Debug)]
|
||||
pub enum Value {
|
||||
Strings(Vec<String>),
|
||||
}
|
||||
|
||||
/// Errors that can occur during Steel script execution.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ExecutionError {
|
||||
#[error("Script execution failed: {0}")]
|
||||
@@ -28,7 +27,83 @@ pub enum ExecutionError {
|
||||
UnsupportedType(String),
|
||||
}
|
||||
|
||||
/// Creates a Steel execution context with proper boolean value conversion.
|
||||
// NEW: upgrade steel_get_column -> steel_get_column_with_index using FK present in row_data
|
||||
fn auto_promote_with_index(
|
||||
script: &str,
|
||||
current_table: &str,
|
||||
row_data: &HashMap<String, String>,
|
||||
) -> String {
|
||||
// Matches: (steel_get_column "table" "column")
|
||||
let re = Regex::new(
|
||||
r#"\(\s*steel_get_column\s+"([^"]+)"\s+"([^"]+)"\s*\)"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
re.replace_all(script, |caps: ®ex::Captures| {
|
||||
let table = caps.get(1).unwrap().as_str();
|
||||
let column = caps.get(2).unwrap().as_str();
|
||||
|
||||
// Only upgrade cross-table calls, if FK is present in the request data
|
||||
if table != current_table {
|
||||
let fk_key = format!("{}_id", table);
|
||||
if let Some(id_str) = row_data.get(&fk_key) {
|
||||
if let Ok(_) = id_str.parse::<i64>() {
|
||||
return format!(
|
||||
r#"(steel_get_column_with_index "{}" {} "{}")"#,
|
||||
table, id_str, column
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default: keep original call
|
||||
caps.get(0).unwrap().as_str().to_string()
|
||||
})
|
||||
.into_owned()
|
||||
}
|
||||
|
||||
use common::proto::komp_ac::table_definition::ColumnDefinition;
|
||||
|
||||
// Converts row data boolean values to Steel script format during context initialization.
|
||||
pub async fn convert_row_data_for_steel(
|
||||
db_pool: &PgPool,
|
||||
schema_id: i64,
|
||||
table_name: &str,
|
||||
row_data: &mut HashMap<String, String>,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
let table_def = sqlx::query!(
|
||||
r#"
|
||||
SELECT columns FROM table_definitions
|
||||
WHERE schema_id = $1 AND table_name = $2
|
||||
"#,
|
||||
schema_id,
|
||||
table_name
|
||||
)
|
||||
.fetch_optional(db_pool)
|
||||
.await?
|
||||
.ok_or_else(|| sqlx::Error::RowNotFound)?;
|
||||
|
||||
// Parse column definitions to identify boolean columns
|
||||
if let Ok(columns) = serde_json::from_value::<Vec<ColumnDefinition>>(table_def.columns) {
|
||||
for col_def in columns {
|
||||
let normalized_type =
|
||||
col_def.field_type.to_uppercase().split('(').next().unwrap().to_string();
|
||||
|
||||
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
|
||||
if let Some(value) = row_data.get_mut(&col_def.name) {
|
||||
*value = match value.to_lowercase().as_str() {
|
||||
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
|
||||
"false" | "f" | "0" | "no" | "off" => "#false".to_string(),
|
||||
_ => value.clone(),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create_steel_context_with_boolean_conversion(
|
||||
current_table: String,
|
||||
schema_id: i64,
|
||||
@@ -36,7 +111,6 @@ pub async fn create_steel_context_with_boolean_conversion(
|
||||
mut row_data: HashMap<String, String>,
|
||||
db_pool: Arc<PgPool>,
|
||||
) -> Result<SteelContext, ExecutionError> {
|
||||
// Convert boolean values in row_data to Steel format
|
||||
convert_row_data_for_steel(&db_pool, schema_id, ¤t_table, &mut row_data)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -53,7 +127,6 @@ pub async fn create_steel_context_with_boolean_conversion(
|
||||
})
|
||||
}
|
||||
|
||||
/// Executes a Steel script with database context and type-safe result processing.
|
||||
pub async fn execute_script(
|
||||
script: String,
|
||||
target_type: &str,
|
||||
@@ -65,42 +138,40 @@ pub async fn execute_script(
|
||||
) -> Result<Value, ExecutionError> {
|
||||
let mut vm = Engine::new();
|
||||
|
||||
// Create execution context with proper boolean value conversion
|
||||
// Upgrade to with_index based on FK presence in the posted data
|
||||
let script = auto_promote_with_index(&script, ¤t_table, &row_data);
|
||||
|
||||
let context = create_steel_context_with_boolean_conversion(
|
||||
current_table,
|
||||
current_table.clone(),
|
||||
schema_id,
|
||||
schema_name,
|
||||
row_data,
|
||||
row_data.clone(),
|
||||
db_pool.clone(),
|
||||
).await?;
|
||||
)
|
||||
.await?;
|
||||
|
||||
let context = Arc::new(context);
|
||||
|
||||
// Register database access functions
|
||||
register_steel_functions(&mut vm, context.clone());
|
||||
|
||||
// Register decimal math operations
|
||||
register_decimal_math_functions(&mut vm);
|
||||
|
||||
// Register row data as variables in the Steel VM for get-var access
|
||||
let mut define_script = String::new();
|
||||
|
||||
for (key, value) in &context.row_data {
|
||||
// Register only bare variable names for get-var access
|
||||
define_script.push_str(&format!("(define {} \"{}\")\n", key, value));
|
||||
}
|
||||
|
||||
// Execute variable definitions if any exist
|
||||
if !define_script.is_empty() {
|
||||
vm.compile_and_run_raw_program(define_script)
|
||||
.map_err(|e| ExecutionError::RuntimeError(format!("Failed to register variables: {}", e)))?;
|
||||
.map_err(|e| ExecutionError::RuntimeError(format!(
|
||||
"Failed to register variables: {}",
|
||||
e
|
||||
)))?;
|
||||
}
|
||||
|
||||
// Also register variables using the decimal registry as backup method
|
||||
FunctionRegistry::register_variables(&mut vm, context.row_data.clone());
|
||||
|
||||
// Execute the main script
|
||||
let results = vm.compile_and_run_raw_program(script.clone())
|
||||
let results = vm
|
||||
.compile_and_run_raw_program(script.clone())
|
||||
.map_err(|e| {
|
||||
error!("Steel script execution failed: {}", e);
|
||||
error!("Script was: {}", script);
|
||||
@@ -108,22 +179,22 @@ pub async fn execute_script(
|
||||
ExecutionError::RuntimeError(e.to_string())
|
||||
})?;
|
||||
|
||||
// Convert results to the requested target type
|
||||
match target_type {
|
||||
"STRINGS" => process_string_results(results),
|
||||
_ => Err(ExecutionError::UnsupportedType(target_type.into()))
|
||||
_ => Err(ExecutionError::UnsupportedType(target_type.into())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Registers Steel functions for database access within the VM context.
|
||||
fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
|
||||
debug!("Registering Steel functions with context");
|
||||
|
||||
// Register column access function for current and related tables
|
||||
vm.register_fn("steel_get_column", {
|
||||
let ctx = context.clone();
|
||||
move |table: String, column: String| {
|
||||
debug!("steel_get_column called with table: '{}', column: '{}'", table, column);
|
||||
debug!(
|
||||
"steel_get_column called with table: '{}', column: '{}'",
|
||||
table, column
|
||||
);
|
||||
ctx.steel_get_column(&table, &column)
|
||||
.map_err(|e| {
|
||||
error!("steel_get_column failed: {:?}", e);
|
||||
@@ -132,11 +203,13 @@ fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
|
||||
}
|
||||
});
|
||||
|
||||
// Register indexed column access for comma-separated values
|
||||
vm.register_fn("steel_get_column_with_index", {
|
||||
let ctx = context.clone();
|
||||
move |table: String, index: i64, column: String| {
|
||||
debug!("steel_get_column_with_index called with table: '{}', index: {}, column: '{}'", table, index, column);
|
||||
debug!(
|
||||
"steel_get_column_with_index called with table: '{}', index: {}, column: '{}'",
|
||||
table, index, column
|
||||
);
|
||||
ctx.steel_get_column_with_index(&table, index, &column)
|
||||
.map_err(|e| {
|
||||
error!("steel_get_column_with_index failed: {:?}", e);
|
||||
@@ -145,27 +218,23 @@ fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
|
||||
}
|
||||
});
|
||||
|
||||
// Register safe SQL query execution
|
||||
vm.register_fn("steel_query_sql", {
|
||||
let ctx = context.clone();
|
||||
move |query: String| {
|
||||
debug!("steel_query_sql called with query: '{}'", query);
|
||||
ctx.steel_query_sql(&query)
|
||||
.map_err(|e| {
|
||||
error!("steel_query_sql failed: {:?}", e);
|
||||
e.to_string()
|
||||
})
|
||||
ctx.steel_query_sql(&query).map_err(|e| {
|
||||
error!("steel_query_sql failed: {:?}", e);
|
||||
e.to_string()
|
||||
})
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Registers decimal mathematics functions in the Steel VM.
|
||||
fn register_decimal_math_functions(vm: &mut Engine) {
|
||||
debug!("Registering decimal math functions");
|
||||
FunctionRegistry::register_all(vm);
|
||||
}
|
||||
|
||||
/// Processes Steel script results into string format for consistent output.
|
||||
fn process_string_results(results: Vec<SteelVal>) -> Result<Value, ExecutionError> {
|
||||
let mut strings = Vec::new();
|
||||
|
||||
@@ -178,7 +247,7 @@ fn process_string_results(results: Vec<SteelVal>) -> Result<Value, ExecutionErro
|
||||
_ => {
|
||||
error!("Unexpected result type: {:?}", result);
|
||||
return Err(ExecutionError::TypeConversionError(
|
||||
format!("Expected string-convertible type, got {:?}", result)
|
||||
format!("Expected string-convertible type, got {:?}", result),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
// src/steel/server/functions.rs
|
||||
|
||||
use common::proto::komp_ac::table_definition::ColumnDefinition;
|
||||
use steel::rvals::SteelVal;
|
||||
use sqlx::PgPool;
|
||||
use std::collections::HashMap;
|
||||
@@ -21,10 +22,8 @@ pub enum FunctionError {
|
||||
ProhibitedTypeAccess(String),
|
||||
}
|
||||
|
||||
/// Data types that Steel scripts are prohibited from accessing for security reasons
|
||||
const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
|
||||
|
||||
/// Execution context for Steel scripts with database access capabilities.
|
||||
#[derive(Clone)]
|
||||
pub struct SteelContext {
|
||||
pub current_table: String,
|
||||
@@ -35,26 +34,11 @@ pub struct SteelContext {
|
||||
}
|
||||
|
||||
impl SteelContext {
|
||||
/// Resolves a base table name to its full qualified name in the current schema.
|
||||
/// Used for foreign key relationship traversal in Steel scripts.
|
||||
pub async fn get_related_table_name(&self, base_name: &str) -> Result<String, FunctionError> {
|
||||
let table_def = sqlx::query!(
|
||||
r#"SELECT table_name FROM table_definitions
|
||||
WHERE schema_id = $1 AND table_name LIKE $2"#,
|
||||
self.schema_id,
|
||||
format!("%_{}", base_name)
|
||||
)
|
||||
.fetch_optional(&*self.db_pool)
|
||||
.await
|
||||
.map_err(|e| FunctionError::DatabaseError(e.to_string()))?
|
||||
.ok_or_else(|| FunctionError::TableNotFound(base_name.to_string()))?;
|
||||
|
||||
Ok(table_def.table_name)
|
||||
}
|
||||
|
||||
/// Retrieves the SQL data type for a specific column in a table.
|
||||
/// Parses the JSON column definitions to find type information.
|
||||
async fn get_column_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
|
||||
async fn get_column_type(
|
||||
&self,
|
||||
table_name: &str,
|
||||
column_name: &str,
|
||||
) -> Result<String, FunctionError> {
|
||||
let table_def = sqlx::query!(
|
||||
r#"SELECT columns FROM table_definitions
|
||||
WHERE schema_id = $1 AND table_name = $2"#,
|
||||
@@ -66,49 +50,43 @@ impl SteelContext {
|
||||
.map_err(|e| FunctionError::DatabaseError(e.to_string()))?
|
||||
.ok_or_else(|| FunctionError::TableNotFound(table_name.to_string()))?;
|
||||
|
||||
let columns: Vec<String> = serde_json::from_value(table_def.columns)
|
||||
.map_err(|e| FunctionError::DatabaseError(format!("Invalid column data: {}", e)))?;
|
||||
let columns: Vec<ColumnDefinition> = serde_json::from_value(table_def.columns)
|
||||
.map_err(|e| FunctionError::DatabaseError(format!(
|
||||
"Invalid column data: {}",
|
||||
e
|
||||
)))?;
|
||||
|
||||
// Parse column definitions to find the requested column type
|
||||
for column_def in columns {
|
||||
let mut parts = column_def.split_whitespace();
|
||||
if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) {
|
||||
let column_name_clean = name.trim_matches('"');
|
||||
if column_name_clean == column_name {
|
||||
return Ok(data_type.to_string());
|
||||
}
|
||||
for col_def in columns {
|
||||
if col_def.name == column_name {
|
||||
return Ok(col_def.field_type.to_uppercase());
|
||||
}
|
||||
}
|
||||
|
||||
Err(FunctionError::ColumnNotFound(format!(
|
||||
"Column '{}' not found in table '{}'",
|
||||
column_name,
|
||||
table_name
|
||||
column_name, table_name
|
||||
)))
|
||||
}
|
||||
|
||||
/// Converts database values to Steel script format based on column type.
|
||||
/// Currently handles boolean conversion to Steel's #true/#false syntax.
|
||||
fn convert_value_to_steel_format(&self, value: &str, column_type: &str) -> String {
|
||||
let normalized_type = normalize_data_type(column_type);
|
||||
|
||||
match normalized_type.as_str() {
|
||||
"BOOLEAN" | "BOOL" => {
|
||||
// Convert database boolean representations to Steel boolean syntax
|
||||
match value.to_lowercase().as_str() {
|
||||
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
|
||||
"false" | "f" | "0" | "no" | "off" => "#false".to_string(),
|
||||
_ => value.to_string(), // Return as-is if not a recognized boolean
|
||||
}
|
||||
}
|
||||
"BOOLEAN" | "BOOL" => match value.to_lowercase().as_str() {
|
||||
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
|
||||
"false" | "f" | "0" | "no" | "off" => "#false".to_string(),
|
||||
_ => value.to_string(),
|
||||
},
|
||||
"INTEGER" => value.to_string(),
|
||||
_ => value.to_string(), // Return as-is for other types
|
||||
_ => value.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validates that a column type is allowed for Steel script access.
|
||||
/// Returns the column type if validation passes, error if prohibited.
|
||||
async fn validate_column_type_and_get_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
|
||||
async fn validate_column_type_and_get_type(
|
||||
&self,
|
||||
table_name: &str,
|
||||
column_name: &str,
|
||||
) -> Result<String, FunctionError> {
|
||||
let column_type = self.get_column_type(table_name, column_name).await?;
|
||||
|
||||
if is_prohibited_type(&column_type) {
|
||||
@@ -124,15 +102,13 @@ impl SteelContext {
|
||||
Ok(column_type)
|
||||
}
|
||||
|
||||
/// Retrieves column value from current table or related tables via foreign keys.
|
||||
///
|
||||
/// # Behavior
|
||||
/// - Current table: Returns value directly from row_data with type conversion
|
||||
/// - Related table: Follows foreign key relationship and queries database
|
||||
/// - All accesses are subject to prohibited type validation
|
||||
pub fn steel_get_column(&self, table: &str, column: &str) -> Result<SteelVal, SteelVal> {
|
||||
pub fn steel_get_column(
|
||||
&self,
|
||||
table: &str,
|
||||
column: &str,
|
||||
) -> Result<SteelVal, SteelVal> {
|
||||
if table == self.current_table {
|
||||
// Access current table data with type validation
|
||||
// current table
|
||||
let column_type = tokio::task::block_in_place(|| {
|
||||
let handle = tokio::runtime::Handle::current();
|
||||
handle.block_on(async {
|
||||
@@ -145,70 +121,112 @@ impl SteelContext {
|
||||
Err(e) => return Err(SteelVal::StringV(e.to_string().into())),
|
||||
};
|
||||
|
||||
return self.row_data.get(column)
|
||||
return self
|
||||
.row_data
|
||||
.get(column)
|
||||
.map(|v| {
|
||||
let converted_value = self.convert_value_to_steel_format(v, &column_type);
|
||||
SteelVal::StringV(converted_value.into())
|
||||
let converted =
|
||||
self.convert_value_to_steel_format(v, &column_type);
|
||||
SteelVal::StringV(converted.into())
|
||||
})
|
||||
.ok_or_else(|| SteelVal::StringV(format!("Column {} not found", column).into()));
|
||||
.ok_or_else(|| {
|
||||
SteelVal::StringV(
|
||||
format!("Column {} not found", column).into(),
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
// Access related table via foreign key relationship
|
||||
let base_name = table.split_once('_')
|
||||
.map(|(_, rest)| rest)
|
||||
.unwrap_or(table);
|
||||
|
||||
let fk_column = format!("{}_id", base_name);
|
||||
let fk_value = self.row_data.get(&fk_column)
|
||||
.ok_or_else(|| SteelVal::StringV(format!("Foreign key {} not found", fk_column).into()))?;
|
||||
|
||||
// Cross-table via FK: use exact table name FK convention: "<table>_id"
|
||||
let result = tokio::task::block_in_place(|| {
|
||||
let handle = tokio::runtime::Handle::current();
|
||||
handle.block_on(async {
|
||||
let actual_table = self.get_related_table_name(base_name).await
|
||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
||||
let fk_column = format!("{}_id", table);
|
||||
let fk_value = self
|
||||
.row_data
|
||||
.get(&fk_column)
|
||||
.ok_or_else(|| {
|
||||
FunctionError::ForeignKeyNotFound(format!(
|
||||
"Foreign key column '{}' not found on '{}'",
|
||||
fk_column, self.current_table
|
||||
))
|
||||
})?;
|
||||
|
||||
// Validate column type and get type information
|
||||
let column_type = self.validate_column_type_and_get_type(&actual_table, column).await
|
||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
||||
let column_type =
|
||||
self.validate_column_type_and_get_type(table, column)
|
||||
.await?;
|
||||
|
||||
// Query the related table for the column value
|
||||
let raw_value = sqlx::query_scalar::<_, String>(
|
||||
&format!("SELECT {} FROM \"{}\".\"{}\" WHERE id = $1", column, self.schema_name, actual_table)
|
||||
let raw_value = sqlx::query_scalar::<_, String>(&format!(
|
||||
"SELECT \"{}\" FROM \"{}\".\"{}\" WHERE id = $1",
|
||||
column, self.schema_name, table
|
||||
))
|
||||
.bind(
|
||||
fk_value
|
||||
.parse::<i64>()
|
||||
.map_err(|_| {
|
||||
FunctionError::DatabaseError(
|
||||
"Invalid foreign key format".into(),
|
||||
)
|
||||
})?,
|
||||
)
|
||||
.bind(fk_value.parse::<i64>().map_err(|_|
|
||||
SteelVal::StringV("Invalid foreign key format".into()))?)
|
||||
.fetch_one(&*self.db_pool)
|
||||
.await
|
||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
||||
.map_err(|e| FunctionError::DatabaseError(e.to_string()))?;
|
||||
|
||||
// Convert to appropriate Steel format
|
||||
let converted_value = self.convert_value_to_steel_format(&raw_value, &column_type);
|
||||
Ok(converted_value)
|
||||
let converted =
|
||||
self.convert_value_to_steel_format(&raw_value, &column_type);
|
||||
Ok::<String, FunctionError>(converted)
|
||||
})
|
||||
});
|
||||
|
||||
result.map(|v| SteelVal::StringV(v.into()))
|
||||
match result {
|
||||
Ok(v) => Ok(SteelVal::StringV(v.into())),
|
||||
Err(e) => Err(SteelVal::StringV(e.to_string().into())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves a specific indexed element from a comma-separated column value.
|
||||
/// Useful for accessing elements from array-like string representations.
|
||||
pub fn steel_get_column_with_index(
|
||||
&self,
|
||||
table: &str,
|
||||
index: i64,
|
||||
column: &str
|
||||
column: &str,
|
||||
) -> Result<SteelVal, SteelVal> {
|
||||
// Get the full value with proper type conversion
|
||||
let value = self.steel_get_column(table, column)?;
|
||||
// Cross-table: interpret 'index' as the row id to fetch directly
|
||||
if table != self.current_table {
|
||||
let result = tokio::task::block_in_place(|| {
|
||||
let handle = tokio::runtime::Handle::current();
|
||||
handle.block_on(async {
|
||||
let column_type =
|
||||
self.validate_column_type_and_get_type(table, column)
|
||||
.await?;
|
||||
|
||||
let raw_value = sqlx::query_scalar::<_, String>(&format!(
|
||||
"SELECT \"{}\" FROM \"{}\".\"{}\" WHERE id = $1",
|
||||
column, self.schema_name, table
|
||||
))
|
||||
.bind(index)
|
||||
.fetch_one(&*self.db_pool)
|
||||
.await
|
||||
.map_err(|e| FunctionError::DatabaseError(e.to_string()))?;
|
||||
|
||||
let converted = self
|
||||
.convert_value_to_steel_format(&raw_value, &column_type);
|
||||
Ok::<String, FunctionError>(converted)
|
||||
})
|
||||
});
|
||||
|
||||
return match result {
|
||||
Ok(v) => Ok(SteelVal::StringV(v.into())),
|
||||
Err(e) => Err(SteelVal::StringV(e.to_string().into())),
|
||||
};
|
||||
}
|
||||
|
||||
// Current table: existing behavior (index in comma-separated string)
|
||||
let value = self.steel_get_column(table, column)?;
|
||||
if let SteelVal::StringV(s) = value {
|
||||
let parts: Vec<_> = s.split(',').collect();
|
||||
|
||||
if let Some(part) = parts.get(index as usize) {
|
||||
let trimmed_part = part.trim();
|
||||
let trimmed = part.trim();
|
||||
|
||||
// Apply type conversion to the indexed part based on original column type
|
||||
let column_type = tokio::task::block_in_place(|| {
|
||||
let handle = tokio::runtime::Handle::current();
|
||||
handle.block_on(async {
|
||||
@@ -218,40 +236,35 @@ impl SteelContext {
|
||||
|
||||
match column_type {
|
||||
Ok(ct) => {
|
||||
let converted_part = self.convert_value_to_steel_format(trimmed_part, &ct);
|
||||
Ok(SteelVal::StringV(converted_part.into()))
|
||||
}
|
||||
Err(_) => {
|
||||
// If type cannot be determined, return value as-is
|
||||
Ok(SteelVal::StringV(trimmed_part.into()))
|
||||
let converted =
|
||||
self.convert_value_to_steel_format(trimmed, &ct);
|
||||
Ok(SteelVal::StringV(converted.into()))
|
||||
}
|
||||
Err(_) => Ok(SteelVal::StringV(trimmed.into())),
|
||||
}
|
||||
} else {
|
||||
Err(SteelVal::StringV("Index out of bounds".into()))
|
||||
}
|
||||
} else {
|
||||
Err(SteelVal::StringV("Expected comma-separated string".into()))
|
||||
Err(SteelVal::StringV(
|
||||
"Expected comma-separated string".into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes read-only SQL queries from Steel scripts with safety restrictions.
|
||||
///
|
||||
/// # Security Features
|
||||
/// - Only SELECT, SHOW, and EXPLAIN queries allowed
|
||||
/// - Prohibited column type access validation
|
||||
/// - Returns first column of all rows as comma-separated string
|
||||
pub fn steel_query_sql(&self, query: &str) -> Result<SteelVal, SteelVal> {
|
||||
if !is_read_only_query(query) {
|
||||
return Err(SteelVal::StringV(
|
||||
"Only SELECT queries are allowed".into()
|
||||
));
|
||||
return Err(SteelVal::StringV("Only SELECT queries are allowed".into()));
|
||||
}
|
||||
|
||||
if contains_prohibited_column_access(query) {
|
||||
return Err(SteelVal::StringV(format!(
|
||||
"SQL query may access prohibited column types. Steel scripts cannot access columns of type: {}",
|
||||
PROHIBITED_TYPES.join(", ")
|
||||
).into()));
|
||||
return Err(SteelVal::StringV(
|
||||
format!(
|
||||
"SQL query may access prohibited column types. Steel scripts cannot access columns of type: {}",
|
||||
PROHIBITED_TYPES.join(", ")
|
||||
)
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool = self.db_pool.clone();
|
||||
@@ -266,7 +279,8 @@ impl SteelContext {
|
||||
|
||||
let mut results = Vec::new();
|
||||
for row in rows {
|
||||
let val: String = row.try_get(0)
|
||||
let val: String = row
|
||||
.try_get(0)
|
||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
||||
results.push(val);
|
||||
}
|
||||
@@ -279,85 +293,30 @@ impl SteelContext {
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if a data type is prohibited for Steel script access.
|
||||
fn is_prohibited_type(data_type: &str) -> bool {
|
||||
let normalized_type = normalize_data_type(data_type);
|
||||
PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited))
|
||||
PROHIBITED_TYPES
|
||||
.iter()
|
||||
.any(|&prohibited| normalized_type.starts_with(prohibited))
|
||||
}
|
||||
|
||||
/// Normalizes data type strings for consistent comparison.
|
||||
/// Handles variations like NUMERIC(10,2) by extracting base type.
|
||||
fn normalize_data_type(data_type: &str) -> String {
|
||||
data_type.to_uppercase()
|
||||
.split('(') // Remove precision/scale from NUMERIC(x,y)
|
||||
data_type
|
||||
.to_uppercase()
|
||||
.split('(')
|
||||
.next()
|
||||
.unwrap_or(data_type)
|
||||
.trim()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Performs basic heuristic check for prohibited column type access in SQL queries.
|
||||
/// Looks for common patterns that might indicate access to restricted types.
|
||||
fn contains_prohibited_column_access(query: &str) -> bool {
|
||||
let query_upper = query.to_uppercase();
|
||||
|
||||
let patterns = [
|
||||
"EXTRACT(", // Common with DATE/TIMESTAMPTZ
|
||||
"DATE_PART(", // Common with DATE/TIMESTAMPTZ
|
||||
"::DATE",
|
||||
"::TIMESTAMPTZ",
|
||||
"::BIGINT",
|
||||
];
|
||||
|
||||
patterns.iter().any(|pattern| query_upper.contains(pattern))
|
||||
let patterns = ["EXTRACT(", "DATE_PART(", "::DATE", "::TIMESTAMPTZ", "::BIGINT"];
|
||||
patterns.iter().any(|p| query_upper.contains(p))
|
||||
}
|
||||
|
||||
/// Validates that a query is read-only and safe for Steel script execution.
|
||||
fn is_read_only_query(query: &str) -> bool {
|
||||
let query = query.trim_start().to_uppercase();
|
||||
query.starts_with("SELECT") ||
|
||||
query.starts_with("SHOW") ||
|
||||
query.starts_with("EXPLAIN")
|
||||
}
|
||||
|
||||
/// Converts row data boolean values to Steel script format during context initialization.
|
||||
pub async fn convert_row_data_for_steel(
|
||||
db_pool: &PgPool,
|
||||
schema_id: i64,
|
||||
table_name: &str,
|
||||
row_data: &mut HashMap<String, String>,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
let table_def = sqlx::query!(
|
||||
r#"SELECT columns FROM table_definitions
|
||||
WHERE schema_id = $1 AND table_name = $2"#,
|
||||
schema_id,
|
||||
table_name
|
||||
)
|
||||
.fetch_optional(db_pool)
|
||||
.await?
|
||||
.ok_or_else(|| sqlx::Error::RowNotFound)?;
|
||||
|
||||
// Parse column definitions to identify boolean columns for conversion
|
||||
if let Ok(columns) = serde_json::from_value::<Vec<String>>(table_def.columns) {
|
||||
for column_def in columns {
|
||||
let mut parts = column_def.split_whitespace();
|
||||
if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) {
|
||||
let column_name = name.trim_matches('"');
|
||||
let normalized_type = normalize_data_type(data_type);
|
||||
|
||||
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
|
||||
if let Some(value) = row_data.get_mut(column_name) {
|
||||
// Convert boolean value to Steel format
|
||||
*value = match value.to_lowercase().as_str() {
|
||||
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
|
||||
"false" | "f" | "0" | "no" | "off" => "#false".to_string(),
|
||||
_ => value.clone(), // Keep original if not recognized
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
query.starts_with("SELECT") || query.starts_with("SHOW") || query.starts_with("EXPLAIN")
|
||||
}
|
||||
|
||||
@@ -4,22 +4,8 @@ use tonic::Status;
|
||||
use sqlx::{PgPool, Transaction, Postgres};
|
||||
use serde_json::json;
|
||||
use common::proto::komp_ac::table_definition::{PostTableDefinitionRequest, TableDefinitionResponse};
|
||||
|
||||
// TODO CRITICAL add decimal with optional precision"
|
||||
const PREDEFINED_FIELD_TYPES: &[(&str, &str)] = &[
|
||||
("text", "TEXT"),
|
||||
("string", "TEXT"),
|
||||
("boolean", "BOOLEAN"),
|
||||
("timestamp", "TIMESTAMPTZ"),
|
||||
("timestamptz", "TIMESTAMPTZ"),
|
||||
("time", "TIMESTAMPTZ"),
|
||||
("money", "NUMERIC(14, 4)"),
|
||||
("integer", "INTEGER"),
|
||||
("int", "INTEGER"),
|
||||
("biginteger", "BIGINT"),
|
||||
("bigint", "BIGINT"),
|
||||
("date", "DATE"),
|
||||
];
|
||||
use common::proto::komp_ac::table_definition::ColumnDefinition;
|
||||
use crate::table_definition::models::map_field_type;
|
||||
|
||||
// NEW: Helper function to provide detailed error messages
|
||||
fn validate_identifier_format(s: &str, identifier_type: &str) -> Result<(), Status> {
|
||||
@@ -58,116 +44,6 @@ fn validate_identifier_format(s: &str, identifier_type: &str) -> Result<(), Stat
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_decimal_number_format(num_str: &str, param_name: &str) -> Result<(), Status> {
|
||||
if num_str.is_empty() {
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"{} cannot be empty",
|
||||
param_name
|
||||
)));
|
||||
}
|
||||
|
||||
// Check for explicit signs
|
||||
if num_str.starts_with('+') || num_str.starts_with('-') {
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"{} cannot have explicit positive or negative signs",
|
||||
param_name
|
||||
)));
|
||||
}
|
||||
|
||||
// Check for decimal points
|
||||
if num_str.contains('.') {
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"{} must be a whole number (no decimal points)",
|
||||
param_name
|
||||
)));
|
||||
}
|
||||
|
||||
// Check for leading zeros (but allow "0" itself)
|
||||
if num_str.len() > 1 && num_str.starts_with('0') {
|
||||
let trimmed = num_str.trim_start_matches('0');
|
||||
let suggestion = if trimmed.is_empty() { "0" } else { trimmed };
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"{} cannot have leading zeros (use '{}' instead of '{}')",
|
||||
param_name,
|
||||
suggestion,
|
||||
num_str
|
||||
)));
|
||||
}
|
||||
|
||||
// Check that all characters are digits
|
||||
if !num_str.chars().all(|c| c.is_ascii_digit()) {
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"{} contains invalid characters. Only digits 0-9 are allowed",
|
||||
param_name
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn map_field_type(field_type: &str) -> Result<String, Status> {
|
||||
let lower_field_type = field_type.to_lowercase();
|
||||
|
||||
// Special handling for "decimal(precision, scale)"
|
||||
if lower_field_type.starts_with("decimal(") && lower_field_type.ends_with(')') {
|
||||
// Extract the part inside the parentheses, e.g., "10, 2"
|
||||
let args = lower_field_type
|
||||
.strip_prefix("decimal(")
|
||||
.and_then(|s| s.strip_suffix(')'))
|
||||
.unwrap_or(""); // Should always succeed due to the checks above
|
||||
|
||||
// Split into precision and scale parts
|
||||
if let Some((p_str, s_str)) = args.split_once(',') {
|
||||
let precision_str = p_str.trim();
|
||||
let scale_str = s_str.trim();
|
||||
|
||||
// NEW: Validate format BEFORE parsing
|
||||
validate_decimal_number_format(precision_str, "precision")?;
|
||||
validate_decimal_number_format(scale_str, "scale")?;
|
||||
|
||||
// Parse precision, returning an error if it's not a valid number
|
||||
let precision = precision_str.parse::<u32>().map_err(|_| {
|
||||
Status::invalid_argument("Invalid precision in decimal type")
|
||||
})?;
|
||||
|
||||
// Parse scale, returning an error if it's not a valid number
|
||||
let scale = scale_str.parse::<u32>().map_err(|_| {
|
||||
Status::invalid_argument("Invalid scale in decimal type")
|
||||
})?;
|
||||
|
||||
// Add validation based on PostgreSQL rules
|
||||
if precision < 1 {
|
||||
return Err(Status::invalid_argument("Precision must be at least 1"));
|
||||
}
|
||||
if scale > precision {
|
||||
return Err(Status::invalid_argument(
|
||||
"Scale cannot be greater than precision",
|
||||
));
|
||||
}
|
||||
|
||||
// If everything is valid, build and return the NUMERIC type string
|
||||
return Ok(format!("NUMERIC({}, {})", precision, scale));
|
||||
} else {
|
||||
// The format was wrong, e.g., "decimal(10)" or "decimal()"
|
||||
return Err(Status::invalid_argument(
|
||||
"Invalid decimal format. Expected: decimal(precision, scale)",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// If not a decimal, fall back to the predefined list
|
||||
PREDEFINED_FIELD_TYPES
|
||||
.iter()
|
||||
.find(|(key, _)| *key == lower_field_type.as_str())
|
||||
.map(|(_, sql_type)| sql_type.to_string()) // Convert to an owned String
|
||||
.ok_or_else(|| {
|
||||
Status::invalid_argument(format!(
|
||||
"Invalid field type: {}",
|
||||
field_type
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn is_invalid_table_name(table_name: &str) -> bool {
|
||||
table_name.ends_with("_id") ||
|
||||
table_name == "id" ||
|
||||
@@ -299,34 +175,48 @@ async fn execute_table_definition(
|
||||
links.push((linked_id, link.required));
|
||||
}
|
||||
|
||||
let mut columns = Vec::new();
|
||||
let mut stored_columns = Vec::new();
|
||||
let mut sql_columns = Vec::new();
|
||||
|
||||
for col_def in request.columns.drain(..) {
|
||||
let col_name = col_def.name.trim().to_string();
|
||||
validate_identifier_format(&col_name, "Column name")?;
|
||||
|
||||
if col_name.ends_with("_id") || col_name == "id" || col_name == "deleted" || col_name == "created_at" {
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"Column name '{}' cannot be 'id', 'deleted', 'created_at' or end with '_id'",
|
||||
col_name
|
||||
"Column name '{}' cannot be 'id', 'deleted', 'created_at' or end with '_id'",
|
||||
col_name
|
||||
)));
|
||||
}
|
||||
|
||||
let sql_type = map_field_type(&col_def.field_type)?;
|
||||
columns.push(format!("\"{}\" {}", col_name, sql_type));
|
||||
sql_columns.push(format!("\"{}\" {}", col_name, sql_type));
|
||||
|
||||
// push the proto type (serde serializable)
|
||||
stored_columns.push(ColumnDefinition {
|
||||
name: col_name,
|
||||
field_type: col_def.field_type,
|
||||
});
|
||||
}
|
||||
|
||||
// Indexes
|
||||
let mut stored_indexes = Vec::new();
|
||||
let mut indexes = Vec::new();
|
||||
for idx in request.indexes.drain(..) {
|
||||
let idx_name = idx.trim().to_string();
|
||||
validate_identifier_format(&idx_name, "Index name")?;
|
||||
|
||||
if !columns.iter().any(|c| c.starts_with(&format!("\"{}\"", idx_name))) {
|
||||
return Err(Status::invalid_argument(format!("Index column '{}' not found", idx_name)));
|
||||
if !sql_columns.iter().any(|c| c.starts_with(&format!("\"{}\"", idx_name))) {
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"Index column '{}' not found", idx_name
|
||||
)));
|
||||
}
|
||||
|
||||
stored_indexes.push(idx_name.clone());
|
||||
indexes.push(idx_name);
|
||||
}
|
||||
|
||||
let (create_sql, index_sql) = generate_table_sql(tx, &profile_name, &table_name, &columns, &indexes, &links).await?;
|
||||
let (create_sql, index_sql) = generate_table_sql(tx, &profile_name, &table_name, &sql_columns, &indexes, &links).await?;
|
||||
|
||||
// Use schema_id instead of profile_id
|
||||
let table_def = sqlx::query!(
|
||||
@@ -336,8 +226,8 @@ async fn execute_table_definition(
|
||||
RETURNING id"#,
|
||||
schema.id,
|
||||
&table_name,
|
||||
json!(columns),
|
||||
json!(indexes)
|
||||
serde_json::to_value(&stored_columns).unwrap(),
|
||||
serde_json::to_value(&stored_indexes).unwrap()
|
||||
)
|
||||
.fetch_one(&mut **tx)
|
||||
.await
|
||||
@@ -351,7 +241,7 @@ async fn execute_table_definition(
|
||||
Status::internal(format!("Database error: {}", e))
|
||||
})?;
|
||||
|
||||
for col_def in &columns {
|
||||
for col_def in &sql_columns {
|
||||
// Column string looks like "\"name\" TYPE", split out identifier
|
||||
let col_name = col_def.split_whitespace().next().unwrap_or("");
|
||||
let clean_col = col_name.trim_matches('"');
|
||||
|
||||
@@ -2,3 +2,6 @@
|
||||
|
||||
pub mod models;
|
||||
pub mod handlers;
|
||||
pub mod repo;
|
||||
|
||||
pub use repo::*;
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
// src/table_definition/models.rs
|
||||
use tonic::Status;
|
||||
|
||||
/// Predefined static field mappings
|
||||
// TODO CRITICAL add decimal with optional precision"
|
||||
pub const PREDEFINED_FIELD_TYPES: &[(&str, &str)] = &[
|
||||
("text", "TEXT"),
|
||||
("string", "TEXT"),
|
||||
("boolean", "BOOLEAN"),
|
||||
("timestamp", "TIMESTAMPTZ"),
|
||||
("timestamptz", "TIMESTAMPTZ"),
|
||||
("time", "TIMESTAMPTZ"),
|
||||
("money", "NUMERIC(14, 4)"),
|
||||
("integer", "INTEGER"),
|
||||
("int", "INTEGER"),
|
||||
("biginteger", "BIGINT"),
|
||||
("bigint", "BIGINT"),
|
||||
("date", "DATE"),
|
||||
];
|
||||
|
||||
/// reusable decimal number validation
|
||||
pub fn validate_decimal_number_format(num_str: &str, param_name: &str) -> Result<(), Status> {
|
||||
if num_str.is_empty() {
|
||||
return Err(Status::invalid_argument(format!("{} cannot be empty", param_name)));
|
||||
}
|
||||
if num_str.starts_with('+') || num_str.starts_with('-') {
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"{} cannot have explicit positive/negative signs", param_name
|
||||
)));
|
||||
}
|
||||
if num_str.contains('.') {
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"{} must be a whole number (no decimal point)", param_name
|
||||
)));
|
||||
}
|
||||
if num_str.len() > 1 && num_str.starts_with('0') {
|
||||
let trimmed = num_str.trim_start_matches('0');
|
||||
let suggestion = if trimmed.is_empty() { "0" } else { trimmed };
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"{} cannot have leading zeros (use '{}' instead of '{}')",
|
||||
param_name, suggestion, num_str
|
||||
)));
|
||||
}
|
||||
if !num_str.chars().all(|c| c.is_ascii_digit()) {
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"{} contains invalid characters. Only digits allowed", param_name
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// reusable field type mapper
|
||||
pub fn map_field_type(field_type: &str) -> Result<String, Status> {
|
||||
let lower_field_type = field_type.to_lowercase();
|
||||
|
||||
if lower_field_type.starts_with("decimal(") && lower_field_type.ends_with(')') {
|
||||
let args = lower_field_type.strip_prefix("decimal(").unwrap()
|
||||
.strip_suffix(')').unwrap();
|
||||
|
||||
if let Some((p_str, s_str)) = args.split_once(',') {
|
||||
let precision_str = p_str.trim();
|
||||
let scale_str = s_str.trim();
|
||||
|
||||
validate_decimal_number_format(precision_str, "precision")?;
|
||||
validate_decimal_number_format(scale_str, "scale")?;
|
||||
|
||||
let precision = precision_str.parse::<u32>()
|
||||
.map_err(|_| Status::invalid_argument("Invalid precision"))?;
|
||||
let scale = scale_str.parse::<u32>()
|
||||
.map_err(|_| Status::invalid_argument("Invalid scale"))?;
|
||||
|
||||
if precision < 1 {
|
||||
return Err(Status::invalid_argument("Precision must be >= 1"));
|
||||
}
|
||||
if scale > precision {
|
||||
return Err(Status::invalid_argument("Scale cannot be > precision"));
|
||||
}
|
||||
return Ok(format!("NUMERIC({}, {})", precision, scale));
|
||||
} else {
|
||||
return Err(Status::invalid_argument(
|
||||
"Invalid decimal format. Expected decimal(precision, scale)"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
PREDEFINED_FIELD_TYPES
|
||||
.iter()
|
||||
.find(|(key, _)| *key == lower_field_type.as_str())
|
||||
.map(|(_, sql_type)| sql_type.to_string())
|
||||
.ok_or_else(|| Status::invalid_argument(format!("Invalid field type: {}", field_type)))
|
||||
}
|
||||
|
||||
33
server/src/table_definition/repo.rs
Normal file
33
server/src/table_definition/repo.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
// src/table_definition/repo.rs
|
||||
use common::proto::komp_ac::table_definition::ColumnDefinition;
|
||||
use sqlx::PgPool;
|
||||
|
||||
pub struct TableDefRow {
|
||||
pub id: i64,
|
||||
pub table_name: String,
|
||||
pub columns: Vec<ColumnDefinition>,
|
||||
pub indexes: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn get_table_definition(
|
||||
db: &PgPool,
|
||||
id: i64,
|
||||
) -> Result<TableDefRow, anyhow::Error> {
|
||||
let rec = sqlx::query!(
|
||||
r#"
|
||||
SELECT id, table_name, columns, indexes
|
||||
FROM table_definitions
|
||||
WHERE id = $1
|
||||
"#,
|
||||
id
|
||||
)
|
||||
.fetch_one(db)
|
||||
.await?;
|
||||
|
||||
Ok(TableDefRow {
|
||||
id: rec.id,
|
||||
table_name: rec.table_name,
|
||||
columns: serde_json::from_value(rec.columns)?, // 🔑
|
||||
indexes: serde_json::from_value(rec.indexes)?,
|
||||
})
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
use tonic::Status;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
/// Represents the state of a node during dependency graph traversal.
|
||||
@@ -40,18 +41,38 @@ impl DependencyType {
|
||||
DependencyType::SqlQuery { .. } => "sql_query",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates context JSON for database storage.
|
||||
pub fn context_json(&self) -> Value {
|
||||
/// Strongly-typed JSON for script_dependencies.context_info
|
||||
/// Using untagged so JSON stays minimal (no "type" field), and we can still
|
||||
/// deserialize it into a proper enum.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ScriptDependencyContext {
|
||||
ColumnAccess { column: String },
|
||||
IndexedAccess { column: String, index: i64 },
|
||||
SqlQuery { query_fragment: String },
|
||||
}
|
||||
|
||||
impl DependencyType {
|
||||
/// Convert this dependency into its JSON context struct.
|
||||
pub fn to_context(&self) -> ScriptDependencyContext {
|
||||
match self {
|
||||
DependencyType::ColumnAccess { column } => {
|
||||
json!({ "column": column })
|
||||
ScriptDependencyContext::ColumnAccess {
|
||||
column: column.clone(),
|
||||
}
|
||||
}
|
||||
DependencyType::IndexedAccess { column, index } => {
|
||||
json!({ "column": column, "index": index })
|
||||
ScriptDependencyContext::IndexedAccess {
|
||||
column: column.clone(),
|
||||
index: *index,
|
||||
}
|
||||
}
|
||||
DependencyType::SqlQuery { query_fragment } => {
|
||||
json!({ "query_fragment": query_fragment })
|
||||
ScriptDependencyContext::SqlQuery {
|
||||
query_fragment: query_fragment.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -554,7 +575,7 @@ impl DependencyAnalyzer {
|
||||
table_id,
|
||||
target_id,
|
||||
dep.dependency_type.as_str(),
|
||||
dep.dependency_type.context_json()
|
||||
serde_json::to_value(dep.dependency_type.to_context()).unwrap()
|
||||
)
|
||||
.execute(&mut **tx)
|
||||
.await
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
use tonic::Status;
|
||||
use sqlx::{PgPool, Error as SqlxError};
|
||||
use common::proto::komp_ac::table_script::{PostTableScriptRequest, TableScriptResponse};
|
||||
use common::proto::komp_ac::table_definition::ColumnDefinition;
|
||||
use serde_json::Value;
|
||||
use steel_decimal::SteelDecimal;
|
||||
use regex::Regex;
|
||||
@@ -303,16 +304,12 @@ async fn validate_math_operations_column_types(
|
||||
let mut table_column_types: HashMap<String, HashMap<String, String>> = HashMap::new();
|
||||
|
||||
for table_def in table_definitions {
|
||||
let columns: Vec<String> = serde_json::from_value(table_def.columns)
|
||||
let columns: Vec<ColumnDefinition> = serde_json::from_value(table_def.columns)
|
||||
.map_err(|e| Status::internal(format!("Invalid column data for table '{}': {}", table_def.table_name, e)))?;
|
||||
|
||||
let mut column_types = HashMap::new();
|
||||
for column_def in columns {
|
||||
let mut parts = column_def.split_whitespace();
|
||||
if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) {
|
||||
let column_name = name.trim_matches('"');
|
||||
column_types.insert(column_name.to_string(), data_type.to_string());
|
||||
}
|
||||
for col_def in columns {
|
||||
column_types.insert(col_def.name.clone(), col_def.field_type.clone());
|
||||
}
|
||||
table_column_types.insert(table_def.table_name, column_types);
|
||||
}
|
||||
@@ -363,25 +360,13 @@ fn validate_target_column(
|
||||
}
|
||||
|
||||
// Parse the columns JSON into a vector of strings
|
||||
let columns: Vec<String> = serde_json::from_value(table_columns.clone())
|
||||
let columns: Vec<ColumnDefinition> = serde_json::from_value(table_columns.clone())
|
||||
.map_err(|e| format!("Invalid column data: {}", e))?;
|
||||
|
||||
// Extract column names and types
|
||||
let column_info: Vec<(&str, &str)> = columns
|
||||
let column_type = columns
|
||||
.iter()
|
||||
.filter_map(|c| {
|
||||
let mut parts = c.split_whitespace();
|
||||
let name = parts.next()?.trim_matches('"');
|
||||
let data_type = parts.next()?;
|
||||
Some((name, data_type))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Find the target column and return its type
|
||||
let column_type = column_info
|
||||
.iter()
|
||||
.find(|(name, _)| *name == target)
|
||||
.map(|(_, dt)| dt.to_string())
|
||||
.find(|c| c.name == target)
|
||||
.map(|c| c.field_type.clone())
|
||||
.ok_or_else(|| format!("Target column '{}' not defined in table '{}'", target, table_name))?;
|
||||
|
||||
// Check if the target column type is prohibited
|
||||
@@ -509,42 +494,29 @@ async fn validate_script_column_references(
|
||||
/// Validate that a referenced column doesn't have a prohibited type
|
||||
fn validate_referenced_column_type(table_name: &str, column_name: &str, table_columns: &Value) -> Result<(), String> {
|
||||
// Parse the columns JSON into a vector of strings
|
||||
let columns: Vec<String> = serde_json::from_value(table_columns.clone())
|
||||
.map_err(|e| format!("Invalid column data for table '{}': {}", table_name, e))?;
|
||||
let columns: Vec<ColumnDefinition> = serde_json::from_value(table_columns.clone())
|
||||
.map_err(|e| format!("Invalid column data for table '{}': {}", table_name, e))?;
|
||||
|
||||
// Extract column names and types
|
||||
let column_info: Vec<(&str, &str)> = columns
|
||||
.iter()
|
||||
.filter_map(|c| {
|
||||
let mut parts = c.split_whitespace();
|
||||
let name = parts.next()?.trim_matches('"');
|
||||
let data_type = parts.next()?;
|
||||
Some((name, data_type))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Find the referenced column and check its type
|
||||
if let Some((_, column_type)) = column_info.iter().find(|(name, _)| *name == column_name) {
|
||||
if is_prohibited_type(column_type) {
|
||||
if let Some(col_def) = columns.iter().find(|c| c.name == column_name) {
|
||||
if is_prohibited_type(&col_def.field_type) {
|
||||
return Err(format!(
|
||||
"Script references column '{}' in table '{}' which has prohibited type '{}'. Steel scripts cannot access columns of type: {}",
|
||||
column_name,
|
||||
table_name,
|
||||
column_type,
|
||||
PROHIBITED_TYPES.join(", ")
|
||||
"Script references column '{}' in table '{}' which has prohibited type '{}'. Steel scripts cannot access columns of type: {}",
|
||||
column_name,
|
||||
table_name,
|
||||
col_def.field_type,
|
||||
PROHIBITED_TYPES.join(", ")
|
||||
));
|
||||
}
|
||||
|
||||
// Log info for boolean columns
|
||||
let normalized_type = normalize_data_type(column_type);
|
||||
let normalized_type = normalize_data_type(&col_def.field_type);
|
||||
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
|
||||
println!("Info: Script references boolean column '{}' in table '{}'. Values will be converted to Steel format (#true/#false)", column_name, table_name);
|
||||
}
|
||||
} else {
|
||||
return Err(format!(
|
||||
"Script references column '{}' in table '{}' but this column does not exist",
|
||||
column_name,
|
||||
table_name
|
||||
"Script references column '{}' in table '{}' but this column does not exist",
|
||||
column_name,
|
||||
table_name
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
// src/table_script/mod.rs
|
||||
|
||||
pub mod handlers;
|
||||
pub mod repo;
|
||||
|
||||
pub use handlers::*;
|
||||
pub use repo::*;
|
||||
|
||||
49
server/src/table_script/repo.rs
Normal file
49
server/src/table_script/repo.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
// src/table_script/repo.rs
|
||||
use anyhow::Result;
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::table_script::handlers::dependency_analyzer::ScriptDependencyContext;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScriptDependencyRecord {
|
||||
pub script_id: i64,
|
||||
pub source_table_id: i64,
|
||||
pub target_table_id: i64,
|
||||
pub dependency_type: String,
|
||||
pub context: Option<ScriptDependencyContext>,
|
||||
}
|
||||
|
||||
pub async fn get_dependencies_for_script(
|
||||
db: &PgPool,
|
||||
script_id: i64,
|
||||
) -> Result<Vec<ScriptDependencyRecord>> {
|
||||
let rows = sqlx::query!(
|
||||
r#"
|
||||
SELECT script_id, source_table_id, target_table_id, dependency_type, context_info
|
||||
FROM script_dependencies
|
||||
WHERE script_id = $1
|
||||
ORDER BY source_table_id, target_table_id
|
||||
"#,
|
||||
script_id
|
||||
)
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
|
||||
let mut out = Vec::new();
|
||||
for r in rows {
|
||||
let context = match r.context_info {
|
||||
Some(value) => Some(serde_json::from_value::<ScriptDependencyContext>(value)?),
|
||||
None => None,
|
||||
};
|
||||
|
||||
out.push(ScriptDependencyRecord {
|
||||
script_id: r.script_id,
|
||||
source_table_id: r.source_table_id,
|
||||
target_table_id: r.target_table_id,
|
||||
dependency_type: r.dependency_type,
|
||||
context,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
// src/tables_data/handlers/get_table_data.rs
|
||||
|
||||
|
||||
use tonic::Status;
|
||||
use sqlx::{PgPool, Row};
|
||||
use std::collections::HashMap;
|
||||
use common::proto::komp_ac::tables_data::{GetTableDataRequest, GetTableDataResponse};
|
||||
use common::proto::komp_ac::table_definition::ColumnDefinition;
|
||||
use crate::shared::schema_qualifier::qualify_table_name_for_data;
|
||||
|
||||
pub async fn get_table_data(
|
||||
@@ -39,17 +41,13 @@ pub async fn get_table_data(
|
||||
let table_def = table_def.ok_or_else(|| Status::not_found("Table not found"))?;
|
||||
|
||||
// Parse user-defined columns from JSON
|
||||
let columns_json: Vec<String> = serde_json::from_value(table_def.columns.clone())
|
||||
let stored_columns: Vec<ColumnDefinition> = serde_json::from_value(table_def.columns.clone())
|
||||
.map_err(|e| Status::internal(format!("Column parsing error: {}", e)))?;
|
||||
|
||||
// Directly extract names, no split(" ") parsing anymore
|
||||
let mut user_columns = Vec::new();
|
||||
for col_def in columns_json {
|
||||
let parts: Vec<&str> = col_def.splitn(2, ' ').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(Status::internal("Invalid column format"));
|
||||
}
|
||||
let name = parts[0].trim_matches('"').to_string();
|
||||
user_columns.push(name);
|
||||
for col_def in stored_columns {
|
||||
user_columns.push(col_def.name.trim().to_string());
|
||||
}
|
||||
|
||||
// --- START OF FIX ---
|
||||
|
||||
@@ -5,6 +5,8 @@ use sqlx::{PgPool, Arguments};
|
||||
use sqlx::postgres::PgArguments;
|
||||
use chrono::{DateTime, Utc};
|
||||
use common::proto::komp_ac::tables_data::{PostTableDataRequest, PostTableDataResponse};
|
||||
use common::proto::komp_ac::table_definition::ColumnDefinition;
|
||||
use crate::table_definition::models::map_field_type;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use prost_types::value::Kind;
|
||||
@@ -56,18 +58,16 @@ pub async fn post_table_data(
|
||||
let table_def = table_def.ok_or_else(|| Status::not_found("Table not found"))?;
|
||||
|
||||
// Parse column definitions from JSON format
|
||||
let columns_json: Vec<String> = serde_json::from_value(table_def.columns.clone())
|
||||
let stored_columns: Vec<ColumnDefinition> = serde_json::from_value(table_def.columns.clone())
|
||||
.map_err(|e| Status::internal(format!("Column parsing error: {}", e)))?;
|
||||
|
||||
// convert ColumnDefinition -> (name, sql_type) using the same map_field_type logic
|
||||
let mut columns = Vec::new();
|
||||
for col_def in columns_json {
|
||||
let parts: Vec<&str> = col_def.splitn(2, ' ').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(Status::internal("Invalid column format"));
|
||||
}
|
||||
let name = parts[0].trim_matches('"').to_string();
|
||||
let sql_type = parts[1].to_string();
|
||||
columns.push((name, sql_type));
|
||||
for col_def in stored_columns {
|
||||
let col_name = col_def.name.trim().to_string();
|
||||
let sql_type = map_field_type(&col_def.field_type)
|
||||
.map_err(|e| Status::invalid_argument(format!("Invalid type for column '{}': {}", col_name, e)))?;
|
||||
columns.push((col_name, sql_type));
|
||||
}
|
||||
|
||||
// Build list of valid system columns (foreign keys and special columns)
|
||||
|
||||
@@ -5,6 +5,7 @@ use sqlx::{PgPool, Arguments, Row};
|
||||
use sqlx::postgres::PgArguments;
|
||||
use chrono::{DateTime, Utc};
|
||||
use common::proto::komp_ac::tables_data::{PutTableDataRequest, PutTableDataResponse};
|
||||
use common::proto::komp_ac::table_definition::ColumnDefinition;
|
||||
|
||||
use std::sync::Arc;
|
||||
use prost_types::value::Kind;
|
||||
@@ -14,6 +15,7 @@ use std::collections::HashMap;
|
||||
|
||||
use crate::steel::server::execution::{self, Value};
|
||||
use crate::indexer::{IndexCommand, IndexCommandData};
|
||||
use crate::table_definition::models::map_field_type;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::error;
|
||||
|
||||
@@ -56,19 +58,20 @@ pub async fn put_table_data(
|
||||
.map_err(|e| Status::internal(format!("Table lookup error: {}", e)))?
|
||||
.ok_or_else(|| Status::not_found("Table not found"))?;
|
||||
|
||||
// Parse column definitions from JSON format
|
||||
let columns_json: Vec<String> = serde_json::from_value(table_def.columns.clone())
|
||||
// Parse column definitions from JSON format (now ColumnDefinition objects)
|
||||
let stored_columns: Vec<ColumnDefinition> = serde_json::from_value(table_def.columns.clone())
|
||||
.map_err(|e| Status::internal(format!("Column parsing error: {}", e)))?;
|
||||
|
||||
// Convert ColumnDefinition → (name, sql_type)
|
||||
let mut columns = Vec::new();
|
||||
for col_def in columns_json {
|
||||
let parts: Vec<&str> = col_def.splitn(2, ' ').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(Status::internal("Invalid column format"));
|
||||
}
|
||||
let name = parts[0].trim_matches('"').to_string();
|
||||
let sql_type = parts[1].to_string();
|
||||
columns.push((name, sql_type));
|
||||
for col_def in stored_columns {
|
||||
let col_name = col_def.name.trim().to_string();
|
||||
let sql_type = map_field_type(&col_def.field_type)
|
||||
.map_err(|e| Status::invalid_argument(format!(
|
||||
"Invalid type for column '{}': {}",
|
||||
col_name, e
|
||||
)))?;
|
||||
columns.push((col_name, sql_type));
|
||||
}
|
||||
|
||||
// Build list of valid system columns (foreign keys and special columns)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use crate::common::setup_isolated_db;
|
||||
use server::table_script::handlers::post_table_script::post_table_script; // Fixed import
|
||||
use common::proto::komp_ac::table_script::PostTableScriptRequest;
|
||||
use common::proto::komp_ac::table_definition::ColumnDefinition;
|
||||
use rstest::*;
|
||||
use serde_json::json;
|
||||
use sqlx::PgPool;
|
||||
@@ -12,15 +13,10 @@ async fn create_test_table(
|
||||
pool: &PgPool,
|
||||
schema_id: i64,
|
||||
table_name: &str,
|
||||
columns: Vec<(&str, &str)>,
|
||||
columns: Vec<ColumnDefinition>,
|
||||
) -> i64 {
|
||||
let column_definitions: Vec<String> = columns
|
||||
.iter()
|
||||
.map(|(name, type_def)| format!("\"{}\" {}", name, type_def))
|
||||
.collect();
|
||||
|
||||
let columns_json = json!(column_definitions);
|
||||
let indexes_json = json!([]);
|
||||
let columns_json = serde_json::to_value(columns).unwrap();
|
||||
let indexes_json = serde_json::json!([]);
|
||||
|
||||
sqlx::query_scalar!(
|
||||
r#"INSERT INTO table_definitions (schema_id, table_name, columns, indexes)
|
||||
@@ -115,22 +111,17 @@ async fn test_comprehensive_error_scenarios(
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
// Create comprehensive error test table
|
||||
let columns = vec![
|
||||
// Valid types
|
||||
("valid_numeric", "NUMERIC(10, 2)"),
|
||||
("valid_integer", "INTEGER"),
|
||||
|
||||
// Invalid for math operations
|
||||
("text_col", "TEXT"),
|
||||
("boolean_col", "BOOLEAN"),
|
||||
("bigint_col", "BIGINT"),
|
||||
("date_col", "DATE"),
|
||||
("timestamp_col", "TIMESTAMPTZ"),
|
||||
|
||||
// Invalid target types
|
||||
("bigint_target", "BIGINT"),
|
||||
("date_target", "DATE"),
|
||||
("timestamp_target", "TIMESTAMPTZ"),
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "valid_numeric".into(), field_type: "NUMERIC(10, 2)".into() },
|
||||
ColumnDefinition { name: "valid_integer".into(), field_type: "INTEGER".into() },
|
||||
ColumnDefinition { name: "text_col".into(), field_type: "TEXT".into() },
|
||||
ColumnDefinition { name: "boolean_col".into(), field_type: "BOOLEAN".into() },
|
||||
ColumnDefinition { name: "bigint_col".into(), field_type: "BIGINT".into() },
|
||||
ColumnDefinition { name: "date_col".into(), field_type: "DATE".into() },
|
||||
ColumnDefinition { name: "timestamp_col".into(), field_type: "TIMESTAMPTZ".into() },
|
||||
ColumnDefinition { name: "bigint_target".into(), field_type: "BIGINT".into() },
|
||||
ColumnDefinition { name: "date_target".into(), field_type: "DATE".into() },
|
||||
ColumnDefinition { name: "timestamp_target".into(), field_type: "TIMESTAMPTZ".into() },
|
||||
];
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "error_table", columns).await;
|
||||
@@ -169,7 +160,9 @@ async fn test_malformed_script_scenarios(
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![("result", "NUMERIC(10, 2)")];
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "result".into(), field_type: "NUMERIC(10, 2)".into() }
|
||||
];
|
||||
let table_id = create_test_table(&pool, schema_id, "malformed_test", columns).await;
|
||||
|
||||
let request = PostTableScriptRequest {
|
||||
@@ -194,7 +187,9 @@ async fn test_advanced_validation_scenarios(
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![("result", "NUMERIC(10, 2)")];
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "result".into(), field_type: "NUMERIC(10, 2)".into() }
|
||||
];
|
||||
let table_id = create_test_table(&pool, schema_id, "advanced_test", columns).await;
|
||||
|
||||
let request = PostTableScriptRequest {
|
||||
@@ -236,16 +231,16 @@ async fn test_dependency_cycle_detection() {
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
// Create table_b first
|
||||
let table_b_columns = vec![
|
||||
("value_b", "NUMERIC(10, 2)"),
|
||||
("result_b", "NUMERIC(10, 2)"),
|
||||
let table_b_columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "value_b".into(), field_type: "NUMERIC(10, 2)".into() },
|
||||
ColumnDefinition { name: "result_b".into(), field_type: "NUMERIC(10, 2)".into() },
|
||||
];
|
||||
let table_b_id = create_test_table(&pool, schema_id, "table_b", table_b_columns).await;
|
||||
|
||||
// Create table_a
|
||||
let table_a_columns = vec![
|
||||
("value_a", "NUMERIC(10, 2)"),
|
||||
("result_a", "NUMERIC(10, 2)"),
|
||||
let table_a_columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "value_a".into(), field_type: "NUMERIC(10, 2)".into() },
|
||||
ColumnDefinition { name: "result_a".into(), field_type: "NUMERIC(10, 2)".into() },
|
||||
];
|
||||
let table_a_id = create_test_table(&pool, schema_id, "table_a", table_a_columns).await;
|
||||
|
||||
@@ -305,7 +300,9 @@ async fn test_edge_case_identifiers(
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![("result", "NUMERIC(10, 2)")];
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "result".into(), field_type: "NUMERIC(10, 2)".into() }
|
||||
];
|
||||
let table_id = create_test_table(&pool, schema_id, "identifier_test", columns).await;
|
||||
|
||||
// Test with edge case identifier in script
|
||||
@@ -342,7 +339,9 @@ async fn test_sql_injection_prevention() {
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![("result", "NUMERIC(10, 2)")];
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "result".into(), field_type: "NUMERIC(10, 2)".into() }
|
||||
];
|
||||
let table_id = create_test_table(&pool, schema_id, "injection_test", columns).await;
|
||||
|
||||
// Attempt SQL injection through script content
|
||||
@@ -388,9 +387,9 @@ async fn test_performance_with_deeply_nested_expressions() {
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![
|
||||
("x", "NUMERIC(15, 8)"),
|
||||
("performance_result", "NUMERIC(25, 12)"),
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "x".into(), field_type: "NUMERIC(15, 8)".into() },
|
||||
ColumnDefinition { name: "performance_result".into(), field_type: "NUMERIC(25, 12)".into() },
|
||||
];
|
||||
let table_id = create_test_table(&pool, schema_id, "performance_test", columns).await;
|
||||
|
||||
@@ -437,11 +436,11 @@ async fn test_concurrent_script_creation() {
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![
|
||||
("value", "NUMERIC(10, 2)"),
|
||||
("result1", "NUMERIC(10, 2)"),
|
||||
("result2", "NUMERIC(10, 2)"),
|
||||
("result3", "NUMERIC(10, 2)"),
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "value".into(), field_type: "NUMERIC(10, 2)".into() },
|
||||
ColumnDefinition { name: "result1".into(), field_type: "NUMERIC(10, 2)".into() },
|
||||
ColumnDefinition { name: "result2".into(), field_type: "NUMERIC(10, 2)".into() },
|
||||
ColumnDefinition { name: "result3".into(), field_type: "NUMERIC(10, 2)".into() },
|
||||
];
|
||||
let table_id = create_test_table(&pool, schema_id, "concurrent_test", columns).await;
|
||||
|
||||
@@ -500,9 +499,10 @@ async fn test_error_message_localization_and_clarity() {
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![
|
||||
("text_col", "TEXT"),
|
||||
("result", "NUMERIC(10, 2)"),
|
||||
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "text_col".into(), field_type: "TEXT".into() },
|
||||
ColumnDefinition { name: "result".into(), field_type: "NUMERIC(10, 2)".into() },
|
||||
];
|
||||
let table_id = create_test_table(&pool, schema_id, "error_clarity_test", columns).await;
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use crate::common::setup_isolated_db;
|
||||
use server::table_script::handlers::post_table_script::post_table_script; // Fixed import
|
||||
use common::proto::komp_ac::table_script::PostTableScriptRequest;
|
||||
use common::proto::komp_ac::table_definition::ColumnDefinition;
|
||||
use rstest::*;
|
||||
use serde_json::json;
|
||||
use sqlx::PgPool;
|
||||
@@ -12,15 +13,10 @@ async fn create_test_table(
|
||||
pool: &PgPool,
|
||||
schema_id: i64,
|
||||
table_name: &str,
|
||||
columns: Vec<(&str, &str)>,
|
||||
columns: Vec<ColumnDefinition>,
|
||||
) -> i64 {
|
||||
let column_definitions: Vec<String> = columns
|
||||
.iter()
|
||||
.map(|(name, type_def)| format!("\"{}\" {}", name, type_def))
|
||||
.collect();
|
||||
|
||||
let columns_json = json!(column_definitions);
|
||||
let indexes_json = json!([]);
|
||||
let columns_json = serde_json::to_value(columns).unwrap();
|
||||
let indexes_json = serde_json::json!([]);
|
||||
|
||||
sqlx::query_scalar!(
|
||||
r#"INSERT INTO table_definitions (schema_id, table_name, columns, indexes)
|
||||
@@ -97,7 +93,9 @@ async fn test_steel_decimal_literal_operations(
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![("result", "NUMERIC(30, 15)")];
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "NUMERIC(30, 15)".to_string() }
|
||||
];
|
||||
let table_id = create_test_table(&pool, schema_id, "literal_test", columns).await;
|
||||
|
||||
let script = format!(r#"({} "{}" "{}")"#, operation, value1, value2);
|
||||
@@ -133,9 +131,9 @@ async fn test_steel_decimal_column_operations(
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![
|
||||
("test_value", column_type),
|
||||
("result", "NUMERIC(30, 15)"),
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "test_value".to_string(), field_type: column_type.to_string() },
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "NUMERIC(30, 15)".to_string() },
|
||||
];
|
||||
let table_id = create_test_table(&pool, schema_id, "column_test", columns).await;
|
||||
|
||||
@@ -179,12 +177,12 @@ async fn test_complex_financial_calculation(
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
// Create a realistic financial calculation table
|
||||
let columns = vec![
|
||||
("principal", "NUMERIC(16, 2)"), // Principal amount
|
||||
("annual_rate", "NUMERIC(6, 5)"), // Interest rate
|
||||
("years", "INTEGER"), // Time period
|
||||
("compounding_periods", "INTEGER"), // Compounding frequency
|
||||
("compound_interest", "NUMERIC(20, 8)"), // Result
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "principal".to_string(), field_type: "NUMERIC(16, 2)".to_string() }, // Principal amount
|
||||
ColumnDefinition { name: "annual_rate".to_string(), field_type: "NUMERIC(6, 5)".to_string() }, // Interest rate
|
||||
ColumnDefinition { name: "years".to_string(), field_type: "INTEGER".to_string() }, // Time period
|
||||
ColumnDefinition { name: "compounding_periods".to_string(), field_type: "INTEGER".to_string() }, // Compounding frequency
|
||||
ColumnDefinition { name: "compound_interest".to_string(), field_type: "NUMERIC(20, 8)".to_string() }, // Result
|
||||
];
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "financial_calc", columns).await;
|
||||
@@ -217,11 +215,11 @@ async fn test_scientific_precision_calculations() {
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![
|
||||
("measurement_a", "NUMERIC(25, 15)"),
|
||||
("measurement_b", "NUMERIC(25, 15)"),
|
||||
("coefficient", "NUMERIC(10, 8)"),
|
||||
("scientific_result", "NUMERIC(30, 18)"),
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "measurement_a".to_string(), field_type: "NUMERIC(25, 15)".to_string() },
|
||||
ColumnDefinition { name: "measurement_b".to_string(), field_type: "NUMERIC(25, 15)".to_string() },
|
||||
ColumnDefinition { name: "coefficient".to_string(), field_type: "NUMERIC(10, 8)".to_string() },
|
||||
ColumnDefinition { name: "scientific_result".to_string(), field_type: "NUMERIC(30, 18)".to_string() },
|
||||
];
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "scientific_data", columns).await;
|
||||
@@ -259,9 +257,9 @@ async fn test_precision_boundary_conditions(
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![
|
||||
("boundary_value", numeric_type),
|
||||
("result", "NUMERIC(30, 15)"),
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "boundary_value".to_string(), field_type: numeric_type.to_string() },
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "NUMERIC(30, 15)".to_string() },
|
||||
];
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "boundary_test", columns).await;
|
||||
@@ -284,11 +282,11 @@ async fn test_mixed_integer_and_numeric_operations() {
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![
|
||||
("integer_quantity", "INTEGER"),
|
||||
("numeric_price", "NUMERIC(10, 4)"),
|
||||
("numeric_tax_rate", "NUMERIC(5, 4)"),
|
||||
("total_with_tax", "NUMERIC(15, 4)"),
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "integer_quantity".to_string(), field_type: "INTEGER".to_string() },
|
||||
ColumnDefinition { name: "numeric_price".to_string(), field_type: "NUMERIC(10, 4)".to_string() },
|
||||
ColumnDefinition { name: "numeric_tax_rate".to_string(), field_type: "NUMERIC(5, 4)".to_string() },
|
||||
ColumnDefinition { name: "total_with_tax".to_string(), field_type: "NUMERIC(15, 4)".to_string() },
|
||||
];
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "mixed_types_calc", columns).await;
|
||||
@@ -325,9 +323,9 @@ async fn test_mathematical_edge_cases(
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![
|
||||
("test_value", "NUMERIC(15, 6)"),
|
||||
("result", "NUMERIC(20, 8)"),
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "test_value".to_string(), field_type: "NUMERIC(15, 6)".to_string() },
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "NUMERIC(20, 8)".to_string() },
|
||||
];
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "edge_case_test", columns).await;
|
||||
@@ -381,10 +379,10 @@ async fn test_comparison_operations_with_valid_types() {
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![
|
||||
("value_a", "NUMERIC(10, 2)"),
|
||||
("value_b", "INTEGER"),
|
||||
("comparison_result", "BOOLEAN"),
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "value_a".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
ColumnDefinition { name: "value_b".to_string(), field_type: "INTEGER".to_string() },
|
||||
ColumnDefinition { name: "comparison_result".to_string(), field_type: "BOOLEAN".to_string() },
|
||||
];
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "comparison_test", columns).await;
|
||||
@@ -419,11 +417,11 @@ async fn test_nested_mathematical_expressions() {
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![
|
||||
("x", "NUMERIC(15, 8)"),
|
||||
("y", "NUMERIC(15, 8)"),
|
||||
("z", "INTEGER"),
|
||||
("nested_result", "NUMERIC(25, 12)"),
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "x".to_string(), field_type: "NUMERIC(15, 8)".to_string() },
|
||||
ColumnDefinition { name: "y".to_string(), field_type: "NUMERIC(15, 8)".to_string() },
|
||||
ColumnDefinition { name: "z".to_string(), field_type: "INTEGER".to_string() },
|
||||
ColumnDefinition { name: "nested_result".to_string(), field_type: "NUMERIC(25, 12)".to_string() },
|
||||
];
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "nested_calc", columns).await;
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use crate::common::setup_isolated_db;
|
||||
use server::table_script::handlers::post_table_script::post_table_script;
|
||||
use common::proto::komp_ac::table_script::{PostTableScriptRequest, TableScriptResponse};
|
||||
use common::proto::komp_ac::table_definition::ColumnDefinition;
|
||||
use serde_json::json;
|
||||
use sqlx::PgPool;
|
||||
|
||||
@@ -26,14 +27,9 @@ impl TableScriptTestHelper {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_table_with_types(&self, table_name: &str, column_definitions: Vec<(&str, &str)>) -> i64 {
|
||||
let columns: Vec<String> = column_definitions
|
||||
.iter()
|
||||
.map(|(name, type_def)| format!("\"{}\" {}", name, type_def))
|
||||
.collect();
|
||||
|
||||
let columns_json = json!(columns);
|
||||
let indexes_json = json!([]);
|
||||
pub async fn create_table_with_types(&self, table_name: &str, column_definitions: Vec<ColumnDefinition>) -> i64 {
|
||||
let columns_json = serde_json::to_value(column_definitions).unwrap();
|
||||
let indexes_json = serde_json::json!([]);
|
||||
|
||||
sqlx::query_scalar!(
|
||||
r#"INSERT INTO table_definitions (schema_id, table_name, columns, indexes)
|
||||
@@ -73,24 +69,24 @@ mod integration_tests {
|
||||
"comprehensive_table",
|
||||
vec![
|
||||
// Supported types for math operations
|
||||
("integer_col", "INTEGER"),
|
||||
("numeric_basic", "NUMERIC(10, 2)"),
|
||||
("numeric_high_precision", "NUMERIC(28, 15)"),
|
||||
("numeric_currency", "NUMERIC(14, 4)"),
|
||||
ColumnDefinition { name: "integer_col".to_string(), field_type: "INTEGER".to_string() },
|
||||
ColumnDefinition { name: "numeric_basic".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
ColumnDefinition { name: "numeric_high_precision".to_string(), field_type: "NUMERIC(28, 15)".to_string() },
|
||||
ColumnDefinition { name: "numeric_currency".to_string(), field_type: "NUMERIC(14, 4)".to_string() },
|
||||
|
||||
// Supported but not for math operations
|
||||
("text_col", "TEXT"),
|
||||
("boolean_col", "BOOLEAN"),
|
||||
ColumnDefinition { name: "text_col".to_string(), field_type: "TEXT".to_string() },
|
||||
ColumnDefinition { name: "boolean_col".to_string(), field_type: "BOOLEAN".to_string() },
|
||||
|
||||
// Prohibited types entirely
|
||||
("bigint_col", "BIGINT"),
|
||||
("date_col", "DATE"),
|
||||
("timestamp_col", "TIMESTAMPTZ"),
|
||||
ColumnDefinition { name: "bigint_col".to_string(), field_type: "BIGINT".to_string() },
|
||||
ColumnDefinition { name: "date_col".to_string(), field_type: "DATE".to_string() },
|
||||
ColumnDefinition { name: "timestamp_col".to_string(), field_type: "TIMESTAMPTZ".to_string() },
|
||||
|
||||
// Result columns of various types
|
||||
("result_integer", "INTEGER"),
|
||||
("result_numeric", "NUMERIC(15, 5)"),
|
||||
("result_text", "TEXT"),
|
||||
ColumnDefinition { name: "result_integer".to_string(), field_type: "INTEGER".to_string() },
|
||||
ColumnDefinition { name: "result_numeric".to_string(), field_type: "NUMERIC(15, 5)".to_string() },
|
||||
ColumnDefinition { name: "result_text".to_string(), field_type: "TEXT".to_string() },
|
||||
]
|
||||
).await;
|
||||
|
||||
@@ -150,13 +146,13 @@ mod integration_tests {
|
||||
let table_id = helper.create_table_with_types(
|
||||
"precision_table",
|
||||
vec![
|
||||
("low_precision", "NUMERIC(5, 2)"), // e.g., 999.99
|
||||
("medium_precision", "NUMERIC(10, 4)"), // e.g., 999999.9999
|
||||
("high_precision", "NUMERIC(28, 15)"), // Maximum PostgreSQL precision
|
||||
("currency", "NUMERIC(14, 4)"), // Standard currency precision
|
||||
("percentage", "NUMERIC(5, 4)"), // e.g., 0.9999 (99.99%)
|
||||
("integer_val", "INTEGER"),
|
||||
("result", "NUMERIC(30, 15)"),
|
||||
ColumnDefinition { name: "low_precision".to_string(), field_type: "NUMERIC(5, 2)".to_string() }, // e.g., 999.99
|
||||
ColumnDefinition { name: "medium_precision".to_string(), field_type: "NUMERIC(10, 4)".to_string() }, // e.g., 999999.9999
|
||||
ColumnDefinition { name: "high_precision".to_string(), field_type: "NUMERIC(28, 15)".to_string() }, // Maximum PostgreSQL precision
|
||||
ColumnDefinition { name: "currency".to_string(), field_type: "NUMERIC(14, 4)".to_string() }, // Standard currency precision
|
||||
ColumnDefinition { name: "percentage".to_string(), field_type: "NUMERIC(5, 4)".to_string() }, // e.g., 0.9999 (99.99%)
|
||||
ColumnDefinition { name: "integer_val".to_string(), field_type: "INTEGER".to_string() },
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "NUMERIC(30, 15)".to_string() },
|
||||
]
|
||||
).await;
|
||||
|
||||
@@ -202,12 +198,12 @@ mod integration_tests {
|
||||
let table_id = helper.create_table_with_types(
|
||||
"financial_instruments",
|
||||
vec![
|
||||
("principal", "NUMERIC(16, 2)"), // Principal amount
|
||||
("annual_rate", "NUMERIC(6, 5)"), // Interest rate (e.g., 0.05250)
|
||||
("years", "INTEGER"), // Time period
|
||||
("compounding_periods", "INTEGER"), // Compounding frequency
|
||||
("fees", "NUMERIC(10, 2)"), // Transaction fees
|
||||
("compound_interest", "NUMERIC(20, 8)"), // Result column
|
||||
ColumnDefinition { name: "principal".to_string(), field_type: "NUMERIC(16, 2)".to_string() }, // Principal amount
|
||||
ColumnDefinition { name: "annual_rate".to_string(), field_type: "NUMERIC(6, 5)".to_string() }, // Interest rate (e.g., 0.05250)
|
||||
ColumnDefinition { name: "years".to_string(), field_type: "INTEGER".to_string() }, // Time period
|
||||
ColumnDefinition { name: "compounding_periods".to_string(), field_type: "INTEGER".to_string() }, // Compounding frequency
|
||||
ColumnDefinition { name: "fees".to_string(), field_type: "NUMERIC(10, 2)".to_string() }, // Transaction fees
|
||||
ColumnDefinition { name: "compound_interest".to_string(), field_type: "NUMERIC(20, 8)".to_string() }, // Result column
|
||||
]
|
||||
).await;
|
||||
|
||||
@@ -237,9 +233,9 @@ mod integration_tests {
|
||||
let table_id = helper.create_table_with_types(
|
||||
"scientific_data",
|
||||
vec![
|
||||
("large_number", "NUMERIC(30, 10)"),
|
||||
("small_number", "NUMERIC(30, 20)"),
|
||||
("result", "NUMERIC(35, 25)"),
|
||||
ColumnDefinition { name: "large_number".to_string(), field_type: "NUMERIC(30, 10)".to_string() },
|
||||
ColumnDefinition { name: "small_number".to_string(), field_type: "NUMERIC(30, 20)".to_string() },
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "NUMERIC(35, 25)".to_string() },
|
||||
]
|
||||
).await;
|
||||
|
||||
@@ -265,8 +261,8 @@ mod integration_tests {
|
||||
let table_a_id = helper.create_table_with_types(
|
||||
"table_a",
|
||||
vec![
|
||||
("value_a", "NUMERIC(10, 2)"),
|
||||
("result_a", "NUMERIC(10, 2)"),
|
||||
ColumnDefinition { name: "value_a".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
ColumnDefinition { name: "result_a".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
]
|
||||
).await;
|
||||
println!("Created table_a with ID: {}", table_a_id);
|
||||
@@ -274,8 +270,8 @@ mod integration_tests {
|
||||
let table_b_id = helper.create_table_with_types(
|
||||
"table_b",
|
||||
vec![
|
||||
("value_b", "NUMERIC(10, 2)"),
|
||||
("result_b", "NUMERIC(10, 2)"),
|
||||
ColumnDefinition { name: "value_b".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
ColumnDefinition { name: "result_b".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
]
|
||||
).await;
|
||||
println!("Created table_b with ID: {}", table_b_id);
|
||||
@@ -354,10 +350,10 @@ mod integration_tests {
|
||||
let table_id = helper.create_table_with_types(
|
||||
"error_test_table",
|
||||
vec![
|
||||
("text_field", "TEXT"),
|
||||
("numeric_field", "NUMERIC(10, 2)"),
|
||||
("boolean_field", "BOOLEAN"),
|
||||
("bigint_field", "BIGINT"),
|
||||
ColumnDefinition { name: "text_field".to_string(), field_type: "TEXT".to_string() },
|
||||
ColumnDefinition { name: "numeric_field".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
ColumnDefinition { name: "boolean_field".to_string(), field_type: "BOOLEAN".to_string() },
|
||||
ColumnDefinition { name: "bigint_field".to_string(), field_type: "BIGINT".to_string() },
|
||||
]
|
||||
).await;
|
||||
|
||||
@@ -417,11 +413,11 @@ mod integration_tests {
|
||||
let table_id = helper.create_table_with_types(
|
||||
"performance_table",
|
||||
vec![
|
||||
("x", "NUMERIC(15, 8)"),
|
||||
("y", "NUMERIC(15, 8)"),
|
||||
("z", "NUMERIC(15, 8)"),
|
||||
("w", "NUMERIC(15, 8)"),
|
||||
("complex_result", "NUMERIC(25, 12)"),
|
||||
ColumnDefinition { name: "x".to_string(), field_type: "NUMERIC(15, 8)".to_string() },
|
||||
ColumnDefinition { name: "y".to_string(), field_type: "NUMERIC(15, 8)".to_string() },
|
||||
ColumnDefinition { name: "z".to_string(), field_type: "NUMERIC(15, 8)".to_string() },
|
||||
ColumnDefinition { name: "w".to_string(), field_type: "NUMERIC(15, 8)".to_string() },
|
||||
ColumnDefinition { name: "complex_result".to_string(), field_type: "NUMERIC(25, 12)".to_string() },
|
||||
]
|
||||
).await;
|
||||
|
||||
@@ -456,11 +452,11 @@ mod integration_tests {
|
||||
let table_id = helper.create_table_with_types(
|
||||
"boundary_table",
|
||||
vec![
|
||||
("min_numeric", "NUMERIC(1, 0)"), // Minimum: single digit, no decimal
|
||||
("max_numeric", "NUMERIC(1000, 999)"), // Maximum PostgreSQL allows
|
||||
("zero_scale", "NUMERIC(10, 0)"), // Integer-like numeric
|
||||
("max_scale", "NUMERIC(28, 28)"), // Maximum scale
|
||||
("result", "NUMERIC(1000, 999)"),
|
||||
ColumnDefinition { name: "min_numeric".to_string(), field_type: "NUMERIC(1, 0)".to_string() }, // Minimum: single digit, no decimal
|
||||
ColumnDefinition { name: "max_numeric".to_string(), field_type: "NUMERIC(1000, 999)".to_string() }, // Maximum PostgreSQL allows
|
||||
ColumnDefinition { name: "zero_scale".to_string(), field_type: "NUMERIC(10, 0)".to_string() }, // Integer-like numeric
|
||||
ColumnDefinition { name: "max_scale".to_string(), field_type: "NUMERIC(28, 28)".to_string() }, // Maximum scale
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "NUMERIC(1000, 999)".to_string() },
|
||||
]
|
||||
).await;
|
||||
|
||||
@@ -495,10 +491,10 @@ mod steel_decimal_integration_tests {
|
||||
let table_id = helper.create_table_with_types(
|
||||
"test_execution_table",
|
||||
vec![
|
||||
("amount", "NUMERIC(10, 2)"),
|
||||
("quantity", "INTEGER"),
|
||||
("tax_rate", "NUMERIC(5, 4)"),
|
||||
("result", "NUMERIC(15, 4)"), // Add a result column
|
||||
ColumnDefinition { name: "amount".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
ColumnDefinition { name: "quantity".to_string(), field_type: "INTEGER".to_string() },
|
||||
ColumnDefinition { name: "tax_rate".to_string(), field_type: "NUMERIC(5, 4)".to_string() },
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "NUMERIC(15, 4)".to_string() }, // Add a result column
|
||||
]
|
||||
).await;
|
||||
println!("Created test table with ID: {}", table_id);
|
||||
@@ -575,9 +571,9 @@ mod steel_decimal_integration_tests {
|
||||
let table_id = helper.create_table_with_types(
|
||||
"precision_test_table",
|
||||
vec![
|
||||
("precise_value", "NUMERIC(20, 12)"),
|
||||
("multiplier", "NUMERIC(20, 12)"),
|
||||
("result", "NUMERIC(25, 15)"), // Add result column
|
||||
ColumnDefinition { name: "precise_value".to_string(), field_type: "NUMERIC(20, 12)".to_string() },
|
||||
ColumnDefinition { name: "multiplier".to_string(), field_type: "NUMERIC(20, 12)".to_string() },
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "NUMERIC(25, 15)".to_string() }, // Add result column
|
||||
]
|
||||
).await;
|
||||
println!("Created precision test table with ID: {}", table_id);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use crate::common::setup_isolated_db;
|
||||
use server::table_script::handlers::post_table_script::post_table_script;
|
||||
use common::proto::komp_ac::table_script::PostTableScriptRequest;
|
||||
use common::proto::komp_ac::table_definition::ColumnDefinition;
|
||||
use serde_json::json;
|
||||
use sqlx::PgPool;
|
||||
|
||||
@@ -11,15 +12,10 @@ async fn create_test_table(
|
||||
pool: &PgPool,
|
||||
schema_id: i64,
|
||||
table_name: &str,
|
||||
columns: Vec<(&str, &str)>,
|
||||
columns: Vec<ColumnDefinition>,
|
||||
) -> i64 {
|
||||
let column_definitions: Vec<String> = columns
|
||||
.iter()
|
||||
.map(|(name, type_def)| format!("\"{}\" {}", name, type_def))
|
||||
.collect();
|
||||
|
||||
let columns_json = json!(column_definitions);
|
||||
let indexes_json = json!([]);
|
||||
let columns_json = serde_json::to_value(columns).unwrap();
|
||||
let indexes_json = serde_json::json!([]);
|
||||
|
||||
sqlx::query_scalar!(
|
||||
r#"INSERT INTO table_definitions (schema_id, table_name, columns, indexes)
|
||||
@@ -67,7 +63,10 @@ async fn test_reject_bigint_target_column() {
|
||||
&pool,
|
||||
schema_id,
|
||||
"bigint_table",
|
||||
vec![("name", "TEXT"), ("big_number", "BIGINT")]
|
||||
vec![
|
||||
ColumnDefinition { name: "name".to_string(), field_type: "TEXT".to_string() },
|
||||
ColumnDefinition { name: "big_number".to_string(), field_type: "BIGINT".to_string() }
|
||||
]
|
||||
).await;
|
||||
|
||||
let request = PostTableScriptRequest {
|
||||
@@ -99,7 +98,10 @@ async fn test_reject_date_target_column() {
|
||||
&pool,
|
||||
schema_id,
|
||||
"date_table",
|
||||
vec![("name", "TEXT"), ("event_date", "DATE")]
|
||||
vec![
|
||||
ColumnDefinition { name: "name".to_string(), field_type: "TEXT".to_string() },
|
||||
ColumnDefinition { name: "event_date".to_string(), field_type: "DATE".to_string() }
|
||||
]
|
||||
).await;
|
||||
|
||||
let request = PostTableScriptRequest {
|
||||
@@ -131,7 +133,10 @@ async fn test_reject_timestamptz_target_column() {
|
||||
&pool,
|
||||
schema_id,
|
||||
"timestamp_table",
|
||||
vec![("name", "TEXT"), ("created_time", "TIMESTAMPTZ")]
|
||||
vec![
|
||||
ColumnDefinition { name: "name".to_string(), field_type: "TEXT".to_string() },
|
||||
ColumnDefinition { name: "created_time".to_string(), field_type: "TIMESTAMPTZ".to_string() }
|
||||
]
|
||||
).await;
|
||||
|
||||
let request = PostTableScriptRequest {
|
||||
@@ -164,9 +169,9 @@ async fn test_reject_text_in_mathematical_operations() {
|
||||
schema_id,
|
||||
"text_math_table",
|
||||
vec![
|
||||
("description", "TEXT"),
|
||||
("amount", "NUMERIC(10, 2)"),
|
||||
("result", "NUMERIC(10, 2)")
|
||||
ColumnDefinition { name: "description".to_string(), field_type: "TEXT".to_string() },
|
||||
ColumnDefinition { name: "amount".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "NUMERIC(10, 2)".to_string() }
|
||||
]
|
||||
).await;
|
||||
|
||||
@@ -202,9 +207,9 @@ async fn test_reject_boolean_in_mathematical_operations() {
|
||||
schema_id,
|
||||
"boolean_math_table",
|
||||
vec![
|
||||
("is_active", "BOOLEAN"),
|
||||
("amount", "NUMERIC(10, 2)"),
|
||||
("result", "NUMERIC(10, 2)")
|
||||
ColumnDefinition { name: "is_active".to_string(), field_type: "BOOLEAN".to_string() },
|
||||
ColumnDefinition { name: "amount".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "NUMERIC(10, 2)".to_string() }
|
||||
]
|
||||
).await;
|
||||
|
||||
@@ -240,8 +245,8 @@ async fn test_reject_bigint_in_mathematical_operations() {
|
||||
schema_id,
|
||||
"bigint_math_table",
|
||||
vec![
|
||||
("big_value", "BIGINT"),
|
||||
("result", "NUMERIC(10, 2)")
|
||||
ColumnDefinition { name: "big_value".to_string(), field_type: "BIGINT".to_string() },
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "NUMERIC(10, 2)".to_string() }
|
||||
]
|
||||
).await;
|
||||
|
||||
@@ -277,10 +282,10 @@ async fn test_allow_valid_script_with_allowed_types() {
|
||||
schema_id,
|
||||
"allowed_types_table",
|
||||
vec![
|
||||
("name", "TEXT"),
|
||||
("count", "INTEGER"),
|
||||
("amount", "NUMERIC(10, 2)"),
|
||||
("computed_value", "TEXT")
|
||||
ColumnDefinition { name: "name".to_string(), field_type: "TEXT".to_string() },
|
||||
ColumnDefinition { name: "count".to_string(), field_type: "INTEGER".to_string() },
|
||||
ColumnDefinition { name: "amount".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
ColumnDefinition { name: "computed_value".to_string(), field_type: "TEXT".to_string() }
|
||||
]
|
||||
).await;
|
||||
|
||||
@@ -312,9 +317,9 @@ async fn test_allow_integer_and_numeric_in_math_operations() {
|
||||
schema_id,
|
||||
"math_allowed_table",
|
||||
vec![
|
||||
("quantity", "INTEGER"),
|
||||
("price", "NUMERIC(10, 2)"),
|
||||
("total", "NUMERIC(12, 2)")
|
||||
ColumnDefinition { name: "quantity".to_string(), field_type: "INTEGER".to_string() },
|
||||
ColumnDefinition { name: "price".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
ColumnDefinition { name: "total".to_string(), field_type: "NUMERIC(12, 2)".to_string() }
|
||||
]
|
||||
).await;
|
||||
|
||||
@@ -328,7 +333,7 @@ async fn test_allow_integer_and_numeric_in_math_operations() {
|
||||
)
|
||||
.fetch_one(&pool)
|
||||
.await;
|
||||
|
||||
|
||||
println!("Table verification: {:?}", table_check);
|
||||
|
||||
let request = PostTableScriptRequest {
|
||||
@@ -341,7 +346,7 @@ async fn test_allow_integer_and_numeric_in_math_operations() {
|
||||
|
||||
println!("About to call post_table_script");
|
||||
let result = post_table_script(&pool, request).await;
|
||||
|
||||
|
||||
// SHOW THE ACTUAL ERROR
|
||||
if let Err(e) = &result {
|
||||
println!("ERROR: {}", e);
|
||||
@@ -363,14 +368,19 @@ async fn test_script_without_table_links_should_fail() {
|
||||
&pool,
|
||||
schema_id,
|
||||
"table_a",
|
||||
vec![("value_a", "INTEGER"), ("result", "INTEGER")]
|
||||
vec![
|
||||
ColumnDefinition { name: "value_a".to_string(), field_type: "INTEGER".to_string() },
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "INTEGER".to_string() }
|
||||
]
|
||||
).await;
|
||||
|
||||
let _table_b_id = create_test_table(
|
||||
&pool,
|
||||
schema_id,
|
||||
"table_b",
|
||||
vec![("value_b", "INTEGER")]
|
||||
"table_b",
|
||||
vec![
|
||||
ColumnDefinition { name: "value_b".to_string(), field_type: "INTEGER".to_string() }
|
||||
]
|
||||
).await;
|
||||
|
||||
// DON'T create a link between the tables
|
||||
@@ -404,14 +414,19 @@ async fn test_script_with_table_links_should_succeed() {
|
||||
&pool,
|
||||
schema_id,
|
||||
"linked_table_a",
|
||||
vec![("value_a", "INTEGER"), ("result", "INTEGER")]
|
||||
vec![
|
||||
ColumnDefinition { name: "value_a".to_string(), field_type: "INTEGER".to_string() },
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "INTEGER".to_string() }
|
||||
]
|
||||
).await;
|
||||
|
||||
let table_b_id = create_test_table(
|
||||
&pool,
|
||||
schema_id,
|
||||
"linked_table_b",
|
||||
vec![("value_b", "INTEGER")]
|
||||
vec![
|
||||
ColumnDefinition { name: "value_b".to_string(), field_type: "INTEGER".to_string() }
|
||||
]
|
||||
).await;
|
||||
|
||||
// Create a link between the tables (table_a can access table_b)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use crate::common::setup_isolated_db;
|
||||
use server::table_script::handlers::post_table_script::post_table_script;
|
||||
use common::proto::komp_ac::table_script::PostTableScriptRequest;
|
||||
use common::proto::komp_ac::table_definition::ColumnDefinition;
|
||||
use rstest::*;
|
||||
use serde_json::json;
|
||||
use sqlx::PgPool;
|
||||
@@ -76,15 +77,10 @@ async fn create_test_table(
|
||||
pool: &PgPool,
|
||||
schema_id: i64,
|
||||
table_name: &str,
|
||||
columns: Vec<(&str, &str)>,
|
||||
columns: Vec<ColumnDefinition>,
|
||||
) -> i64 {
|
||||
let column_definitions: Vec<String> = columns
|
||||
.iter()
|
||||
.map(|(name, type_def)| format!("\"{}\" {}", name, type_def))
|
||||
.collect();
|
||||
|
||||
let columns_json = json!(column_definitions);
|
||||
let indexes_json = json!([]);
|
||||
let columns_json = serde_json::to_value(columns).unwrap();
|
||||
let indexes_json = serde_json::json!([]);
|
||||
|
||||
sqlx::query_scalar!(
|
||||
r#"INSERT INTO table_definitions (schema_id, table_name, columns, indexes)
|
||||
@@ -123,8 +119,17 @@ async fn test_allowed_types_in_math_operations(
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
// Create table with all allowed mathematical types plus a result column
|
||||
let mut columns = allowed_math_types.clone();
|
||||
columns.push(("result", "NUMERIC(30, 15)"));
|
||||
let mut columns: Vec<ColumnDefinition> = allowed_math_types
|
||||
.iter()
|
||||
.map(|(name, field_type)| ColumnDefinition {
|
||||
name: name.to_string(),
|
||||
field_type: field_type.to_string(),
|
||||
})
|
||||
.collect();
|
||||
columns.push(ColumnDefinition {
|
||||
name: "result".to_string(),
|
||||
field_type: "NUMERIC(30, 15)".to_string(),
|
||||
});
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "math_test_table", columns).await;
|
||||
|
||||
@@ -172,8 +177,17 @@ async fn test_prohibited_types_in_math_operations(
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
// Create table with prohibited types plus a valid result column
|
||||
let mut columns = prohibited_math_types.clone();
|
||||
columns.push(("result", "NUMERIC(15, 6)"));
|
||||
let mut columns: Vec<ColumnDefinition> = prohibited_math_types
|
||||
.iter()
|
||||
.map(|(name, field_type)| ColumnDefinition {
|
||||
name: name.to_string(),
|
||||
field_type: field_type.to_string(),
|
||||
})
|
||||
.collect();
|
||||
columns.push(ColumnDefinition {
|
||||
name: "result".to_string(),
|
||||
field_type: "NUMERIC(15, 6)".to_string(),
|
||||
});
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "prohibited_math_table", columns).await;
|
||||
|
||||
@@ -225,8 +239,17 @@ async fn test_prohibited_target_column_types(
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
// Create table with prohibited target types plus some valid source columns
|
||||
let mut columns = prohibited_target_types.clone();
|
||||
columns.push(("amount", "NUMERIC(10, 2)"));
|
||||
let mut columns: Vec<ColumnDefinition> = prohibited_target_types
|
||||
.iter()
|
||||
.map(|(name, field_type)| ColumnDefinition {
|
||||
name: name.to_string(),
|
||||
field_type: field_type.to_string(),
|
||||
})
|
||||
.collect();
|
||||
columns.push(ColumnDefinition {
|
||||
name: "amount".to_string(),
|
||||
field_type: "NUMERIC(10, 2)".to_string(),
|
||||
});
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "prohibited_target_table", columns).await;
|
||||
|
||||
@@ -245,7 +268,7 @@ async fn test_prohibited_target_column_types(
|
||||
|
||||
let error_message = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
error_message.to_lowercase().contains("cannot create script") ||
|
||||
error_message.to_lowercase().contains("cannot create script") ||
|
||||
error_message.contains("prohibited type"),
|
||||
"Error should mention prohibited type: {}",
|
||||
error_message
|
||||
@@ -261,7 +284,12 @@ async fn test_system_column_restrictions(#[case] target_column: &str, #[case] de
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![("amount", "NUMERIC(10, 2)")];
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition {
|
||||
name: "amount".to_string(),
|
||||
field_type: "NUMERIC(10, 2)".to_string(),
|
||||
}
|
||||
];
|
||||
let table_id = create_test_table(&pool, schema_id, "system_test_table", columns).await;
|
||||
|
||||
let script = r#"(+ "10" "20")"#;
|
||||
@@ -290,22 +318,22 @@ async fn test_comprehensive_type_matrix() {
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
// Create comprehensive table with all type categories
|
||||
let all_columns = vec![
|
||||
let all_columns: Vec<ColumnDefinition> = vec![
|
||||
// Allowed math types
|
||||
("integer_col", "INTEGER"),
|
||||
("numeric_col", "NUMERIC(10, 2)"),
|
||||
("high_precision", "NUMERIC(28, 15)"),
|
||||
|
||||
ColumnDefinition { name: "integer_col".to_string(), field_type: "INTEGER".to_string() },
|
||||
ColumnDefinition { name: "numeric_col".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
ColumnDefinition { name: "high_precision".to_string(), field_type: "NUMERIC(28, 15)".to_string() },
|
||||
|
||||
// Prohibited math types
|
||||
("text_col", "TEXT"),
|
||||
("boolean_col", "BOOLEAN"),
|
||||
("bigint_col", "BIGINT"),
|
||||
("date_col", "DATE"),
|
||||
("timestamp_col", "TIMESTAMPTZ"),
|
||||
|
||||
ColumnDefinition { name: "text_col".to_string(), field_type: "TEXT".to_string() },
|
||||
ColumnDefinition { name: "boolean_col".to_string(), field_type: "BOOLEAN".to_string() },
|
||||
ColumnDefinition { name: "bigint_col".to_string(), field_type: "BIGINT".to_string() },
|
||||
ColumnDefinition { name: "date_col".to_string(), field_type: "DATE".to_string() },
|
||||
ColumnDefinition { name: "timestamp_col".to_string(), field_type: "TIMESTAMPTZ".to_string() },
|
||||
|
||||
// Result columns
|
||||
("result_numeric", "NUMERIC(20, 8)"),
|
||||
("result_text", "TEXT"),
|
||||
ColumnDefinition { name: "result_numeric".to_string(), field_type: "NUMERIC(20, 8)".to_string() },
|
||||
ColumnDefinition { name: "result_text".to_string(), field_type: "TEXT".to_string() },
|
||||
];
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "comprehensive_table", all_columns).await;
|
||||
@@ -316,7 +344,7 @@ async fn test_comprehensive_type_matrix() {
|
||||
("integer_col", "+", "result_numeric", true),
|
||||
("numeric_col", "*", "result_numeric", true),
|
||||
("high_precision", "/", "result_numeric", true),
|
||||
|
||||
|
||||
// Invalid source types in math
|
||||
("text_col", "+", "result_numeric", false),
|
||||
("boolean_col", "*", "result_numeric", false),
|
||||
@@ -361,20 +389,20 @@ async fn test_complex_mathematical_expressions() {
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![
|
||||
("principal", "NUMERIC(16, 2)"),
|
||||
("rate", "NUMERIC(6, 5)"),
|
||||
("years", "INTEGER"),
|
||||
("compound_result", "NUMERIC(20, 8)"),
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "principal".to_string(), field_type: "NUMERIC(16, 2)".to_string() },
|
||||
ColumnDefinition { name: "rate".to_string(), field_type: "NUMERIC(6, 5)".to_string() },
|
||||
ColumnDefinition { name: "years".to_string(), field_type: "INTEGER".to_string() },
|
||||
ColumnDefinition { name: "compound_result".to_string(), field_type: "NUMERIC(20, 8)".to_string() },
|
||||
];
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "financial_table", columns).await;
|
||||
|
||||
// Complex compound interest calculation - all using allowed types
|
||||
let complex_script = r#"
|
||||
(*
|
||||
(*
|
||||
(steel_get_column "financial_table" "principal")
|
||||
(pow
|
||||
(pow
|
||||
(+ "1" (steel_get_column "financial_table" "rate"))
|
||||
(steel_get_column "financial_table" "years")))
|
||||
"#;
|
||||
@@ -395,9 +423,9 @@ async fn test_nonexistent_column_reference() {
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![
|
||||
("amount", "NUMERIC(10, 2)"),
|
||||
("result", "NUMERIC(10, 2)"),
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "amount".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
];
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "simple_table", columns).await;
|
||||
@@ -427,9 +455,9 @@ async fn test_nonexistent_table_reference() {
|
||||
let pool = setup_isolated_db().await;
|
||||
let schema_id = get_default_schema_id(&pool).await;
|
||||
|
||||
let columns = vec![
|
||||
("amount", "NUMERIC(10, 2)"),
|
||||
("result", "NUMERIC(10, 2)"),
|
||||
let columns: Vec<ColumnDefinition> = vec![
|
||||
ColumnDefinition { name: "amount".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
ColumnDefinition { name: "result".to_string(), field_type: "NUMERIC(10, 2)".to_string() },
|
||||
];
|
||||
|
||||
let table_id = create_test_table(&pool, schema_id, "existing_table", columns).await;
|
||||
|
||||
@@ -67,10 +67,10 @@ async fn table_definition(#[future] schema: (PgPool, String, i64)) -> (PgPool, S
|
||||
|
||||
// Define columns and indexes for the table
|
||||
let columns = json!([
|
||||
"\"name\" TEXT",
|
||||
"\"age\" INTEGER",
|
||||
"\"email\" TEXT",
|
||||
"\"is_active\" BOOLEAN"
|
||||
{ "name": "name", "field_type": "text" },
|
||||
{ "name": "age", "field_type": "integer" },
|
||||
{ "name": "email", "field_type": "text" },
|
||||
{ "name": "is_active", "field_type": "boolean" }
|
||||
]);
|
||||
let indexes = json!([]);
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ async fn create_initial_record(
|
||||
// Set different initial values based on the test case to satisfy validation scripts
|
||||
match (profile_name, table_name) {
|
||||
("test_put_complex", "order") => {
|
||||
// For complex formula: (+ (* @price @quantity) (* (* @price @quantity) 0.08))
|
||||
// For complex formula: (+ (* $price $quantity) (* (* $price $quantity) 0.08))
|
||||
// With price=10.00, quantity=1: (10*1) + (10*1*0.08) = 10 + 0.8 = 10.8
|
||||
data.insert("price".to_string(), ProtoValue { kind: Some(Kind::StringValue("10.00".to_string())) });
|
||||
data.insert("quantity".to_string(), ProtoValue { kind: Some(Kind::NumberValue(1.0)) });
|
||||
@@ -99,7 +99,7 @@ async fn create_initial_record(
|
||||
data.insert("percentage".to_string(), ProtoValue { kind: Some(Kind::StringValue("100.00".to_string())) });
|
||||
},
|
||||
("test_put_division", "calculation") => {
|
||||
// For division: (/ @total @price)
|
||||
// For division: (/ $total $price)
|
||||
// With total=10.00, price=10.00: 10/10 = 1
|
||||
data.insert("price".to_string(), ProtoValue { kind: Some(Kind::StringValue("10.00".to_string())) });
|
||||
data.insert("quantity".to_string(), ProtoValue { kind: Some(Kind::NumberValue(1.0)) });
|
||||
@@ -142,7 +142,7 @@ async fn test_put_basic_arithmetic_validation_success(pool: PgPool) {
|
||||
let script_request = PostTableScriptRequest {
|
||||
table_definition_id: table_def_id,
|
||||
target_column: "total".to_string(),
|
||||
script: "(* @price @quantity)".to_string(),
|
||||
script: "(* $price $quantity)".to_string(),
|
||||
description: "Total = Price × Quantity".to_string(),
|
||||
};
|
||||
post_table_script(&pool, script_request).await.unwrap();
|
||||
@@ -180,7 +180,7 @@ async fn test_put_basic_arithmetic_validation_failure(pool: PgPool) {
|
||||
let script_request = PostTableScriptRequest {
|
||||
table_definition_id: table_def_id,
|
||||
target_column: "total".to_string(),
|
||||
script: "(* @price @quantity)".to_string(),
|
||||
script: "(* $price $quantity)".to_string(),
|
||||
description: "Total = Price × Quantity".to_string(),
|
||||
};
|
||||
post_table_script(&pool, script_request).await.unwrap();
|
||||
@@ -224,7 +224,7 @@ async fn test_put_complex_formula_validation(pool: PgPool) {
|
||||
let script_request = PostTableScriptRequest {
|
||||
table_definition_id: table_def_id,
|
||||
target_column: "total".to_string(),
|
||||
script: "(+ (* @price @quantity) (* (* @price @quantity) 0.08))".to_string(),
|
||||
script: "(+ (* $price $quantity) (* (* $price $quantity) 0.08))".to_string(),
|
||||
description: "Total with 8% tax".to_string(),
|
||||
};
|
||||
post_table_script(&pool, script_request).await.unwrap();
|
||||
@@ -261,7 +261,7 @@ async fn test_put_division_with_precision(pool: PgPool) {
|
||||
let script_request = PostTableScriptRequest {
|
||||
table_definition_id: table_def_id,
|
||||
target_column: "percentage".to_string(),
|
||||
script: "(/ @total @price)".to_string(),
|
||||
script: "(/ $total $price)".to_string(),
|
||||
description: "Percentage = Total / Price".to_string(),
|
||||
};
|
||||
post_table_script(&pool, script_request).await.unwrap();
|
||||
@@ -326,7 +326,7 @@ async fn test_put_advanced_math_functions(pool: PgPool) {
|
||||
let sqrt_script = PostTableScriptRequest {
|
||||
table_definition_id: table_row.id,
|
||||
target_column: "square_root".to_string(),
|
||||
script: "(sqrt @input)".to_string(),
|
||||
script: "(sqrt $input)".to_string(),
|
||||
description: "Square root validation".to_string(),
|
||||
};
|
||||
post_table_script(&pool, sqrt_script).await.unwrap();
|
||||
@@ -334,7 +334,7 @@ async fn test_put_advanced_math_functions(pool: PgPool) {
|
||||
let power_script = PostTableScriptRequest {
|
||||
table_definition_id: table_row.id,
|
||||
target_column: "power_result".to_string(),
|
||||
script: "(^ @input 2.0)".to_string(),
|
||||
script: "(^ $input 2.0)".to_string(),
|
||||
description: "Power function validation".to_string(),
|
||||
};
|
||||
post_table_script(&pool, power_script).await.unwrap();
|
||||
@@ -389,7 +389,7 @@ async fn test_put_financial_calculations(pool: PgPool) {
|
||||
let compound_script = PostTableScriptRequest {
|
||||
table_definition_id: table_row.id,
|
||||
target_column: "compound_result".to_string(),
|
||||
script: "(* @principal (^ (+ 1.0 @rate) @time))".to_string(),
|
||||
script: "(* $principal (^ (+ 1.0 $rate) $time))".to_string(),
|
||||
description: "Compound interest calculation".to_string(),
|
||||
};
|
||||
post_table_script(&pool, compound_script).await.unwrap();
|
||||
@@ -397,7 +397,7 @@ async fn test_put_financial_calculations(pool: PgPool) {
|
||||
let percentage_script = PostTableScriptRequest {
|
||||
table_definition_id: table_row.id,
|
||||
target_column: "percentage_result".to_string(),
|
||||
script: "(* @principal @rate)".to_string(),
|
||||
script: "(* $principal $rate)".to_string(),
|
||||
description: "Percentage calculation".to_string(),
|
||||
};
|
||||
post_table_script(&pool, percentage_script).await.unwrap();
|
||||
@@ -441,15 +441,13 @@ async fn test_put_partial_update_with_validation(pool: PgPool) {
|
||||
let script_request = PostTableScriptRequest {
|
||||
table_definition_id: table_def_id,
|
||||
target_column: "total".to_string(),
|
||||
script: "(* @price @quantity)".to_string(),
|
||||
script: r#"( * (get-var "price") (get-var "quantity") )"#.to_string(),
|
||||
description: "Total = Price × Quantity".to_string(),
|
||||
};
|
||||
post_table_script(&pool, script_request).await.unwrap();
|
||||
|
||||
let record_id = create_initial_record(&pool, "test_put_partial", "invoice", &indexer_tx).await;
|
||||
|
||||
// Partial update: only update quantity. The script detects this would change total
|
||||
// from 10.00 to 50.00 and requires the user to include 'total' in the update.
|
||||
let mut update_data = HashMap::new();
|
||||
update_data.insert("quantity".to_string(), ProtoValue {
|
||||
kind: Some(Kind::NumberValue(5.0)),
|
||||
@@ -462,16 +460,16 @@ async fn test_put_partial_update_with_validation(pool: PgPool) {
|
||||
data: update_data,
|
||||
};
|
||||
|
||||
// This should fail because script would change total value
|
||||
// This should fail because script would change total value (Case B: implicit change detection)
|
||||
let result = put_table_data(&pool, put_request, &indexer_tx).await;
|
||||
assert!(result.is_err());
|
||||
let error = result.unwrap_err();
|
||||
assert_eq!(error.code(), tonic::Code::FailedPrecondition);
|
||||
assert!(error.message().contains("Script for column 'total' was triggered"));
|
||||
assert!(error.message().contains("from '10.00' to '50.00'"));
|
||||
let msg = error.message();
|
||||
assert!(msg.contains("Script for column 'total' was triggered"));
|
||||
assert!(msg.contains("from '10.00' to '50.00'"));
|
||||
assert!(msg.contains("include 'total' in your update request")); // Full change detection msg
|
||||
|
||||
// Now, test a partial update that SHOULD fail validation.
|
||||
// We update quantity and provide an incorrect total.
|
||||
let mut failing_update_data = HashMap::new();
|
||||
failing_update_data.insert("quantity".to_string(), ProtoValue {
|
||||
kind: Some(Kind::NumberValue(3.0)),
|
||||
@@ -491,8 +489,9 @@ async fn test_put_partial_update_with_validation(pool: PgPool) {
|
||||
assert!(result.is_err());
|
||||
let error = result.unwrap_err();
|
||||
assert_eq!(error.code(), tonic::Code::InvalidArgument);
|
||||
assert!(error.message().contains("Script calculated '30.00'"));
|
||||
assert!(error.message().contains("but user provided '99.99'"));
|
||||
let msg = error.message();
|
||||
assert!(msg.contains("Script calculated '30.00'"));
|
||||
assert!(msg.contains("but user provided '99.99'"));
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
@@ -553,7 +552,7 @@ async fn test_put_steel_script_error_handling(pool: PgPool) {
|
||||
let script_request = PostTableScriptRequest {
|
||||
table_definition_id: table_def_id,
|
||||
target_column: "total".to_string(),
|
||||
script: "(/ @price 0.0)".to_string(),
|
||||
script: "(/ $price 0.0)".to_string(),
|
||||
description: "Error test".to_string(),
|
||||
};
|
||||
post_table_script(&pool, script_request).await.unwrap();
|
||||
@@ -623,7 +622,7 @@ async fn test_decimal_precision_behavior(pool: PgPool) {
|
||||
let script_request = PostTableScriptRequest {
|
||||
table_definition_id: table_row.id,
|
||||
target_column: "result".to_string(),
|
||||
script: "(/ @dividend @divisor)".to_string(),
|
||||
script: "(/ $dividend $divisor)".to_string(),
|
||||
description: "Division test for precision".to_string(),
|
||||
};
|
||||
post_table_script(&pool, script_request).await.unwrap();
|
||||
@@ -816,7 +815,7 @@ async fn test_put_complex_formula_validation_via_handlers(pool: PgPool) {
|
||||
"test_put_complex_handlers",
|
||||
"order",
|
||||
"total",
|
||||
"(+ (* @price @quantity) (* (* @price @quantity) 0.08))", // Total with 8% tax
|
||||
"(+ (* $price $quantity) (* (* $price $quantity) 0.08))", // Total with 8% tax
|
||||
)
|
||||
.await
|
||||
.expect("Failed to add validation script");
|
||||
@@ -891,7 +890,7 @@ async fn test_put_basic_arithmetic_validation_via_handlers(pool: PgPool) {
|
||||
"test_put_arithmetic_handlers",
|
||||
"invoice",
|
||||
"total",
|
||||
"(* @price @quantity)", // Simple: Total = Price × Quantity
|
||||
"(* $price $quantity)", // Simple: Total = Price × Quantity
|
||||
)
|
||||
.await
|
||||
.expect("Failed to add validation script");
|
||||
@@ -955,7 +954,7 @@ async fn test_put_arithmetic_validation_failure_via_handlers(pool: PgPool) {
|
||||
"test_put_arithmetic_fail_handlers",
|
||||
"invoice",
|
||||
"total",
|
||||
"(* @price @quantity)",
|
||||
"(* $price $quantity)",
|
||||
)
|
||||
.await
|
||||
.expect("Failed to add validation script");
|
||||
|
||||
Reference in New Issue
Block a user