From d503752588705262347455aa5fab575fd1ea0b07 Mon Sep 17 00:00:00 2001 From: Aditya Kulkarni Date: Wed, 25 Sep 2019 15:55:37 -0700 Subject: [PATCH] Block Cache --- common/cache.go | 34 +++++++++++++++-------- common/common.go | 67 ++++++++++++++++++++++++++++++++++++++++++--- frontend/service.go | 7 +++-- 3 files changed, 90 insertions(+), 18 deletions(-) diff --git a/common/cache.go b/common/cache.go index 56743c3..5c691f5 100644 --- a/common/cache.go +++ b/common/cache.go @@ -1,6 +1,11 @@ package common -import "github.com/pkg/errors" +import ( + "bytes" + + "github.com/adityapk00/lightwalletd/parser" + "github.com/pkg/errors" +) type BlockCache struct { MaxEntries int @@ -8,7 +13,7 @@ type BlockCache struct { FirstBlock int LastBlock int - m map[int][]byte + m map[int]*parser.Block } func New(maxEntries int) *BlockCache { @@ -16,12 +21,12 @@ func New(maxEntries int) *BlockCache { MaxEntries: maxEntries, FirstBlock: -1, LastBlock: -1, - m: make(map[int][]byte), + m: make(map[int]*parser.Block), } } -func (c *BlockCache) Add(height int, bytes []byte) error { - +func (c *BlockCache) Add(height int, block *parser.Block) error { + println("Cache add", height) if c.FirstBlock == -1 && c.LastBlock == -1 { // If this is the first block, prep the data structure c.FirstBlock = height @@ -32,14 +37,20 @@ func (c *BlockCache) Add(height int, bytes []byte) error { for i := height; i <= c.LastBlock; i++ { delete(c.m, i) } + c.LastBlock = height - 1 } if height != c.LastBlock+1 { return errors.New("Blocks need to be added sequentially") } + if c.m[height-1] != nil && !bytes.Equal(block.GetPrevHash(), c.m[height-1].GetEncodableHash()) { + return errors.New("Prev hash of the block didn't match") + } + // Add the entry and update the counters - c.m[height] = bytes + c.m[height] = block + c.LastBlock = height // If the cache is full, remove the oldest block @@ -51,15 +62,16 @@ func (c *BlockCache) Add(height int, bytes []byte) error { return nil } -func (c *BlockCache) Get(height int) ([]byte, error) { - +func (c *BlockCache) Get(height int) *parser.Block { + println("Cache get", height) if c.LastBlock == -1 || c.FirstBlock == -1 { - return nil, errors.New("Map is empty") + return nil } if height < c.FirstBlock || height > c.LastBlock { - return nil, errors.New("Index out of range") + return nil } - return c.m[height], nil + println("Cache returned") + return c.m[height] } diff --git a/common/common.go b/common/common.go index 216f834..458e4e3 100644 --- a/common/common.go +++ b/common/common.go @@ -1,6 +1,7 @@ package common import ( + "bytes" "encoding/hex" "encoding/json" "strconv" @@ -44,7 +45,7 @@ func GetSaplingInfo(rpcClient *rpcclient.Client) (int, string, error) { return int(saplingHeight), chainName, nil } -func GetBlock(rpcClient *rpcclient.Client, height int) (*parser.Block, error) { +func getBlockFromRPC(rpcClient *rpcclient.Client, height int) (*parser.Block, error) { params := make([]json.RawMessage, 2) params[0] = json.RawMessage("\"" + strconv.Itoa(height) + "\"") params[1] = json.RawMessage("0") @@ -83,15 +84,73 @@ func GetBlock(rpcClient *rpcclient.Client, height int) (*parser.Block, error) { if len(rest) != 0 { return nil, errors.New("received overlong message") } + return block, nil } -func GetBlockRange(rpcClient *rpcclient.Client, blockOut chan<- walletrpc.CompactBlock, - errOut chan<- error, start, end int) { +func GetBlock(rpcClient *rpcclient.Client, cache *BlockCache, height int) (*parser.Block, error) { + // First, check the cache to see if we have the block + block := cache.Get(height) + if block != nil { + return block, nil + } + + block, err := getBlockFromRPC(rpcClient, height) + if err != nil { + return nil, err + } + + // Store the block in the cache, but test for reorg first + prevBlock := cache.Get(height - 1) + + if prevBlock != nil { + if !bytes.Equal(prevBlock.GetEncodableHash(), block.GetPrevHash()) { + // Reorg! + reorgCount := 0 + cacheBlock := cache.Get(height - reorgCount) + + rpcBlocks := []*parser.Block{} + + for ; reorgCount <= 100 && + cacheBlock != nil && + !bytes.Equal(block.GetPrevHash(), cacheBlock.GetEncodableHash()); reorgCount++ { + + block, err = getBlockFromRPC(rpcClient, height-reorgCount-1) + if err != nil { + return nil, err + } + + _ = append(rpcBlocks, block) + + cacheBlock = cache.Get(height - reorgCount - 2) + + } + + if reorgCount == 100 { + return nil, errors.New("Max reorg depth exceeded") + } + + // At this point, the block.prevHash == cache.hash + // Store all blocks starting with 'block' + for i := len(rpcBlocks) - 1; i >= 0; i-- { + cache.Add(rpcBlocks[i].GetHeight(), rpcBlocks[i]) + } + } + } + + cache.Add(height, block) + + return block, nil +} + +func GetBlockRange(rpcClient *rpcclient.Client, cache *BlockCache, + blockOut chan<- walletrpc.CompactBlock, errOut chan<- error, start, end int) { + + println("Getting block range") // Go over [start, end] inclusive for i := start; i <= end; i++ { - block, err := GetBlock(rpcClient, i) + block, err := GetBlock(rpcClient, cache, i) if err != nil { errOut <- err return diff --git a/frontend/service.go b/frontend/service.go index fba9bcd..ae8aa06 100644 --- a/frontend/service.go +++ b/frontend/service.go @@ -22,12 +22,13 @@ var ( // the service type type SqlStreamer struct { + cache *common.BlockCache client *rpcclient.Client log *logrus.Entry } func NewSQLiteStreamer(client *rpcclient.Client, log *logrus.Entry) (walletrpc.CompactTxStreamerServer, error) { - return &SqlStreamer{client, log}, nil + return &SqlStreamer{common.New(100000), client, log}, nil } func (s *SqlStreamer) GracefulStop() error { @@ -129,7 +130,7 @@ func (s *SqlStreamer) GetBlock(ctx context.Context, id *walletrpc.BlockID) (*wal return nil, errors.New("GetBlock by Hash is not yet implemented") } else { - cBlock, err := common.GetBlock(s.client, int(id.Height)) + cBlock, err := common.GetBlock(s.client, s.cache, int(id.Height)) if err != nil { return nil, err @@ -144,7 +145,7 @@ func (s *SqlStreamer) GetBlockRange(span *walletrpc.BlockRange, resp walletrpc.C blockChan := make(chan walletrpc.CompactBlock) errChan := make(chan error) - go common.GetBlockRange(s.client, blockChan, errChan, int(span.Start.Height), int(span.End.Height)) + go common.GetBlockRange(s.client, s.cache, blockChan, errChan, int(span.Start.Height), int(span.End.Height)) for { select {