diff --git a/commonspace/deletion_test.go b/commonspace/deletion_test.go index add216a5..2f49ce63 100644 --- a/commonspace/deletion_test.go +++ b/commonspace/deletion_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/anyproto/any-sync/commonspace/object/accountdata" + "github.com/anyproto/any-sync/commonspace/object/acl/recordverifier" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/spacepayloads" "github.com/anyproto/any-sync/commonspace/spacestorage" @@ -17,6 +18,14 @@ import ( "github.com/anyproto/any-sync/util/crypto" ) +func mockDeps() Deps { + return Deps{ + TreeSyncer: mockTreeSyncer{}, + SyncStatus: syncstatus.NewNoOpSyncStatus(), + recordVerifier: recordverifier.NewAlwaysAccept(), + } +} + func createTree(t *testing.T, ctx context.Context, spc Space, acc *accountdata.AccountKeys) string { bytes := make([]byte, 32) rand.Read(bytes) @@ -60,7 +69,7 @@ func TestSpaceDeleteIdsMarkDeleted(t *testing.T) { require.NotNil(t, sp) // initializing space - spc, err := fx.spaceService.NewSpace(ctx, sp, Deps{TreeSyncer: mockTreeSyncer{}, SyncStatus: syncstatus.NewNoOpSyncStatus()}) + spc, err := fx.spaceService.NewSpace(ctx, sp, mockDeps()) require.NoError(t, err) require.NotNil(t, spc) // adding space to tree manager @@ -109,7 +118,7 @@ func TestSpaceDeleteIdsMarkDeleted(t *testing.T) { time.Sleep(100 * time.Millisecond) storeSetter := fx.storageProvider.(storeSetter) storeSetter.SetStore(sp, newStore) - spc, err = fx.spaceService.NewSpace(ctx, sp, Deps{TreeSyncer: mockTreeSyncer{}, SyncStatus: syncstatus.NewNoOpSyncStatus()}) + spc, err = fx.spaceService.NewSpace(ctx, sp, mockDeps()) require.NoError(t, err) require.NotNil(t, spc) waitTest := make(chan struct{}) @@ -153,7 +162,7 @@ func TestSpaceDeleteIds(t *testing.T) { require.NotNil(t, sp) // initializing space - spc, err := fx.spaceService.NewSpace(ctx, sp, Deps{TreeSyncer: mockTreeSyncer{}, SyncStatus: syncstatus.NewNoOpSyncStatus()}) + spc, err := fx.spaceService.NewSpace(ctx, sp, mockDeps()) require.NoError(t, err) require.NotNil(t, spc) // adding space to tree manager @@ -202,7 +211,7 @@ func TestSpaceDeleteIds(t *testing.T) { time.Sleep(100 * time.Millisecond) storeSetter := fx.storageProvider.(storeSetter) storeSetter.SetStore(sp, newStore) - spc, err = fx.spaceService.NewSpace(ctx, sp, Deps{TreeSyncer: mockTreeSyncer{}, SyncStatus: syncstatus.NewNoOpSyncStatus()}) + spc, err = fx.spaceService.NewSpace(ctx, sp, mockDeps()) require.NoError(t, err) require.NotNil(t, spc) waitTest := make(chan struct{}) diff --git a/commonspace/object/acl/recordverifier/alwaysaccept.go b/commonspace/object/acl/recordverifier/alwaysaccept.go new file mode 100644 index 00000000..c5712d2a --- /dev/null +++ b/commonspace/object/acl/recordverifier/alwaysaccept.go @@ -0,0 +1,24 @@ +package recordverifier + +import ( + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/consensus/consensusproto" +) + +type AlwaysAccept struct{} + +func NewAlwaysAccept() RecordVerifier { + return &AlwaysAccept{} +} + +func (a *AlwaysAccept) Init(_ *app.App) error { + return nil +} + +func (a *AlwaysAccept) Name() string { + return CName +} + +func (a *AlwaysAccept) VerifyAcceptor(_ *consensusproto.RawRecord) error { + return nil +} diff --git a/commonspace/object/acl/recordverifier/recordverifier.go b/commonspace/object/acl/recordverifier/recordverifier.go new file mode 100644 index 00000000..8ecba76c --- /dev/null +++ b/commonspace/object/acl/recordverifier/recordverifier.go @@ -0,0 +1,55 @@ +package recordverifier + +import ( + "fmt" + + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/commonspace/object/acl/list" + "github.com/anyproto/any-sync/consensus/consensusproto" + "github.com/anyproto/any-sync/nodeconf" + "github.com/anyproto/any-sync/util/crypto" +) + +const CName = "common.acl.recordverifier" + +type RecordVerifier interface { + app.Component + list.AcceptorVerifier +} + +func New() RecordVerifier { + return &recordVerifier{} +} + +type recordVerifier struct { + configuration nodeconf.NodeConf + networkKey crypto.PubKey + store crypto.KeyStorage +} + +func (r *recordVerifier) Init(a *app.App) (err error) { + r.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf) + r.store = crypto.NewKeyStorage() + networkId := r.configuration.Configuration().NetworkId + r.networkKey, err = crypto.DecodeNetworkId(networkId) + return +} + +func (r *recordVerifier) Name() (name string) { + return CName +} + +func (r *recordVerifier) VerifyAcceptor(rec *consensusproto.RawRecord) (err error) { + identity, err := r.store.PubKeyFromProto(rec.AcceptorIdentity) + if err != nil { + return fmt.Errorf("failed to get acceptor identity: %w", err) + } + if !identity.Equals(r.networkKey) { + return fmt.Errorf("acceptor identity does not match network key") + } + verified, err := r.networkKey.Verify(rec.Payload, rec.AcceptorSignature) + if !verified || err != nil { + return fmt.Errorf("failed to verify acceptor: %w", err) + } + return nil +} diff --git a/commonspace/object/acl/recordverifier/recordverifier_test.go b/commonspace/object/acl/recordverifier/recordverifier_test.go new file mode 100644 index 00000000..40b9b337 --- /dev/null +++ b/commonspace/object/acl/recordverifier/recordverifier_test.go @@ -0,0 +1,108 @@ +package recordverifier + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/consensus/consensusproto" + "github.com/anyproto/any-sync/nodeconf/testconf" + "github.com/anyproto/any-sync/testutil/accounttest" + "github.com/anyproto/any-sync/util/crypto" +) + +type fixture struct { + *recordVerifier + app *app.App + networkPrivKey crypto.PrivKey +} + +func newFixture(t *testing.T) *fixture { + accService := &accounttest.AccountTestService{} + a := &app.App{} + verifier := &recordVerifier{} + a.Register(accService). + Register(&testconf.StubConf{}). + Register(verifier) + require.NoError(t, a.Start(context.Background())) + return &fixture{ + recordVerifier: verifier, + app: a, + networkPrivKey: accService.Account().SignKey, + } +} + +func TestRecordVerifier_VerifyAcceptor(t *testing.T) { + fx := newFixture(t) + identity, err := fx.networkPrivKey.GetPublic().Marshall() + require.NoError(t, err) + testPayload := []byte("test payload") + acceptorSignature, err := fx.networkPrivKey.Sign(testPayload) + require.NoError(t, err) + rawRecord := &consensusproto.RawRecord{ + AcceptorIdentity: identity, + Payload: testPayload, + AcceptorSignature: acceptorSignature, + } + err = fx.VerifyAcceptor(rawRecord) + require.NoError(t, err) +} + +func TestRecordVerifier_VerifyAcceptor_InvalidSignature(t *testing.T) { + fx := newFixture(t) + identity, err := fx.networkPrivKey.GetPublic().Marshall() + require.NoError(t, err) + testPayload := []byte("test payload") + rawRecord := &consensusproto.RawRecord{ + AcceptorIdentity: identity, + Payload: testPayload, + AcceptorSignature: []byte("invalid signature"), + } + err = fx.VerifyAcceptor(rawRecord) + require.Error(t, err) +} + +func TestRecordVerifier_VerifyAcceptor_ModifiedPayload(t *testing.T) { + fx := newFixture(t) + identity, err := fx.networkPrivKey.GetPublic().Marshall() + require.NoError(t, err) + testPayload := []byte("test payload") + acceptorSignature, err := fx.networkPrivKey.Sign(testPayload) + require.NoError(t, err) + rawRecord := &consensusproto.RawRecord{ + AcceptorIdentity: identity, + Payload: []byte("modified payload"), + AcceptorSignature: acceptorSignature, + } + err = fx.VerifyAcceptor(rawRecord) + require.Error(t, err) +} + +func TestRecordVerifier_VerifyAcceptor_InvalidIdentity(t *testing.T) { + fx := newFixture(t) + testPayload := []byte("test payload") + acceptorSignature, err := fx.networkPrivKey.Sign(testPayload) + require.NoError(t, err) + rawRecord := &consensusproto.RawRecord{ + AcceptorIdentity: []byte("invalid identity"), + Payload: testPayload, + AcceptorSignature: acceptorSignature, + } + err = fx.VerifyAcceptor(rawRecord) + require.Error(t, err) +} + +func TestRecordVerifier_VerifyAcceptor_EmptySignature(t *testing.T) { + fx := newFixture(t) + identity, err := fx.networkPrivKey.GetPublic().Marshall() + require.NoError(t, err) + rawRecord := &consensusproto.RawRecord{ + AcceptorIdentity: identity, + Payload: []byte("test payload"), + AcceptorSignature: nil, + } + err = fx.VerifyAcceptor(rawRecord) + require.Error(t, err) +} diff --git a/commonspace/spacerpc_test.go b/commonspace/spacerpc_test.go index dd44695b..892f0f49 100644 --- a/commonspace/spacerpc_test.go +++ b/commonspace/spacerpc_test.go @@ -10,6 +10,7 @@ import ( "github.com/anyproto/any-sync/accountservice" "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/commonspace/object/acl/recordverifier" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/synctree" "github.com/anyproto/any-sync/commonspace/object/treemanager" @@ -127,8 +128,9 @@ func (r *RpcServer) getSpace(ctx context.Context, spaceId string) (sp Space, err sp, ok := r.spaces[spaceId] if !ok { sp, err = r.spaceService.NewSpace(ctx, spaceId, Deps{ - TreeSyncer: NewTreeSyncer(spaceId), - SyncStatus: syncstatus.NewNoOpSyncStatus(), + TreeSyncer: NewTreeSyncer(spaceId), + SyncStatus: syncstatus.NewNoOpSyncStatus(), + recordVerifier: recordverifier.NewAlwaysAccept(), }) if err != nil { return nil, err diff --git a/commonspace/spaceservice.go b/commonspace/spaceservice.go index 630e15c2..e2a14f0a 100644 --- a/commonspace/spaceservice.go +++ b/commonspace/spaceservice.go @@ -13,6 +13,7 @@ import ( "github.com/anyproto/any-sync/commonspace/acl/aclclient" "github.com/anyproto/any-sync/commonspace/deletionmanager" + "github.com/anyproto/any-sync/commonspace/object/acl/recordverifier" "github.com/anyproto/any-sync/commonspace/object/keyvalue" "github.com/anyproto/any-sync/commonspace/object/keyvalue/keyvaluestorage" "github.com/anyproto/any-sync/commonspace/object/treesyncer" @@ -72,6 +73,7 @@ type Deps struct { SyncStatus syncstatus.StatusUpdater TreeSyncer treesyncer.TreeSyncer AccountService accountservice.Service + recordVerifier recordverifier.RecordVerifier Indexer keyvaluestorage.Indexer } @@ -188,8 +190,13 @@ func (s *spaceService) NewSpace(ctx context.Context, id string, deps Deps) (Spac if deps.Indexer != nil { keyValueIndexer = deps.Indexer } + recordVerifier := recordverifier.New() + if deps.recordVerifier != nil { + recordVerifier = deps.recordVerifier + } spaceApp.Register(state). Register(deps.SyncStatus). + Register(recordVerifier). Register(peerManager). Register(st). Register(keyValueIndexer). diff --git a/commonspace/spaceutils_test.go b/commonspace/spaceutils_test.go index 2d5e97d5..75eae124 100644 --- a/commonspace/spaceutils_test.go +++ b/commonspace/spaceutils_test.go @@ -10,7 +10,6 @@ import ( "time" anystore "github.com/anyproto/any-store" - "github.com/anyproto/go-chash" "github.com/stretchr/testify/require" "go.uber.org/zap" "storj.io/drpc" @@ -44,114 +43,12 @@ import ( "github.com/anyproto/any-sync/net/streampool/streamhandler" "github.com/anyproto/any-sync/node/nodeclient" "github.com/anyproto/any-sync/nodeconf" + "github.com/anyproto/any-sync/nodeconf/testconf" "github.com/anyproto/any-sync/testutil/accounttest" "github.com/anyproto/any-sync/util/crypto" "github.com/anyproto/any-sync/util/syncqueues" ) -type mockConf struct { - id string - networkId string - configuration nodeconf.Configuration -} - -func (m *mockConf) NetworkCompatibilityStatus() nodeconf.NetworkCompatibilityStatus { - return nodeconf.NetworkCompatibilityStatusOk -} - -func (m *mockConf) Init(a *app.App) (err error) { - accountKeys := a.MustComponent(accountService.CName).(accountService.Service).Account() - networkId := accountKeys.SignKey.GetPublic().Network() - node := nodeconf.Node{ - PeerId: accountKeys.PeerId, - Addresses: []string{"127.0.0.1:4430"}, - Types: []nodeconf.NodeType{nodeconf.NodeTypeTree}, - } - m.id = networkId - m.networkId = networkId - m.configuration = nodeconf.Configuration{ - Id: networkId, - NetworkId: networkId, - Nodes: []nodeconf.Node{node}, - CreationTime: time.Now(), - } - return nil -} - -func (m *mockConf) Name() (name string) { - return nodeconf.CName -} - -func (m *mockConf) Run(ctx context.Context) (err error) { - return nil -} - -func (m *mockConf) Close(ctx context.Context) (err error) { - return nil -} - -func (m *mockConf) Id() string { - return m.id -} - -func (m *mockConf) Configuration() nodeconf.Configuration { - return m.configuration -} - -func (m *mockConf) NodeIds(spaceId string) []string { - var nodeIds []string - for _, node := range m.configuration.Nodes { - nodeIds = append(nodeIds, node.PeerId) - } - return nodeIds -} - -func (m *mockConf) IsResponsible(spaceId string) bool { - return true -} - -func (m *mockConf) FilePeers() []string { - return nil -} - -func (m *mockConf) ConsensusPeers() []string { - return nil -} - -func (m *mockConf) CoordinatorPeers() []string { - return nil -} - -func (m *mockConf) NamingNodePeers() []string { - return nil -} - -func (m *mockConf) PaymentProcessingNodePeers() []string { - return nil -} - -func (m *mockConf) PeerAddresses(peerId string) (addrs []string, ok bool) { - if peerId == m.configuration.Nodes[0].PeerId { - return m.configuration.Nodes[0].Addresses, true - } - return nil, false -} - -func (m *mockConf) CHash() chash.CHash { - return nil -} - -func (m *mockConf) Partition(spaceId string) (part int) { - return 0 -} - -func (m *mockConf) NodeTypes(nodeId string) []nodeconf.NodeType { - if nodeId == m.configuration.Nodes[0].PeerId { - return m.configuration.Nodes[0].Types - } - return nil -} - var _ nodeclient.NodeClient = (*mockNodeClient)(nil) type mockNodeClient struct { @@ -654,7 +551,7 @@ func newFixture(t *testing.T) *spaceFixture { app: &app.App{}, config: &mockConfig{}, account: &accounttest.AccountTestService{}, - configurationService: &mockConf{}, + configurationService: &testconf.StubConf{}, streamOpener: newStreamOpener("spaceId"), peerManagerProvider: &testPeerManagerProvider{}, storageProvider: &spaceStorageProvider{rootPath: t.TempDir()}, @@ -699,7 +596,7 @@ func newPeerFixture(t *testing.T, spaceId string, keys *accountdata.AccountKeys, app: &app.App{}, config: &mockConfig{}, account: accounttest.NewWithAcc(keys), - configurationService: &mockConf{}, + configurationService: &testconf.StubConf{}, storageProvider: provider, streamOpener: newStreamOpener(spaceId), peerManagerProvider: &testPeerManagerProvider{}, diff --git a/nodeconf/testconf/nodeconf.go b/nodeconf/testconf/nodeconf.go new file mode 100644 index 00000000..242a8b53 --- /dev/null +++ b/nodeconf/testconf/nodeconf.go @@ -0,0 +1,115 @@ +package testconf + +import ( + "context" + "time" + + "github.com/anyproto/go-chash" + + accountService "github.com/anyproto/any-sync/accountservice" + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/nodeconf" +) + +type StubConf struct { + id string + networkId string + configuration nodeconf.Configuration +} + +func (m *StubConf) NetworkCompatibilityStatus() nodeconf.NetworkCompatibilityStatus { + return nodeconf.NetworkCompatibilityStatusOk +} + +func (m *StubConf) Init(a *app.App) (err error) { + accountKeys := a.MustComponent(accountService.CName).(accountService.Service).Account() + networkId := accountKeys.SignKey.GetPublic().Network() + node := nodeconf.Node{ + PeerId: accountKeys.PeerId, + Addresses: []string{"127.0.0.1:4430"}, + Types: []nodeconf.NodeType{nodeconf.NodeTypeTree}, + } + m.id = networkId + m.networkId = networkId + m.configuration = nodeconf.Configuration{ + Id: networkId, + NetworkId: networkId, + Nodes: []nodeconf.Node{node}, + CreationTime: time.Now(), + } + return nil +} + +func (m *StubConf) Name() (name string) { + return nodeconf.CName +} + +func (m *StubConf) Run(ctx context.Context) (err error) { + return nil +} + +func (m *StubConf) Close(ctx context.Context) (err error) { + return nil +} + +func (m *StubConf) Id() string { + return m.id +} + +func (m *StubConf) Configuration() nodeconf.Configuration { + return m.configuration +} + +func (m *StubConf) NodeIds(spaceId string) []string { + var nodeIds []string + for _, node := range m.configuration.Nodes { + nodeIds = append(nodeIds, node.PeerId) + } + return nodeIds +} + +func (m *StubConf) IsResponsible(spaceId string) bool { + return true +} + +func (m *StubConf) FilePeers() []string { + return nil +} + +func (m *StubConf) ConsensusPeers() []string { + return nil +} + +func (m *StubConf) CoordinatorPeers() []string { + return nil +} + +func (m *StubConf) NamingNodePeers() []string { + return nil +} + +func (m *StubConf) PaymentProcessingNodePeers() []string { + return nil +} + +func (m *StubConf) PeerAddresses(peerId string) (addrs []string, ok bool) { + if peerId == m.configuration.Nodes[0].PeerId { + return m.configuration.Nodes[0].Addresses, true + } + return nil, false +} + +func (m *StubConf) CHash() chash.CHash { + return nil +} + +func (m *StubConf) Partition(spaceId string) (part int) { + return 0 +} + +func (m *StubConf) NodeTypes(nodeId string) []nodeconf.NodeType { + if nodeId == m.configuration.Nodes[0].PeerId { + return m.configuration.Nodes[0].Types + } + return nil +} diff --git a/util/crypto/decode.go b/util/crypto/decode.go index ce6c58e2..30f60183 100644 --- a/util/crypto/decode.go +++ b/util/crypto/decode.go @@ -2,8 +2,10 @@ package crypto import ( "encoding/base64" - "github.com/anyproto/any-sync/util/strkey" + "github.com/libp2p/go-libp2p/core/peer" + + "github.com/anyproto/any-sync/util/strkey" ) func EncodeKeyToString[T Key](key T) (str string, err error) { @@ -55,3 +57,11 @@ func DecodePeerId(peerId string) (PubKey, error) { } return UnmarshalEd25519PublicKey(raw) } + +func DecodeNetworkId(networkId string) (PubKey, error) { + pubKeyRaw, err := strkey.Decode(strkey.NetworkAddressVersionByte, networkId) + if err != nil { + return nil, err + } + return UnmarshalEd25519PublicKey(pubKeyRaw) +} diff --git a/util/crypto/decode_test.go b/util/crypto/decode_test.go new file mode 100644 index 00000000..1f3da8d5 --- /dev/null +++ b/util/crypto/decode_test.go @@ -0,0 +1,18 @@ +package crypto + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDecodeNetworkId(t *testing.T) { + _, pubKey, err := GenerateRandomEd25519KeyPair() + require.NoError(t, err) + + networkId := pubKey.Network() + require.Equal(t, uint8('N'), networkId[0]) + decodedKey, err := DecodeNetworkId(networkId) + require.NoError(t, err) + require.Equal(t, pubKey, decodedKey) +}