Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions pgdog/src/frontend/router/parser/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ pub struct CopyParser {
sharded_column: usize,
/// Schema shard.
schema_shard: Option<Shard>,
/// String representing NULL values in text/CSV format.
null_string: String,
}

impl Default for CopyParser {
Expand All @@ -86,6 +88,7 @@ impl Default for CopyParser {
sharded_table: None,
sharded_column: 0,
schema_shard: None,
null_string: "\\N".to_owned(),
}
}
}
Expand Down Expand Up @@ -187,6 +190,7 @@ impl CopyParser {
)))
};
parser.sharding_schema = cluster.sharding_schema();
parser.null_string = null_string;

Ok(parser)
}
Expand Down Expand Up @@ -234,12 +238,16 @@ impl CopyParser {
.get(self.sharded_column)
.ok_or(Error::NoShardingColumn)?;

let ctx = ContextBuilder::new(table)
.data(key)
.shards(self.sharding_schema.shards)
.build()?;
if key == self.null_string {
Shard::All
} else {
let ctx = ContextBuilder::new(table)
.data(key)
.shards(self.sharding_schema.shards)
.build()?;

ctx.apply()?
ctx.apply()?
}
} else if let Some(schema_shard) = self.schema_shard.clone() {
schema_shard
} else {
Expand Down Expand Up @@ -443,6 +451,38 @@ mod test {
assert_eq!(sharded[2].shard(), &Shard::All);
}

#[test]
fn test_copy_text_null_sharding_key() {
// pg_dump text format uses `\N` to represent NULL values.
// When the sharding key is NULL, route to all shards.
// When a non-sharding column is NULL, route normally based on the key.
let copy = "COPY sharded (id, value) FROM STDIN";
let stmt = parse(copy).unwrap();
let stmt = stmt.protobuf.stmts.first().unwrap();
let copy = match stmt.stmt.clone().unwrap().node.unwrap() {
NodeEnum::CopyStmt(copy) => copy,
_ => panic!("not a copy"),
};

let mut copy = CopyParser::new(&copy, &Cluster::new_test(&config())).unwrap();

let one = CopyData::new("1\tAlice\n".as_bytes());
let two = CopyData::new("\\N\tBob\n".as_bytes());
let three = CopyData::new("11\tCharlie\n".as_bytes());
let four = CopyData::new("6\t\\N\n".as_bytes());

let sharded = copy.shard(&[one, two, three, four]).unwrap();
assert_eq!(sharded.len(), 4);
assert_eq!(sharded[0].message().data(), b"1\tAlice\n");
assert_eq!(sharded[0].shard(), &Shard::Direct(0));
assert_eq!(sharded[1].message().data(), b"\\N\tBob\n");
assert_eq!(sharded[1].shard(), &Shard::All);
assert_eq!(sharded[2].message().data(), b"11\tCharlie\n");
assert_eq!(sharded[2].shard(), &Shard::Direct(1));
assert_eq!(sharded[3].message().data(), b"6\t\\N\n");
assert_eq!(sharded[3].shard(), &Shard::Direct(1));
}

#[test]
fn test_copy_binary() {
let copy = "COPY sharded (id, value) FROM STDIN (FORMAT 'binary')";
Expand Down
Loading