use std::collections::HashMap; use std::path::{Path, PathBuf}; use tantivy::collector::TopDocs; use tantivy::query::{ BooleanQuery, BoostQuery, FuzzyTermQuery, Occur, Query, QueryParser, TermQuery, }; use tantivy::schema::{IndexRecordOption, Value}; use tantivy::{Index, TantivyDocument, Term}; use tonic::{Request, Response, Status}; use common::proto::komp_ac::search::searcher_server::Searcher; use common::proto::komp_ac::search::{search_response::Hit, SearchRequest, SearchResponse}; pub use common::proto::komp_ac::search::searcher_server::SearcherServer; use common::search::{register_slovak_tokenizers, search_index_path}; use sqlx::{PgPool, Row}; use tracing::info; const INDEX_ROOT: &str = "./tantivy_indexes"; const DEFAULT_RESULT_LIMIT: usize = 5; const SEARCH_RESULT_LIMIT: usize = 100; pub struct SearcherService { pub pool: PgPool, } struct SearchTarget { table_name: String, qualified_table: String, index_path: PathBuf, } fn normalize_slovak_text(text: &str) -> String { text.chars() .map(|c| match c { 'á' | 'à' | 'â' | 'ä' | 'ă' | 'ā' => 'a', 'Á' | 'À' | 'Â' | 'Ä' | 'Ă' | 'Ā' => 'A', 'é' | 'è' | 'ê' | 'ë' | 'ě' | 'ē' => 'e', 'É' | 'È' | 'Ê' | 'Ë' | 'Ě' | 'Ē' => 'E', 'í' | 'ì' | 'î' | 'ï' | 'ī' => 'i', 'Í' | 'Ì' | 'Î' | 'Ï' | 'Ī' => 'I', 'ó' | 'ò' | 'ô' | 'ö' | 'ō' | 'ő' => 'o', 'Ó' | 'Ò' | 'Ô' | 'Ö' | 'Ō' | 'Ő' => 'O', 'ú' | 'ù' | 'û' | 'ü' | 'ū' | 'ű' => 'u', 'Ú' | 'Ù' | 'Û' | 'Ü' | 'Ū' | 'Ű' => 'U', 'ý' | 'ỳ' | 'ŷ' | 'ÿ' => 'y', 'Ý' | 'Ỳ' | 'Ŷ' | 'Ÿ' => 'Y', 'č' => 'c', 'Č' => 'C', 'ď' => 'd', 'Ď' => 'D', 'ľ' => 'l', 'Ľ' => 'L', 'ň' => 'n', 'Ň' => 'N', 'ř' => 'r', 'Ř' => 'R', 'š' => 's', 'Š' => 'S', 'ť' => 't', 'Ť' => 'T', 'ž' => 'z', 'Ž' => 'Z', _ => c, }) .collect() } fn validate_identifier(value: &str, field_name: &str) -> Result<(), Status> { let mut chars = value.chars(); let Some(first) = chars.next() else { return Err(Status::invalid_argument(format!( "{field_name} must not be empty" ))); }; if !(first.is_ascii_alphabetic() || first == '_') || !chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_') { return Err(Status::invalid_argument(format!( "{field_name} contains invalid characters" ))); } Ok(()) } fn qualify_profile_table(profile_name: &str, table_name: &str) -> String { format!("\"{}\".\"{}\"", profile_name, table_name) } async fn profile_exists(pool: &PgPool, profile_name: &str) -> Result { let exists = sqlx::query_scalar::<_, bool>("SELECT EXISTS(SELECT 1 FROM schemas WHERE name = $1)") .bind(profile_name) .fetch_one(pool) .await .map_err(|e| Status::internal(format!("Profile lookup failed: {}", e)))?; Ok(exists) } // Scope resolution async fn resolve_search_targets( pool: &PgPool, profile_name: &str, requested_table: Option<&str>, ) -> Result, Status> { validate_identifier(profile_name, "profile_name")?; if !profile_exists(pool, profile_name).await? { return Err(Status::not_found(format!( "Profile '{}' was not found", profile_name ))); } let tables = if let Some(table_name) = requested_table.filter(|value| !value.trim().is_empty()) { validate_identifier(table_name, "table_name")?; let row = sqlx::query_scalar::<_, String>( r#" SELECT td.table_name FROM table_definitions td JOIN schemas s ON td.schema_id = s.id WHERE s.name = $1 AND td.table_name = $2 "#, ) .bind(profile_name) .bind(table_name) .fetch_optional(pool) .await .map_err(|e| Status::internal(format!("Table lookup failed: {}", e)))?; let table_name = row.ok_or_else(|| { Status::not_found(format!( "Table '{}' was not found in profile '{}'", table_name, profile_name )) })?; vec![table_name] } else { sqlx::query_scalar::<_, String>( r#" SELECT td.table_name FROM table_definitions td JOIN schemas s ON td.schema_id = s.id WHERE s.name = $1 ORDER BY td.table_name "#, ) .bind(profile_name) .fetch_all(pool) .await .map_err(|e| Status::internal(format!("Profile table lookup failed: {}", e)))? }; Ok(tables .into_iter() .map(|table_name| SearchTarget { qualified_table: qualify_profile_table(profile_name, &table_name), index_path: search_index_path(Path::new(INDEX_ROOT), profile_name, &table_name), table_name, }) .collect()) } // Query building fn build_query(index: &Index, normalized_query: &str) -> Result, Status> { let schema = index.schema(); let prefix_edge_field = schema .get_field("prefix_edge") .map_err(|_| Status::internal("Schema is missing the 'prefix_edge' field."))?; let prefix_full_field = schema .get_field("prefix_full") .map_err(|_| Status::internal("Schema is missing the 'prefix_full' field."))?; let text_ngram_field = schema .get_field("text_ngram") .map_err(|_| Status::internal("Schema is missing the 'text_ngram' field."))?; let words: Vec<&str> = normalized_query.split_whitespace().collect(); if words.is_empty() { return Ok(None); } let mut query_layers: Vec<(Occur, Box)> = Vec::new(); // Layer 1: prefix { let mut must_clauses: Vec<(Occur, Box)> = Vec::new(); for word in &words { let edge_term = Term::from_field_text(prefix_edge_field, word); let full_term = Term::from_field_text(prefix_full_field, word); let per_word_query = BooleanQuery::new(vec![ ( Occur::Should, Box::new(TermQuery::new(edge_term, IndexRecordOption::Basic)), ), ( Occur::Should, Box::new(TermQuery::new(full_term, IndexRecordOption::Basic)), ), ]); must_clauses.push((Occur::Must, Box::new(per_word_query))); } if !must_clauses.is_empty() { let prefix_query = BooleanQuery::new(must_clauses); let boosted_query = BoostQuery::new(Box::new(prefix_query), 4.0); query_layers.push((Occur::Should, Box::new(boosted_query))); } } // Layer 2: fuzzy { let last_word = words .last() .ok_or_else(|| Status::internal("Query normalization lost all tokens"))?; let fuzzy_term = Term::from_field_text(prefix_full_field, last_word); let fuzzy_query = FuzzyTermQuery::new(fuzzy_term, 2, true); let boosted_query = BoostQuery::new(Box::new(fuzzy_query), 3.0); query_layers.push((Occur::Should, Box::new(boosted_query))); } // Layer 3: phrase if words.len() > 1 { let slop_parser = QueryParser::for_index(index, vec![prefix_full_field]); let slop_query_str = format!("\"{}\"~3", normalized_query); if let Ok(slop_query) = slop_parser.parse_query(&slop_query_str) { let boosted_query = BoostQuery::new(slop_query, 2.0); query_layers.push((Occur::Should, Box::new(boosted_query))); } } // Layer 4: ngram { let ngram_parser = QueryParser::for_index(index, vec![text_ngram_field]); if let Ok(ngram_query) = ngram_parser.parse_query(normalized_query) { let boosted_query = BoostQuery::new(ngram_query, 1.0); query_layers.push((Occur::Should, Box::new(boosted_query))); } } Ok(Some(BooleanQuery::new(query_layers))) } // Empty query async fn fetch_default_hits(pool: &PgPool, target: &SearchTarget) -> Result, Status> { let sql = format!( "SELECT id, to_jsonb(t) AS data FROM {} t WHERE deleted = FALSE ORDER BY id DESC LIMIT {}", target.qualified_table, DEFAULT_RESULT_LIMIT ); let rows = sqlx::query(&sql) .fetch_all(pool) .await .map_err(|e| Status::internal(format!("DB query for default results failed: {}", e)))?; Ok(rows .into_iter() .map(|row| { let id: i64 = row.try_get("id").unwrap_or_default(); let json_data: serde_json::Value = row.try_get("data").unwrap_or_default(); Hit { id, score: 0.0, content_json: json_data.to_string(), table_name: target.table_name.clone(), } }) .collect()) } // Search + hydrate async fn search_target( pool: &PgPool, target: &SearchTarget, query_str: &str, ) -> Result, Status> { if !target.index_path.exists() { return Ok(vec![]); } let index = Index::open_in_dir(&target.index_path) .map_err(|e| Status::internal(format!("Failed to open index: {}", e)))?; register_slovak_tokenizers(&index) .map_err(|e| Status::internal(format!("Failed to register Slovak tokenizers: {}", e)))?; let Some(master_query) = build_query(&index, &normalize_slovak_text(query_str))? else { return Ok(vec![]); }; let reader = index .reader() .map_err(|e| Status::internal(format!("Failed to create index reader: {}", e)))?; let searcher = reader.searcher(); let schema = index.schema(); let pg_id_field = schema .get_field("pg_id") .map_err(|_| Status::internal("Schema is missing the 'pg_id' field."))?; let top_docs = searcher .search(&master_query, &TopDocs::with_limit(SEARCH_RESULT_LIMIT)) .map_err(|e| Status::internal(format!("Search failed: {}", e)))?; if top_docs.is_empty() { return Ok(vec![]); } let mut scored_ids: Vec<(f32, u64)> = Vec::new(); for (score, doc_address) in top_docs { let doc: TantivyDocument = searcher .doc(doc_address) .map_err(|e| Status::internal(format!("Failed to retrieve document: {}", e)))?; if let Some(pg_id_value) = doc.get_first(pg_id_field) { if let Some(pg_id) = pg_id_value.as_u64() { scored_ids.push((score, pg_id)); } } } if scored_ids.is_empty() { return Ok(vec![]); } let pg_ids: Vec = scored_ids.iter().map(|(_, id)| *id as i64).collect(); let sql = format!( "SELECT id, to_jsonb(t) AS data FROM {} t WHERE deleted = FALSE AND id = ANY($1)", target.qualified_table ); let rows = sqlx::query(&sql) .bind(&pg_ids) .fetch_all(pool) .await .map_err(|e| Status::internal(format!("Database query failed: {}", e)))?; let mut content_map: HashMap = HashMap::new(); for row in rows { let id: i64 = row.try_get("id").unwrap_or_default(); let json_data: serde_json::Value = row.try_get("data").unwrap_or_default(); content_map.insert(id, json_data.to_string()); } Ok(scored_ids .into_iter() .filter_map(|(score, pg_id)| { content_map.get(&(pg_id as i64)).map(|content_json| Hit { id: pg_id as i64, score, content_json: content_json.clone(), table_name: target.table_name.clone(), }) }) .collect()) } #[tonic::async_trait] impl Searcher for SearcherService { async fn search_table( &self, request: Request, ) -> Result, Status> { let req = request.into_inner(); let profile_name = req.profile_name.trim(); if profile_name.is_empty() { return Err(Status::invalid_argument("profile_name is required")); } // Request scope let requested_table = req.table_name.as_deref().map(str::trim); let targets = resolve_search_targets(&self.pool, profile_name, requested_table).await?; if targets.is_empty() { return Ok(Response::new(SearchResponse { hits: vec![] })); } let query = req.query.trim(); if query.is_empty() { // Empty query if targets.len() != 1 { return Err(Status::invalid_argument( "table_name is required when query is empty", )); } let hits = fetch_default_hits(&self.pool, &targets[0]).await?; info!( "Empty query for profile '{}' table '{}'. Returning {} default hits.", profile_name, targets[0].table_name, hits.len() ); return Ok(Response::new(SearchResponse { hits })); } if requested_table.is_some() && targets.len() == 1 && !targets[0].index_path.exists() { return Err(Status::not_found(format!( "No search index found for table '{}'", targets[0].table_name ))); } // Merge per-table hits let mut hits = Vec::new(); for target in &targets { hits.extend(search_target(&self.pool, target, query).await?); } hits.sort_by(|left, right| right.score.total_cmp(&left.score)); if hits.len() > SEARCH_RESULT_LIMIT { hits.truncate(SEARCH_RESULT_LIMIT); } info!( "Processed search for profile '{}' (table scope: {}). Returning {} hits.", profile_name, requested_table.unwrap_or("*"), hits.len() ); Ok(Response::new(SearchResponse { hits })) } }