diff --git a/storage/sqlite3.go b/storage/sqlite3.go index 1dfe5a5..e149272 100644 --- a/storage/sqlite3.go +++ b/storage/sqlite3.go @@ -1,16 +1,149 @@ package storage -import "database/sql" +import ( + "database/sql" + "fmt" -func createBlockTable(conn *sql.DB) error { - tableCreation := ` + protobuf "github.com/golang/protobuf/proto" + "github.com/gtank/ctxd/proto" + "github.com/pkg/errors" +) + +var ( + ErrBadRange = errors.New("no blocks in specified range") +) + +func createTables(conn *sql.DB) error { + stateTable := ` + CREATE TABLE IF NOT EXISTS state ( + current_height INTEGER, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (current_height) REFERENCES blocks (height) + ); + ` + _, err := conn.Exec(stateTable) + if err != nil { + return err + } + + blockTable := ` CREATE TABLE IF NOT EXISTS blocks ( height INTEGER PRIMARY KEY, hash TEXT, has_sapling_tx BOOL, + encoding_version INTEGER, compact_encoding BLOB ); ` - _, err := conn.Exec(tableCreation) + _, err = conn.Exec(blockTable) + return err +} + +func CurrentHeight(conn *sql.DB) (int, error) { + var height int + query := "SELECT current_height FROM state WHERE rowid = 1" + err := conn.QueryRow(query).Scan(&height) + return height, err +} + +func SetCurrentHeight(conn *sql.DB, height int) error { + update := "UPDATE state SET current_height=?, timestamp=CURRENT_TIMESTAMP WHERE rowid = 1" + result, err := conn.Exec(update, height) + if err != nil { + return errors.Wrap(err, "updating state row") + } + rowCount, err := result.RowsAffected() + if err != nil { + return errors.Wrap(err, "checking if state row exists") + } + if rowCount == 0 { + // row does not yet exist + insert := "INSERT OR IGNORE INTO state (rowid, current_height) VALUES (1, ?)" + result, err = conn.Exec(insert, height) + if err != nil { + return errors.Wrap(err, "on state row insert") + } + rowCount, err = result.RowsAffected() + if err != nil { + return errors.Wrap(err, "checking if state row exists") + } + if rowCount != 1 { + return errors.New("totally failed to update current height state") + } + } + return nil +} + +func GetBlock(conn *sql.DB, height int) (*proto.CompactBlock, error) { + var blockBytes []byte // avoid a copy with *RawBytes + query := "SELECT compact_encoding from blocks WHERE height = ?" + err := conn.QueryRow(query, height).Scan(&blockBytes) + if err != nil { + return nil, err + } + compactBlock := &proto.CompactBlock{} + err = protobuf.Unmarshal(blockBytes, compactBlock) + return compactBlock, err +} + +// [start, end] +func GetBlockRange(conn *sql.DB, start, end int) ([]*proto.CompactBlock, error) { + // TODO sanity check range bounds + query := "SELECT compact_encoding from blocks WHERE (height BETWEEN ? AND ?)" + result, err := conn.Query(query, start, end) + if err != nil { + return nil, err + } + defer result.Close() + + compactBlocks := make([]*proto.CompactBlock, 0, (end-start)+1) + for result.Next() { + var blockBytes []byte // avoid a copy with *RawBytes + err = result.Scan(&blockBytes) + if err != nil { + return nil, err + } + newBlock := &proto.CompactBlock{} + err = protobuf.Unmarshal(blockBytes, newBlock) + if err != nil { + return nil, err + } + compactBlocks = append(compactBlocks, newBlock) + } + + err = result.Err() + if err != nil { + return nil, err + } + + if len(compactBlocks) == 0 { + return nil, ErrBadRange + } + return compactBlocks, nil +} + +func GetBlockByHash(conn *sql.DB, hash string) (*proto.CompactBlock, error) { + var blockBytes []byte // avoid a copy with *RawBytes + query := "SELECT compact_encoding from blocks WHERE hash = ?" + err := conn.QueryRow(query, hash).Scan(&blockBytes) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("getting block with hash %s", hash)) + } + compactBlock := &proto.CompactBlock{} + err = protobuf.Unmarshal(blockBytes, compactBlock) + return compactBlock, err +} + +func StoreBlock(conn *sql.DB, height int, hash string, sapling bool, version int, encoded []byte) error { + insertBlock := "INSERT INTO blocks (height, hash, has_sapling_tx, encoding_version, compact_encoding) values (?, ?, ?, ?, ?)" + _, err := conn.Exec(insertBlock, height, hash, sapling, version, encoded) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("storing compact block %d", height)) + } + + currentHeight, err := CurrentHeight(conn) + if err != nil || height > currentHeight { + err = SetCurrentHeight(conn, height) + } return err } diff --git a/storage/sqlite3_test.go b/storage/sqlite3_test.go index 1a4b552..415e42b 100644 --- a/storage/sqlite3_test.go +++ b/storage/sqlite3_test.go @@ -14,7 +14,7 @@ import ( "github.com/pkg/errors" ) -func TestFillDB(t *testing.T) { +func TestSqliteStorage(t *testing.T) { type compactTest struct { BlockHeight int `json:"block"` BlockHash string `json:"hash"` @@ -38,7 +38,8 @@ func TestFillDB(t *testing.T) { t.Fatal(err) } defer conn.Close() - err = createBlockTable(conn) + + err = createTables(conn) if err != nil { t.Fatal(err) } @@ -55,19 +56,20 @@ func TestFillDB(t *testing.T) { height := block.GetHeight() hash := hex.EncodeToString(block.GetHash()) hasSapling := block.HasSaplingTransactions() - marshaled, _ := protobuf.Marshal(block.ToCompact()) + protoBlock := block.ToCompact() + version := 1 + marshaled, _ := protobuf.Marshal(protoBlock) - insertBlock := "INSERT INTO blocks (height, hash, has_sapling_tx, compact_encoding) values (?, ?, ?, ?)" - _, err := conn.Exec(insertBlock, height, hash, hasSapling, marshaled) + err = StoreBlock(conn, height, hash, hasSapling, version, marshaled) if err != nil { - t.Error(errors.Wrap(err, fmt.Sprintf("storing compact block %d", height))) + t.Error(err) continue } } var count int countBlocks := "SELECT count(*) FROM blocks" - conn.QueryRow(countBlocks).Scan(&count) + err = conn.QueryRow(countBlocks).Scan(&count) if err != nil { t.Error(errors.Wrap(err, fmt.Sprintf("counting compact blocks"))) } @@ -75,4 +77,44 @@ func TestFillDB(t *testing.T) { if count != len(compactTests) { t.Errorf("Wrong row count, want %d got %d", len(compactTests), count) } + + blockHeight, err := CurrentHeight(conn) + if err != nil { + t.Error(errors.Wrap(err, fmt.Sprintf("checking current block height"))) + } + + lastBlockTest := compactTests[len(compactTests)-1] + if blockHeight != lastBlockTest.BlockHeight { + t.Errorf("Wrong block height, got: %d", blockHeight) + } + + retBlock, err := GetBlock(conn, blockHeight) + if err != nil { + t.Error(errors.Wrap(err, "retrieving stored block")) + } + + if int(retBlock.BlockID.BlockHeight) != lastBlockTest.BlockHeight { + t.Error("incorrect retrieval") + } + + blockRange, err := GetBlockRange(conn, 289460, 289465) + if err != nil { + t.Error(err) + } + if len(blockRange) != 6 { + t.Error("failed to retrieve full range") + } + + blockRange, err = GetBlockRange(conn, 289462, 289465) + if err != nil { + t.Error(err) + } + if len(blockRange) != 4 { + t.Error("failed to retrieve partial range") + } + + blockRange, err = GetBlockRange(conn, 1337, 1338) + if err != ErrBadRange { + t.Error("Somehow retrieved nonexistent blocks!") + } }