diff --git a/commonspace/headsync/statestorage/statestorage.go b/commonspace/headsync/statestorage/statestorage.go index 24a16c61..d1a17ad5 100644 --- a/commonspace/headsync/statestorage/statestorage.go +++ b/commonspace/headsync/statestorage/statestorage.go @@ -99,13 +99,26 @@ func New(ctx context.Context, spaceId string, store anystore.DB) (StateStorage, return storage, nil } -func Create(ctx context.Context, state State, store anystore.DB) (StateStorage, error) { - arena := &anyenc.Arena{} - stateCollection, err := store.Collection(ctx, stateCollectionKey) +func Create(ctx context.Context, state State, store anystore.DB) (st StateStorage, err error) { + tx, err := store.WriteTx(ctx) if err != nil { return nil, err } - tx, err := stateCollection.WriteTx(ctx) + storage, err := CreateTx(tx.Context(), state, store) + defer func() { + if err != nil { + tx.Rollback() + } + }() + if err != nil { + return nil, err + } + return storage, tx.Commit() +} + +func CreateTx(ctx context.Context, state State, store anystore.DB) (StateStorage, error) { + arena := &anyenc.Arena{} + stateCollection, err := store.Collection(ctx, stateCollectionKey) if err != nil { return nil, err } @@ -115,9 +128,8 @@ func Create(ctx context.Context, state State, store anystore.DB) (StateStorage, doc.Set(settingsIdKey, arena.NewString(state.SettingsId)) doc.Set(headerKey, arena.NewBinary(state.SpaceHeader)) doc.Set(aclIdKey, arena.NewString(state.AclId)) - err = stateCollection.Insert(tx.Context(), doc) + err = stateCollection.Insert(ctx, doc) if err != nil { - tx.Rollback() return nil, err } return &stateStorage{ @@ -126,7 +138,7 @@ func Create(ctx context.Context, state State, store anystore.DB) (StateStorage, settingsId: state.SettingsId, stateColl: stateCollection, arena: arena, - }, tx.Commit() + }, nil } func (s *stateStorage) SettingsId() string { diff --git a/commonspace/object/acl/list/storage.go b/commonspace/object/acl/list/storage.go index 4c96d24c..a54eedf3 100644 --- a/commonspace/object/acl/list/storage.go +++ b/commonspace/object/acl/list/storage.go @@ -61,6 +61,23 @@ type storage struct { } func CreateStorage(ctx context.Context, root *consensusproto.RawRecordWithId, headStorage headstorage.HeadStorage, store anystore.DB) (Storage, error) { + tx, err := store.WriteTx(ctx) + if err != nil { + return nil, err + } + storage, err := CreateStorageTx(tx.Context(), root, headStorage, store) + defer func() { + if err != nil { + tx.Rollback() + } + }() + if err != nil { + return nil, err + } + return storage, tx.Commit() +} + +func CreateStorageTx(ctx context.Context, root *consensusproto.RawRecordWithId, headStorage headstorage.HeadStorage, store anystore.DB) (Storage, error) { st := &storage{ id: root.Id, store: store, @@ -89,24 +106,18 @@ func CreateStorage(ctx context.Context, root *consensusproto.RawRecordWithId, he st.arena = &anyenc.Arena{} defer st.arena.Reset() doc := newStorageRecordValue(rec, st.arena) - tx, err := st.store.WriteTx(ctx) + err = st.recordsColl.Insert(ctx, doc) if err != nil { return nil, err } - err = st.recordsColl.Insert(tx.Context(), doc) - if err != nil { - tx.Rollback() - return nil, err - } - err = st.headStorage.UpdateEntryTx(tx.Context(), headstorage.HeadsUpdate{ + err = st.headStorage.UpdateEntryTx(ctx, headstorage.HeadsUpdate{ Id: root.Id, Heads: []string{root.Id}, }) if err != nil { - tx.Rollback() return nil, err } - return st, tx.Commit() + return st, nil } func NewStorage(ctx context.Context, id string, headStorage headstorage.HeadStorage, store anystore.DB) (Storage, error) { diff --git a/commonspace/object/tree/objecttree/storage.go b/commonspace/object/tree/objecttree/storage.go index 587c88d0..97ae6fa2 100644 --- a/commonspace/object/tree/objecttree/storage.go +++ b/commonspace/object/tree/objecttree/storage.go @@ -77,6 +77,23 @@ type storage struct { var StorageChangeBuilder = NewChangeBuilder func CreateStorage(ctx context.Context, root *treechangeproto.RawTreeChangeWithId, headStorage headstorage.HeadStorage, store anystore.DB) (Storage, error) { + tx, err := store.WriteTx(ctx) + if err != nil { + return nil, err + } + storage, err := CreateStorageTx(tx.Context(), root, headStorage, store) + defer func() { + if err != nil { + tx.Rollback() + } + }() + if err != nil { + return nil, err + } + return storage, tx.Commit() +} + +func CreateStorageTx(ctx context.Context, root *treechangeproto.RawTreeChangeWithId, headStorage headstorage.HeadStorage, store anystore.DB) (Storage, error) { st := &storage{ id: root.Id, store: store, @@ -107,29 +124,23 @@ func CreateStorage(ctx context.Context, root *treechangeproto.RawTreeChangeWithI st.parser = &anyenc.Parser{} defer st.arena.Reset() doc := newStorageChangeValue(stChange, st.arena) - tx, err := st.store.WriteTx(ctx) + err = st.changesColl.Insert(ctx, doc) if err != nil { - return nil, err - } - err = st.changesColl.Insert(tx.Context(), doc) - if err != nil { - tx.Rollback() if errors.Is(err, anystore.ErrDocExists) { return nil, treestorage.ErrTreeExists } return nil, err } - err = st.headStorage.UpdateEntryTx(tx.Context(), headstorage.HeadsUpdate{ + err = st.headStorage.UpdateEntryTx(ctx, headstorage.HeadsUpdate{ Id: root.Id, Heads: []string{root.Id}, CommonSnapshot: &root.Id, IsDerived: &unmarshalled.IsDerived, }) if err != nil { - tx.Rollback() return nil, err } - return st, tx.Commit() + return st, nil } func NewStorage(ctx context.Context, id string, headStorage headstorage.HeadStorage, store anystore.DB) (Storage, error) { diff --git a/commonspace/spacestorage/spacestorage.go b/commonspace/spacestorage/spacestorage.go index 30f6d258..fac0f398 100644 --- a/commonspace/spacestorage/spacestorage.go +++ b/commonspace/spacestorage/spacestorage.go @@ -52,7 +52,7 @@ type SpaceStorageProvider interface { CreateSpaceStorage(ctx context.Context, payload SpaceStorageCreatePayload) (SpaceStorage, error) } -func Create(ctx context.Context, store anystore.DB, payload SpaceStorageCreatePayload) (SpaceStorage, error) { +func Create(ctx context.Context, store anystore.DB, payload SpaceStorageCreatePayload) (st SpaceStorage, err error) { spaceId := payload.SpaceHeaderWithId.Id state := statestorage.State{ AclId: payload.AclWithId.Id, @@ -60,7 +60,16 @@ func Create(ctx context.Context, store anystore.DB, payload SpaceStorageCreatePa SpaceId: payload.SpaceHeaderWithId.Id, SpaceHeader: payload.SpaceHeaderWithId.RawHeader, } - changesColl, err := store.Collection(ctx, objecttree.CollName) + tx, err := store.WriteTx(ctx) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + tx.Rollback() + } + }() + changesColl, err := store.Collection(tx.Context(), objecttree.CollName) if err != nil { return nil, err } @@ -68,27 +77,27 @@ func Create(ctx context.Context, store anystore.DB, payload SpaceStorageCreatePa Fields: []string{objecttree.TreeKey, objecttree.OrderKey}, Unique: true, } - err = changesColl.EnsureIndex(ctx, orderIdx) + err = changesColl.EnsureIndex(tx.Context(), orderIdx) if err != nil { return nil, err } // TODO: put it in one transaction - stateStorage, err := statestorage.Create(ctx, state, store) + stateStorage, err := statestorage.CreateTx(tx.Context(), state, store) if err != nil { return nil, err } - headStorage, err := headstorage.New(ctx, store) + headStorage, err := headstorage.New(tx.Context(), store) if err != nil { return nil, err } - aclStorage, err := list.CreateStorage(ctx, &consensusproto.RawRecordWithId{ + aclStorage, err := list.CreateStorageTx(tx.Context(), &consensusproto.RawRecordWithId{ Payload: payload.AclWithId.Payload, Id: payload.AclWithId.Id, }, headStorage, store) if err != nil { return nil, err } - _, err = objecttree.CreateStorage(ctx, &treechangeproto.RawTreeChangeWithId{ + _, err = objecttree.CreateStorageTx(tx.Context(), &treechangeproto.RawTreeChangeWithId{ RawChange: payload.SpaceSettingsWithId.RawChange, Id: payload.SpaceSettingsWithId.Id, }, headStorage, store) @@ -101,7 +110,7 @@ func Create(ctx context.Context, store anystore.DB, payload SpaceStorageCreatePa headStorage: headStorage, stateStorage: stateStorage, aclStorage: aclStorage, - }, nil + }, tx.Commit() } func New(ctx context.Context, spaceId string, store anystore.DB) (SpaceStorage, error) { diff --git a/commonspace/spacestorage_test.go b/commonspace/spacestorage_test.go new file mode 100644 index 00000000..51059d51 --- /dev/null +++ b/commonspace/spacestorage_test.go @@ -0,0 +1,52 @@ +package commonspace + +import ( + "context" + "path/filepath" + "testing" + + anystore "github.com/anyproto/any-store" + "github.com/stretchr/testify/require" + + "github.com/anyproto/any-sync/commonspace/object/accountdata" + "github.com/anyproto/any-sync/commonspace/spacestorage" + "github.com/anyproto/any-sync/util/crypto" +) + +func newStorageCreatePayload(t *testing.T) spacestorage.SpaceStorageCreatePayload { + keys, err := accountdata.NewRandom() + require.NoError(t, err) + masterKey, _, err := crypto.GenerateRandomEd25519KeyPair() + require.NoError(t, err) + metaKey, _, err := crypto.GenerateRandomEd25519KeyPair() + require.NoError(t, err) + readKey := crypto.NewAES() + meta := []byte("account") + payload := SpaceCreatePayload{ + SigningKey: keys.SignKey, + SpaceType: "space", + ReplicationKey: 10, + SpacePayload: nil, + MasterKey: masterKey, + ReadKey: readKey, + MetadataKey: metaKey, + Metadata: meta, + } + createSpace, err := StoragePayloadForSpaceCreate(payload) + require.NoError(t, err) + return createSpace +} + +var ctx = context.Background() + +func TestCreateSpaceStorageFailed_EmptyStorage(t *testing.T) { + payload := newStorageCreatePayload(t) + store, err := anystore.Open(ctx, filepath.Join(t.TempDir(), "store.db"), nil) + require.NoError(t, err) + payload.SpaceSettingsWithId.RawChange = nil + _, err = spacestorage.Create(ctx, store, payload) + require.Error(t, err) + collNames, err := store.GetCollectionNames(ctx) + require.NoError(t, err) + require.Empty(t, collNames) +} \ No newline at end of file