425 lines
14 KiB
Rust
425 lines
14 KiB
Rust
use std::collections::HashMap;
|
|
use std::path::{Path, PathBuf};
|
|
|
|
use tantivy::collector::TopDocs;
|
|
use tantivy::query::{
|
|
BooleanQuery, BoostQuery, FuzzyTermQuery, Occur, Query, QueryParser, TermQuery,
|
|
};
|
|
use tantivy::schema::{IndexRecordOption, Value};
|
|
use tantivy::{Index, TantivyDocument, Term};
|
|
use tonic::{Request, Response, Status};
|
|
|
|
use common::proto::komp_ac::search::searcher_server::Searcher;
|
|
use common::proto::komp_ac::search::{search_response::Hit, SearchRequest, SearchResponse};
|
|
pub use common::proto::komp_ac::search::searcher_server::SearcherServer;
|
|
use common::search::{register_slovak_tokenizers, search_index_path};
|
|
use sqlx::{PgPool, Row};
|
|
use tracing::info;
|
|
|
|
const INDEX_ROOT: &str = "./tantivy_indexes";
|
|
const DEFAULT_RESULT_LIMIT: usize = 5;
|
|
const SEARCH_RESULT_LIMIT: usize = 100;
|
|
|
|
pub struct SearcherService {
|
|
pub pool: PgPool,
|
|
}
|
|
|
|
struct SearchTarget {
|
|
table_name: String,
|
|
qualified_table: String,
|
|
index_path: PathBuf,
|
|
}
|
|
|
|
fn normalize_slovak_text(text: &str) -> String {
|
|
text.chars()
|
|
.map(|c| match c {
|
|
'á' | 'à' | 'â' | 'ä' | 'ă' | 'ā' => 'a',
|
|
'Á' | 'À' | 'Â' | 'Ä' | 'Ă' | 'Ā' => 'A',
|
|
'é' | 'è' | 'ê' | 'ë' | 'ě' | 'ē' => 'e',
|
|
'É' | 'È' | 'Ê' | 'Ë' | 'Ě' | 'Ē' => 'E',
|
|
'í' | 'ì' | 'î' | 'ï' | 'ī' => 'i',
|
|
'Í' | 'Ì' | 'Î' | 'Ï' | 'Ī' => 'I',
|
|
'ó' | 'ò' | 'ô' | 'ö' | 'ō' | 'ő' => 'o',
|
|
'Ó' | 'Ò' | 'Ô' | 'Ö' | 'Ō' | 'Ő' => 'O',
|
|
'ú' | 'ù' | 'û' | 'ü' | 'ū' | 'ű' => 'u',
|
|
'Ú' | 'Ù' | 'Û' | 'Ü' | 'Ū' | 'Ű' => 'U',
|
|
'ý' | 'ỳ' | 'ŷ' | 'ÿ' => 'y',
|
|
'Ý' | 'Ỳ' | 'Ŷ' | 'Ÿ' => 'Y',
|
|
'č' => 'c',
|
|
'Č' => 'C',
|
|
'ď' => 'd',
|
|
'Ď' => 'D',
|
|
'ľ' => 'l',
|
|
'Ľ' => 'L',
|
|
'ň' => 'n',
|
|
'Ň' => 'N',
|
|
'ř' => 'r',
|
|
'Ř' => 'R',
|
|
'š' => 's',
|
|
'Š' => 'S',
|
|
'ť' => 't',
|
|
'Ť' => 'T',
|
|
'ž' => 'z',
|
|
'Ž' => 'Z',
|
|
_ => c,
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn validate_identifier(value: &str, field_name: &str) -> Result<(), Status> {
|
|
let mut chars = value.chars();
|
|
let Some(first) = chars.next() else {
|
|
return Err(Status::invalid_argument(format!(
|
|
"{field_name} must not be empty"
|
|
)));
|
|
};
|
|
|
|
if !(first.is_ascii_alphabetic() || first == '_')
|
|
|| !chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
|
|
{
|
|
return Err(Status::invalid_argument(format!(
|
|
"{field_name} contains invalid characters"
|
|
)));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn qualify_profile_table(profile_name: &str, table_name: &str) -> String {
|
|
format!("\"{}\".\"{}\"", profile_name, table_name)
|
|
}
|
|
|
|
async fn profile_exists(pool: &PgPool, profile_name: &str) -> Result<bool, Status> {
|
|
let exists = sqlx::query_scalar::<_, bool>("SELECT EXISTS(SELECT 1 FROM schemas WHERE name = $1)")
|
|
.bind(profile_name)
|
|
.fetch_one(pool)
|
|
.await
|
|
.map_err(|e| Status::internal(format!("Profile lookup failed: {}", e)))?;
|
|
Ok(exists)
|
|
}
|
|
|
|
// Scope resolution
|
|
async fn resolve_search_targets(
|
|
pool: &PgPool,
|
|
profile_name: &str,
|
|
requested_table: Option<&str>,
|
|
) -> Result<Vec<SearchTarget>, Status> {
|
|
validate_identifier(profile_name, "profile_name")?;
|
|
|
|
if !profile_exists(pool, profile_name).await? {
|
|
return Err(Status::not_found(format!(
|
|
"Profile '{}' was not found",
|
|
profile_name
|
|
)));
|
|
}
|
|
|
|
let tables = if let Some(table_name) = requested_table.filter(|value| !value.trim().is_empty()) {
|
|
validate_identifier(table_name, "table_name")?;
|
|
|
|
let row = sqlx::query_scalar::<_, String>(
|
|
r#"
|
|
SELECT td.table_name
|
|
FROM table_definitions td
|
|
JOIN schemas s ON td.schema_id = s.id
|
|
WHERE s.name = $1 AND td.table_name = $2
|
|
"#,
|
|
)
|
|
.bind(profile_name)
|
|
.bind(table_name)
|
|
.fetch_optional(pool)
|
|
.await
|
|
.map_err(|e| Status::internal(format!("Table lookup failed: {}", e)))?;
|
|
|
|
let table_name = row.ok_or_else(|| {
|
|
Status::not_found(format!(
|
|
"Table '{}' was not found in profile '{}'",
|
|
table_name, profile_name
|
|
))
|
|
})?;
|
|
|
|
vec![table_name]
|
|
} else {
|
|
sqlx::query_scalar::<_, String>(
|
|
r#"
|
|
SELECT td.table_name
|
|
FROM table_definitions td
|
|
JOIN schemas s ON td.schema_id = s.id
|
|
WHERE s.name = $1
|
|
ORDER BY td.table_name
|
|
"#,
|
|
)
|
|
.bind(profile_name)
|
|
.fetch_all(pool)
|
|
.await
|
|
.map_err(|e| Status::internal(format!("Profile table lookup failed: {}", e)))?
|
|
};
|
|
|
|
Ok(tables
|
|
.into_iter()
|
|
.map(|table_name| SearchTarget {
|
|
qualified_table: qualify_profile_table(profile_name, &table_name),
|
|
index_path: search_index_path(Path::new(INDEX_ROOT), profile_name, &table_name),
|
|
table_name,
|
|
})
|
|
.collect())
|
|
}
|
|
|
|
// Query building
|
|
fn build_query(index: &Index, normalized_query: &str) -> Result<Option<BooleanQuery>, Status> {
|
|
let schema = index.schema();
|
|
let prefix_edge_field = schema
|
|
.get_field("prefix_edge")
|
|
.map_err(|_| Status::internal("Schema is missing the 'prefix_edge' field."))?;
|
|
let prefix_full_field = schema
|
|
.get_field("prefix_full")
|
|
.map_err(|_| Status::internal("Schema is missing the 'prefix_full' field."))?;
|
|
let text_ngram_field = schema
|
|
.get_field("text_ngram")
|
|
.map_err(|_| Status::internal("Schema is missing the 'text_ngram' field."))?;
|
|
|
|
let words: Vec<&str> = normalized_query.split_whitespace().collect();
|
|
if words.is_empty() {
|
|
return Ok(None);
|
|
}
|
|
|
|
let mut query_layers: Vec<(Occur, Box<dyn Query>)> = Vec::new();
|
|
|
|
// Layer 1: prefix
|
|
{
|
|
let mut must_clauses: Vec<(Occur, Box<dyn Query>)> = Vec::new();
|
|
for word in &words {
|
|
let edge_term = Term::from_field_text(prefix_edge_field, word);
|
|
let full_term = Term::from_field_text(prefix_full_field, word);
|
|
|
|
let per_word_query = BooleanQuery::new(vec![
|
|
(
|
|
Occur::Should,
|
|
Box::new(TermQuery::new(edge_term, IndexRecordOption::Basic)),
|
|
),
|
|
(
|
|
Occur::Should,
|
|
Box::new(TermQuery::new(full_term, IndexRecordOption::Basic)),
|
|
),
|
|
]);
|
|
must_clauses.push((Occur::Must, Box::new(per_word_query)));
|
|
}
|
|
|
|
if !must_clauses.is_empty() {
|
|
let prefix_query = BooleanQuery::new(must_clauses);
|
|
let boosted_query = BoostQuery::new(Box::new(prefix_query), 4.0);
|
|
query_layers.push((Occur::Should, Box::new(boosted_query)));
|
|
}
|
|
}
|
|
|
|
// Layer 2: fuzzy
|
|
{
|
|
let last_word = words
|
|
.last()
|
|
.ok_or_else(|| Status::internal("Query normalization lost all tokens"))?;
|
|
let fuzzy_term = Term::from_field_text(prefix_full_field, last_word);
|
|
let fuzzy_query = FuzzyTermQuery::new(fuzzy_term, 2, true);
|
|
let boosted_query = BoostQuery::new(Box::new(fuzzy_query), 3.0);
|
|
query_layers.push((Occur::Should, Box::new(boosted_query)));
|
|
}
|
|
|
|
// Layer 3: phrase
|
|
if words.len() > 1 {
|
|
let slop_parser = QueryParser::for_index(index, vec![prefix_full_field]);
|
|
let slop_query_str = format!("\"{}\"~3", normalized_query);
|
|
if let Ok(slop_query) = slop_parser.parse_query(&slop_query_str) {
|
|
let boosted_query = BoostQuery::new(slop_query, 2.0);
|
|
query_layers.push((Occur::Should, Box::new(boosted_query)));
|
|
}
|
|
}
|
|
|
|
// Layer 4: ngram
|
|
{
|
|
let ngram_parser = QueryParser::for_index(index, vec![text_ngram_field]);
|
|
if let Ok(ngram_query) = ngram_parser.parse_query(normalized_query) {
|
|
let boosted_query = BoostQuery::new(ngram_query, 1.0);
|
|
query_layers.push((Occur::Should, Box::new(boosted_query)));
|
|
}
|
|
}
|
|
|
|
Ok(Some(BooleanQuery::new(query_layers)))
|
|
}
|
|
|
|
// Empty query
|
|
async fn fetch_default_hits(pool: &PgPool, target: &SearchTarget) -> Result<Vec<Hit>, Status> {
|
|
let sql = format!(
|
|
"SELECT id, to_jsonb(t) AS data FROM {} t WHERE deleted = FALSE ORDER BY id DESC LIMIT {}",
|
|
target.qualified_table, DEFAULT_RESULT_LIMIT
|
|
);
|
|
|
|
let rows = sqlx::query(&sql)
|
|
.fetch_all(pool)
|
|
.await
|
|
.map_err(|e| Status::internal(format!("DB query for default results failed: {}", e)))?;
|
|
|
|
Ok(rows
|
|
.into_iter()
|
|
.map(|row| {
|
|
let id: i64 = row.try_get("id").unwrap_or_default();
|
|
let json_data: serde_json::Value = row.try_get("data").unwrap_or_default();
|
|
Hit {
|
|
id,
|
|
score: 0.0,
|
|
content_json: json_data.to_string(),
|
|
table_name: target.table_name.clone(),
|
|
}
|
|
})
|
|
.collect())
|
|
}
|
|
|
|
// Search + hydrate
|
|
async fn search_target(
|
|
pool: &PgPool,
|
|
target: &SearchTarget,
|
|
query_str: &str,
|
|
) -> Result<Vec<Hit>, Status> {
|
|
if !target.index_path.exists() {
|
|
return Ok(vec![]);
|
|
}
|
|
|
|
let index = Index::open_in_dir(&target.index_path)
|
|
.map_err(|e| Status::internal(format!("Failed to open index: {}", e)))?;
|
|
register_slovak_tokenizers(&index)
|
|
.map_err(|e| Status::internal(format!("Failed to register Slovak tokenizers: {}", e)))?;
|
|
|
|
let Some(master_query) = build_query(&index, &normalize_slovak_text(query_str))? else {
|
|
return Ok(vec![]);
|
|
};
|
|
|
|
let reader = index
|
|
.reader()
|
|
.map_err(|e| Status::internal(format!("Failed to create index reader: {}", e)))?;
|
|
let searcher = reader.searcher();
|
|
let schema = index.schema();
|
|
let pg_id_field = schema
|
|
.get_field("pg_id")
|
|
.map_err(|_| Status::internal("Schema is missing the 'pg_id' field."))?;
|
|
|
|
let top_docs = searcher
|
|
.search(&master_query, &TopDocs::with_limit(SEARCH_RESULT_LIMIT))
|
|
.map_err(|e| Status::internal(format!("Search failed: {}", e)))?;
|
|
|
|
if top_docs.is_empty() {
|
|
return Ok(vec![]);
|
|
}
|
|
|
|
let mut scored_ids: Vec<(f32, u64)> = Vec::new();
|
|
for (score, doc_address) in top_docs {
|
|
let doc: TantivyDocument = searcher
|
|
.doc(doc_address)
|
|
.map_err(|e| Status::internal(format!("Failed to retrieve document: {}", e)))?;
|
|
if let Some(pg_id_value) = doc.get_first(pg_id_field) {
|
|
if let Some(pg_id) = pg_id_value.as_u64() {
|
|
scored_ids.push((score, pg_id));
|
|
}
|
|
}
|
|
}
|
|
|
|
if scored_ids.is_empty() {
|
|
return Ok(vec![]);
|
|
}
|
|
|
|
let pg_ids: Vec<i64> = scored_ids.iter().map(|(_, id)| *id as i64).collect();
|
|
let sql = format!(
|
|
"SELECT id, to_jsonb(t) AS data FROM {} t WHERE deleted = FALSE AND id = ANY($1)",
|
|
target.qualified_table
|
|
);
|
|
|
|
let rows = sqlx::query(&sql)
|
|
.bind(&pg_ids)
|
|
.fetch_all(pool)
|
|
.await
|
|
.map_err(|e| Status::internal(format!("Database query failed: {}", e)))?;
|
|
|
|
let mut content_map: HashMap<i64, String> = HashMap::new();
|
|
for row in rows {
|
|
let id: i64 = row.try_get("id").unwrap_or_default();
|
|
let json_data: serde_json::Value = row.try_get("data").unwrap_or_default();
|
|
content_map.insert(id, json_data.to_string());
|
|
}
|
|
|
|
Ok(scored_ids
|
|
.into_iter()
|
|
.filter_map(|(score, pg_id)| {
|
|
content_map.get(&(pg_id as i64)).map(|content_json| Hit {
|
|
id: pg_id as i64,
|
|
score,
|
|
content_json: content_json.clone(),
|
|
table_name: target.table_name.clone(),
|
|
})
|
|
})
|
|
.collect())
|
|
}
|
|
|
|
#[tonic::async_trait]
|
|
impl Searcher for SearcherService {
|
|
async fn search_table(
|
|
&self,
|
|
request: Request<SearchRequest>,
|
|
) -> Result<Response<SearchResponse>, Status> {
|
|
let req = request.into_inner();
|
|
let profile_name = req.profile_name.trim();
|
|
if profile_name.is_empty() {
|
|
return Err(Status::invalid_argument("profile_name is required"));
|
|
}
|
|
|
|
// Request scope
|
|
let requested_table = req.table_name.as_deref().map(str::trim);
|
|
let targets = resolve_search_targets(&self.pool, profile_name, requested_table).await?;
|
|
|
|
if targets.is_empty() {
|
|
return Ok(Response::new(SearchResponse { hits: vec![] }));
|
|
}
|
|
|
|
let query = req.query.trim();
|
|
if query.is_empty() {
|
|
// Empty query
|
|
if targets.len() != 1 {
|
|
return Err(Status::invalid_argument(
|
|
"table_name is required when query is empty",
|
|
));
|
|
}
|
|
|
|
let hits = fetch_default_hits(&self.pool, &targets[0]).await?;
|
|
info!(
|
|
"Empty query for profile '{}' table '{}'. Returning {} default hits.",
|
|
profile_name,
|
|
targets[0].table_name,
|
|
hits.len()
|
|
);
|
|
return Ok(Response::new(SearchResponse { hits }));
|
|
}
|
|
|
|
if requested_table.is_some() && targets.len() == 1 && !targets[0].index_path.exists() {
|
|
return Err(Status::not_found(format!(
|
|
"No search index found for table '{}'",
|
|
targets[0].table_name
|
|
)));
|
|
}
|
|
|
|
// Merge per-table hits
|
|
let mut hits = Vec::new();
|
|
for target in &targets {
|
|
hits.extend(search_target(&self.pool, target, query).await?);
|
|
}
|
|
|
|
hits.sort_by(|left, right| right.score.total_cmp(&left.score));
|
|
if hits.len() > SEARCH_RESULT_LIMIT {
|
|
hits.truncate(SEARCH_RESULT_LIMIT);
|
|
}
|
|
|
|
info!(
|
|
"Processed search for profile '{}' (table scope: {}). Returning {} hits.",
|
|
profile_name,
|
|
requested_table.unwrap_or("*"),
|
|
hits.len()
|
|
);
|
|
|
|
Ok(Response::new(SearchResponse { hits }))
|
|
}
|
|
}
|