diff --git a/commonspace/object/tree/objecttree/mock_objecttree/mock_objecttree.go b/commonspace/object/tree/objecttree/mock_objecttree/mock_objecttree.go index aea9b3a0..eaa0815e 100644 --- a/commonspace/object/tree/objecttree/mock_objecttree/mock_objecttree.go +++ b/commonspace/object/tree/objecttree/mock_objecttree/mock_objecttree.go @@ -103,6 +103,21 @@ func (mr *MockObjectTreeMockRecorder) AddRawChanges(arg0, arg1 any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawChanges", reflect.TypeOf((*MockObjectTree)(nil).AddRawChanges), arg0, arg1) } +// AddRawChangesWithUpdater mocks base method. +func (m *MockObjectTree) AddRawChangesWithUpdater(arg0 context.Context, arg1 objecttree.RawChangesPayload, arg2 objecttree.Updater) (objecttree.AddResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddRawChangesWithUpdater", arg0, arg1, arg2) + ret0, _ := ret[0].(objecttree.AddResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddRawChangesWithUpdater indicates an expected call of AddRawChangesWithUpdater. +func (mr *MockObjectTreeMockRecorder) AddRawChangesWithUpdater(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawChangesWithUpdater", reflect.TypeOf((*MockObjectTree)(nil).AddRawChangesWithUpdater), arg0, arg1, arg2) +} + // ChangeInfo mocks base method. func (m *MockObjectTree) ChangeInfo() *treechangeproto.TreeChangeInfo { m.ctrl.T.Helper() diff --git a/commonspace/object/tree/objecttree/objecttree.go b/commonspace/object/tree/objecttree/objecttree.go index ce19326a..4ae10ee0 100644 --- a/commonspace/object/tree/objecttree/objecttree.go +++ b/commonspace/object/tree/objecttree/objecttree.go @@ -29,6 +29,8 @@ var ( ErrNoAclHead = errors.New("no acl head") ) +type Updater func(tree ObjectTree, md Mode) error + type AddResultSummary int type AddResult struct { @@ -86,6 +88,7 @@ type ObjectTree interface { AddContent(ctx context.Context, content SignableChangeContent) (AddResult, error) AddContentWithValidator(ctx context.Context, content SignableChangeContent, validate func(change *treechangeproto.RawTreeChangeWithId) error) (AddResult, error) AddRawChanges(ctx context.Context, changes RawChangesPayload) (AddResult, error) + AddRawChangesWithUpdater(ctx context.Context, changes RawChangesPayload, updater Updater) (AddResult, error) UnpackChange(raw *treechangeproto.RawTreeChangeWithId) (data []byte, err error) PrepareChange(content SignableChangeContent) (res *treechangeproto.RawTreeChangeWithId, err error) @@ -93,7 +96,6 @@ type ObjectTree interface { Delete() error Close() error SetFlusher(flusher Flusher) - Flush() error TryClose(objectTTL time.Duration) (bool, error) } @@ -340,14 +342,14 @@ func (ot *objectTree) prepareBuilderContent(content SignableChangeContent) (cnt return } -func (ot *objectTree) AddRawChanges(ctx context.Context, changesPayload RawChangesPayload) (addResult AddResult, err error) { +func (ot *objectTree) AddRawChangesWithUpdater(ctx context.Context, changes RawChangesPayload, updater Updater) (addResult AddResult, err error) { if ot.isDeleted { err = ErrDeleted return } ot.logUseWhenUnlocked() lastHeadId := ot.tree.lastIteratedHeadId - addResult, err = ot.addRawChanges(ctx, changesPayload) + addResult, err = ot.addChangesToTree(ctx, changes) if err != nil { return } @@ -363,18 +365,35 @@ func (ot *objectTree) AddRawChanges(ctx context.Context, changesPayload RawChang addResult.Mode = Rebuild } - err = ot.treeStorage.AddRawChangesSetHeads(addResult.Added, addResult.Heads) - if err != nil { - // rolling back all changes made to inmemory state + rollback := func() { rebuildErr := ot.rebuildFromStorage(nil, nil) if rebuildErr != nil { log.Error("failed to rebuild after adding to storage", zap.Strings("heads", ot.Heads()), zap.Error(rebuildErr)) } } + + if updater != nil { + err = updater(ot, addResult.Mode) + if err != nil { + rollback() + return + } + } + + err = ot.treeStorage.AddRawChangesSetHeads(addResult.Added, addResult.Heads) + if err != nil { + rollback() + return + } + ot.flusher.Flush(ot) return } -func (ot *objectTree) addRawChanges(ctx context.Context, changesPayload RawChangesPayload) (addResult AddResult, err error) { +func (ot *objectTree) AddRawChanges(ctx context.Context, changesPayload RawChangesPayload) (addResult AddResult, err error) { + return ot.AddRawChangesWithUpdater(ctx, changesPayload, nil) +} + +func (ot *objectTree) addChangesToTree(ctx context.Context, changesPayload RawChangesPayload) (addResult AddResult, err error) { // resetting buffers ot.newChangesBuf = ot.newChangesBuf[:0] ot.notSeenIdxBuf = ot.notSeenIdxBuf[:0] @@ -528,13 +547,6 @@ func (ot *objectTree) addRawChanges(ctx context.Context, changesPayload RawChang } } -func (ot *objectTree) Flush() error { - if ot.isDeleted { - return ErrDeleted - } - return ot.flusher.Flush(ot) -} - func (ot *objectTree) createAddResult(oldHeads []string, mode Mode, treeChangesAdded []*Change, rawChanges []*treechangeproto.RawTreeChangeWithId) (addResult AddResult, err error) { headsCopy := func() []string { newHeads := make([]string, 0, len(ot.tree.Heads())) diff --git a/commonspace/object/tree/objecttree/objecttree_test.go b/commonspace/object/tree/objecttree/objecttree_test.go index 13b2cd4e..93aa89cb 100644 --- a/commonspace/object/tree/objecttree/objecttree_test.go +++ b/commonspace/object/tree/objecttree/objecttree_test.go @@ -427,6 +427,21 @@ func TestObjectTree(t *testing.T) { oTree, err := BuildObjectTree(store, aclList) require.NoError(t, err) + t.Run("add content validate failed", func(t *testing.T) { + _, err := oTree.AddContentWithValidator(ctx, SignableChangeContent{ + Data: []byte("some"), + Key: keys.SignKey, + IsSnapshot: false, + IsEncrypted: true, + Timestamp: 0, + DataType: mockDataType, + }, func(change *treechangeproto.RawTreeChangeWithId) error { + return errors.New("validation failed") + }) + require.Error(t, err) + require.Len(t, oTree.Heads(), 1) + require.Equal(t, root.Id, oTree.Root().Id) + }) t.Run("0 timestamp is changed to current, data type is correct", func(t *testing.T) { start := time.Now() res, err := oTree.AddContent(ctx, SignableChangeContent{ @@ -815,7 +830,24 @@ func TestObjectTree(t *testing.T) { RawChanges: rawChanges, } - res, err := objTree.AddRawChanges(context.Background(), payload) + res, err := objTree.AddRawChangesWithUpdater(context.Background(), payload, func(tree ObjectTree, md Mode) error { + // check tree iterate + var iterChangesId []string + err := objTree.IterateRoot(nil, func(change *Change) bool { + iterChangesId = append(iterChangesId, change.Id) + return true + }) + require.NoError(t, err, "iterate should be without error") + assert.Equal(t, []string{"0", "1", "2", "3", "4"}, iterChangesId) + assert.Equal(t, "0", objTree.Root().Id) + + for _, ch := range rawChanges { + treeCh, err := objTree.GetChange(ch.Id) + require.NoError(t, err) + require.True(t, treeCh.IsNew) + } + return nil + }) require.NoError(t, err, "adding changes should be without error") // check result @@ -827,36 +859,16 @@ func TestObjectTree(t *testing.T) { // check tree heads assert.Equal(t, []string{"4"}, objTree.Heads()) - // check tree iterate - var iterChangesId []string - err = objTree.IterateRoot(nil, func(change *Change) bool { - iterChangesId = append(iterChangesId, change.Id) - return true - }) - require.NoError(t, err, "iterate should be without error") - assert.Equal(t, []string{"0", "1", "2", "3", "4"}, iterChangesId) - // before Flush - assert.Equal(t, "0", objTree.Root().Id) - // check storage heads, _ := treeStorage.Heads() assert.Equal(t, []string{"4"}, heads) - for _, ch := range rawChanges { - treeCh, err := objTree.GetChange(ch.Id) - require.NoError(t, err) - require.True(t, treeCh.IsNew) - raw, err := treeStorage.GetRawChange(context.Background(), ch.Id) - assert.NoError(t, err, "storage should have all the changes") - assert.Equal(t, ch, raw, "the changes in the storage should be the same") - } - - err = objTree.Flush() - require.NoError(t, err) - // after Flush assert.Equal(t, "3", objTree.Root().Id) for _, ch := range rawChanges { + raw, err := treeStorage.GetRawChange(context.Background(), ch.Id) + assert.NoError(t, err, "storage should have all the changes") + assert.Equal(t, ch, raw, "the changes in the storage should be the same") treeCh, err := objTree.GetChange(ch.Id) if ch.Id == "3" || ch.Id == "4" { require.NoError(t, err) diff --git a/commonspace/object/tree/synctree/mock_synctree/mock_synctree.go b/commonspace/object/tree/synctree/mock_synctree/mock_synctree.go index fee2c652..f231e61f 100644 --- a/commonspace/object/tree/synctree/mock_synctree/mock_synctree.go +++ b/commonspace/object/tree/synctree/mock_synctree/mock_synctree.go @@ -127,6 +127,21 @@ func (mr *MockSyncTreeMockRecorder) AddRawChangesFromPeer(arg0, arg1, arg2 any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawChangesFromPeer", reflect.TypeOf((*MockSyncTree)(nil).AddRawChangesFromPeer), arg0, arg1, arg2) } +// AddRawChangesWithUpdater mocks base method. +func (m *MockSyncTree) AddRawChangesWithUpdater(arg0 context.Context, arg1 objecttree.RawChangesPayload, arg2 objecttree.Updater) (objecttree.AddResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddRawChangesWithUpdater", arg0, arg1, arg2) + ret0, _ := ret[0].(objecttree.AddResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddRawChangesWithUpdater indicates an expected call of AddRawChangesWithUpdater. +func (mr *MockSyncTreeMockRecorder) AddRawChangesWithUpdater(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawChangesWithUpdater", reflect.TypeOf((*MockSyncTree)(nil).AddRawChangesWithUpdater), arg0, arg1, arg2) +} + // ChangeInfo mocks base method. func (m *MockSyncTree) ChangeInfo() *treechangeproto.TreeChangeInfo { m.ctrl.T.Helper() diff --git a/commonspace/object/tree/synctree/synctree.go b/commonspace/object/tree/synctree/synctree.go index 9929b559..e187417e 100644 --- a/commonspace/object/tree/synctree/synctree.go +++ b/commonspace/object/tree/synctree/synctree.go @@ -238,21 +238,23 @@ func (s *syncTree) AddRawChanges(ctx context.Context, changesPayload objecttree. if err = s.checkAlive(); err != nil { return } - res, err = s.ObjectTree.AddRawChanges(ctx, changesPayload) + res, err = s.ObjectTree.AddRawChangesWithUpdater(ctx, changesPayload, func(tree objecttree.ObjectTree, md objecttree.Mode) error { + if s.listener != nil { + switch md { + case objecttree.Nothing: + return nil + case objecttree.Append: + return s.listener.Update(s) + case objecttree.Rebuild: + return s.listener.Rebuild(s) + } + } + return nil + }) if err != nil { return } - if s.listener != nil { - switch res.Mode { - case objecttree.Nothing: - return - case objecttree.Append: - s.listener.Update(s) - case objecttree.Rebuild: - s.listener.Rebuild(s) - } - } - s.flush() + if res.Mode != objecttree.Nothing { if s.notifiable != nil { s.notifiable.UpdateHeads(s.Id(), res.Heads) @@ -338,16 +340,8 @@ func (s *syncTree) SyncWithPeer(ctx context.Context, p peer.Peer) (err error) { func (s *syncTree) afterBuild() { if s.listener != nil { s.listener.Rebuild(s) - s.flush() } if s.notifiable != nil { s.notifiable.UpdateHeads(s.Id(), s.Heads()) } } - -func (s *syncTree) flush() { - err := s.Flush() - if err != nil { - log.Warn("flush error", zap.Error(err)) - } -} diff --git a/commonspace/object/tree/synctree/synctree_test.go b/commonspace/object/tree/synctree/synctree_test.go index af3905e0..629d8598 100644 --- a/commonspace/object/tree/synctree/synctree_test.go +++ b/commonspace/object/tree/synctree/synctree_test.go @@ -204,10 +204,16 @@ func Test_SyncTree(t *testing.T) { objTreeMock.EXPECT().Heads().Return([]string{"headId"}).Times(2) objTreeMock.EXPECT().Heads().Return([]string{"headId1"}).Times(1) objTreeMock.EXPECT().HasChanges(gomock.Any()).AnyTimes().Return(false) - objTreeMock.EXPECT().AddRawChanges(gomock.Any(), gomock.Eq(payload)). - Return(expectedRes, nil) + objTreeMock.EXPECT().AddRawChangesWithUpdater(gomock.Any(), gomock.Eq(payload), gomock.Any()). + DoAndReturn(func(ctx context.Context, changes objecttree.RawChangesPayload, updater objecttree.Updater) (addResult objecttree.AddResult, err error) { + err = updater(objTreeMock, objecttree.Append) + if err != nil { + return objecttree.AddResult{}, err + } + return expectedRes, nil + }) notifiableMock.EXPECT().UpdateHeads("id", []string{"headId1"}) - updateListenerMock.EXPECT().Update(tr) + updateListenerMock.EXPECT().Update(tr).Return(nil) syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), "peerId", gomock.Eq(changes)).Return(headUpdate, nil) syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)) @@ -231,10 +237,16 @@ func Test_SyncTree(t *testing.T) { objTreeMock.EXPECT().Heads().Return([]string{"headId"}).Times(2) objTreeMock.EXPECT().Heads().Return([]string{"headId1"}).Times(1) objTreeMock.EXPECT().HasChanges(gomock.Any()).AnyTimes().Return(false) - objTreeMock.EXPECT().AddRawChanges(gomock.Any(), gomock.Eq(payload)). - Return(expectedRes, nil) + objTreeMock.EXPECT().AddRawChangesWithUpdater(gomock.Any(), gomock.Eq(payload), gomock.Any()). + DoAndReturn(func(ctx context.Context, changes objecttree.RawChangesPayload, updater objecttree.Updater) (addResult objecttree.AddResult, err error) { + err = updater(objTreeMock, objecttree.Rebuild) + if err != nil { + return objecttree.AddResult{}, err + } + return expectedRes, nil + }) notifiableMock.EXPECT().UpdateHeads("id", []string{"headId1"}) - updateListenerMock.EXPECT().Rebuild(tr) + updateListenerMock.EXPECT().Rebuild(tr).Return(nil) syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), "peerId", gomock.Eq(changes)).Return(headUpdate, nil) syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)) @@ -275,8 +287,14 @@ func Test_SyncTree(t *testing.T) { } objTreeMock.EXPECT().Heads().Return([]string{"headId"}).AnyTimes() objTreeMock.EXPECT().HasChanges(gomock.Any()).AnyTimes().Return(false) - objTreeMock.EXPECT().AddRawChanges(gomock.Any(), gomock.Eq(payload)). - Return(expectedRes, nil) + objTreeMock.EXPECT().AddRawChangesWithUpdater(gomock.Any(), gomock.Eq(payload), gomock.Any()). + DoAndReturn(func(ctx context.Context, changes objecttree.RawChangesPayload, updater objecttree.Updater) (addResult objecttree.AddResult, err error) { + err = updater(objTreeMock, objecttree.Nothing) + if err != nil { + return objecttree.AddResult{}, err + } + return expectedRes, nil + }) res, err := tr.AddRawChangesFromPeer(ctx, "peerId", payload) require.NoError(t, err) require.Equal(t, expectedRes, res) @@ -293,7 +311,7 @@ func Test_SyncTree(t *testing.T) { Added: changes, } objTreeMock.EXPECT().Id().Return("id").AnyTimes() - objTreeMock.EXPECT().AddContent(gomock.Any(), gomock.Eq(content)). + objTreeMock.EXPECT().AddContentWithValidator(gomock.Any(), gomock.Eq(content), gomock.Any()). Return(expectedRes, nil) syncStatusMock.EXPECT().HeadsChange("id", []string{"headId"}) notifiableMock.EXPECT().UpdateHeads("id", []string{"headId"}) diff --git a/commonspace/object/tree/synctree/updatelistener/mock_updatelistener/mock_updatelistener.go b/commonspace/object/tree/synctree/updatelistener/mock_updatelistener/mock_updatelistener.go index 54035163..5e135802 100644 --- a/commonspace/object/tree/synctree/updatelistener/mock_updatelistener/mock_updatelistener.go +++ b/commonspace/object/tree/synctree/updatelistener/mock_updatelistener/mock_updatelistener.go @@ -40,9 +40,11 @@ func (m *MockUpdateListener) EXPECT() *MockUpdateListenerMockRecorder { } // Rebuild mocks base method. -func (m *MockUpdateListener) Rebuild(arg0 objecttree.ObjectTree) { +func (m *MockUpdateListener) Rebuild(arg0 objecttree.ObjectTree) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Rebuild", arg0) + ret := m.ctrl.Call(m, "Rebuild", arg0) + ret0, _ := ret[0].(error) + return ret0 } // Rebuild indicates an expected call of Rebuild. @@ -52,9 +54,11 @@ func (mr *MockUpdateListenerMockRecorder) Rebuild(arg0 any) *gomock.Call { } // Update mocks base method. -func (m *MockUpdateListener) Update(arg0 objecttree.ObjectTree) { +func (m *MockUpdateListener) Update(arg0 objecttree.ObjectTree) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Update", arg0) + ret := m.ctrl.Call(m, "Update", arg0) + ret0, _ := ret[0].(error) + return ret0 } // Update indicates an expected call of Update. diff --git a/commonspace/object/tree/synctree/updatelistener/updatelistener.go b/commonspace/object/tree/synctree/updatelistener/updatelistener.go index f0d9ab5e..ebee4f36 100644 --- a/commonspace/object/tree/synctree/updatelistener/updatelistener.go +++ b/commonspace/object/tree/synctree/updatelistener/updatelistener.go @@ -6,6 +6,6 @@ import ( ) type UpdateListener interface { - Update(tree objecttree.ObjectTree) - Rebuild(tree objecttree.ObjectTree) + Update(tree objecttree.ObjectTree) error + Rebuild(tree objecttree.ObjectTree) error } diff --git a/commonspace/settings/settingsobject.go b/commonspace/settings/settingsobject.go index c19162db..29e6bafb 100644 --- a/commonspace/settings/settingsobject.go +++ b/commonspace/settings/settingsobject.go @@ -119,15 +119,17 @@ func (s *settingsObject) updateIds(tr objecttree.ObjectTree) { } // Update is called as part of UpdateListener interface -func (s *settingsObject) Update(tr objecttree.ObjectTree) { +func (s *settingsObject) Update(tr objecttree.ObjectTree) error { s.updateIds(tr) + return nil } // Rebuild is called as part of UpdateListener interface (including when the object is built for the first time, e.g. on Init call) -func (s *settingsObject) Rebuild(tr objecttree.ObjectTree) { +func (s *settingsObject) Rebuild(tr objecttree.ObjectTree) error { // at initial build "s" may not contain the object tree, so it is safer to provide it from the function parameter s.state = nil s.updateIds(tr) + return nil } func (s *settingsObject) Init(ctx context.Context) (err error) {