From 2a3098287703b76cf5a52ccb31d07d0b146196d6 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 22 Jan 2026 11:08:25 -0800 Subject: [PATCH 1/4] Default to omni --- integration/go/go_pgx/sharded_test.go | 5 +- integration/rust/tests/integration/avg.rs | 19 +- integration/rust/tests/integration/explain.rs | 7 +- .../tests/integration/per_stmt_routing.rs | 4 +- .../tests/integration/shard_consistency.rs | 28 +- integration/rust/tests/integration/stddev.rs | 28 +- integration/schema_inference/pgdog.toml | 21 ++ integration/schema_inference/users.toml | 5 + pgdog/src/backend/pool/cluster.rs | 107 +++++--- pgdog/src/backend/pool/shard/mod.rs | 26 +- .../src/backend/replication/sharded_tables.rs | 4 + pgdog/src/backend/schema/mod.rs | 251 ++++++++++++++++-- pgdog/src/backend/schema/relation.rs | 11 +- pgdog/src/backend/schema/relations.sql | 3 +- pgdog/src/backend/schema/setup.sql | 81 ++++++ .../client/query_engine/route_query.rs | 20 +- pgdog/src/frontend/error.rs | 3 + pgdog/src/frontend/router/context.rs | 5 +- .../frontend/router/parser/multi_tenant.rs | 4 +- .../frontend/router/parser/query/select.rs | 135 +++++----- .../frontend/router/parser/query/shared.rs | 46 ++-- .../frontend/router/parser/query/test/mod.rs | 5 +- .../parser/query/test/test_schema_sharding.rs | 11 +- .../parser/query/test/test_search_path.rs | 19 +- .../router/parser/query/test/test_sharding.rs | 2 +- pgdog/src/frontend/router/parser/statement.rs | 60 ++++- 26 files changed, 713 insertions(+), 197 deletions(-) create mode 100644 integration/schema_inference/pgdog.toml create mode 100644 integration/schema_inference/users.toml diff --git a/integration/go/go_pgx/sharded_test.go b/integration/go/go_pgx/sharded_test.go index d5997f438..c4892ad6d 100644 --- a/integration/go/go_pgx/sharded_test.go +++ b/integration/go/go_pgx/sharded_test.go @@ -151,10 +151,11 @@ func TestShardedTwoPc(t *testing.T) { assert.NoError(t, err) } + // +3 is for schema sync assertShowField(t, "SHOW STATS", "total_xact_2pc_count", 200, "pgdog_2pc", "pgdog_sharded", 0, "primary") assertShowField(t, "SHOW STATS", "total_xact_2pc_count", 200, "pgdog_2pc", "pgdog_sharded", 1, "primary") - assertShowField(t, "SHOW STATS", "total_xact_count", 401, "pgdog_2pc", "pgdog_sharded", 0, "primary") // PREPARE, COMMIT for each transaction + TRUNCATE - assertShowField(t, "SHOW STATS", "total_xact_count", 401, "pgdog_2pc", "pgdog_sharded", 1, "primary") + assertShowField(t, "SHOW STATS", "total_xact_count", 401+3, "pgdog_2pc", "pgdog_sharded", 0, "primary") // PREPARE, COMMIT for each transaction + TRUNCATE + assertShowField(t, "SHOW STATS", "total_xact_count", 401+3, "pgdog_2pc", "pgdog_sharded", 1, "primary") for i := range 200 { rows, err := conn.Query( diff --git a/integration/rust/tests/integration/avg.rs b/integration/rust/tests/integration/avg.rs index ef45e2c17..f44e99c40 100644 --- a/integration/rust/tests/integration/avg.rs +++ b/integration/rust/tests/integration/avg.rs @@ -1,4 +1,4 @@ -use rust::setup::connections_sqlx; +use rust::setup::{admin_sqlx, connections_sqlx}; use sqlx::{Connection, Executor, PgConnection, Row}; #[tokio::test] @@ -17,12 +17,15 @@ async fn avg_merges_with_helper_count() -> Result<(), Box for shard in [0, 1] { let comment = format!( - "/* pgdog_shard: {} */ CREATE TABLE avg_reduce_test(price DOUBLE PRECISION)", + "/* pgdog_shard: {} */ CREATE TABLE avg_reduce_test(price DOUBLE PRECISION, customer_id BIGINT)", shard ); sharded.execute(comment.as_str()).await?; } + // Make sure sharded table is loaded in schema. + admin_sqlx().await.execute("RELOAD").await?; + // Insert data on each shard so the query spans multiple shards. sharded .execute("/* pgdog_shard: 0 */ INSERT INTO avg_reduce_test(price) VALUES (10.0), (14.0)") @@ -73,12 +76,14 @@ async fn avg_without_helper_should_still_merge() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box> { ) .await?; + admin_sqlx().await.execute("RELOAD").await?; + sharded.execute("TRUNCATE TABLE per_stmt_routing").await?; for i in 0..50 { diff --git a/integration/rust/tests/integration/shard_consistency.rs b/integration/rust/tests/integration/shard_consistency.rs index 738c2a58d..562fdfe0d 100644 --- a/integration/rust/tests/integration/shard_consistency.rs +++ b/integration/rust/tests/integration/shard_consistency.rs @@ -1,4 +1,4 @@ -use rust::setup::connections_sqlx; +use rust::setup::{admin_sqlx, connections_sqlx}; use sqlx::{Executor, Row}; #[tokio::test] @@ -15,14 +15,16 @@ async fn shard_consistency_validator() -> Result<(), Box> // Create different table schemas on each shard to trigger validator errors // Shard 0: table with 2 columns (id, name) sharded - .execute("/* pgdog_shard: 0 */ CREATE TABLE shard_test (id BIGINT PRIMARY KEY, name VARCHAR(100))") + .execute("/* pgdog_shard: 0 */ CREATE TABLE shard_test (id BIGINT PRIMARY KEY, name VARCHAR(100), customer_id BIGINT)") .await?; // Shard 1: table with 3 columns (id, name, extra) - different column count sharded - .execute("/* pgdog_shard: 1 */ CREATE TABLE shard_test (id BIGINT PRIMARY KEY, name VARCHAR(100), extra TEXT)") + .execute("/* pgdog_shard: 1 */ CREATE TABLE shard_test (id BIGINT PRIMARY KEY, name VARCHAR(100), extra TEXT, customer_id BIGINT)") .await?; + admin_sqlx().await.execute("RELOAD").await?; + // Insert some test data on each shard sharded .execute("/* pgdog_shard: 0 */ INSERT INTO shard_test (id, name) VALUES (1, 'shard0_row1'), (2, 'shard0_row2')") @@ -76,14 +78,16 @@ async fn shard_consistency_validator_column_names() -> Result<(), Box Result<(), Box $1 ORDER BY id") diff --git a/integration/rust/tests/integration/stddev.rs b/integration/rust/tests/integration/stddev.rs index 9263b93d6..9a121a016 100644 --- a/integration/rust/tests/integration/stddev.rs +++ b/integration/rust/tests/integration/stddev.rs @@ -1,7 +1,7 @@ use std::collections::BTreeSet; use ordered_float::OrderedFloat; -use rust::setup::{connections_sqlx, connections_tokio}; +use rust::setup::{admin_sqlx, connections_sqlx, connections_tokio}; use sqlx::{Connection, Executor, PgConnection, Row, postgres::PgPool}; const SHARD_URLS: [&str; 2] = [ @@ -55,7 +55,7 @@ async fn stddev_pop_merges_with_helpers() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), sql pool.execute(create_stmt.as_str()).await?; } + admin_sqlx().await.execute("RELOAD").await?; + Ok(()) } diff --git a/integration/schema_inference/pgdog.toml b/integration/schema_inference/pgdog.toml new file mode 100644 index 000000000..0c68d623c --- /dev/null +++ b/integration/schema_inference/pgdog.toml @@ -0,0 +1,21 @@ +[general] +expanded_explain = true + +[[databases]] +name = "pgdog" +host = "127.0.0.1" +database_name = "shard_0" +shard = 0 + +[[databases]] +name = "pgdog" +host = "127.0.0.1" +database_name = "shard_1" +shard = 1 + +[[sharded_tables]] +column = "user_id" +database = "pgdog" + +[admin] +password = "pgdog" diff --git a/integration/schema_inference/users.toml b/integration/schema_inference/users.toml new file mode 100644 index 000000000..ddd797005 --- /dev/null +++ b/integration/schema_inference/users.toml @@ -0,0 +1,5 @@ +[[users]] +name = "pgdog" +password = "pgdog" +database = "pgdog" +schema_admin = true diff --git a/pgdog/src/backend/pool/cluster.rs b/pgdog/src/backend/pool/cluster.rs index 8fb2e2b1f..8f3bfbb5b 100644 --- a/pgdog/src/backend/pool/cluster.rs +++ b/pgdog/src/backend/pool/cluster.rs @@ -1,15 +1,15 @@ //! A collection of replicas and a primary. -use parking_lot::{Mutex, RwLock}; +use parking_lot::Mutex; use pgdog_config::{PreparedStatements, QueryParserEngine, QueryParserLevel, Rewrite, RewriteMode}; use std::{ sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicUsize, Ordering}, Arc, }, time::Duration, }; -use tokio::spawn; +use tokio::{spawn, sync::Notify}; use tracing::{error, info}; use crate::{ @@ -37,6 +37,14 @@ pub struct PoolConfig { pub(crate) config: Config, } +#[derive(Default, Debug)] +struct Readiness { + online: AtomicBool, + schema_loading_started: AtomicBool, + schemas_loaded: AtomicUsize, + schemas_ready: Notify, +} + /// A collection of sharded replicas and primaries /// belonging to the same database cluster. #[derive(Clone, Default, Debug)] @@ -48,7 +56,6 @@ pub struct Cluster { sharded_tables: ShardedTables, sharded_schemas: ShardedSchemas, replication_sharding: Option, - schema: Arc>, multi_tenant: Option, rw_strategy: ReadWriteStrategy, schema_admin: bool, @@ -56,7 +63,7 @@ pub struct Cluster { cross_shard_disabled: bool, two_phase_commit: bool, two_phase_commit_auto: bool, - online: Arc, + readiness: Arc, rewrite: Rewrite, prepared_statements: PreparedStatements, dry_run: bool, @@ -246,7 +253,6 @@ impl Cluster { sharded_tables, sharded_schemas, replication_sharding, - schema: Arc::new(RwLock::new(Schema::default())), multi_tenant: multi_tenant.clone(), rw_strategy, schema_admin, @@ -254,7 +260,7 @@ impl Cluster { cross_shard_disabled, two_phase_commit: two_pc && shards.len() > 1, two_phase_commit_auto: two_pc_auto && shards.len() > 1, - online: Arc::new(AtomicBool::new(false)), + readiness: Arc::new(Readiness::default()), rewrite: rewrite.clone(), prepared_statements: *prepared_statements, dry_run, @@ -460,26 +466,16 @@ impl Cluster { } } - /// Update schema from primary. - async fn update_schema(&self) -> Result<(), crate::backend::Error> { - let mut server = self.primary(0, &Request::default()).await?; - let schema = Schema::load(&mut server).await?; - info!( - "loaded {} tables from schema [{}]", - schema.tables().len(), - server.addr() - ); - *self.schema.write() = schema; - Ok(()) - } - fn load_schema(&self) -> bool { - self.multi_tenant.is_some() + self.shards.len() > 1 || self.multi_tenant().is_some() } - /// Get currently loaded schema. + /// Get currently loaded schema from shard 0. pub fn schema(&self) -> Schema { - self.schema.read().clone() + self.shards + .first() + .map(|shard| shard.schema()) + .unwrap_or_default() } /// Read/write strategy @@ -509,16 +505,36 @@ impl Cluster { shard.launch(); } - if self.load_schema() { - let me = self.clone(); - spawn(async move { - if let Err(err) = me.update_schema().await { - error!("error loading schema: {}", err); - } - }); + // Only spawn schema loading once per cluster, even if launch() is called multiple times. + let already_started = self + .readiness + .schema_loading_started + .swap(true, Ordering::SeqCst); + + if self.load_schema() && !already_started { + for shard in self.shards() { + let readiness = self.readiness.clone(); + let shard = shard.clone(); + let shards = self.shards.len(); + + spawn(async move { + if let Err(err) = shard.update_schema().await { + error!("error loading schema for shard {}: {}", shard.number(), err); + } + + let done = readiness.schemas_loaded.fetch_add(1, Ordering::SeqCst); + + info!("loaded schema from {}/{} shards", done + 1, shards); + + // Loaded schema on all shards. + if done >= shards - 1 { + readiness.schemas_ready.notify_waiters(); + } + }); + } } - self.online.store(true, Ordering::Relaxed); + self.readiness.online.store(true, Ordering::Relaxed); } /// Shutdown the connection pools. @@ -527,11 +543,36 @@ impl Cluster { shard.shutdown(); } - self.online.store(false, Ordering::Relaxed); + self.readiness.online.store(false, Ordering::Relaxed); } + /// Is the cluster online? pub(crate) fn online(&self) -> bool { - self.online.load(Ordering::Relaxed) + self.readiness.online.load(Ordering::Relaxed) + } + + /// Schema loaded for all shards? + pub(crate) async fn wait_schema_loaded(&self) { + if !self.load_schema() { + return; + } + + fn check_loaded(cluster: &Cluster) -> bool { + cluster.readiness.schemas_loaded.load(Ordering::Acquire) == cluster.shards.len() + } + + // Fast path. + if check_loaded(self) { + return; + } + + // Queue up. + let notified = self.readiness.schemas_ready.notified(); + // Race condition check. + if check_loaded(self) { + return; + } + notified.await; } /// Execute a query on every primary in the cluster. diff --git a/pgdog/src/backend/pool/shard/mod.rs b/pgdog/src/backend/pool/shard/mod.rs index 9429b431d..fca067eb7 100644 --- a/pgdog/src/backend/pool/shard/mod.rs +++ b/pgdog/src/backend/pool/shard/mod.rs @@ -3,13 +3,14 @@ use std::ops::Deref; use std::sync::Arc; use std::time::Duration; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, OnceCell}; use tokio::{select, spawn, sync::Notify}; -use tracing::debug; +use tracing::{debug, info}; use crate::backend::databases::User; use crate::backend::pool::lb::ban::Ban; use crate::backend::PubSubListener; +use crate::backend::Schema; use crate::config::{config, LoadBalancingStrategy, ReadWriteSplit, Role}; use crate::net::messages::BackendKeyData; use crate::net::NotificationResponse; @@ -124,6 +125,20 @@ impl Shard { } } + /// Load schema from the shard's primary. + pub async fn update_schema(&self) -> Result<(), crate::backend::Error> { + let mut server = self.primary_or_replica(&Request::default()).await?; + let schema = Schema::load(&mut server).await?; + info!( + "loaded schema for {} tables on shard {} [{}]", + schema.tables().len(), + self.number(), + server.addr() + ); + let _ = self.schema.set(schema); + Ok(()) + } + /// Bring every pool online. pub fn launch(&self) { self.lb.launch(); @@ -206,6 +221,11 @@ impl Shard { &self.identifier } + /// Get currently loaded schema for this shard. + pub fn schema(&self) -> Schema { + self.schema.get().cloned().unwrap_or_default() + } + /// Re-detect primary/replica roles and re-build /// the shard routing logic. pub fn redetect_roles(&self) -> bool { @@ -230,6 +250,7 @@ pub struct ShardInner { comms: Arc, pub_sub: Option, identifier: Arc, + schema: Arc>, } impl ShardInner { @@ -261,6 +282,7 @@ impl ShardInner { comms, pub_sub, identifier, + schema: Arc::new(OnceCell::new()), } } } diff --git a/pgdog/src/backend/replication/sharded_tables.rs b/pgdog/src/backend/replication/sharded_tables.rs index 33680f296..137fe1c50 100644 --- a/pgdog/src/backend/replication/sharded_tables.rs +++ b/pgdog/src/backend/replication/sharded_tables.rs @@ -91,6 +91,10 @@ impl ShardedTables { &self.inner.omnisharded } + pub fn is_omnisharded_sticky(&self, name: &str) -> Option { + self.omnishards().get(name).cloned() + } + /// The deployment has only one sharded table. pub fn common_mapping(&self) -> &Option { &self.inner.common_mapping diff --git a/pgdog/src/backend/schema/mod.rs b/pgdog/src/backend/schema/mod.rs index b8ad33eb1..203ede4ec 100644 --- a/pgdog/src/backend/schema/mod.rs +++ b/pgdog/src/backend/schema/mod.rs @@ -10,13 +10,18 @@ use tracing::debug; pub use relation::Relation; use super::{pool::Request, Cluster, Error, Server}; +use crate::frontend::router::parser::Table; +use crate::net::parameter::ParameterValue; static SETUP: &str = include_str!("setup.sql"); +/// Schema name -> Table name -> Relation +type Relations = HashMap>; + #[derive(Debug, Default)] struct Inner { search_path: Vec, - relations: HashMap<(String, String), Relation>, + relations: Relations, } /// Load schema from database. @@ -28,16 +33,13 @@ pub struct Schema { impl Schema { /// Load schema from a server connection. pub async fn load(server: &mut Server) -> Result { - let relations = Relation::load(server) - .await? - .into_iter() - .map(|relation| { - ( - (relation.schema().to_owned(), relation.name.clone()), - relation, - ) - }) - .collect(); + let mut relations: Relations = HashMap::new(); + for relation in Relation::load(server).await? { + relations + .entry(relation.schema().to_owned()) + .or_default() + .insert(relation.name.clone(), relation); + } let search_path = server .fetch_all::("SHOW search_path") @@ -58,25 +60,28 @@ impl Schema { }) } + /// The schema has been loaded from the database. + pub(crate) fn is_loaded(&self) -> bool { + !self.inner.relations.is_empty() + } + #[cfg(test)] pub(crate) fn from_parts( search_path: Vec, relations: HashMap<(String, String), Relation>, ) -> Self { + let mut nested: Relations = HashMap::new(); + for ((schema, name), relation) in relations { + nested.entry(schema).or_default().insert(name, relation); + } Self { inner: Arc::new(Inner { search_path, - relations, + relations: nested, }), } } - /// Load schema from primary database. - pub async fn from_cluster(cluster: &Cluster, shard: usize) -> Result { - let mut primary = cluster.primary(shard, &Request::default()).await?; - Self::load(&mut primary).await - } - /// Install PgDog functions and schema. pub async fn setup(server: &mut Server) -> Result<(), Error> { server.execute_checked(SETUP).await?; @@ -145,19 +150,58 @@ impl Schema { } /// Get table by name. - pub fn table(&self, name: &str, schema: Option<&str>) -> Option<&Relation> { - let schema = schema.unwrap_or("public"); + /// + /// If the table has an explicit schema, looks up in that schema directly. + /// Otherwise, iterates through the search_path to find the first match. + pub fn table( + &self, + table: Table<'_>, + user: &str, + search_path: Option<&ParameterValue>, + ) -> Option<&Relation> { + if let Some(schema) = table.schema { + return self.get(schema, table.name); + } + + for schema in self.resolve_search_path(user, search_path) { + if let Some(relation) = self.get(schema, table.name) { + return Some(relation); + } + } + + None + } + + /// Get a relation by schema and table name. + pub fn get(&self, schema: &str, name: &str) -> Option<&Relation> { self.inner .relations - .get(&(name.to_string(), schema.to_string())) + .get(schema) + .and_then(|tables| tables.get(name)) } - /// Get all indices. + fn resolve_search_path<'a>( + &'a self, + user: &'a str, + search_path: Option<&'a ParameterValue>, + ) -> Vec<&'a str> { + let path: &[String] = match search_path { + Some(ParameterValue::Tuple(overriden)) => overriden.as_slice(), + _ => &self.inner.search_path, + }; + + path.iter() + .map(|p| if p == "$user" { user } else { p.as_str() }) + .collect() + } + + /// Get all tables. pub fn tables(&self) -> Vec<&Relation> { self.inner .relations .values() - .filter(|value| value.is_table()) + .flat_map(|tables| tables.values()) + .filter(|relation| relation.is_table()) .collect() } @@ -166,7 +210,8 @@ impl Schema { self.inner .relations .values() - .filter(|value| value.is_sequence()) + .flat_map(|tables| tables.values()) + .filter(|relation| relation.is_sequence()) .collect() } @@ -177,7 +222,7 @@ impl Schema { } impl Deref for Schema { - type Target = HashMap<(String, String), Relation>; + type Target = Relations; fn deref(&self) -> &Self::Target { &self.inner.relations @@ -186,7 +231,12 @@ impl Deref for Schema { #[cfg(test)] mod test { + use std::collections::HashMap; + use crate::backend::pool::Request; + use crate::backend::schema::relation::Relation; + use crate::frontend::router::parser::Table; + use crate::net::parameter::ParameterValue; use super::super::pool::test::pool; use super::Schema; @@ -207,7 +257,14 @@ mod test { .find(|seq| seq.schema() == "pgdog") .cloned() .unwrap(); - assert_eq!(seq.name, "validator_bigint_id_seq"); + assert!( + matches!( + seq.name.as_str(), + "unique_id_seq" | "validator_bigint_id_seq" + ), + "{}", + seq.name + ); let server_ok = conn.fetch_all::("SELECT 1 AS one").await.unwrap(); assert_eq!(server_ok.first().unwrap().clone(), 1); @@ -218,4 +275,146 @@ mod test { .unwrap(); assert!(debug.first().unwrap().contains("PgDog Debug")); } + + #[test] + fn test_resolve_search_path_default() { + let schema = Schema::from_parts(vec!["$user".into(), "public".into()], HashMap::new()); + + let resolved = schema.resolve_search_path("alice", None); + assert_eq!(resolved, vec!["alice", "public"]); + } + + #[test] + fn test_resolve_search_path_override() { + let schema = Schema::from_parts(vec!["$user".into(), "public".into()], HashMap::new()); + + let override_path = ParameterValue::Tuple(vec!["custom".into(), "other".into()]); + let resolved = schema.resolve_search_path("alice", Some(&override_path)); + assert_eq!(resolved, vec!["custom", "other"]); + } + + #[test] + fn test_resolve_search_path_override_with_user() { + let schema = Schema::from_parts(vec!["public".into()], HashMap::new()); + + let override_path = ParameterValue::Tuple(vec!["$user".into(), "app".into()]); + let resolved = schema.resolve_search_path("bob", Some(&override_path)); + assert_eq!(resolved, vec!["bob", "app"]); + } + + #[test] + fn test_table_with_explicit_schema() { + let relations: HashMap<(String, String), Relation> = HashMap::from([ + ( + ("myschema".into(), "users".into()), + Relation::test_table("myschema", "users", HashMap::new()), + ), + ( + ("public".into(), "users".into()), + Relation::test_table("public", "users", HashMap::new()), + ), + ]); + let schema = Schema::from_parts(vec!["$user".into(), "public".into()], relations); + + let table = Table { + name: "users", + schema: Some("myschema"), + alias: None, + }; + + let result = schema.table(table, "alice", None); + assert!(result.is_some()); + assert_eq!(result.unwrap().schema(), "myschema"); + } + + #[test] + fn test_table_search_path_lookup() { + let relations: HashMap<(String, String), Relation> = HashMap::from([( + ("public".into(), "orders".into()), + Relation::test_table("public", "orders", HashMap::new()), + )]); + let schema = Schema::from_parts(vec!["$user".into(), "public".into()], relations); + + let table = Table { + name: "orders", + schema: None, + alias: None, + }; + + // User schema "alice" doesn't have "orders", but "public" does + let result = schema.table(table, "alice", None); + assert!(result.is_some()); + assert_eq!(result.unwrap().schema(), "public"); + } + + #[test] + fn test_table_found_in_user_schema() { + let relations: HashMap<(String, String), Relation> = HashMap::from([ + ( + ("alice".into(), "settings".into()), + Relation::test_table("alice", "settings", HashMap::new()), + ), + ( + ("public".into(), "settings".into()), + Relation::test_table("public", "settings", HashMap::new()), + ), + ]); + let schema = Schema::from_parts(vec!["$user".into(), "public".into()], relations); + + let table = Table { + name: "settings", + schema: None, + alias: None, + }; + + // Should find in "alice" schema first (due to $user) + let result = schema.table(table, "alice", None); + assert!(result.is_some()); + assert_eq!(result.unwrap().schema(), "alice"); + } + + #[test] + fn test_table_not_found() { + let relations: HashMap<(String, String), Relation> = HashMap::from([( + ("public".into(), "users".into()), + Relation::test_table("public", "users", HashMap::new()), + )]); + let schema = Schema::from_parts(vec!["$user".into(), "public".into()], relations); + + let table = Table { + name: "nonexistent", + schema: None, + alias: None, + }; + + let result = schema.table(table, "alice", None); + assert!(result.is_none()); + } + + #[test] + fn test_table_with_overridden_search_path() { + let relations: HashMap<(String, String), Relation> = HashMap::from([ + ( + ("custom".into(), "data".into()), + Relation::test_table("custom", "data", HashMap::new()), + ), + ( + ("public".into(), "data".into()), + Relation::test_table("public", "data", HashMap::new()), + ), + ]); + let schema = Schema::from_parts(vec!["$user".into(), "public".into()], relations); + + let table = Table { + name: "data", + schema: None, + alias: None, + }; + + // Override search_path to look in "custom" first + let override_path = ParameterValue::Tuple(vec!["custom".into(), "public".into()]); + let result = schema.table(table, "alice", Some(&override_path)); + assert!(result.is_some()); + assert_eq!(result.unwrap().schema(), "custom"); + } } diff --git a/pgdog/src/backend/schema/relation.rs b/pgdog/src/backend/schema/relation.rs index 9f914a308..a25b1e1dc 100644 --- a/pgdog/src/backend/schema/relation.rs +++ b/pgdog/src/backend/schema/relation.rs @@ -17,7 +17,6 @@ pub struct Relation { pub owner: String, pub persistence: String, pub access_method: String, - pub size: usize, pub description: String, pub oid: i32, pub columns: HashMap, @@ -32,9 +31,8 @@ impl From for Relation { owner: value.get_text(3).unwrap_or_default(), persistence: value.get_text(4).unwrap_or_default(), access_method: value.get_text(5).unwrap_or_default(), - size: value.get_int(6, true).unwrap_or_default() as usize, - description: value.get_text(7).unwrap_or_default(), - oid: value.get::(8, Format::Text).unwrap_or_default(), + description: value.get_text(6).unwrap_or_default(), + oid: value.get::(7, Format::Text).unwrap_or_default(), columns: HashMap::new(), } } @@ -96,6 +94,10 @@ impl Relation { pub fn columns(&self) -> &HashMap { &self.columns } + + pub fn has_column(&self, name: &str) -> bool { + self.columns.contains_key(name) + } } #[cfg(test)] @@ -108,7 +110,6 @@ impl Relation { owner: String::new(), persistence: String::new(), access_method: String::new(), - size: 0, description: String::new(), oid: 0, columns, diff --git a/pgdog/src/backend/schema/relations.sql b/pgdog/src/backend/schema/relations.sql index fa4b56170..ff489373b 100644 --- a/pgdog/src/backend/schema/relations.sql +++ b/pgdog/src/backend/schema/relations.sql @@ -18,8 +18,7 @@ WHEN 'u' THEN 'unlogged' end AS "persistence", am.amname AS "access_method", - pg_catalog.pg_table_size(c.oid) AS "size", - pg_catalog.obj_description(c.oid, 'pg_class') AS "description", + pg_catalog.obj_description(c.oid, 'pg_class') AS "description", c.oid::integer AS "oid" FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n diff --git a/pgdog/src/backend/schema/setup.sql b/pgdog/src/backend/schema/setup.sql index 37dd2d33b..25c59558c 100644 --- a/pgdog/src/backend/schema/setup.sql +++ b/pgdog/src/backend/schema/setup.sql @@ -318,5 +318,86 @@ BEGIN END; $body$ LANGUAGE plpgsql; +-- Globally unique 64-bit ID generator (Snowflake-like). +-- Bit allocation: 41 timestamp + 10 node + 12 sequence = 63 bits (keeps sign bit clear) +-- The sequence stores (elapsed_ms << 12) | sequence_within_ms, allowing +-- automatic reset of the sequence counter when the millisecond changes. +CREATE SEQUENCE IF NOT EXISTS pgdog.unique_id_seq; + +CREATE OR REPLACE FUNCTION pgdog.unique_id(id_offset BIGINT DEFAULT 0) RETURNS BIGINT AS $body$ +DECLARE + sequence_bits CONSTANT INTEGER := 12; + node_bits CONSTANT INTEGER := 10; + max_node_id CONSTANT INTEGER := (1 << node_bits) - 1; -- 1023 + max_sequence CONSTANT INTEGER := (1 << sequence_bits) - 1; -- 4095 + max_timestamp CONSTANT BIGINT := (1::bigint << 41) - 1; + pgdog_epoch CONSTANT BIGINT := 1764184395000; -- Wednesday, November 26, 2025 11:13:15 AM GMT-08:00 + node_shift CONSTANT INTEGER := sequence_bits; -- 12 + timestamp_shift CONSTANT INTEGER := sequence_bits + node_bits; -- 22 + + node_id INTEGER; + now_ms BIGINT; + elapsed BIGINT; + min_combined BIGINT; + combined_seq BIGINT; + seq INTEGER; + timestamp_part BIGINT; + node_part BIGINT; + base_id BIGINT; +BEGIN + -- Get node_id from pgdog.config.shard + SELECT pgdog.config.shard INTO node_id FROM pgdog.config; + + IF node_id IS NULL THEN + RAISE EXCEPTION 'pgdog.config.shard not set'; + END IF; + + IF node_id < 0 OR node_id > max_node_id THEN + RAISE EXCEPTION 'shard must be between 0 and %', max_node_id; + END IF; + + LOOP + -- Get next combined sequence value + combined_seq := nextval('pgdog.unique_id_seq'); + + -- Get current time in milliseconds since Unix epoch + now_ms := (EXTRACT(EPOCH FROM clock_timestamp()) * 1000)::bigint; + elapsed := now_ms - pgdog_epoch; + + IF elapsed < 0 THEN + RAISE EXCEPTION 'Clock is before PgDog epoch (November 26, 2025)'; + END IF; + + -- Minimum valid combined value for current millisecond + min_combined := elapsed << 12; + + -- If sequence is at or ahead of current time, we're good + IF combined_seq >= min_combined THEN + EXIT; + END IF; + + -- Sequence is behind current time, advance it + PERFORM setval('pgdog.unique_id_seq', min_combined, false); + END LOOP; + + -- Decompose the combined sequence value + seq := (combined_seq & max_sequence)::integer; + elapsed := combined_seq >> 12; + + IF elapsed > max_timestamp THEN + RAISE EXCEPTION 'Timestamp overflow: % > %', elapsed, max_timestamp; + END IF; + + -- Compose the ID: timestamp | node | sequence + timestamp_part := elapsed << timestamp_shift; + node_part := node_id::bigint << node_shift; + base_id := timestamp_part | node_part | seq; + + RETURN base_id + id_offset; +END; +$body$ LANGUAGE plpgsql; + +GRANT USAGE ON SEQUENCE pgdog.unique_id_seq TO PUBLIC; + -- Allow functions to be executed by anyone. GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA pgdog TO PUBLIC; diff --git a/pgdog/src/frontend/client/query_engine/route_query.rs b/pgdog/src/frontend/client/query_engine/route_query.rs index ad24bf924..208057be8 100644 --- a/pgdog/src/frontend/client/query_engine/route_query.rs +++ b/pgdog/src/frontend/client/query_engine/route_query.rs @@ -1,4 +1,5 @@ use pgdog_config::PoolerMode; +use tokio::time::timeout; use tracing::trace; use super::*; @@ -21,7 +22,7 @@ impl QueryEngine { context: &mut QueryEngineContext<'_>, ) -> Result { // Admin doesn't have a cluster. - if let Ok(cluster) = self.backend.cluster() { + let res = if let Ok(cluster) = self.backend.cluster() { if !context.in_transaction() && !cluster.online() { let identifier = cluster.identifier(); @@ -43,6 +44,23 @@ impl QueryEngine { } } else { Ok(ClusterCheck::Ok) + }; + + if let Ok(ClusterCheck::Ok) = res { + // Make sure schema is loaded before we throw traffic + // at it. This matters for sharded deployments only. + if let Ok(cluster) = self.backend.cluster() { + cluster.wait_schema_loaded().await; + // timeout( + // context.timeouts.query_timeout(&State::Active), + // cluster.wait_schema_loaded(), + // ) + // .await + // .map_err(|_| Error::SchemaLoad)?; + } + res + } else { + res } } diff --git a/pgdog/src/frontend/error.rs b/pgdog/src/frontend/error.rs index bfa1cac74..6bfeea9c9 100644 --- a/pgdog/src/frontend/error.rs +++ b/pgdog/src/frontend/error.rs @@ -45,6 +45,9 @@ pub enum Error { #[error("query timeout")] Timeout(#[from] tokio::time::error::Elapsed), + #[error("schema load timeout")] + SchemaLoad, + #[error("join error")] Join(#[from] tokio::task::JoinError), diff --git a/pgdog/src/frontend/router/context.rs b/pgdog/src/frontend/router/context.rs index 935522d0d..75507c60d 100644 --- a/pgdog/src/frontend/router/context.rs +++ b/pgdog/src/frontend/router/context.rs @@ -1,6 +1,6 @@ use super::{Error, ParameterHints}; use crate::{ - backend::Cluster, + backend::{Cluster, Schema}, frontend::{ client::{Sticky, TransactionType}, router::Ast, @@ -33,6 +33,8 @@ pub struct RouterContext<'a> { pub extended: bool, /// AST. pub ast: Option, + /// Schema. + pub schema: Schema, } impl<'a> RouterContext<'a> { @@ -59,6 +61,7 @@ impl<'a> RouterContext<'a> { extended: matches!(query, Some(BufferedQuery::Prepared(_))) || bind.is_some(), query, ast: buffer.ast.clone(), + schema: cluster.schema(), }) } diff --git a/pgdog/src/frontend/router/parser/multi_tenant.rs b/pgdog/src/frontend/router/parser/multi_tenant.rs index d9f7ef736..bcff95a08 100644 --- a/pgdog/src/frontend/router/parser/multi_tenant.rs +++ b/pgdog/src/frontend/router/parser/multi_tenant.rs @@ -83,9 +83,7 @@ impl<'a> MultiTenantCheck<'a> { let schemas = search_path.resolve(); for schema in schemas { - let schema_table = self - .schema - .get(&(schema.to_owned(), table.name.to_string())); + let schema_table = self.schema.get(schema, table.name); if let Some(schema_table) = schema_table { let has_tenant_id = schema_table.columns().contains_key(&self.config.column); if !has_tenant_id { diff --git a/pgdog/src/frontend/router/parser/query/select.rs b/pgdog/src/frontend/router/parser/query/select.rs index a69a40968..fc21c761c 100644 --- a/pgdog/src/frontend/router/parser/query/select.rs +++ b/pgdog/src/frontend/router/parser/query/select.rs @@ -40,25 +40,46 @@ impl QueryParser { let mut shards = HashSet::new(); - let shard = StatementParser::from_select( - stmt, - context.router_context.bind, - &context.sharding_schema, - self.recorder_mut(), - ) - .shard()?; + let (shard, is_sharded, tables) = { + let mut statement_parser = StatementParser::from_select( + stmt, + context.router_context.bind, + &context.sharding_schema, + self.recorder_mut(), + ); + + let shard = statement_parser.shard()?; + + if shard.is_some() { + (shard, true, vec![]) + } else { + ( + None, + statement_parser.is_sharded( + &context.router_context.schema, + context.router_context.cluster.user(), + context.router_context.parameter_hints.search_path, + ), + statement_parser.extract_tables(), + ) + } + }; if let Some(shard) = shard { shards.insert(shard); } - // `SELECT NOW()`, `SELECT 1`, etc. + // SELECT NOW(), SELECT 1 if shards.is_empty() && stmt.from_clause.is_empty() { + let shard = Shard::Direct(round_robin::next() % context.shards); + + if let Some(recorder) = self.recorder_mut() { + recorder.record_entry(Some(shard.clone()), format!("SELECT omnishard no table")); + } + context .shards_calculator - .push(ShardWithPriority::new_rr_no_table(Shard::Direct( - round_robin::next() % context.shards, - ))); + .push(ShardWithPriority::new_rr_no_table(shard)); return Ok(Command::Query( Route::read(context.shards_calculator.shard().clone()).with_write(writes), @@ -97,64 +118,58 @@ impl QueryParser { let limit = LimitClause::new(stmt, context.router_context.bind).limit_offset()?; let distinct = Distinct::new(stmt).distinct()?; - context - .shards_calculator - .push(ShardWithPriority::new_table(shard)); + if let Some(shard) = shard { + debug!("direct-to-shard {}", shard); - let mut query = Route::select( - context.shards_calculator.shard().clone(), - order_by, - aggregates, - limit, - distinct, - ); + context + .shards_calculator + .push(ShardWithPriority::new_table(shard)); + } else if is_sharded { + debug!("table is sharded, but no sharding key detected"); - // Omnisharded tables check. - if query.is_all_shards() { - let tables = from_clause.tables(); - let mut sticky = false; - let omni = tables.iter().all(|table| { - let is_sticky = context.sharding_schema.tables.omnishards().get(table.name); + context + .shards_calculator + .push(ShardWithPriority::new_table(Shard::All)); + } else { + debug!( + "table is not sharded, defaulting to omnisharded (schema loaded: {})", + context.router_context.schema.is_loaded() + ); - if let Some(is_sticky) = is_sticky { - if *is_sticky { - sticky = true; - } - true - } else { - false - } + // Omnisharded by default. + let sticky = tables.iter().any(|table| { + context + .sharding_schema + .tables() + .is_omnisharded_sticky(table.name) + == Some(true) }); - if omni { - let shard = if sticky { - context.router_context.sticky.omni_index - } else { - round_robin::next() - } % context.shards; + let (rr_index, explain) = if sticky { + (context.router_context.sticky.omni_index, "sticky") + } else { + (round_robin::next(), "round robin") + }; - context - .shards_calculator - .push(ShardWithPriority::new_rr_omni(Shard::Direct(shard))); - - query.set_shard_mut(context.shards_calculator.shard().clone()); - - if let Some(recorder) = self.recorder_mut() { - recorder.record_entry( - Some(shard.into()), - format!( - "SELECT matched omnisharded tables: {}", - tables - .iter() - .map(|table| table.name) - .collect::>() - .join(", ") - ), - ); - } + let shard = Shard::Direct(rr_index % context.shards); + + if let Some(recorder) = self.recorder_mut() { + recorder.record_entry(Some(shard.clone()), format!("SELECT omnishard {}", explain)); } + + context + .shards_calculator + .push(ShardWithPriority::new_rr_omni(shard)); } + let mut query = Route::select( + context.shards_calculator.shard().clone(), + order_by, + aggregates, + limit, + distinct, + ); + // Only rewrite if query is cross-shard. if query.is_cross_shard() && context.shards > 1 { query.with_aggregate_rewrite_plan_mut(cached_ast.rewrite_plan.aggregates.clone()); diff --git a/pgdog/src/frontend/router/parser/query/shared.rs b/pgdog/src/frontend/router/parser/query/shared.rs index 433b72689..dd22c92d8 100644 --- a/pgdog/src/frontend/router/parser/query/shared.rs +++ b/pgdog/src/frontend/router/parser/query/shared.rs @@ -13,9 +13,11 @@ pub(super) enum ConvergeAlgorithm { impl QueryParser { /// Converge to a single route given multiple shards. - pub(super) fn converge(shards: &HashSet, algorithm: ConvergeAlgorithm) -> Shard { - let shard = if shards.len() == 1 { - shards.iter().next().cloned().unwrap() + pub(super) fn converge(shards: &HashSet, algorithm: ConvergeAlgorithm) -> Option { + let shard = if shards.is_empty() { + None + } else if shards.len() == 1 { + shards.iter().next().cloned() } else { let mut multi = HashSet::new(); let mut all = false; @@ -35,15 +37,15 @@ impl QueryParser { if algorithm == ConvergeAlgorithm::FirstDirect { let direct = shards.iter().find(|shard| shard.is_direct()); if let Some(direct) = direct { - return direct.clone(); + return Some(direct.clone()); } } - if all || shards.is_empty() { + Some(if all || shards.is_empty() { Shard::All } else { Shard::Multi(multi.into_iter().collect()) - } + }) }; shard @@ -60,10 +62,10 @@ mod tests { let shards = HashSet::from([Shard::Direct(5)]); let result = QueryParser::converge(&shards, ConvergeAlgorithm::AllFirstElseMulti); - assert_eq!(result, Shard::Direct(5)); + assert_eq!(result, Some(Shard::Direct(5))); let result = QueryParser::converge(&shards, ConvergeAlgorithm::FirstDirect); - assert_eq!(result, Shard::Direct(5)); + assert_eq!(result, Some(Shard::Direct(5))); } #[test] @@ -71,10 +73,10 @@ mod tests { let shards = HashSet::from([Shard::All]); let result = QueryParser::converge(&shards, ConvergeAlgorithm::AllFirstElseMulti); - assert_eq!(result, Shard::All); + assert_eq!(result, Some(Shard::All)); let result = QueryParser::converge(&shards, ConvergeAlgorithm::FirstDirect); - assert_eq!(result, Shard::All); + assert_eq!(result, Some(Shard::All)); } #[test] @@ -82,10 +84,10 @@ mod tests { let shards = HashSet::from([Shard::Multi(vec![1, 2, 3])]); let result = QueryParser::converge(&shards, ConvergeAlgorithm::AllFirstElseMulti); - assert_eq!(result, Shard::Multi(vec![1, 2, 3])); + assert_eq!(result, Some(Shard::Multi(vec![1, 2, 3]))); let result = QueryParser::converge(&shards, ConvergeAlgorithm::FirstDirect); - assert_eq!(result, Shard::Multi(vec![1, 2, 3])); + assert_eq!(result, Some(Shard::Multi(vec![1, 2, 3]))); } #[test] @@ -94,7 +96,7 @@ mod tests { let result = QueryParser::converge(&shards, ConvergeAlgorithm::AllFirstElseMulti); match result { - Shard::Multi(mut v) => { + Some(Shard::Multi(mut v)) => { v.sort(); assert_eq!(v, vec![1, 2]); } @@ -108,7 +110,7 @@ mod tests { let result = QueryParser::converge(&shards, ConvergeAlgorithm::FirstDirect); assert!( - matches!(result, Shard::Direct(1) | Shard::Direct(2)), + matches!(result, Some(Shard::Direct(1)) | Some(Shard::Direct(2))), "expected Direct(1) or Direct(2), got {:?}", result ); @@ -119,7 +121,7 @@ mod tests { let shards = HashSet::from([Shard::All, Shard::Direct(1)]); let result = QueryParser::converge(&shards, ConvergeAlgorithm::AllFirstElseMulti); - assert_eq!(result, Shard::All); + assert_eq!(result, Some(Shard::All)); } #[test] @@ -127,18 +129,18 @@ mod tests { let shards = HashSet::from([Shard::All, Shard::Direct(1)]); let result = QueryParser::converge(&shards, ConvergeAlgorithm::FirstDirect); - assert_eq!(result, Shard::Direct(1)); + assert_eq!(result, Some(Shard::Direct(1))); } #[test] - fn empty_set_returns_all() { + fn empty_set_returns_none() { let shards = HashSet::new(); let result = QueryParser::converge(&shards, ConvergeAlgorithm::AllFirstElseMulti); - assert_eq!(result, Shard::All); + assert_eq!(result, None); let result = QueryParser::converge(&shards, ConvergeAlgorithm::FirstDirect); - assert_eq!(result, Shard::All); + assert_eq!(result, None); } #[test] @@ -147,7 +149,7 @@ mod tests { let result = QueryParser::converge(&shards, ConvergeAlgorithm::AllFirstElseMulti); match result { - Shard::Multi(mut v) => { + Some(Shard::Multi(mut v)) => { v.sort(); assert_eq!(v, vec![1, 2, 3]); } @@ -160,7 +162,7 @@ mod tests { let shards = HashSet::from([Shard::Multi(vec![1, 2]), Shard::Direct(3)]); let result = QueryParser::converge(&shards, ConvergeAlgorithm::FirstDirect); - assert_eq!(result, Shard::Direct(3)); + assert_eq!(result, Some(Shard::Direct(3))); } #[test] @@ -168,6 +170,6 @@ mod tests { let shards = HashSet::from([Shard::All, Shard::Multi(vec![1, 2])]); let result = QueryParser::converge(&shards, ConvergeAlgorithm::FirstDirect); - assert_eq!(result, Shard::All); + assert_eq!(result, Some(Shard::All)); } } diff --git a/pgdog/src/frontend/router/parser/query/test/mod.rs b/pgdog/src/frontend/router/parser/query/test/mod.rs index b2fd43219..7ab00e6aa 100644 --- a/pgdog/src/frontend/router/parser/query/test/mod.rs +++ b/pgdog/src/frontend/router/parser/query/test/mod.rs @@ -334,8 +334,9 @@ fn test_omni() { assert!(matches!(shard, Shard::Direct(_))); } - // Test that all tables have to be omnisharded. - let q = "SELECT * FROM sharded_omni INNER JOIN not_sharded ON sharded_omni.id = not_sharded.id WHERE sharded_omni = $1"; + // If a sharded table is joined to an omnisharded table, + // the query goes to all shards, not round robin + let q = "SELECT * FROM sharded_omni INNER JOIN not_sharded ON sharded_omni.id = not_sharded.id INNER JOIN sharded ON sharded.id = sharded_omni.id WHERE sharded_omni = $1"; let route = query!(q); assert!(matches!(route.shard(), Shard::All)); } diff --git a/pgdog/src/frontend/router/parser/query/test/test_schema_sharding.rs b/pgdog/src/frontend/router/parser/query/test/test_schema_sharding.rs index 50d018541..85ce5ac3a 100644 --- a/pgdog/src/frontend/router/parser/query/test/test_schema_sharding.rs +++ b/pgdog/src/frontend/router/parser/query/test/test_schema_sharding.rs @@ -1,3 +1,4 @@ +use crate::frontend::router::parser::route::RoundRobinReason; use crate::frontend::router::parser::{route::ShardSource, Shard}; use crate::net::parameter::ParameterValue; @@ -30,15 +31,19 @@ fn test_select_from_shard_1_schema() { } #[test] -fn test_select_from_unsharded_schema_goes_to_all() { +fn test_select_from_unsharded_schema_goes_to_rr() { let mut test = QueryParserTest::new(); let command = test.execute(vec![ Query::new("SELECT * FROM public.users WHERE id = 1").into() ]); - // Unknown schema goes to all shards - assert_eq!(command.route().shard(), &Shard::All); + // Unknown schema goes to omnisharded + assert!(matches!(command.route().shard(), &Shard::Direct(_))); + assert_eq!( + command.route().shard_with_priority().source(), + &ShardSource::RoundRobin(RoundRobinReason::Omni) + ); } // --- INSERT queries with schema-qualified tables --- diff --git a/pgdog/src/frontend/router/parser/query/test/test_search_path.rs b/pgdog/src/frontend/router/parser/query/test/test_search_path.rs index bfd1a53e5..63f0b15de 100644 --- a/pgdog/src/frontend/router/parser/query/test/test_search_path.rs +++ b/pgdog/src/frontend/router/parser/query/test/test_search_path.rs @@ -1,4 +1,5 @@ -use crate::frontend::router::parser::Shard; +use crate::frontend::router::parser::route::RoundRobinReason; +use crate::frontend::router::parser::{route::ShardSource, Shard}; use crate::net::parameter::ParameterValue; use super::setup::{QueryParserTest, *}; @@ -186,7 +187,7 @@ fn test_search_path_shard_at_end_still_matches() { // --- search_path with no sharded schema routes to all --- #[test] -fn test_search_path_no_sharded_schema_routes_to_all() { +fn test_search_path_no_sharded_schema_routes_to_rr() { let mut test = QueryParserTest::new().with_param( "search_path", ParameterValue::Tuple(vec!["$user".into(), "public".into(), "pg_catalog".into()]), @@ -194,11 +195,15 @@ fn test_search_path_no_sharded_schema_routes_to_all() { let command = test.execute(vec![Query::new("SELECT * FROM users WHERE id = 1").into()]); - assert_eq!(command.route().shard(), &Shard::All); + assert!(matches!(command.route().shard(), &Shard::Direct(_))); + assert_eq!( + command.route().shard_with_priority().source(), + &ShardSource::RoundRobin(RoundRobinReason::Omni) + ); } #[test] -fn test_search_path_only_system_schemas_routes_to_all() { +fn test_search_path_only_system_schemas_routes_to_rr() { let mut test = QueryParserTest::new().with_param( "search_path", ParameterValue::Tuple(vec![ @@ -210,7 +215,11 @@ fn test_search_path_only_system_schemas_routes_to_all() { let command = test.execute(vec![Query::new("SELECT * FROM users WHERE id = 1").into()]); - assert_eq!(command.route().shard(), &Shard::All); + assert!(matches!(command.route().shard(), &Shard::Direct(_))); + assert_eq!( + command.route().shard_with_priority().source(), + &ShardSource::RoundRobin(RoundRobinReason::Omni) + ); } // --- search_path routing for DDL --- diff --git a/pgdog/src/frontend/router/parser/query/test/test_sharding.rs b/pgdog/src/frontend/router/parser/query/test/test_sharding.rs index 57064a5c5..bfc38d625 100644 --- a/pgdog/src/frontend/router/parser/query/test/test_sharding.rs +++ b/pgdog/src/frontend/router/parser/query/test/test_sharding.rs @@ -114,7 +114,7 @@ fn test_omni_sharded_table_takes_priority() { #[test] fn test_omni_all_tables_must_be_omnisharded() { - let q = "SELECT * FROM sharded_omni INNER JOIN not_sharded ON sharded_omni.id = not_sharded.id WHERE sharded_omni = $1"; + let q = "SELECT * FROM sharded_omni INNER JOIN not_sharded ON sharded_omni.id = not_sharded.id INNER JOIN sharded ON sharded.id = sharded_omni.id WHERE sharded_omni = $1"; let mut test = QueryParserTest::new(); let command = test.execute(vec![Query::new(q).into()]); diff --git a/pgdog/src/frontend/router/parser/statement.rs b/pgdog/src/frontend/router/parser/statement.rs index da3276333..be38885cc 100644 --- a/pgdog/src/frontend/router/parser/statement.rs +++ b/pgdog/src/frontend/router/parser/statement.rs @@ -12,9 +12,12 @@ use super::{ Value, }; use crate::{ - backend::ShardingSchema, - frontend::router::{parser::Shard, sharding::ContextBuilder, sharding::SchemaSharder}, - net::Bind, + backend::{Schema, ShardingSchema}, + frontend::router::{ + parser::Shard, + sharding::{ContextBuilder, SchemaSharder}, + }, + net::{parameter::ParameterValue, Bind}, }; /// Context for searching a SELECT statement, tracking table aliases. @@ -288,8 +291,57 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { Ok(None) } + /// Check that the query references a table that contains a sharded + /// column. This check is needed in case sharded tables config + /// doesn't specify a table name and should short-circuit if it does. + pub fn is_sharded( + &self, + db_schema: &Schema, + user: &str, + search_path: Option<&ParameterValue>, + ) -> bool { + let sharded_tables = self.schema.tables.tables(); + + // Separate configs with explicit table names from those without + let (named, nameless): (Vec<_>, Vec<_>) = + sharded_tables.iter().partition(|t| t.name.is_some()); + + let tables = self.extract_tables(); + + for table in &tables { + // Check named sharded table configs (fast path, no schema lookup needed) + for config in &named { + if let Some(ref name) = config.name { + if table.name == name { + // Also check schema match if specified in config + if let Some(ref config_schema) = config.schema { + if table.schema != Some(config_schema.as_str()) { + continue; + } + } + return true; + } + } + } + + // Check nameless configs by looking up the table in the db schema + // to see if it has the sharding column + if !nameless.is_empty() { + if let Some(relation) = db_schema.table(*table, user, search_path) { + for config in &nameless { + if relation.has_column(&config.column) { + return true; + } + } + } + } + } + + false + } + /// Extract all tables referenced in the statement. - fn extract_tables(&self) -> Vec> { + pub fn extract_tables(&self) -> Vec> { let mut tables = Vec::new(); match self.stmt { Statement::Select(stmt) => self.extract_tables_from_select(stmt, &mut tables), From 6fb1c0d362762b44e2e4ca0e9f378bdb5638c166 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 22 Jan 2026 11:40:51 -0800 Subject: [PATCH 2/4] increase timeout --- integration/go/go_pgx/pg_tests_test.go | 6 +++--- .../src/frontend/client/query_engine/route_query.rs | 13 ++++++------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/integration/go/go_pgx/pg_tests_test.go b/integration/go/go_pgx/pg_tests_test.go index 12d55f0a0..bda5be1d0 100644 --- a/integration/go/go_pgx/pg_tests_test.go +++ b/integration/go/go_pgx/pg_tests_test.go @@ -222,7 +222,7 @@ func executeTimeoutTest(t *testing.T) { c := make(chan int, 1) go func() { - err := pgSleepOneSecond(conn, ctx) + err := pgSleepTwoSecond(conn, ctx) assert.NotNil(t, err) defer conn.Close(context.Background()) @@ -240,8 +240,8 @@ func executeTimeoutTest(t *testing.T) { } // Sleep for 1 second. -func pgSleepOneSecond(conn *pgx.Conn, ctx context.Context) (err error) { - _, err = conn.Exec(ctx, "SELECT pg_sleep(1)") +func pgSleepTwoSecond(conn *pgx.Conn, ctx context.Context) (err error) { + _, err = conn.Exec(ctx, "SELECT pg_sleep(2)") return err } diff --git a/pgdog/src/frontend/client/query_engine/route_query.rs b/pgdog/src/frontend/client/query_engine/route_query.rs index 208057be8..38875dcfa 100644 --- a/pgdog/src/frontend/client/query_engine/route_query.rs +++ b/pgdog/src/frontend/client/query_engine/route_query.rs @@ -50,13 +50,12 @@ impl QueryEngine { // Make sure schema is loaded before we throw traffic // at it. This matters for sharded deployments only. if let Ok(cluster) = self.backend.cluster() { - cluster.wait_schema_loaded().await; - // timeout( - // context.timeouts.query_timeout(&State::Active), - // cluster.wait_schema_loaded(), - // ) - // .await - // .map_err(|_| Error::SchemaLoad)?; + timeout( + context.timeouts.query_timeout(&State::Active), + cluster.wait_schema_loaded(), + ) + .await + .map_err(|_| Error::SchemaLoad)?; } res } else { From 0af09359097278b9bbe7eaeee2488d0129428281 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 22 Jan 2026 12:23:18 -0800 Subject: [PATCH 3/4] add more tests --- Cargo.lock | 2 +- pgdog/Cargo.toml | 2 +- pgdog/src/backend/pool/cluster.rs | 174 +++++++++++++++++++++++++++++- 3 files changed, 173 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e22666cd3..b20554fa0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2398,7 +2398,7 @@ dependencies = [ [[package]] name = "pgdog" -version = "0.1.25" +version = "0.1.26" dependencies = [ "arc-swap", "async-trait", diff --git a/pgdog/Cargo.toml b/pgdog/Cargo.toml index ba6079fe9..8573b6f7f 100644 --- a/pgdog/Cargo.toml +++ b/pgdog/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgdog" -version = "0.1.25" +version = "0.1.26" edition = "2021" description = "Modern PostgreSQL proxy, pooler and load balancer." authors = ["PgDog "] diff --git a/pgdog/src/backend/pool/cluster.rs b/pgdog/src/backend/pool/cluster.rs index 8f3bfbb5b..29b769122 100644 --- a/pgdog/src/backend/pool/cluster.rs +++ b/pgdog/src/backend/pool/cluster.rs @@ -467,7 +467,10 @@ impl Cluster { } fn load_schema(&self) -> bool { - self.shards.len() > 1 || self.multi_tenant().is_some() + self.shards.len() > 1 + && self.sharded_schemas.is_empty() + && !self.sharded_tables.tables().is_empty() + || self.multi_tenant().is_some() } /// Get currently loaded schema from shard 0. @@ -602,8 +605,8 @@ mod test { Shard, ShardedTables, }, config::{ - DataType, Hasher, LoadBalancingStrategy, ReadWriteSplit, ReadWriteStrategy, - ShardedTable, + DataType, Hasher, LoadBalancingStrategy, MultiTenant, ReadWriteSplit, + ReadWriteStrategy, ShardedTable, }, }; @@ -702,4 +705,169 @@ mod test { self.rw_strategy = rw_strategy; } } + + #[test] + fn test_load_schema_multiple_shards_empty_schemas_with_tables() { + let config = ConfigAndUsers::default(); + let mut cluster = Cluster::new_test(&config); + cluster.sharded_schemas = ShardedSchemas::default(); + + assert!(cluster.load_schema()); + } + + #[test] + fn test_load_schema_multiple_shards_with_schemas() { + let config = ConfigAndUsers::default(); + let cluster = Cluster::new_test(&config); + + assert!(!cluster.load_schema()); + } + + #[test] + fn test_load_schema_multiple_shards_empty_tables() { + let config = ConfigAndUsers::default(); + let mut cluster = Cluster::new_test(&config); + cluster.sharded_schemas = ShardedSchemas::default(); + cluster.sharded_tables = ShardedTables::default(); + + assert!(!cluster.load_schema()); + } + + #[test] + fn test_load_schema_single_shard() { + let config = ConfigAndUsers::default(); + let mut cluster = Cluster::new_test_single_shard(&config); + cluster.sharded_schemas = ShardedSchemas::default(); + + assert!(!cluster.load_schema()); + } + + #[test] + fn test_load_schema_with_multi_tenant() { + let config = ConfigAndUsers::default(); + let mut cluster = Cluster::new_test_single_shard(&config); + cluster.multi_tenant = Some(MultiTenant { + column: "tenant_id".into(), + }); + + assert!(cluster.load_schema()); + } + + #[test] + fn test_load_schema_multi_tenant_overrides_other_conditions() { + let config = ConfigAndUsers::default(); + let mut cluster = Cluster::new_test(&config); + cluster.sharded_tables = ShardedTables::default(); + cluster.multi_tenant = Some(MultiTenant { + column: "tenant_id".into(), + }); + + assert!(cluster.load_schema()); + } + + #[tokio::test] + async fn test_launch_sets_online() { + let config = ConfigAndUsers::default(); + let cluster = Cluster::new_test(&config); + + assert!(!cluster.online()); + cluster.launch(); + assert!(cluster.online()); + } + + #[tokio::test] + async fn test_shutdown_sets_offline() { + let config = ConfigAndUsers::default(); + let cluster = Cluster::new_test(&config); + + cluster.launch(); + assert!(cluster.online()); + cluster.shutdown(); + assert!(!cluster.online()); + } + + #[tokio::test] + async fn test_launch_schema_loading_idempotent() { + use std::sync::atomic::Ordering; + use tokio::time::{sleep, Duration}; + + let config = ConfigAndUsers::default(); + let mut cluster = Cluster::new_test(&config); + cluster.sharded_schemas = ShardedSchemas::default(); + + assert!(cluster.load_schema()); + + cluster.launch(); + cluster.wait_schema_loaded().await; + + let count_after_first = cluster.readiness.schemas_loaded.load(Ordering::SeqCst); + assert_eq!(count_after_first, cluster.shards.len()); + + // Second launch should not spawn additional schema loading tasks + cluster.launch(); + sleep(Duration::from_millis(50)).await; + + let count_after_second = cluster.readiness.schemas_loaded.load(Ordering::SeqCst); + assert_eq!(count_after_second, count_after_first); + } + + #[tokio::test] + async fn test_wait_schema_loaded_returns_immediately_when_not_needed() { + let config = ConfigAndUsers::default(); + let cluster = Cluster::new_test(&config); + + // load_schema() returns false because sharded_schemas is not empty + assert!(!cluster.load_schema()); + + // Should return immediately without waiting + cluster.wait_schema_loaded().await; + } + + #[tokio::test] + async fn test_wait_schema_loaded_fast_path_when_already_loaded() { + use std::sync::atomic::Ordering; + + let config = ConfigAndUsers::default(); + let mut cluster = Cluster::new_test(&config); + cluster.sharded_schemas = ShardedSchemas::default(); + + assert!(cluster.load_schema()); + + // Simulate that all schemas have been loaded + cluster + .readiness + .schemas_loaded + .store(cluster.shards.len(), Ordering::SeqCst); + + // Should return immediately via fast path + cluster.wait_schema_loaded().await; + } + + #[tokio::test] + async fn test_wait_schema_loaded_waits_for_notification() { + use std::sync::atomic::Ordering; + use tokio::time::{timeout, Duration}; + + let config = ConfigAndUsers::default(); + let mut cluster = Cluster::new_test(&config); + cluster.sharded_schemas = ShardedSchemas::default(); + + assert!(cluster.load_schema()); + + let readiness = cluster.readiness.clone(); + let shards_count = cluster.shards.len(); + + // Spawn a task that will complete schema loading after a short delay + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + readiness + .schemas_loaded + .store(shards_count, Ordering::SeqCst); + readiness.schemas_ready.notify_waiters(); + }); + + // Should wait for notification and complete within timeout + let result = timeout(Duration::from_millis(100), cluster.wait_schema_loaded()).await; + assert!(result.is_ok()); + } } From 24158f8275248b94a479c69e48ea8a02d7dc6114 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 22 Jan 2026 12:31:18 -0800 Subject: [PATCH 4/4] unique id test --- integration/rust/tests/sqlx/unique_id.rs | 154 ++++++++++++++++++++++- 1 file changed, 153 insertions(+), 1 deletion(-) diff --git a/integration/rust/tests/sqlx/unique_id.rs b/integration/rust/tests/sqlx/unique_id.rs index af553c673..f39772969 100644 --- a/integration/rust/tests/sqlx/unique_id.rs +++ b/integration/rust/tests/sqlx/unique_id.rs @@ -1,4 +1,5 @@ -use sqlx::{Postgres, pool::Pool, postgres::PgPoolOptions}; +use rust::setup::connection_sqlx_direct; +use sqlx::{Executor, Postgres, pool::Pool, postgres::PgPoolOptions}; async fn sharded_pool() -> Pool { PgPoolOptions::new() @@ -58,3 +59,154 @@ async fn test_unique_id_uniqueness() { conn.close().await; } + +/// Test that pgdog.unique_id() PL/pgSQL function produces IDs with the same +/// bit layout as Rust's unique_id.rs implementation. +#[tokio::test] +async fn test_unique_id_bit_layout_matches_rust() { + // Constants from Rust unique_id.rs - these must match the SQL implementation + const SEQUENCE_BITS: u64 = 12; + const NODE_BITS: u64 = 10; + const NODE_SHIFT: u64 = SEQUENCE_BITS; // 12 + const TIMESTAMP_SHIFT: u64 = SEQUENCE_BITS + NODE_BITS; // 22 + const MAX_NODE_ID: u64 = (1 << NODE_BITS) - 1; // 1023 + const MAX_SEQUENCE: u64 = (1 << SEQUENCE_BITS) - 1; // 4095 + const PGDOG_EPOCH: u64 = 1764184395000; + + let conn = connection_sqlx_direct().await; + + // Run schema setup to ensure pgdog schema exists + let setup_sql = include_str!("../../../../pgdog/src/backend/schema/setup.sql"); + conn.execute(setup_sql).await.expect("schema setup failed"); + + // Configure pgdog.config with a known shard value + let test_shard: i64 = 42; + conn.execute("DELETE FROM pgdog.config") + .await + .expect("clear config"); + conn.execute( + sqlx::query("INSERT INTO pgdog.config (shard, shards) VALUES ($1, 100)").bind(test_shard), + ) + .await + .expect("insert config"); + + // Generate an ID using the SQL function + let row: (i64,) = sqlx::query_as("SELECT pgdog.unique_id()") + .fetch_one(&conn) + .await + .expect("generate unique_id"); + let id = row.0 as u64; + + // Extract components using the same bit layout as Rust + let extracted_sequence = id & MAX_SEQUENCE; + let extracted_node = (id >> NODE_SHIFT) & MAX_NODE_ID; + let extracted_timestamp = id >> TIMESTAMP_SHIFT; + + // Verify node_id matches the configured shard + assert_eq!( + extracted_node, test_shard as u64, + "node_id in generated ID should match configured shard" + ); + + // Verify timestamp is reasonable (after epoch, within a day) + let now_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + let expected_elapsed = now_ms - PGDOG_EPOCH; + + assert!( + extracted_timestamp > 0, + "timestamp should be positive (after epoch)" + ); + // Timestamp should be close to current time (within 5 seconds) + let diff = if extracted_timestamp > expected_elapsed { + extracted_timestamp - expected_elapsed + } else { + expected_elapsed - extracted_timestamp + }; + assert!( + diff < 5000, + "timestamp {} should be close to expected {} (diff: {}ms)", + extracted_timestamp, + expected_elapsed, + diff + ); + + // Verify sequence is within valid range + assert!( + extracted_sequence <= MAX_SEQUENCE, + "sequence {} should not exceed max {}", + extracted_sequence, + MAX_SEQUENCE + ); + + // Generate multiple IDs and verify they're monotonically increasing + let mut prev_id = id; + for _ in 0..100 { + let row: (i64,) = sqlx::query_as("SELECT pgdog.unique_id()") + .fetch_one(&conn) + .await + .unwrap(); + let new_id = row.0 as u64; + assert!( + new_id > prev_id, + "IDs should be monotonically increasing: {} > {}", + new_id, + prev_id + ); + prev_id = new_id; + } + + conn.close().await; + + // Also test through pgdog (sharded pool) and verify bit layout matches + let sharded = sharded_pool().await; + + let row: (i64,) = sqlx::query_as("SELECT pgdog.unique_id()") + .fetch_one(&sharded) + .await + .expect("generate unique_id through pgdog"); + let pgdog_id = row.0 as u64; + + // Extract components from pgdog-generated ID + let pgdog_sequence = pgdog_id & MAX_SEQUENCE; + let pgdog_node = (pgdog_id >> NODE_SHIFT) & MAX_NODE_ID; + let pgdog_timestamp = pgdog_id >> TIMESTAMP_SHIFT; + + // Verify node_id is valid (0 or 1 for sharded setup) + assert!( + pgdog_node <= MAX_NODE_ID, + "pgdog node_id {} should be valid", + pgdog_node + ); + + // Verify timestamp is close to current time + let now_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + let expected_elapsed = now_ms - PGDOG_EPOCH; + let pgdog_diff = if pgdog_timestamp > expected_elapsed { + pgdog_timestamp - expected_elapsed + } else { + expected_elapsed - pgdog_timestamp + }; + assert!( + pgdog_diff < 5000, + "pgdog timestamp {} should be close to expected {} (diff: {}ms)", + pgdog_timestamp, + expected_elapsed, + pgdog_diff + ); + + // Verify sequence is valid + assert!( + pgdog_sequence <= MAX_SEQUENCE, + "pgdog sequence {} should not exceed max {}", + pgdog_sequence, + MAX_SEQUENCE + ); + + sharded.close().await; +}