diff --git a/commonspace/object/tree/objecttree/objecttree_test.go b/commonspace/object/tree/objecttree/objecttree_test.go index e9dea299..010a26e4 100644 --- a/commonspace/object/tree/objecttree/objecttree_test.go +++ b/commonspace/object/tree/objecttree/objecttree_test.go @@ -323,13 +323,13 @@ func TestObjectTree(t *testing.T) { bStore = aTree.Storage().(*treestorage.InMemoryTreeStorage).Copy() root, _ = bStore.Root() heads, _ := bStore.Heads() - filteredPayload, err := ValidateFilterRawTree(treestorage.TreeStorageCreatePayload{ + newTree, err := ValidateFilterRawTree(treestorage.TreeStorageCreatePayload{ RootRawChange: root, Changes: bStore.AllChanges(), Heads: heads, - }, bAccount.Acl) + }, InMemoryStorageCreator{}, bAccount.Acl) require.NoError(t, err) - require.Equal(t, 2, len(filteredPayload.Changes)) + require.Equal(t, 2, len(newTree.Storage().(*treestorage.InMemoryTreeStorage).AllChanges())) err = aTree.IterateRoot(func(change *Change, decrypted []byte) (any, error) { return nil, nil }, func(change *Change) bool { @@ -497,6 +497,7 @@ func TestObjectTree(t *testing.T) { store, _ := treestorage.NewInMemoryTreeStorage(root, []string{root.Id}, []*treechangeproto.RawTreeChangeWithId{root}) oTree, err := BuildObjectTree(store, aclList) require.NoError(t, err) + emptyDataTreeDeps = nonVerifiableTreeDeps err = ValidateRawTree(treestorage.TreeStorageCreatePayload{ RootRawChange: oTree.Header(), Heads: []string{root.Id}, @@ -516,6 +517,7 @@ func TestObjectTree(t *testing.T) { store, _ := treestorage.NewInMemoryTreeStorage(root, []string{root.Id}, []*treechangeproto.RawTreeChangeWithId{root}) oTree, err := BuildObjectTree(store, aclList) require.NoError(t, err) + emptyDataTreeDeps = nonVerifiableTreeDeps err = ValidateRawTree(treestorage.TreeStorageCreatePayload{ RootRawChange: oTree.Header(), Heads: []string{root.Id}, @@ -558,6 +560,7 @@ func TestObjectTree(t *testing.T) { }) require.NoError(t, err) allChanges := oTree.Storage().(*treestorage.InMemoryTreeStorage).AllChanges() + emptyDataTreeDeps = nonVerifiableTreeDeps err = ValidateRawTree(treestorage.TreeStorageCreatePayload{ RootRawChange: oTree.Header(), Heads: []string{oTree.Heads()[0]}, @@ -587,6 +590,7 @@ func TestObjectTree(t *testing.T) { }, aclList) require.NoError(t, err) store, _ := treestorage.NewInMemoryTreeStorage(root, []string{root.Id}, []*treechangeproto.RawTreeChangeWithId{root}) + emptyDataTreeDeps = nonVerifiableTreeDeps oTree, err := BuildObjectTree(store, aclList) require.NoError(t, err) _, err = oTree.AddContent(ctx, SignableChangeContent{ @@ -629,7 +633,7 @@ func TestObjectTree(t *testing.T) { changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"), } - defaultObjectTreeDeps = nonVerifiableTreeDeps + emptyDataTreeDeps = nonVerifiableTreeDeps err := ValidateRawTree(treestorage.TreeStorageCreatePayload{ RootRawChange: ctx.objTree.Header(), Heads: []string{"3"}, @@ -1476,7 +1480,7 @@ func TestObjectTree(t *testing.T) { changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"), } - defaultObjectTreeDeps = nonVerifiableTreeDeps + emptyDataTreeDeps = nonVerifiableTreeDeps err := ValidateRawTree(treestorage.TreeStorageCreatePayload{ RootRawChange: ctx.objTree.Header(), Heads: []string{"3"}, @@ -1493,7 +1497,7 @@ func TestObjectTree(t *testing.T) { ctx.objTree.Header(), changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"), } - defaultObjectTreeDeps = nonVerifiableTreeDeps + emptyDataTreeDeps = nonVerifiableTreeDeps err := ValidateRawTree(treestorage.TreeStorageCreatePayload{ RootRawChange: ctx.objTree.Header(), Heads: []string{"3"}, diff --git a/commonspace/object/tree/objecttree/objecttreefactory.go b/commonspace/object/tree/objecttree/objecttreefactory.go index 545bd609..de671050 100644 --- a/commonspace/object/tree/objecttree/objecttreefactory.go +++ b/commonspace/object/tree/objecttree/objecttreefactory.go @@ -63,7 +63,9 @@ func verifiableTreeDeps( } } -func emptyDataTreeDeps( +var emptyDataTreeDeps = verifiableEmptyDataTreeDeps + +func verifiableEmptyDataTreeDeps( rootChange *treechangeproto.RawTreeChangeWithId, treeStorage treestorage.TreeStorage, aclList list.AclList) objectTreeDeps { @@ -156,6 +158,16 @@ func BuildKeyFilterableObjectTree(treeStorage treestorage.TreeStorage, aclList l return buildObjectTree(deps) } +func BuildEmptyDataKeyFilterableObjectTree(treeStorage treestorage.TreeStorage, aclList list.AclList) (ObjectTree, error) { + rootChange, err := treeStorage.Root() + if err != nil { + return nil, err + } + deps := emptyDataTreeDeps(rootChange, treeStorage, aclList) + deps.validator = newTreeValidator(true, true) + return buildObjectTree(deps) +} + func BuildObjectTree(treeStorage treestorage.TreeStorage, aclList list.AclList) (ObjectTree, error) { rootChange, err := treeStorage.Root() if err != nil { diff --git a/commonspace/object/tree/objecttree/objecttreevalidator.go b/commonspace/object/tree/objecttree/objecttreevalidator.go index 0ed68bb1..3ef5625e 100644 --- a/commonspace/object/tree/objecttree/objecttreevalidator.go +++ b/commonspace/object/tree/objecttree/objecttreevalidator.go @@ -9,7 +9,17 @@ import ( "github.com/anyproto/any-sync/util/slice" ) -type ValidatorFunc func(payload treestorage.TreeStorageCreatePayload, buildFunc BuildObjectTreeFunc, aclList list.AclList) (retPayload treestorage.TreeStorageCreatePayload, err error) +type TreeStorageCreator interface { + CreateTreeStorage(payload treestorage.TreeStorageCreatePayload) (treestorage.TreeStorage, error) +} + +type InMemoryStorageCreator struct{} + +func (i InMemoryStorageCreator) CreateTreeStorage(payload treestorage.TreeStorageCreatePayload) (treestorage.TreeStorage, error) { + return treestorage.NewInMemoryTreeStorage(payload.RootRawChange, payload.Heads, payload.Changes) +} + +type ValidatorFunc func(payload treestorage.TreeStorageCreatePayload, storageCreator TreeStorageCreator, aclList list.AclList) (ret ObjectTree, err error) type ObjectTreeValidator interface { // ValidateFullTree should always be entered while holding a read lock on AclList @@ -160,12 +170,15 @@ func (v *objectTreeValidator) validateChange(tree *Tree, aclList list.AclList, c return } -func ValidateRawTreeBuildFunc(payload treestorage.TreeStorageCreatePayload, buildFunc BuildObjectTreeFunc, aclList list.AclList) (newPayload treestorage.TreeStorageCreatePayload, err error) { - treeStorage, err := treestorage.NewInMemoryTreeStorage(payload.RootRawChange, []string{payload.RootRawChange.Id}, nil) +func ValidateRawTreeDefault(payload treestorage.TreeStorageCreatePayload, storageCreator TreeStorageCreator, aclList list.AclList) (objTree ObjectTree, err error) { + treeStorage, err := storageCreator.CreateTreeStorage(treestorage.TreeStorageCreatePayload{ + RootRawChange: payload.RootRawChange, + Heads: []string{payload.RootRawChange.Id}, + }) if err != nil { return } - tree, err := buildFunc(treeStorage, aclList) + tree, err := BuildEmptyDataObjectTree(treeStorage, aclList) if err != nil { return } @@ -179,33 +192,36 @@ func ValidateRawTreeBuildFunc(payload treestorage.TreeStorageCreatePayload, buil return } if !slice.UnsortedEquals(res.Heads, payload.Heads) { - return payload, fmt.Errorf("heads mismatch: %v != %v, %w", res.Heads, payload.Heads, ErrHasInvalidChanges) + return nil, fmt.Errorf("heads mismatch: %v != %v, %w", res.Heads, payload.Heads, ErrHasInvalidChanges) } // if tree has only one change we still should check if the snapshot id is same as root if IsEmptyDerivedTree(tree) { - return payload, ErrDerived + return nil, ErrDerived } - return payload, nil + return tree, nil } -func ValidateFilterRawTree(payload treestorage.TreeStorageCreatePayload, aclList list.AclList) (retPayload treestorage.TreeStorageCreatePayload, err error) { +func ValidateFilterRawTree(payload treestorage.TreeStorageCreatePayload, storageCreator TreeStorageCreator, aclList list.AclList) (objTree ObjectTree, err error) { aclList.RLock() if !aclList.AclState().HadReadPermissions(aclList.AclState().Identity()) { aclList.RUnlock() - return payload, list.ErrNoReadKey + return nil, list.ErrNoReadKey } aclList.RUnlock() - treeStorage, err := treestorage.NewInMemoryTreeStorage(payload.RootRawChange, []string{payload.RootRawChange.Id}, nil) + treeStorage, err := storageCreator.CreateTreeStorage(treestorage.TreeStorageCreatePayload{ + RootRawChange: payload.RootRawChange, + Heads: []string{payload.RootRawChange.Id}, + }) if err != nil { return } - tree, err := BuildKeyFilterableObjectTree(treeStorage, aclList) + tree, err := BuildEmptyDataKeyFilterableObjectTree(treeStorage, aclList) if err != nil { return } tree.Lock() defer tree.Unlock() - res, err := tree.AddRawChanges(context.Background(), RawChangesPayload{ + _, err = tree.AddRawChanges(context.Background(), RawChangesPayload{ NewHeads: payload.Heads, RawChanges: payload.Changes, }) @@ -213,16 +229,12 @@ func ValidateFilterRawTree(payload treestorage.TreeStorageCreatePayload, aclList return } if IsEmptyTree(tree) { - return payload, ErrNoChangeInTree + return nil, ErrNoChangeInTree } - return treestorage.TreeStorageCreatePayload{ - RootRawChange: payload.RootRawChange, - Heads: res.Heads, - Changes: treeStorage.(*treestorage.InMemoryTreeStorage).AllChanges(), - }, nil + return tree, nil } func ValidateRawTree(payload treestorage.TreeStorageCreatePayload, aclList list.AclList) (err error) { - _, err = ValidateRawTreeBuildFunc(payload, BuildObjectTree, aclList) + _, err = ValidateRawTreeDefault(payload, InMemoryStorageCreator{}, aclList) return } diff --git a/commonspace/object/tree/synctree/responsecollector.go b/commonspace/object/tree/synctree/responsecollector.go index 3ed06365..fac21293 100644 --- a/commonspace/object/tree/synctree/responsecollector.go +++ b/commonspace/object/tree/synctree/responsecollector.go @@ -37,21 +37,13 @@ func (r *fullResponseCollector) CollectResponse(ctx context.Context, peerId, obj } validator := r.deps.ValidateObjectTree if validator == nil { - validator = objecttree.ValidateRawTreeBuildFunc + validator = objecttree.ValidateRawTreeDefault } - payload, err := validator(createPayload, r.deps.BuildObjectTree, r.deps.AclList) + objTree, err := validator(createPayload, r.deps.SpaceStorage, r.deps.AclList) if err != nil { return err } - storage, err := r.deps.SpaceStorage.CreateTreeStorage(payload) - if err != nil { - return err - } - r.objectTree, err = r.deps.BuildObjectTree(storage, r.deps.AclList) - if err != nil { - return err - } - r.objectTree.SetEmptyData(true) + r.objectTree = objTree return nil } _, err := r.objectTree.AddRawChanges(ctx, objecttree.RawChangesPayload{