package objecttree import ( "golang.org/x/exp/slices" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/util/slice" ) type LoadIterator interface { NextBatch(maxSize int) (batch []*treechangeproto.RawTreeChangeWithId, heads []string, err error) } type loadIterator struct { loader *rawChangeLoader cache map[string]rawCacheEntry idStack []string heads []string lastHeads []string root *Change lastChange *Change iter *iterator } func newLoadIterator(loader *rawChangeLoader) *loadIterator { return &loadIterator{ loader: loader, cache: make(map[string]rawCacheEntry), } } func (l *loadIterator) NextBatch(maxSize int) (batch []*treechangeproto.RawTreeChangeWithId, heads []string, err error) { l.iter = newIterator() defer freeIterator(l.iter) curSize := 0 heads = l.lastHeads l.iter.iterateSkip(l.root, l.lastChange, func(c *Change) (isContinue bool) { rawEntry := l.cache[c.Id] if rawEntry.removed { return true } if curSize+rawEntry.size > maxSize { l.lastChange = c return false } curSize += rawEntry.size batch = append(batch, rawEntry.rawChange) heads = slice.DiscardFromSlice(heads, func(s string) bool { return slices.Contains(c.PreviousIds, s) }) heads = append(heads, c.Id) return true }) l.lastHeads = heads return } func (l *loadIterator) load(commonSnapshot string, heads, breakpoints []string) (err error) { existingBreakpoints := make([]string, 0, len(breakpoints)) for _, b := range breakpoints { _, err := l.loader.loadEntry(b) if err != nil { continue } existingBreakpoints = append(existingBreakpoints, b) } loadedCs, err := l.loader.loadEntry(commonSnapshot) if err != nil { return err } l.cache[commonSnapshot] = loadedCs l.heads = heads dfs := func( commonSnapshot string, heads []string, shouldVisit func(entry rawCacheEntry, mapExists bool) bool, visit func(entry rawCacheEntry) rawCacheEntry) (err error) { // resetting stack l.idStack = l.idStack[:0] l.idStack = append(l.idStack, heads...) for len(l.idStack) > 0 { id := l.idStack[len(l.idStack)-1] l.idStack = l.idStack[:len(l.idStack)-1] entry, exists := l.cache[id] if !shouldVisit(entry, exists) { continue } if !exists { entry, err = l.loader.loadEntry(id) if err != nil { return } } // setting the counter when we visit entry = visit(entry) l.cache[id] = entry for _, prev := range entry.change.PreviousIds { prevEntry, exists := l.cache[prev] if !shouldVisit(prevEntry, exists) { continue } l.idStack = append(l.idStack, prev) } } return nil } l.idStack = append(l.idStack, heads...) // load everything err = dfs(commonSnapshot, heads, func(_ rawCacheEntry, mapExists bool) bool { return !mapExists }, func(entry rawCacheEntry) rawCacheEntry { entry.position = 0 return entry }) if err != nil { return } // marking some changes as removed, not to send to anybody err = dfs(commonSnapshot, existingBreakpoints, func(entry rawCacheEntry, mapExists bool) bool { // only going through already loaded changes return mapExists && !entry.removed }, func(entry rawCacheEntry) rawCacheEntry { entry.removed = true return entry }) if err != nil { return } // setting next for all changes in batch err = dfs(commonSnapshot, heads, func(entry rawCacheEntry, mapExists bool) bool { return mapExists && !entry.nextSet }, func(entry rawCacheEntry) rawCacheEntry { for _, id := range entry.change.PreviousIds { prevEntry, exists := l.cache[id] if !exists { continue } prev := prevEntry.change if len(prev.Next) == 0 || prev.Next[len(prev.Next)-1].Id <= entry.change.Id { prev.Next = append(prev.Next, entry.change) } else { insertIdx := 0 for idx, el := range prev.Next { if el.Id >= entry.change.Id { insertIdx = idx break } } prev.Next = append(prev.Next[:insertIdx+1], prev.Next[insertIdx:]...) prev.Next[insertIdx] = entry.change } } entry.nextSet = true return entry }) if err != nil { return } l.lastHeads = []string{l.root.Id} l.lastChange = loadedCs.change l.root = loadedCs.change return nil }