1
0
Fork 0
mirror of https://github.com/anyproto/any-sync.git synced 2025-06-08 14:07:02 +09:00

Add encryption

This commit is contained in:
Mikhail Rakhmanov 2025-04-07 21:37:55 +02:00
parent d2ea2ba75d
commit c211ccf26e
No known key found for this signature in database
GPG key ID: DED12CFEF5B8396B
5 changed files with 131 additions and 19 deletions

View file

@ -131,6 +131,20 @@ func (st *AclState) CurrentReadKeyId() string {
return st.readKeyChanges[len(st.readKeyChanges)-1] return st.readKeyChanges[len(st.readKeyChanges)-1]
} }
func (st *AclState) ReadKeyForAclId(id string) (string, error) {
recIdx, ok := st.list.indexes[id]
if !ok {
return "", ErrNoSuchRecord
}
for i := len(st.readKeyChanges) - 1; i >= 0; i-- {
recId := st.readKeyChanges[i]
if recIdx >= st.list.indexes[recId] {
return recId, nil
}
}
return "", ErrNoSuchRecord
}
func (st *AclState) AccountKey() crypto.PrivKey { func (st *AclState) AccountKey() crypto.PrivKey {
return st.key return st.key
} }

View file

@ -13,11 +13,13 @@ var ErrInvalidSignature = errors.New("invalid signature")
type KeyValue struct { type KeyValue struct {
KeyPeerId string KeyPeerId string
ReadKeyId string
Key string Key string
Value Value Value Value
TimestampMilli int TimestampMilli int
Identity string Identity string
PeerId string PeerId string
AclId string
} }
type Value struct { type Value struct {
@ -47,6 +49,7 @@ func KeyValueFromProto(proto *spacesyncproto.StoreKeyValue, verify bool) (kv Key
kv.Identity = identity.Account() kv.Identity = identity.Account()
kv.PeerId = peerId.PeerId() kv.PeerId = peerId.PeerId()
kv.Key = innerValue.Key kv.Key = innerValue.Key
kv.AclId = innerValue.AclHeadId
// TODO: check that key-peerId is equal to key+peerId? // TODO: check that key-peerId is equal to key+peerId?
if verify { if verify {
if verify, _ = identity.Verify(proto.Value, proto.IdentitySignature); !verify { if verify, _ = identity.Verify(proto.Value, proto.IdentitySignature); !verify {
@ -71,6 +74,7 @@ func (kv KeyValue) AnyEnc(a *anyenc.Arena) *anyenc.Value {
obj := a.NewObject() obj := a.NewObject()
obj.Set("id", a.NewString(kv.KeyPeerId)) obj.Set("id", a.NewString(kv.KeyPeerId))
obj.Set("k", a.NewString(kv.Key)) obj.Set("k", a.NewString(kv.Key))
obj.Set("r", a.NewString(kv.ReadKeyId))
obj.Set("v", kv.Value.AnyEnc(a)) obj.Set("v", kv.Value.AnyEnc(a))
obj.Set("t", a.NewNumberInt(kv.TimestampMilli)) obj.Set("t", a.NewNumberInt(kv.TimestampMilli))
obj.Set("i", a.NewString(kv.Identity)) obj.Set("i", a.NewString(kv.Identity))

View file

@ -156,11 +156,12 @@ func (s *storage) keyValueFromDoc(doc anystore.Doc) KeyValue {
} }
return KeyValue{ return KeyValue{
KeyPeerId: doc.Value().GetString("id"), KeyPeerId: doc.Value().GetString("id"),
ReadKeyId: doc.Value().GetString("r"),
Value: value, Value: value,
TimestampMilli: doc.Value().GetInt("t"), TimestampMilli: doc.Value().GetInt("t"),
Identity: doc.Value().GetString("i"), Identity: doc.Value().GetString("i"),
PeerId: doc.Value().GetString("p"), PeerId: doc.Value().GetString("p"),
Key: doc.Value().GetString("k"), Key: doc.Value().GetString("k"),
} }
} }

View file

@ -2,6 +2,7 @@ package keyvaluestorage
import ( import (
"context" "context"
"fmt"
"sync" "sync"
"time" "time"
@ -16,6 +17,8 @@ import (
"github.com/anyproto/any-sync/commonspace/object/keyvalue/keyvaluestorage/innerstorage" "github.com/anyproto/any-sync/commonspace/object/keyvalue/keyvaluestorage/innerstorage"
"github.com/anyproto/any-sync/commonspace/object/keyvalue/keyvaluestorage/syncstorage" "github.com/anyproto/any-sync/commonspace/object/keyvalue/keyvaluestorage/syncstorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/util/crypto"
"github.com/anyproto/any-sync/util/slice"
) )
var log = logger.NewNamed("common.keyvalue.keyvaluestorage") var log = logger.NewNamed("common.keyvalue.keyvaluestorage")
@ -24,9 +27,11 @@ const IndexerCName = "common.keyvalue.indexer"
type Indexer interface { type Indexer interface {
app.Component app.Component
Index(keyValue ...innerstorage.KeyValue) error Index(decryptor Decryptor, keyValue ...innerstorage.KeyValue) error
} }
type Decryptor = func(kv innerstorage.KeyValue) (value []byte, err error)
type NoOpIndexer struct{} type NoOpIndexer struct{}
func (n NoOpIndexer) Init(a *app.App) (err error) { func (n NoOpIndexer) Init(a *app.App) (err error) {
@ -37,7 +42,7 @@ func (n NoOpIndexer) Name() (name string) {
return IndexerCName return IndexerCName
} }
func (n NoOpIndexer) Index(keyValue ...innerstorage.KeyValue) error { func (n NoOpIndexer) Index(decryptor Decryptor, keyValue ...innerstorage.KeyValue) error {
return nil return nil
} }
@ -50,13 +55,15 @@ type Storage interface {
} }
type storage struct { type storage struct {
inner innerstorage.KeyValueStorage inner innerstorage.KeyValueStorage
keys *accountdata.AccountKeys keys *accountdata.AccountKeys
aclList list.AclList aclList list.AclList
syncClient syncstorage.SyncClient syncClient syncstorage.SyncClient
indexer Indexer indexer Indexer
storageId string storageId string
mx sync.Mutex readKeys map[string]crypto.SymKey
currentReadKey crypto.SymKey
mx sync.Mutex
} }
func New( func New(
@ -80,6 +87,7 @@ func New(
aclList: aclList, aclList: aclList,
indexer: indexer, indexer: indexer,
syncClient: syncClient, syncClient: syncClient,
readKeys: make(map[string]crypto.SymKey),
}, nil }, nil
} }
@ -92,11 +100,21 @@ func (s *storage) Set(ctx context.Context, key string, value []byte) error {
defer s.mx.Unlock() defer s.mx.Unlock()
s.aclList.RLock() s.aclList.RLock()
headId := s.aclList.Head().Id headId := s.aclList.Head().Id
if !s.aclList.AclState().Permissions(s.aclList.AclState().Identity()).CanWrite() { state := s.aclList.AclState()
if !s.aclList.AclState().Permissions(state.Identity()).CanWrite() {
s.aclList.RUnlock() s.aclList.RUnlock()
return list.ErrInsufficientPermissions return list.ErrInsufficientPermissions
} }
readKeyId := state.CurrentReadKeyId()
err := s.readKeysFromAclState(state)
if err != nil {
return err
}
s.aclList.RUnlock() s.aclList.RUnlock()
value, err = s.currentReadKey.Encrypt(value)
if err != nil {
return err
}
peerIdKey := s.keys.PeerKey peerIdKey := s.keys.PeerKey
identityKey := s.keys.SignKey identityKey := s.keys.SignKey
protoPeerKey, err := peerIdKey.GetPublic().Marshall() protoPeerKey, err := peerIdKey.GetPublic().Marshall()
@ -135,6 +153,8 @@ func (s *storage) Set(ctx context.Context, key string, value []byte) error {
TimestampMilli: int(timestampMilli), TimestampMilli: int(timestampMilli),
Identity: identityKey.GetPublic().Account(), Identity: identityKey.GetPublic().Account(),
PeerId: peerIdKey.GetPublic().PeerId(), PeerId: peerIdKey.GetPublic().PeerId(),
AclId: headId,
ReadKeyId: readKeyId,
Value: innerstorage.Value{ Value: innerstorage.Value{
Value: value, Value: value,
PeerSignature: peerSig, PeerSignature: peerSig,
@ -145,7 +165,7 @@ func (s *storage) Set(ctx context.Context, key string, value []byte) error {
if err != nil { if err != nil {
return err return err
} }
indexErr := s.indexer.Index(keyValue) indexErr := s.indexer.Index(s.decrypt, keyValue)
if indexErr != nil { if indexErr != nil {
log.Warn("failed to index for key", zap.String("key", key), zap.Error(indexErr)) log.Warn("failed to index for key", zap.String("key", key), zap.Error(indexErr))
} }
@ -156,7 +176,7 @@ func (s *storage) Set(ctx context.Context, key string, value []byte) error {
return nil return nil
} }
func (s *storage) SetRaw(ctx context.Context, keyValue ...*spacesyncproto.StoreKeyValue) error { func (s *storage) SetRaw(ctx context.Context, keyValue ...*spacesyncproto.StoreKeyValue) (err error) {
if len(keyValue) == 0 { if len(keyValue) == 0 {
return nil return nil
} }
@ -170,18 +190,36 @@ func (s *storage) SetRaw(ctx context.Context, keyValue ...*spacesyncproto.StoreK
} }
keyValues = append(keyValues, innerKv) keyValues = append(keyValues, innerKv)
} }
err := s.inner.Set(ctx, keyValues...) s.aclList.RLock()
state := s.aclList.AclState()
err = s.readKeysFromAclState(state)
if err != nil { if err != nil {
s.aclList.RUnlock()
return err return err
} }
indexErr := s.indexer.Index(keyValues...) for _, kv := range keyValues {
if indexErr != nil { kv.ReadKeyId, err = state.ReadKeyForAclId(kv.AclId)
log.Warn("failed to index for keys", zap.Error(indexErr)) if err != nil {
kv.KeyPeerId = ""
continue
}
}
s.aclList.RUnlock()
keyValues = slice.DiscardFromSlice(keyValues, func(value innerstorage.KeyValue) bool {
return value.KeyPeerId == ""
})
err = s.inner.Set(ctx, keyValues...)
if err != nil {
return err
} }
sendErr := s.syncClient.Broadcast(ctx, s.storageId, keyValues...) sendErr := s.syncClient.Broadcast(ctx, s.storageId, keyValues...)
if sendErr != nil { if sendErr != nil {
log.Warn("failed to send key values", zap.Error(sendErr)) log.Warn("failed to send key values", zap.Error(sendErr))
} }
indexErr := s.indexer.Index(s.decrypt, keyValues...)
if indexErr != nil {
log.Warn("failed to index for keys", zap.Error(indexErr))
}
return nil return nil
} }
@ -196,3 +234,57 @@ func (s *storage) GetAll(ctx context.Context, key string) (values []innerstorage
func (s *storage) InnerStorage() innerstorage.KeyValueStorage { func (s *storage) InnerStorage() innerstorage.KeyValueStorage {
return s.inner return s.inner
} }
func (s *storage) readKeysFromAclState(state *list.AclState) (err error) {
if len(s.readKeys) == len(state.Keys()) {
return nil
}
if state.AccountKey() == nil || !state.HadReadPermissions(state.AccountKey().GetPublic()) {
return nil
}
for key, value := range state.Keys() {
if _, exists := s.readKeys[key]; exists {
continue
}
if value.ReadKey == nil {
continue
}
treeKey, err := deriveKey(value.ReadKey, s.storageId)
if err != nil {
return err
}
s.readKeys[key] = treeKey
}
curKey, err := state.CurrentReadKey()
if err != nil {
return err
}
if curKey == nil {
return nil
}
s.currentReadKey, err = deriveKey(curKey, s.storageId)
return err
}
func (s *storage) decrypt(kv innerstorage.KeyValue) (value []byte, err error) {
if kv.ReadKeyId == "" {
return nil, fmt.Errorf("no read key id")
}
key := s.readKeys[kv.ReadKeyId]
if key == nil {
return nil, fmt.Errorf("no read key for %s", kv.ReadKeyId)
}
value, err = key.Decrypt(kv.Value.Value)
if err != nil {
return nil, err
}
return value, nil
}
func deriveKey(key crypto.SymKey, id string) (crypto.SymKey, error) {
raw, err := key.Raw()
if err != nil {
return nil, err
}
return crypto.DeriveSymmetricKey(raw, fmt.Sprintf(crypto.AnysyncKeyValuePath, id))
}

View file

@ -5,8 +5,9 @@ import (
) )
const ( const (
AnysyncSpacePath = "m/SLIP-0021/anysync/space" AnysyncSpacePath = "m/SLIP-0021/anysync/space"
AnysyncTreePath = "m/SLIP-0021/anysync/tree/%s" AnysyncTreePath = "m/SLIP-0021/anysync/tree/%s"
AnysyncKeyValuePath = "m/SLIP-0021/anysync/keyvalue/%s"
) )
// DeriveSymmetricKey derives a symmetric key from seed and path using slip-21 // DeriveSymmetricKey derives a symmetric key from seed and path using slip-21