Skip to content
4 changes: 2 additions & 2 deletions adapter/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (r *GRPCServer) RawGet(ctx context.Context, req *pb.RawGetRequest) (*pb.Raw

v, err := r.store.GetAt(ctx, req.Key, readTS)
if errors.Is(err, store.ErrKeyNotFound) {
return &pb.RawGetResponse{Value: nil}, nil
return &pb.RawGetResponse{Value: nil, Exists: false}, nil
}
if err != nil {
return nil, errors.WithStack(err)
Expand All @@ -97,7 +97,7 @@ func (r *GRPCServer) RawGet(ctx context.Context, req *pb.RawGetRequest) (*pb.Raw
slog.String("key", string(req.Key)),
slog.String("value", string(v)))

return &pb.RawGetResponse{Value: v}, nil
return &pb.RawGetResponse{Value: v, Exists: true}, nil
}

func (r *GRPCServer) RawLatestCommitTS(ctx context.Context, req *pb.RawLatestCommitTSRequest) (*pb.RawLatestCommitTSResponse, error) {
Expand Down
22 changes: 21 additions & 1 deletion adapter/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,33 @@ func Test_value_can_be_deleted(t *testing.T) {
resp, err := c.RawGet(context.TODO(), &pb.RawGetRequest{Key: key})
assert.NoError(t, err, "Get RPC failed")
assert.Nil(t, err)
assert.True(t, resp.Exists)
assert.Equal(t, want, resp.Value)

_, err = c.RawDelete(context.TODO(), &pb.RawDeleteRequest{Key: key})
assert.NoError(t, err, "Delete RPC failed")

_, err = c.RawGet(context.TODO(), &pb.RawGetRequest{Key: key})
resp, err = c.RawGet(context.TODO(), &pb.RawGetRequest{Key: key})
assert.NoError(t, err, "Get RPC failed")
assert.False(t, resp.Exists)
}

func Test_grpc_raw_get_empty_value(t *testing.T) {
t.Parallel()
nodes, adders, _ := createNode(t, 3)
c := rawKVClient(t, adders)
defer shutdown(nodes)

key := []byte("empty-key")
empty := []byte{}

_, err := c.RawPut(context.Background(), &pb.RawPutRequest{Key: key, Value: empty})
assert.NoError(t, err, "Put RPC failed")

resp, err := c.RawGet(context.TODO(), &pb.RawGetRequest{Key: key})
assert.NoError(t, err, "Get RPC failed")
assert.True(t, resp.Exists)
assert.Equal(t, 0, len(resp.Value))
}

func Test_grpc_scan(t *testing.T) {
Expand Down
137 changes: 117 additions & 20 deletions adapter/internal.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package adapter

import (
"bytes"
"context"

"github.com/bootjp/elastickv/kv"
Expand Down Expand Up @@ -29,13 +30,19 @@ var _ pb.InternalServer = (*Internal)(nil)

var ErrNotLeader = errors.New("not leader")
var ErrLeaderNotFound = errors.New("leader not found")
var ErrTxnTimestampOverflow = errors.New("txn timestamp overflow")

func (i *Internal) Forward(_ context.Context, req *pb.ForwardRequest) (*pb.ForwardResponse, error) {
if i.raft.State() != raft.Leader {
return nil, errors.WithStack(ErrNotLeader)
}

i.stampTimestamps(req)
if err := i.stampTimestamps(req); err != nil {
return &pb.ForwardResponse{
Success: false,
CommitIndex: 0,
}, errors.WithStack(err)
}

r, err := i.transactionManager.Commit(req.Requests)
if err != nil {
Expand All @@ -51,35 +58,125 @@ func (i *Internal) Forward(_ context.Context, req *pb.ForwardRequest) (*pb.Forwa
}, nil
}

func (i *Internal) stampTimestamps(req *pb.ForwardRequest) {
func (i *Internal) stampTimestamps(req *pb.ForwardRequest) error {
if req == nil {
return
return nil
}
if req.IsTxn {
var startTs uint64
// All requests in a transaction must have the same timestamp.
// Find a timestamp from the requests, or generate a new one if none exist.
for _, r := range req.Requests {
if r.Ts != 0 {
startTs = r.Ts
break
}
return i.stampTxnTimestamps(req.Requests)
}

i.stampRawTimestamps(req.Requests)
return nil
}

func (i *Internal) stampRawTimestamps(reqs []*pb.Request) {
for _, r := range reqs {
if r == nil {
continue
}
if r.Ts != 0 {
continue
}
if i.clock == nil {
r.Ts = 1
continue
}
r.Ts = i.clock.Next()
}
}

if startTs == 0 && len(req.Requests) > 0 {
startTs = i.clock.Next()
func (i *Internal) stampTxnTimestamps(reqs []*pb.Request) error {
startTS := forwardedTxnStartTS(reqs)
if startTS == 0 {
if i.clock == nil {
startTS = 1
} else {
startTS = i.clock.Next()
}
}
if startTS == ^uint64(0) {
return errors.WithStack(ErrTxnTimestampOverflow)
}

// Assign the unified timestamp to all requests in the transaction.
for _, r := range reqs {
if r != nil {
r.Ts = startTS
}
}

return i.fillForwardedTxnCommitTS(reqs, startTS)
}

// Assign the unified timestamp to all requests in the transaction.
for _, r := range req.Requests {
r.Ts = startTs
func forwardedTxnStartTS(reqs []*pb.Request) uint64 {
for _, r := range reqs {
if r != nil && r.Ts != 0 {
return r.Ts
}
return
}
return 0
}

func forwardedTxnMetaMutation(r *pb.Request, metaPrefix []byte) (*pb.Mutation, bool) {
if r == nil {
return nil, false
}
if r.Phase != pb.Phase_COMMIT && r.Phase != pb.Phase_ABORT {
return nil, false
}
if len(r.Mutations) == 0 || r.Mutations[0] == nil {
return nil, false
}
if !bytes.HasPrefix(r.Mutations[0].Key, metaPrefix) {
return nil, false
}
return r.Mutations[0], true
}

func (i *Internal) fillForwardedTxnCommitTS(reqs []*pb.Request, startTS uint64) error {
type metaToUpdate struct {
m *pb.Mutation
meta kv.TxnMeta
}

for _, r := range req.Requests {
if r.Ts == 0 {
r.Ts = i.clock.Next()
metaMutations := make([]metaToUpdate, 0, len(reqs))
prefix := []byte(kv.TxnMetaPrefix)
for _, r := range reqs {
m, ok := forwardedTxnMetaMutation(r, prefix)
if !ok {
continue
}
meta, err := kv.DecodeTxnMeta(m.Value)
if err != nil {
continue
}
if meta.CommitTS != 0 {
continue
}
metaMutations = append(metaMutations, metaToUpdate{m: m, meta: meta})
}
if len(metaMutations) == 0 {
return nil
}

commitTS := startTS + 1
if commitTS == 0 {
// Overflow: can't choose a commit timestamp strictly greater than startTS.
return errors.WithStack(ErrTxnTimestampOverflow)
}
if i.clock != nil {
i.clock.Observe(startTS)
commitTS = i.clock.Next()
}
if commitTS <= startTS {
// Defensive: avoid writing an invalid CommitTS.
return errors.WithStack(ErrTxnTimestampOverflow)
}
Comment on lines 163 to 175
Copy link

Copilot AI Feb 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In commitTS = 0 forwarding scenario at line 154, if startTS = ^uint64(0), then commitTS = startTS + 1 overflows to 0. The code returns early at line 157, silently leaving CommitTS = 0 in the metadata. This causes the FSM to reject the commit request with ErrTxnCommitTSRequired, but the caller (Forward) doesn't distinguish this case and may return success=false without a clear error. Consider logging a warning or returning an explicit error when overflow is detected to aid debugging.

Copilot uses AI. Check for mistakes.

for _, item := range metaMutations {
item.meta.CommitTS = commitTS
item.m.Value = kv.EncodeTxnMeta(item.meta)
}
return nil
}
80 changes: 80 additions & 0 deletions adapter/internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package adapter

import (
"testing"

"github.com/bootjp/elastickv/kv"
pb "github.com/bootjp/elastickv/proto"
"github.com/stretchr/testify/require"
)

func TestStampTxnTimestamps_RejectsMaxStartTS(t *testing.T) {
t.Parallel()

i := &Internal{}
reqs := []*pb.Request{
{
IsTxn: true,
Phase: pb.Phase_COMMIT,
Ts: ^uint64(0),
Mutations: []*pb.Mutation{
{
Op: pb.Op_PUT,
Key: []byte(kv.TxnMetaPrefix),
Value: kv.EncodeTxnMeta(kv.TxnMeta{PrimaryKey: []byte("k"), CommitTS: 0}),
},
},
},
}

err := i.stampTxnTimestamps(reqs)
require.ErrorIs(t, err, ErrTxnTimestampOverflow)
}

func TestFillForwardedTxnCommitTS_RejectsOverflow(t *testing.T) {
t.Parallel()

i := &Internal{}
reqs := []*pb.Request{
{
IsTxn: true,
Phase: pb.Phase_COMMIT,
Mutations: []*pb.Mutation{
{
Op: pb.Op_PUT,
Key: []byte(kv.TxnMetaPrefix),
Value: kv.EncodeTxnMeta(kv.TxnMeta{PrimaryKey: []byte("k"), CommitTS: 0}),
},
},
},
}

err := i.fillForwardedTxnCommitTS(reqs, ^uint64(0))
require.ErrorIs(t, err, ErrTxnTimestampOverflow)
}

func TestFillForwardedTxnCommitTS_AssignsCommitTS(t *testing.T) {
t.Parallel()

i := &Internal{}
startTS := uint64(10)
reqs := []*pb.Request{
{
IsTxn: true,
Phase: pb.Phase_COMMIT,
Mutations: []*pb.Mutation{
{
Op: pb.Op_PUT,
Key: []byte(kv.TxnMetaPrefix),
Value: kv.EncodeTxnMeta(kv.TxnMeta{PrimaryKey: []byte("k"), CommitTS: 0}),
},
},
},
}

require.NoError(t, i.fillForwardedTxnCommitTS(reqs, startTS))

meta, err := kv.DecodeTxnMeta(reqs[0].Mutations[0].Value)
require.NoError(t, err)
require.Equal(t, startTS+1, meta.CommitTS)
}
27 changes: 24 additions & 3 deletions adapter/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ func (r *RedisServer) Run() error {

name := strings.ToUpper(string(cmd.Args[0]))
if state.inTxn && name != cmdExec && name != cmdDiscard && name != cmdMulti {
state.queue = append(state.queue, cmd)
// redcon reuses the underlying argument buffers; copy queued commands
// so MULTI/EXEC works reliably under concurrency and with -race.
state.queue = append(state.queue, cloneCommand(cmd))
conn.WriteString("QUEUED")
return
}
Expand All @@ -170,6 +172,17 @@ func (r *RedisServer) Run() error {
return errors.WithStack(err)
}

func cloneCommand(cmd redcon.Command) redcon.Command {
out := redcon.Command{
Raw: bytes.Clone(cmd.Raw),
Args: make([][]byte, len(cmd.Args)),
}
for i := range cmd.Args {
out.Args[i] = bytes.Clone(cmd.Args[i])
}
return out
}

func (r *RedisServer) Stop() {
_ = r.listen.Close()
}
Expand Down Expand Up @@ -233,8 +246,14 @@ func (r *RedisServer) get(conn redcon.Conn, cmd redcon.Command) {
return
}

key := cmd.Args[1]
readTS := r.readTS()
v, err := r.readValueAt(cmd.Args[1], readTS)
// When proxying reads to the leader, let the leader choose a safe snapshot.
// Our local store watermark may lag behind a just-committed transaction.
if !r.coordinator.IsLeaderForKey(key) {
readTS = 0
}
v, err := r.readValueAt(key, readTS)
if err != nil {
switch {
case errors.Is(err, store.ErrKeyNotFound):
Expand Down Expand Up @@ -1156,7 +1175,9 @@ func (r *RedisServer) tryLeaderGetAt(key []byte, ts uint64) ([]byte, error) {
if err != nil {
return nil, errors.WithStack(err)
}

if !resp.GetExists() {
return nil, errors.WithStack(store.ErrKeyNotFound)
}
return resp.Value, nil
}

Expand Down
2 changes: 1 addition & 1 deletion adapter/redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func TestRedis_follower_redirect_node_set_get_deleted(t *testing.T) {
assert.Equal(t, int64(1), res3.Val())

res4 := rdb.Get(ctx, string(key))
assert.NoError(t, res4.Err())
assert.Equal(t, redis.Nil, res4.Err())
assert.Equal(t, "", res4.Val())
}

Expand Down
Loading
Loading