1
0
Fork 0
mirror of https://github.com/anyproto/any-sync.git synced 2025-06-08 05:57:03 +09:00

WIP fix diff

This commit is contained in:
Mikhail Rakhmanov 2025-03-21 11:45:30 +01:00
parent 7beb6b1a77
commit cfc54ad61f
No known key found for this signature in database
GPG key ID: DED12CFEF5B8396B
5 changed files with 1022 additions and 5 deletions

View file

@ -120,6 +120,7 @@ type diff struct {
compareThreshold int
ranges *hashRanges
mu sync.RWMutex
id int
}
// Compare implements skiplist interface
@ -237,11 +238,10 @@ func (d *diff) getRange(r Range) (rr RangeResult) {
if rng != nil {
rr.Hash = rng.hash
rr.Count = rng.elements
if !r.Elements && rng.isDivided {
if !r.Elements {
return
}
}
el := d.sl.Find(&element{hash: r.From})
rr.Elements = make([]Element, 0, d.divideFactor)
for el != nil && el.Key().(*element).hash <= r.To {
@ -312,7 +312,6 @@ func (d *diff) compareResults(dctx *diffCtx, r Range, myRes, otherRes RangeResul
if bytes.Equal(myRes.Hash, otherRes.Hash) {
return
}
// other has elements
if len(otherRes.Elements) == otherRes.Count {
if len(myRes.Elements) == myRes.Count {
@ -323,8 +322,7 @@ func (d *diff) compareResults(dctx *diffCtx, r Range, myRes, otherRes RangeResul
}
return
}
// request all elements from other, because we don't have enough
if len(myRes.Elements) == myRes.Count {
if otherRes.Count <= d.compareThreshold && len(otherRes.Elements) == 0 || len(myRes.Elements) == myRes.Count {
r.Elements = true
dctx.prepare = append(dctx.prepare, r)
return

View file

@ -0,0 +1,50 @@
package ldiff
import "context"
type RemoteTypeChecker interface {
DiffTypeCheck(ctx context.Context, diffContainer DiffContainer) (needsSync bool, diff Diff, err error)
}
type DiffContainer interface {
DiffTypeCheck(ctx context.Context, typeChecker RemoteTypeChecker) (needsSync bool, diff Diff, err error)
OldDiff() Diff
NewDiff() Diff
Set(elements ...Element)
RemoveId(id string) error
}
type diffContainer struct {
newDiff Diff
oldDiff Diff
}
func (d *diffContainer) NewDiff() Diff {
return d.newDiff
}
func (d *diffContainer) OldDiff() Diff {
return d.oldDiff
}
func (d *diffContainer) Set(elements ...Element) {
d.newDiff.Set(elements...)
d.oldDiff.Set(elements...)
}
func (d *diffContainer) RemoveId(id string) error {
_ = d.newDiff.RemoveId(id)
_ = d.oldDiff.RemoveId(id)
return nil
}
func (d *diffContainer) DiffTypeCheck(ctx context.Context, typeChecker RemoteTypeChecker) (needsSync bool, diff Diff, err error) {
return typeChecker.DiffTypeCheck(ctx, d)
}
func NewDiffContainer(divideFactor, compareThreshold int) DiffContainer {
newDiff := newDiff(divideFactor, compareThreshold)
return &diffContainer{
precalculated: newDiff,
}
}

318
app/olddiff/diff.go Normal file
View file

@ -0,0 +1,318 @@
//go:generate mockgen -destination mock_olddiff/mock_olddiff.go github.com/anyproto/any-sync/app/olddiff Diff,Remote
package olddiff
import (
"bytes"
"context"
"encoding/hex"
"errors"
"math"
"sync"
"github.com/cespare/xxhash"
"github.com/huandu/skiplist"
"github.com/zeebo/blake3"
"github.com/anyproto/any-sync/app/ldiff"
)
// New creates precalculated Diff container
//
// divideFactor - means how many hashes you want to ask for once
//
// it must be 2 or greater
// normal value usually between 4 and 64
//
// compareThreshold - means the maximum count of elements remote diff will send directly
//
// if elements under range will be more - remote diff will send only hash
// it must be 1 or greater
// normal value between 8 and 64
//
// Less threshold and divideFactor - less traffic but more requests
func New(divideFactor, compareThreshold int) ldiff.Diff {
return newDiff(divideFactor, compareThreshold)
}
func newDiff(divideFactor, compareThreshold int) ldiff.Diff {
if divideFactor < 2 {
divideFactor = 2
}
if compareThreshold < 1 {
compareThreshold = 1
}
d := &diff{
divideFactor: divideFactor,
compareThreshold: compareThreshold,
}
d.sl = skiplist.New(d)
d.ranges = newHashRanges(divideFactor, compareThreshold, d.sl)
d.ranges.dirty[d.ranges.topRange] = struct{}{}
d.ranges.recalculateHashes()
return d
}
var hashersPool = &sync.Pool{
New: func() any {
return blake3.New()
},
}
var ErrElementNotFound = errors.New("ldiff: element not found")
type element struct {
ldiff.Element
hash uint64
}
// Diff contains elements and can compare it with Remote diff
type diff struct {
sl *skiplist.SkipList
divideFactor int
compareThreshold int
ranges *hashRanges
mu sync.RWMutex
}
// Compare implements skiplist interface
func (d *diff) Compare(lhs, rhs interface{}) int {
lhe := lhs.(*element)
rhe := rhs.(*element)
if lhe.Id == rhe.Id {
return 0
}
if lhe.hash > rhe.hash {
return 1
} else if lhe.hash < rhe.hash {
return -1
}
if lhe.Id > rhe.Id {
return 1
} else {
return -1
}
}
// CalcScore implements skiplist interface
func (d *diff) CalcScore(key interface{}) float64 {
return 0
}
// Set adds or update element in container
func (d *diff) Set(elements ...ldiff.Element) {
d.mu.Lock()
defer d.mu.Unlock()
for _, e := range elements {
hash := xxhash.Sum64([]byte(e.Id))
el := &element{Element: e, hash: hash}
d.sl.Remove(el)
d.sl.Set(el, nil)
d.ranges.addElement(hash)
}
d.ranges.recalculateHashes()
}
func (d *diff) Ids() (ids []string) {
d.mu.RLock()
defer d.mu.RUnlock()
ids = make([]string, 0, d.sl.Len())
cur := d.sl.Front()
for cur != nil {
el := cur.Key().(*element).Element
ids = append(ids, el.Id)
cur = cur.Next()
}
return
}
func (d *diff) Len() int {
d.mu.RLock()
defer d.mu.RUnlock()
return d.sl.Len()
}
func (d *diff) Elements() (elements []ldiff.Element) {
d.mu.RLock()
defer d.mu.RUnlock()
elements = make([]ldiff.Element, 0, d.sl.Len())
cur := d.sl.Front()
for cur != nil {
el := cur.Key().(*element).Element
elements = append(elements, el)
cur = cur.Next()
}
return
}
func (d *diff) Element(id string) (ldiff.Element, error) {
d.mu.RLock()
defer d.mu.RUnlock()
el := d.sl.Get(&element{Element: ldiff.Element{Id: id}, hash: xxhash.Sum64([]byte(id))})
if el == nil {
return ldiff.Element{}, ErrElementNotFound
}
if e, ok := el.Key().(*element); ok {
return e.Element, nil
}
return ldiff.Element{}, ErrElementNotFound
}
func (d *diff) Hash() string {
d.mu.RLock()
defer d.mu.RUnlock()
return hex.EncodeToString(d.ranges.hash())
}
// RemoveId removes element by id
func (d *diff) RemoveId(id string) error {
d.mu.Lock()
defer d.mu.Unlock()
hash := xxhash.Sum64([]byte(id))
el := &element{Element: ldiff.Element{
Id: id,
}, hash: hash}
if d.sl.Remove(el) == nil {
return ErrElementNotFound
}
d.ranges.removeElement(hash)
d.ranges.recalculateHashes()
return nil
}
func (d *diff) getRange(r ldiff.Range) (rr ldiff.RangeResult) {
rng := d.ranges.getRange(r.From, r.To)
// if we have the division for this range
if rng != nil {
rr.Hash = rng.hash
rr.Count = rng.elements
if !r.Elements && rng.isDivided {
return
}
}
el := d.sl.Find(&element{hash: r.From})
rr.Elements = make([]ldiff.Element, 0, d.divideFactor)
for el != nil && el.Key().(*element).hash <= r.To {
elem := el.Key().(*element).Element
el = el.Next()
rr.Elements = append(rr.Elements, elem)
}
rr.Count = len(rr.Elements)
return
}
// Ranges calculates given ranges and return results
func (d *diff) Ranges(ctx context.Context, ranges []ldiff.Range, resBuf []ldiff.RangeResult) (results []ldiff.RangeResult, err error) {
d.mu.RLock()
defer d.mu.RUnlock()
results = resBuf[:0]
for _, r := range ranges {
results = append(results, d.getRange(r))
}
return
}
type diffCtx struct {
newIds, changedIds, removedIds []string
toSend, prepare []ldiff.Range
myRes, otherRes []ldiff.RangeResult
}
var errMismatched = errors.New("query and results mismatched")
// Diff makes diff with remote container
func (d *diff) Diff(ctx context.Context, dl ldiff.Remote) (newIds, changedIds, removedIds []string, err error) {
dctx := &diffCtx{}
dctx.toSend = append(dctx.toSend, ldiff.Range{
From: 0,
To: math.MaxUint64,
})
for len(dctx.toSend) > 0 {
select {
case <-ctx.Done():
err = ctx.Err()
return
default:
}
if dctx.otherRes, err = dl.Ranges(ctx, dctx.toSend, dctx.otherRes); err != nil {
return
}
if dctx.myRes, err = d.Ranges(ctx, dctx.toSend, dctx.myRes); err != nil {
return
}
if len(dctx.otherRes) != len(dctx.toSend) || len(dctx.myRes) != len(dctx.toSend) {
err = errMismatched
return
}
for i, r := range dctx.toSend {
d.compareResults(dctx, r, dctx.myRes[i], dctx.otherRes[i])
}
dctx.toSend, dctx.prepare = dctx.prepare, dctx.toSend
dctx.prepare = dctx.prepare[:0]
}
return dctx.newIds, dctx.changedIds, dctx.removedIds, nil
}
func (d *diff) compareResults(dctx *diffCtx, r ldiff.Range, myRes, otherRes ldiff.RangeResult) {
// both hash equals - do nothing
if bytes.Equal(myRes.Hash, otherRes.Hash) {
return
}
// other has elements
if len(otherRes.Elements) == otherRes.Count {
if len(myRes.Elements) == myRes.Count {
d.compareElements(dctx, myRes.Elements, otherRes.Elements)
} else {
r.Elements = true
d.compareElements(dctx, d.getRange(r).Elements, otherRes.Elements)
}
return
}
// request all elements from other, because we don't have enough
if len(myRes.Elements) == myRes.Count {
r.Elements = true
dctx.prepare = append(dctx.prepare, r)
return
}
rangeTuples := genTupleRanges(r.From, r.To, d.divideFactor)
for _, tuple := range rangeTuples {
dctx.prepare = append(dctx.prepare, ldiff.Range{From: tuple.from, To: tuple.to})
}
return
}
func (d *diff) compareElements(dctx *diffCtx, my, other []ldiff.Element) {
find := func(list []ldiff.Element, targetEl ldiff.Element) (has, eq bool) {
for _, el := range list {
if el.Id == targetEl.Id {
return true, el.Head == targetEl.Head
}
}
return false, false
}
for _, el := range my {
has, eq := find(other, el)
if !has {
dctx.removedIds = append(dctx.removedIds, el.Id)
continue
} else {
if !eq {
dctx.changedIds = append(dctx.changedIds, el.Id)
}
}
}
for _, el := range other {
if has, _ := find(my, el); !has {
dctx.newIds = append(dctx.newIds, el.Id)
}
}
}

428
app/olddiff/diff_test.go Normal file
View file

@ -0,0 +1,428 @@
package olddiff
import (
"context"
"fmt"
"math"
"sort"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"github.com/anyproto/any-sync/app/ldiff"
)
func TestDiff_fillRange(t *testing.T) {
d := New(4, 4).(*diff)
for i := 0; i < 10; i++ {
el := ldiff.Element{
Id: fmt.Sprint(i),
Head: fmt.Sprint("h", i),
}
d.Set(el)
}
t.Log(d.sl.Len())
t.Run("elements", func(t *testing.T) {
r := ldiff.Range{From: 0, To: math.MaxUint64}
res := d.getRange(r)
assert.NotNil(t, res.Hash)
assert.Equal(t, res.Count, 10)
})
}
func TestDiff_Diff(t *testing.T) {
ctx := context.Background()
t.Run("basic", func(t *testing.T) {
d1 := New(16, 16)
d2 := New(16, 16)
for i := 0; i < 1000; i++ {
id := fmt.Sprint(i)
head := uuid.NewString()
d1.Set(ldiff.Element{
Id: id,
Head: head,
})
d2.Set(ldiff.Element{
Id: id,
Head: head,
})
}
newIds, changedIds, removedIds, err := d1.Diff(ctx, d2)
require.NoError(t, err)
assert.Len(t, newIds, 0)
assert.Len(t, changedIds, 0)
assert.Len(t, removedIds, 0)
d2.Set(ldiff.Element{
Id: "newD1",
Head: "newD1",
})
d2.Set(ldiff.Element{
Id: "1",
Head: "changed",
})
require.NoError(t, d2.RemoveId("0"))
newIds, changedIds, removedIds, err = d1.Diff(ctx, d2)
require.NoError(t, err)
assert.Len(t, newIds, 1)
assert.Len(t, changedIds, 1)
assert.Len(t, removedIds, 1)
})
t.Run("complex", func(t *testing.T) {
d1 := New(16, 128)
d2 := New(16, 128)
length := 10000
for i := 0; i < length; i++ {
id := fmt.Sprint(i)
head := uuid.NewString()
d1.Set(ldiff.Element{
Id: id,
Head: head,
})
}
newIds, changedIds, removedIds, err := d1.Diff(ctx, d2)
require.NoError(t, err)
assert.Len(t, newIds, 0)
assert.Len(t, changedIds, 0)
assert.Len(t, removedIds, length)
for i := 0; i < length; i++ {
id := fmt.Sprint(i)
head := uuid.NewString()
d2.Set(ldiff.Element{
Id: id,
Head: head,
})
}
newIds, changedIds, removedIds, err = d1.Diff(ctx, d2)
require.NoError(t, err)
assert.Len(t, newIds, 0)
assert.Len(t, changedIds, length)
assert.Len(t, removedIds, 0)
for i := 0; i < length; i++ {
id := fmt.Sprint(i)
head := uuid.NewString()
d2.Set(ldiff.Element{
Id: id,
Head: head,
})
}
res, err := d1.Ranges(
context.Background(),
[]ldiff.Range{{From: 0, To: math.MaxUint64, Elements: true}},
nil)
require.NoError(t, err)
require.Len(t, res, 1)
for i, el := range res[0].Elements {
if i < length/2 {
continue
}
id := el.Id
head := el.Head
d2.Set(ldiff.Element{
Id: id,
Head: head,
})
}
newIds, changedIds, removedIds, err = d1.Diff(ctx, d2)
require.NoError(t, err)
assert.Len(t, newIds, 0)
assert.Len(t, changedIds, length/2)
assert.Len(t, removedIds, 0)
})
t.Run("empty", func(t *testing.T) {
d1 := New(16, 16)
d2 := New(16, 16)
newIds, changedIds, removedIds, err := d1.Diff(ctx, d2)
require.NoError(t, err)
assert.Len(t, newIds, 0)
assert.Len(t, changedIds, 0)
assert.Len(t, removedIds, 0)
})
t.Run("one empty", func(t *testing.T) {
d1 := New(4, 4)
d2 := New(4, 4)
length := 10000
for i := 0; i < length; i++ {
d2.Set(ldiff.Element{
Id: fmt.Sprint(i),
Head: uuid.NewString(),
})
}
newIds, changedIds, removedIds, err := d1.Diff(ctx, d2)
require.NoError(t, err)
assert.Len(t, newIds, length)
assert.Len(t, changedIds, 0)
assert.Len(t, removedIds, 0)
})
t.Run("not intersecting", func(t *testing.T) {
d1 := New(16, 16)
d2 := New(16, 16)
length := 10000
for i := 0; i < length; i++ {
d1.Set(ldiff.Element{
Id: fmt.Sprint(i),
Head: uuid.NewString(),
})
}
for i := length; i < length*2; i++ {
d2.Set(ldiff.Element{
Id: fmt.Sprint(i),
Head: uuid.NewString(),
})
}
newIds, changedIds, removedIds, err := d1.Diff(ctx, d2)
require.NoError(t, err)
assert.Len(t, newIds, length)
assert.Len(t, changedIds, 0)
assert.Len(t, removedIds, length)
})
t.Run("context cancel", func(t *testing.T) {
d1 := New(4, 4)
d2 := New(4, 4)
for i := 0; i < 10; i++ {
d2.Set(ldiff.Element{
Id: fmt.Sprint(i),
Head: uuid.NewString(),
})
}
var cancel func()
ctx, cancel = context.WithCancel(ctx)
cancel()
_, _, _, err := d1.Diff(ctx, d2)
assert.ErrorIs(t, err, context.Canceled)
})
}
func BenchmarkDiff_Ranges(b *testing.B) {
d := New(16, 16)
for i := 0; i < 10000; i++ {
id := fmt.Sprint(i)
head := uuid.NewString()
d.Set(ldiff.Element{
Id: id,
Head: head,
})
}
ctx := context.Background()
b.ResetTimer()
b.ReportAllocs()
var resBuf []ldiff.RangeResult
var ranges = []ldiff.Range{{From: 0, To: math.MaxUint64}}
for i := 0; i < b.N; i++ {
d.Ranges(ctx, ranges, resBuf)
resBuf = resBuf[:0]
}
}
func TestDiff_Hash(t *testing.T) {
d := New(16, 16)
h1 := d.Hash()
assert.NotEmpty(t, h1)
d.Set(ldiff.Element{Id: "1"})
h2 := d.Hash()
assert.NotEmpty(t, h2)
assert.NotEqual(t, h1, h2)
}
func TestDiff_Element(t *testing.T) {
d := New(16, 16)
for i := 0; i < 10; i++ {
d.Set(ldiff.Element{Id: fmt.Sprint("id", i), Head: fmt.Sprint("head", i)})
}
_, err := d.Element("not found")
assert.Equal(t, ErrElementNotFound, err)
el, err := d.Element("id5")
require.NoError(t, err)
assert.Equal(t, "head5", el.Head)
d.Set(ldiff.Element{"id5", "otherHead"})
el, err = d.Element("id5")
require.NoError(t, err)
assert.Equal(t, "otherHead", el.Head)
}
func TestDiff_Ids(t *testing.T) {
d := New(16, 16)
var ids []string
for i := 0; i < 10; i++ {
id := fmt.Sprint("id", i)
d.Set(ldiff.Element{Id: id, Head: fmt.Sprint("head", i)})
ids = append(ids, id)
}
gotIds := d.Ids()
sort.Strings(gotIds)
assert.Equal(t, ids, gotIds)
assert.Equal(t, len(ids), d.Len())
}
func TestDiff_Elements(t *testing.T) {
d := New(16, 16)
var els []ldiff.Element
for i := 0; i < 10; i++ {
id := fmt.Sprint("id", i)
el := ldiff.Element{Id: id, Head: fmt.Sprint("head", i)}
d.Set(el)
els = append(els, el)
}
gotEls := d.Elements()
sort.Slice(gotEls, func(i, j int) bool {
return gotEls[i].Id < gotEls[j].Id
})
assert.Equal(t, els, gotEls)
}
func TestRangesAddRemove(t *testing.T) {
length := 10000
divideFactor := 4
compareThreshold := 4
addTwice := func() string {
d := New(divideFactor, compareThreshold)
var els []ldiff.Element
for i := 0; i < length; i++ {
if i < length/20 {
continue
}
els = append(els, ldiff.Element{
Id: fmt.Sprint(i),
Head: fmt.Sprint("h", i),
})
}
d.Set(els...)
els = els[:0]
for i := 0; i < length/20; i++ {
els = append(els, ldiff.Element{
Id: fmt.Sprint(i),
Head: fmt.Sprint("h", i),
})
}
d.Set(els...)
return d.Hash()
}
addOnce := func() string {
d := New(divideFactor, compareThreshold)
var els []ldiff.Element
for i := 0; i < length; i++ {
els = append(els, ldiff.Element{
Id: fmt.Sprint(i),
Head: fmt.Sprint("h", i),
})
}
d.Set(els...)
return d.Hash()
}
addRemove := func() string {
d := New(divideFactor, compareThreshold)
var els []ldiff.Element
for i := 0; i < length; i++ {
els = append(els, ldiff.Element{
Id: fmt.Sprint(i),
Head: fmt.Sprint("h", i),
})
}
d.Set(els...)
for i := 0; i < length/20; i++ {
err := d.RemoveId(fmt.Sprint(i))
require.NoError(t, err)
}
els = els[:0]
for i := 0; i < length/20; i++ {
els = append(els, ldiff.Element{
Id: fmt.Sprint(i),
Head: fmt.Sprint("h", i),
})
}
d.Set(els...)
return d.Hash()
}
require.Equal(t, addTwice(), addOnce(), addRemove())
}
func printBestParams() {
numTests := 10
length := 100000
calcParams := func(divideFactor, compareThreshold, length int) (total, maxLevel, avgLevel, zeroEls int) {
d := New(divideFactor, compareThreshold).(*diff)
var els []ldiff.Element
for i := 0; i < length; i++ {
els = append(els, ldiff.Element{
Id: uuid.NewString(),
Head: uuid.NewString(),
})
}
d.Set(els...)
for _, rng := range d.ranges.ranges {
if rng.elements == 0 {
zeroEls++
}
if rng.level > maxLevel {
maxLevel = rng.level
}
avgLevel += rng.level
}
total = len(d.ranges.ranges)
avgLevel = avgLevel / total
return
}
type result struct {
divFactor, compThreshold, numRanges, maxLevel, avgLevel, zeroEls int
}
sf := func(i, j result) int {
if i.numRanges < j.numRanges {
return -1
} else if i.numRanges == j.numRanges {
return 0
} else {
return 1
}
}
var results []result
for divFactor := 0; divFactor < 6; divFactor++ {
df := 1 << divFactor
for compThreshold := 0; compThreshold < 10; compThreshold++ {
ct := 1 << compThreshold
fmt.Println("starting, df:", df, "ct:", ct)
var rngs []result
for i := 0; i < numTests; i++ {
total, maxLevel, avgLevel, zeroEls := calcParams(df, ct, length)
rngs = append(rngs, result{
divFactor: df,
compThreshold: ct,
numRanges: total,
maxLevel: maxLevel,
avgLevel: avgLevel,
zeroEls: zeroEls,
})
}
slices.SortFunc(rngs, sf)
ranges := rngs[len(rngs)/2]
results = append(results, ranges)
}
}
slices.SortFunc(results, sf)
fmt.Println(results)
// 100000 - [{16 512 273 2 1 0} {4 512 341 4 3 0} {2 512 511 8 7 0} {1 512 511 8 7 0}
// {8 256 585 3 2 0} {8 512 585 3 2 0} {1 256 1023 9 8 0} {2 256 1023 9 8 0}
// {32 256 1057 2 1 0} {32 512 1057 2 1 0} {32 128 1089 3 1 0} {4 256 1365 5 4 0}
// {4 128 1369 6 4 0} {2 128 2049 11 9 0} {1 128 2049 11 9 0} {1 64 4157 12 10 0}
// {2 64 4159 12 10 0} {16 128 4369 3 2 0} {16 64 4369 3 2 0} {16 256 4369 3 2 0}
// {8 64 4681 4 3 0} {8 128 4681 4 3 0} {4 64 5461 6 5 0} {4 32 6389 7 5 0}
// {8 32 6505 5 4 17} {16 32 8033 4 3 374} {2 32 8619 13 11 0} {1 32 8621 13 11 0}
// {2 16 17837 15 12 0} {1 16 17847 15 12 0} {4 16 21081 8 6 22} {32 64 33825 3 2 1578}
// {32 32 33825 3 2 1559} {32 16 33825 3 2 1518} {8 16 35881 5 4 1313} {16 16 66737 4 3 13022}]
// 1000000 - [{8 256 11753 5 4 0}]
// 1000000 - [{16 128 69905 4 3 0}]
// 1000000 - [{32 256 33825 3 2 0}]
}

223
app/olddiff/hashrange.go Normal file
View file

@ -0,0 +1,223 @@
package olddiff
import (
"math"
"github.com/huandu/skiplist"
"github.com/zeebo/blake3"
"golang.org/x/exp/slices"
)
type hashRange struct {
from, to uint64
parent *hashRange
isDivided bool
elements int
level int
hash []byte
}
type rangeTuple struct {
from, to uint64
}
type hashRanges struct {
ranges map[rangeTuple]*hashRange
topRange *hashRange
sl *skiplist.SkipList
dirty map[*hashRange]struct{}
divideFactor int
compareThreshold int
}
func newHashRanges(divideFactor, compareThreshold int, sl *skiplist.SkipList) *hashRanges {
h := &hashRanges{
ranges: make(map[rangeTuple]*hashRange),
dirty: make(map[*hashRange]struct{}),
divideFactor: divideFactor,
compareThreshold: compareThreshold,
sl: sl,
}
h.topRange = &hashRange{
from: 0,
to: math.MaxUint64,
isDivided: true,
level: 0,
}
h.ranges[rangeTuple{from: 0, to: math.MaxUint64}] = h.topRange
h.makeBottomRanges(h.topRange)
return h
}
func (h *hashRanges) hash() []byte {
return h.topRange.hash
}
func (h *hashRanges) addElement(elHash uint64) {
rng := h.topRange
rng.elements++
for rng.isDivided {
rng = h.getBottomRange(rng, elHash)
rng.elements++
}
h.dirty[rng] = struct{}{}
if rng.elements > h.compareThreshold {
rng.isDivided = true
h.makeBottomRanges(rng)
}
if rng.parent != nil {
if _, ok := h.dirty[rng.parent]; ok {
delete(h.dirty, rng.parent)
}
}
}
func (h *hashRanges) removeElement(elHash uint64) {
rng := h.topRange
rng.elements--
for rng.isDivided {
rng = h.getBottomRange(rng, elHash)
rng.elements--
}
parent := rng.parent
if parent.elements <= h.compareThreshold && parent != h.topRange {
ranges := genTupleRanges(parent.from, parent.to, h.divideFactor)
for _, tuple := range ranges {
child := h.ranges[tuple]
delete(h.ranges, tuple)
delete(h.dirty, child)
}
parent.isDivided = false
h.dirty[parent] = struct{}{}
} else {
h.dirty[rng] = struct{}{}
}
}
func (h *hashRanges) recalculateHashes() {
for len(h.dirty) > 0 {
var slDirty []*hashRange
for rng := range h.dirty {
slDirty = append(slDirty, rng)
}
slices.SortFunc(slDirty, func(a, b *hashRange) int {
if a.level < b.level {
return -1
} else if a.level > b.level {
return 1
} else {
return 0
}
})
for _, rng := range slDirty {
if rng.isDivided {
rng.hash = h.calcDividedHash(rng)
} else {
rng.hash, rng.elements = h.calcElementsHash(rng.from, rng.to)
}
delete(h.dirty, rng)
if rng.parent != nil {
h.dirty[rng.parent] = struct{}{}
}
}
}
}
func (h *hashRanges) getRange(from, to uint64) *hashRange {
return h.ranges[rangeTuple{from: from, to: to}]
}
func (h *hashRanges) getBottomRange(rng *hashRange, elHash uint64) *hashRange {
df := uint64(h.divideFactor)
perRange := (rng.to - rng.from) / df
align := ((rng.to-rng.from)%df + 1) % df
if align == 0 {
perRange++
}
bucket := (elHash - rng.from) / perRange
tuple := rangeTuple{from: rng.from + bucket*perRange, to: rng.from - 1 + (bucket+1)*perRange}
if bucket == df-1 {
tuple.to += align
}
return h.ranges[tuple]
}
func (h *hashRanges) makeBottomRanges(rng *hashRange) {
ranges := genTupleRanges(rng.from, rng.to, h.divideFactor)
for _, tuple := range ranges {
newRange := h.makeRange(tuple, rng)
h.ranges[tuple] = newRange
if newRange.elements > h.compareThreshold {
if _, ok := h.dirty[rng]; ok {
delete(h.dirty, rng)
}
h.dirty[newRange] = struct{}{}
newRange.isDivided = true
h.makeBottomRanges(newRange)
}
}
}
func (h *hashRanges) makeRange(tuple rangeTuple, parent *hashRange) *hashRange {
newRange := &hashRange{
from: tuple.from,
to: tuple.to,
parent: parent,
}
hash, els := h.calcElementsHash(tuple.from, tuple.to)
newRange.hash = hash
newRange.level = parent.level + 1
newRange.elements = els
return newRange
}
func (h *hashRanges) calcDividedHash(rng *hashRange) (hash []byte) {
hasher := hashersPool.Get().(*blake3.Hasher)
defer hashersPool.Put(hasher)
hasher.Reset()
ranges := genTupleRanges(rng.from, rng.to, h.divideFactor)
for _, tuple := range ranges {
child := h.ranges[tuple]
hasher.Write(child.hash)
}
hash = hasher.Sum(nil)
return
}
func genTupleRanges(from, to uint64, divideFactor int) (prepare []rangeTuple) {
df := uint64(divideFactor)
perRange := (to - from) / df
align := ((to-from)%df + 1) % df
if align == 0 {
perRange++
}
var j = from
for i := 0; i < divideFactor; i++ {
if i == divideFactor-1 {
perRange += align
}
prepare = append(prepare, rangeTuple{from: j, to: j + perRange - 1})
j += perRange
}
return
}
func (h *hashRanges) calcElementsHash(from, to uint64) (hash []byte, els int) {
hasher := hashersPool.Get().(*blake3.Hasher)
defer hashersPool.Put(hasher)
hasher.Reset()
el := h.sl.Find(&element{hash: from})
for el != nil && el.Key().(*element).hash <= to {
elem := el.Key().(*element).Element
el = el.Next()
hasher.WriteString(elem.Id)
hasher.WriteString(elem.Head)
els++
}
if els != 0 {
hash = hasher.Sum(nil)
}
return
}