Skip to content

Commit

Permalink
Merge pull request #14571 from MinaProtocol/feature/merkle-mask-prelo…
Browse files Browse the repository at this point in the history
…ading

Preload accounts into merkle path for staged ledger diff application
  • Loading branch information
deepthiskumar authored Dec 5, 2023
2 parents df8eba8 + 4b56a35 commit d9035e7
Show file tree
Hide file tree
Showing 22 changed files with 591 additions and 285 deletions.
42 changes: 26 additions & 16 deletions src/lib/merkle_address/merkle_address.ml
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ let height ~ledger_depth path = ledger_depth - depth path

let get = get

[%%define_locally
Stable.Latest.(t_of_sexp, sexp_of_t, to_yojson, compare, equal)]
[%%define_locally Stable.Latest.(t_of_sexp, sexp_of_t, to_yojson)]

include Comparable.Make_binable (Stable.Latest)
include Hashable.Make_binable (Stable.Latest)

let of_byte_string = bitstring_of_string
Expand All @@ -114,13 +114,13 @@ let copy (path : t) : t =
(* returns a slice of the original path, so the returned key needs to be
copied before mutating the path *)
let parent (path : t) =
if bitstring_length path = 0 then
if Int.equal (bitstring_length path) 0 then
Or_error.error_string "Address length should be nonzero"
else Or_error.return (slice path 0 (bitstring_length path - 1))

let parent_exn = Fn.compose Or_error.ok_exn parent

let is_leaf ~ledger_depth path = bitstring_length path >= ledger_depth
let is_leaf ~ledger_depth path = Int.(bitstring_length path >= ledger_depth)

let child ~ledger_depth (path : t) dir : t Or_error.t =
if is_leaf ~ledger_depth path then
Expand All @@ -137,10 +137,10 @@ let to_int (path : t) : int =
Sequence.range 0 (depth path)
|> Sequence.fold ~init:0 ~f:(fun acc i ->
let index = depth path - 1 - i in
acc + ((if get path index <> 0 then 1 else 0) lsl i) )
acc + ((if Int.(get path index <> 0) then 1 else 0) lsl i) )

let of_int_exn ~ledger_depth index =
if index >= 1 lsl ledger_depth then failwith "Index is too large"
if Int.(index >= 1 lsl ledger_depth) then failwith "Index is too large"
else
let buf = create_bitstring ledger_depth in
ignore
Expand All @@ -160,7 +160,7 @@ let root () = create_bitstring 0
let sibling (path : t) : t =
let path = copy path in
let last_bit_index = depth path - 1 in
let last_bit = if get path last_bit_index = 0 then 1 else 0 in
let last_bit = if Int.equal (get path last_bit_index) 0 then 1 else 0 in
put path last_bit_index last_bit ;
path

Expand All @@ -169,12 +169,12 @@ let next (path : t) : t Option.t =
let path = copy path in
let len = depth path in
let rec find_rightmost_clear_bit i =
if i < 0 then None
if Int.(i < 0) then None
else if is_clear path i then Some i
else find_rightmost_clear_bit (i - 1)
in
let rec clear_bits i =
if i >= len then ()
if Int.(i >= len) then ()
else (
clear path i ;
clear_bits (i + 1) )
Expand All @@ -189,12 +189,12 @@ let prev (path : t) : t Option.t =
let path = copy path in
let len = depth path in
let rec find_rightmost_one_bit i =
if i < 0 then None
if Int.(i < 0) then None
else if is_set path i then Some i
else find_rightmost_one_bit (i - 1)
in
let rec set_bits i =
if i >= len then ()
if Int.(i >= len) then ()
else (
set path i ;
set_bits (i + 1) )
Expand All @@ -208,28 +208,38 @@ let serialize ~ledger_depth path =
let path = add_padding path in
let path_len = depth path in
let required_bits = 8 * byte_count_of_bits ledger_depth in
assert (path_len <= required_bits) ;
assert (Int.(path_len <= required_bits)) ;
let required_padding = required_bits - path_len in
Bigstring.of_string @@ string_of_bitstring
@@ concat [ path; zeroes_bitstring required_padding ]

