diff --git a/server/src/steel/server/functions.rs b/server/src/steel/server/functions.rs index 6a68d8d..042df15 100644 --- a/server/src/steel/server/functions.rs +++ b/server/src/steel/server/functions.rs @@ -1,5 +1,6 @@ // src/steel/server/functions.rs +use common::proto::komp_ac::table_definition::ColumnDefinition; use steel::rvals::SteelVal; use sqlx::PgPool; use std::collections::HashMap; @@ -66,17 +67,12 @@ impl SteelContext { .map_err(|e| FunctionError::DatabaseError(e.to_string()))? .ok_or_else(|| FunctionError::TableNotFound(table_name.to_string()))?; - let columns: Vec = serde_json::from_value(table_def.columns) + let columns: Vec = serde_json::from_value(table_def.columns) .map_err(|e| FunctionError::DatabaseError(format!("Invalid column data: {}", e)))?; - // Parse column definitions to find the requested column type - for column_def in columns { - let mut parts = column_def.split_whitespace(); - if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) { - let column_name_clean = name.trim_matches('"'); - if column_name_clean == column_name { - return Ok(data_type.to_string()); - } + for col_def in columns { + if col_def.name == column_name { + return Ok(col_def.field_type.to_uppercase()); } } @@ -338,22 +334,17 @@ pub async fn convert_row_data_for_steel( .ok_or_else(|| sqlx::Error::RowNotFound)?; // Parse column definitions to identify boolean columns for conversion - if let Ok(columns) = serde_json::from_value::>(table_def.columns) { - for column_def in columns { - let mut parts = column_def.split_whitespace(); - if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) { - let column_name = name.trim_matches('"'); - let normalized_type = normalize_data_type(data_type); + if let Ok(columns) = serde_json::from_value::>(table_def.columns) { + for col_def in columns { + let normalized_type = normalize_data_type(&col_def.field_type); - if normalized_type == "BOOLEAN" || normalized_type == "BOOL" { - if let Some(value) = row_data.get_mut(column_name) { - // Convert boolean value to Steel format - *value = match value.to_lowercase().as_str() { - "true" | "t" | "1" | "yes" | "on" => "#true".to_string(), - "false" | "f" | "0" | "no" | "off" => "#false".to_string(), - _ => value.clone(), // Keep original if not recognized - }; - } + if normalized_type == "BOOLEAN" || normalized_type == "BOOL" { + if let Some(value) = row_data.get_mut(&col_def.name) { + *value = match value.to_lowercase().as_str() { + "true" | "t" | "1" | "yes" | "on" => "#true".to_string(), + "false" | "f" | "0" | "no" | "off" => "#false".to_string(), + _ => value.clone(), + }; } } } diff --git a/server/src/table_script/handlers/post_table_script.rs b/server/src/table_script/handlers/post_table_script.rs index c95818b..b501534 100644 --- a/server/src/table_script/handlers/post_table_script.rs +++ b/server/src/table_script/handlers/post_table_script.rs @@ -4,6 +4,7 @@ use tonic::Status; use sqlx::{PgPool, Error as SqlxError}; use common::proto::komp_ac::table_script::{PostTableScriptRequest, TableScriptResponse}; +use common::proto::komp_ac::table_definition::ColumnDefinition; use serde_json::Value; use steel_decimal::SteelDecimal; use regex::Regex; @@ -303,16 +304,12 @@ async fn validate_math_operations_column_types( let mut table_column_types: HashMap> = HashMap::new(); for table_def in table_definitions { - let columns: Vec = serde_json::from_value(table_def.columns) + let columns: Vec = serde_json::from_value(table_def.columns) .map_err(|e| Status::internal(format!("Invalid column data for table '{}': {}", table_def.table_name, e)))?; let mut column_types = HashMap::new(); - for column_def in columns { - let mut parts = column_def.split_whitespace(); - if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) { - let column_name = name.trim_matches('"'); - column_types.insert(column_name.to_string(), data_type.to_string()); - } + for col_def in columns { + column_types.insert(col_def.name.clone(), col_def.field_type.clone()); } table_column_types.insert(table_def.table_name, column_types); } @@ -363,25 +360,13 @@ fn validate_target_column( } // Parse the columns JSON into a vector of strings - let columns: Vec = serde_json::from_value(table_columns.clone()) + let columns: Vec = serde_json::from_value(table_columns.clone()) .map_err(|e| format!("Invalid column data: {}", e))?; - // Extract column names and types - let column_info: Vec<(&str, &str)> = columns + let column_type = columns .iter() - .filter_map(|c| { - let mut parts = c.split_whitespace(); - let name = parts.next()?.trim_matches('"'); - let data_type = parts.next()?; - Some((name, data_type)) - }) - .collect(); - - // Find the target column and return its type - let column_type = column_info - .iter() - .find(|(name, _)| *name == target) - .map(|(_, dt)| dt.to_string()) + .find(|c| c.name == target) + .map(|c| c.field_type.clone()) .ok_or_else(|| format!("Target column '{}' not defined in table '{}'", target, table_name))?; // Check if the target column type is prohibited @@ -509,42 +494,29 @@ async fn validate_script_column_references( /// Validate that a referenced column doesn't have a prohibited type fn validate_referenced_column_type(table_name: &str, column_name: &str, table_columns: &Value) -> Result<(), String> { // Parse the columns JSON into a vector of strings - let columns: Vec = serde_json::from_value(table_columns.clone()) - .map_err(|e| format!("Invalid column data for table '{}': {}", table_name, e))?; +let columns: Vec = serde_json::from_value(table_columns.clone()) + .map_err(|e| format!("Invalid column data for table '{}': {}", table_name, e))?; - // Extract column names and types - let column_info: Vec<(&str, &str)> = columns - .iter() - .filter_map(|c| { - let mut parts = c.split_whitespace(); - let name = parts.next()?.trim_matches('"'); - let data_type = parts.next()?; - Some((name, data_type)) - }) - .collect(); - - // Find the referenced column and check its type - if let Some((_, column_type)) = column_info.iter().find(|(name, _)| *name == column_name) { - if is_prohibited_type(column_type) { + if let Some(col_def) = columns.iter().find(|c| c.name == column_name) { + if is_prohibited_type(&col_def.field_type) { return Err(format!( - "Script references column '{}' in table '{}' which has prohibited type '{}'. Steel scripts cannot access columns of type: {}", - column_name, - table_name, - column_type, - PROHIBITED_TYPES.join(", ") + "Script references column '{}' in table '{}' which has prohibited type '{}'. Steel scripts cannot access columns of type: {}", + column_name, + table_name, + col_def.field_type, + PROHIBITED_TYPES.join(", ") )); } - // Log info for boolean columns - let normalized_type = normalize_data_type(column_type); + let normalized_type = normalize_data_type(&col_def.field_type); if normalized_type == "BOOLEAN" || normalized_type == "BOOL" { println!("Info: Script references boolean column '{}' in table '{}'. Values will be converted to Steel format (#true/#false)", column_name, table_name); } } else { return Err(format!( - "Script references column '{}' in table '{}' but this column does not exist", - column_name, - table_name + "Script references column '{}' in table '{}' but this column does not exist", + column_name, + table_name )); }