1
0
Fork 0
mirror of https://github.com/anyproto/any-sync.git synced 2025-06-08 05:57:03 +09:00
any-sync/commonspace/object/tree/objecttree/loaditerator.go
2024-06-17 23:23:28 +02:00

192 lines
4.6 KiB
Go

package objecttree
import (
"golang.org/x/exp/slices"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/util/slice"
)
type LoadIterator interface {
NextBatch(maxSize int) (batch IteratorBatch, err error)
}
type IteratorBatch struct {
Batch []*treechangeproto.RawTreeChangeWithId
Heads []string
SnapshotPath []string
Root *treechangeproto.RawTreeChangeWithId
}
type loadIterator struct {
loader *rawChangeLoader
cache map[string]rawCacheEntry
idStack []string
heads []string
lastHeads []string
snapshotPath []string
root *Change
lastChange *Change
iter *iterator
isExhausted bool
}
func newLoadIterator(loader *rawChangeLoader, snapshotPath []string) *loadIterator {
return &loadIterator{
loader: loader,
cache: make(map[string]rawCacheEntry),
snapshotPath: snapshotPath,
}
}
func (l *loadIterator) NextBatch(maxSize int) (batch IteratorBatch, err error) {
l.iter = newIterator()
defer freeIterator(l.iter)
curSize := 0
batch.Root = l.loader.Root()
batch.Heads = l.lastHeads
batch.SnapshotPath = l.snapshotPath
if l.isExhausted {
return
}
l.isExhausted = true
l.iter.iterateSkip(l.root, l.lastChange, true, func(c *Change) (isContinue bool) {
l.lastChange = c
rawEntry := l.cache[c.Id]
if rawEntry.removed {
return true
}
if curSize+rawEntry.size > maxSize {
l.isExhausted = false
return false
}
curSize += rawEntry.size
batch.Batch = append(batch.Batch, rawEntry.rawChange)
batch.Heads = slice.DiscardFromSlice(batch.Heads, func(s string) bool {
return slices.Contains(c.PreviousIds, s)
})
batch.Heads = append(batch.Heads, c.Id)
return true
})
l.lastHeads = batch.Heads
return
}
func (l *loadIterator) load(commonSnapshot string, heads, breakpoints []string) (err error) {
existingBreakpoints := make([]string, 0, len(breakpoints))
for _, b := range breakpoints {
_, err := l.loader.loadEntry(b)
if err != nil {
continue
}
existingBreakpoints = append(existingBreakpoints, b)
}
loadedCs, err := l.loader.loadEntry(commonSnapshot)
if err != nil {
return err
}
l.cache[commonSnapshot] = loadedCs
l.heads = heads
dfs := func(
commonSnapshot string,
heads []string,
shouldVisit func(entry rawCacheEntry, mapExists bool) bool,
visit func(entry rawCacheEntry) rawCacheEntry) (err error) {
// resetting stack
l.idStack = l.idStack[:0]
l.idStack = append(l.idStack, heads...)
for len(l.idStack) > 0 {
id := l.idStack[len(l.idStack)-1]
l.idStack = l.idStack[:len(l.idStack)-1]
entry, exists := l.cache[id]
if !shouldVisit(entry, exists) {
continue
}
if !exists {
entry, err = l.loader.loadEntry(id)
if err != nil {
return
}
}
// setting the counter when we visit
entry = visit(entry)
l.cache[id] = entry
for _, prev := range entry.change.PreviousIds {
prevEntry, exists := l.cache[prev]
if !shouldVisit(prevEntry, exists) {
continue
}
l.idStack = append(l.idStack, prev)
}
}
return nil
}
l.idStack = append(l.idStack, heads...)
// load everything
err = dfs(commonSnapshot, heads,
func(_ rawCacheEntry, mapExists bool) bool {
return !mapExists
},
func(entry rawCacheEntry) rawCacheEntry {
entry.position = 0
return entry
})
if err != nil {
return
}
// marking some changes as removed, not to send to anybody
err = dfs(commonSnapshot, existingBreakpoints,
func(entry rawCacheEntry, mapExists bool) bool {
// only going through already loaded changes
return mapExists && !entry.removed
},
func(entry rawCacheEntry) rawCacheEntry {
entry.removed = true
return entry
})
if err != nil {
return
}
// setting next for all changes in batch
err = dfs(commonSnapshot, heads,
func(entry rawCacheEntry, mapExists bool) bool {
return mapExists && !entry.nextSet
},
func(entry rawCacheEntry) rawCacheEntry {
for _, id := range entry.change.PreviousIds {
prevEntry, exists := l.cache[id]
if !exists {
continue
}
prev := prevEntry.change
if len(prev.Next) == 0 || prev.Next[len(prev.Next)-1].Id <= entry.change.Id {
prev.Next = append(prev.Next, entry.change)
} else {
insertIdx := 0
for idx, el := range prev.Next {
if el.Id >= entry.change.Id {
insertIdx = idx
break
}
}
prev.Next = append(prev.Next[:insertIdx+1], prev.Next[insertIdx:]...)
prev.Next[insertIdx] = entry.change
}
}
entry.nextSet = true
return entry
})
if err != nil {
return
}
l.root = loadedCs.change
l.lastHeads = []string{l.root.Id}
l.lastChange = loadedCs.change
return nil
}