frontend, storage: improve GetBlockRange, fix tests
This commit is contained in:
@@ -5,7 +5,9 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/protobuf/proto"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
|
||||||
"github.com/gtank/ctxd/rpc"
|
"github.com/gtank/ctxd/rpc"
|
||||||
@@ -77,22 +79,28 @@ func (s *SqlStreamer) GetBlock(ctx context.Context, id *rpc.BlockID) (*rpc.Compa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStreamer) GetBlockRange(span *rpc.BlockRange, resp rpc.CompactTxStreamer_GetBlockRangeServer) error {
|
func (s *SqlStreamer) GetBlockRange(span *rpc.BlockRange, resp rpc.CompactTxStreamer_GetBlockRangeServer) error {
|
||||||
blocks := make(chan []byte)
|
blockChan := make(chan []byte)
|
||||||
errors := make(chan error)
|
errChan := make(chan error)
|
||||||
done := make(chan bool)
|
|
||||||
|
|
||||||
timeout := resp.Context().WithTimeout(1 * time.Second)
|
// TODO configure or stress-test this timeout
|
||||||
go GetBlockRange(timeout, s.db, blocks, errors, done, span.Start, span.End)
|
timeout, cancel := context.WithTimeout(resp.Context(), 1*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
go storage.GetBlockRange(timeout,
|
||||||
|
s.db,
|
||||||
|
blockChan,
|
||||||
|
errChan,
|
||||||
|
int(span.Start.Height),
|
||||||
|
int(span.End.Height),
|
||||||
|
)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-timeout.Done():
|
case err := <-errChan:
|
||||||
return timeout.Err()
|
// this will also catch context.DeadlineExceeded from the timeout
|
||||||
case err := <-errors:
|
|
||||||
return err
|
return err
|
||||||
case blockBytes := <-blocks:
|
case blockBytes := <-blockChan:
|
||||||
cBlock := &rpc.CompactBlock{}
|
cBlock := &rpc.CompactBlock{}
|
||||||
err = proto.Unmarshal(blockBytes, cBlock)
|
err := proto.Unmarshal(blockBytes, cBlock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err // TODO really need better logging in this whole service
|
return err // TODO really need better logging in this whole service
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrBadRange = errors.New("no blocks in specified range")
|
ErrLotsOfBlocks = errors.New("requested >10k blocks at once")
|
||||||
)
|
)
|
||||||
|
|
||||||
func CreateTables(conn *sql.DB) error {
|
func CreateTables(conn *sql.DB) error {
|
||||||
@@ -84,35 +84,41 @@ func GetBlockByHash(ctx context.Context, db *sql.DB, hash string) ([]byte, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// [start, end] inclusive
|
// [start, end] inclusive
|
||||||
func GetBlockRange(ctx context.Context, db *sql.DB, blocks chan<- []byte, errors chan<- error, start, end int) {
|
func GetBlockRange(ctx context.Context, db *sql.DB, blockOut chan<- []byte, errOut chan<- error, start, end int) {
|
||||||
// TODO sanity check ranges, rate limit?
|
// TODO sanity check ranges, this limit, etc
|
||||||
numBlocks := (end - start) + 1
|
numBlocks := (end - start) + 1
|
||||||
query := "SELECT height, compact_encoding from blocks WHERE (height BETWEEN ? AND ?)"
|
if numBlocks > 10000 {
|
||||||
|
errOut <- ErrLotsOfBlocks
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
query := "SELECT compact_encoding from blocks WHERE (height BETWEEN ? AND ?)"
|
||||||
result, err := db.QueryContext(ctx, query, start, end)
|
result, err := db.QueryContext(ctx, query, start, end)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors <- err
|
errOut <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer result.Close()
|
defer result.Close()
|
||||||
|
|
||||||
// My assumption here is that if the context is cancelled then result.Next() will fail.
|
// My assumption here is that if the context is cancelled then result.Next() will fail.
|
||||||
|
|
||||||
var blockBytes []byte // avoid a copy with *RawBytes
|
var blockBytes []byte
|
||||||
for result.Next() {
|
for result.Next() {
|
||||||
err = result.Scan(&blockBytes)
|
err = result.Scan(&blockBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors <- err
|
errOut <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
blocks <- blockBytes
|
blockOut <- blockBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := result.Err(); err != nil {
|
if err := result.Err(); err != nil {
|
||||||
errors <- err
|
errOut <- err
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// done
|
// done
|
||||||
errors <- nil
|
errOut <- nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func StoreBlock(conn *sql.DB, height int, hash string, sapling bool, encoded []byte) error {
|
func StoreBlock(conn *sql.DB, height int, hash string, sapling bool, encoded []byte) error {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
@@ -36,88 +37,163 @@ func TestSqliteStorage(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
err = CreateTables(conn)
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
for _, test := range compactTests {
|
// Fill tables
|
||||||
blockData, _ := hex.DecodeString(test.Full)
|
{
|
||||||
block := parser.NewBlock()
|
err = CreateTables(db)
|
||||||
blockData, err = block.ParseFromSlice(blockData)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(errors.Wrap(err, fmt.Sprintf("parsing testnet block %d", test.BlockHeight)))
|
t.Fatal(err)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
height := block.GetHeight()
|
for _, test := range compactTests {
|
||||||
hash := hex.EncodeToString(block.GetEncodableHash())
|
blockData, _ := hex.DecodeString(test.Full)
|
||||||
hasSapling := block.HasSaplingTransactions()
|
block := parser.NewBlock()
|
||||||
protoBlock := block.ToCompact()
|
blockData, err = block.ParseFromSlice(blockData)
|
||||||
marshaled, _ := proto.Marshal(protoBlock)
|
if err != nil {
|
||||||
|
t.Error(errors.Wrap(err, fmt.Sprintf("parsing testnet block %d", test.BlockHeight)))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
err = StoreBlock(conn, height, hash, hasSapling, marshaled)
|
height := block.GetHeight()
|
||||||
if err != nil {
|
hash := hex.EncodeToString(block.GetEncodableHash())
|
||||||
t.Error(err)
|
hasSapling := block.HasSaplingTransactions()
|
||||||
continue
|
protoBlock := block.ToCompact()
|
||||||
|
marshaled, _ := proto.Marshal(protoBlock)
|
||||||
|
|
||||||
|
err = StoreBlock(db, height, hash, hasSapling, marshaled)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var count int
|
// Count the blocks
|
||||||
countBlocks := "SELECT count(*) FROM blocks"
|
{
|
||||||
err = conn.QueryRow(countBlocks).Scan(&count)
|
var count int
|
||||||
if err != nil {
|
countBlocks := "SELECT count(*) FROM blocks"
|
||||||
t.Error(errors.Wrap(err, fmt.Sprintf("counting compact blocks")))
|
err = db.QueryRow(countBlocks).Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(errors.Wrap(err, fmt.Sprintf("counting compact blocks")))
|
||||||
|
}
|
||||||
|
|
||||||
|
if count != len(compactTests) {
|
||||||
|
t.Errorf("Wrong row count, want %d got %d", len(compactTests), count)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if count != len(compactTests) {
|
ctx := context.Background()
|
||||||
t.Errorf("Wrong row count, want %d got %d", len(compactTests), count)
|
|
||||||
|
// Check height state is as expected
|
||||||
|
{
|
||||||
|
blockHeight, err := GetCurrentHeight(ctx, db)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(errors.Wrap(err, fmt.Sprintf("checking current block height")))
|
||||||
|
}
|
||||||
|
|
||||||
|
lastBlockTest := compactTests[len(compactTests)-1]
|
||||||
|
if blockHeight != lastBlockTest.BlockHeight {
|
||||||
|
t.Errorf("Wrong block height, got: %d", blockHeight)
|
||||||
|
}
|
||||||
|
|
||||||
|
retBlock, err := GetBlock(ctx, db, blockHeight)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(errors.Wrap(err, "retrieving stored block"))
|
||||||
|
}
|
||||||
|
cblock := &rpc.CompactBlock{}
|
||||||
|
err = proto.Unmarshal(retBlock, cblock)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if int(cblock.Height) != lastBlockTest.BlockHeight {
|
||||||
|
t.Error("incorrect retrieval")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
blockHeight, err := GetCurrentHeight(context.Background(), conn)
|
// Block ranges
|
||||||
if err != nil {
|
{
|
||||||
t.Error(errors.Wrap(err, fmt.Sprintf("checking current block height")))
|
blockOut := make(chan []byte)
|
||||||
}
|
errOut := make(chan error)
|
||||||
|
|
||||||
lastBlockTest := compactTests[len(compactTests)-1]
|
count := 0
|
||||||
if blockHeight != lastBlockTest.BlockHeight {
|
go GetBlockRange(ctx, db, blockOut, errOut, 289460, 289465)
|
||||||
t.Errorf("Wrong block height, got: %d", blockHeight)
|
recvLoop0:
|
||||||
}
|
for {
|
||||||
|
select {
|
||||||
|
case <-blockOut:
|
||||||
|
count++
|
||||||
|
case err := <-errOut:
|
||||||
|
if err != nil {
|
||||||
|
t.Error(errors.Wrap(err, "in full blockrange"))
|
||||||
|
}
|
||||||
|
break recvLoop0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
retBlock, err := GetBlock(context.Background(), conn, blockHeight)
|
if count != 6 {
|
||||||
if err != nil {
|
t.Error("failed to retrieve full range")
|
||||||
t.Error(errors.Wrap(err, "retrieving stored block"))
|
}
|
||||||
}
|
|
||||||
cblock := &rpc.CompactBlock{}
|
|
||||||
err = proto.Unmarshal(retBlock, cblock)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if int(cblock.Height) != lastBlockTest.BlockHeight {
|
// Test timeout
|
||||||
t.Error("incorrect retrieval")
|
timeout, _ := context.WithTimeout(ctx, 0*time.Second)
|
||||||
}
|
go GetBlockRange(timeout, db, blockOut, errOut, 289460, 289465)
|
||||||
|
recvLoop1:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case err := <-errOut:
|
||||||
|
if err != context.DeadlineExceeded {
|
||||||
|
t.Errorf("got the wrong error: %v", err)
|
||||||
|
}
|
||||||
|
break recvLoop1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
blockRange, err := GetBlockRange(conn, 289460, 289465)
|
// Test a smaller range
|
||||||
if err != nil {
|
count = 0
|
||||||
t.Error(err)
|
go GetBlockRange(ctx, db, blockOut, errOut, 289462, 289465)
|
||||||
}
|
recvLoop2:
|
||||||
if len(blockRange) != 6 {
|
for {
|
||||||
t.Error("failed to retrieve full range")
|
select {
|
||||||
}
|
case <-blockOut:
|
||||||
|
count++
|
||||||
|
case err := <-errOut:
|
||||||
|
if err != nil {
|
||||||
|
t.Error(errors.Wrap(err, "in short blockrange"))
|
||||||
|
}
|
||||||
|
break recvLoop2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
blockRange, err = GetBlockRange(conn, 289462, 289465)
|
if count != 4 {
|
||||||
if err != nil {
|
t.Errorf("failed to retrieve the shorter range")
|
||||||
t.Error(err)
|
}
|
||||||
}
|
|
||||||
if len(blockRange) != 4 {
|
// Test a nonsense range
|
||||||
t.Error("failed to retrieve partial range")
|
count = 0
|
||||||
}
|
go GetBlockRange(ctx, db, blockOut, errOut, 1, 2)
|
||||||
|
recvLoop3:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-blockOut:
|
||||||
|
count++
|
||||||
|
case err := <-errOut:
|
||||||
|
if err != nil {
|
||||||
|
t.Error(errors.Wrap(err, "in invalid blockrange"))
|
||||||
|
}
|
||||||
|
break recvLoop3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if count > 0 {
|
||||||
|
t.Errorf("got some blocks that shouldn't be there")
|
||||||
|
}
|
||||||
|
|
||||||
blockRange, err = GetBlockRange(conn, 1337, 1338)
|
|
||||||
if err != ErrBadRange {
|
|
||||||
t.Error("Somehow retrieved nonexistent blocks!")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user