diff --git a/commonspace/object/acl/list/aclstate.go b/commonspace/object/acl/list/aclstate.go index c9d0dbb0..ea8a7e46 100644 --- a/commonspace/object/acl/list/aclstate.go +++ b/commonspace/object/acl/list/aclstate.go @@ -131,6 +131,20 @@ func (st *AclState) CurrentReadKeyId() string { return st.readKeyChanges[len(st.readKeyChanges)-1] } +func (st *AclState) ReadKeyForAclId(id string) (string, error) { + recIdx, ok := st.list.indexes[id] + if !ok { + return "", ErrNoSuchRecord + } + for i := len(st.readKeyChanges) - 1; i >= 0; i-- { + recId := st.readKeyChanges[i] + if recIdx >= st.list.indexes[recId] { + return recId, nil + } + } + return "", ErrNoSuchRecord +} + func (st *AclState) AccountKey() crypto.PrivKey { return st.key } diff --git a/commonspace/object/keyvalue/keyvaluestorage/innerstorage/element.go b/commonspace/object/keyvalue/keyvaluestorage/innerstorage/element.go index 8fd74004..4c9aa8e7 100644 --- a/commonspace/object/keyvalue/keyvaluestorage/innerstorage/element.go +++ b/commonspace/object/keyvalue/keyvaluestorage/innerstorage/element.go @@ -13,11 +13,13 @@ var ErrInvalidSignature = errors.New("invalid signature") type KeyValue struct { KeyPeerId string + ReadKeyId string Key string Value Value TimestampMilli int Identity string PeerId string + AclId string } type Value struct { @@ -47,6 +49,7 @@ func KeyValueFromProto(proto *spacesyncproto.StoreKeyValue, verify bool) (kv Key kv.Identity = identity.Account() kv.PeerId = peerId.PeerId() kv.Key = innerValue.Key + kv.AclId = innerValue.AclHeadId // TODO: check that key-peerId is equal to key+peerId? if verify { if verify, _ = identity.Verify(proto.Value, proto.IdentitySignature); !verify { @@ -71,6 +74,7 @@ func (kv KeyValue) AnyEnc(a *anyenc.Arena) *anyenc.Value { obj := a.NewObject() obj.Set("id", a.NewString(kv.KeyPeerId)) obj.Set("k", a.NewString(kv.Key)) + obj.Set("r", a.NewString(kv.ReadKeyId)) obj.Set("v", kv.Value.AnyEnc(a)) obj.Set("t", a.NewNumberInt(kv.TimestampMilli)) obj.Set("i", a.NewString(kv.Identity)) diff --git a/commonspace/object/keyvalue/keyvaluestorage/innerstorage/keyvaluestorage.go b/commonspace/object/keyvalue/keyvaluestorage/innerstorage/keyvaluestorage.go index 1b62b904..ac941395 100644 --- a/commonspace/object/keyvalue/keyvaluestorage/innerstorage/keyvaluestorage.go +++ b/commonspace/object/keyvalue/keyvaluestorage/innerstorage/keyvaluestorage.go @@ -156,11 +156,12 @@ func (s *storage) keyValueFromDoc(doc anystore.Doc) KeyValue { } return KeyValue{ KeyPeerId: doc.Value().GetString("id"), + ReadKeyId: doc.Value().GetString("r"), Value: value, TimestampMilli: doc.Value().GetInt("t"), Identity: doc.Value().GetString("i"), PeerId: doc.Value().GetString("p"), - Key: doc.Value().GetString("k"), + Key: doc.Value().GetString("k"), } } diff --git a/commonspace/object/keyvalue/keyvaluestorage/storage.go b/commonspace/object/keyvalue/keyvaluestorage/storage.go index f0635211..f0835772 100644 --- a/commonspace/object/keyvalue/keyvaluestorage/storage.go +++ b/commonspace/object/keyvalue/keyvaluestorage/storage.go @@ -2,6 +2,7 @@ package keyvaluestorage import ( "context" + "fmt" "sync" "time" @@ -16,6 +17,8 @@ import ( "github.com/anyproto/any-sync/commonspace/object/keyvalue/keyvaluestorage/innerstorage" "github.com/anyproto/any-sync/commonspace/object/keyvalue/keyvaluestorage/syncstorage" "github.com/anyproto/any-sync/commonspace/spacesyncproto" + "github.com/anyproto/any-sync/util/crypto" + "github.com/anyproto/any-sync/util/slice" ) var log = logger.NewNamed("common.keyvalue.keyvaluestorage") @@ -24,9 +27,11 @@ const IndexerCName = "common.keyvalue.indexer" type Indexer interface { app.Component - Index(keyValue ...innerstorage.KeyValue) error + Index(decryptor Decryptor, keyValue ...innerstorage.KeyValue) error } +type Decryptor = func(kv innerstorage.KeyValue) (value []byte, err error) + type NoOpIndexer struct{} func (n NoOpIndexer) Init(a *app.App) (err error) { @@ -37,7 +42,7 @@ func (n NoOpIndexer) Name() (name string) { return IndexerCName } -func (n NoOpIndexer) Index(keyValue ...innerstorage.KeyValue) error { +func (n NoOpIndexer) Index(decryptor Decryptor, keyValue ...innerstorage.KeyValue) error { return nil } @@ -50,13 +55,15 @@ type Storage interface { } type storage struct { - inner innerstorage.KeyValueStorage - keys *accountdata.AccountKeys - aclList list.AclList - syncClient syncstorage.SyncClient - indexer Indexer - storageId string - mx sync.Mutex + inner innerstorage.KeyValueStorage + keys *accountdata.AccountKeys + aclList list.AclList + syncClient syncstorage.SyncClient + indexer Indexer + storageId string + readKeys map[string]crypto.SymKey + currentReadKey crypto.SymKey + mx sync.Mutex } func New( @@ -80,6 +87,7 @@ func New( aclList: aclList, indexer: indexer, syncClient: syncClient, + readKeys: make(map[string]crypto.SymKey), }, nil } @@ -92,11 +100,21 @@ func (s *storage) Set(ctx context.Context, key string, value []byte) error { defer s.mx.Unlock() s.aclList.RLock() headId := s.aclList.Head().Id - if !s.aclList.AclState().Permissions(s.aclList.AclState().Identity()).CanWrite() { + state := s.aclList.AclState() + if !s.aclList.AclState().Permissions(state.Identity()).CanWrite() { s.aclList.RUnlock() return list.ErrInsufficientPermissions } + readKeyId := state.CurrentReadKeyId() + err := s.readKeysFromAclState(state) + if err != nil { + return err + } s.aclList.RUnlock() + value, err = s.currentReadKey.Encrypt(value) + if err != nil { + return err + } peerIdKey := s.keys.PeerKey identityKey := s.keys.SignKey protoPeerKey, err := peerIdKey.GetPublic().Marshall() @@ -135,6 +153,8 @@ func (s *storage) Set(ctx context.Context, key string, value []byte) error { TimestampMilli: int(timestampMilli), Identity: identityKey.GetPublic().Account(), PeerId: peerIdKey.GetPublic().PeerId(), + AclId: headId, + ReadKeyId: readKeyId, Value: innerstorage.Value{ Value: value, PeerSignature: peerSig, @@ -145,7 +165,7 @@ func (s *storage) Set(ctx context.Context, key string, value []byte) error { if err != nil { return err } - indexErr := s.indexer.Index(keyValue) + indexErr := s.indexer.Index(s.decrypt, keyValue) if indexErr != nil { log.Warn("failed to index for key", zap.String("key", key), zap.Error(indexErr)) } @@ -156,7 +176,7 @@ func (s *storage) Set(ctx context.Context, key string, value []byte) error { return nil } -func (s *storage) SetRaw(ctx context.Context, keyValue ...*spacesyncproto.StoreKeyValue) error { +func (s *storage) SetRaw(ctx context.Context, keyValue ...*spacesyncproto.StoreKeyValue) (err error) { if len(keyValue) == 0 { return nil } @@ -170,18 +190,36 @@ func (s *storage) SetRaw(ctx context.Context, keyValue ...*spacesyncproto.StoreK } keyValues = append(keyValues, innerKv) } - err := s.inner.Set(ctx, keyValues...) + s.aclList.RLock() + state := s.aclList.AclState() + err = s.readKeysFromAclState(state) if err != nil { + s.aclList.RUnlock() return err } - indexErr := s.indexer.Index(keyValues...) - if indexErr != nil { - log.Warn("failed to index for keys", zap.Error(indexErr)) + for _, kv := range keyValues { + kv.ReadKeyId, err = state.ReadKeyForAclId(kv.AclId) + if err != nil { + kv.KeyPeerId = "" + continue + } + } + s.aclList.RUnlock() + keyValues = slice.DiscardFromSlice(keyValues, func(value innerstorage.KeyValue) bool { + return value.KeyPeerId == "" + }) + err = s.inner.Set(ctx, keyValues...) + if err != nil { + return err } sendErr := s.syncClient.Broadcast(ctx, s.storageId, keyValues...) if sendErr != nil { log.Warn("failed to send key values", zap.Error(sendErr)) } + indexErr := s.indexer.Index(s.decrypt, keyValues...) + if indexErr != nil { + log.Warn("failed to index for keys", zap.Error(indexErr)) + } return nil } @@ -196,3 +234,57 @@ func (s *storage) GetAll(ctx context.Context, key string) (values []innerstorage func (s *storage) InnerStorage() innerstorage.KeyValueStorage { return s.inner } + +func (s *storage) readKeysFromAclState(state *list.AclState) (err error) { + if len(s.readKeys) == len(state.Keys()) { + return nil + } + if state.AccountKey() == nil || !state.HadReadPermissions(state.AccountKey().GetPublic()) { + return nil + } + for key, value := range state.Keys() { + if _, exists := s.readKeys[key]; exists { + continue + } + if value.ReadKey == nil { + continue + } + treeKey, err := deriveKey(value.ReadKey, s.storageId) + if err != nil { + return err + } + s.readKeys[key] = treeKey + } + curKey, err := state.CurrentReadKey() + if err != nil { + return err + } + if curKey == nil { + return nil + } + s.currentReadKey, err = deriveKey(curKey, s.storageId) + return err +} + +func (s *storage) decrypt(kv innerstorage.KeyValue) (value []byte, err error) { + if kv.ReadKeyId == "" { + return nil, fmt.Errorf("no read key id") + } + key := s.readKeys[kv.ReadKeyId] + if key == nil { + return nil, fmt.Errorf("no read key for %s", kv.ReadKeyId) + } + value, err = key.Decrypt(kv.Value.Value) + if err != nil { + return nil, err + } + return value, nil +} + +func deriveKey(key crypto.SymKey, id string) (crypto.SymKey, error) { + raw, err := key.Raw() + if err != nil { + return nil, err + } + return crypto.DeriveSymmetricKey(raw, fmt.Sprintf(crypto.AnysyncKeyValuePath, id)) +} diff --git a/util/crypto/derived.go b/util/crypto/derived.go index 0318b17a..86f64d0d 100644 --- a/util/crypto/derived.go +++ b/util/crypto/derived.go @@ -5,8 +5,9 @@ import ( ) const ( - AnysyncSpacePath = "m/SLIP-0021/anysync/space" - AnysyncTreePath = "m/SLIP-0021/anysync/tree/%s" + AnysyncSpacePath = "m/SLIP-0021/anysync/space" + AnysyncTreePath = "m/SLIP-0021/anysync/tree/%s" + AnysyncKeyValuePath = "m/SLIP-0021/anysync/keyvalue/%s" ) // DeriveSymmetricKey derives a symmetric key from seed and path using slip-21