diff --git a/core/txdb/txdb.go b/core/txdb/txdb.go index 02686e4..19bf5a1 100644 --- a/core/txdb/txdb.go +++ b/core/txdb/txdb.go @@ -1,6 +1,8 @@ package txdb import ( + "sync" + "github.com/sirupsen/logrus" . "github.com/yu-org/yu/common" @@ -18,13 +20,58 @@ const ( type TxDB struct { nodeType int - txnKV kv.KV - receiptKV kv.KV + txnKV *txnkvdb + receiptKV *receipttxnkvdb enableUseSql bool db sql.SqlDB } +type txnkvdb struct { + sync.RWMutex + txnKV kv.KV +} + +func (t *txnkvdb) GetTxn(txnHash Hash) (*SignedTxn, error) { + t.RLock() + defer t.RUnlock() + byt, err := t.txnKV.Get(txnHash.Bytes()) + if err != nil { + return nil, err + } + if byt == nil { + return nil, nil + } + return DecodeSignedTxn(byt) +} + +func (t *txnkvdb) ExistTxn(txnHash Hash) bool { + t.RLock() + defer t.RUnlock() + return t.txnKV.Exist(txnHash.Bytes()) +} + +func (t *txnkvdb) SetTxns(txns []*SignedTxn) error { + t.Lock() + defer t.Unlock() + kvtx, err := t.txnKV.NewKvTxn() + if err != nil { + return err + } + for _, txn := range txns { + txbyt, err := txn.Encode() + if err != nil { + logrus.Errorf("TxDB.SetTxns set tx(%s) failed: %v", txn.TxnHash.String(), err) + return err + } + err = kvtx.Set(txn.TxnHash.Bytes(), txbyt) + if err != nil { + return err + } + } + return kvtx.Commit() +} + type TxnDBSchema struct { Type string `gorm:"type:varchar(10)"` Key string `gorm:"primaryKey;type:text"` @@ -38,8 +85,8 @@ func (TxnDBSchema) TableName() string { func NewTxDB(nodeTyp int, kvdb kv.Kvdb, kvdbConf *config.KVconf) (ItxDB, error) { txdb := &TxDB{ nodeType: nodeTyp, - txnKV: kvdb.New(Txns), - receiptKV: kvdb.New(Results), + txnKV: &txnkvdb{txnKV: kvdb.New(Txns)}, + receiptKV: &receipttxnkvdb{receiptKV: kvdb.New(Results)}, } if kvdbConf != nil && kvdbConf.UseSQlDbConf { db, err := sql.NewSqlDB(&kvdbConf.SQLDbConf) @@ -67,14 +114,7 @@ func (bb *TxDB) GetTxn(txnHash Hash) (*SignedTxn, error) { return DecodeSignedTxn([]byte(records[0].Value)) } } - byt, err := bb.txnKV.Get(txnHash.Bytes()) - if err != nil { - return nil, err - } - if byt == nil { - return nil, nil - } - return DecodeSignedTxn(byt) + return bb.txnKV.GetTxn(txnHash) } func (bb *TxDB) GetTxns(txnHashes []Hash) ([]*SignedTxn, error) { @@ -103,7 +143,7 @@ func (bb *TxDB) ExistTxn(txnHash Hash) bool { return true } } - return bb.txnKV.Exist(txnHash.Bytes()) + return bb.txnKV.ExistTxn(txnHash) } func (bb *TxDB) SetTxns(txns []*SignedTxn) error { @@ -124,22 +164,7 @@ func (bb *TxDB) SetTxns(txns []*SignedTxn) error { } return nil } - kvtx, err := bb.txnKV.NewKvTxn() - if err != nil { - return err - } - for _, txn := range txns { - txbyt, err := txn.Encode() - if err != nil { - logrus.Errorf("TxDB.SetTxns set tx(%s) failed: %v", txn.TxnHash.String(), err) - return err - } - err = kvtx.Set(txn.TxnHash.Bytes(), txbyt) - if err != nil { - return err - } - } - return kvtx.Commit() + return bb.txnKV.SetTxns(txns) } func (bb *TxDB) SetReceipts(receipts map[Hash]*Receipt) error { @@ -151,22 +176,7 @@ func (bb *TxDB) SetReceipts(receipts map[Hash]*Receipt) error { } return nil } - kvtx, err := bb.receiptKV.NewKvTxn() - if err != nil { - return err - } - for txHash, receipt := range receipts { - byt, err := receipt.Encode() - if err != nil { - return err - } - err = kvtx.Set(txHash.Bytes(), byt) - if err != nil { - return err - } - } - - return kvtx.Commit() + return bb.receiptKV.SetReceipts(receipts) } func (bb *TxDB) SetReceipt(txHash Hash, receipt *Receipt) error { @@ -180,11 +190,7 @@ func (bb *TxDB) SetReceipt(txHash Hash, receipt *Receipt) error { } return nil } - byt, err := receipt.Encode() - if err != nil { - return err - } - return bb.receiptKV.Set(txHash.Bytes(), byt) + return bb.receiptKV.SetReceipt(txHash, receipt) } func (bb *TxDB) GetReceipt(txHash Hash) (*Receipt, error) { @@ -199,7 +205,18 @@ func (bb *TxDB) GetReceipt(txHash Hash) (*Receipt, error) { } } } - byt, err := bb.receiptKV.Get(txHash.Bytes()) + return bb.receiptKV.GetReceipt(txHash) +} + +type receipttxnkvdb struct { + sync.RWMutex + receiptKV kv.KV +} + +func (r *receipttxnkvdb) GetReceipt(txHash Hash) (*Receipt, error) { + r.RLock() + defer r.RUnlock() + byt, err := r.receiptKV.Get(txHash.Bytes()) if err != nil { logrus.Errorf("TxDB.GetReceipt(%s), failed: %s, error: %v", txHash.String(), string(byt), err) return nil, err @@ -214,3 +231,33 @@ func (bb *TxDB) GetReceipt(txHash Hash) (*Receipt, error) { } return receipt, err } + +func (r *receipttxnkvdb) SetReceipt(txHash Hash, receipt *Receipt) error { + r.Lock() + defer r.Unlock() + byt, err := receipt.Encode() + if err != nil { + return err + } + return r.receiptKV.Set(txHash.Bytes(), byt) +} + +func (r *receipttxnkvdb) SetReceipts(receipts map[Hash]*Receipt) error { + r.Lock() + defer r.Unlock() + kvtx, err := r.receiptKV.NewKvTxn() + if err != nil { + return err + } + for txHash, receipt := range receipts { + byt, err := receipt.Encode() + if err != nil { + return err + } + err = kvtx.Set(txHash.Bytes(), byt) + if err != nil { + return err + } + } + return kvtx.Commit() +}