frontend, storage: improve GetBlockRange, fix tests

This commit is contained in:
George Tankersley
2018-12-14 20:33:50 -05:00
parent abca4335ec
commit 0d84493db3
3 changed files with 171 additions and 81 deletions

View File

@@ -5,7 +5,9 @@ import (
"database/sql" "database/sql"
"encoding/hex" "encoding/hex"
"errors" "errors"
"time"
"github.com/golang/protobuf/proto"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/gtank/ctxd/rpc" "github.com/gtank/ctxd/rpc"
@@ -77,22 +79,28 @@ func (s *SqlStreamer) GetBlock(ctx context.Context, id *rpc.BlockID) (*rpc.Compa
} }
func (s *SqlStreamer) GetBlockRange(span *rpc.BlockRange, resp rpc.CompactTxStreamer_GetBlockRangeServer) error { func (s *SqlStreamer) GetBlockRange(span *rpc.BlockRange, resp rpc.CompactTxStreamer_GetBlockRangeServer) error {
blocks := make(chan []byte) blockChan := make(chan []byte)
errors := make(chan error) errChan := make(chan error)
done := make(chan bool)
timeout := resp.Context().WithTimeout(1 * time.Second) // TODO configure or stress-test this timeout
go GetBlockRange(timeout, s.db, blocks, errors, done, span.Start, span.End) timeout, cancel := context.WithTimeout(resp.Context(), 1*time.Second)
defer cancel()
go storage.GetBlockRange(timeout,
s.db,
blockChan,
errChan,
int(span.Start.Height),
int(span.End.Height),
)
for { for {
select { select {
case <-timeout.Done(): case err := <-errChan:
return timeout.Err() // this will also catch context.DeadlineExceeded from the timeout
case err := <-errors:
return err return err
case blockBytes := <-blocks: case blockBytes := <-blockChan:
cBlock := &rpc.CompactBlock{} cBlock := &rpc.CompactBlock{}
err = proto.Unmarshal(blockBytes, cBlock) err := proto.Unmarshal(blockBytes, cBlock)
if err != nil { if err != nil {
return err // TODO really need better logging in this whole service return err // TODO really need better logging in this whole service
} }

View File

@@ -9,7 +9,7 @@ import (
) )
var ( var (
ErrBadRange = errors.New("no blocks in specified range") ErrLotsOfBlocks = errors.New("requested >10k blocks at once")
) )
func CreateTables(conn *sql.DB) error { func CreateTables(conn *sql.DB) error {
@@ -84,35 +84,41 @@ func GetBlockByHash(ctx context.Context, db *sql.DB, hash string) ([]byte, error
} }
// [start, end] inclusive // [start, end] inclusive
func GetBlockRange(ctx context.Context, db *sql.DB, blocks chan<- []byte, errors chan<- error, start, end int) { func GetBlockRange(ctx context.Context, db *sql.DB, blockOut chan<- []byte, errOut chan<- error, start, end int) {
// TODO sanity check ranges, rate limit? // TODO sanity check ranges, this limit, etc
numBlocks := (end - start) + 1 numBlocks := (end - start) + 1
query := "SELECT height, compact_encoding from blocks WHERE (height BETWEEN ? AND ?)" if numBlocks > 10000 {
errOut <- ErrLotsOfBlocks
return
}
query := "SELECT compact_encoding from blocks WHERE (height BETWEEN ? AND ?)"
result, err := db.QueryContext(ctx, query, start, end) result, err := db.QueryContext(ctx, query, start, end)
if err != nil { if err != nil {
errors <- err errOut <- err
return return
} }
defer result.Close() defer result.Close()
// My assumption here is that if the context is cancelled then result.Next() will fail. // My assumption here is that if the context is cancelled then result.Next() will fail.
var blockBytes []byte // avoid a copy with *RawBytes var blockBytes []byte
for result.Next() { for result.Next() {
err = result.Scan(&blockBytes) err = result.Scan(&blockBytes)
if err != nil { if err != nil {
errors <- err errOut <- err
return return
} }
blocks <- blockBytes blockOut <- blockBytes
} }
if err := result.Err(); err != nil { if err := result.Err(); err != nil {
errors <- err errOut <- err
return
} }
// done // done
errors <- nil errOut <- nil
} }
func StoreBlock(conn *sql.DB, height int, hash string, sapling bool, encoded []byte) error { func StoreBlock(conn *sql.DB, height int, hash string, sapling bool, encoded []byte) error {

View File

@@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"testing" "testing"
"time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@@ -36,88 +37,163 @@ func TestSqliteStorage(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer conn.Close()
err = CreateTables(conn) db, err := sql.Open("sqlite3", ":memory:")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer db.Close()
for _, test := range compactTests { // Fill tables
blockData, _ := hex.DecodeString(test.Full) {
block := parser.NewBlock() err = CreateTables(db)
blockData, err = block.ParseFromSlice(blockData)
if err != nil { if err != nil {
t.Error(errors.Wrap(err, fmt.Sprintf("parsing testnet block %d", test.BlockHeight))) t.Fatal(err)
continue
} }
height := block.GetHeight() for _, test := range compactTests {
hash := hex.EncodeToString(block.GetEncodableHash()) blockData, _ := hex.DecodeString(test.Full)
hasSapling := block.HasSaplingTransactions() block := parser.NewBlock()
protoBlock := block.ToCompact() blockData, err = block.ParseFromSlice(blockData)
marshaled, _ := proto.Marshal(protoBlock) if err != nil {
t.Error(errors.Wrap(err, fmt.Sprintf("parsing testnet block %d", test.BlockHeight)))
continue
}
err = StoreBlock(conn, height, hash, hasSapling, marshaled) height := block.GetHeight()
if err != nil { hash := hex.EncodeToString(block.GetEncodableHash())
t.Error(err) hasSapling := block.HasSaplingTransactions()
continue protoBlock := block.ToCompact()
marshaled, _ := proto.Marshal(protoBlock)
err = StoreBlock(db, height, hash, hasSapling, marshaled)
if err != nil {
t.Error(err)
continue
}
} }
} }
var count int // Count the blocks
countBlocks := "SELECT count(*) FROM blocks" {
err = conn.QueryRow(countBlocks).Scan(&count) var count int
if err != nil { countBlocks := "SELECT count(*) FROM blocks"
t.Error(errors.Wrap(err, fmt.Sprintf("counting compact blocks"))) err = db.QueryRow(countBlocks).Scan(&count)
if err != nil {
t.Error(errors.Wrap(err, fmt.Sprintf("counting compact blocks")))
}
if count != len(compactTests) {
t.Errorf("Wrong row count, want %d got %d", len(compactTests), count)
}
} }
if count != len(compactTests) { ctx := context.Background()
t.Errorf("Wrong row count, want %d got %d", len(compactTests), count)
// Check height state is as expected
{
blockHeight, err := GetCurrentHeight(ctx, db)
if err != nil {
t.Error(errors.Wrap(err, fmt.Sprintf("checking current block height")))
}
lastBlockTest := compactTests[len(compactTests)-1]
if blockHeight != lastBlockTest.BlockHeight {
t.Errorf("Wrong block height, got: %d", blockHeight)
}
retBlock, err := GetBlock(ctx, db, blockHeight)
if err != nil {
t.Error(errors.Wrap(err, "retrieving stored block"))
}
cblock := &rpc.CompactBlock{}
err = proto.Unmarshal(retBlock, cblock)
if err != nil {
t.Fatal(err)
}
if int(cblock.Height) != lastBlockTest.BlockHeight {
t.Error("incorrect retrieval")
}
} }
blockHeight, err := GetCurrentHeight(context.Background(), conn) // Block ranges
if err != nil { {
t.Error(errors.Wrap(err, fmt.Sprintf("checking current block height"))) blockOut := make(chan []byte)
} errOut := make(chan error)
lastBlockTest := compactTests[len(compactTests)-1] count := 0
if blockHeight != lastBlockTest.BlockHeight { go GetBlockRange(ctx, db, blockOut, errOut, 289460, 289465)
t.Errorf("Wrong block height, got: %d", blockHeight) recvLoop0:
} for {
select {
case <-blockOut:
count++
case err := <-errOut:
if err != nil {
t.Error(errors.Wrap(err, "in full blockrange"))
}
break recvLoop0
}
}
retBlock, err := GetBlock(context.Background(), conn, blockHeight) if count != 6 {
if err != nil { t.Error("failed to retrieve full range")
t.Error(errors.Wrap(err, "retrieving stored block")) }
}
cblock := &rpc.CompactBlock{}
err = proto.Unmarshal(retBlock, cblock)
if err != nil {
t.Fatal(err)
}
if int(cblock.Height) != lastBlockTest.BlockHeight { // Test timeout
t.Error("incorrect retrieval") timeout, _ := context.WithTimeout(ctx, 0*time.Second)
} go GetBlockRange(timeout, db, blockOut, errOut, 289460, 289465)
recvLoop1:
for {
select {
case err := <-errOut:
if err != context.DeadlineExceeded {
t.Errorf("got the wrong error: %v", err)
}
break recvLoop1
}
}
blockRange, err := GetBlockRange(conn, 289460, 289465) // Test a smaller range
if err != nil { count = 0
t.Error(err) go GetBlockRange(ctx, db, blockOut, errOut, 289462, 289465)
} recvLoop2:
if len(blockRange) != 6 { for {
t.Error("failed to retrieve full range") select {
} case <-blockOut:
count++
case err := <-errOut:
if err != nil {
t.Error(errors.Wrap(err, "in short blockrange"))
}
break recvLoop2
}
}
blockRange, err = GetBlockRange(conn, 289462, 289465) if count != 4 {
if err != nil { t.Errorf("failed to retrieve the shorter range")
t.Error(err) }
}
if len(blockRange) != 4 { // Test a nonsense range
t.Error("failed to retrieve partial range") count = 0
} go GetBlockRange(ctx, db, blockOut, errOut, 1, 2)
recvLoop3:
for {
select {
case <-blockOut:
count++
case err := <-errOut:
if err != nil {
t.Error(errors.Wrap(err, "in invalid blockrange"))
}
break recvLoop3
}
}
if count > 0 {
t.Errorf("got some blocks that shouldn't be there")
}
blockRange, err = GetBlockRange(conn, 1337, 1338)
if err != ErrBadRange {
t.Error("Somehow retrieved nonexistent blocks!")
} }
} }