diff --git a/server/main.go b/server/main.go index 2a9e3e4..a154bdd 100644 --- a/server/main.go +++ b/server/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "database/sql" "errors" "flag" @@ -10,9 +11,12 @@ import ( "syscall" "github.com/gtank/ctxd/rpc" + "github.com/gtank/ctxd/storage" + _ "github.com/mattn/go-sqlite3" "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/reflection" ) var ( @@ -41,13 +45,13 @@ type Options struct { func main() { opts := &Options{} flag.StringVar(&opts.bindAddr, "bind-addr", "127.0.0.1:9067", "the address to listen on") - flag.StringVar(&opts.dbPath, "db-path", "", "the location of a sqlite database file") + flag.StringVar(&opts.dbPath, "db-path", "", "the path to a sqlite database file") flag.StringVar(&opts.tlsCertPath, "tls-cert", "", "the path to a TLS certificate (optional)") flag.StringVar(&opts.tlsKeyPath, "tls-key", "", "the path to a TLS key file (optional)") flag.Uint64Var(&opts.logLevel, "log-level", uint64(logrus.InfoLevel), "log level (logrus 1-7)") // TODO prod logging flag.StringVar(&opts.logPath, "log-file", "", "log file to write to") // TODO prod metrics - // TODO support config from file + // TODO support config from file and env vars flag.Parse() if opts.dbPath == "" { @@ -74,6 +78,11 @@ func main() { server = grpc.NewServer() } + // Enable reflection for debugging + if opts.logLevel >= uint64(logrus.WarnLevel) { + reflection.Register(server) + } + // Compact transaction service initialization service, err := NewSQLiteStreamer(opts.dbPath) if err != nil { @@ -110,7 +119,7 @@ func main() { if err != nil { log.WithFields(logrus.Fields{ "error": err.Error(), - }).Fatal("gRPC server failed") + }).Fatal("gRPC server exited") } } @@ -120,5 +129,43 @@ type sqlStreamer struct { } func NewSQLiteStreamer(dbPath string) (rpc.CompactTxStreamerServer, error) { + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + return nil, err + } + + // Creates our tables if they don't already exist. + err = storage.CreateTables(db) + if err != nil { + return nil, err + } + + return &sqlStreamer{db}, nil +} + +func (s *sqlStreamer) GetLatestBlock(ctx context.Context, placeholder *rpc.ChainSpec) (*rpc.BlockID, error) { + // the ChainSpec type is an empty placeholder + height, err := storage.GetCurrentHeight(ctx, s.db) + if err != nil { + log.WithFields(logrus.Fields{ + "error": err.Error(), + "context": ctx, + }).Error("GetLatestBlock call failed") + return nil, err + } + // TODO: also return block hashes here + return &rpc.BlockID{Height: uint64(height)}, nil +} + +func (s *sqlStreamer) GetBlock(context.Context, *rpc.BlockID) (*rpc.CompactBlock, error) { + return nil, ErrNoImpl +} +func (s *sqlStreamer) GetBlockRange(*rpc.BlockRange, rpc.CompactTxStreamer_GetBlockRangeServer) error { + return ErrNoImpl +} +func (s *sqlStreamer) GetTransaction(context.Context, *rpc.TxFilter) (*rpc.RawTransaction, error) { + return nil, ErrNoImpl +} +func (s *sqlStreamer) SendTransaction(context.Context, *rpc.RawTransaction) (*rpc.SendResponse, error) { return nil, ErrNoImpl } diff --git a/storage/sqlite3.go b/storage/sqlite3.go index 50ab2c1..b6e1e4b 100644 --- a/storage/sqlite3.go +++ b/storage/sqlite3.go @@ -1,6 +1,7 @@ package storage import ( + "context" "database/sql" "fmt" @@ -13,7 +14,7 @@ var ( ErrBadRange = errors.New("no blocks in specified range") ) -func createTables(conn *sql.DB) error { +func CreateTables(conn *sql.DB) error { stateTable := ` CREATE TABLE IF NOT EXISTS state ( current_height INTEGER, @@ -39,10 +40,10 @@ func createTables(conn *sql.DB) error { return err } -func CurrentHeight(conn *sql.DB) (int, error) { +func GetCurrentHeight(ctx context.Context, conn *sql.DB) (int, error) { var height int query := "SELECT current_height FROM state WHERE rowid = 1" - err := conn.QueryRow(query).Scan(&height) + err := conn.QueryRowContext(ctx, query).Scan(&height) return height, err } @@ -74,7 +75,7 @@ func SetCurrentHeight(conn *sql.DB, height int) error { return nil } -func GetBlock(conn *sql.DB, height int) (*rpc.CompactBlock, error) { +func GetBlock(ctx context.Context, conn *sql.DB, height int) (*rpc.CompactBlock, error) { var blockBytes []byte // avoid a copy with *RawBytes query := "SELECT compact_encoding from blocks WHERE height = ?" err := conn.QueryRow(query, height).Scan(&blockBytes) @@ -141,7 +142,7 @@ func StoreBlock(conn *sql.DB, height int, hash string, sapling bool, version int return errors.Wrap(err, fmt.Sprintf("storing compact block %d", height)) } - currentHeight, err := CurrentHeight(conn) + currentHeight, err := GetCurrentHeight(context.Background(), conn) if err != nil || height > currentHeight { err = SetCurrentHeight(conn, height) } diff --git a/storage/sqlite3_test.go b/storage/sqlite3_test.go index 7c5560a..14717c8 100644 --- a/storage/sqlite3_test.go +++ b/storage/sqlite3_test.go @@ -1,6 +1,7 @@ package storage import ( + "context" "database/sql" "encoding/hex" "encoding/json" @@ -39,7 +40,7 @@ func TestSqliteStorage(t *testing.T) { } defer conn.Close() - err = createTables(conn) + err = CreateTables(conn) if err != nil { t.Fatal(err) } @@ -78,7 +79,7 @@ func TestSqliteStorage(t *testing.T) { t.Errorf("Wrong row count, want %d got %d", len(compactTests), count) } - blockHeight, err := CurrentHeight(conn) + blockHeight, err := GetCurrentHeight(context.Background(), conn) if err != nil { t.Error(errors.Wrap(err, fmt.Sprintf("checking current block height"))) } @@ -88,7 +89,7 @@ func TestSqliteStorage(t *testing.T) { t.Errorf("Wrong block height, got: %d", blockHeight) } - retBlock, err := GetBlock(conn, blockHeight) + retBlock, err := GetBlock(context.Background(), conn, blockHeight) if err != nil { t.Error(errors.Wrap(err, "retrieving stored block")) }