diff --git a/integration/rust/tests/integration/rewrite.rs b/integration/rust/tests/integration/rewrite.rs index 1141a9dd9..0d44fb649 100644 --- a/integration/rust/tests/integration/rewrite.rs +++ b/integration/rust/tests/integration/rewrite.rs @@ -280,6 +280,77 @@ async fn update_expects_transactions() { cleanup_table(&pool).await; } +#[tokio::test] +async fn test_error_disconnects_and_update_works() -> Result<(), Box> { + let conn = connections_sqlx().await.pop().unwrap(); + let admin = admin_sqlx().await; + + admin.execute("SET rewrite_enabled TO true").await?; + admin + .execute("SET rewrite_shard_key_updates TO 'rewrite'") + .await?; + admin + .execute("SET rewrite_split_inserts TO 'rewrite'") + .await?; + admin.execute("SET two_phase_commit TO true").await?; + admin.execute("SET two_phase_commit_auto TO true").await?; + + conn.execute("TRUNCATE TABLE sharded").await?; + + for _ in 0..250 { + conn.execute("INSERT INTO sharded (id, value) VALUES (pgdog.unique_id(), 'test')") + .await?; + } + + let ids: Vec<(i64,)> = sqlx::query_as("SELECT id FROM sharded") + .fetch_all(&conn) + .await?; + + let mut errors = 0; + for id in ids.iter() { + let num: i64 = rand::random(); + let err = sqlx::query("UPDATE sharded SET id = $2 WHERE id = $1") + .bind(id.0) + .bind(num as i64) + .execute(&conn) + .await + .err(); + + if let Some(err) = err { + errors += 1; + assert!( + err.to_string() + .contains("sharding key update must be executed inside a transaction"), + "{}", + err.to_string(), + ); + } + + let mut transaction = conn.begin().await?; + sqlx::query("UPDATE sharded SET id = $2 WHERE id = $1") + .bind(id.0) + .bind(num as i64) + .execute(&mut *transaction) + .await?; + let _ = sqlx::query("SELECT * FROM sharded WHERE id = $1") + .bind(num as i64) + .fetch_one(&mut *transaction) + .await?; + transaction.commit().await?; + + let _ = sqlx::query("SELECT * FROM sharded WHERE id = $1") + .bind(num as i64) + .fetch_one(&conn) + .await?; + } + + assert!(errors > 0); + + admin.execute("RELOAD").await?; + + Ok(()) +} + async fn prepare_table(pool: &Pool) { for shard in [0, 1] { let drop = format!("/* pgdog_shard: {shard} */ DROP TABLE IF EXISTS {TEST_TABLE}"); diff --git a/pgdog/src/frontend/client/query_engine/multi_step/update.rs b/pgdog/src/frontend/client/query_engine/multi_step/update.rs index 2d92b1fed..ee006914b 100644 --- a/pgdog/src/frontend/client/query_engine/multi_step/update.rs +++ b/pgdog/src/frontend/client/query_engine/multi_step/update.rs @@ -126,11 +126,12 @@ impl<'a> UpdateMulti<'a> { return Ok(()); } - if !context.in_transaction() && !self.engine.backend.is_multishard() + if !context.in_transaction() || !self.engine.backend.is_multishard() // Do this check at the last possible moment. // Just in case we change how transactions are // routed in the future. { + self.engine.cleanup_backend(context); return Err(UpdateError::TransactionRequired.into()); }