diff --git a/pkg/reflector/dml_source.go b/pkg/reflector/dml_source.go index fc3879ec..439119a3 100644 --- a/pkg/reflector/dml_source.go +++ b/pkg/reflector/dml_source.go @@ -45,7 +45,7 @@ func (source *sqlDmlSource) Next(ctx context.Context) (statement schema.DMLState } // table layout is: seq, leader_ts, statement - qs := sqlgen.SqlSprintf("SELECT seq, leader_ts, statement FROM $1 WHERE seq > ? ORDER BY seq LIMIT $2", + qs := sqlgen.SqlSprintf("SELECT seq, leader_ts, statement, family_name, table_name FROM $1 WHERE seq > ? ORDER BY seq LIMIT $2", source.ledgerTableName, fmt.Sprintf("%d", blocksize)) @@ -62,9 +62,11 @@ func (source *sqlDmlSource) Next(ctx context.Context) (statement schema.DMLState defer rows.Close() row := struct { - seq int64 - leaderTs string // this is a string b/c the driver errors when trying to Scan into a *time.Time. - statement string + seq int64 + leaderTs string // this is a string b/c the driver errors when trying to Scan into a *time.Time. + statement string + familyName string + tableName string }{} for { @@ -76,7 +78,7 @@ func (source *sqlDmlSource) Next(ctx context.Context) (statement schema.DMLState break } - err = rows.Scan(&row.seq, &row.leaderTs, &row.statement) + err = rows.Scan(&row.seq, &row.leaderTs, &row.statement, &row.familyName, &row.tableName) if err != nil { return statement, errors.Wrap(err, "scan row") } @@ -91,9 +93,11 @@ func (source *sqlDmlSource) Next(ctx context.Context) (statement schema.DMLState } dmlst := schema.DMLStatement{ - Sequence: schema.DMLSequence(row.seq), - Statement: row.statement, - Timestamp: timestamp, + Sequence: schema.DMLSequence(row.seq), + Statement: row.statement, + Timestamp: timestamp, + FamilyName: schema.FamilyName{Name: row.familyName}, + TableName: schema.TableName{Name: row.tableName}, } source.buffer = append(source.buffer, dmlst) diff --git a/pkg/reflector/dml_source_test.go b/pkg/reflector/dml_source_test.go index ccf90249..777f5de0 100644 --- a/pkg/reflector/dml_source_test.go +++ b/pkg/reflector/dml_source_test.go @@ -23,7 +23,9 @@ func (u *sqlDmlSourceTestUtil) InitializeDB() { CREATE TABLE ctlstore_dml_ledger ( seq INTEGER PRIMARY KEY AUTOINCREMENT, leader_ts INTEGER NOT NULL DEFAULT CURRENT_TIMESTAMP, - statement VARCHAR($1) + statement VARCHAR($1), + family_name VARCHAR(191) NOT NULL DEFAULT '', + table_name VARCHAR(191) NOT NULL DEFAULT '' ); INSERT INTO ctlstore_dml_ledger (statement) VALUES(''); DELETE FROM ctlstore_dml_ledger; diff --git a/pkg/reflector/reflector_test.go b/pkg/reflector/reflector_test.go index 9d8fda93..3182ff0e 100644 --- a/pkg/reflector/reflector_test.go +++ b/pkg/reflector/reflector_test.go @@ -168,7 +168,9 @@ func TestReflector(t *testing.T) { CREATE TABLE ctlstore_dml_ledger ( seq INTEGER PRIMARY KEY AUTOINCREMENT, leader_ts INTEGER NOT NULL DEFAULT CURRENT_TIMESTAMP, - statement VARCHAR(786432) + statement VARCHAR(786432), + family_name VARCHAR(191) NOT NULL DEFAULT '', + table_name VARCHAR(191) NOT NULL DEFAULT '' ); `) require.NoError(t, err) diff --git a/pkg/schema/dml.go b/pkg/schema/dml.go index 0357a60e..36ed74a6 100644 --- a/pkg/schema/dml.go +++ b/pkg/schema/dml.go @@ -15,9 +15,11 @@ var currentTestDmlSeq int64 type DMLSequence int64 type DMLStatement struct { - Sequence DMLSequence - Timestamp time.Time - Statement string + Sequence DMLSequence + Timestamp time.Time + Statement string + FamilyName FamilyName + TableName TableName } func (seq DMLSequence) Int() int64 { @@ -33,6 +35,17 @@ func NewTestDMLStatement(statement string) DMLStatement { } } +func NewTestDMLStatementWithSharding(statement string, familyName FamilyName, tableName TableName) DMLStatement { + return DMLStatement{ + Statement: statement, + Sequence: nextTestDmlSeq(), + Timestamp: time.Now(), + FamilyName: FamilyName(familyName), + TableName: TableName(tableName), + } + +} + func nextTestDmlSeq() DMLSequence { return DMLSequence(atomic.AddInt64(¤tTestDmlSeq, 1)) } diff --git a/pkg/schema/dml_test.go b/pkg/schema/dml_test.go new file mode 100644 index 00000000..2ab34a32 --- /dev/null +++ b/pkg/schema/dml_test.go @@ -0,0 +1,59 @@ +package schema + +import ( + "reflect" + "testing" +) + +func TestDMLSequence_Int(t *testing.T) { + tests := []struct { + name string + seq DMLSequence + want int64 + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.seq.Int(); got != tt.want { + t.Errorf("Int() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewTestDMLStatement(t *testing.T) { + type args struct { + statement string + } + tests := []struct { + name string + args args + want DMLStatement + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewTestDMLStatement(tt.args.statement); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewTestDMLStatement() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_nextTestDmlSeq(t *testing.T) { + tests := []struct { + name string + want DMLSequence + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := nextTestDmlSeq(); got != tt.want { + t.Errorf("nextTestDmlSeq() = %v, want %v", got, tt.want) + } + }) + } +}