diff --git a/frontend/service.go b/frontend/service.go index 054a99a..9195116 100644 --- a/frontend/service.go +++ b/frontend/service.go @@ -5,7 +5,9 @@ import ( "database/sql" "encoding/hex" "errors" + "time" + "github.com/golang/protobuf/proto" _ "github.com/mattn/go-sqlite3" "github.com/gtank/ctxd/rpc" @@ -77,22 +79,28 @@ func (s *SqlStreamer) GetBlock(ctx context.Context, id *rpc.BlockID) (*rpc.Compa } func (s *SqlStreamer) GetBlockRange(span *rpc.BlockRange, resp rpc.CompactTxStreamer_GetBlockRangeServer) error { - blocks := make(chan []byte) - errors := make(chan error) - done := make(chan bool) + blockChan := make(chan []byte) + errChan := make(chan error) - timeout := resp.Context().WithTimeout(1 * time.Second) - go GetBlockRange(timeout, s.db, blocks, errors, done, span.Start, span.End) + // TODO configure or stress-test this timeout + timeout, cancel := context.WithTimeout(resp.Context(), 1*time.Second) + defer cancel() + go storage.GetBlockRange(timeout, + s.db, + blockChan, + errChan, + int(span.Start.Height), + int(span.End.Height), + ) for { select { - case <-timeout.Done(): - return timeout.Err() - case err := <-errors: + case err := <-errChan: + // this will also catch context.DeadlineExceeded from the timeout return err - case blockBytes := <-blocks: + case blockBytes := <-blockChan: cBlock := &rpc.CompactBlock{} - err = proto.Unmarshal(blockBytes, cBlock) + err := proto.Unmarshal(blockBytes, cBlock) if err != nil { return err // TODO really need better logging in this whole service } diff --git a/storage/sqlite3.go b/storage/sqlite3.go index 5da2123..70bbbc2 100644 --- a/storage/sqlite3.go +++ b/storage/sqlite3.go @@ -9,7 +9,7 @@ import ( ) var ( - ErrBadRange = errors.New("no blocks in specified range") + ErrLotsOfBlocks = errors.New("requested >10k blocks at once") ) func CreateTables(conn *sql.DB) error { @@ -84,35 +84,41 @@ func GetBlockByHash(ctx context.Context, db *sql.DB, hash string) ([]byte, error } // [start, end] inclusive -func GetBlockRange(ctx context.Context, db *sql.DB, blocks chan<- []byte, errors chan<- error, start, end int) { - // TODO sanity check ranges, rate limit? +func GetBlockRange(ctx context.Context, db *sql.DB, blockOut chan<- []byte, errOut chan<- error, start, end int) { + // TODO sanity check ranges, this limit, etc numBlocks := (end - start) + 1 - query := "SELECT height, compact_encoding from blocks WHERE (height BETWEEN ? AND ?)" + if numBlocks > 10000 { + errOut <- ErrLotsOfBlocks + return + } + + query := "SELECT compact_encoding from blocks WHERE (height BETWEEN ? AND ?)" result, err := db.QueryContext(ctx, query, start, end) if err != nil { - errors <- err + errOut <- err return } defer result.Close() // My assumption here is that if the context is cancelled then result.Next() will fail. - var blockBytes []byte // avoid a copy with *RawBytes + var blockBytes []byte for result.Next() { err = result.Scan(&blockBytes) if err != nil { - errors <- err + errOut <- err return } - blocks <- blockBytes + blockOut <- blockBytes } if err := result.Err(); err != nil { - errors <- err + errOut <- err + return } // done - errors <- nil + errOut <- nil } func StoreBlock(conn *sql.DB, height int, hash string, sapling bool, encoded []byte) error { diff --git a/storage/sqlite3_test.go b/storage/sqlite3_test.go index f569a4e..e7f0777 100644 --- a/storage/sqlite3_test.go +++ b/storage/sqlite3_test.go @@ -8,6 +8,7 @@ import ( "fmt" "io/ioutil" "testing" + "time" "github.com/golang/protobuf/proto" _ "github.com/mattn/go-sqlite3" @@ -36,88 +37,163 @@ func TestSqliteStorage(t *testing.T) { if err != nil { t.Fatal(err) } - defer conn.Close() - err = CreateTables(conn) + db, err := sql.Open("sqlite3", ":memory:") if err != nil { t.Fatal(err) } + defer db.Close() - for _, test := range compactTests { - blockData, _ := hex.DecodeString(test.Full) - block := parser.NewBlock() - blockData, err = block.ParseFromSlice(blockData) + // Fill tables + { + err = CreateTables(db) if err != nil { - t.Error(errors.Wrap(err, fmt.Sprintf("parsing testnet block %d", test.BlockHeight))) - continue + t.Fatal(err) } - height := block.GetHeight() - hash := hex.EncodeToString(block.GetEncodableHash()) - hasSapling := block.HasSaplingTransactions() - protoBlock := block.ToCompact() - marshaled, _ := proto.Marshal(protoBlock) + for _, test := range compactTests { + blockData, _ := hex.DecodeString(test.Full) + block := parser.NewBlock() + blockData, err = block.ParseFromSlice(blockData) + if err != nil { + t.Error(errors.Wrap(err, fmt.Sprintf("parsing testnet block %d", test.BlockHeight))) + continue + } - err = StoreBlock(conn, height, hash, hasSapling, marshaled) - if err != nil { - t.Error(err) - continue + height := block.GetHeight() + hash := hex.EncodeToString(block.GetEncodableHash()) + hasSapling := block.HasSaplingTransactions() + protoBlock := block.ToCompact() + marshaled, _ := proto.Marshal(protoBlock) + + err = StoreBlock(db, height, hash, hasSapling, marshaled) + if err != nil { + t.Error(err) + continue + } } } - var count int - countBlocks := "SELECT count(*) FROM blocks" - err = conn.QueryRow(countBlocks).Scan(&count) - if err != nil { - t.Error(errors.Wrap(err, fmt.Sprintf("counting compact blocks"))) + // Count the blocks + { + var count int + countBlocks := "SELECT count(*) FROM blocks" + err = db.QueryRow(countBlocks).Scan(&count) + if err != nil { + t.Error(errors.Wrap(err, fmt.Sprintf("counting compact blocks"))) + } + + if count != len(compactTests) { + t.Errorf("Wrong row count, want %d got %d", len(compactTests), count) + } } - if count != len(compactTests) { - t.Errorf("Wrong row count, want %d got %d", len(compactTests), count) + ctx := context.Background() + + // Check height state is as expected + { + blockHeight, err := GetCurrentHeight(ctx, db) + if err != nil { + t.Error(errors.Wrap(err, fmt.Sprintf("checking current block height"))) + } + + lastBlockTest := compactTests[len(compactTests)-1] + if blockHeight != lastBlockTest.BlockHeight { + t.Errorf("Wrong block height, got: %d", blockHeight) + } + + retBlock, err := GetBlock(ctx, db, blockHeight) + if err != nil { + t.Error(errors.Wrap(err, "retrieving stored block")) + } + cblock := &rpc.CompactBlock{} + err = proto.Unmarshal(retBlock, cblock) + if err != nil { + t.Fatal(err) + } + + if int(cblock.Height) != lastBlockTest.BlockHeight { + t.Error("incorrect retrieval") + } } - blockHeight, err := GetCurrentHeight(context.Background(), conn) - if err != nil { - t.Error(errors.Wrap(err, fmt.Sprintf("checking current block height"))) - } + // Block ranges + { + blockOut := make(chan []byte) + errOut := make(chan error) - lastBlockTest := compactTests[len(compactTests)-1] - if blockHeight != lastBlockTest.BlockHeight { - t.Errorf("Wrong block height, got: %d", blockHeight) - } + count := 0 + go GetBlockRange(ctx, db, blockOut, errOut, 289460, 289465) + recvLoop0: + for { + select { + case <-blockOut: + count++ + case err := <-errOut: + if err != nil { + t.Error(errors.Wrap(err, "in full blockrange")) + } + break recvLoop0 + } + } - retBlock, err := GetBlock(context.Background(), conn, blockHeight) - if err != nil { - t.Error(errors.Wrap(err, "retrieving stored block")) - } - cblock := &rpc.CompactBlock{} - err = proto.Unmarshal(retBlock, cblock) - if err != nil { - t.Fatal(err) - } + if count != 6 { + t.Error("failed to retrieve full range") + } - if int(cblock.Height) != lastBlockTest.BlockHeight { - t.Error("incorrect retrieval") - } + // Test timeout + timeout, _ := context.WithTimeout(ctx, 0*time.Second) + go GetBlockRange(timeout, db, blockOut, errOut, 289460, 289465) + recvLoop1: + for { + select { + case err := <-errOut: + if err != context.DeadlineExceeded { + t.Errorf("got the wrong error: %v", err) + } + break recvLoop1 + } + } - blockRange, err := GetBlockRange(conn, 289460, 289465) - if err != nil { - t.Error(err) - } - if len(blockRange) != 6 { - t.Error("failed to retrieve full range") - } + // Test a smaller range + count = 0 + go GetBlockRange(ctx, db, blockOut, errOut, 289462, 289465) + recvLoop2: + for { + select { + case <-blockOut: + count++ + case err := <-errOut: + if err != nil { + t.Error(errors.Wrap(err, "in short blockrange")) + } + break recvLoop2 + } + } - blockRange, err = GetBlockRange(conn, 289462, 289465) - if err != nil { - t.Error(err) - } - if len(blockRange) != 4 { - t.Error("failed to retrieve partial range") - } + if count != 4 { + t.Errorf("failed to retrieve the shorter range") + } + + // Test a nonsense range + count = 0 + go GetBlockRange(ctx, db, blockOut, errOut, 1, 2) + recvLoop3: + for { + select { + case <-blockOut: + count++ + case err := <-errOut: + if err != nil { + t.Error(errors.Wrap(err, "in invalid blockrange")) + } + break recvLoop3 + } + } + + if count > 0 { + t.Errorf("got some blocks that shouldn't be there") + } - blockRange, err = GetBlockRange(conn, 1337, 1338) - if err != ErrBadRange { - t.Error("Somehow retrieved nonexistent blocks!") } }