let is_parent_of parent ~maybe_child = Bitstring.is_prefix maybe_child parent

let same_height_ancestors x y =
let depth_x = depth x in
let depth_y = depth y in
if Int.(depth_x < depth_y) then (x, slice y 0 depth_x)
else (slice x 0 depth_y, y)

let is_further_right ~than path =
let than, path = same_height_ancestors than path in
Int.( < ) (compare than path) 0

module Range = struct
type nonrec t = t * t

let rec fold_exl (first, last) ~init ~f =
let comparison = compare first last in
if comparison > 0 then
if Int.(comparison > 0) then
raise (Invalid_argument "first address needs to precede last address")
else if comparison = 0 then init
else if Int.(comparison = 0) then init
else fold_exl (next first |> Option.value_exn, last) ~init:(f first init) ~f

let fold_incl (first, last) ~init ~f =
f last @@ fold_exl (first, last) ~init ~f

let fold ?(stop = `Inclusive) (first, last) ~init ~f =
assert (depth first = depth last) ;
assert (Int.(depth first = depth last)) ;
match stop with
| `Inclusive ->
fold_incl (first, last) ~init ~f
Expand All @@ -253,7 +263,7 @@ module Range = struct
| _, `Stop ->
None
| current_node, `Don't_stop ->
if compare current_node last_node = 0 then
if Int.equal (compare current_node last_node) 0 then
Some (current_node, (current_node, `Stop))
else
Option.map (next current_node) ~f:(fun next_node ->
Expand Down
6 changes: 6 additions & 0 deletions src/lib/merkle_address/merkle_address.mli
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ module Stable : sig
module Latest : module type of V1
end

include Comparable.S_binable with type t := t

include Hashable.S_binable with type t := t

val of_byte_string : string -> t
Expand Down Expand Up @@ -74,3 +76,7 @@ val height : ledger_depth:int -> t -> int
val to_int : t -> int

val of_int_exn : ledger_depth:int -> int -> t

val same_height_ancestors : t -> t -> t * t

val is_further_right : than:t -> t -> bool
5 changes: 5 additions & 0 deletions src/lib/merkle_ledger/any_ledger.ml
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,13 @@ module Make_base (Inputs : Inputs_intf) :

let merkle_path_batch (T ((module Base), t)) = Base.merkle_path_batch t

let wide_merkle_path_batch (T ((module Base), t)) =
Base.wide_merkle_path_batch t

let merkle_root (T ((module Base), t)) = Base.merkle_root t

let get_hash_batch_exn (T ((module Base), t)) = Base.get_hash_batch_exn t

let index_of_account_exn (T ((module Base), t)) =
Base.index_of_account_exn t

Expand Down
7 changes: 7 additions & 0 deletions src/lib/merkle_ledger/base_ledger_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ module type S = sig

val merkle_path_batch : t -> Location.t list -> Path.t list

val wide_merkle_path_batch :
t
-> Location.t list
-> [ `Left of hash * hash | `Right of hash * hash ] list list

val get_hash_batch_exn : t -> Location.t list -> hash list

val remove_accounts_exn : t -> account_id list -> unit

(** Triggers when the ledger has been detached and should no longer be
Expand Down
82 changes: 39 additions & 43 deletions src/lib/merkle_ledger/database.ml
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ module Make (Inputs : Inputs_intf) :
| None ->
empty_hash (Location.height ~ledger_depth:mdb.depth location)

let get_hash_batch mdb locations =
let get_hash_batch_exn mdb locations =
List.iter locations ~f:(fun location -> assert (Location.is_hash location)) ;
let hashes = get_bin_batch mdb locations Hash.bin_read_t in
List.map2_exn locations hashes ~f:(fun location hash ->
Expand Down Expand Up @@ -696,11 +696,11 @@ module Make (Inputs : Inputs_intf) :
let dependency_locs, dependency_dirs =
List.unzip (Location.merkle_path_dependencies_exn location)
in
let dependency_hashes = get_hash_batch mdb dependency_locs in
let dependency_hashes = get_hash_batch_exn mdb dependency_locs in
List.map2_exn dependency_dirs dependency_hashes ~f:(fun dir hash ->
Direction.map dir ~left:(`Left hash) ~right:(`Right hash) )

let merkle_path_batch mdb locations =
let path_batch_impl ~expand_query ~compute_path mdb locations =
let locations =
List.map locations ~f:(fun location ->
if Location.is_account location then
Expand All @@ -709,48 +709,44 @@ module Make (Inputs : Inputs_intf) :
assert (Location.is_hash location) ;
location ) )
in
let rev_locations, rev_directions, rev_lengths =
let rec loop locations loc_acc dir_acc length_acc =
match (locations, length_acc) with
| [], _ :: length_acc ->
(loc_acc, dir_acc, length_acc)
| k :: locations, length :: length_acc ->
if Location.height ~ledger_depth:mdb.depth k >= mdb.depth then
loop locations loc_acc dir_acc (0 :: length :: length_acc)
else
let sibling = Location.sibling k in
let sibling_dir =
Location.last_direction (Location.to_path_exn k)
in
loop
(Location.parent k :: locations)
(sibling :: loc_acc) (sibling_dir :: dir_acc)
((length + 1) :: length_acc)
| _ ->
assert false
in
loop locations [] [] [ 0 ]
let list_of_dependencies =
List.map locations ~f:Location.merkle_path_dependencies_exn
in
let rev_hashes = get_hash_batch mdb rev_locations in
let rec loop directions hashes lengths acc =
match (directions, hashes, lengths, acc) with
| [], [], [], _ (* actually [] *) :: acc_tl ->
acc_tl
| _, _, 0 :: lengths, _ ->
loop directions hashes lengths ([] :: acc)
| ( direction :: directions
, hash :: hashes
, length :: lengths
, acc_hd :: acc_tl ) ->
let dir =
Direction.map direction ~left:(`Left hash) ~right:(`Right hash)
in
loop directions hashes ((length - 1) :: lengths)
((dir :: acc_hd) :: acc_tl)
| _ ->
failwith "Mismatched lengths"
let all_locs =
List.map list_of_dependencies ~f:(fun deps -> List.map ~f:fst deps |> expand_query) |> List.concat
in
loop rev_directions rev_hashes rev_lengths [ [] ]
let hashes = get_hash_batch_exn mdb all_locs in
snd @@ List.fold_map ~init:hashes ~f:compute_path list_of_dependencies

let merkle_path_batch =
path_batch_impl ~expand_query:ident
~compute_path:(fun all_hashes loc_and_dir_list ->
let len = List.length loc_and_dir_list in
let sibling_hashes, rest_hashes = List.split_n all_hashes len in
let res =
List.map2_exn loc_and_dir_list sibling_hashes
~f:(fun (_, direction) sibling_hash ->
Direction.map direction ~left:(`Left sibling_hash)
~right:(`Right sibling_hash) )
in
(rest_hashes, res) )

let wide_merkle_path_batch =
path_batch_impl
~expand_query:(fun sib_locs ->
sib_locs @ List.map sib_locs ~f:Location.sibling )
~compute_path:(fun all_hashes loc_and_dir_list ->
let len = List.length loc_and_dir_list in
let sibling_hashes, rest_hashes = List.split_n all_hashes len in
let self_hashes, rest_hashes' = List.split_n rest_hashes len in
let res =
List.map3_exn loc_and_dir_list sibling_hashes self_hashes
~f:(fun (_, direction) sibling_hash self_hash ->
Direction.map direction
~left:(`Left (self_hash, sibling_hash))
~right:(`Right (sibling_hash, self_hash)) )
in
(rest_hashes', res) )

let merkle_path_at_addr_exn t addr = merkle_path t (Location.Hash addr)

Expand Down
30 changes: 28 additions & 2 deletions src/lib/merkle_ledger/null_ledger.ml
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,48 @@ end = struct
let h = Location.height ~ledger_depth:t.depth k in
if h >= t.depth then []
else
let sibling_dir = Location.last_direction (Location.to_path_exn k) in
let dir = Location.last_direction (Location.to_path_exn k) in
let hash = empty_hash_at_height h in
Direction.map sibling_dir ~left:(`Left hash) ~right:(`Right hash)
Direction.map dir ~left:(`Left hash) ~right:(`Right hash)
:: loop (Location.parent k)
in
loop location

let merkle_path_batch t locations = List.map ~f:(merkle_path t) locations

let wide_merkle_path t location =
let location =
if Location.is_account location then
Location.Hash (Location.to_path_exn location)
else location
in
assert (Location.is_hash location) ;
let rec loop k =
let h = Location.height ~ledger_depth:t.depth k in
if h >= t.depth then []
else
let dir = Location.last_direction (Location.to_path_exn k) in
let hash = empty_hash_at_height h in
Direction.map dir ~left:(`Left (hash, hash)) ~right:(`Right (hash, hash))
:: loop (Location.parent k)
in
loop location

let wide_merkle_path_batch t locations =
List.map ~f:(wide_merkle_path t) locations

let merkle_root t = empty_hash_at_height t.depth

let merkle_path_at_addr_exn t addr = merkle_path t (Location.Hash addr)

let merkle_path_at_index_exn t index =
merkle_path_at_addr_exn t (Addr.of_int_exn ~ledger_depth:t.depth index)

let get_hash_batch_exn t locations =
List.map locations ~f:(fun location ->
empty_hash_at_height
(Addr.height ~ledger_depth:t.depth (Location.to_path_exn location)) )

let index_of_account_exn _t =
failwith "index_of_account_exn: null ledgers are empty"

Expand Down
14 changes: 2 additions & 12 deletions src/lib/merkle_ledger_tests/test_mask.ml
Original file line number Diff line number Diff line change
Expand Up @@ -188,17 +188,6 @@ module Make (Test : Test_intf) = struct
(* verify all hashes to root are same in mask and parent *)
compare_maskable_mask_hashes maskable attached_mask dummy_address )

let%test "mask delegates to parent" =
Test.with_instances (fun maskable mask ->
let attached_mask = Maskable.register_mask maskable mask in
(* set to parent, get from mask *)
Maskable.set maskable dummy_location dummy_account ;
let mask_result = Mask.Attached.get attached_mask dummy_location in
Option.is_some mask_result
&&
let mask_account = Option.value_exn mask_result in
Account.equal dummy_account mask_account )

let%test "mask prune after parent notification" =
Test.with_instances (fun maskable mask ->
let attached_mask = Maskable.register_mask maskable mask in
Expand Down Expand Up @@ -763,8 +752,9 @@ module Make_maskable_and_mask_with_depth (Depth : Depth_S) = struct
| Generic of Merkle_ledger.Location.Bigstring.t
| Account of Location.Addr.t
| Hash of Location.Addr.t
[@@deriving hash, sexp, compare]
[@@deriving hash, sexp]

include Comparable.Make_binable (Arg)
include Hashable.Make_binable (Arg) [@@deriving sexp, compare, hash, yojson]
end

Expand Down
1 change: 1 addition & 0 deletions src/lib/merkle_mask/dune
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
visualization
mina_stdlib
direction
empty_hashes
)
(preprocess
(pps
Expand Down
7 changes: 5 additions & 2 deletions src/lib/merkle_mask/inputs_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ module type S = sig

module Location : Merkle_ledger.Location_intf.S

module Location_binable :
Core_kernel.Hashable.S_binable with type t := Location.t
module Location_binable : sig
include Core_kernel.Hashable.S_binable with type t := Location.t

include Core_kernel.Comparable.S_binable with type t := Location.t
end

module Base :
Base_merkle_tree_intf.S
Expand Down
3 changes: 3 additions & 0 deletions src/lib/merkle_mask/maskable_merkle_tree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ module Make (Inputs : Inputs_intf) = struct
Node (summary, List.map masks ~f:(_crawl (module Mask.Attached)))
end

let unsafe_preload_accounts_from_parent =
Mask.Attached.unsafe_preload_accounts_from_parent

let register_mask t mask =
let attached_mask = Mask.set_parent mask t in
List.iter (Uuid.Table.data registered_masks) ~f:(fun ms ->
Expand Down
Loading

0 comments on commit d9035e7

Please sign in to comment.