diff --git a/commonspace/object/tree/objecttree/objecttree.go b/commonspace/object/tree/objecttree/objecttree.go index a90cb4e0..f6c63053 100644 --- a/commonspace/object/tree/objecttree/objecttree.go +++ b/commonspace/object/tree/objecttree/objecttree.go @@ -460,11 +460,24 @@ func (ot *objectTree) addChangesToTree(ctx context.Context, changesPayload RawCh headsToUse = []string{} } rollback := func(changes []*Change) { + var visited []*Change for _, ch := range changes { - if _, exists := ot.tree.attached[ch.Id]; exists { + if ex, exists := ot.tree.attached[ch.Id]; exists { + ex.visited = true + visited = append(visited, ex) delete(ot.tree.attached, ch.Id) } } + for _, ch := range ot.tree.attached { + // deleting all visited changes from next + ch.Next = slice.DiscardFromSlice(ch.Next, func(change *Change) bool { + return change.visited + }) + } + // doing this just in case + for _, ch := range visited { + ch.visited = false + } ot.tree.headIds = headsCopy(prevHeadsCopy) ot.tree.lastIteratedHeadId = lastIteratedId } diff --git a/commonspace/object/tree/objecttree/objecttree_test.go b/commonspace/object/tree/objecttree/objecttree_test.go index 010a26e4..7736e2e8 100644 --- a/commonspace/object/tree/objecttree/objecttree_test.go +++ b/commonspace/object/tree/objecttree/objecttree_test.go @@ -818,6 +818,29 @@ func TestObjectTree(t *testing.T) { } }) + t.Run("add with rollback", func(t *testing.T) { + ctx := prepareTreeContext(t, aclList) + changeCreator := ctx.changeCreator + objTree := ctx.objTree + + rawChanges := []*treechangeproto.RawTreeChangeWithId{ + changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"), + changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"), + changeCreator.CreateRaw("3", aclList.Head().Id, "0", false, "2"), + changeCreator.CreateRaw("4", aclList.Head().Id, "0", false, "3"), + } + payload := RawChangesPayload{ + NewHeads: []string{rawChanges[len(rawChanges)-1].Id}, + RawChanges: rawChanges, + } + tr := objTree.(*objectTree) + tr.validator.(*noOpTreeValidator).fail = true + _, err := objTree.AddRawChanges(context.Background(), payload) + require.Error(t, err) + require.Len(t, tr.tree.attached, 1) + require.Empty(t, tr.tree.attached["0"].Next) + }) + t.Run("add new snapshot simple with newChangeFlusher", func(t *testing.T) { ctx := prepareTreeContext(t, aclList) treeStorage := ctx.treeStorage diff --git a/commonspace/object/tree/objecttree/objecttreevalidator.go b/commonspace/object/tree/objecttree/objecttreevalidator.go index 3ef5625e..0276c441 100644 --- a/commonspace/object/tree/objecttree/objecttreevalidator.go +++ b/commonspace/object/tree/objecttree/objecttreevalidator.go @@ -31,13 +31,20 @@ type ObjectTreeValidator interface { type noOpTreeValidator struct { filterFunc func(ch *Change) bool + fail bool } func (n *noOpTreeValidator) ValidateFullTree(tree *Tree, aclList list.AclList) error { + if n.fail { + return fmt.Errorf("failed") + } return nil } func (n *noOpTreeValidator) ValidateNewChanges(tree *Tree, aclList list.AclList, newChanges []*Change) error { + if n.fail { + return fmt.Errorf("failed") + } return nil }