Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions integration/rust/tests/integration/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,77 @@ async fn update_expects_transactions() {
cleanup_table(&pool).await;
}

#[tokio::test]
async fn test_error_disconnects_and_update_works() -> Result<(), Box<dyn std::error::Error>> {
let conn = connections_sqlx().await.pop().unwrap();
let admin = admin_sqlx().await;

admin.execute("SET rewrite_enabled TO true").await?;
admin
.execute("SET rewrite_shard_key_updates TO 'rewrite'")
.await?;
admin
.execute("SET rewrite_split_inserts TO 'rewrite'")
.await?;
admin.execute("SET two_phase_commit TO true").await?;
admin.execute("SET two_phase_commit_auto TO true").await?;

conn.execute("TRUNCATE TABLE sharded").await?;

for _ in 0..250 {
conn.execute("INSERT INTO sharded (id, value) VALUES (pgdog.unique_id(), 'test')")
.await?;
}

let ids: Vec<(i64,)> = sqlx::query_as("SELECT id FROM sharded")
.fetch_all(&conn)
.await?;

let mut errors = 0;
for id in ids.iter() {
let num: i64 = rand::random();
let err = sqlx::query("UPDATE sharded SET id = $2 WHERE id = $1")
.bind(id.0)
.bind(num as i64)
.execute(&conn)
.await
.err();

if let Some(err) = err {
errors += 1;
assert!(
err.to_string()
.contains("sharding key update must be executed inside a transaction"),
"{}",
err.to_string(),
);
}

let mut transaction = conn.begin().await?;
sqlx::query("UPDATE sharded SET id = $2 WHERE id = $1")
.bind(id.0)
.bind(num as i64)
.execute(&mut *transaction)
.await?;
let _ = sqlx::query("SELECT * FROM sharded WHERE id = $1")
.bind(num as i64)
.fetch_one(&mut *transaction)
.await?;
transaction.commit().await?;

let _ = sqlx::query("SELECT * FROM sharded WHERE id = $1")
.bind(num as i64)
.fetch_one(&conn)
.await?;
}

assert!(errors > 0);

admin.execute("RELOAD").await?;

Ok(())
}

async fn prepare_table(pool: &Pool<Postgres>) {
for shard in [0, 1] {
let drop = format!("/* pgdog_shard: {shard} */ DROP TABLE IF EXISTS {TEST_TABLE}");
Expand Down
3 changes: 2 additions & 1 deletion pgdog/src/frontend/client/query_engine/multi_step/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,12 @@ impl<'a> UpdateMulti<'a> {
return Ok(());
}

if !context.in_transaction() && !self.engine.backend.is_multishard()
if !context.in_transaction() || !self.engine.backend.is_multishard()
// Do this check at the last possible moment.
// Just in case we change how transactions are
// routed in the future.
{
self.engine.cleanup_backend(context);
return Err(UpdateError::TransactionRequired.into());
}

Expand Down
Loading