1
0
Fork 0
mirror of https://github.com/anyproto/anytype-heart.git synced 2025-06-11 02:13:41 +09:00

Merge branch 'main' of github.com:anyproto/anytype-heart into snapshoting

This commit is contained in:
AnastasiaShemyakinskaya 2024-08-02 10:34:40 +02:00
commit 55c3b87a0e
No known key found for this signature in database
GPG key ID: CCD60ED83B103281
158 changed files with 10295 additions and 5604 deletions

22
.github/install_macos_sdk.sh vendored Executable file
View file

@ -0,0 +1,22 @@
#!/usr/bin/env bash
# Install an older MacOS SDK
OSX_SDK_DIR="$(xcode-select -p)/Platforms/MacOSX.platform/Developer/SDKs"
export MACOSX_DEPLOYMENT_TARGET=$1
export MACOSX_SDK_VERSION=$MACOSX_DEPLOYMENT_TARGET
export OSX_SYSROOT="${OSX_SDK_DIR}/MacOSX${MACOSX_SDK_VERSION}.sdk"
FILENAME="MacOSX${MACOSX_SDK_VERSION}.sdk.tar.xz"
DOWNLOAD_URL="https://github.com/phracker/MacOSX-SDKs/releases/download/10.15/${FILENAME}"
if [[ ! -d ${OSX_SYSROOT}} ]]; then
echo "MacOS SDK ${MACOSX_SDK_VERSION} is missing, downloading..."
curl -L -O --connect-timeout 5 --max-time 10 --retry 10 --retry-delay 0 --retry-max-time 40 --retry-connrefused --retry-all-errors \
${DOWNLOAD_URL}
tar -xf ${FILENAME} -C "$(dirname ${OSX_SYSROOT})"
fi
plutil -replace MinimumSDKVersion -string ${MACOSX_SDK_VERSION} $(xcode-select -p)/Platforms/MacOSX.platform/Info.plist
plutil -replace DTSDKName -string macosx${MACOSX_SDK_VERSION}internal $(xcode-select -p)/Platforms/MacOSX.platform/Info.plist
echo "SDKROOT=${OSX_SYSROOT}" >> ${GITHUB_ENV}

View file

@ -10,7 +10,7 @@ on:
run-on-runner:
description: 'Specify the runner to use'
required: true
default: 'self-hosted'
default: 'ARM64'
perf-test:
description: 'Run perf test times'
required: true
@ -32,11 +32,11 @@ permissions:
name: Build
jobs:
build:
runs-on: ${{ github.event_name == 'push' && 'macos-11' || (github.event.inputs.run-on-runner || 'self-hosted') }}
runs-on: ${{ github.event_name == 'push' && 'macos-12' || (github.event.inputs.run-on-runner || 'ARM64') }}
steps:
- name: validate agent
run: |
if [[ "${{ github.event_name }}" == "workflow_dispatch" && "${{ github.event.inputs.run-on-runner }}" != "self-hosted" ]]; then
if [[ "${{ github.event_name }}" == "workflow_dispatch" && "${{ github.event.inputs.run-on-runner }}" != "ARM64" ]]; then
echo "Invalid runner"
exit 1
fi
@ -44,7 +44,7 @@ jobs:
uses: actions/setup-go@v1
with:
go-version: 1.22
if: github.event.inputs.run-on-runner != 'self-hosted' && github.event_name != 'schedule'
if: github.event.inputs.run-on-runner != 'ARM64' && github.event_name != 'schedule'
- name: Setup GO
run: |
go version
@ -60,16 +60,17 @@ jobs:
git fetch
git checkout db6184738b77fbd5089e5fa1112177f391c91b24
go install github.com/mitchellh/gox
if: github.event.inputs.run-on-runner != 'self-hosted' && github.event_name != 'schedule'
if: github.event.inputs.run-on-runner != 'ARM64' && github.event_name != 'schedule'
- name: Install brew and node deps
run: |
curl https://raw.githubusercontent.com/Homebrew/homebrew-core/31b24d65a7210ea0a5689d5ad00dd8d1bf5211db/Formula/protobuf.rb --output protobuf.rb
curl https://raw.githubusercontent.com/Homebrew/homebrew-core/d600b1f7119f6e6a4e97fb83233b313b0468b7e4/Formula/s/swift-protobuf.rb --output swift-protobuf.rb
HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 brew install ./protobuf.rb
HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 brew install --ignore-dependencies swift-protobuf
HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 brew install --ignore-dependencies ./swift-protobuf.rb
HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 brew install mingw-w64
HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 brew install grpcurl
npm i -g node-gyp
if: github.event.inputs.run-on-runner != 'self-hosted' && github.event_name != 'schedule'
if: github.event.inputs.run-on-runner != 'ARM64' && github.event_name != 'schedule'
- name: Checkout
uses: actions/checkout@v3
- uses: actions/cache@v3
@ -79,6 +80,9 @@ jobs:
key: ${{ runner.os }}-go-${{ matrix.go-version }}-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-${{ matrix.go-version }}-
- name: Install old MacOS SDK (for backward compatibility of CGO)
run: source .github/install_macos_sdk.sh 10.15
if: runner.os == 'macOS' && startsWith(runner.os_version, '12')
- name: Set env vars
env:
UNSPLASH_KEY: ${{ secrets.UNSPLASH_KEY }}
@ -97,7 +101,6 @@ jobs:
fi
echo VERSION=${VERSION} >> $GITHUB_ENV
echo MAVEN_ARTIFACT_VERSION=${VERSION} >> $GITHUB_ENV
echo SDKROOT=$(xcrun --sdk macosx --show-sdk-path) >> $GITHUB_ENV
echo GOPRIVATE=github.com/anyproto >> $GITHUB_ENV
echo $(pwd)/deps >> $GITHUB_PATH
echo "${GOBIN}" >> $GITHUB_PATH
@ -115,9 +118,14 @@ jobs:
which gomobile
- name: Cross-compile library mac/win
run: |
make download-tantivy-all
echo $FLAGS
mkdir -p .release
gox -cgo -ldflags="$FLAGS" -osarch="darwin/amd64 darwin/arm64" --tags="envproduction nographviz nowatchdog nosigar nomutexdeadlockdetector" -output="{{.OS}}-{{.Arch}}" github.com/anyproto/anytype-heart/cmd/grpcserver
echo $SDKROOT
gox -cgo -ldflags="$FLAGS" -osarch="darwin/amd64" --tags="envproduction nographviz nowatchdog nosigar nomutexdeadlockdetector" -output="{{.OS}}-{{.Arch}}" github.com/anyproto/anytype-heart/cmd/grpcserver
export SDKROOT=$(xcrun --sdk macosx --show-sdk-path)
echo $SDKROOT
gox -cgo -ldflags="$FLAGS" -osarch="darwin/arm64" --tags="envproduction nographviz nowatchdog nosigar nomutexdeadlockdetector" -output="{{.OS}}-{{.Arch}}" github.com/anyproto/anytype-heart/cmd/grpcserver
make protos-server
CC="x86_64-w64-mingw32-gcc" CXX="x86_64-w64-mingw32-g++" gox -cgo -ldflags="$FLAGS -linkmode external -extldflags=-static" -osarch="windows/amd64" --tags="envproduction nographviz nowatchdog nosigar nomutexdeadlockdetector noheic" -output="{{.OS}}-{{.Arch}}" github.com/anyproto/anytype-heart/cmd/grpcserver
ls -lha .
@ -125,12 +133,13 @@ jobs:
- name: run perf tests
run: |
echo "Running perf tests"
make download-tantivy-all
RUN_COUNT=${{ github.event.inputs.perf-test }}
if [[ "${{ github.event_name }}" == "schedule" ]]; then
RUN_COUNT=10
fi
cd cmd/perftester/
go run main.go $RUN_COUNT
CGO_ENABLED="1" go run main.go $RUN_COUNT
env:
ANYTYPE_REPORT_MEMORY: 'true'
TEST_MNEMONIC: ${{ secrets.TEST_MNEMONIC }}
@ -189,7 +198,7 @@ jobs:
mv js_${VERSION}_${OSARCH}.zip .release/
done
if: github.event_name == 'push'
- name: Pack server unix
- name: Pack server osx
run: |
declare -a arr=("darwin-amd64" "darwin-arm64")
for i in "${arr[@]}"
@ -314,10 +323,8 @@ jobs:
run: |
sudo apt update
sudo apt install -y protobuf-compiler libprotoc-dev
curl -O https://musl.cc/aarch64-linux-musl-cross.tgz
curl -O https://musl.cc/x86_64-linux-musl-native.tgz
tar xzf aarch64-linux-musl-cross.tgz -C $HOME
tar xzf x86_64-linux-musl-native.tgz -C $HOME
curl -O https://pub-c60a000d68b544109df4fe5837762101.r2.dev/linux-compiler-musl-x86.zip
unzip linux-compiler-musl-x86.zip -d $HOME
npm i -g node-gyp
- name: Checkout
uses: actions/checkout@v3
@ -349,10 +356,10 @@ jobs:
make setup-go
- name: Cross-compile library for linux amd64/arm64
run: |
make download-tantivy-all
echo $FLAGS
mkdir -p .release
CXX=$HOME/x86_64-linux-musl-native/bin/x86_64-linux-musl-g++ CC=$HOME/x86_64-linux-musl-native/bin/x86_64-linux-musl-gcc gox -cgo -osarch="linux/amd64" -ldflags="$FLAGS -linkmode external -extldflags=-static" --tags="envproduction nographviz nowatchdog nosigar nomutexdeadlockdetector" -output="{{.OS}}-{{.Arch}}" github.com/anyproto/anytype-heart/cmd/grpcserver
CXX=$HOME/aarch64-linux-musl-cross/bin/aarch64-linux-musl-g++ CC=$HOME/aarch64-linux-musl-cross/bin/aarch64-linux-musl-gcc gox -cgo -osarch="linux/arm64" -ldflags="$FLAGS -linkmode external -extldflags=-static" --tags="envproduction nographviz nowatchdog nosigar nomutexdeadlockdetector" -output="{{.OS}}-{{.Arch}}" github.com/anyproto/anytype-heart/cmd/grpcserver
CXX=$HOME/linux-compiler-musl-x86/bin/x86_64-linux-musl-g++ CC=$HOME/linux-compiler-musl-x86/bin/x86_64-linux-musl-gcc gox -cgo -osarch="linux/amd64" -ldflags="$FLAGS -linkmode external -extldflags=-static" --tags="envproduction nographviz nowatchdog nosigar nomutexdeadlockdetector" -output="{{.OS}}-{{.Arch}}" github.com/anyproto/anytype-heart/cmd/grpcserver
make protos-server
- name: Make JS protos
run: |
@ -379,7 +386,7 @@ jobs:
retention-days: 1
- name: Pack server unix
run: |
declare -a arr=("linux-amd64" "linux-arm64")
declare -a arr=("linux-amd64")
for i in "${arr[@]}"
do
OSARCH=${i%.*}

View file

@ -1,5 +1,10 @@
on: [ pull_request ]
name: Test
on:
push:
branches:
- main
pull_request:
branches:
- '*'
concurrency:
group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
@ -47,6 +52,7 @@ jobs:
license_finder --enabled-package-managers gomodules
- name: Generate mocks
run: |
make download-tantivy-all
go install go.uber.org/mock/mockgen@v0.3.0
CGO_ENABLED=1 CGO_CFLAGS="-Wno-deprecated-declarations -Wno-deprecated-non-prototype -Wno-xor-used-as-pow" go generate ./...
- name: Go test
@ -60,7 +66,12 @@ jobs:
PACKAGE_NAMES=$(go list -tags nogrpcserver ./... | grep -v "github.com/anyproto/anytype-heart/cmd/grpserver" | grep -v "github.com/anyproto/anytype-heart/clientlibrary/clib")
rm -rf ~/gotestsum-report
mkdir ~/gotestsum-report
CGO_CFLAGS="-Wno-deprecated-non-prototype -Wno-unknown-warning-option -Wno-deprecated-declarations -Wno-xor-used-as-pow -Wno-single-bit-bitfield-constant-conversion" gotestsum --junitfile ~/gotestsum-report/gotestsum-report.xml -- -tags "nogrpcserver nographviz" -ldflags="-extldflags=-Wl,-ld_classic" -p 1 $(echo $PACKAGE_NAMES) -race -coverprofile=coverage.out -covermode=atomic ./...
if [[ "$GITHUB_REF" == "refs/heads/main" && "$GITHUB_EVENT_NAME" == "push" ]]; then
export RACE=-race
else
echo "run without race detector"
fi
CGO_CFLAGS="-Wno-deprecated-non-prototype -Wno-unknown-warning-option -Wno-deprecated-declarations -Wno-xor-used-as-pow -Wno-single-bit-bitfield-constant-conversion" gotestsum --junitfile ~/gotestsum-report/gotestsum-report.xml -- -tags "nogrpcserver nographviz" -ldflags="-extldflags=-Wl,-ld_classic" -p 1 $(echo $PACKAGE_NAMES) $(echo $RACE) -coverprofile=coverage.out -covermode=atomic ./...
generated_pattern='^\/\/ Code generated .* DO NOT EDIT\.$'
files_list=$(grep -rl "$generated_pattern" . | grep '\.go$' | sed 's/^\.\///')

View file

@ -13,6 +13,7 @@ issues:
- pb
exclude-files:
- '.*_test.go'
- 'mock*'
- 'testMock/*'
- 'clientlibrary/service/service.pb.go'

View file

@ -192,6 +192,9 @@ packages:
interfaces:
PeerStatusChecker:
SyncDetailsUpdater:
github.com/anyproto/anytype-heart/core/syncstatus/nodestatus:
interfaces:
NodeStatus:
github.com/anyproto/anytype-heart/core/syncstatus/objectsyncstatus:
interfaces:
Updater:
@ -210,4 +213,6 @@ packages:
github.com/anyproto/anytype-heart/core/syncstatus/spacesyncstatus:
interfaces:
SpaceIdGetter:
NodeUsage:
NetworkConfig:
Updater:

View file

@ -2,6 +2,7 @@ CUSTOM_NETWORK_FILE ?= ./core/anytype/config/nodes/custom.yml
CLIENT_DESKTOP_PATH ?= ../anytype-ts
CLIENT_ANDROID_PATH ?= ../anytype-kotlin
CLIENT_IOS_PATH ?= ../anytype-swift
TANTIVY_GO_PATH ?= ../tantivy-go
BUILD_FLAGS ?=
export GOLANGCI_LINT_VERSION=1.58.1
@ -66,13 +67,17 @@ test:
@echo 'Running tests...'
@ANYTYPE_LOG_NOGELF=1 go test -cover github.com/anyproto/anytype-heart/...
test-no-cache:
@echo 'Running tests...'
@ANYTYPE_LOG_NOGELF=1 go test -count=1 github.com/anyproto/anytype-heart/...
test-integration:
@echo 'Running integration tests...'
@go test -run=TestBasic -tags=integration -v -count 1 ./tests
test-race:
@echo 'Running tests with race-detector...'
@ANYTYPE_LOG_NOGELF=1 go test -race github.com/anyproto/anytype-heart/...
@ANYTYPE_LOG_NOGELF=1 go test -count=1 -race github.com/anyproto/anytype-heart/...
test-deps:
@echo 'Generating test mocks...'
@ -328,3 +333,53 @@ ifdef GOLANGCI_LINT_BRANCH
else
@golangci-lint run -v ./... --new-from-rev=origin/main --timeout 15m --fix
endif
### Tantivy Section
REPO := anyproto/tantivy-go
VERSION := go/v0.0.5
OUTPUT_DIR := deps/libs
SHA_FILE = tantivity_sha256.txt
TANTIVY_LIBS := android-386.tar.gz \
android-amd64.tar.gz \
android-arm.tar.gz \
android-arm64.tar.gz \
darwin-amd64.tar.gz \
darwin-arm64.tar.gz \
ios-amd64.tar.gz \
ios-arm64.tar.gz \
linux-amd64-musl.tar.gz \
windows-amd64.tar.gz
define download_tantivy_lib
curl -L -o $(OUTPUT_DIR)/$(1) https://github.com/$(REPO)/releases/download/$(VERSION)/$(1)
endef
define remove_arch
rm -f $(OUTPUT_DIR)/$(1)
endef
download-tantivy: $(TANTIVY_LIBS)
$(TANTIVY_LIBS):
@mkdir -p $(OUTPUT_DIR)/$(shell echo $@ | cut -d'.' -f1)
$(call download_tantivy_lib,$@)
@tar -C $(OUTPUT_DIR)/$(shell echo $@ | cut -d'.' -f1) -xvzf $(OUTPUT_DIR)/$@
download-tantivy-all-force: download-tantivy
@rm -f $(SHA_FILE)
@for file in $(TANTIVY_LIBS); do \
echo "SHA256 $(OUTPUT_DIR)/$$file" ; \
shasum -a 256 $(OUTPUT_DIR)/$$file | awk '{print $$1 " " "'$(OUTPUT_DIR)/$$file'" }' >> $(SHA_FILE); \
done
@echo "SHA256 checksums generated."
download-tantivy-all: download-tantivy
@echo "Validating SHA256 checksums..."
@shasum -a 256 -c $(SHA_FILE) --status || { echo "Hash mismatch detected."; exit 1; }
@echo "All files are valid."
download-tantivy-local:
@mkdir -p $(OUTPUT_DIR)
@cp -r $(TANTIVY_GO_PATH)/go/libs/ $(OUTPUT_DIR)

View file

@ -82,6 +82,7 @@ import (
"github.com/anyproto/anytype-heart/core/syncstatus/detailsupdater"
"github.com/anyproto/anytype-heart/core/syncstatus/nodestatus"
"github.com/anyproto/anytype-heart/core/syncstatus/spacesyncstatus"
"github.com/anyproto/anytype-heart/core/syncstatus/syncsubscriptions"
"github.com/anyproto/anytype-heart/core/wallet"
"github.com/anyproto/anytype-heart/metrics"
"github.com/anyproto/anytype-heart/pkg/lib/core"
@ -205,7 +206,8 @@ func Bootstrap(a *app.App, components ...app.Component) {
// Data storages
Register(clientds.New()).
Register(debugstat.New()).
Register(ftsearch.New()).
// Register(ftsearch.BleveNew()).
Register(ftsearch.TantivyNew()).
Register(objectstore.New()).
Register(backlinks.New()).
Register(filestore.New()).
@ -266,7 +268,7 @@ func Bootstrap(a *app.App, components ...app.Component) {
Register(treemanager.New()).
Register(block.New()).
Register(indexer.New()).
Register(detailsupdater.NewUpdater()).
Register(detailsupdater.New()).
Register(session.NewHookRunner()).
Register(spacesyncstatus.NewSpaceSyncStatus()).
Register(nodestatus.NewNodeStatus()).
@ -280,6 +282,7 @@ func Bootstrap(a *app.App, components ...app.Component) {
Register(debug.New()).
Register(collection.New()).
Register(subscription.New()).
Register(syncsubscriptions.New()).
Register(builtinobjects.New()).
Register(bookmark.New()).
Register(decorator.New()).
@ -292,7 +295,7 @@ func Bootstrap(a *app.App, components ...app.Component) {
Register(profiler.New()).
Register(identity.New(30*time.Second, 10*time.Second)).
Register(templateservice.New()).
Register(notifications.New()).
Register(notifications.New(time.Second * 10)).
Register(paymentserviceclient.New()).
Register(nameservice.New()).
Register(nameserviceclient.New()).

View file

@ -71,6 +71,7 @@ func (s *Service) AccountSelect(ctx context.Context, req *pb.RpcAccountSelectReq
if err := s.stop(); err != nil {
return nil, errors.Join(ErrFailedToStopApplication, err)
}
metrics.Service.SetWorkingDir(req.RootPath, req.Id)
return s.start(ctx, req.Id, req.RootPath, req.DisableLocalNetworkSync, req.PreferYamuxTransport, req.NetworkMode, req.NetworkCustomConfigFilePath)
}

View file

@ -190,9 +190,7 @@ func (s *Service) CreateCollection(details *types.Struct, flags []*model.Interna
newState := state.NewDoc("", nil).NewState().SetDetails(details)
tmpls := []template.StateTransformer{
template.WithRequiredRelations(),
}
tmpls := []template.StateTransformer{}
blockContent := template.MakeCollectionDataviewContent()
tmpls = append(tmpls,

View file

@ -8,6 +8,7 @@ import (
"github.com/anyproto/anytype-heart/core/block/editor/state"
"github.com/anyproto/anytype-heart/core/block/editor/template"
"github.com/anyproto/anytype-heart/core/block/migration"
"github.com/anyproto/anytype-heart/core/domain"
"github.com/anyproto/anytype-heart/core/relationutils"
"github.com/anyproto/anytype-heart/pkg/lib/bundle"
"github.com/anyproto/anytype-heart/pkg/lib/database"
@ -17,6 +18,9 @@ import (
"github.com/anyproto/anytype-heart/util/slice"
)
// required relations for archive beside the bundle.RequiredInternalRelations
var archiveRequiredRelations = []domain.RelationKey{}
type Archive struct {
smartblock.SmartBlock
collection.Collection
@ -35,6 +39,7 @@ func NewArchive(
}
func (p *Archive) Init(ctx *smartblock.InitContext) (err error) {
ctx.RequiredInternalRelationKeys = append(ctx.RequiredInternalRelationKeys, archiveRequiredRelations...)
if err = p.SmartBlock.Init(ctx); err != nil {
return
}

View file

@ -10,6 +10,7 @@ import (
"github.com/anyproto/anytype-heart/core/block/editor/converter"
"github.com/anyproto/anytype-heart/core/block/editor/smartblock"
"github.com/anyproto/anytype-heart/core/block/editor/state"
"github.com/anyproto/anytype-heart/core/block/editor/table"
"github.com/anyproto/anytype-heart/core/block/editor/template"
"github.com/anyproto/anytype-heart/core/block/restriction"
"github.com/anyproto/anytype-heart/core/block/simple"
@ -19,6 +20,8 @@ import (
relationblock "github.com/anyproto/anytype-heart/core/block/simple/relation"
"github.com/anyproto/anytype-heart/core/block/simple/text"
"github.com/anyproto/anytype-heart/core/domain"
"github.com/anyproto/anytype-heart/core/domain/objectorigin"
"github.com/anyproto/anytype-heart/core/files/fileobject"
"github.com/anyproto/anytype-heart/core/session"
"github.com/anyproto/anytype-heart/pb"
"github.com/anyproto/anytype-heart/pkg/lib/bundle"
@ -102,19 +105,22 @@ func NewBasic(
sb smartblock.SmartBlock,
objectStore objectstore.ObjectStore,
layoutConverter converter.LayoutConverter,
fileObjectService fileobject.Service,
) AllOperations {
return &basic{
SmartBlock: sb,
objectStore: objectStore,
layoutConverter: layoutConverter,
SmartBlock: sb,
objectStore: objectStore,
layoutConverter: layoutConverter,
fileObjectService: fileObjectService,
}
}
type basic struct {
smartblock.SmartBlock
objectStore objectstore.ObjectStore
layoutConverter converter.LayoutConverter
objectStore objectstore.ObjectStore
layoutConverter converter.LayoutConverter
fileObjectService fileobject.Service
}
func (bs *basic) CreateBlock(s *state.State, req pb.RpcBlockCreateRequest) (id string, err error) {
@ -148,7 +154,7 @@ func (bs *basic) CreateBlock(s *state.State, req pb.RpcBlockCreateRequest) (id s
func (bs *basic) Duplicate(srcState, destState *state.State, targetBlockId string, position model.BlockPosition, blockIds []string) (newIds []string, err error) {
blockIds = srcState.SelectRoots(blockIds)
for _, id := range blockIds {
copyId, e := copyBlocks(srcState, destState, id)
copyId, e := bs.copyBlocks(srcState, destState, id)
if e != nil {
return nil, e
}
@ -168,7 +174,7 @@ type duplicatable interface {
Duplicate(s *state.State) (newId string, visitedIds []string, blocks []simple.Block, err error)
}
func copyBlocks(srcState, destState *state.State, sourceId string) (id string, err error) {
func (bs *basic) copyBlocks(srcState, destState *state.State, sourceId string) (id string, err error) {
b := srcState.Pick(sourceId)
if b == nil {
return "", smartblock.ErrSimpleBlockNotFound
@ -189,13 +195,37 @@ func copyBlocks(srcState, destState *state.State, sourceId string) (id string, e
result := simple.New(m)
destState.Add(result)
for i, childrenId := range result.Model().ChildrenIds {
if result.Model().ChildrenIds[i], err = copyBlocks(srcState, destState, childrenId); err != nil {
if result.Model().ChildrenIds[i], err = bs.copyBlocks(srcState, destState, childrenId); err != nil {
return
}
}
if f, ok := result.Model().Content.(*model.BlockContentOfFile); ok && srcState.SpaceID() != destState.SpaceID() {
bs.processFileBlock(f, destState.SpaceID())
}
return result.Model().Id, nil
}
func (bs *basic) processFileBlock(f *model.BlockContentOfFile, spaceId string) {
fileId, err := bs.fileObjectService.GetFileIdFromObject(f.File.TargetObjectId)
if err != nil {
log.Errorf("failed to get fileId: %v", err)
return
}
objectId, err := bs.fileObjectService.CreateFromImport(
domain.FullFileId{SpaceId: spaceId, FileId: fileId.FileId},
objectorigin.ObjectOrigin{Origin: model.ObjectOrigin_clipboard},
)
if err != nil {
log.Errorf("failed to create file object: %v", err)
return
}
f.File.TargetObjectId = objectId
}
func (bs *basic) Unlink(ctx session.Context, ids ...string) (err error) {
s := bs.NewStateCtx(ctx)
@ -236,6 +266,11 @@ func (bs *basic) Move(srcState, destState *state.State, targetBlockId string, po
}
}
targetBlockId, position, err = table.CheckTableBlocksMove(srcState, targetBlockId, position, blockIds)
if err != nil {
return err
}
var replacementCandidate simple.Block
for _, id := range blockIds {
if b := srcState.Pick(id); b != nil {

View file

@ -1,19 +1,26 @@
package basic
import (
"errors"
"math/rand"
"testing"
"github.com/gogo/protobuf/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/anyproto/anytype-heart/core/block/editor/converter"
"github.com/anyproto/anytype-heart/core/block/editor/smartblock"
"github.com/anyproto/anytype-heart/core/block/editor/smartblock/smarttest"
"github.com/anyproto/anytype-heart/core/block/editor/table"
"github.com/anyproto/anytype-heart/core/block/editor/template"
"github.com/anyproto/anytype-heart/core/block/restriction"
"github.com/anyproto/anytype-heart/core/block/simple"
"github.com/anyproto/anytype-heart/core/block/simple/text"
"github.com/anyproto/anytype-heart/core/domain"
"github.com/anyproto/anytype-heart/core/files/fileobject"
"github.com/anyproto/anytype-heart/core/files/fileobject/mock_fileobject"
"github.com/anyproto/anytype-heart/pb"
"github.com/anyproto/anytype-heart/pkg/lib/bundle"
"github.com/anyproto/anytype-heart/pkg/lib/pb/model"
@ -38,7 +45,7 @@ func TestBasic_Create(t *testing.T) {
t.Run("generic", func(t *testing.T) {
sb := smarttest.New("test")
sb.AddBlock(simple.New(&model.Block{Id: "test"}))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
st := sb.NewState()
id, err := b.CreateBlock(st, pb.RpcBlockCreateRequest{
Block: &model.Block{Content: &model.BlockContentOfText{Text: &model.BlockContentText{Text: "ll"}}},
@ -52,7 +59,7 @@ func TestBasic_Create(t *testing.T) {
sb := smarttest.New("test")
sb.AddBlock(simple.New(&model.Block{Id: "test"}))
require.NoError(t, smartblock.ObjectApplyTemplate(sb, sb.NewState(), template.WithTitle))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
s := sb.NewState()
id, err := b.CreateBlock(s, pb.RpcBlockCreateRequest{
TargetId: template.TitleBlockId,
@ -73,29 +80,123 @@ func TestBasic_Create(t *testing.T) {
}
sb.AddBlock(simple.New(&model.Block{Id: "test"}))
require.NoError(t, smartblock.ObjectApplyTemplate(sb, sb.NewState(), template.WithTitle))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
_, err := b.CreateBlock(sb.NewState(), pb.RpcBlockCreateRequest{})
assert.ErrorIs(t, err, restriction.ErrRestricted)
})
}
func TestBasic_Duplicate(t *testing.T) {
sb := smarttest.New("test")
sb.AddBlock(simple.New(&model.Block{Id: "test", ChildrenIds: []string{"2"}})).
AddBlock(simple.New(&model.Block{Id: "2", ChildrenIds: []string{"3"}})).
AddBlock(simple.New(&model.Block{Id: "3"}))
t.Run("dup blocks to same state", func(t *testing.T) {
sb := smarttest.New("test")
sb.AddBlock(simple.New(&model.Block{Id: "test", ChildrenIds: []string{"2"}})).
AddBlock(simple.New(&model.Block{Id: "2", ChildrenIds: []string{"3"}})).
AddBlock(simple.New(&model.Block{Id: "3"}))
st := sb.NewState()
newIds, err := NewBasic(sb, nil, converter.NewLayoutConverter()).Duplicate(st, st, "", 0, []string{"2"})
require.NoError(t, err)
st := sb.NewState()
newIds, err := NewBasic(sb, nil, converter.NewLayoutConverter(), nil).Duplicate(st, st, "", 0, []string{"2"})
require.NoError(t, err)
err = sb.Apply(st)
require.NoError(t, err)
err = sb.Apply(st)
require.NoError(t, err)
require.Len(t, newIds, 1)
s := sb.NewState()
assert.Len(t, s.Pick(newIds[0]).Model().ChildrenIds, 1)
assert.Len(t, sb.Blocks(), 5)
})
for _, tc := range []struct {
name string
fos func() fileobject.Service
spaceIds []string
targets []string
}{
{
name: "dup file block - same space",
fos: func() fileobject.Service {
return nil
},
spaceIds: []string{"space1", "space1"},
targets: []string{"file1_space1", "file2_space1"},
},
{
name: "dup file block - other space",
fos: func() fileobject.Service {
fos := mock_fileobject.NewMockService(t)
fos.EXPECT().GetFileIdFromObject("file1_space1").Return(domain.FullFileId{SpaceId: "space1", FileId: "file1"}, nil)
fos.EXPECT().CreateFromImport(domain.FullFileId{SpaceId: "space2", FileId: "file1"}, mock.Anything).Return("file1_space2", nil)
fos.EXPECT().GetFileIdFromObject("file2_space1").Return(domain.FullFileId{SpaceId: "space1", FileId: "file2"}, nil)
fos.EXPECT().CreateFromImport(domain.FullFileId{SpaceId: "space2", FileId: "file2"}, mock.Anything).Return("file2_space2", nil)
return fos
},
spaceIds: []string{"space1", "space2"},
targets: []string{"file1_space2", "file2_space2"},
},
{
name: "dup file block - no target change if failed to retrieve file id",
fos: func() fileobject.Service {
fos := mock_fileobject.NewMockService(t)
fos.EXPECT().GetFileIdFromObject(mock.Anything).Return(domain.FullFileId{}, errors.New("no such file")).Times(2)
return fos
},
spaceIds: []string{"space1", "space2"},
targets: []string{"file1_space1", "file2_space1"},
},
{
name: "dup file block - no target change if failed to create file object",
fos: func() fileobject.Service {
fos := mock_fileobject.NewMockService(t)
fos.EXPECT().GetFileIdFromObject("file1_space1").Return(domain.FullFileId{SpaceId: "space1", FileId: "file1"}, nil)
fos.EXPECT().GetFileIdFromObject("file2_space1").Return(domain.FullFileId{SpaceId: "space1", FileId: "file2"}, nil)
fos.EXPECT().CreateFromImport(mock.Anything, mock.Anything).Return("", errors.New("creation failure"))
return fos
},
spaceIds: []string{"space1", "space2"},
targets: []string{"file1_space1", "file2_space1"},
},
} {
t.Run(tc.name, func(t *testing.T) {
// given
source := smarttest.New("source").
AddBlock(simple.New(&model.Block{Id: "source", ChildrenIds: []string{"1", "f1"}})).
AddBlock(simple.New(&model.Block{Id: "1", ChildrenIds: []string{"f2"}})).
AddBlock(simple.New(&model.Block{Id: "f1", Content: &model.BlockContentOfFile{File: &model.BlockContentFile{TargetObjectId: "file1_space1"}}})).
AddBlock(simple.New(&model.Block{Id: "f2", Content: &model.BlockContentOfFile{File: &model.BlockContentFile{TargetObjectId: "file2_space1"}}}))
ss := source.NewState()
ss.SetDetail(bundle.RelationKeySpaceId.String(), pbtypes.String(tc.spaceIds[0]))
target := smarttest.New("target").
AddBlock(simple.New(&model.Block{Id: "target"}))
ts := target.NewState()
ts.SetDetail(bundle.RelationKeySpaceId.String(), pbtypes.String(tc.spaceIds[1]))
// when
newIds, err := NewBasic(source, nil, nil, tc.fos()).Duplicate(ss, ts, "target", model.Block_Inner, []string{"1", "f1"})
require.NoError(t, err)
require.NoError(t, target.Apply(ts))
// then
assert.Len(t, newIds, 2)
ts = target.NewState()
root := ts.Pick("target")
assert.Equal(t, newIds, root.Model().ChildrenIds)
block1 := ts.Pick(newIds[0])
require.NotNil(t, block1)
blockChildren := block1.Model().ChildrenIds
assert.Len(t, blockChildren, 1)
for fbID, targetID := range map[string]string{newIds[1]: tc.targets[0], blockChildren[0]: tc.targets[1]} {
fb := ts.Pick(fbID)
assert.NotNil(t, fb)
f, ok := fb.Model().Content.(*model.BlockContentOfFile)
assert.True(t, ok)
assert.Equal(t, targetID, f.File.TargetObjectId)
}
})
}
require.Len(t, newIds, 1)
s := sb.NewState()
assert.Len(t, s.Pick(newIds[0]).Model().ChildrenIds, 1)
assert.Len(t, sb.Blocks(), 5)
}
func TestBasic_Unlink(t *testing.T) {
@ -105,7 +206,7 @@ func TestBasic_Unlink(t *testing.T) {
AddBlock(simple.New(&model.Block{Id: "2", ChildrenIds: []string{"3"}})).
AddBlock(simple.New(&model.Block{Id: "3"}))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
err := b.Unlink(nil, "2")
require.NoError(t, err)
@ -119,7 +220,7 @@ func TestBasic_Unlink(t *testing.T) {
AddBlock(simple.New(&model.Block{Id: "2", ChildrenIds: []string{"3"}})).
AddBlock(simple.New(&model.Block{Id: "3"}))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
err := b.Unlink(nil, "2", "3")
require.NoError(t, err)
@ -136,7 +237,7 @@ func TestBasic_Move(t *testing.T) {
AddBlock(simple.New(&model.Block{Id: "3"})).
AddBlock(simple.New(&model.Block{Id: "4"}))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
st := sb.NewState()
err := b.Move(st, st, "4", model.Block_Inner, []string{"3"})
@ -150,7 +251,7 @@ func TestBasic_Move(t *testing.T) {
sb := smarttest.New("test")
sb.AddBlock(simple.New(&model.Block{Id: "test"}))
require.NoError(t, smartblock.ObjectApplyTemplate(sb, sb.NewState(), template.WithTitle))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
s := sb.NewState()
id1, err := b.CreateBlock(s, pb.RpcBlockCreateRequest{
TargetId: template.HeaderLayoutId,
@ -199,7 +300,7 @@ func TestBasic_Move(t *testing.T) {
},
),
)
basic := NewBasic(testDoc, nil, converter.NewLayoutConverter())
basic := NewBasic(testDoc, nil, converter.NewLayoutConverter(), nil)
state := testDoc.NewState()
// when
@ -215,7 +316,7 @@ func TestBasic_Move(t *testing.T) {
AddBlock(newTextBlock("1", "", nil)).
AddBlock(newTextBlock("2", "one", nil))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
st := sb.NewState()
err := b.Move(st, st, "1", model.Block_InnerFirst, []string{"2"})
require.NoError(t, err)
@ -235,7 +336,7 @@ func TestBasic_Move(t *testing.T) {
AddBlock(firstBlock).
AddBlock(secondBlock)
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
st := sb.NewState()
err := b.Move(st, st, "1", model.Block_InnerFirst, []string{"2"})
require.NoError(t, err)
@ -249,7 +350,7 @@ func TestBasic_Move(t *testing.T) {
AddBlock(newTextBlock("1", "", nil)).
AddBlock(newTextBlock("2", "one", nil))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
st := sb.NewState()
err := b.Move(st, nil, "1", model.Block_Top, []string{"2"})
require.NoError(t, err)
@ -258,6 +359,152 @@ func TestBasic_Move(t *testing.T) {
})
}
func TestBasic_MoveTableBlocks(t *testing.T) {
getSB := func() *smarttest.SmartTest {
sb := smarttest.New("test")
sb.AddBlock(simple.New(&model.Block{Id: "test", ChildrenIds: []string{"upper", "table", "block"}})).
AddBlock(simple.New(&model.Block{Id: "table", ChildrenIds: []string{"columns", "rows"}, Content: &model.BlockContentOfTable{Table: &model.BlockContentTable{}}})).
AddBlock(simple.New(&model.Block{Id: "columns", ChildrenIds: []string{"column"}, Content: &model.BlockContentOfLayout{Layout: &model.BlockContentLayout{Style: model.BlockContentLayout_TableColumns}}})).
AddBlock(simple.New(&model.Block{Id: "column", ChildrenIds: []string{}, Content: &model.BlockContentOfTableColumn{TableColumn: &model.BlockContentTableColumn{}}})).
AddBlock(simple.New(&model.Block{Id: "rows", ChildrenIds: []string{"row", "row2"}, Content: &model.BlockContentOfLayout{Layout: &model.BlockContentLayout{Style: model.BlockContentLayout_TableRows}}})).
AddBlock(simple.New(&model.Block{Id: "row", ChildrenIds: []string{"column-row"}, Content: &model.BlockContentOfTableRow{TableRow: &model.BlockContentTableRow{IsHeader: false}}})).
AddBlock(simple.New(&model.Block{Id: "row2", ChildrenIds: []string{}, Content: &model.BlockContentOfTableRow{TableRow: &model.BlockContentTableRow{IsHeader: false}}})).
AddBlock(simple.New(&model.Block{Id: "column-row", ChildrenIds: []string{}})).
AddBlock(simple.New(&model.Block{Id: "block", ChildrenIds: []string{}})).
AddBlock(simple.New(&model.Block{Id: "upper", ChildrenIds: []string{}}))
return sb
}
for _, block := range []string{"columns", "rows", "column", "row", "column-row"} {
t.Run("moving non-root table block '"+block+"' leads to error", func(t *testing.T) {
// given
sb := getSB()
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
st := sb.NewState()
// when
err := b.Move(st, st, "block", model.Block_Bottom, []string{block})
// then
assert.Error(t, err)
assert.True(t, errors.Is(err, table.ErrCannotMoveTableBlocks))
})
}
t.Run("no error on moving root table block", func(t *testing.T) {
// given
sb := getSB()
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
st := sb.NewState()
// when
err := b.Move(st, st, "block", model.Block_Bottom, []string{"table"})
// then
assert.NoError(t, err)
assert.Equal(t, []string{"upper", "block", "table"}, st.Pick("test").Model().ChildrenIds)
})
t.Run("no error on moving one row between another", func(t *testing.T) {
// given
sb := getSB()
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
st := sb.NewState()
// when
err := b.Move(st, st, "row2", model.Block_Bottom, []string{"row"})
// then
assert.NoError(t, err)
assert.Equal(t, []string{"row2", "row"}, st.Pick("rows").Model().ChildrenIds)
})
t.Run("moving rows with incorrect position leads to error", func(t *testing.T) {
// given
sb := getSB()
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
st := sb.NewState()
// when
err := b.Move(st, st, "row2", model.Block_Left, []string{"row"})
// then
assert.Error(t, err)
})
t.Run("moving rows and some other blocks between another leads to error", func(t *testing.T) {
// given
sb := getSB()
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
st := sb.NewState()
// when
err := b.Move(st, st, "row2", model.Block_Top, []string{"row", "rows"})
// then
assert.Error(t, err)
})
t.Run("moving the row between itself leads to error", func(t *testing.T) {
// given
sb := getSB()
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
st := sb.NewState()
// when
err := b.Move(st, st, "row2", model.Block_Bottom, []string{"row2"})
// then
assert.Error(t, err)
})
t.Run("moving table block from invalid table leads to error", func(t *testing.T) {
// given
sb := getSB()
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
st := sb.NewState()
st.Unlink("columns")
// when
err := b.Move(st, st, "block", model.Block_Bottom, []string{"column-row"})
// then
assert.Error(t, err)
assert.True(t, errors.Is(err, table.ErrCannotMoveTableBlocks))
})
for _, block := range []string{"columns", "rows", "column", "row", "column-row"} {
t.Run("moving a block to '"+block+"' block leads to moving it under the table", func(t *testing.T) {
// given
sb := getSB()
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
st := sb.NewState()
// when
err := b.Move(st, st, block, model.BlockPosition(rand.Intn(len(model.BlockPosition_name))), []string{"upper"})
// then
assert.NoError(t, err)
assert.Equal(t, []string{"table", "upper", "block"}, st.Pick("test").Model().ChildrenIds)
})
}
t.Run("moving a block to the invalid table leads to moving it under the table", func(t *testing.T) {
// given
sb := getSB()
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
st := sb.NewState()
st.Unlink("columns")
// when
err := b.Move(st, st, "rows", model.BlockPosition(rand.Intn(6)), []string{"upper"})
// then
assert.NoError(t, err)
assert.Equal(t, []string{"table", "upper", "block"}, st.Pick("test").Model().ChildrenIds)
})
}
func TestBasic_MoveToAnotherObject(t *testing.T) {
t.Run("basic", func(t *testing.T) {
sb1 := smarttest.New("test1")
@ -269,7 +516,7 @@ func TestBasic_MoveToAnotherObject(t *testing.T) {
sb2 := smarttest.New("test2")
sb2.AddBlock(simple.New(&model.Block{Id: "test2", ChildrenIds: []string{}}))
b := NewBasic(sb1, nil, converter.NewLayoutConverter())
b := NewBasic(sb1, nil, converter.NewLayoutConverter(), nil)
srcState := sb1.NewState()
destState := sb2.NewState()
@ -304,7 +551,7 @@ func TestBasic_Replace(t *testing.T) {
sb := smarttest.New("test")
sb.AddBlock(simple.New(&model.Block{Id: "test", ChildrenIds: []string{"2"}})).
AddBlock(simple.New(&model.Block{Id: "2"}))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
newId, err := b.Replace(nil, "2", &model.Block{Content: &model.BlockContentOfText{Text: &model.BlockContentText{Text: "l"}}})
require.NoError(t, err)
require.NotEmpty(t, newId)
@ -314,7 +561,7 @@ func TestBasic_SetFields(t *testing.T) {
sb := smarttest.New("test")
sb.AddBlock(simple.New(&model.Block{Id: "test", ChildrenIds: []string{"2"}})).
AddBlock(simple.New(&model.Block{Id: "2"}))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
fields := &types.Struct{
Fields: map[string]*types.Value{
@ -333,7 +580,7 @@ func TestBasic_Update(t *testing.T) {
sb := smarttest.New("test")
sb.AddBlock(simple.New(&model.Block{Id: "test", ChildrenIds: []string{"2"}})).
AddBlock(simple.New(&model.Block{Id: "2"}))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
err := b.Update(nil, func(b simple.Block) error {
b.Model().BackgroundColor = "test"
@ -347,7 +594,7 @@ func TestBasic_SetDivStyle(t *testing.T) {
sb := smarttest.New("test")
sb.AddBlock(simple.New(&model.Block{Id: "test", ChildrenIds: []string{"2"}})).
AddBlock(simple.New(&model.Block{Id: "2", Content: &model.BlockContentOfDiv{Div: &model.BlockContentDiv{}}}))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
err := b.SetDivStyle(nil, model.BlockContentDiv_Dots, "2")
require.NoError(t, err)
@ -358,7 +605,7 @@ func TestBasic_SetDivStyle(t *testing.T) {
func TestBasic_PasteBlocks(t *testing.T) {
sb := smarttest.New("test")
sb.AddBlock(simple.New(&model.Block{Id: "test"}))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
s := sb.NewState()
err := b.PasteBlocks(s, "", model.Block_Inner, []simple.Block{
simple.New(&model.Block{Id: "1", ChildrenIds: []string{"1.1"}}),
@ -385,7 +632,7 @@ func TestBasic_SetRelationKey(t *testing.T) {
t.Run("correct", func(t *testing.T) {
sb := smarttest.New("test")
fillSb(sb)
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
err := b.SetRelationKey(nil, pb.RpcBlockRelationSetKeyRequest{
BlockId: "2",
Key: "testRelKey",
@ -407,7 +654,7 @@ func TestBasic_SetRelationKey(t *testing.T) {
t.Run("not relation block", func(t *testing.T) {
sb := smarttest.New("test")
fillSb(sb)
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
require.Error(t, b.SetRelationKey(nil, pb.RpcBlockRelationSetKeyRequest{
BlockId: "1",
Key: "key",
@ -416,7 +663,7 @@ func TestBasic_SetRelationKey(t *testing.T) {
t.Run("relation not found", func(t *testing.T) {
sb := smarttest.New("test")
fillSb(sb)
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
require.Error(t, b.SetRelationKey(nil, pb.RpcBlockRelationSetKeyRequest{
BlockId: "2",
Key: "not exists",
@ -428,11 +675,11 @@ func TestBasic_FeaturedRelationAdd(t *testing.T) {
sb := smarttest.New("test")
s := sb.NewState()
template.WithTitle(s)
s.AddBundledRelations(bundle.RelationKeyName)
s.AddBundledRelations(bundle.RelationKeyDescription)
s.AddBundledRelationLinks(bundle.RelationKeyName)
s.AddBundledRelationLinks(bundle.RelationKeyDescription)
require.NoError(t, sb.Apply(s))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
newRel := []string{bundle.RelationKeyDescription.String(), bundle.RelationKeyName.String()}
require.NoError(t, b.FeaturedRelationAdd(nil, newRel...))
@ -448,7 +695,7 @@ func TestBasic_FeaturedRelationRemove(t *testing.T) {
template.WithDescription(s)
require.NoError(t, sb.Apply(s))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
require.NoError(t, b.FeaturedRelationRemove(nil, bundle.RelationKeyDescription.String()))
res := sb.NewState()
@ -485,7 +732,7 @@ func TestBasic_ReplaceLink(t *testing.T) {
}
require.NoError(t, sb.Apply(s))
b := NewBasic(sb, nil, converter.NewLayoutConverter())
b := NewBasic(sb, nil, converter.NewLayoutConverter(), nil)
require.NoError(t, b.ReplaceLink(oldId, newId))
res := sb.NewState()

View file

@ -32,7 +32,7 @@ func newDUFixture(t *testing.T) *duFixture {
store := objectstore.NewStoreFixture(t)
b := NewBasic(sb, store, converter.NewLayoutConverter())
b := NewBasic(sb, store, converter.NewLayoutConverter(), nil)
return &duFixture{
sb: sb,

View file

@ -2,7 +2,6 @@ package basic
import (
"context"
"errors"
"fmt"
"github.com/globalsign/mgo/bson"
@ -119,31 +118,44 @@ func insertBlocksToState(
}
func (bs *basic) changeToBlockWithLink(newState *state.State, blockToReplace simple.Block, objectID string, linkBlock *model.Block) (string, error) {
if linkBlock == nil {
linkBlock = &model.Block{
Content: &model.BlockContentOfLink{
Link: &model.BlockContentLink{
TargetBlockId: objectID,
Style: model.BlockContentLink_Page,
},
},
}
} else {
link := linkBlock.GetLink()
if link == nil {
return "", errors.New("linkBlock content is not a link")
} else {
link.TargetBlockId = objectID
}
}
linkBlockCopy := pbtypes.CopyBlock(linkBlock)
return bs.CreateBlock(newState, pb.RpcBlockCreateRequest{
TargetId: blockToReplace.Model().Id,
Block: linkBlockCopy,
Block: buildBlock(linkBlock, objectID),
Position: model.Block_Replace,
})
}
func buildBlock(b *model.Block, targetID string) (result *model.Block) {
fallback := &model.Block{
Content: &model.BlockContentOfLink{
Link: &model.BlockContentLink{
TargetBlockId: targetID,
Style: model.BlockContentLink_Page,
},
},
}
if b == nil {
return fallback
}
result = pbtypes.CopyBlock(b)
switch v := result.Content.(type) {
case *model.BlockContentOfLink:
v.Link.TargetBlockId = targetID
case *model.BlockContentOfBookmark:
v.Bookmark.TargetObjectId = targetID
case *model.BlockContentOfFile:
v.File.TargetObjectId = targetID
case *model.BlockContentOfDataview:
v.Dataview.TargetObjectId = targetID
default:
result = fallback
}
return
}
func removeBlocks(state *state.State, descendants []simple.Block) {
for _, b := range descendants {
state.Unlink(b.Model().Id)

View file

@ -59,7 +59,7 @@ func (tts testTemplateService) CreateTemplateStateWithDetails(id string, details
template.InitTemplate(st, template.WithEmpty,
template.WithDefaultFeaturedRelations,
template.WithFeaturedRelations,
template.WithRequiredRelations(),
template.WithRequiredRelations,
template.WithTitle,
)
return st, nil
@ -290,7 +290,7 @@ func TestExtractObjects(t *testing.T) {
ObjectTypeUniqueKey: domain.MustUniqueKey(coresb.SmartBlockTypeObjectType, bundle.TypeKeyNote.String()).Marshal(),
}
ctx := session.NewContext()
linkIds, err := NewBasic(sb, fixture.store, converter.NewLayoutConverter()).ExtractBlocksToObjects(ctx, creator, ts, req)
linkIds, err := NewBasic(sb, fixture.store, converter.NewLayoutConverter(), nil).ExtractBlocksToObjects(ctx, creator, ts, req)
assert.NoError(t, err)
var gotBlockIds []string
@ -345,7 +345,7 @@ func TestExtractObjects(t *testing.T) {
}},
}
ctx := session.NewContext()
_, err := NewBasic(sb, fixture.store, converter.NewLayoutConverter()).ExtractBlocksToObjects(ctx, creator, ts, req)
_, err := NewBasic(sb, fixture.store, converter.NewLayoutConverter(), nil).ExtractBlocksToObjects(ctx, creator, ts, req)
assert.NoError(t, err)
var block *model.Block
for _, block = range sb.Blocks() {
@ -378,7 +378,7 @@ func TestExtractObjects(t *testing.T) {
}},
}
ctx := session.NewContext()
_, err := NewBasic(sb, fixture.store, converter.NewLayoutConverter()).ExtractBlocksToObjects(ctx, creator, ts, req)
_, err := NewBasic(sb, fixture.store, converter.NewLayoutConverter(), nil).ExtractBlocksToObjects(ctx, creator, ts, req)
assert.NoError(t, err)
var addedBlocks []*model.Block
for _, message := range sb.Results.Events {
@ -394,6 +394,84 @@ func TestExtractObjects(t *testing.T) {
})
}
func TestBuildBlock(t *testing.T) {
const target = "target"
for _, tc := range []struct {
name string
input, output *model.Block
}{
{
name: "nil",
input: nil,
output: &model.Block{Content: &model.BlockContentOfLink{Link: &model.BlockContentLink{
TargetBlockId: target,
Style: model.BlockContentLink_Page,
}}},
},
{
name: "link",
input: &model.Block{Content: &model.BlockContentOfLink{Link: &model.BlockContentLink{
Style: model.BlockContentLink_Dashboard,
CardStyle: model.BlockContentLink_Card,
}}},
output: &model.Block{Content: &model.BlockContentOfLink{Link: &model.BlockContentLink{
TargetBlockId: target,
Style: model.BlockContentLink_Dashboard,
CardStyle: model.BlockContentLink_Card,
}}},
},
{
name: "bookmark",
input: &model.Block{Content: &model.BlockContentOfBookmark{Bookmark: &model.BlockContentBookmark{
Type: model.LinkPreview_Image,
State: model.BlockContentBookmark_Fetching,
}}},
output: &model.Block{Content: &model.BlockContentOfBookmark{Bookmark: &model.BlockContentBookmark{
TargetObjectId: target,
Type: model.LinkPreview_Image,
State: model.BlockContentBookmark_Fetching,
}}},
},
{
name: "file",
input: &model.Block{Content: &model.BlockContentOfFile{File: &model.BlockContentFile{
Type: model.BlockContentFile_Image,
}}},
output: &model.Block{Content: &model.BlockContentOfFile{File: &model.BlockContentFile{
TargetObjectId: target,
Type: model.BlockContentFile_Image,
}}},
},
{
name: "dataview",
input: &model.Block{Content: &model.BlockContentOfDataview{Dataview: &model.BlockContentDataview{
IsCollection: true,
Source: []string{"ot-note"},
}}},
output: &model.Block{Content: &model.BlockContentOfDataview{Dataview: &model.BlockContentDataview{
TargetObjectId: target,
IsCollection: true,
Source: []string{"ot-note"},
}}},
},
{
name: "other",
input: &model.Block{Content: &model.BlockContentOfTableRow{TableRow: &model.BlockContentTableRow{
IsHeader: true,
}}},
output: &model.Block{Content: &model.BlockContentOfLink{Link: &model.BlockContentLink{
TargetBlockId: target,
Style: model.BlockContentLink_Page,
}}},
},
} {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.output, buildBlock(tc.input, target))
})
}
}
type fixture struct {
t *testing.T
ctrl *gomock.Controller

View file

@ -19,6 +19,7 @@ import (
"github.com/anyproto/anytype-heart/core/block/simple"
"github.com/anyproto/anytype-heart/core/block/simple/text"
"github.com/anyproto/anytype-heart/core/converter/html"
"github.com/anyproto/anytype-heart/core/domain"
"github.com/anyproto/anytype-heart/core/domain/objectorigin"
"github.com/anyproto/anytype-heart/core/files"
"github.com/anyproto/anytype-heart/core/files/fileobject"
@ -386,6 +387,9 @@ func (cb *clipboard) pasteAny(
return
}
}
if f, ok := b.Content.(*model.BlockContentOfFile); ok {
cb.processFileBlock(f)
}
}
srcState := cb.blocksToState(req.AnySlot)
visited := map[string]struct{}{}
@ -585,6 +589,29 @@ func (cb *clipboard) newHTMLConverter(s *state.State) *html.HTML {
return html.NewHTMLConverter(cb.fileService, s, cb.fileObjectService)
}
func (cb *clipboard) processFileBlock(f *model.BlockContentOfFile) {
fileId, err := cb.fileObjectService.GetFileIdFromObject(f.File.TargetObjectId)
if err != nil {
log.Errorf("failed to get fileId: %v", err)
return
}
if cb.SpaceID() == fileId.SpaceId {
return
}
objectId, err := cb.fileObjectService.CreateFromImport(
domain.FullFileId{SpaceId: cb.SpaceID(), FileId: fileId.FileId},
objectorigin.ObjectOrigin{Origin: model.ObjectOrigin_clipboard},
)
if err != nil {
log.Errorf("failed to create file object: %v", err)
return
}
f.File.TargetObjectId = objectId
}
func renderText(s *state.State, ignoreStyle bool) string {
texts := make([]string, 0)
texts, _ = renderBlock(s, texts, s.RootId(), -1, 0, ignoreStyle)

View file

@ -13,6 +13,8 @@ import (
"github.com/anyproto/anytype-heart/core/block/editor/smartblock"
"github.com/anyproto/anytype-heart/core/block/editor/smartblock/smarttest"
"github.com/anyproto/anytype-heart/core/block/simple"
"github.com/anyproto/anytype-heart/core/domain"
"github.com/anyproto/anytype-heart/core/files/fileobject/mock_fileobject"
"github.com/anyproto/anytype-heart/pb"
"github.com/anyproto/anytype-heart/pkg/lib/pb/model"
@ -216,7 +218,9 @@ func checkBlockMarksDebug(t *testing.T, sb *smarttest.SmartTest, marksArr [][]*m
func newFixture(t *testing.T, sb smartblock.SmartBlock) Clipboard {
file := file.NewMockFile(t)
file.EXPECT().UploadState(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
return NewClipboard(sb, file, nil, nil, nil, nil)
fos := mock_fileobject.NewMockService(t)
fos.EXPECT().GetFileIdFromObject(mock.Anything).Return(domain.FullFileId{}, fmt.Errorf("no fileId")).Maybe()
return NewClipboard(sb, file, nil, nil, nil, fos)
}
func pasteAny(t *testing.T, sb *smarttest.SmartTest, id string, textRange model.Range, selectedBlockIds []string, blocks []*model.Block) ([]string, bool) {

View file

@ -2,12 +2,14 @@ package clipboard
import (
"errors"
"fmt"
"strconv"
"testing"
"github.com/gogo/protobuf/types"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/anyproto/anytype-heart/core/block/editor/smartblock"
@ -18,6 +20,8 @@ import (
"github.com/anyproto/anytype-heart/core/block/simple"
_ "github.com/anyproto/anytype-heart/core/block/simple/base"
"github.com/anyproto/anytype-heart/core/block/simple/text"
"github.com/anyproto/anytype-heart/core/domain"
"github.com/anyproto/anytype-heart/core/files/fileobject/mock_fileobject"
"github.com/anyproto/anytype-heart/core/session"
"github.com/anyproto/anytype-heart/pb"
"github.com/anyproto/anytype-heart/pkg/lib/pb/model"
@ -1083,7 +1087,7 @@ func TestClipboard_TitleOps(t *testing.T) {
t.Run("do not paste if Blocks restriction is set to smartblock", func(t *testing.T) {
// given
sb := smarttest.New("test")
sb.SetRestrictions(restriction.Restrictions{Object: restriction.ObjectRestrictions{model.Restrictions_Blocks}})
sb.TestRestrictions = restriction.Restrictions{Object: restriction.ObjectRestrictions{model.Restrictions_Blocks}}
cb := newFixture(t, sb)
// when
@ -1799,3 +1803,94 @@ bbb`},
})
}
}
func TestProcessFileBlock(t *testing.T) {
const (
fileObject1 = "fileObject1"
fileObject2 = "fileObject2"
space1 = "space1"
space2 = "space2"
fileId = domain.FileId("fileId")
)
sb := smarttest.New("test")
sb.SetSpaceId(space1)
t.Run("old target object id remains if space is the same", func(t *testing.T) {
// given
file := mock_fileobject.NewMockService(t)
file.EXPECT().GetFileIdFromObject(fileObject1).Return(domain.FullFileId{SpaceId: space1, FileId: fileId}, nil)
c := &clipboard{
SmartBlock: sb,
fileObjectService: file,
}
fb := &model.BlockContentOfFile{File: &model.BlockContentFile{TargetObjectId: fileObject1}}
// when
c.processFileBlock(fb)
// then
assert.Equal(t, fileObject1, fb.File.TargetObjectId)
})
t.Run("new target object id is set if space is different", func(t *testing.T) {
// given
file := mock_fileobject.NewMockService(t)
file.EXPECT().GetFileIdFromObject(fileObject1).Return(domain.FullFileId{SpaceId: space2, FileId: fileId}, nil)
file.EXPECT().CreateFromImport(domain.FullFileId{FileId: fileId, SpaceId: space1}, mock.Anything).Return(fileObject2, nil)
c := &clipboard{
SmartBlock: sb,
fileObjectService: file,
}
fb := &model.BlockContentOfFile{File: &model.BlockContentFile{TargetObjectId: fileObject1}}
// when
c.processFileBlock(fb)
// then
assert.Equal(t, fileObject2, fb.File.TargetObjectId)
})
t.Run("old target object id remains if failed to create new object", func(t *testing.T) {
// given
file := mock_fileobject.NewMockService(t)
file.EXPECT().GetFileIdFromObject(fileObject1).Return(domain.FullFileId{SpaceId: space2, FileId: fileId}, nil)
file.EXPECT().CreateFromImport(domain.FullFileId{FileId: fileId, SpaceId: space1}, mock.Anything).Return("", fmt.Errorf("some error"))
c := &clipboard{
SmartBlock: sb,
fileObjectService: file,
}
fb := &model.BlockContentOfFile{File: &model.BlockContentFile{TargetObjectId: fileObject1}}
// when
c.processFileBlock(fb)
// then
assert.Equal(t, fileObject1, fb.File.TargetObjectId)
})
t.Run("old target object id remains if failed to get file id", func(t *testing.T) {
// given
file := mock_fileobject.NewMockService(t)
file.EXPECT().GetFileIdFromObject(fileObject1).Return(domain.FullFileId{}, fmt.Errorf("not found"))
c := &clipboard{
SmartBlock: sb,
fileObjectService: file,
}
fb := &model.BlockContentOfFile{File: &model.BlockContentFile{TargetObjectId: fileObject1}}
// when
c.processFileBlock(fb)
// then
assert.Equal(t, fileObject1, fb.File.TargetObjectId)
})
}

View file

@ -19,6 +19,9 @@ import (
"github.com/anyproto/anytype-heart/util/slice"
)
// required relations for archive beside the bundle.RequiredInternalRelations
var dashboardRequiredRelations = []domain.RelationKey{}
type Dashboard struct {
smartblock.SmartBlock
basic.AllOperations
@ -30,13 +33,14 @@ type Dashboard struct {
func NewDashboard(sb smartblock.SmartBlock, objectStore objectstore.ObjectStore, layoutConverter converter.LayoutConverter) *Dashboard {
return &Dashboard{
SmartBlock: sb,
AllOperations: basic.NewBasic(sb, objectStore, layoutConverter),
AllOperations: basic.NewBasic(sb, objectStore, layoutConverter, nil),
Collection: collection.NewCollection(sb, objectStore),
objectStore: objectStore,
}
}
func (p *Dashboard) Init(ctx *smartblock.InitContext) (err error) {
ctx.RequiredInternalRelationKeys = append(ctx.RequiredInternalRelationKeys, dashboardRequiredRelations...)
if err = p.SmartBlock.Init(ctx); err != nil {
return
}
@ -55,7 +59,6 @@ func (p *Dashboard) CreationStateMigration(ctx *smartblock.InitContext) migratio
template.WithEmpty,
template.WithDetailName("Home"),
template.WithDetailIconEmoji("🏠"),
template.WithRequiredRelations(),
template.WithNoDuplicateLinks(),
)
},

View file

@ -2,8 +2,12 @@ package editor
import (
"github.com/anyproto/anytype-heart/core/block/editor/smartblock"
"github.com/anyproto/anytype-heart/core/domain"
)
// required relations for device beside the bundle.RequiredInternalRelations
var deviceRequiredRelations = []domain.RelationKey{}
type DevicesObject struct {
smartblock.SmartBlock
deviceService deviceService
@ -17,6 +21,7 @@ func NewDevicesObject(sb smartblock.SmartBlock, deviceService deviceService) *De
}
func (d *DevicesObject) Init(ctx *smartblock.InitContext) (err error) {
ctx.RequiredInternalRelationKeys = append(ctx.RequiredInternalRelationKeys, deviceRequiredRelations...)
if err = d.SmartBlock.Init(ctx); err != nil {
return
}

View file

@ -110,7 +110,7 @@ func TestDropFiles(t *testing.T) {
t.Run("do not drop files to object with Blocks restriction", func(t *testing.T) {
// given
fx := newFixture(t)
fx.sb.SetRestrictions(restriction.Restrictions{Object: restriction.ObjectRestrictions{model.Restrictions_Blocks}})
fx.sb.TestRestrictions = restriction.Restrictions{Object: restriction.ObjectRestrictions{model.Restrictions_Blocks}}
// when
err := fx.sfile.DropFiles(pb.RpcFileDropRequest{})

View file

@ -14,11 +14,18 @@ import (
"github.com/anyproto/anytype-heart/core/files/fileobject"
"github.com/anyproto/anytype-heart/core/files/reconciler"
"github.com/anyproto/anytype-heart/core/filestorage"
"github.com/anyproto/anytype-heart/pkg/lib/bundle"
coresb "github.com/anyproto/anytype-heart/pkg/lib/core/smartblock"
)
// required relations for files beside the bundle.RequiredInternalRelations
var fileRequiredRelations = append(pageRequiredRelations, []domain.RelationKey{
bundle.RelationKeyFileBackupStatus,
bundle.RelationKeyFileSyncStatus,
}...)
func (f *ObjectFactory) newFile(sb smartblock.SmartBlock) *File {
basicComponent := basic.NewBasic(sb, f.objectStore, f.layoutConverter)
basicComponent := basic.NewBasic(sb, f.objectStore, f.layoutConverter, f.fileObjectService)
return &File{
SmartBlock: sb,
ChangeReceiver: sb.(source.ChangeReceiver),
@ -65,6 +72,8 @@ func (f *File) Init(ctx *smartblock.InitContext) error {
return fmt.Errorf("source type should be a file")
}
ctx.RequiredInternalRelationKeys = append(ctx.RequiredInternalRelationKeys, fileRequiredRelations...)
if ctx.BuildOpts.DisableRemoteLoad {
ctx.Ctx = context.WithValue(ctx.Ctx, filestorage.CtxKeyRemoteLoadDisabled, true)
}

View file

@ -5,12 +5,16 @@ import (
"github.com/anyproto/anytype-heart/core/block/editor/state"
"github.com/anyproto/anytype-heart/core/block/editor/template"
"github.com/anyproto/anytype-heart/core/block/migration"
"github.com/anyproto/anytype-heart/core/domain"
)
type NotificationObject struct {
smartblock.SmartBlock
}
// required relations for notifications beside the bundle.RequiredInternalRelations
var notificationsRequiredRelations = []domain.RelationKey{}
func NewNotificationObject(sb smartblock.SmartBlock) *NotificationObject {
return &NotificationObject{
SmartBlock: sb,
@ -30,6 +34,7 @@ func (n *NotificationObject) CreationStateMigration(ctx *smartblock.InitContext)
}
func (n *NotificationObject) Init(ctx *smartblock.InitContext) (err error) {
ctx.RequiredInternalRelationKeys = notificationsRequiredRelations
if err = n.SmartBlock.Init(ctx); err != nil {
return
}

View file

@ -23,6 +23,18 @@ import (
"github.com/anyproto/anytype-heart/util/pbtypes"
)
var pageRequiredRelations = []domain.RelationKey{
bundle.RelationKeyCoverId,
bundle.RelationKeyCoverScale,
bundle.RelationKeyCoverType,
bundle.RelationKeyCoverX,
bundle.RelationKeyCoverY,
bundle.RelationKeySnippet,
bundle.RelationKeyFeaturedRelations,
bundle.RelationKeyLinks,
bundle.RelationKeyLayoutAlign,
}
type Page struct {
smartblock.SmartBlock
basic.AllOperations
@ -46,7 +58,7 @@ func (f *ObjectFactory) newPage(sb smartblock.SmartBlock) *Page {
return &Page{
SmartBlock: sb,
ChangeReceiver: sb.(source.ChangeReceiver),
AllOperations: basic.NewBasic(sb, f.objectStore, f.layoutConverter),
AllOperations: basic.NewBasic(sb, f.objectStore, f.layoutConverter, f.fileObjectService),
IHistory: basic.NewHistory(sb),
Text: stext.NewText(
sb,
@ -72,6 +84,7 @@ func (f *ObjectFactory) newPage(sb smartblock.SmartBlock) *Page {
}
func (p *Page) Init(ctx *smartblock.InitContext) (err error) {
ctx.RequiredInternalRelationKeys = append(ctx.RequiredInternalRelationKeys, pageRequiredRelations...)
if ctx.ObjectTypeKeys == nil && (ctx.State == nil || len(ctx.State.ObjectTypeKeys()) == 0) && ctx.IsNewObject {
ctx.ObjectTypeKeys = []domain.TypeKey{bundle.TypeKeyPage}
}
@ -162,7 +175,6 @@ func (p *Page) CreationStateMigration(ctx *smartblock.InitContext) migration.Mig
template.WithLayout(layout),
template.WithDefaultFeaturedRelations,
template.WithFeaturedRelations,
template.WithRequiredRelations(),
template.WithLinkFieldsMigration,
template.WithCreatorRemovedFromFeaturedRelations,
}

View file

@ -8,19 +8,30 @@ import (
"github.com/anyproto/anytype-heart/core/block/editor/basic"
"github.com/anyproto/anytype-heart/core/block/editor/smartblock"
"github.com/anyproto/anytype-heart/core/block/editor/template"
"github.com/anyproto/anytype-heart/core/domain"
"github.com/anyproto/anytype-heart/pkg/lib/bundle"
"github.com/anyproto/anytype-heart/pkg/lib/pb/model"
"github.com/anyproto/anytype-heart/space/spaceinfo"
"github.com/anyproto/anytype-heart/util/pbtypes"
)
var participantRequiredRelations = []domain.RelationKey{
bundle.RelationKeyGlobalName,
bundle.RelationKeyIdentity,
bundle.RelationKeyBacklinks,
bundle.RelationKeyParticipantPermissions,
bundle.RelationKeyParticipantStatus,
bundle.RelationKeyIdentityProfileLink,
bundle.RelationKeyIsHiddenDiscovery,
}
type participant struct {
smartblock.SmartBlock
basic.DetailsUpdatable
}
func (f *ObjectFactory) newParticipant(sb smartblock.SmartBlock) *participant {
basicComponent := basic.NewBasic(sb, f.objectStore, f.layoutConverter)
basicComponent := basic.NewBasic(sb, f.objectStore, f.layoutConverter, nil)
return &participant{
SmartBlock: sb,
DetailsUpdatable: basicComponent,
@ -29,6 +40,7 @@ func (f *ObjectFactory) newParticipant(sb smartblock.SmartBlock) *participant {
func (p *participant) Init(ctx *smartblock.InitContext) (err error) {
// Details come from aclobjectmanager, see buildParticipantDetails
ctx.RequiredInternalRelationKeys = append(ctx.RequiredInternalRelationKeys, participantRequiredRelations...)
if err = p.SmartBlock.Init(ctx); err != nil {
return

View file

@ -135,7 +135,7 @@ func newStoreFixture(t *testing.T) *objectstore.StoreFixture {
func newParticipantTest(t *testing.T) (*participant, error) {
sb := smarttest.New("root")
store := newStoreFixture(t)
basicComponent := basic.NewBasic(sb, store, nil)
basicComponent := basic.NewBasic(sb, store, nil, nil)
p := &participant{
SmartBlock: sb,
DetailsUpdatable: basicComponent,

View file

@ -40,7 +40,7 @@ func (f *ObjectFactory) newProfile(sb smartblock.SmartBlock) *Profile {
fileComponent := file.NewFile(sb, f.fileBlockService, f.picker, f.processService, f.fileUploaderService)
return &Profile{
SmartBlock: sb,
AllOperations: basic.NewBasic(sb, f.objectStore, f.layoutConverter),
AllOperations: basic.NewBasic(sb, f.objectStore, f.layoutConverter, f.fileObjectService),
IHistory: basic.NewHistory(sb),
Text: stext.NewText(
sb,
@ -82,7 +82,6 @@ func (p *Profile) CreationStateMigration(ctx *smartblock.InitContext) migration.
template.InitTemplate(st,
template.WithObjectTypesAndLayout([]domain.TypeKey{bundle.TypeKeyProfile}, model.ObjectType_profile),
template.WithDetail(bundle.RelationKeyLayoutAlign, pbtypes.Float64(float64(model.Block_AlignCenter))),
template.WithRequiredRelations(),
migrationSetHidden,
)
},

View file

@ -29,8 +29,6 @@ func (sb *smartBlock) updateBackLinks(s *state.State) {
func (sb *smartBlock) injectLinksDetails(s *state.State) {
links := sb.navigationalLinks(s)
links = slice.RemoveMut(links, sb.Id())
// todo: we need to move it to the injectDerivedDetails, but we don't call it now on apply
s.SetLocalDetail(bundle.RelationKeyLinks.String(), pbtypes.StringList(links))
}

View file

@ -40,6 +40,7 @@ import (
"github.com/anyproto/anytype-heart/pkg/lib/logging"
"github.com/anyproto/anytype-heart/pkg/lib/pb/model"
"github.com/anyproto/anytype-heart/pkg/lib/threads"
"github.com/anyproto/anytype-heart/util/anonymize"
"github.com/anyproto/anytype-heart/util/internalflag"
"github.com/anyproto/anytype-heart/util/pbtypes"
"github.com/anyproto/anytype-heart/util/slice"
@ -161,7 +162,6 @@ type SmartBlock interface {
CheckSubscriptions() (changed bool)
GetDocInfo() DocInfo
Restrictions() restriction.Restrictions
SetRestrictions(r restriction.Restrictions)
ObjectClose(ctx session.Context)
ObjectCloseAllSessions()
@ -187,17 +187,18 @@ type DocInfo struct {
// TODO Maybe create constructor? Don't want to forget required fields
type InitContext struct {
IsNewObject bool
Source source.Source
ObjectTypeKeys []domain.TypeKey
RelationKeys []string
State *state.State
Relations []*model.Relation
Restriction restriction.Service
ObjectStore objectstore.ObjectStore
SpaceID string
BuildOpts source.BuildOptions
Ctx context.Context
IsNewObject bool
Source source.Source
ObjectTypeKeys []domain.TypeKey
RelationKeys []string
RequiredInternalRelationKeys []domain.RelationKey // bundled relations that MUST be present in the state
State *state.State
Relations []*model.Relation
Restriction restriction.Service
ObjectStore objectstore.ObjectStore
SpaceID string
BuildOpts source.BuildOptions
Ctx context.Context
}
type linkSource interface {
@ -309,6 +310,7 @@ func (sb *smartBlock) ObjectTypeID() string {
}
func (sb *smartBlock) Init(ctx *InitContext) (err error) {
ctx.RequiredInternalRelationKeys = append(ctx.RequiredInternalRelationKeys, bundle.RequiredInternalRelations...)
if sb.Doc, err = ctx.Source.ReadDoc(ctx.Ctx, sb, ctx.State != nil); err != nil {
return fmt.Errorf("reading document: %w", err)
}
@ -336,18 +338,24 @@ func (sb *smartBlock) Init(ctx *InitContext) (err error) {
ctx.State.SetParent(sb.Doc.(*state.State))
}
injectRequiredRelationLinks := func(s *state.State) {
s.AddBundledRelationLinks(bundle.RequiredInternalRelations...)
s.AddBundledRelationLinks(ctx.RequiredInternalRelationKeys...)
}
injectRequiredRelationLinks(ctx.State)
injectRequiredRelationLinks(ctx.State.ParentState())
if err = sb.AddRelationLinksToState(ctx.State, ctx.RelationKeys...); err != nil {
return
}
// Add bundled relations
var relKeys []domain.RelationKey
for k := range ctx.State.Details().GetFields() {
if _, err := bundle.GetRelation(domain.RelationKey(k)); err == nil {
if bundle.HasRelation(k) {
relKeys = append(relKeys, domain.RelationKey(k))
}
}
ctx.State.AddBundledRelations(relKeys...)
ctx.State.AddBundledRelationLinks(relKeys...)
if ctx.IsNewObject && ctx.State != nil {
source.NewSubObjectsAndProfileLinksMigration(sb.Type(), sb.space, sb.currentParticipantId, sb.objectStore).Migrate(ctx.State)
}
@ -377,11 +385,7 @@ func (sb *smartBlock) sendObjectCloseEvent(_ ApplyInfo) error {
// updateRestrictions refetch restrictions from restriction service and update them in the smartblock
func (sb *smartBlock) updateRestrictions() {
restrictions := sb.restrictionService.GetRestrictions(sb)
sb.SetRestrictions(restrictions)
}
func (sb *smartBlock) SetRestrictions(r restriction.Restrictions) {
r := sb.restrictionService.GetRestrictions(sb)
if sb.restrictions.Equal(r) {
return
}
@ -448,7 +452,7 @@ func (sb *smartBlock) fetchMeta() (details []*model.ObjectViewDetailsSet, err er
recordsCh := make(chan *types.Struct, 10)
sb.recordsSub = database.NewSubscription(nil, recordsCh)
depIDs := sb.dependentSmartIds(sb.includeRelationObjectsAsDependents, true, true, true)
depIDs := sb.dependentSmartIds(sb.includeRelationObjectsAsDependents, true, true)
sb.setDependentIDs(depIDs)
var records []database.Record
@ -538,7 +542,7 @@ func (sb *smartBlock) onMetaChange(details *types.Struct) {
}
// dependentSmartIds returns list of dependent objects in this order: Simple blocks(Link, mentions in Text), Relations. Both of them are returned in the order of original blocks/relations
func (sb *smartBlock) dependentSmartIds(includeRelations, includeObjTypes, includeCreatorModifier, _ bool) (ids []string) {
func (sb *smartBlock) dependentSmartIds(includeRelations, includeObjTypes, includeCreatorModifier bool) (ids []string) {
return objectlink.DependentObjectIDs(sb.Doc.(*state.State), sb.Space(), true, true, includeRelations, includeObjTypes, includeCreatorModifier)
}
@ -630,8 +634,6 @@ func (sb *smartBlock) Apply(s *state.State, flags ...ApplyFlag) (err error) {
}
}
sb.beforeStateApply(s)
if !keepInternalFlags {
removeInternalFlags(s)
}
@ -706,7 +708,7 @@ func (sb *smartBlock) Apply(s *state.State, flags ...ApplyFlag) (err error) {
if !act.IsEmpty() {
if len(changes) == 0 && !doSnapshot {
log.Errorf("apply 0 changes %s: %v", st.RootId(), msgs)
log.Errorf("apply 0 changes %s: %v", st.RootId(), anonymize.Events(msgsToEvents(msgs)))
}
err = pushChange()
if err != nil {
@ -775,7 +777,6 @@ func (sb *smartBlock) ResetToVersion(s *state.State) (err error) {
s.SetParent(sb.Doc.(*state.State))
sb.storeFileKeys(s)
sb.injectLocalDetails(s)
sb.injectDerivedDetails(s, sb.SpaceID(), sb.Type())
if err = sb.Apply(s, NoHistory, DoSnapshot, NoRestrictions); err != nil {
return
}
@ -786,7 +787,7 @@ func (sb *smartBlock) ResetToVersion(s *state.State) (err error) {
}
func (sb *smartBlock) CheckSubscriptions() (changed bool) {
depIDs := sb.dependentSmartIds(sb.includeRelationObjectsAsDependents, true, true, true)
depIDs := sb.dependentSmartIds(sb.includeRelationObjectsAsDependents, true, true)
changed = sb.setDependentIDs(depIDs)
if sb.recordsSub == nil {
@ -845,6 +846,8 @@ func (sb *smartBlock) AddRelationLinksToState(s *state.State, relationKeys ...st
if len(relationKeys) == 0 {
return
}
// todo: filter-out existing relation links?
// in the most cases it should save as an objectstore query
relations, err := sb.objectStore.FetchRelationByKeys(sb.SpaceID(), relationKeys...)
if err != nil {
return
@ -1284,11 +1287,6 @@ func (sb *smartBlock) runIndexer(s *state.State, opts ...IndexOption) {
}
}
func (sb *smartBlock) beforeStateApply(s *state.State) {
sb.setRestrictionsDetail(s)
sb.injectLinksDetails(s)
}
func removeInternalFlags(s *state.State) {
flags := internalflag.NewFromState(s)
@ -1449,6 +1447,7 @@ func (sb *smartBlock) injectDerivedDetails(s *state.State, spaceID string, sbt s
s.SetDetailAndBundledRelation(bundle.RelationKeyIsDeleted, pbtypes.Bool(isDeleted))
}
sb.injectLinksDetails(s)
sb.updateBackLinks(s)
}

View file

@ -43,8 +43,14 @@ func TestSmartBlock_Init(t *testing.T) {
fx.store.EXPECT().UpdatePendingLocalDetails(mock.Anything, mock.Anything).Return(nil).Maybe()
// when
fx.init(t, []*model.Block{{Id: id}})
initCtx := fx.init(t, []*model.Block{{Id: id}})
require.NotNil(t, initCtx)
require.NotNil(t, initCtx.State)
links := initCtx.State.GetRelationLinks()
for _, key := range bundle.RequiredInternalRelations {
assert.Truef(t, links.Has(key.String()), "missing relation %s", key)
}
// then
assert.Equal(t, id, fx.RootId())
}
@ -464,6 +470,32 @@ func TestInjectLocalDetails(t *testing.T) {
// TODO More tests
}
func TestInjectDerivedDetails(t *testing.T) {
const (
id = "id"
spaceId = "testSpace"
)
t.Run("links are updated on injection", func(t *testing.T) {
// given
fx := newFixture(id, t)
fx.store.EXPECT().GetInboundLinksByID(id).Return(nil, nil)
st := state.NewDoc("id", map[string]simple.Block{
id: simple.New(&model.Block{Id: id, ChildrenIds: []string{"dataview", "link"}}),
"dataview": simple.New(&model.Block{Id: "dataview", Content: &model.BlockContentOfDataview{Dataview: &model.BlockContentDataview{TargetObjectId: "some_set"}}}),
"link": simple.New(&model.Block{Id: "link", Content: &model.BlockContentOfLink{Link: &model.BlockContentLink{TargetBlockId: "some_obj"}}}),
}).NewState()
st.AddRelationLinks(&model.RelationLink{Key: bundle.RelationKeyAssignee.String(), Format: model.RelationFormat_object})
st.SetDetail(bundle.RelationKeyAssignee.String(), pbtypes.String("Kirill"))
// when
fx.injectDerivedDetails(st, spaceId, smartblock.SmartBlockTypePage)
// then
assert.Len(t, pbtypes.GetStringList(st.LocalDetails(), bundle.RelationKeyLinks.String()), 3)
})
}
type fixture struct {
store *mock_objectstore.MockObjectStore
restrictionService *mock_restriction.MockService
@ -501,7 +533,7 @@ func newFixture(id string, t *testing.T) *fixture {
}
}
func (fx *fixture) init(t *testing.T, blocks []*model.Block) {
func (fx *fixture) init(t *testing.T, blocks []*model.Block) *InitContext {
bm := make(map[string]simple.Block)
for _, b := range blocks {
bm[b.Id] = simple.New(b)
@ -509,12 +541,14 @@ func (fx *fixture) init(t *testing.T, blocks []*model.Block) {
doc := state.NewDoc(fx.source.id, bm)
fx.source.doc = doc
err := fx.Init(&InitContext{
initCtx := &InitContext{
Ctx: context.Background(),
SpaceID: "space1",
Source: fx.source,
})
}
err := fx.Init(initCtx)
require.NoError(t, err)
return initCtx
}
type sourceStub struct {

View file

@ -175,10 +175,6 @@ func (st *SmartTest) Tree() objecttree.ObjectTree {
return st.objectTree
}
func (st *SmartTest) SetRestrictions(r restriction.Restrictions) {
st.TestRestrictions = r
}
func (st *SmartTest) Restrictions() restriction.Restrictions {
return st.TestRestrictions
}

View file

@ -26,6 +26,21 @@ var spaceViewLog = logging.Logger("core.block.editor.spaceview")
var ErrIncorrectSpaceInfo = errors.New("space info is incorrect")
// required relations for spaceview beside the bundle.RequiredInternalRelations
var spaceViewRequiredRelations = []domain.RelationKey{
bundle.RelationKeySpaceLocalStatus,
bundle.RelationKeySpaceRemoteStatus,
bundle.RelationKeyTargetSpaceId,
bundle.RelationKeySpaceInviteFileCid,
bundle.RelationKeySpaceInviteFileKey,
bundle.RelationKeyIsAclShared,
bundle.RelationKeySharedSpacesLimit,
bundle.RelationKeySpaceAccountStatus,
bundle.RelationKeySpaceShareableStatus,
bundle.RelationKeySpaceAccessType,
bundle.RelationKeyLatestAclHeadId,
}
type spaceService interface {
OnViewUpdated(info spaceinfo.SpacePersistentInfo)
OnWorkspaceChanged(spaceId string, details *types.Struct)
@ -51,6 +66,7 @@ func (f *ObjectFactory) newSpaceView(sb smartblock.SmartBlock) *SpaceView {
// Init initializes SpaceView
func (s *SpaceView) Init(ctx *smartblock.InitContext) (err error) {
ctx.RequiredInternalRelationKeys = append(ctx.RequiredInternalRelationKeys, spaceViewRequiredRelations...)
if err = s.SmartBlock.Init(ctx); err != nil {
return
}
@ -95,11 +111,6 @@ func (s *SpaceView) StateMigrations() migration.Migrations {
func (s *SpaceView) initTemplate(st *state.State) {
template.InitTemplate(st,
template.WithObjectTypesAndLayout([]domain.TypeKey{bundle.TypeKeySpaceView}, model.ObjectType_spaceView),
template.WithRelations([]domain.RelationKey{
bundle.RelationKeySpaceLocalStatus,
bundle.RelationKeySpaceRemoteStatus,
bundle.RelationKeyTargetSpaceId,
}),
)
}

View file

@ -581,6 +581,8 @@ func (s *State) fillChanges(msgs []simple.EventMessage) {
updMsgs = append(updMsgs, msg.Msg)
case *pb.EventMessageValueOfBlockSetRestrictions:
updMsgs = append(updMsgs, msg.Msg)
case *pb.EventMessageValueOfBlockSetTableRow:
updMsgs = append(updMsgs, msg.Msg)
default:
log.Errorf("unexpected event - can't convert to changes: %T", msg.Msg.GetValue())
}

View file

@ -943,3 +943,54 @@ func TestRootDeviceChanges(t *testing.T) {
assert.Equal(t, device, s.GetChanges()[0].GetDeviceAdd().GetDevice())
})
}
func TestTableChanges(t *testing.T) {
t.Run("change row header", func(t *testing.T) {
contRow := &model.BlockContentOfLayout{
Layout: &model.BlockContentLayout{
Style: model.BlockContentLayout_Row,
},
}
contColumn := &model.BlockContentOfLayout{
Layout: &model.BlockContentLayout{
Style: model.BlockContentLayout_Column,
},
}
r := NewDoc("root", nil).(*State)
s := r.NewState()
s.Add(simple.New(&model.Block{Id: "root", ChildrenIds: []string{"r1", "t1"}}))
s.Add(simple.New(&model.Block{Id: "r1", ChildrenIds: []string{"c1", "c2"}, Content: contRow}))
s.Add(simple.New(&model.Block{Id: "c1", Content: contColumn}))
s.Add(simple.New(&model.Block{Id: "c2", Content: contColumn}))
s.Add(simple.New(&model.Block{Id: "t1", ChildrenIds: []string{"tableRows", "tableColumns"}, Content: &model.BlockContentOfTable{
Table: &model.BlockContentTable{},
}}))
s.Add(simple.New(&model.Block{Id: "tableRows", ChildrenIds: []string{"tableRow1"}, Content: &model.BlockContentOfLayout{
Layout: &model.BlockContentLayout{
Style: model.BlockContentLayout_TableRows,
},
}}))
s.Add(simple.New(&model.Block{Id: "tableRow1", Content: &model.BlockContentOfTableRow{TableRow: &model.BlockContentTableRow{IsHeader: false}}}))
s.Add(simple.New(&model.Block{Id: "tableColumns", Content: &model.BlockContentOfLayout{
Layout: &model.BlockContentLayout{
Style: model.BlockContentLayout_TableColumns,
},
}}))
msgs, _, err := ApplyState(s, true)
require.NoError(t, err)
assert.Len(t, msgs, 1)
s = s.NewState()
rows := s.Get("tableRow1")
require.NotNil(t, rows)
rows.Model().GetTableRow().IsHeader = true
msgs, _, err = ApplyState(s, true)
require.NoError(t, err)
assert.Len(t, msgs, 1)
})
}

View file

@ -174,7 +174,7 @@ func (s *State) wrapToRow(opId string, parent, b simple.Block) (row simple.Block
if pos == -1 {
return nil, fmt.Errorf("creating row: can't find child[%s] in given parent[%s]", b.Model().Id, parent.Model().Id)
}
s.removeFromCache(parent.Model().ChildrenIds[pos])
// do not need to remove from cache
parent.Model().ChildrenIds[pos] = row.Model().Id
s.addCacheIds(parent.Model(), row.Model().Id)
return
@ -185,6 +185,16 @@ func (s *State) setChildrenIds(parent *model.Block, childrenIds []string) {
s.addCacheIds(parent, childrenIds...)
}
// do not use this method outside of normalization
func (s *State) SetChildrenIds(parent *model.Block, childrenIds []string) {
s.setChildrenIds(parent, childrenIds)
}
// do not use this method outside of normalization
func (s *State) RemoveFromCache(childrenIds []string) {
s.removeFromCache(childrenIds...)
}
func (s *State) removeChildren(parent *model.Block, childrenId string) {
parent.ChildrenIds = slice.RemoveMut(parent.ChildrenIds, childrenId)
s.removeFromCache(childrenId)

View file

@ -265,6 +265,7 @@ func (s *State) CleanupBlock(id string) bool {
)
for t != nil {
if _, ok = t.blocks[id]; ok {
s.removeFromCache(id)
delete(t.blocks, id)
return true
}
@ -939,7 +940,7 @@ func (s *State) SetDetails(d *types.Struct) *State {
// SetDetailAndBundledRelation sets the detail value and bundled relation in case it is missing
func (s *State) SetDetailAndBundledRelation(key domain.RelationKey, value *types.Value) {
s.AddBundledRelations(key)
s.AddBundledRelationLinks(key)
s.SetDetail(key.String(), value)
return
}
@ -1933,13 +1934,19 @@ func (s *State) SelectRoots(ids []string) []string {
return res
}
func (s *State) AddBundledRelations(keys ...domain.RelationKey) {
links := make([]*model.RelationLink, 0, len(keys))
func (s *State) AddBundledRelationLinks(keys ...domain.RelationKey) {
existingLinks := s.PickRelationLinks()
var links []*model.RelationLink
for _, key := range keys {
rel := bundle.MustGetRelation(key)
links = append(links, &model.RelationLink{Format: rel.Format, Key: rel.Key})
if !existingLinks.Has(key.String()) {
rel := bundle.MustGetRelation(key)
links = append(links, &model.RelationLink{Format: rel.Format, Key: rel.Key})
}
}
if len(links) > 0 {
s.AddRelationLinks(links...)
}
s.AddRelationLinks(links...)
}
func (s *State) GetNotificationById(id string) *model.Notification {

View file

@ -16,6 +16,7 @@ import (
"github.com/anyproto/anytype-heart/core/block/simple/text"
"github.com/anyproto/anytype-heart/core/domain"
"github.com/anyproto/anytype-heart/pb"
"github.com/anyproto/anytype-heart/pkg/lib/bundle"
"github.com/anyproto/anytype-heart/pkg/lib/pb/model"
"github.com/anyproto/anytype-heart/tests/blockbuilder"
"github.com/anyproto/anytype-heart/util/pbtypes"
@ -2730,3 +2731,92 @@ func TestState_SetDeviceName(t *testing.T) {
assert.Equal(t, newState.deviceStore["id"].Name, "test1")
})
}
func TestAddBundledRealtionLinks(t *testing.T) {
t.Run("with relationLinks in state", func(t *testing.T) {
t.Run("empty", func(t *testing.T) {
st := &State{
relationLinks: []*model.RelationLink{},
}
st.AddBundledRelationLinks(bundle.RelationKeyName, bundle.RelationKeyPriority)
want := &State{
relationLinks: []*model.RelationLink{
{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
},
{
Key: bundle.RelationKeyPriority.String(),
Format: model.RelationFormat_number,
},
},
}
assert.Equal(t, want, st)
})
t.Run("one already exists, one not", func(t *testing.T) {
st := &State{
relationLinks: []*model.RelationLink{
{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
},
},
}
st.AddBundledRelationLinks(bundle.RelationKeyName, bundle.RelationKeyPriority)
want := &State{
relationLinks: []*model.RelationLink{
{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
},
{
Key: bundle.RelationKeyPriority.String(),
Format: model.RelationFormat_number,
},
},
}
assert.Equal(t, want, st)
})
})
t.Run("with relationLinks only in parent state", func(t *testing.T) {
st := &State{
relationLinks: nil,
parent: &State{
relationLinks: []*model.RelationLink{
{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
},
},
},
}
st.AddBundledRelationLinks(bundle.RelationKeyName, bundle.RelationKeyPriority)
want := &State{
relationLinks: []*model.RelationLink{
{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
},
{
Key: bundle.RelationKeyPriority.String(),
Format: model.RelationFormat_number,
},
},
parent: &State{
relationLinks: []*model.RelationLink{
{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
},
},
},
}
assert.Equal(t, want, st)
})
}

View file

@ -62,7 +62,7 @@ func (b *block) Normalize(s *state.State) error {
}
continue
}
normalizeRow(colIdx, row)
normalizeRow(s, colIdx, row)
}
if err := normalizeRows(s, tb); err != nil {
@ -169,7 +169,6 @@ func normalizeRows(s *state.State, tb *Table) error {
var headers []string
regular := make([]string, 0, len(rows.Model().ChildrenIds))
for _, rowID := range rows.Model().ChildrenIds {
row, err := pickRow(s, rowID)
if err != nil {
@ -183,12 +182,11 @@ func normalizeRows(s *state.State, tb *Table) error {
}
}
// nolint:gocritic
rows.Model().ChildrenIds = append(headers, regular...)
s.SetChildrenIds(rows.Model(), append(headers, regular...))
return nil
}
func normalizeRow(colIdx map[string]int, row simple.Block) {
func normalizeRow(s *state.State, colIdx map[string]int, row simple.Block) {
if row == nil || row.Model() == nil {
return
}
@ -196,10 +194,12 @@ func normalizeRow(colIdx map[string]int, row simple.Block) {
cells: make([]string, 0, len(row.Model().ChildrenIds)),
indices: make([]int, 0, len(row.Model().ChildrenIds)),
}
toRemove := []string{}
for _, id := range row.Model().ChildrenIds {
_, colID, err := ParseCellID(id)
if err != nil {
log.Warnf("normalize row %s: discard cell %s: invalid id", row.Model().Id, id)
toRemove = append(toRemove, id)
rs.touched = true
continue
}
@ -207,6 +207,7 @@ func normalizeRow(colIdx map[string]int, row simple.Block) {
v, ok := colIdx[colID]
if !ok {
log.Warnf("normalize row %s: discard cell %s: column %s not found", row.Model().Id, id, colID)
toRemove = append(toRemove, id)
rs.touched = true
continue
}
@ -216,6 +217,11 @@ func normalizeRow(colIdx map[string]int, row simple.Block) {
sort.Sort(rs)
if rs.touched {
row.Model().ChildrenIds = rs.cells
if s == nil {
row.Model().ChildrenIds = rs.cells
} else {
s.RemoveFromCache(toRemove)
s.SetChildrenIds(row.Model(), rs.cells)
}
}
}

View file

@ -0,0 +1,783 @@
package table
import (
"errors"
"fmt"
"sort"
"github.com/globalsign/mgo/bson"
"github.com/anyproto/anytype-heart/core/block/editor/smartblock"
"github.com/anyproto/anytype-heart/core/block/editor/state"
"github.com/anyproto/anytype-heart/core/block/simple"
"github.com/anyproto/anytype-heart/core/block/simple/text"
"github.com/anyproto/anytype-heart/core/block/source"
"github.com/anyproto/anytype-heart/pb"
"github.com/anyproto/anytype-heart/pkg/lib/pb/model"
)
// nolint:revive,interfacebloat
type TableEditor interface {
TableCreate(s *state.State, req pb.RpcBlockTableCreateRequest) (string, error)
CellCreate(s *state.State, rowID string, colID string, b *model.Block) (string, error)
RowCreate(s *state.State, req pb.RpcBlockTableRowCreateRequest) (string, error)
RowDelete(s *state.State, req pb.RpcBlockTableRowDeleteRequest) error
RowDuplicate(s *state.State, req pb.RpcBlockTableRowDuplicateRequest) (newRowID string, err error)
// RowMove is done via BlockListMoveToExistingObject
RowListFill(s *state.State, req pb.RpcBlockTableRowListFillRequest) error
RowListClean(s *state.State, req pb.RpcBlockTableRowListCleanRequest) error
RowSetHeader(s *state.State, req pb.RpcBlockTableRowSetHeaderRequest) error
ColumnCreate(s *state.State, req pb.RpcBlockTableColumnCreateRequest) (string, error)
ColumnDelete(s *state.State, req pb.RpcBlockTableColumnDeleteRequest) error
ColumnDuplicate(s *state.State, req pb.RpcBlockTableColumnDuplicateRequest) (id string, err error)
ColumnMove(s *state.State, req pb.RpcBlockTableColumnMoveRequest) error
ColumnListFill(s *state.State, req pb.RpcBlockTableColumnListFillRequest) error
Expand(s *state.State, req pb.RpcBlockTableExpandRequest) error
Sort(s *state.State, req pb.RpcBlockTableSortRequest) error
cleanupTables(_ smartblock.ApplyInfo) error
cloneColumnStyles(s *state.State, srcColID string, targetColID string) error
}
type editor struct {
sb smartblock.SmartBlock
generateRowID func() string
generateColID func() string
}
var _ TableEditor = &editor{}
func NewEditor(sb smartblock.SmartBlock) TableEditor {
genID := func() string {
return bson.NewObjectId().Hex()
}
t := editor{
sb: sb,
generateRowID: genID,
generateColID: genID,
}
if sb != nil {
sb.AddHook(t.cleanupTables, smartblock.HookOnBlockClose)
}
return &t
}
func (t *editor) TableCreate(s *state.State, req pb.RpcBlockTableCreateRequest) (string, error) {
if t.sb != nil {
if err := t.sb.Restrictions().Object.Check(model.Restrictions_Blocks); err != nil {
return "", err
}
}
tableBlock := simple.New(&model.Block{
Content: &model.BlockContentOfTable{
Table: &model.BlockContentTable{},
},
})
if !s.Add(tableBlock) {
return "", fmt.Errorf("add table block")
}
if err := s.InsertTo(req.TargetId, req.Position, tableBlock.Model().Id); err != nil {
return "", fmt.Errorf("insert block: %w", err)
}
columnIds := make([]string, 0, req.Columns)
for i := uint32(0); i < req.Columns; i++ {
id, err := t.addColumnHeader(s)
if err != nil {
return "", err
}
columnIds = append(columnIds, id)
}
columnsLayout := simple.New(&model.Block{
ChildrenIds: columnIds,
Content: &model.BlockContentOfLayout{
Layout: &model.BlockContentLayout{
Style: model.BlockContentLayout_TableColumns,
},
},
})
if !s.Add(columnsLayout) {
return "", fmt.Errorf("add columns block")
}
rowIDs := make([]string, 0, req.Rows)
for i := uint32(0); i < req.Rows; i++ {
id, err := t.addRow(s)
if err != nil {
return "", err
}
rowIDs = append(rowIDs, id)
}
rowsLayout := simple.New(&model.Block{
ChildrenIds: rowIDs,
Content: &model.BlockContentOfLayout{
Layout: &model.BlockContentLayout{
Style: model.BlockContentLayout_TableRows,
},
},
})
if !s.Add(rowsLayout) {
return "", fmt.Errorf("add rows block")
}
tableBlock.Model().ChildrenIds = []string{columnsLayout.Model().Id, rowsLayout.Model().Id}
if !req.WithHeaderRow {
return tableBlock.Model().Id, nil
}
if len(rowIDs) == 0 {
return "", fmt.Errorf("no rows to make header row")
}
headerID := rowIDs[0]
if err := t.RowSetHeader(s, pb.RpcBlockTableRowSetHeaderRequest{
TargetId: headerID,
IsHeader: true,
}); err != nil {
return "", fmt.Errorf("row set header: %w", err)
}
if err := t.RowListFill(s, pb.RpcBlockTableRowListFillRequest{
BlockIds: []string{headerID},
}); err != nil {
return "", fmt.Errorf("fill header row: %w", err)
}
row, err := getRow(s, headerID)
if err != nil {
return "", fmt.Errorf("get header row: %w", err)
}
for _, cellID := range row.Model().ChildrenIds {
cell := s.Get(cellID)
if cell == nil {
return "", fmt.Errorf("get header cell id %s", cellID)
}
cell.Model().BackgroundColor = "grey"
}
return tableBlock.Model().Id, nil
}
func (t *editor) CellCreate(s *state.State, rowID string, colID string, b *model.Block) (string, error) {
tb, err := NewTable(s, rowID)
if err != nil {
return "", fmt.Errorf("initialize table state: %w", err)
}
row, err := getRow(s, rowID)
if err != nil {
return "", fmt.Errorf("get row: %w", err)
}
if _, err = pickColumn(s, colID); err != nil {
return "", fmt.Errorf("pick column: %w", err)
}
cellID, err := addCell(s, rowID, colID)
if err != nil {
return "", fmt.Errorf("add cell: %w", err)
}
cell := s.Get(cellID)
cell.Model().Content = b.Content
if err := s.InsertTo(rowID, model.Block_Inner, cellID); err != nil {
return "", fmt.Errorf("insert to: %w", err)
}
colIdx := tb.MakeColumnIndex()
normalizeRow(nil, colIdx, row)
return cellID, nil
}
func (t *editor) RowCreate(s *state.State, req pb.RpcBlockTableRowCreateRequest) (string, error) {
switch req.Position {
case model.Block_Top, model.Block_Bottom:
case model.Block_Inner:
tb, err := NewTable(s, req.TargetId)
if err != nil {
return "", fmt.Errorf("initialize table state: %w", err)
}
req.TargetId = tb.Rows().Id
default:
return "", fmt.Errorf("position is not supported")
}
rowID, err := t.addRow(s)
if err != nil {
return "", err
}
if err := s.InsertTo(req.TargetId, req.Position, rowID); err != nil {
return "", fmt.Errorf("insert row: %w", err)
}
return rowID, nil
}
func (t *editor) RowDelete(s *state.State, req pb.RpcBlockTableRowDeleteRequest) error {
_, err := pickRow(s, req.TargetId)
if err != nil {
return fmt.Errorf("pick target row: %w", err)
}
if !s.Unlink(req.TargetId) {
return fmt.Errorf("unlink row block")
}
return nil
}
func (t *editor) RowDuplicate(s *state.State, req pb.RpcBlockTableRowDuplicateRequest) (newRowID string, err error) {
if req.Position != model.Block_Top && req.Position != model.Block_Bottom {
return "", fmt.Errorf("position %s is not supported", model.BlockPosition_name[int32(req.Position)])
}
srcRow, err := pickRow(s, req.BlockId)
if err != nil {
return "", fmt.Errorf("pick source row: %w", err)
}
if _, err = pickRow(s, req.TargetId); err != nil {
return "", fmt.Errorf("pick target row: %w", err)
}
newRow := srcRow.Copy()
newRow.Model().Id = t.generateRowID()
if !s.Add(newRow) {
return "", fmt.Errorf("add new row %s", newRow.Model().Id)
}
if err = s.InsertTo(req.TargetId, req.Position, newRow.Model().Id); err != nil {
return "", fmt.Errorf("insert column: %w", err)
}
for i, srcID := range newRow.Model().ChildrenIds {
cell := s.Pick(srcID)
if cell == nil {
return "", fmt.Errorf("cell %s is not found", srcID)
}
_, colID, err := ParseCellID(srcID)
if err != nil {
return "", fmt.Errorf("parse cell id %s: %w", srcID, err)
}
newCell := cell.Copy()
newCell.Model().Id = MakeCellID(newRow.Model().Id, colID)
if !s.Add(newCell) {
return "", fmt.Errorf("add new cell %s", newCell.Model().Id)
}
newRow.Model().ChildrenIds[i] = newCell.Model().Id
}
return newRow.Model().Id, nil
}
func (t *editor) RowListFill(s *state.State, req pb.RpcBlockTableRowListFillRequest) error {
if len(req.BlockIds) == 0 {
return fmt.Errorf("empty row list")
}
tb, err := NewTable(s, req.BlockIds[0])
if err != nil {
return fmt.Errorf("init table: %w", err)
}
columns := tb.ColumnIDs()
for _, rowID := range req.BlockIds {
row, err := getRow(s, rowID)
if err != nil {
return fmt.Errorf("get row %s: %w", rowID, err)
}
newIds := make([]string, 0, len(columns))
for _, colID := range columns {
id := MakeCellID(rowID, colID)
newIds = append(newIds, id)
if !s.Exists(id) {
_, err := addCell(s, rowID, colID)
if err != nil {
return fmt.Errorf("add cell %s: %w", id, err)
}
}
}
row.Model().ChildrenIds = newIds
}
return nil
}
func (t *editor) RowListClean(s *state.State, req pb.RpcBlockTableRowListCleanRequest) error {
if len(req.BlockIds) == 0 {
return fmt.Errorf("empty row list")
}
for _, rowID := range req.BlockIds {
row, err := pickRow(s, rowID)
if err != nil {
return fmt.Errorf("pick row: %w", err)
}
for _, cellID := range row.Model().ChildrenIds {
cell := s.Pick(cellID)
if v, ok := cell.(text.Block); ok && v.IsEmpty() {
s.Unlink(cellID)
}
}
}
return nil
}
func (t *editor) RowSetHeader(s *state.State, req pb.RpcBlockTableRowSetHeaderRequest) error {
tb, err := NewTable(s, req.TargetId)
if err != nil {
return fmt.Errorf("init table: %w", err)
}
row, err := getRow(s, req.TargetId)
if err != nil {
return fmt.Errorf("get target row: %w", err)
}
if row.Model().GetTableRow().IsHeader != req.IsHeader {
row.Model().GetTableRow().IsHeader = req.IsHeader
err = normalizeRows(s, tb)
if err != nil {
return fmt.Errorf("normalize rows: %w", err)
}
}
return nil
}
func (t *editor) ColumnCreate(s *state.State, req pb.RpcBlockTableColumnCreateRequest) (string, error) {
switch req.Position {
case model.Block_Left:
req.Position = model.Block_Top
if _, err := pickColumn(s, req.TargetId); err != nil {
return "", fmt.Errorf("pick column: %w", err)
}
case model.Block_Right:
req.Position = model.Block_Bottom
if _, err := pickColumn(s, req.TargetId); err != nil {
return "", fmt.Errorf("pick column: %w", err)
}
case model.Block_Inner:
tb, err := NewTable(s, req.TargetId)
if err != nil {
return "", fmt.Errorf("initialize table state: %w", err)
}
req.TargetId = tb.Columns().Id
default:
return "", fmt.Errorf("position is not supported")
}
colID, err := t.addColumnHeader(s)
if err != nil {
return "", err
}
if err = s.InsertTo(req.TargetId, req.Position, colID); err != nil {
return "", fmt.Errorf("insert column header: %w", err)
}
return colID, t.cloneColumnStyles(s, req.TargetId, colID)
}
func (t *editor) ColumnDelete(s *state.State, req pb.RpcBlockTableColumnDeleteRequest) error {
_, err := pickColumn(s, req.TargetId)
if err != nil {
return fmt.Errorf("pick target column: %w", err)
}
tb, err := NewTable(s, req.TargetId)
if err != nil {
return fmt.Errorf("initialize table state: %w", err)
}
for _, rowID := range tb.RowIDs() {
row, err := pickRow(s, rowID)
if err != nil {
return fmt.Errorf("pick row %s: %w", rowID, err)
}
for _, cellID := range row.Model().ChildrenIds {
_, colID, err := ParseCellID(cellID)
if err != nil {
return fmt.Errorf("parse cell id %s: %w", cellID, err)
}
if colID == req.TargetId {
if !s.Unlink(cellID) {
return fmt.Errorf("unlink cell %s", cellID)
}
break
}
}
}
if !s.Unlink(req.TargetId) {
return fmt.Errorf("unlink column header")
}
return nil
}
func (t *editor) ColumnDuplicate(s *state.State, req pb.RpcBlockTableColumnDuplicateRequest) (id string, err error) {
switch req.Position {
case model.Block_Left:
req.Position = model.Block_Top
case model.Block_Right:
req.Position = model.Block_Bottom
default:
return "", fmt.Errorf("position is not supported")
}
srcCol, err := pickColumn(s, req.BlockId)
if err != nil {
return "", fmt.Errorf("pick source column: %w", err)
}
_, err = pickColumn(s, req.TargetId)
if err != nil {
return "", fmt.Errorf("pick target column: %w", err)
}
tb, err := NewTable(s, req.TargetId)
if err != nil {
return "", fmt.Errorf("init table block: %w", err)
}
newCol := srcCol.Copy()
newCol.Model().Id = t.generateColID()
if !s.Add(newCol) {
return "", fmt.Errorf("add column block")
}
if err = s.InsertTo(req.TargetId, req.Position, newCol.Model().Id); err != nil {
return "", fmt.Errorf("insert column: %w", err)
}
colIdx := tb.MakeColumnIndex()
for _, rowID := range tb.RowIDs() {
row, err := getRow(s, rowID)
if err != nil {
return "", fmt.Errorf("get row %s: %w", rowID, err)
}
var cellID string
for _, id := range row.Model().ChildrenIds {
_, colID, err := ParseCellID(id)
if err != nil {
return "", fmt.Errorf("parse cell %s in row %s: %w", cellID, rowID, err)
}
if colID == req.BlockId {
cellID = id
break
}
}
if cellID == "" {
continue
}
cell := s.Pick(cellID)
if cell == nil {
return "", fmt.Errorf("cell %s is not found", cellID)
}
cell = cell.Copy()
cell.Model().Id = MakeCellID(rowID, newCol.Model().Id)
if !s.Add(cell) {
return "", fmt.Errorf("add cell block")
}
row.Model().ChildrenIds = append(row.Model().ChildrenIds, cell.Model().Id)
normalizeRow(nil, colIdx, row)
}
return newCol.Model().Id, nil
}
func (t *editor) ColumnMove(s *state.State, req pb.RpcBlockTableColumnMoveRequest) error {
switch req.Position {
case model.Block_Left:
req.Position = model.Block_Top
case model.Block_Right:
req.Position = model.Block_Bottom
default:
return fmt.Errorf("position is not supported")
}
_, err := pickColumn(s, req.TargetId)
if err != nil {
return fmt.Errorf("get target column: %w", err)
}
_, err = pickColumn(s, req.DropTargetId)
if err != nil {
return fmt.Errorf("get drop target column: %w", err)
}
tb, err := NewTable(s, req.TargetId)
if err != nil {
return fmt.Errorf("init table block: %w", err)
}
if !s.Unlink(req.TargetId) {
return fmt.Errorf("unlink target column")
}
if err = s.InsertTo(req.DropTargetId, req.Position, req.TargetId); err != nil {
return fmt.Errorf("insert column: %w", err)
}
colIdx := tb.MakeColumnIndex()
for _, id := range tb.RowIDs() {
row, err := getRow(s, id)
if err != nil {
return fmt.Errorf("get row %s: %w", id, err)
}
normalizeRow(nil, colIdx, row)
}
return nil
}
func (t *editor) ColumnListFill(s *state.State, req pb.RpcBlockTableColumnListFillRequest) error {
if len(req.BlockIds) == 0 {
return fmt.Errorf("empty row list")
}
tb, err := NewTable(s, req.BlockIds[0])
if err != nil {
return fmt.Errorf("init table: %w", err)
}
rows := tb.RowIDs()
for _, colID := range req.BlockIds {
for _, rowID := range rows {
id := MakeCellID(rowID, colID)
if s.Exists(id) {
continue
}
_, err := addCell(s, rowID, colID)
if err != nil {
return fmt.Errorf("add cell %s: %w", id, err)
}
row, err := getRow(s, rowID)
if err != nil {
return fmt.Errorf("get row %s: %w", rowID, err)
}
row.Model().ChildrenIds = append(row.Model().ChildrenIds, id)
}
}
colIdx := tb.MakeColumnIndex()
for _, rowID := range rows {
row, err := getRow(s, rowID)
if err != nil {
return fmt.Errorf("get row %s: %w", rowID, err)
}
normalizeRow(nil, colIdx, row)
}
return nil
}
func (t *editor) Expand(s *state.State, req pb.RpcBlockTableExpandRequest) error {
tb, err := NewTable(s, req.TargetId)
if err != nil {
return fmt.Errorf("init table block: %w", err)
}
for i := uint32(0); i < req.Columns; i++ {
_, err := t.ColumnCreate(s, pb.RpcBlockTableColumnCreateRequest{
TargetId: req.TargetId,
Position: model.Block_Inner,
})
if err != nil {
return fmt.Errorf("create column: %w", err)
}
}
for i := uint32(0); i < req.Rows; i++ {
rows := tb.Rows()
_, err := t.RowCreate(s, pb.RpcBlockTableRowCreateRequest{
TargetId: rows.ChildrenIds[len(rows.ChildrenIds)-1],
Position: model.Block_Bottom,
})
if err != nil {
return fmt.Errorf("create row: %w", err)
}
}
return nil
}
func (t *editor) Sort(s *state.State, req pb.RpcBlockTableSortRequest) error {
_, err := pickColumn(s, req.ColumnId)
if err != nil {
return fmt.Errorf("pick column: %w", err)
}
tb, err := NewTable(s, req.ColumnId)
if err != nil {
return fmt.Errorf("init table block: %w", err)
}
rows := s.Get(tb.Rows().Id)
sorter := tableSorter{
rowIDs: make([]string, 0, len(rows.Model().ChildrenIds)),
values: make([]string, len(rows.Model().ChildrenIds)),
}
var headers []string
var i int
for _, rowID := range rows.Model().ChildrenIds {
row, err := pickRow(s, rowID)
if err != nil {
return fmt.Errorf("pick row %s: %w", rowID, err)
}
if row.Model().GetTableRow().GetIsHeader() {
headers = append(headers, rowID)
continue
}
sorter.rowIDs = append(sorter.rowIDs, rowID)
for _, cellID := range row.Model().ChildrenIds {
_, colID, err := ParseCellID(cellID)
if err != nil {
return fmt.Errorf("parse cell id %s: %w", cellID, err)
}
if colID == req.ColumnId {
cell := s.Pick(cellID)
if cell == nil {
return fmt.Errorf("cell %s is not found", cellID)
}
sorter.values[i] = cell.Model().GetText().GetText()
}
}
i++
}
if req.Type == model.BlockContentDataviewSort_Asc {
sort.Stable(sorter)
} else {
sort.Stable(sort.Reverse(sorter))
}
// nolint:gocritic
rows.Model().ChildrenIds = append(headers, sorter.rowIDs...)
return nil
}
func (t *editor) cleanupTables(_ smartblock.ApplyInfo) error {
if t.sb == nil {
return fmt.Errorf("nil smartblock")
}
s := t.sb.NewState()
err := s.Iterate(func(b simple.Block) bool {
if b.Model().GetTable() == nil {
return true
}
tb, err := NewTable(s, b.Model().Id)
if err != nil {
log.Errorf("cleanup: init table %s: %s", b.Model().Id, err)
return true
}
err = t.RowListClean(s, pb.RpcBlockTableRowListCleanRequest{
BlockIds: tb.RowIDs(),
})
if err != nil {
log.Errorf("cleanup table %s: %s", b.Model().Id, err)
return true
}
return true
})
if err != nil {
log.Errorf("cleanup iterate: %s", err)
}
if err = t.sb.Apply(s, smartblock.KeepInternalFlags); err != nil {
if errors.Is(err, source.ErrReadOnly) {
return nil
}
log.Errorf("cleanup apply: %s", err)
}
return nil
}
func (t *editor) cloneColumnStyles(s *state.State, srcColID, targetColID string) error {
tb, err := NewTable(s, srcColID)
if err != nil {
return fmt.Errorf("init table block: %w", err)
}
colIdx := tb.MakeColumnIndex()
for _, rowID := range tb.RowIDs() {
row, err := pickRow(s, rowID)
if err != nil {
return fmt.Errorf("pick row: %w", err)
}
var protoBlock simple.Block
for _, cellID := range row.Model().ChildrenIds {
_, colID, err := ParseCellID(cellID)
if err != nil {
return fmt.Errorf("parse cell id: %w", err)
}
if colID == srcColID {
protoBlock = s.Pick(cellID)
}
}
if protoBlock != nil && protoBlock.Model().BackgroundColor != "" {
targetCellID := MakeCellID(rowID, targetColID)
if !s.Exists(targetCellID) {
_, err := addCell(s, rowID, targetColID)
if err != nil {
return fmt.Errorf("add cell: %w", err)
}
}
cell := s.Get(targetCellID)
cell.Model().BackgroundColor = protoBlock.Model().BackgroundColor
row = s.Get(row.Model().Id)
row.Model().ChildrenIds = append(row.Model().ChildrenIds, targetCellID)
normalizeRow(nil, colIdx, row)
}
}
return nil
}
func (t *editor) addColumnHeader(s *state.State) (string, error) {
b := simple.New(&model.Block{
Id: t.generateColID(),
Content: &model.BlockContentOfTableColumn{
TableColumn: &model.BlockContentTableColumn{},
},
})
if !s.Add(b) {
return "", fmt.Errorf("add column block")
}
return b.Model().Id, nil
}
func (t *editor) addRow(s *state.State) (string, error) {
row := makeRow(t.generateRowID())
if !s.Add(row) {
return "", fmt.Errorf("add row block")
}
return row.Model().Id, nil
}

File diff suppressed because it is too large Load diff

View file

@ -2,751 +2,20 @@ package table
import (
"fmt"
"sort"
"strings"
"github.com/globalsign/mgo/bson"
"github.com/samber/lo"
"github.com/anyproto/anytype-heart/core/block/editor/smartblock"
"github.com/anyproto/anytype-heart/core/block/editor/state"
"github.com/anyproto/anytype-heart/core/block/simple"
"github.com/anyproto/anytype-heart/core/block/simple/table"
"github.com/anyproto/anytype-heart/core/block/simple/text"
"github.com/anyproto/anytype-heart/core/block/source"
"github.com/anyproto/anytype-heart/pb"
"github.com/anyproto/anytype-heart/pkg/lib/logging"
"github.com/anyproto/anytype-heart/pkg/lib/pb/model"
)
var log = logging.Logger("anytype-simple-tables")
// nolint:revive,interfacebloat
type TableEditor interface {
TableCreate(s *state.State, req pb.RpcBlockTableCreateRequest) (string, error)
RowCreate(s *state.State, req pb.RpcBlockTableRowCreateRequest) (string, error)
RowDelete(s *state.State, req pb.RpcBlockTableRowDeleteRequest) error
ColumnDelete(s *state.State, req pb.RpcBlockTableColumnDeleteRequest) error
ColumnMove(s *state.State, req pb.RpcBlockTableColumnMoveRequest) error
RowDuplicate(s *state.State, req pb.RpcBlockTableRowDuplicateRequest) (newRowID string, err error)
RowListFill(s *state.State, req pb.RpcBlockTableRowListFillRequest) error
RowListClean(s *state.State, req pb.RpcBlockTableRowListCleanRequest) error
RowSetHeader(s *state.State, req pb.RpcBlockTableRowSetHeaderRequest) error
ColumnListFill(s *state.State, req pb.RpcBlockTableColumnListFillRequest) error
cleanupTables(_ smartblock.ApplyInfo) error
ColumnCreate(s *state.State, req pb.RpcBlockTableColumnCreateRequest) (string, error)
cloneColumnStyles(s *state.State, srcColID string, targetColID string) error
ColumnDuplicate(s *state.State, req pb.RpcBlockTableColumnDuplicateRequest) (id string, err error)
Expand(s *state.State, req pb.RpcBlockTableExpandRequest) error
Sort(s *state.State, req pb.RpcBlockTableSortRequest) error
CellCreate(s *state.State, rowID string, colID string, b *model.Block) (string, error)
}
type Editor struct {
sb smartblock.SmartBlock
generateRowID func() string
generateColID func() string
}
var _ TableEditor = &Editor{}
func NewEditor(sb smartblock.SmartBlock) *Editor {
genID := func() string {
return bson.NewObjectId().Hex()
}
t := Editor{
sb: sb,
generateRowID: genID,
generateColID: genID,
}
if sb != nil {
sb.AddHook(t.cleanupTables, smartblock.HookOnBlockClose)
}
return &t
}
func (t *Editor) TableCreate(s *state.State, req pb.RpcBlockTableCreateRequest) (string, error) {
if t.sb != nil {
if err := t.sb.Restrictions().Object.Check(model.Restrictions_Blocks); err != nil {
return "", err
}
}
tableBlock := simple.New(&model.Block{
Content: &model.BlockContentOfTable{
Table: &model.BlockContentTable{},
},
})
if !s.Add(tableBlock) {
return "", fmt.Errorf("add table block")
}
if err := s.InsertTo(req.TargetId, req.Position, tableBlock.Model().Id); err != nil {
return "", fmt.Errorf("insert block: %w", err)
}
columnIds := make([]string, 0, req.Columns)
for i := uint32(0); i < req.Columns; i++ {
id, err := t.addColumnHeader(s)
if err != nil {
return "", err
}
columnIds = append(columnIds, id)
}
columnsLayout := simple.New(&model.Block{
ChildrenIds: columnIds,
Content: &model.BlockContentOfLayout{
Layout: &model.BlockContentLayout{
Style: model.BlockContentLayout_TableColumns,
},
},
})
if !s.Add(columnsLayout) {
return "", fmt.Errorf("add columns block")
}
rowIDs := make([]string, 0, req.Rows)
for i := uint32(0); i < req.Rows; i++ {
id, err := t.addRow(s)
if err != nil {
return "", err
}
rowIDs = append(rowIDs, id)
}
rowsLayout := simple.New(&model.Block{
ChildrenIds: rowIDs,
Content: &model.BlockContentOfLayout{
Layout: &model.BlockContentLayout{
Style: model.BlockContentLayout_TableRows,
},
},
})
if !s.Add(rowsLayout) {
return "", fmt.Errorf("add rows block")
}
tableBlock.Model().ChildrenIds = []string{columnsLayout.Model().Id, rowsLayout.Model().Id}
if req.WithHeaderRow {
headerID := rowIDs[0]
if err := t.RowSetHeader(s, pb.RpcBlockTableRowSetHeaderRequest{
TargetId: headerID,
IsHeader: true,
}); err != nil {
return "", fmt.Errorf("row set header: %w", err)
}
if err := t.RowListFill(s, pb.RpcBlockTableRowListFillRequest{
BlockIds: []string{headerID},
}); err != nil {
return "", fmt.Errorf("fill header row: %w", err)
}
row, err := getRow(s, headerID)
if err != nil {
return "", fmt.Errorf("get header row: %w", err)
}
for _, cellID := range row.Model().ChildrenIds {
cell := s.Get(cellID)
if cell == nil {
return "", fmt.Errorf("get header cell id %s", cellID)
}
cell.Model().BackgroundColor = "grey"
}
}
return tableBlock.Model().Id, nil
}
func (t *Editor) RowCreate(s *state.State, req pb.RpcBlockTableRowCreateRequest) (string, error) {
switch req.Position {
case model.Block_Top, model.Block_Bottom:
case model.Block_Inner:
tb, err := NewTable(s, req.TargetId)
if err != nil {
return "", fmt.Errorf("initialize table state: %w", err)
}
req.TargetId = tb.Rows().Id
default:
return "", fmt.Errorf("position is not supported")
}
rowID, err := t.addRow(s)
if err != nil {
return "", err
}
if err := s.InsertTo(req.TargetId, req.Position, rowID); err != nil {
return "", fmt.Errorf("insert row: %w", err)
}
return rowID, nil
}
func (t *Editor) RowDelete(s *state.State, req pb.RpcBlockTableRowDeleteRequest) error {
_, err := pickRow(s, req.TargetId)
if err != nil {
return fmt.Errorf("pick target row: %w", err)
}
if !s.Unlink(req.TargetId) {
return fmt.Errorf("unlink row block")
}
return nil
}
func (t *Editor) ColumnDelete(s *state.State, req pb.RpcBlockTableColumnDeleteRequest) error {
_, err := pickColumn(s, req.TargetId)
if err != nil {
return fmt.Errorf("pick target column: %w", err)
}
tb, err := NewTable(s, req.TargetId)
if err != nil {
return fmt.Errorf("initialize table state: %w", err)
}
for _, rowID := range tb.RowIDs() {
row, err := pickRow(s, rowID)
if err != nil {
return fmt.Errorf("pick row %s: %w", rowID, err)
}
for _, cellID := range row.Model().ChildrenIds {
_, colID, err := ParseCellID(cellID)
if err != nil {
return fmt.Errorf("parse cell id %s: %w", cellID, err)
}
if colID == req.TargetId {
if !s.Unlink(cellID) {
return fmt.Errorf("unlink cell %s", cellID)
}
break
}
}
}
if !s.Unlink(req.TargetId) {
return fmt.Errorf("unlink column header")
}
return nil
}
func (t *Editor) ColumnMove(s *state.State, req pb.RpcBlockTableColumnMoveRequest) error {
switch req.Position {
case model.Block_Left:
req.Position = model.Block_Top
case model.Block_Right:
req.Position = model.Block_Bottom
default:
return fmt.Errorf("position is not supported")
}
_, err := pickColumn(s, req.TargetId)
if err != nil {
return fmt.Errorf("get target column: %w", err)
}
_, err = pickColumn(s, req.DropTargetId)
if err != nil {
return fmt.Errorf("get drop target column: %w", err)
}
tb, err := NewTable(s, req.TargetId)
if err != nil {
return fmt.Errorf("init table block: %w", err)
}
if !s.Unlink(req.TargetId) {
return fmt.Errorf("unlink target column")
}
if err = s.InsertTo(req.DropTargetId, req.Position, req.TargetId); err != nil {
return fmt.Errorf("insert column: %w", err)
}
colIdx := tb.MakeColumnIndex()
for _, id := range tb.RowIDs() {
row, err := getRow(s, id)
if err != nil {
return fmt.Errorf("get row %s: %w", id, err)
}
normalizeRow(colIdx, row)
}
return nil
}
func (t *Editor) RowDuplicate(s *state.State, req pb.RpcBlockTableRowDuplicateRequest) (newRowID string, err error) {
srcRow, err := pickRow(s, req.BlockId)
if err != nil {
return "", fmt.Errorf("pick source row: %w", err)
}
newRow := srcRow.Copy()
newRow.Model().Id = t.generateRowID()
if !s.Add(newRow) {
return "", fmt.Errorf("add new row %s", newRow.Model().Id)
}
if err = s.InsertTo(req.TargetId, req.Position, newRow.Model().Id); err != nil {
return "", fmt.Errorf("insert column: %w", err)
}
for i, srcID := range newRow.Model().ChildrenIds {
cell := s.Pick(srcID)
if cell == nil {
return "", fmt.Errorf("cell %s is not found", srcID)
}
_, colID, err := ParseCellID(srcID)
if err != nil {
return "", fmt.Errorf("parse cell id %s: %w", srcID, err)
}
newCell := cell.Copy()
newCell.Model().Id = MakeCellID(newRow.Model().Id, colID)
if !s.Add(newCell) {
return "", fmt.Errorf("add new cell %s", newCell.Model().Id)
}
newRow.Model().ChildrenIds[i] = newCell.Model().Id
}
return newRow.Model().Id, nil
}
func (t *Editor) RowListFill(s *state.State, req pb.RpcBlockTableRowListFillRequest) error {
if len(req.BlockIds) == 0 {
return fmt.Errorf("empty row list")
}
tb, err := NewTable(s, req.BlockIds[0])
if err != nil {
return fmt.Errorf("init table: %w", err)
}
columns := tb.ColumnIDs()
for _, rowID := range req.BlockIds {
row, err := getRow(s, rowID)
if err != nil {
return fmt.Errorf("get row %s: %w", rowID, err)
}
newIds := make([]string, 0, len(columns))
for _, colID := range columns {
id := MakeCellID(rowID, colID)
newIds = append(newIds, id)
if !s.Exists(id) {
_, err := addCell(s, rowID, colID)
if err != nil {
return fmt.Errorf("add cell %s: %w", id, err)
}
}
}
row.Model().ChildrenIds = newIds
}
return nil
}
func (t *Editor) RowListClean(s *state.State, req pb.RpcBlockTableRowListCleanRequest) error {
if len(req.BlockIds) == 0 {
return fmt.Errorf("empty row list")
}
for _, rowID := range req.BlockIds {
row, err := pickRow(s, rowID)
if err != nil {
return fmt.Errorf("pick row: %w", err)
}
for _, cellID := range row.Model().ChildrenIds {
cell := s.Pick(cellID)
if v, ok := cell.(text.Block); ok && v.IsEmpty() {
s.Unlink(cellID)
}
}
}
return nil
}
func (t *Editor) RowSetHeader(s *state.State, req pb.RpcBlockTableRowSetHeaderRequest) error {
tb, err := NewTable(s, req.TargetId)
if err != nil {
return fmt.Errorf("init table: %w", err)
}
row, err := getRow(s, req.TargetId)
if err != nil {
return fmt.Errorf("get target row: %w", err)
}
if row.Model().GetTableRow().IsHeader != req.IsHeader {
row.Model().GetTableRow().IsHeader = req.IsHeader
err = normalizeRows(s, tb)
if err != nil {
return fmt.Errorf("normalize rows: %w", err)
}
}
return nil
}
func (t *Editor) ColumnListFill(s *state.State, req pb.RpcBlockTableColumnListFillRequest) error {
if len(req.BlockIds) == 0 {
return fmt.Errorf("empty row list")
}
tb, err := NewTable(s, req.BlockIds[0])
if err != nil {
return fmt.Errorf("init table: %w", err)
}
rows := tb.RowIDs()
for _, colID := range req.BlockIds {
for _, rowID := range rows {
id := MakeCellID(rowID, colID)
if s.Exists(id) {
continue
}
_, err := addCell(s, rowID, colID)
if err != nil {
return fmt.Errorf("add cell %s: %w", id, err)
}
row, err := getRow(s, rowID)
if err != nil {
return fmt.Errorf("get row %s: %w", rowID, err)
}
row.Model().ChildrenIds = append(row.Model().ChildrenIds, id)
}
}
colIdx := tb.MakeColumnIndex()
for _, rowID := range rows {
row, err := getRow(s, rowID)
if err != nil {
return fmt.Errorf("get row %s: %w", rowID, err)
}
normalizeRow(colIdx, row)
}
return nil
}
func (t *Editor) cleanupTables(_ smartblock.ApplyInfo) error {
if t.sb == nil {
return fmt.Errorf("nil smartblock")
}
s := t.sb.NewState()
err := s.Iterate(func(b simple.Block) bool {
if b.Model().GetTable() == nil {
return true
}
tb, err := NewTable(s, b.Model().Id)
if err != nil {
log.Errorf("cleanup: init table %s: %s", b.Model().Id, err)
return true
}
err = t.RowListClean(s, pb.RpcBlockTableRowListCleanRequest{
BlockIds: tb.RowIDs(),
})
if err != nil {
log.Errorf("cleanup table %s: %s", b.Model().Id, err)
return true
}
return true
})
if err != nil {
log.Errorf("cleanup iterate: %s", err)
}
if err = t.sb.Apply(s, smartblock.KeepInternalFlags); err != nil {
if err == source.ErrReadOnly {
return nil
}
log.Errorf("cleanup apply: %s", err)
}
return nil
}
func (t *Editor) ColumnCreate(s *state.State, req pb.RpcBlockTableColumnCreateRequest) (string, error) {
switch req.Position {
case model.Block_Left:
req.Position = model.Block_Top
if _, err := pickColumn(s, req.TargetId); err != nil {
return "", fmt.Errorf("pick column: %w", err)
}
case model.Block_Right:
req.Position = model.Block_Bottom
if _, err := pickColumn(s, req.TargetId); err != nil {
return "", fmt.Errorf("pick column: %w", err)
}
case model.Block_Inner:
tb, err := NewTable(s, req.TargetId)
if err != nil {
return "", fmt.Errorf("initialize table state: %w", err)
}
req.TargetId = tb.Columns().Id
default:
return "", fmt.Errorf("position is not supported")
}
colID, err := t.addColumnHeader(s)
if err != nil {
return "", err
}
if err = s.InsertTo(req.TargetId, req.Position, colID); err != nil {
return "", fmt.Errorf("insert column header: %w", err)
}
return colID, t.cloneColumnStyles(s, req.TargetId, colID)
}
func (t *Editor) cloneColumnStyles(s *state.State, srcColID, targetColID string) error {
tb, err := NewTable(s, srcColID)
if err != nil {
return fmt.Errorf("init table block: %w", err)
}
colIdx := tb.MakeColumnIndex()
for _, rowID := range tb.RowIDs() {
row, err := pickRow(s, rowID)
if err != nil {
return fmt.Errorf("pick row: %w", err)
}
var protoBlock simple.Block
for _, cellID := range row.Model().ChildrenIds {
_, colID, err := ParseCellID(cellID)
if err != nil {
return fmt.Errorf("parse cell id: %w", err)
}
if colID == srcColID {
protoBlock = s.Pick(cellID)
}
}
if protoBlock != nil && protoBlock.Model().BackgroundColor != "" {
targetCellID := MakeCellID(rowID, targetColID)
if !s.Exists(targetCellID) {
_, err := addCell(s, rowID, targetColID)
if err != nil {
return fmt.Errorf("add cell: %w", err)
}
}
cell := s.Get(targetCellID)
cell.Model().BackgroundColor = protoBlock.Model().BackgroundColor
row = s.Get(row.Model().Id)
row.Model().ChildrenIds = append(row.Model().ChildrenIds, targetCellID)
normalizeRow(colIdx, row)
}
}
return nil
}
func (t *Editor) ColumnDuplicate(s *state.State, req pb.RpcBlockTableColumnDuplicateRequest) (id string, err error) {
switch req.Position {
case model.Block_Left:
req.Position = model.Block_Top
case model.Block_Right:
req.Position = model.Block_Bottom
default:
return "", fmt.Errorf("position is not supported")
}
srcCol, err := pickColumn(s, req.BlockId)
if err != nil {
return "", fmt.Errorf("pick source column: %w", err)
}
_, err = pickColumn(s, req.TargetId)
if err != nil {
return "", fmt.Errorf("pick target column: %w", err)
}
tb, err := NewTable(s, req.TargetId)
if err != nil {
return "", fmt.Errorf("init table block: %w", err)
}
newCol := srcCol.Copy()
newCol.Model().Id = t.generateColID()
if !s.Add(newCol) {
return "", fmt.Errorf("add column block")
}
if err = s.InsertTo(req.TargetId, req.Position, newCol.Model().Id); err != nil {
return "", fmt.Errorf("insert column: %w", err)
}
colIdx := tb.MakeColumnIndex()
for _, rowID := range tb.RowIDs() {
row, err := getRow(s, rowID)
if err != nil {
return "", fmt.Errorf("get row %s: %w", rowID, err)
}
var cellID string
for _, id := range row.Model().ChildrenIds {
_, colID, err := ParseCellID(id)
if err != nil {
return "", fmt.Errorf("parse cell %s in row %s: %w", cellID, rowID, err)
}
if colID == req.BlockId {
cellID = id
break
}
}
if cellID == "" {
continue
}
cell := s.Pick(cellID)
if cell == nil {
return "", fmt.Errorf("cell %s is not found", cellID)
}
cell = cell.Copy()
cell.Model().Id = MakeCellID(rowID, newCol.Model().Id)
if !s.Add(cell) {
return "", fmt.Errorf("add cell block")
}
row.Model().ChildrenIds = append(row.Model().ChildrenIds, cell.Model().Id)
normalizeRow(colIdx, row)
}
return newCol.Model().Id, nil
}
func (t *Editor) Expand(s *state.State, req pb.RpcBlockTableExpandRequest) error {
tb, err := NewTable(s, req.TargetId)
if err != nil {
return fmt.Errorf("init table block: %w", err)
}
for i := uint32(0); i < req.Columns; i++ {
_, err := t.ColumnCreate(s, pb.RpcBlockTableColumnCreateRequest{
TargetId: req.TargetId,
Position: model.Block_Inner,
})
if err != nil {
return fmt.Errorf("create column: %w", err)
}
}
for i := uint32(0); i < req.Rows; i++ {
rows := tb.Rows()
_, err := t.RowCreate(s, pb.RpcBlockTableRowCreateRequest{
TargetId: rows.ChildrenIds[len(rows.ChildrenIds)-1],
Position: model.Block_Bottom,
})
if err != nil {
return fmt.Errorf("create row: %w", err)
}
}
return nil
}
func (t *Editor) Sort(s *state.State, req pb.RpcBlockTableSortRequest) error {
_, err := pickColumn(s, req.ColumnId)
if err != nil {
return fmt.Errorf("pick column: %w", err)
}
tb, err := NewTable(s, req.ColumnId)
if err != nil {
return fmt.Errorf("init table block: %w", err)
}
rows := s.Get(tb.Rows().Id)
sorter := tableSorter{
rowIDs: make([]string, 0, len(rows.Model().ChildrenIds)),
values: make([]string, len(rows.Model().ChildrenIds)),
}
var headers []string
var i int
for _, rowID := range rows.Model().ChildrenIds {
row, err := pickRow(s, rowID)
if err != nil {
return fmt.Errorf("pick row %s: %w", rowID, err)
}
if row.Model().GetTableRow().GetIsHeader() {
headers = append(headers, rowID)
continue
}
sorter.rowIDs = append(sorter.rowIDs, rowID)
for _, cellID := range row.Model().ChildrenIds {
_, colID, err := ParseCellID(cellID)
if err != nil {
return fmt.Errorf("parse cell id %s: %w", cellID, err)
}
if colID == req.ColumnId {
cell := s.Pick(cellID)
if cell == nil {
return fmt.Errorf("cell %s is not found", cellID)
}
sorter.values[i] = cell.Model().GetText().GetText()
}
}
i++
}
if req.Type == model.BlockContentDataviewSort_Asc {
sort.Stable(sorter)
} else {
sort.Stable(sort.Reverse(sorter))
}
// nolint:gocritic
rows.Model().ChildrenIds = append(headers, sorter.rowIDs...)
return nil
}
func (t *Editor) CellCreate(s *state.State, rowID string, colID string, b *model.Block) (string, error) {
tb, err := NewTable(s, rowID)
if err != nil {
return "", fmt.Errorf("initialize table state: %w", err)
}
row, err := getRow(s, rowID)
if err != nil {
return "", fmt.Errorf("get row: %w", err)
}
if _, err = pickColumn(s, colID); err != nil {
return "", fmt.Errorf("pick column: %w", err)
}
cellID, err := addCell(s, rowID, colID)
if err != nil {
return "", fmt.Errorf("add cell: %w", err)
}
cell := s.Get(cellID)
cell.Model().Content = b.Content
if err := s.InsertTo(rowID, model.Block_Inner, cellID); err != nil {
return "", fmt.Errorf("insert to: %w", err)
}
colIdx := tb.MakeColumnIndex()
normalizeRow(colIdx, row)
return cellID, nil
}
var ErrCannotMoveTableBlocks = fmt.Errorf("can not move table blocks")
type tableSorter struct {
rowIDs []string
@ -766,27 +35,6 @@ func (t tableSorter) Swap(i, j int) {
t.rowIDs[i], t.rowIDs[j] = t.rowIDs[j], t.rowIDs[i]
}
func (t *Editor) addColumnHeader(s *state.State) (string, error) {
b := simple.New(&model.Block{
Id: t.generateColID(),
Content: &model.BlockContentOfTableColumn{
TableColumn: &model.BlockContentTableColumn{},
},
})
if !s.Add(b) {
return "", fmt.Errorf("add column block")
}
return b.Model().Id, nil
}
func (t *Editor) addRow(s *state.State) (string, error) {
row := makeRow(t.generateRowID())
if !s.Add(row) {
return "", fmt.Errorf("add row block")
}
return row.Model().Id, nil
}
func makeRow(id string) simple.Block {
return simple.New(&model.Block{
Id: id,
@ -868,14 +116,7 @@ func NewTable(s *state.State, id string) (*Table, error) {
s: s,
}
next := s.Pick(id)
for next != nil {
if next.Model().GetTable() != nil {
tb.block = next
break
}
next = s.PickParentOf(next.Model().Id)
}
tb.block = PickTableRootBlock(s, id)
if tb.block == nil {
return nil, fmt.Errorf("root table block is not found")
}
@ -901,6 +142,19 @@ func NewTable(s *state.State, id string) (*Table, error) {
return &tb, nil
}
// PickTableRootBlock iterates over parents of block. Returns nil in case root table block is not found
func PickTableRootBlock(s *state.State, id string) (block simple.Block) {
next := s.Pick(id)
for next != nil {
if next.Model().GetTable() != nil {
block = next
break
}
next = s.PickParentOf(next.Model().Id)
}
return block
}
// destructureDivs removes child dividers from block
func destructureDivs(s *state.State, blockID string) {
parent := s.Pick(blockID)
@ -1006,3 +260,32 @@ func (tb Table) Iterate(f func(b simple.Block, pos CellPosition) bool) error {
}
return nil
}
// CheckTableBlocksMove checks if Insert operation is allowed in case table blocks are affected
func CheckTableBlocksMove(st *state.State, target string, pos model.BlockPosition, blockIds []string) (string, model.BlockPosition, error) {
if t, err := NewTable(st, target); err == nil && t != nil {
// we allow moving rows between each other
if lo.Every(t.RowIDs(), append(blockIds, target)) {
if pos == model.Block_Bottom || pos == model.Block_Top {
return target, pos, nil
}
return "", 0, fmt.Errorf("failed to move rows: position should be Top or Bottom, got %s", model.BlockPosition_name[int32(pos)])
}
}
for _, id := range blockIds {
t := PickTableRootBlock(st, id)
if t != nil && t.Model().Id != id {
// we should not move table blocks except table root block
return "", 0, ErrCannotMoveTableBlocks
}
}
t := PickTableRootBlock(st, target)
if t != nil && t.Model().Id != target {
// if the target is one of table blocks, but not table root, we should insert blocks under the table
return t.Model().Id, model.Block_Bottom, nil
}
return target, pos, nil
}

File diff suppressed because it is too large Load diff

View file

@ -81,20 +81,12 @@ var WithNoDuplicateLinks = func() StateTransformer {
var WithRelations = func(rels []domain.RelationKey) StateTransformer {
return func(s *state.State) {
var links []*model.RelationLink
for _, relKey := range rels {
if s.HasRelation(relKey.String()) {
continue
}
rel := bundle.MustGetRelation(relKey)
links = append(links, &model.RelationLink{Format: rel.Format, Key: rel.Key})
}
s.AddRelationLinks(links...)
s.AddBundledRelationLinks(rels...)
}
}
var WithRequiredRelations = func() StateTransformer {
return WithRelations(bundle.RequiredInternalRelations)
var WithRequiredRelations = func(s *state.State) {
WithRelations(bundle.RequiredInternalRelations)(s)
}
var WithObjectTypesAndLayout = func(otypes []domain.TypeKey, layout model.ObjectTypeLayout) StateTransformer {
@ -622,7 +614,7 @@ var WithBookmarkBlocks = func(s *state.State) {
for _, k := range bookmarkRelationKeys {
if !s.HasRelation(k) {
s.AddBundledRelations(domain.RelationKey(k))
s.AddBundledRelationLinks(domain.RelationKey(k))
}
}

View file

@ -30,7 +30,7 @@ func NewWidgetObject(
objectStore objectstore.ObjectStore,
layoutConverter converter.LayoutConverter,
) *WidgetObject {
bs := basic.NewBasic(sb, objectStore, layoutConverter)
bs := basic.NewBasic(sb, objectStore, layoutConverter, nil)
return &WidgetObject{
SmartBlock: sb,
Movable: bs,

View file

@ -16,6 +16,10 @@ import (
"github.com/anyproto/anytype-heart/util/pbtypes"
)
var workspaceRequiredRelations = []domain.RelationKey{
// SpaceInviteFileCid and SpaceInviteFileKey are added only when creating invite
}
type Workspaces struct {
smartblock.SmartBlock
basic.AllOperations
@ -31,7 +35,7 @@ type Workspaces struct {
func (f *ObjectFactory) newWorkspace(sb smartblock.SmartBlock) *Workspaces {
w := &Workspaces{
SmartBlock: sb,
AllOperations: basic.NewBasic(sb, f.objectStore, f.layoutConverter),
AllOperations: basic.NewBasic(sb, f.objectStore, f.layoutConverter, f.fileObjectService),
IHistory: basic.NewHistory(sb),
Text: stext.NewText(
sb,
@ -49,6 +53,7 @@ func (f *ObjectFactory) newWorkspace(sb smartblock.SmartBlock) *Workspaces {
}
func (w *Workspaces) Init(ctx *smartblock.InitContext) (err error) {
ctx.RequiredInternalRelationKeys = append(ctx.RequiredInternalRelationKeys, workspaceRequiredRelations...)
err = w.SmartBlock.Init(ctx)
if err != nil {
return err

View file

@ -50,7 +50,7 @@ func (d *derivedObject) GetIDAndPayload(ctx context.Context, spaceID string, sn
}
var key string
if d.isDeletedObject(uniqueKey.Marshal()) {
if d.isDeletedObject(spaceID, uniqueKey.Marshal()) {
key = bson.NewObjectId().Hex()
uniqueKey, err = domain.NewUniqueKey(sn.SbType, key)
if err != nil {
@ -74,9 +74,14 @@ func (d *derivedObject) GetInternalKey(sbType sb.SmartBlockType) string {
return d.internalKey
}
func (d *derivedObject) isDeletedObject(uniqueKey string) bool {
func (d *derivedObject) isDeletedObject(spaceId string, uniqueKey string) bool {
ids, _, err := d.objectStore.QueryObjectIDs(database.Query{
Filters: []*model.BlockContentDataviewFilter{
{
Condition: model.BlockContentDataviewFilter_Equal,
RelationKey: bundle.RelationKeySpaceId.String(),
Value: pbtypes.String(spaceId),
},
{
Condition: model.BlockContentDataviewFilter_Equal,
RelationKey: bundle.RelationKeyUniqueKey.String(),

View file

@ -32,9 +32,7 @@ func (s *service) createSet(ctx context.Context, space clientspace.Space, req *p
newState.AddDetails(req.Details)
newState.BlocksInit(newState)
tmpls := []template.StateTransformer{
template.WithRequiredRelations(),
}
tmpls := []template.StateTransformer{}
for i, view := range dvContent.Dataview.Views {
if view.Relations == nil {

View file

@ -74,12 +74,13 @@ func (gr *Builder) ObjectGraph(req *pb.RpcObjectGraphRequest) ([]*types.Struct,
return rel.Key, isRelationShouldBeIncludedAsEdge(rel)
})...)
resp, err := gr.subscriptionService.Search(pb.RpcObjectSearchSubscribeRequest{
resp, err := gr.subscriptionService.Search(subscription.SubscribeRequest{
Source: req.SetSource,
Filters: req.Filters,
Keys: lo.Map(relations.Models(), func(rel *model.Relation, _ int) string { return rel.Key }),
CollectionId: req.CollectionId,
Limit: int64(req.Limit),
Internal: true,
})
if err != nil {

View file

@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/mock"
"github.com/anyproto/anytype-heart/core/relationutils"
"github.com/anyproto/anytype-heart/core/subscription"
"github.com/anyproto/anytype-heart/core/subscription/mock_subscription"
"github.com/anyproto/anytype-heart/pb"
"github.com/anyproto/anytype-heart/pkg/lib/bundle"
@ -51,7 +52,7 @@ func Test(t *testing.T) {
{Relation: bundle.MustGetRelation(bundle.RelationKeyAuthor)},
{Relation: bundle.MustGetRelation(bundle.RelationKeyAttachments)},
}, nil)
fixture.subscriptionServiceMock.EXPECT().Search(mock.Anything).Return(&pb.RpcObjectSearchSubscribeResponse{
fixture.subscriptionServiceMock.EXPECT().Search(mock.Anything).Return(&subscription.SubscribeResponse{
Records: []*types.Struct{},
}, nil)
fixture.subscriptionServiceMock.EXPECT().Unsubscribe(mock.Anything).Return(nil)
@ -73,7 +74,7 @@ func Test(t *testing.T) {
{Relation: bundle.MustGetRelation(bundle.RelationKeyAssignee)},
{Relation: bundle.MustGetRelation(bundle.RelationKeyAttachments)},
}, nil)
fixture.subscriptionServiceMock.EXPECT().Search(mock.Anything).Return(&pb.RpcObjectSearchSubscribeResponse{
fixture.subscriptionServiceMock.EXPECT().Search(mock.Anything).Return(&subscription.SubscribeResponse{
Records: []*types.Struct{
{Fields: map[string]*types.Value{
bundle.RelationKeyId.String(): pbtypes.String("id1"),

View file

@ -4,7 +4,6 @@ package mock_treesyncer
import (
app "github.com/anyproto/any-sync/app"
domain "github.com/anyproto/anytype-heart/core/domain"
mock "github.com/stretchr/testify/mock"
)
@ -112,38 +111,37 @@ func (_c *MockSyncDetailsUpdater_Name_Call) RunAndReturn(run func() string) *Moc
return _c
}
// UpdateDetails provides a mock function with given fields: objectId, status, syncError, spaceId
func (_m *MockSyncDetailsUpdater) UpdateDetails(objectId []string, status domain.ObjectSyncStatus, syncError domain.SyncError, spaceId string) {
_m.Called(objectId, status, syncError, spaceId)
// UpdateSpaceDetails provides a mock function with given fields: existing, missing, spaceId
func (_m *MockSyncDetailsUpdater) UpdateSpaceDetails(existing []string, missing []string, spaceId string) {
_m.Called(existing, missing, spaceId)
}
// MockSyncDetailsUpdater_UpdateDetails_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateDetails'
type MockSyncDetailsUpdater_UpdateDetails_Call struct {
// MockSyncDetailsUpdater_UpdateSpaceDetails_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateSpaceDetails'
type MockSyncDetailsUpdater_UpdateSpaceDetails_Call struct {
*mock.Call
}
// UpdateDetails is a helper method to define mock.On call
// - objectId []string
// - status domain.ObjectSyncStatus
// - syncError domain.SyncError
// UpdateSpaceDetails is a helper method to define mock.On call
// - existing []string
// - missing []string
// - spaceId string
func (_e *MockSyncDetailsUpdater_Expecter) UpdateDetails(objectId interface{}, status interface{}, syncError interface{}, spaceId interface{}) *MockSyncDetailsUpdater_UpdateDetails_Call {
return &MockSyncDetailsUpdater_UpdateDetails_Call{Call: _e.mock.On("UpdateDetails", objectId, status, syncError, spaceId)}
func (_e *MockSyncDetailsUpdater_Expecter) UpdateSpaceDetails(existing interface{}, missing interface{}, spaceId interface{}) *MockSyncDetailsUpdater_UpdateSpaceDetails_Call {
return &MockSyncDetailsUpdater_UpdateSpaceDetails_Call{Call: _e.mock.On("UpdateSpaceDetails", existing, missing, spaceId)}
}
func (_c *MockSyncDetailsUpdater_UpdateDetails_Call) Run(run func(objectId []string, status domain.ObjectSyncStatus, syncError domain.SyncError, spaceId string)) *MockSyncDetailsUpdater_UpdateDetails_Call {
func (_c *MockSyncDetailsUpdater_UpdateSpaceDetails_Call) Run(run func(existing []string, missing []string, spaceId string)) *MockSyncDetailsUpdater_UpdateSpaceDetails_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]string), args[1].(domain.ObjectSyncStatus), args[2].(domain.SyncError), args[3].(string))
run(args[0].([]string), args[1].([]string), args[2].(string))
})
return _c
}
func (_c *MockSyncDetailsUpdater_UpdateDetails_Call) Return() *MockSyncDetailsUpdater_UpdateDetails_Call {
func (_c *MockSyncDetailsUpdater_UpdateSpaceDetails_Call) Return() *MockSyncDetailsUpdater_UpdateSpaceDetails_Call {
_c.Call.Return()
return _c
}
func (_c *MockSyncDetailsUpdater_UpdateDetails_Call) RunAndReturn(run func([]string, domain.ObjectSyncStatus, domain.SyncError, string)) *MockSyncDetailsUpdater_UpdateDetails_Call {
func (_c *MockSyncDetailsUpdater_UpdateSpaceDetails_Call) RunAndReturn(run func([]string, []string, string)) *MockSyncDetailsUpdater_UpdateSpaceDetails_Call {
_c.Call.Return(run)
return _c
}

View file

@ -15,8 +15,6 @@ import (
"github.com/anyproto/any-sync/net/streampool"
"github.com/anyproto/any-sync/nodeconf"
"go.uber.org/zap"
"github.com/anyproto/anytype-heart/core/domain"
)
var log = logger.NewNamed(treemanager.CName)
@ -62,14 +60,9 @@ type SyncedTreeRemover interface {
RemoveAllExcept(senderId string, differentRemoteIds []string)
}
type PeerStatusChecker interface {
app.Component
IsPeerOffline(peerId string) bool
}
type SyncDetailsUpdater interface {
app.Component
UpdateDetails(objectId []string, status domain.ObjectSyncStatus, syncError domain.SyncError, spaceId string)
UpdateSpaceDetails(existing, missing []string, spaceId string)
}
type treeSyncer struct {
@ -84,7 +77,6 @@ type treeSyncer struct {
treeManager treemanager.TreeManager
isRunning bool
isSyncing bool
peerManager PeerStatusChecker
nodeConf nodeconf.NodeConf
syncedTreeRemover SyncedTreeRemover
syncDetailsUpdater SyncDetailsUpdater
@ -106,7 +98,6 @@ func NewTreeSyncer(spaceId string) treesyncer.TreeSyncer {
func (t *treeSyncer) Init(a *app.App) (err error) {
t.isSyncing = true
t.treeManager = app.MustComponent[treemanager.TreeManager](a)
t.peerManager = app.MustComponent[PeerStatusChecker](a)
t.nodeConf = app.MustComponent[nodeconf.NodeConf](a)
t.syncedTreeRemover = app.MustComponent[SyncedTreeRemover](a)
t.syncDetailsUpdater = app.MustComponent[SyncDetailsUpdater](a)
@ -161,13 +152,11 @@ func (t *treeSyncer) ShouldSync(peerId string) bool {
return t.isSyncing
}
func (t *treeSyncer) SyncAll(ctx context.Context, peerId string, existing, missing []string) error {
func (t *treeSyncer) SyncAll(ctx context.Context, peerId string, existing, missing []string) (err error) {
t.Lock()
defer t.Unlock()
var err error
isResponsible := slices.Contains(t.nodeConf.NodeIds(t.spaceId), peerId)
defer t.sendResultEvent(err, isResponsible, peerId, existing)
t.sendSyncingEvent(peerId, existing, missing, isResponsible)
t.sendSyncEvents(existing, missing, isResponsible)
reqExec, exists := t.requestPools[peerId]
if !exists {
reqExec = newExecutor(t.requests, 0)
@ -206,31 +195,15 @@ func (t *treeSyncer) SyncAll(ctx context.Context, peerId string, existing, missi
return nil
}
func (t *treeSyncer) sendSyncingEvent(peerId string, existing []string, missing []string, nodePeer bool) {
func (t *treeSyncer) sendSyncEvents(existing, missing []string, nodePeer bool) {
if !nodePeer {
return
}
if t.peerManager.IsPeerOffline(peerId) {
t.sendDetailsUpdates(existing, domain.ObjectError, domain.NetworkError)
return
}
if len(existing) != 0 || len(missing) != 0 {
t.sendDetailsUpdates(existing, domain.ObjectSyncing, domain.Null)
}
t.sendDetailsUpdates(existing, missing)
}
func (t *treeSyncer) sendResultEvent(err error, nodePeer bool, peerId string, existing []string) {
if nodePeer && !t.peerManager.IsPeerOffline(peerId) {
if err != nil {
t.sendDetailsUpdates(existing, domain.ObjectError, domain.NetworkError)
} else {
t.sendDetailsUpdates(existing, domain.ObjectSynced, domain.Null)
}
}
}
func (t *treeSyncer) sendDetailsUpdates(existing []string, status domain.ObjectSyncStatus, syncError domain.SyncError) {
t.syncDetailsUpdater.UpdateDetails(existing, status, syncError, t.spaceId)
func (t *treeSyncer) sendDetailsUpdates(existing, missing []string) {
t.syncDetailsUpdater.UpdateSpaceDetails(existing, missing, t.spaceId)
}
func (t *treeSyncer) requestTree(peerId, id string) {
@ -257,6 +230,7 @@ func (t *treeSyncer) updateTree(peerId, id string) {
syncTree, ok := tr.(synctree.SyncTree)
if !ok {
log.Warn("not a sync tree")
return
}
if err = syncTree.SyncWithPeer(ctx, peerId); err != nil {
log.Warn("synctree.SyncWithPeer error", zap.Error(err))

View file

@ -16,7 +16,6 @@ import (
"go.uber.org/mock/gomock"
"github.com/anyproto/anytype-heart/core/block/object/treesyncer/mock_treesyncer"
"github.com/anyproto/anytype-heart/core/domain"
"github.com/anyproto/anytype-heart/tests/testutil"
)
@ -26,7 +25,6 @@ type fixture struct {
missingMock *mock_objecttree.MockObjectTree
existingMock *mock_synctree.MockSyncTree
treeManager *mock_treemanager.MockTreeManager
checker *mock_treesyncer.MockPeerStatusChecker
nodeConf *mock_nodeconf.MockService
syncStatus *mock_treesyncer.MockSyncedTreeRemover
syncDetailsUpdater *mock_treesyncer.MockSyncDetailsUpdater
@ -37,8 +35,6 @@ func newFixture(t *testing.T, spaceId string) *fixture {
treeManager := mock_treemanager.NewMockTreeManager(ctrl)
missingMock := mock_objecttree.NewMockObjectTree(ctrl)
existingMock := mock_synctree.NewMockSyncTree(ctrl)
checker := mock_treesyncer.NewMockPeerStatusChecker(t)
checker.EXPECT().Name().Return("checker").Maybe()
nodeConf := mock_nodeconf.NewMockService(ctrl)
nodeConf.EXPECT().Name().Return("nodeConf").AnyTimes()
syncStatus := mock_treesyncer.NewMockSyncedTreeRemover(t)
@ -46,7 +42,6 @@ func newFixture(t *testing.T, spaceId string) *fixture {
a := new(app.App)
a.Register(testutil.PrepareMock(context.Background(), a, treeManager)).
Register(testutil.PrepareMock(context.Background(), a, checker)).
Register(testutil.PrepareMock(context.Background(), a, syncStatus)).
Register(testutil.PrepareMock(context.Background(), a, nodeConf)).
Register(testutil.PrepareMock(context.Background(), a, syncDetailsUpdater))
@ -59,7 +54,6 @@ func newFixture(t *testing.T, spaceId string) *fixture {
missingMock: missingMock,
existingMock: existingMock,
treeManager: treeManager,
checker: checker,
nodeConf: nodeConf,
syncStatus: syncStatus,
syncDetailsUpdater: syncDetailsUpdater,
@ -91,6 +85,25 @@ func TestTreeSyncer(t *testing.T) {
fx.Close(ctx)
})
t.Run("delayed sync notify sync status", func(t *testing.T) {
ctx := context.Background()
fx := newFixture(t, spaceId)
fx.treeManager.EXPECT().GetTree(gomock.Any(), spaceId, existingId).Return(fx.existingMock, nil)
fx.existingMock.EXPECT().SyncWithPeer(gomock.Any(), peerId).Return(nil)
fx.treeManager.EXPECT().GetTree(gomock.Any(), spaceId, missingId).Return(fx.missingMock, nil)
fx.nodeConf.EXPECT().NodeIds(spaceId).Return([]string{peerId})
fx.syncDetailsUpdater.EXPECT().UpdateSpaceDetails([]string{existingId}, []string{missingId}, spaceId)
fx.syncStatus.EXPECT().RemoveAllExcept(peerId, []string{existingId}).Return()
err := fx.SyncAll(context.Background(), peerId, []string{existingId}, []string{missingId})
require.NoError(t, err)
require.NotNil(t, fx.requestPools[peerId])
require.NotNil(t, fx.headPools[peerId])
fx.StartSync()
time.Sleep(100 * time.Millisecond)
fx.Close(ctx)
})
t.Run("sync after run", func(t *testing.T) {
ctx := context.Background()
fx := newFixture(t, spaceId)
@ -189,45 +202,5 @@ func TestTreeSyncer(t *testing.T) {
require.Equal(t, []string{"before close", "after done"}, events)
mutex.Unlock()
})
t.Run("send offline event", func(t *testing.T) {
ctx := context.Background()
fx := newFixture(t, spaceId)
fx.treeManager.EXPECT().GetTree(gomock.Any(), spaceId, existingId).Return(fx.existingMock, nil)
fx.existingMock.EXPECT().SyncWithPeer(gomock.Any(), peerId).Return(nil)
fx.treeManager.EXPECT().GetTree(gomock.Any(), spaceId, missingId).Return(fx.missingMock, nil)
fx.nodeConf.EXPECT().NodeIds(spaceId).Return([]string{peerId})
fx.checker.EXPECT().IsPeerOffline(peerId).Return(true)
fx.syncStatus.EXPECT().RemoveAllExcept(peerId, []string{existingId}).Return()
fx.syncDetailsUpdater.EXPECT().UpdateDetails([]string{"existing"}, domain.ObjectError, domain.NetworkError, "spaceId").Return()
fx.StartSync()
err := fx.SyncAll(context.Background(), peerId, []string{existingId}, []string{missingId})
require.NoError(t, err)
require.NotNil(t, fx.requestPools[peerId])
require.NotNil(t, fx.headPools[peerId])
time.Sleep(100 * time.Millisecond)
fx.Close(ctx)
})
t.Run("send syncing and synced event", func(t *testing.T) {
ctx := context.Background()
fx := newFixture(t, spaceId)
fx.treeManager.EXPECT().GetTree(gomock.Any(), spaceId, existingId).Return(fx.existingMock, nil)
fx.existingMock.EXPECT().SyncWithPeer(gomock.Any(), peerId).Return(nil)
fx.treeManager.EXPECT().GetTree(gomock.Any(), spaceId, missingId).Return(fx.missingMock, nil)
fx.nodeConf.EXPECT().NodeIds(spaceId).Return([]string{peerId})
fx.checker.EXPECT().IsPeerOffline(peerId).Return(false)
fx.syncStatus.EXPECT().RemoveAllExcept(peerId, []string{existingId}).Return()
fx.syncDetailsUpdater.EXPECT().UpdateDetails([]string{"existing"}, domain.ObjectSynced, domain.Null, "spaceId").Return()
fx.syncDetailsUpdater.EXPECT().UpdateDetails([]string{"existing"}, domain.ObjectSyncing, domain.Null, "spaceId").Return()
fx.StartSync()
err := fx.SyncAll(context.Background(), peerId, []string{existingId}, []string{missingId})
require.NoError(t, err)
require.NotNil(t, fx.requestPools[peerId])
require.NotNil(t, fx.headPools[peerId])
time.Sleep(100 * time.Millisecond)
fx.Close(ctx)
})
}

View file

@ -45,6 +45,7 @@ type accountService interface {
type Space interface {
Id() string
IsPersonal() bool
TreeBuilder() objecttreebuilder.TreeBuilder
GetRelationIdByKey(ctx context.Context, key domain.RelationKey) (id string, err error)
GetTypeIdByKey(ctx context.Context, key domain.TypeKey) (id string, err error)

View file

@ -171,7 +171,7 @@ type fileObjectMigrator interface {
}
type RelationGetter interface {
GetRelationByKey(key string) (*model.Relation, error)
GetRelationByKey(spaceId string, key string) (*model.Relation, error)
}
type source struct {
@ -294,6 +294,8 @@ func (s *source) buildState() (doc state.Doc, err error) {
migration := NewSubObjectsAndProfileLinksMigration(s.smartblockType, s.space, s.accountService.MyParticipantId(s.spaceID), s.objectStore)
migration.Migrate(st)
// we need to have required internal relations for all objects, including system
st.AddBundledRelationLinks(bundle.RequiredInternalRelations...)
if s.Type() == smartblock.SmartBlockTypePage || s.Type() == smartblock.SmartBlockTypeProfilePage {
template.WithAddedFeaturedRelation(bundle.RelationKeyBacklinks)(st)
template.WithRelations([]domain.RelationKey{bundle.RelationKeyBacklinks})(st)
@ -347,7 +349,6 @@ func (s *source) PushChange(params PushChangeParams) (id string, err error) {
change := s.buildChange(params)
data, dataType, err := MarshalChange(change)
if err != nil {
return
}

View file

@ -91,7 +91,12 @@ func (m *subObjectsAndProfileLinksMigration) replaceLinksInDetails(s *state.Stat
}
}
// Migrate works only in personal space
func (m *subObjectsAndProfileLinksMigration) Migrate(s *state.State) {
if !m.space.IsPersonal() {
return
}
uk, err := domain.NewUniqueKey(smartblock.SmartBlockTypeProfilePage, "")
if err != nil {
log.Errorf("migration: failed to create unique key for profile: %s", err)
@ -187,7 +192,7 @@ func (m *subObjectsAndProfileLinksMigration) migrateFilter(filter *model.BlockCo
log.With("relationKey", filter.RelationKey).Warnf("empty filter value")
return nil
}
relation, err := m.objectStore.GetRelationByKey(filter.RelationKey)
relation, err := m.objectStore.GetRelationByKey(m.space.Id(), filter.RelationKey)
if err != nil {
log.Warnf("migration: failed to get relation by key %s: %s", filter.RelationKey, err)
}

View file

@ -314,7 +314,6 @@ func (s *service) createBlankTemplateState(layout model.ObjectTypeLayout) (st *s
template.WithFeaturedRelations,
template.WithAddedFeaturedRelation(bundle.RelationKeyTag),
template.WithDetail(bundle.RelationKeyTag, pbtypes.StringList(nil)),
template.WithRequiredRelations(),
template.WithTitle,
)
_ = s.converter.Convert(nil, st, model.ObjectType_basic, layout)

View file

@ -7,7 +7,7 @@ import (
const (
// ObjectPathSeparator is the separator between object id and block id or relation key
objectPathSeparator = "/"
ObjectPathSeparator = "/"
blockPrefix = "b"
relationPrefix = "r"
)
@ -21,10 +21,10 @@ type ObjectPath struct {
// String returns the full path, e.g. "objectId-b-blockId" or "objectId-r-relationKey"
func (o ObjectPath) String() string {
if o.HasBlock() {
return strings.Join([]string{o.ObjectId, blockPrefix, o.BlockId}, objectPathSeparator)
return strings.Join([]string{o.ObjectId, blockPrefix, o.BlockId}, ObjectPathSeparator)
}
if o.HasRelation() {
return strings.Join([]string{o.ObjectId, relationPrefix, o.RelationKey}, objectPathSeparator)
return strings.Join([]string{o.ObjectId, relationPrefix, o.RelationKey}, ObjectPathSeparator)
}
return o.ObjectId
}
@ -32,10 +32,10 @@ func (o ObjectPath) String() string {
// ObjectRelativePath returns the relative path of the object without the object id prefix
func (o ObjectPath) ObjectRelativePath() string {
if o.HasBlock() {
return strings.Join([]string{blockPrefix, o.BlockId}, objectPathSeparator)
return strings.Join([]string{blockPrefix, o.BlockId}, ObjectPathSeparator)
}
if o.HasRelation() {
return strings.Join([]string{relationPrefix, o.RelationKey}, objectPathSeparator)
return strings.Join([]string{relationPrefix, o.RelationKey}, ObjectPathSeparator)
}
return ""
}
@ -67,7 +67,7 @@ func NewObjectPathWithRelation(objectId, relationKey string) ObjectPath {
}
func NewFromPath(path string) (ObjectPath, error) {
parts := strings.Split(path, objectPathSeparator)
parts := strings.Split(path, ObjectPathSeparator)
if len(parts) == 3 && parts[1] == blockPrefix {
return NewObjectPathWithBlock(parts[0], parts[2]), nil
}

View file

@ -1,53 +1,29 @@
package domain
type SyncType int32
const (
Objects SyncType = 0
Files SyncType = 1
)
type SpaceSyncStatus int32
const (
Synced SpaceSyncStatus = 0
Syncing SpaceSyncStatus = 1
Error SpaceSyncStatus = 2
Offline SpaceSyncStatus = 3
Unknown SpaceSyncStatus = 4
SpaceSyncStatusSynced SpaceSyncStatus = 0
SpaceSyncStatusSyncing SpaceSyncStatus = 1
SpaceSyncStatusError SpaceSyncStatus = 2
SpaceSyncStatusOffline SpaceSyncStatus = 3
SpaceSyncStatusUnknown SpaceSyncStatus = 4
)
type ObjectSyncStatus int32
const (
ObjectSynced ObjectSyncStatus = 0
ObjectSyncing ObjectSyncStatus = 1
ObjectError ObjectSyncStatus = 2
ObjectQueued ObjectSyncStatus = 3
ObjectSyncStatusSynced ObjectSyncStatus = 0
ObjectSyncStatusSyncing ObjectSyncStatus = 1
ObjectSyncStatusError ObjectSyncStatus = 2
ObjectSyncStatusQueued ObjectSyncStatus = 3
)
type SyncError int32
const (
Null SyncError = 0
StorageLimitExceed SyncError = 1
IncompatibleVersion SyncError = 2
NetworkError SyncError = 3
Oversized SyncError = 4
SyncErrorNull SyncError = 0
SyncErrorIncompatibleVersion SyncError = 2
SyncErrorNetworkError SyncError = 3
SyncErrorOversized SyncError = 4
)
type SpaceSync struct {
SpaceId string
Status SpaceSyncStatus
SyncError SyncError
SyncType SyncType
}
func MakeSyncStatus(spaceId string, status SpaceSyncStatus, syncError SyncError, syncType SyncType) *SpaceSync {
return &SpaceSync{
SpaceId: spaceId,
Status: status,
SyncError: syncError,
SyncType: syncType,
}
}

View file

@ -217,7 +217,7 @@ func (ind *indexer) injectMetadataToState(ctx context.Context, st *state.State,
for k := range details.Fields {
keys = append(keys, domain.RelationKey(k))
}
st.AddBundledRelations(keys...)
st.AddBundledRelationLinks(keys...)
details = pbtypes.StructMerge(prevDetails, details, false)
st.SetDetails(details)

View file

@ -304,8 +304,8 @@ func (s *service) makeInitialDetails(fileId domain.FileId, origin objectorigin.O
// Use general file layout. It will be changed for proper layout after indexing
bundle.RelationKeyLayout.String(): pbtypes.Int64(int64(model.ObjectType_file)),
bundle.RelationKeyFileIndexingStatus.String(): pbtypes.Int64(int64(model.FileIndexingStatus_NotIndexed)),
bundle.RelationKeySyncStatus.String(): pbtypes.Int64(int64(domain.ObjectQueued)),
bundle.RelationKeySyncError.String(): pbtypes.Int64(int64(domain.Null)),
bundle.RelationKeySyncStatus.String(): pbtypes.Int64(int64(domain.ObjectSyncStatusQueued)),
bundle.RelationKeySyncError.String(): pbtypes.Int64(int64(domain.SyncErrorNull)),
bundle.RelationKeyFileBackupStatus.String(): pbtypes.Int64(int64(filesyncstatus.Queued)),
},
}

View file

@ -92,10 +92,12 @@ type fileSync struct {
importEventsMutex sync.Mutex
importEvents []*pb.Event
cfg *config.Config
closeWg *sync.WaitGroup
}
func New() FileSync {
return &fileSync{}
return &fileSync{closeWg: &sync.WaitGroup{}}
}
func (s *fileSync) Init(a *app.App) (err error) {
@ -173,7 +175,10 @@ func (s *fileSync) Run(ctx context.Context) (err error) {
s.retryDeletionQueue.Run()
s.loopCtx, s.loopCancel = context.WithCancel(context.Background())
s.closeWg.Add(1)
go s.runNodeUsageUpdater()
return
}
@ -211,5 +216,7 @@ func (s *fileSync) Close(ctx context.Context) error {
}
}
s.closeWg.Wait()
return nil
}

View file

@ -62,6 +62,8 @@ func (s FileStat) IsPinned() bool {
}
func (s *fileSync) runNodeUsageUpdater() {
defer s.closeWg.Done()
s.precacheNodeUsage()
ticker := time.NewTicker(time.Second * 10)

View file

@ -97,6 +97,9 @@ func (s *fileSync) handleLimitReachedError(err error, it *QueueItem) *errLimitRe
func (s *fileSync) uploadingHandler(ctx context.Context, it *QueueItem) (persistentqueue.Action, error) {
spaceId, fileId := it.SpaceId, it.FileId
err := s.uploadFile(ctx, spaceId, fileId, it.ObjectId)
if errors.Is(err, context.Canceled) {
return persistentqueue.ActionRetry, nil
}
if isObjectDeletedError(err) {
return persistentqueue.ActionDone, s.DeleteFile(it.ObjectId, it.FullFileId())
}
@ -143,6 +146,9 @@ func (s *fileSync) addToRetryUploadingQueue(it *QueueItem) persistentqueue.Actio
func (s *fileSync) retryingHandler(ctx context.Context, it *QueueItem) (persistentqueue.Action, error) {
spaceId, fileId := it.SpaceId, it.FileId
err := s.uploadFile(ctx, spaceId, fileId, it.ObjectId)
if errors.Is(err, context.Canceled) {
return persistentqueue.ActionRetry, nil
}
if isObjectDeletedError(err) {
return persistentqueue.ActionDone, s.removeFromUploadingQueues(it.ObjectId)
}

View file

@ -30,7 +30,7 @@ type service struct {
func (s *service) Init(a *app.App) (err error) {
s.pool = a.MustComponent(pool.CName).(pool.Pool)
s.peerStore = a.MustComponent(peerstore.CName).(peerstore.PeerStore)
s.peerStore.AddObserver(func(peerId string, spaceIds []string) {
s.peerStore.AddObserver(func(peerId string, _, spaceIds []string, peerRemoved bool) {
select {
case s.peerUpdateCh <- struct{}{}:
default:

View file

@ -90,6 +90,9 @@ func (i *indexer) runFullTextIndexer(ctx context.Context) {
}
for _, doc := range objDocs {
if err != nil {
return fmt.Errorf("batcher delete: %w", err)
}
err = batcher.UpdateDoc(doc)
if err != nil {
return fmt.Errorf("batcher add: %w", err)

View file

@ -207,11 +207,11 @@ func (i *indexer) ReindexSpace(space clientspace.Space) (err error) {
func (i *indexer) addSyncDetails(space clientspace.Space) {
typesForSyncRelations := helper.SyncRelationsSmartblockTypes()
syncStatus := domain.ObjectSynced
syncError := domain.Null
syncStatus := domain.ObjectSyncStatusSynced
syncError := domain.SyncErrorNull
if i.config.IsLocalOnlyMode() {
syncStatus = domain.ObjectError
syncError = domain.NetworkError
syncStatus = domain.ObjectSyncStatusError
syncError = domain.SyncErrorNetworkError
}
ids, err := i.getIdsForTypes(space, typesForSyncRelations...)
if err != nil {

View file

@ -156,7 +156,7 @@ func TestReindexDeletedObjects(t *testing.T) {
space1 := mock_space.NewMockSpace(t)
space1.EXPECT().Id().Return(spaceId1)
space1.EXPECT().Storage().Return(storage1)
space1.EXPECT().StoredIds().Return([]string{})
space1.EXPECT().StoredIds().Return([]string{}).Maybe()
fx.sourceFx.EXPECT().IDsListerBySmartblockType(mock.Anything, mock.Anything).Return(idsLister{Ids: []string{}}, nil)
@ -175,7 +175,7 @@ func TestReindexDeletedObjects(t *testing.T) {
space2 := mock_space.NewMockSpace(t)
space2.EXPECT().Id().Return(spaceId2)
space2.EXPECT().Storage().Return(storage2)
space2.EXPECT().StoredIds().Return([]string{})
space2.EXPECT().StoredIds().Return([]string{}).Maybe()
fx.sourceFx.EXPECT().IDsListerBySmartblockType(mock.Anything, mock.Anything).Return(idsLister{Ids: []string{}}, nil)
err = fx.ReindexSpace(space2)
@ -374,7 +374,7 @@ func TestReindex_addSyncRelations(t *testing.T) {
space1 := mock_space.NewMockSpace(t)
space1.EXPECT().Id().Return(spaceId1)
space1.EXPECT().StoredIds().Return([]string{})
space1.EXPECT().StoredIds().Return([]string{}).Maybe()
fx.sourceFx.EXPECT().IDsListerBySmartblockType(space1, coresb.SmartBlockTypePage).Return(idsLister{Ids: []string{"1", "2"}}, nil)
fx.sourceFx.EXPECT().IDsListerBySmartblockType(space1, coresb.SmartBlockTypeRelation).Return(idsLister{Ids: []string{}}, nil)
@ -412,7 +412,7 @@ func TestReindex_addSyncRelations(t *testing.T) {
space1 := mock_space.NewMockSpace(t)
space1.EXPECT().Id().Return(spaceId1)
space1.EXPECT().StoredIds().Return([]string{})
space1.EXPECT().StoredIds().Return([]string{}).Maybe()
fx.sourceFx.EXPECT().IDsListerBySmartblockType(space1, coresb.SmartBlockTypePage).Return(idsLister{Ids: []string{"1", "2"}}, nil)
fx.sourceFx.EXPECT().IDsListerBySmartblockType(space1, coresb.SmartBlockTypeRelation).Return(idsLister{Ids: []string{}}, nil)

View file

@ -47,16 +47,18 @@ type notificationService struct {
spaceService space.Service
picker cache.ObjectGetter
mu sync.Mutex
loadTimeout time.Duration
loadFinish chan struct{}
sync.RWMutex
lastNotificationIdToAcl map[string]string
}
func New() Notifications {
func New(loadTimeout time.Duration) Notifications {
return &notificationService{
lastNotificationIdToAcl: make(map[string]string, 0),
loadFinish: make(chan struct{}),
loadTimeout: loadTimeout,
}
}
@ -227,7 +229,7 @@ func (n *notificationService) Reply(notificationIds []string, notificationAction
}
func (n *notificationService) List(limit int64, includeRead bool) ([]*model.Notification, error) {
ticker := time.NewTicker(time.Second * 10)
ticker := time.NewTicker(n.loadTimeout)
defer ticker.Stop()
select {

View file

@ -29,6 +29,7 @@ func TestNotificationService_List(t *testing.T) {
notifications := notificationService{
eventSender: sender,
notificationStore: NewTestStore(t),
loadTimeout: 10 * time.Millisecond,
}
// when
@ -49,6 +50,7 @@ func TestNotificationService_List(t *testing.T) {
notifications := notificationService{
eventSender: sender,
notificationStore: storeFixture,
loadTimeout: 10 * time.Millisecond,
}
// when
@ -69,6 +71,7 @@ func TestNotificationService_List(t *testing.T) {
notifications := notificationService{
eventSender: sender,
notificationStore: storeFixture,
loadTimeout: 10 * time.Millisecond,
}
// when
@ -89,6 +92,7 @@ func TestNotificationService_List(t *testing.T) {
notifications := notificationService{
eventSender: sender,
notificationStore: storeFixture,
loadTimeout: 10 * time.Millisecond,
}
// when
@ -109,6 +113,7 @@ func TestNotificationService_List(t *testing.T) {
notifications := notificationService{
eventSender: sender,
notificationStore: storeFixture,
loadTimeout: 10 * time.Millisecond,
}
// when
@ -134,6 +139,7 @@ func TestNotificationService_Reply(t *testing.T) {
eventSender: sender,
notificationStore: storeFixture,
lastNotificationIdToAcl: map[string]string{},
loadTimeout: 10 * time.Millisecond,
}
// when
@ -157,6 +163,7 @@ func TestNotificationService_Reply(t *testing.T) {
notifications := notificationService{
eventSender: sender,
notificationStore: storeFixture,
loadTimeout: 10 * time.Millisecond,
}
// when
@ -183,6 +190,7 @@ func TestNotificationService_Reply(t *testing.T) {
notifications := notificationService{
eventSender: sender,
notificationStore: storeFixture,
loadTimeout: 10 * time.Millisecond,
}
// when
@ -219,6 +227,7 @@ func TestNotificationService_CreateAndSend(t *testing.T) {
notifications := notificationService{
eventSender: sender,
notificationStore: storeFixture,
loadTimeout: 10 * time.Millisecond,
}
// when
@ -255,6 +264,7 @@ func TestNotificationService_CreateAndSend(t *testing.T) {
notificationStore: storeFixture,
picker: objectGetter,
notificationId: notificationObjectId,
loadTimeout: 10 * time.Millisecond,
}
// when
@ -301,6 +311,7 @@ func TestNotificationService_CreateAndSend(t *testing.T) {
picker: objectGetter,
notificationId: notificationObjectId,
lastNotificationIdToAcl: map[string]string{},
loadTimeout: 10 * time.Millisecond,
}
// when

View file

@ -349,12 +349,30 @@ func (mw *Middleware) ObjectSearchSubscribe(cctx context.Context, req *pb.RpcObj
subService := mw.applicationService.GetApp().MustComponent(subscription.CName).(subscription.Service)
resp, err := subService.Search(*req)
resp, err := subService.Search(subscription.SubscribeRequest{
SubId: req.SubId,
Filters: req.Filters,
Sorts: req.Sorts,
Limit: req.Limit,
Offset: req.Offset,
Keys: req.Keys,
AfterId: req.AfterId,
BeforeId: req.BeforeId,
Source: req.Source,
IgnoreWorkspace: req.IgnoreWorkspace,
NoDepSubscription: req.NoDepSubscription,
CollectionId: req.CollectionId,
})
if err != nil {
return errResponse(err)
}
return resp
return &pb.RpcObjectSearchSubscribeResponse{
SubId: resp.SubId,
Records: resp.Records,
Dependencies: resp.Dependencies,
Counters: resp.Counters,
}
}
func (mw *Middleware) ObjectGroupsSubscribe(cctx context.Context, req *pb.RpcObjectGroupsSubscribeRequest) *pb.RpcObjectGroupsSubscribeResponse {

View file

@ -27,13 +27,11 @@ var (
ErrCacheDbNotInitialized = errors.New("cache db is not initialized")
ErrCacheDbError = errors.New("cache db error")
ErrUnsupportedCacheVersion = errors.New("unsupported cache version")
ErrCacheDisabled = errors.New("cache is disabled")
ErrCacheExpired = errors.New("cache is empty")
)
// once you change the cache format, you need to update this variable
// it will cause cache to be dropped and recreated
const cacheLastVersion = 6
const cacheLastVersion = 7
const (
cacheLifetimeDurExplorer = 24 * time.Hour
@ -79,8 +77,6 @@ func newStorageStruct() *StorageStruct {
}
type CacheService interface {
// if cache is disabled -> will return objects and ErrCacheDisabled
// if cache is expired -> will return objects and ErrCacheExpired
CacheGet() (status *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse, err error)
// if cache is disabled -> will return no error
@ -88,7 +84,9 @@ type CacheService interface {
// status or tiers can be nil depending on what you want to update
CacheSet(status *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error)
IsCacheEnabled() (enabled bool)
IsCacheDisabled() (disabled bool)
IsCacheExpired() (expired bool)
// if already enabled -> will not return error
CacheEnable() (err error)
@ -134,10 +132,17 @@ func (s *cacheservice) Run(_ context.Context) (err error) {
}
func (s *cacheservice) Close(_ context.Context) (err error) {
return s.db.Close()
s.m.Lock()
defer s.m.Unlock()
s.db = nil
return nil
}
func (s *cacheservice) CacheGet() (status *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse, err error) {
s.m.Lock()
defer s.m.Unlock()
// 1 - check in storage
ss, err := s.get()
if err != nil {
@ -152,19 +157,7 @@ func (s *cacheservice) CacheGet() (status *pb.RpcMembershipGetStatusResponse, ti
return nil, nil, ErrUnsupportedCacheVersion
}
// 2 - check if cache is disabled
if !s.IsCacheEnabled() {
// return object too
return &ss.SubscriptionStatus, &ss.TiersData, ErrCacheDisabled
}
// 3 - check if cache is outdated
if time.Now().UTC().After(ss.ExpireTime) {
// return object too
return &ss.SubscriptionStatus, &ss.TiersData, ErrCacheExpired
}
// 4 - return value
// 2 - return value
return &ss.SubscriptionStatus, &ss.TiersData, nil
}
@ -204,6 +197,9 @@ func getExpireTime(latestStatus *model.Membership) time.Time {
}
func (s *cacheservice) CacheSet(status *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error) {
s.m.Lock()
defer s.m.Unlock()
var latestStatus *model.Membership
// 1 - get existing storage
@ -231,23 +227,47 @@ func (s *cacheservice) CacheSet(status *pb.RpcMembershipGetStatusResponse, tiers
return s.set(ss)
}
func (s *cacheservice) IsCacheEnabled() (enabled bool) {
func (s *cacheservice) IsCacheDisabled() (disabled bool) {
s.m.Lock()
defer s.m.Unlock()
// 1 - get existing storage
ss, err := s.get()
if err != nil {
return false
}
// 2 - check if cache is disabled
if !ss.DisableUntilTime.IsZero() && time.Now().UTC().Before(ss.DisableUntilTime) {
return true
}
return false
}
func (s *cacheservice) IsCacheExpired() (expired bool) {
s.m.Lock()
defer s.m.Unlock()
// 1 - get existing storage
ss, err := s.get()
if err != nil {
return true
}
// 2 - check if cache is disabled
if (ss.DisableUntilTime != time.Time{}) && time.Now().UTC().Before(ss.DisableUntilTime) {
return false
// 2 - check if cache is outdated
if time.Now().UTC().After(ss.ExpireTime) {
return true
}
return true
return false
}
// will not return error if already enabled
func (s *cacheservice) CacheEnable() (err error) {
s.m.Lock()
defer s.m.Unlock()
// 1 - get existing storage
ss, err := s.get()
if err != nil {
@ -265,6 +285,9 @@ func (s *cacheservice) CacheEnable() (err error) {
// will not return error if already disabled
// if currently disabled - will disable for next N minutes
func (s *cacheservice) CacheDisableForNextMinutes(minutes int) (err error) {
s.m.Lock()
defer s.m.Unlock()
// 1 - get existing storage
ss, err := s.get()
if err != nil {
@ -281,6 +304,9 @@ func (s *cacheservice) CacheDisableForNextMinutes(minutes int) (err error) {
// does not take into account if cache is enabled or not, erases always
func (s *cacheservice) CacheClear() (err error) {
s.m.Lock()
defer s.m.Unlock()
// 1 - get existing storage
_, err = s.get()
if err != nil {
@ -300,9 +326,6 @@ func (s *cacheservice) get() (out *StorageStruct, err error) {
return nil, ErrCacheDbNotInitialized
}
s.m.Lock()
defer s.m.Unlock()
var ss StorageStruct
err = s.db.View(func(txn *badger.Txn) error {
item, err := txn.Get([]byte(dbKey))
@ -321,8 +344,9 @@ func (s *cacheservice) get() (out *StorageStruct, err error) {
}
func (s *cacheservice) set(in *StorageStruct) (err error) {
s.m.Lock()
defer s.m.Unlock()
if s.db == nil {
return ErrCacheDbNotInitialized
}
return s.db.Update(func(txn *badger.Txn) error {
// convert

View file

@ -97,13 +97,15 @@ func TestPayments_DisableCache(t *testing.T) {
require.NoError(t, err)
_, _, err = fx.CacheGet()
require.Equal(t, ErrCacheDisabled, err)
require.NoError(t, err)
require.True(t, fx.IsCacheDisabled())
err = fx.CacheClear()
require.NoError(t, err)
_, _, err = fx.CacheGet()
require.Equal(t, ErrCacheExpired, err)
require.NoError(t, err)
require.False(t, fx.IsCacheDisabled())
})
}
@ -159,7 +161,8 @@ func TestPayments_ClearCache(t *testing.T) {
require.NoError(t, err)
_, _, err = fx.CacheGet()
require.Equal(t, ErrCacheExpired, err)
require.NoError(t, err)
require.True(t, fx.IsCacheExpired())
})
}
@ -216,14 +219,14 @@ func TestPayments_CacheGetSubscriptionStatus(t *testing.T) {
fx := newFixture(t)
defer fx.finish(t)
en := fx.IsCacheEnabled()
require.Equal(t, true, en)
dis := fx.IsCacheDisabled()
require.False(t, dis)
err := fx.CacheDisableForNextMinutes(10)
require.NoError(t, err)
en = fx.IsCacheEnabled()
require.Equal(t, false, en)
dis = fx.IsCacheDisabled()
require.True(t, dis)
err = fx.CacheSet(&pb.RpcMembershipGetStatusResponse{
Data: &model.Membership{
@ -236,17 +239,18 @@ func TestPayments_CacheGetSubscriptionStatus(t *testing.T) {
},
)
require.NoError(t, err)
dis = fx.IsCacheDisabled()
require.True(t, dis)
out, _, err := fx.CacheGet()
require.Equal(t, ErrCacheDisabled, err)
// HERE: weird semantics, error is returned too :-)
require.NoError(t, err)
require.Equal(t, uint32(psp.SubscriptionTier_TierExplorer), out.Data.Tier)
err = fx.CacheEnable()
require.NoError(t, err)
en = fx.IsCacheEnabled()
require.Equal(t, true, en)
dis = fx.IsCacheDisabled()
require.False(t, dis)
out, _, err = fx.CacheGet()
require.NoError(t, err)
@ -272,8 +276,12 @@ func TestPayments_CacheGetSubscriptionStatus(t *testing.T) {
err = fx.CacheClear()
require.NoError(t, err)
// check if cache is expired
exp := fx.IsCacheExpired()
require.True(t, exp)
_, _, err = fx.CacheGet()
require.Equal(t, ErrCacheExpired, err)
require.NoError(t, err)
})
}

View file

@ -318,12 +318,12 @@ func (_c *MockCacheService_Init_Call) RunAndReturn(run func(*app.App) error) *Mo
return _c
}
// IsCacheEnabled provides a mock function with given fields:
func (_m *MockCacheService) IsCacheEnabled() bool {
// IsCacheDisabled provides a mock function with given fields:
func (_m *MockCacheService) IsCacheDisabled() bool {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for IsCacheEnabled")
panic("no return value specified for IsCacheDisabled")
}
var r0 bool
@ -336,29 +336,74 @@ func (_m *MockCacheService) IsCacheEnabled() bool {
return r0
}
// MockCacheService_IsCacheEnabled_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsCacheEnabled'
type MockCacheService_IsCacheEnabled_Call struct {
// MockCacheService_IsCacheDisabled_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsCacheDisabled'
type MockCacheService_IsCacheDisabled_Call struct {
*mock.Call
}
// IsCacheEnabled is a helper method to define mock.On call
func (_e *MockCacheService_Expecter) IsCacheEnabled() *MockCacheService_IsCacheEnabled_Call {
return &MockCacheService_IsCacheEnabled_Call{Call: _e.mock.On("IsCacheEnabled")}
// IsCacheDisabled is a helper method to define mock.On call
func (_e *MockCacheService_Expecter) IsCacheDisabled() *MockCacheService_IsCacheDisabled_Call {
return &MockCacheService_IsCacheDisabled_Call{Call: _e.mock.On("IsCacheDisabled")}
}
func (_c *MockCacheService_IsCacheEnabled_Call) Run(run func()) *MockCacheService_IsCacheEnabled_Call {
func (_c *MockCacheService_IsCacheDisabled_Call) Run(run func()) *MockCacheService_IsCacheDisabled_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockCacheService_IsCacheEnabled_Call) Return(enabled bool) *MockCacheService_IsCacheEnabled_Call {
_c.Call.Return(enabled)
func (_c *MockCacheService_IsCacheDisabled_Call) Return(disabled bool) *MockCacheService_IsCacheDisabled_Call {
_c.Call.Return(disabled)
return _c
}
func (_c *MockCacheService_IsCacheEnabled_Call) RunAndReturn(run func() bool) *MockCacheService_IsCacheEnabled_Call {
func (_c *MockCacheService_IsCacheDisabled_Call) RunAndReturn(run func() bool) *MockCacheService_IsCacheDisabled_Call {
_c.Call.Return(run)
return _c
}
// IsCacheExpired provides a mock function with given fields:
func (_m *MockCacheService) IsCacheExpired() bool {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for IsCacheExpired")
}
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// MockCacheService_IsCacheExpired_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsCacheExpired'
type MockCacheService_IsCacheExpired_Call struct {
*mock.Call
}
// IsCacheExpired is a helper method to define mock.On call
func (_e *MockCacheService_Expecter) IsCacheExpired() *MockCacheService_IsCacheExpired_Call {
return &MockCacheService_IsCacheExpired_Call{Call: _e.mock.On("IsCacheExpired")}
}
func (_c *MockCacheService_IsCacheExpired_Call) Run(run func()) *MockCacheService_IsCacheExpired_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockCacheService_IsCacheExpired_Call) Return(expired bool) *MockCacheService_IsCacheExpired_Call {
_c.Call.Return(expired)
return _c
}
func (_c *MockCacheService_IsCacheExpired_Call) RunAndReturn(run func() bool) *MockCacheService_IsCacheExpired_Call {
_c.Call.Return(run)
return _c
}

View file

@ -171,14 +171,12 @@ func (s *service) Close(_ context.Context) (err error) {
func (s *service) getPeriodicStatus(ctx context.Context) error {
// get subscription status (from cache or from the PP node)
// if status has changed -> it will send an event
log.Debug("periodic: getting subscription status from cache/PP node")
// if status has changed -> it will send events, etc
_, err := s.GetSubscriptionStatus(ctx, &pb.RpcMembershipGetStatusRequest{})
return err
}
func (s *service) sendEvent(status *pb.RpcMembershipGetStatusResponse) {
func (s *service) sendMembershipUpdateEvent(status *pb.RpcMembershipGetStatusResponse) {
s.eventSender.Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
@ -208,55 +206,50 @@ func (s *service) GetSubscriptionStatus(ctx context.Context, req *pb.RpcMembersh
s.mx.Lock()
defer s.mx.Unlock()
ownerID := s.wallet.Account().SignKey.GetPublic().Account()
privKey := s.wallet.GetAccountPrivkey()
// 1 - check in cache first
var (
cachedStatus *pb.RpcMembershipGetStatusResponse
isCacheExpired bool
isCacheDisabled bool
cacheErr error
)
// 1 - check in cache
// if cache is disabled -> will return objects and ErrCacheDisabled
// if cache is expired -> will return objects and ErrCacheExpired
cachedStatus, _, err := s.cache.CacheGet()
if !req.NoCache {
isCacheExpired = s.cache.IsCacheExpired()
isCacheDisabled = s.cache.IsCacheDisabled()
// if NoCache flag -> skip returning from cache
if !req.NoCache && (err == nil) && (cachedStatus != nil) && (cachedStatus.Data != nil) {
// 2. If found in cache -> return it
log.Debug("returning subscription status from cache", zap.Error(err), zap.Any("cachedStatus", cachedStatus))
return cachedStatus, nil
cachedStatus, _, cacheErr = s.cache.CacheGet()
isNotExpiredAndNotDisabled := !isCacheExpired && !isCacheDisabled
if cacheErr == nil && isNotExpiredAndNotDisabled && canReturnCachedStatus(cachedStatus) {
log.Debug("returning subscription status from cache", zap.Error(cacheErr), zap.Any("cachedStatus", cachedStatus))
return cachedStatus, nil
}
}
// 3 - send request to PP node
gsr := proto.GetSubscriptionRequest{
// payment node will check if signature matches with this OwnerAnyID
OwnerAnyID: ownerID,
}
payload, err := gsr.Marshal()
// 2 - if not in cache - send request to PP node
ppReq, err := s.generateRequest()
if err != nil {
log.Error("can not marshal GetSubscriptionRequest", zap.Error(err))
return nil, ErrCanNotSign
return nil, err
}
// this is the SignKey
signature, err := privKey.Sign(payload)
log.Debug("get sub from PP node")
status, err := s.ppclient.GetSubscriptionStatus(ctx, ppReq)
// 3 - on PP node error
// try returning from cache again (do not care about the NoCache flag here!)
if err != nil {
log.Error("can not sign GetSubscriptionRequest", zap.Error(err))
return nil, ErrCanNotSign
}
isCacheExpired = s.cache.IsCacheExpired()
isCacheDisabled = s.cache.IsCacheDisabled()
cachedStatus, _, cacheErr = s.cache.CacheGet()
reqSigned := proto.GetSubscriptionRequestSigned{
Payload: payload,
Signature: signature,
}
log.Debug("get sub from PP node", zap.Any("cachedStatus", cachedStatus), zap.Bool("noCache", req.NoCache))
status, err := s.ppclient.GetSubscriptionStatus(ctx, &reqSigned)
if err != nil {
// 4a. try reading from cache again
if (cachedStatus != nil) && (cachedStatus.Data != nil) {
log.Debug("returning subscription status from cache again", zap.Error(err), zap.Any("cachedStatus", cachedStatus))
// if cache is expired/disabled or OK -> use this data
isExpiredOrDisabled := isCacheExpired || isCacheDisabled
if (cacheErr == nil || isExpiredOrDisabled) && canReturnCachedStatus(cachedStatus) {
log.Debug("returning subscription status from cache", zap.Error(err), zap.Any("cachedStatus", cachedStatus))
return cachedStatus, nil
}
// 4b. If PP node didn't answer -> create empty response
// If PP node didn't answer -> create empty response
log.Info("creating empty subscription in cache because can not get subscription status from the payment node")
// eat error and create empty status ("no tier") so that we will then save it to the cache
@ -268,7 +261,7 @@ func (s *service) GetSubscriptionStatus(ctx context.Context, req *pb.RpcMembersh
out := convertMembershipStatus(status)
// 5. Save to cache. Lifetime - min(subscription ends, now + TTL)
// 4 - Save to cache. Lifetime - min(subscription ends, now + TTL)
// update only status, not tiers
err = s.cache.CacheSet(&out, nil)
if err != nil {
@ -276,31 +269,110 @@ func (s *service) GetSubscriptionStatus(ctx context.Context, req *pb.RpcMembersh
// return nil, ErrCacheProblem
}
isDiffTier := (cachedStatus != nil) && (cachedStatus.Data != nil) && (cachedStatus.Data.Tier != status.Tier)
isDiffStatus := (cachedStatus != nil) && (cachedStatus.Data != nil) && (cachedStatus.Data.Status != model.MembershipStatus(status.Status))
isEmailDiff := (cachedStatus != nil) && (cachedStatus.Data != nil) && (cachedStatus.Data.UserEmail != status.UserEmail)
log.Debug("subscription status", zap.Any("from server", status), zap.Any("cached", cachedStatus))
log.Debug("subscription status", zap.Any("from server", status), zap.Any("cached", cachedStatus), zap.Bool("isEmailDiff", isEmailDiff))
// 5 - Send all messages to the client if needed
if !isUpdateRequired(cacheErr, isCacheDisabled, isCacheExpired, cachedStatus, status) {
// no need to send events or enable/disable cache
return &out, nil
}
s.updateStatus(ctx, status)
// 6. If tier or status has changed -> send event
if cachedStatus != nil && !isDiffTier && !isDiffStatus && !isEmailDiff {
// 6 - Enable or disable cache (only if status has changed)
if isNeedToEnableCache(status) {
s.enableCache(status)
} else if isNeedToDisableCache(status) {
// also the cache will be automatically enbaled in N minutes
s.disableCache(status)
}
return &out, nil
}
func (s *service) generateRequest() (*proto.GetSubscriptionRequestSigned, error) {
ownerID := s.wallet.Account().SignKey.GetPublic().Account()
privKey := s.wallet.GetAccountPrivkey()
gsr := proto.GetSubscriptionRequest{
// payment node will check if signature matches with this OwnerAnyID
OwnerAnyID: ownerID,
}
payload, err := gsr.Marshal()
if err != nil {
log.Error("can not marshal GetSubscriptionRequest", zap.Error(err))
return nil, ErrCanNotSign
}
signature, err := privKey.Sign(payload)
if err != nil {
log.Error("can not sign GetSubscriptionRequest", zap.Error(err))
return nil, ErrCanNotSign
}
return &proto.GetSubscriptionRequestSigned{
Payload: payload,
Signature: signature,
}, nil
}
func isCacheContainsError(s *pb.RpcMembershipGetStatusResponse) bool {
return s != nil && s.Error != nil && s.Error.Code != pb.RpcMembershipGetStatusResponseError_NULL
}
func canReturnCachedStatus(s *pb.RpcMembershipGetStatusResponse) bool {
return s != nil && s.Data != nil && (s.Error == nil || s.Error.Code == pb.RpcMembershipGetStatusResponseError_NULL)
}
func isUpdateRequired(cacheErr error, isCacheDisabled bool, isCacheExpired bool, cachedStatus *pb.RpcMembershipGetStatusResponse, status *proto.GetSubscriptionResponse) bool {
// 1 - If cache was empty or expired
// -> treat at is if data was different
isCacheEmpty := cacheErr != nil || cachedStatus == nil || cachedStatus.Data == nil || isCacheExpired
if isCacheEmpty {
log.Debug("subscription status treated as changed because cache was empty/expired")
return true
}
// 2 - Extra check that cache contained previous error
if isCacheContainsError(cachedStatus) {
log.Debug("subscription status treated as changed because cache contained previous error")
return true
}
// 3 - Check if tier or status has changed
if status == nil {
return false
}
isDiffTier := cachedStatus.Data.Tier != status.Tier
isDiffStatus := cachedStatus.Data.Status != model.MembershipStatus(status.Status)
isEmailDiff := cachedStatus.Data.UserEmail != status.UserEmail
if !isDiffTier && !isDiffStatus && !isEmailDiff {
log.Debug("subscription status has NOT changed",
zap.Bool("cache was empty", cachedStatus == nil),
zap.Bool("isDiffTier", isDiffTier),
zap.Bool("isDiffStatus", isDiffStatus),
)
return &out, nil
return false
}
log.Info("subscription status has changed. sending EventMembershipUpdate",
log.Info("subscription status has been changed. sending EventMembershipUpdate",
zap.Bool("cache was empty", cachedStatus == nil),
zap.Bool("isDiffTier", isDiffTier),
zap.Bool("isDiffStatus", isDiffStatus),
zap.Bool("isEmailDiff", isEmailDiff),
)
s.sendEvent(&out)
return true
}
// 7. If name has changed -> update global name or own identity
func (s *service) updateStatus(ctx context.Context, status *proto.GetSubscriptionResponse) {
out := convertMembershipStatus(status)
// 1 - Broadcast event
log.Debug("sending EventMembershipUpdate", zap.Any("status", status))
s.sendMembershipUpdateEvent(&out)
// 2 - If name has changed -> update global name or own identity
if status.RequestedAnyName != "" {
log.Debug("update global name",
zap.String("requestedAnyName", status.RequestedAnyName),
@ -309,36 +381,12 @@ func (s *service) GetSubscriptionStatus(ctx context.Context, req *pb.RpcMembersh
s.profileUpdater.UpdateOwnGlobalName(status.RequestedAnyName)
}
// 8. UpdateLimits
err = s.updateLimits(ctx)
// 3 - Update limits
err := s.updateLimits(ctx)
if err != nil {
// eat error
log.Error("update limits", zap.Error(err))
}
// 9. Disable cache in case status is Pending
if status.Status == proto.SubscriptionStatus_StatusPending {
log.Info("disabling cache to wait for Active state")
err = s.cache.CacheDisableForNextMinutes(cacheDisableMinutes)
if err != nil {
log.Warn("can not disable cache", zap.Error(err))
// return nil, errors.Wrap(ErrCacheProblem, err.Error())
}
}
// 10. Enable cache again if status is active
isFinished := status.Status == proto.SubscriptionStatus_StatusActive
if isFinished {
log.Info("enabling cache again")
// or it will be automatically enabled after N minutes of DisableForNextMinutes() call
err = s.cache.CacheEnable()
if err != nil {
log.Warn("can not enable cache", zap.Error(err))
// return nil, errors.Wrap(ErrCacheProblem, err.Error())
}
}
return &out, nil
}
func (s *service) updateLimits(ctx context.Context) error {
@ -346,6 +394,38 @@ func (s *service) updateLimits(ctx context.Context) error {
return s.fileLimitsUpdater.UpdateNodeUsage(ctx)
}
func isNeedToDisableCache(status *proto.GetSubscriptionResponse) bool {
return status.Status == proto.SubscriptionStatus_StatusPending
}
func (s *service) disableCache(status *proto.GetSubscriptionResponse) {
log.Info("disabling cache to wait for Active state")
err := s.cache.CacheDisableForNextMinutes(cacheDisableMinutes)
if err != nil {
log.Warn("can not disable cache", zap.Error(err))
// return nil, errors.Wrap(ErrCacheProblem, err.Error())
}
}
func isNeedToEnableCache(status *proto.GetSubscriptionResponse) bool {
isEnableCacheStatus := (status.Status != proto.SubscriptionStatus_StatusUnknown) && (status.Status != proto.SubscriptionStatus_StatusPending)
isEnableCacheTier := status.Tier > uint32(proto.SubscriptionTier_TierExplorer)
return isEnableCacheStatus && isEnableCacheTier
}
func (s *service) enableCache(status *proto.GetSubscriptionResponse) {
log.Info("enabling cache again")
// or it will be automatically enabled after N minutes of DisableForNextMinutes() call
err := s.cache.CacheEnable()
if err != nil {
log.Warn("can not enable cache", zap.Error(err))
// return nil, errors.Wrap(ErrCacheProblem, err.Error())
}
}
func (s *service) IsNameValid(ctx context.Context, req *pb.RpcMembershipIsNameValidRequest) (*pb.RpcMembershipIsNameValidResponse, error) {
var code proto.IsNameValidResponse_Code
var desc string
@ -762,15 +842,22 @@ func (s *service) GetTiers(ctx context.Context, req *pb.RpcMembershipGetTiersReq
return filtered, nil
}
func (s *service) getAllTiers(ctx context.Context, req *pb.RpcMembershipGetTiersRequest) (*pb.RpcMembershipGetTiersResponse, error) {
// 1 - check in cache
// status var. is unused here
_, cachedTiers, err := s.cache.CacheGet()
func canReturnCachedTiers(t *pb.RpcMembershipGetTiersResponse) bool {
return t != nil && t.Tiers != nil && (t.Error == nil || t.Error.Code == pb.RpcMembershipGetTiersResponseError_NULL)
}
// if NoCache -> skip returning from cache
if !req.NoCache && (err == nil) && (cachedTiers != nil) && (cachedTiers.Tiers != nil) {
log.Debug("returning tiers from cache", zap.Error(err), zap.Any("cachedTiers", cachedTiers))
return cachedTiers, nil
func (s *service) getAllTiers(ctx context.Context, req *pb.RpcMembershipGetTiersRequest) (*pb.RpcMembershipGetTiersResponse, error) {
// 1 - check in cache in case NoCache is False
if !req.NoCache {
isCacheExpired := s.cache.IsCacheExpired()
isCacheDisabled := s.cache.IsCacheDisabled()
_, cachedTiers, cacheErr := s.cache.CacheGet()
isNotExpiredAndNotDisabled := !isCacheExpired && !isCacheDisabled
if cacheErr == nil && isNotExpiredAndNotDisabled && canReturnCachedTiers(cachedTiers) {
log.Debug("returning tiers from cache", zap.Any("cachedTiers", cachedTiers))
return cachedTiers, nil
}
}
// 2 - send request

View file

@ -122,7 +122,7 @@ func (fx *fixture) finish(t *testing.T) {
}
func TestGetStatus(t *testing.T) {
t.Run("fail if no cache and GetSubscriptionStatus returns error", func(t *testing.T) {
t.Run("return default if no cache and GetSubscriptionStatus returns error", func(t *testing.T) {
fx := newFixture(t)
defer fx.finish(t)
@ -130,12 +130,14 @@ func TestGetStatus(t *testing.T) {
return nil, errors.New("test error")
}).MinTimes(1)
fx.cache.EXPECT().CacheGet().Return(nil, nil, cache.ErrCacheExpired)
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(nil, nil, cache.ErrCacheDbError)
fx.cache.EXPECT().CacheSet(mock.AnythingOfType("*pb.RpcMembershipGetStatusResponse"), mock.AnythingOfType("*pb.RpcMembershipGetTiersResponse")).RunAndReturn(func(in *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error) {
return nil
})
// fx.cache.EXPECT().CacheEnable().Return(nil)
// changing from NO CACHE -> default "Unknown" tier
fx.expectLimitsUpdated()
// Call the function being tested
@ -146,7 +148,7 @@ func TestGetStatus(t *testing.T) {
assert.Equal(t, model.Membership_StatusUnknown, resp.Data.Status)
})
t.Run("success if NoCache flag is passed", func(t *testing.T) {
t.Run("return default if no cache and GetSubscriptionStatus returns error, NoCache is passed", func(t *testing.T) {
fx := newFixture(t)
defer fx.finish(t)
@ -154,14 +156,112 @@ func TestGetStatus(t *testing.T) {
return nil, errors.New("test error")
}).MinTimes(1)
fx.cache.EXPECT().CacheGet().Return(nil, nil, cache.ErrCacheExpired)
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(nil, nil, cache.ErrCacheDbError)
fx.cache.EXPECT().CacheSet(mock.AnythingOfType("*pb.RpcMembershipGetStatusResponse"), mock.AnythingOfType("*pb.RpcMembershipGetTiersResponse")).RunAndReturn(func(in *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error) {
return nil
})
// fx.cache.EXPECT().CacheEnable().Return(nil)
// changing from NO CACHE -> default "Unknown" tier
fx.expectLimitsUpdated()
// Call the function being tested
resp, err := fx.GetSubscriptionStatus(ctx, &pb.RpcMembershipGetStatusRequest{
// / >>> here:
NoCache: true,
})
assert.NoError(t, err)
assert.Equal(t, uint32(psp.SubscriptionTier_TierUnknown), resp.Data.Tier)
assert.Equal(t, model.Membership_StatusUnknown, resp.Data.Status)
})
t.Run("return prev values if ErrCacheExpired and GetSubscriptionStatus returns error", func(t *testing.T) {
fx := newFixture(t)
defer fx.finish(t)
fx.ppclient.EXPECT().GetSubscriptionStatus(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx interface{}, in *psp.GetSubscriptionRequestSigned) (*psp.GetSubscriptionResponse, error) {
return nil, errors.New("test error")
}).MinTimes(1)
sr := psp.GetSubscriptionResponse{
Tier: uint32(psp.SubscriptionTier_TierExplorer),
Status: psp.SubscriptionStatus_StatusActive,
DateStarted: uint64(timeNow.Unix()),
DateEnds: uint64(subsExpire.Unix()),
IsAutoRenew: true,
PaymentMethod: psp.PaymentMethod_MethodCrypto,
RequestedAnyName: "something.any",
}
psgsr := pb.RpcMembershipGetStatusResponse{
Error: &pb.RpcMembershipGetStatusResponseError{
Code: pb.RpcMembershipGetStatusResponseError_NULL,
},
Data: &model.Membership{
Tier: uint32(sr.Tier),
Status: model.MembershipStatus(sr.Status),
DateStarted: sr.DateStarted,
DateEnds: sr.DateEnds,
IsAutoRenew: sr.IsAutoRenew,
PaymentMethod: PaymentMethodToModel(sr.PaymentMethod),
NsName: "something",
NsNameType: model.NameserviceNameType_AnyName,
},
}
fx.cache.EXPECT().IsCacheExpired().Return(true)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, nil)
// Call the function being tested
resp, err := fx.GetSubscriptionStatus(ctx, &pb.RpcMembershipGetStatusRequest{})
assert.NoError(t, err)
assert.Equal(t, uint32(psp.SubscriptionTier_TierExplorer), resp.Data.Tier)
assert.Equal(t, model.Membership_StatusActive, resp.Data.Status)
})
t.Run("return prev values if ErrCacheExpired, GetSubscriptionStatus returns error, and if NoCache flag is passed", func(t *testing.T) {
fx := newFixture(t)
defer fx.finish(t)
fx.ppclient.EXPECT().GetSubscriptionStatus(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx interface{}, in *psp.GetSubscriptionRequestSigned) (*psp.GetSubscriptionResponse, error) {
return nil, errors.New("test error")
}).MinTimes(1)
sr := psp.GetSubscriptionResponse{
Tier: uint32(psp.SubscriptionTier_TierExplorer),
Status: psp.SubscriptionStatus_StatusActive,
DateStarted: uint64(timeNow.Unix()),
DateEnds: uint64(subsExpire.Unix()),
IsAutoRenew: true,
PaymentMethod: psp.PaymentMethod_MethodCrypto,
RequestedAnyName: "something.any",
}
psgsr := pb.RpcMembershipGetStatusResponse{
Error: &pb.RpcMembershipGetStatusResponseError{
Code: pb.RpcMembershipGetStatusResponseError_NULL,
},
Data: &model.Membership{
Tier: uint32(sr.Tier),
Status: model.MembershipStatus(sr.Status),
DateStarted: sr.DateStarted,
DateEnds: sr.DateEnds,
IsAutoRenew: sr.IsAutoRenew,
PaymentMethod: PaymentMethodToModel(sr.PaymentMethod),
NsName: "something",
NsNameType: model.NameserviceNameType_AnyName,
},
}
// in case of cache.ErrCacheExpired this should always return objects
fx.cache.EXPECT().IsCacheExpired().Return(true)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, nil)
// Call the function being tested
req := pb.RpcMembershipGetStatusRequest{
// / >>> here:
@ -170,8 +270,8 @@ func TestGetStatus(t *testing.T) {
resp, err := fx.GetSubscriptionStatus(ctx, &req)
assert.NoError(t, err)
assert.Equal(t, uint32(psp.SubscriptionTier_TierUnknown), resp.Data.Tier)
assert.Equal(t, model.Membership_StatusUnknown, resp.Data.Status)
assert.Equal(t, uint32(psp.SubscriptionTier_TierExplorer), resp.Data.Tier)
assert.Equal(t, model.Membership_StatusActive, resp.Data.Status)
})
t.Run("success if NoCache flag is passed, but no connectivity", func(t *testing.T) {
@ -195,6 +295,8 @@ func TestGetStatus(t *testing.T) {
NsNameType: model.NameserviceNameType_AnyName,
},
}
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, nil)
// Call the function being tested
@ -209,37 +311,7 @@ func TestGetStatus(t *testing.T) {
assert.Equal(t, model.Membership_StatusActive, resp.Data.Status)
})
t.Run("fail if NoCache flag is passed, no cache, no connectivity", func(t *testing.T) {
fx := newFixture(t)
defer fx.finish(t)
fx.ppclient.EXPECT().GetSubscriptionStatus(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx interface{}, in *psp.GetSubscriptionRequestSigned) (*psp.GetSubscriptionResponse, error) {
// >>> here
return nil, ErrNoConnection
}).MinTimes(1)
// >>> here:
fx.cache.EXPECT().CacheGet().Return(nil, nil, nil)
fx.cache.EXPECT().CacheSet(mock.AnythingOfType("*pb.RpcMembershipGetStatusResponse"), mock.AnythingOfType("*pb.RpcMembershipGetTiersResponse")).RunAndReturn(func(in *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error) {
return nil
})
fx.expectLimitsUpdated()
// Call the function being tested
req := pb.RpcMembershipGetStatusRequest{
// / >>> here:
NoCache: true,
}
resp, err := fx.GetSubscriptionStatus(ctx, &req)
assert.NoError(t, err)
// default values
assert.Equal(t, uint32(psp.SubscriptionTier_TierUnknown), resp.Data.Tier)
assert.Equal(t, model.Membership_StatusUnknown, resp.Data.Status)
})
t.Run("fail if no cache, GetSubscriptionStatus returns error, and default tiers", func(t *testing.T) {
t.Run("return from cache, if cache expired and GetSubscriptionStatus returns error, and default tiers", func(t *testing.T) {
fx := newFixture(t)
defer fx.finish(t)
@ -283,11 +355,16 @@ func TestGetStatus(t *testing.T) {
return nil, errors.New("no internet")
}).MinTimes(1)
fx.cache.EXPECT().CacheGet().Return(&psgsr, &tgr, cache.ErrCacheExpired)
fx.cache.EXPECT().IsCacheExpired().Return(true)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(&psgsr, &tgr, nil)
// Call the function being tested
_, err := fx.GetSubscriptionStatus(ctx, &pb.RpcMembershipGetStatusRequest{})
resp, err := fx.GetSubscriptionStatus(ctx, &pb.RpcMembershipGetStatusRequest{})
assert.NoError(t, err)
assert.Equal(t, uint32(psp.SubscriptionTier_TierExplorer), resp.Data.Tier)
assert.Equal(t, model.Membership_StatusActive, resp.Data.Status)
})
t.Run("success if no cache, GetSubscriptionStatus returns error and data", func(t *testing.T) {
@ -335,13 +412,16 @@ func TestGetStatus(t *testing.T) {
return nil, errors.New("no internet")
}).MinTimes(1)
// TODO: refactor - bad method semantics:
// returns error, but also returns data...
fx.cache.EXPECT().CacheGet().Return(&psgsr, &tgr, cache.ErrCacheExpired)
fx.cache.EXPECT().IsCacheExpired().Return(true)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(&psgsr, &tgr, nil)
// Call the function being tested
_, err := fx.GetSubscriptionStatus(ctx, &pb.RpcMembershipGetStatusRequest{})
resp, err := fx.GetSubscriptionStatus(ctx, &pb.RpcMembershipGetStatusRequest{})
assert.NoError(t, err)
assert.Equal(t, uint32(psp.SubscriptionTier_TierExplorer), resp.Data.Tier)
assert.Equal(t, model.Membership_StatusActive, resp.Data.Status)
})
t.Run("success if cache is expired and GetSubscriptionStatus returns no error", func(t *testing.T) {
@ -360,7 +440,8 @@ func TestGetStatus(t *testing.T) {
psgsr := pb.RpcMembershipGetStatusResponse{
Data: &model.Membership{
Tier: uint32(sr.Tier),
// >>> here: different tier returned by cache!
Tier: uint32(psp.SubscriptionTier_TierBuilder1WeekTEST),
Status: model.MembershipStatus(sr.Status),
DateStarted: sr.DateStarted,
DateEnds: sr.DateEnds,
@ -375,16 +456,22 @@ func TestGetStatus(t *testing.T) {
return &sr, nil
}).MinTimes(1)
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, cache.ErrCacheExpired)
fx.cache.EXPECT().IsCacheExpired().Return(true)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, nil)
fx.cache.EXPECT().CacheSet(mock.AnythingOfType("*pb.RpcMembershipGetStatusResponse"), mock.AnythingOfType("*pb.RpcMembershipGetTiersResponse")).RunAndReturn(func(in *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error) {
return nil
})
// fx.cache.EXPECT().CacheEnable().Return(nil)
// this should not be called because server returned Explorer tier
//fx.cache.EXPECT().CacheEnable().Return(nil)
fx.expectLimitsUpdated()
// Call the function being tested
resp, err := fx.GetSubscriptionStatus(ctx, &pb.RpcMembershipGetStatusRequest{})
assert.NoError(t, err)
// the tier should be as returned by GetSubscriptionStatus, not from cache
assert.Equal(t, uint32(psp.SubscriptionTier_TierExplorer), resp.Data.Tier)
assert.Equal(t, model.Membership_StatusActive, resp.Data.Status)
assert.Equal(t, sr.DateStarted, resp.Data.DateStarted)
@ -410,6 +497,7 @@ func TestGetStatus(t *testing.T) {
psgsr := pb.RpcMembershipGetStatusResponse{
Data: &model.Membership{
// same tier returned by cache here
Tier: uint32(sr.Tier),
Status: model.MembershipStatus(sr.Status),
DateStarted: sr.DateStarted,
@ -425,12 +513,119 @@ func TestGetStatus(t *testing.T) {
return &sr, nil
}).MinTimes(1)
// here: cache is disabled
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, cache.ErrCacheDisabled)
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().IsCacheDisabled().Return(true)
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, nil)
fx.cache.EXPECT().CacheSet(mock.AnythingOfType("*pb.RpcMembershipGetStatusResponse"), mock.AnythingOfType("*pb.RpcMembershipGetTiersResponse")).RunAndReturn(func(in *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error) {
return nil
})
// fx.cache.EXPECT().CacheEnable().Return(nil)
// tier was not changed
//fx.expectLimitsUpdated()
// Call the function being tested
resp, err := fx.GetSubscriptionStatus(ctx, &pb.RpcMembershipGetStatusRequest{})
assert.NoError(t, err)
assert.Equal(t, uint32(psp.SubscriptionTier_TierExplorer), resp.Data.Tier)
assert.Equal(t, model.Membership_StatusActive, resp.Data.Status)
assert.Equal(t, sr.DateStarted, resp.Data.DateStarted)
assert.Equal(t, sr.DateEnds, resp.Data.DateEnds)
assert.Equal(t, true, resp.Data.IsAutoRenew)
assert.Equal(t, model.Membership_MethodCrypto, resp.Data.PaymentMethod)
assert.Equal(t, "something", resp.Data.NsName)
})
t.Run("success if cache was disabled and GetSubscriptionStatus returns error", func(t *testing.T) {
fx := newFixture(t)
defer fx.finish(t)
sr := psp.GetSubscriptionResponse{
Tier: uint32(psp.SubscriptionTier_TierExplorer),
Status: psp.SubscriptionStatus_StatusActive,
DateStarted: uint64(timeNow.Unix()),
DateEnds: uint64(subsExpire.Unix()),
IsAutoRenew: true,
PaymentMethod: psp.PaymentMethod_MethodCrypto,
RequestedAnyName: "something.any",
}
psgsr := pb.RpcMembershipGetStatusResponse{
Data: &model.Membership{
// same tier returned by cache here
Tier: uint32(sr.Tier),
Status: model.MembershipStatus(sr.Status),
DateStarted: sr.DateStarted,
DateEnds: sr.DateEnds,
IsAutoRenew: sr.IsAutoRenew,
PaymentMethod: PaymentMethodToModel(sr.PaymentMethod),
NsName: "something",
NsNameType: model.NameserviceNameType_AnyName,
},
}
fx.ppclient.EXPECT().GetSubscriptionStatus(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx interface{}, in *psp.GetSubscriptionRequestSigned) (*psp.GetSubscriptionResponse, error) {
return nil, errors.New("no internet")
}).MinTimes(1)
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().IsCacheDisabled().Return(true)
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, nil)
// tier was not changed
//fx.expectLimitsUpdated()
// Call the function being tested
resp, err := fx.GetSubscriptionStatus(ctx, &pb.RpcMembershipGetStatusRequest{})
assert.NoError(t, err)
assert.Equal(t, uint32(psp.SubscriptionTier_TierExplorer), resp.Data.Tier)
assert.Equal(t, model.Membership_StatusActive, resp.Data.Status)
assert.Equal(t, sr.DateStarted, resp.Data.DateStarted)
assert.Equal(t, sr.DateEnds, resp.Data.DateEnds)
assert.Equal(t, true, resp.Data.IsAutoRenew)
assert.Equal(t, model.Membership_MethodCrypto, resp.Data.PaymentMethod)
assert.Equal(t, "something", resp.Data.NsName)
})
t.Run("success if cache was expired and GetSubscriptionStatus returns error", func(t *testing.T) {
fx := newFixture(t)
defer fx.finish(t)
sr := psp.GetSubscriptionResponse{
Tier: uint32(psp.SubscriptionTier_TierExplorer),
Status: psp.SubscriptionStatus_StatusActive,
DateStarted: uint64(timeNow.Unix()),
DateEnds: uint64(subsExpire.Unix()),
IsAutoRenew: true,
PaymentMethod: psp.PaymentMethod_MethodCrypto,
RequestedAnyName: "something.any",
}
psgsr := pb.RpcMembershipGetStatusResponse{
Data: &model.Membership{
// same tier returned by cache here
Tier: uint32(sr.Tier),
Status: model.MembershipStatus(sr.Status),
DateStarted: sr.DateStarted,
DateEnds: sr.DateEnds,
IsAutoRenew: sr.IsAutoRenew,
PaymentMethod: PaymentMethodToModel(sr.PaymentMethod),
NsName: "something",
NsNameType: model.NameserviceNameType_AnyName,
},
}
fx.ppclient.EXPECT().GetSubscriptionStatus(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx interface{}, in *psp.GetSubscriptionRequestSigned) (*psp.GetSubscriptionResponse, error) {
return nil, errors.New("no internet")
}).MinTimes(1)
fx.cache.EXPECT().IsCacheExpired().Return(true)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, nil)
// tier was not changed
//fx.expectLimitsUpdated()
// Call the function being tested
resp, err := fx.GetSubscriptionStatus(ctx, &pb.RpcMembershipGetStatusRequest{})
@ -479,10 +674,16 @@ func TestGetStatus(t *testing.T) {
return &sr, nil
}).MinTimes(1)
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, cache.ErrCacheExpired)
fx.cache.EXPECT().IsCacheExpired().Return(true)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, nil)
fx.cache.EXPECT().CacheSet(mock.AnythingOfType("*pb.RpcMembershipGetStatusResponse"), mock.AnythingOfType("*pb.RpcMembershipGetTiersResponse")).RunAndReturn(func(in *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error) {
return errors.New("can not write to cache!")
})
// this should not be called because server returned Explorer tier
//fx.cache.EXPECT().CacheEnable().Return(nil)
fx.expectLimitsUpdated()
// Call the function being tested
_, err := fx.GetSubscriptionStatus(ctx, &pb.RpcMembershipGetStatusRequest{})
@ -510,6 +711,8 @@ func TestGetStatus(t *testing.T) {
},
}
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
// HERE>>>
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, nil)
@ -557,12 +760,15 @@ func TestGetStatus(t *testing.T) {
return &sr, nil
}).MinTimes(1)
fx.cache.EXPECT().CacheGet().Return(nil, nil, cache.ErrCacheExpired)
fx.cache.EXPECT().IsCacheExpired().Return(true)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(nil, nil, nil)
fx.cache.EXPECT().CacheSet(&psgsr, mock.AnythingOfType("*pb.RpcMembershipGetTiersResponse")).RunAndReturn(func(in *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error) {
return nil
})
fx.cache.EXPECT().CacheEnable().Return(nil)
// because cache was expired before!
fx.expectLimitsUpdated()
// Call the function being tested
@ -573,13 +779,13 @@ func TestGetStatus(t *testing.T) {
assert.Equal(t, model.Membership_StatusActive, resp.Data.Status)
})
t.Run("if cache was disabled and tier has changed -> save, but enable cache back", func(t *testing.T) {
t.Run("if cache was disabled and tier has changed -> save, and enable cache back", func(t *testing.T) {
fx := newFixture(t)
defer fx.finish(t)
var subsExpire5 time.Time = timeNow.Add(365 * 24 * time.Hour)
// this is from PP node
// this is from PP node (new status)
sr := psp.GetSubscriptionResponse{
Tier: uint32(psp.SubscriptionTier_TierBuilder1Year),
Status: psp.SubscriptionStatus_StatusActive,
@ -607,16 +813,29 @@ func TestGetStatus(t *testing.T) {
},
}
// this is the new state
var psgsr2 pb.RpcMembershipGetStatusResponse = psgsr
psgsr2.Data.Tier = uint32(psp.SubscriptionTier_TierBuilder1Year)
psgsr2 := pb.RpcMembershipGetStatusResponse{
Error: &pb.RpcMembershipGetStatusResponseError{
Code: pb.RpcMembershipGetStatusResponseError_NULL,
},
Data: &model.Membership{
Tier: uint32(psp.SubscriptionTier_TierBuilder1Year),
Status: model.MembershipStatus(sr.Status),
DateStarted: sr.DateStarted,
DateEnds: sr.DateEnds,
IsAutoRenew: sr.IsAutoRenew,
PaymentMethod: PaymentMethodToModel(sr.PaymentMethod),
NsName: "",
NsNameType: model.NameserviceNameType_AnyName,
},
}
fx.ppclient.EXPECT().GetSubscriptionStatus(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx interface{}, in *psp.GetSubscriptionRequestSigned) (*psp.GetSubscriptionResponse, error) {
return &sr, nil
}).MinTimes(1)
// return real struct and error
fx.cache.EXPECT().CacheGet().Return(nil, nil, cache.ErrCacheDisabled)
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().IsCacheDisabled().Return(true)
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, nil)
fx.cache.EXPECT().CacheSet(&psgsr2, mock.AnythingOfType("*pb.RpcMembershipGetTiersResponse")).RunAndReturn(func(in *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error) {
return nil
})
@ -633,6 +852,57 @@ func TestGetStatus(t *testing.T) {
assert.Equal(t, uint32(psp.SubscriptionTier_TierBuilder1Year), resp.Data.Tier)
assert.Equal(t, model.Membership_StatusActive, resp.Data.Status)
})
t.Run("cache has error saved, GetSubscriptionStatus returns no error", func(t *testing.T) {
fx := newFixture(t)
defer fx.finish(t)
sr := psp.GetSubscriptionResponse{
Tier: uint32(psp.SubscriptionTier_TierExplorer),
Status: psp.SubscriptionStatus_StatusActive,
DateStarted: uint64(timeNow.Unix()),
DateEnds: uint64(subsExpire.Unix()),
IsAutoRenew: true,
PaymentMethod: psp.PaymentMethod_MethodCrypto,
RequestedAnyName: "something.any",
}
psgsr := pb.RpcMembershipGetStatusResponse{
Error: &pb.RpcMembershipGetStatusResponseError{
// >> here:
Code: pb.RpcMembershipGetStatusResponseError_PAYMENT_NODE_ERROR,
},
Data: &model.Membership{
Tier: uint32(sr.Tier),
Status: model.MembershipStatus(sr.Status),
DateStarted: sr.DateStarted,
DateEnds: sr.DateEnds,
IsAutoRenew: sr.IsAutoRenew,
PaymentMethod: PaymentMethodToModel(sr.PaymentMethod),
NsName: "something",
NsNameType: model.NameserviceNameType_AnyName,
},
}
fx.ppclient.EXPECT().GetSubscriptionStatus(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx interface{}, in *psp.GetSubscriptionRequestSigned) (*psp.GetSubscriptionResponse, error) {
return &sr, nil
}).MinTimes(1)
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, nil)
fx.cache.EXPECT().CacheSet(mock.AnythingOfType("*pb.RpcMembershipGetStatusResponse"), mock.AnythingOfType("*pb.RpcMembershipGetTiersResponse")).RunAndReturn(func(in *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error) {
return nil
})
// this should not be called because server returned Explorer tier
//fx.cache.EXPECT().CacheEnable().Return(nil)
fx.expectLimitsUpdated()
// Call the function being tested
_, err := fx.GetSubscriptionStatus(ctx, &pb.RpcMembershipGetStatusRequest{})
assert.NoError(t, err)
})
}
func (fx *fixture) expectLimitsUpdated() {
@ -872,7 +1142,9 @@ func TestGetTiers(t *testing.T) {
return nil, errors.New("test error")
}).MinTimes(1)
fx.cache.EXPECT().CacheGet().Return(nil, nil, cache.ErrCacheExpired)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().CacheGet().Return(nil, nil, nil)
req := pb.RpcMembershipGetTiersRequest{
NoCache: false,
@ -886,7 +1158,9 @@ func TestGetTiers(t *testing.T) {
fx := newFixture(t)
defer fx.finish(t)
fx.cache.EXPECT().CacheGet().Return(nil, nil, cache.ErrCacheExpired)
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(nil, nil, cache.ErrCacheDbError)
fx.cache.EXPECT().CacheSet(mock.AnythingOfType("*pb.RpcMembershipGetStatusResponse"), mock.AnythingOfType("*pb.RpcMembershipGetTiersResponse")).RunAndReturn(func(in *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error) {
return nil
})
@ -907,8 +1181,6 @@ func TestGetTiers(t *testing.T) {
}, nil
}).MinTimes(1)
fx.cache.EXPECT().CacheEnable().Return(nil)
fx.expectLimitsUpdated()
req := pb.RpcMembershipGetTiersRequest{
@ -923,7 +1195,10 @@ func TestGetTiers(t *testing.T) {
fx := newFixture(t)
defer fx.finish(t)
fx.cache.EXPECT().CacheGet().Return(nil, nil, cache.ErrCacheExpired)
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(nil, nil, cache.ErrCacheDbError)
fx.cache.EXPECT().CacheSet(mock.AnythingOfType("*pb.RpcMembershipGetStatusResponse"), mock.AnythingOfType("*pb.RpcMembershipGetTiersResponse")).RunAndReturn(func(in *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error) {
return nil
})
@ -972,7 +1247,8 @@ func TestGetTiers(t *testing.T) {
}, nil
}).MinTimes(1)
fx.cache.EXPECT().CacheEnable().Return(nil)
// this should not be called because server returned Explorer tier
//fx.cache.EXPECT().CacheEnable().Return(nil)
fx.expectLimitsUpdated()
@ -1024,6 +1300,9 @@ func TestGetTiers(t *testing.T) {
NsNameType: model.NameserviceNameType_AnyName,
},
}
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(&psgsr, nil, nil)
fx.cache.EXPECT().CacheSet(mock.AnythingOfType("*pb.RpcMembershipGetStatusResponse"), mock.AnythingOfType("*pb.RpcMembershipGetTiersResponse")).RunAndReturn(func(in *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error) {
return nil
@ -1093,6 +1372,8 @@ func TestGetTiers(t *testing.T) {
},
},
}
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(&psgsr, &tgr, nil)
req := pb.RpcMembershipGetTiersRequest{
@ -1159,6 +1440,8 @@ func TestGetTiers(t *testing.T) {
},
},
}
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(&psgsr, &tgr, nil)
req := pb.RpcMembershipGetTiersRequest{
@ -1210,6 +1493,9 @@ func TestGetTiers(t *testing.T) {
},
},
}
fx.cache.EXPECT().IsCacheExpired().Return(false)
fx.cache.EXPECT().IsCacheDisabled().Return(false)
fx.cache.EXPECT().CacheGet().Return(nil, &tgr, nil)
// should call it to save status
fx.cache.EXPECT().CacheSet(mock.AnythingOfType("*pb.RpcMembershipGetStatusResponse"), mock.AnythingOfType("*pb.RpcMembershipGetTiersResponse")).RunAndReturn(func(in *pb.RpcMembershipGetStatusResponse, tiers *pb.RpcMembershipGetTiersResponse) (err error) {

View file

@ -2,22 +2,28 @@ package peerstatus
import (
"context"
"errors"
"sync"
"time"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/net/pool"
"github.com/samber/lo"
"github.com/anyproto/anytype-heart/core/event"
"github.com/anyproto/anytype-heart/core/session"
"github.com/anyproto/anytype-heart/pb"
"github.com/anyproto/anytype-heart/pkg/lib/logging"
"github.com/anyproto/anytype-heart/space/spacecore/peerstore"
)
const CName = "core.syncstatus.p2p"
var log = logging.Logger(CName)
type Status int32
var ErrClosed = errors.New("component is closing")
const (
Unknown Status = 0
Connected Status = 1
@ -25,6 +31,19 @@ const (
NotConnected Status = 3
)
func (s Status) ToPb() pb.EventP2PStatusStatus {
switch s {
case Connected:
return pb.EventP2PStatus_Connected
case NotConnected:
return pb.EventP2PStatus_NotConnected
case NotPossible:
return pb.EventP2PStatus_NotPossible
}
// default status is NotConnected
return pb.EventP2PStatus_NotConnected
}
type LocalDiscoveryHook interface {
app.Component
RegisterP2PNotPossible(hook func())
@ -33,39 +52,35 @@ type LocalDiscoveryHook interface {
type PeerToPeerStatus interface {
app.ComponentRunnable
SendNotPossibleStatus()
CheckPeerStatus()
ResetNotPossibleStatus()
RegisterSpace(spaceId string)
UnregisterSpace(spaceId string)
}
type spaceStatus struct {
status Status
connectionsCount int64
}
type p2pStatus struct {
spaceIds map[string]struct{}
spaceIds map[string]*spaceStatus
eventSender event.Sender
contextCancel context.CancelFunc
ctx context.Context
peerStore peerstore.PeerStore
sync.Mutex
status Status
connectionsCount int64
forceCheckSpace chan struct{}
updateStatus chan Status
resetNotPossibleStatus chan struct{}
finish chan struct{}
p2pNotPossible bool // global flag means p2p is not possible because of network
workerFinished chan struct{}
refreshSpaceId chan string
peersConnectionPool pool.Pool
}
func New() PeerToPeerStatus {
p2pStatusService := &p2pStatus{
forceCheckSpace: make(chan struct{}, 1),
updateStatus: make(chan Status, 1),
resetNotPossibleStatus: make(chan struct{}, 1),
finish: make(chan struct{}),
spaceIds: make(map[string]struct{}),
workerFinished: make(chan struct{}),
refreshSpaceId: make(chan string),
spaceIds: make(map[string]*spaceStatus),
}
return p2pStatusService
@ -77,20 +92,35 @@ func (p *p2pStatus) Init(a *app.App) (err error) {
p.peersConnectionPool = app.MustComponent[pool.Service](a)
localDiscoveryHook := app.MustComponent[LocalDiscoveryHook](a)
sessionHookRunner := app.MustComponent[session.HookRunner](a)
localDiscoveryHook.RegisterP2PNotPossible(p.SendNotPossibleStatus)
localDiscoveryHook.RegisterResetNotPossible(p.ResetNotPossibleStatus)
localDiscoveryHook.RegisterP2PNotPossible(p.setNotPossibleStatus)
localDiscoveryHook.RegisterResetNotPossible(p.resetNotPossibleStatus)
sessionHookRunner.RegisterHook(p.sendStatusForNewSession)
p.ctx, p.contextCancel = context.WithCancel(context.Background())
p.peerStore.AddObserver(func(peerId string, spaceIdsBefore, spaceIdsAfter []string, peerRemoved bool) {
// we need to update status for all spaces that were either added or removed to some local peer
// because we start this observer on init we can be sure that the spaceIdsBefore is empty on the first run for peer
removed, added := lo.Difference(spaceIdsBefore, spaceIdsAfter)
err := p.refreshSpaces(lo.Union(removed, added))
if errors.Is(err, ErrClosed) {
return
} else if err != nil {
log.Errorf("refreshSpaces failed: %v", err)
}
})
return nil
}
func (p *p2pStatus) sendStatusForNewSession(ctx session.Context) error {
p.sendStatus(p.status)
p.Lock()
defer p.Unlock()
for spaceId, space := range p.spaceIds {
p.sendEvent(ctx.ID(), spaceId, space.status.ToPb(), space.connectionsCount)
}
return nil
}
func (p *p2pStatus) Run(ctx context.Context) error {
p.ctx, p.contextCancel = context.WithCancel(context.Background())
go p.checkP2PDevices()
go p.worker()
return nil
}
@ -98,7 +128,7 @@ func (p *p2pStatus) Close(ctx context.Context) error {
if p.contextCancel != nil {
p.contextCancel()
}
<-p.finish
<-p.workerFinished
return nil
}
@ -106,159 +136,145 @@ func (p *p2pStatus) Name() (name string) {
return CName
}
func (p *p2pStatus) CheckPeerStatus() {
p.forceCheckSpace <- struct{}{}
}
func (p *p2pStatus) SendNotPossibleStatus() {
p.updateStatus <- NotPossible
}
func (p *p2pStatus) ResetNotPossibleStatus() {
p.resetNotPossibleStatus <- struct{}{}
}
func (p *p2pStatus) RegisterSpace(spaceId string) {
func (p *p2pStatus) setNotPossibleStatus() {
p.Lock()
defer p.Unlock()
p.spaceIds[spaceId] = struct{}{}
connection := p.connectionsCount
if connection == 0 {
connection++ // count current device
if p.p2pNotPossible {
p.Unlock()
return
}
p.eventSender.Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: spaceId,
Status: p.mapStatusToEvent(p.status),
DevicesCounter: connection,
},
},
},
},
})
p.p2pNotPossible = true
p.Unlock()
p.refreshAllSpaces()
}
func (p *p2pStatus) resetNotPossibleStatus() {
p.Lock()
if !p.p2pNotPossible {
p.Unlock()
return
}
p.p2pNotPossible = false
p.Unlock()
p.refreshAllSpaces()
}
// RegisterSpace registers spaceId to be monitored for p2p status changes
// must be called only when p2pStatus is Running
func (p *p2pStatus) RegisterSpace(spaceId string) {
select {
case <-p.ctx.Done():
return
case p.refreshSpaceId <- spaceId:
}
}
// UnregisterSpace unregisters spaceId from monitoring
// must be called only when p2pStatus is Running
func (p *p2pStatus) UnregisterSpace(spaceId string) {
p.Lock()
defer p.Unlock()
delete(p.spaceIds, spaceId)
}
func (p *p2pStatus) checkP2PDevices() {
defer close(p.finish)
timer := time.NewTicker(10 * time.Second)
defer timer.Stop()
p.updateSpaceP2PStatus()
func (p *p2pStatus) worker() {
defer close(p.workerFinished)
for {
select {
case <-p.ctx.Done():
return
case <-timer.C:
p.updateSpaceP2PStatus()
case <-p.forceCheckSpace:
p.updateSpaceP2PStatus()
case newStatus := <-p.updateStatus:
p.sendStatus(newStatus)
case <-p.resetNotPossibleStatus:
p.resetNotPossible()
case spaceId := <-p.refreshSpaceId:
p.processSpaceStatusUpdate(spaceId)
}
}
}
func (p *p2pStatus) updateSpaceP2PStatus() {
func (p *p2pStatus) refreshAllSpaces() {
p.Lock()
defer p.Unlock()
connectionCount := p.countOpenConnections()
newStatus, event := p.getResultStatus(connectionCount)
if newStatus == NotPossible {
return
var spaceIds = make([]string, 0, len(p.spaceIds))
for spaceId := range p.spaceIds {
spaceIds = append(spaceIds, spaceId)
}
connectionCount++ // count current device
if p.status != newStatus || p.connectionsCount != connectionCount {
p.sendEvent(event, connectionCount)
p.status = newStatus
p.connectionsCount = connectionCount
p.Unlock()
err := p.refreshSpaces(spaceIds)
if errors.Is(err, ErrClosed) {
return
} else if err != nil {
log.Errorf("refreshSpaces failed: %v", err)
}
}
func (p *p2pStatus) getResultStatus(connectionCount int64) (Status, pb.EventP2PStatusStatus) {
func (p *p2pStatus) refreshSpaces(spaceIds []string) error {
for _, spaceId := range spaceIds {
select {
case <-p.ctx.Done():
return ErrClosed
case p.refreshSpaceId <- spaceId:
}
}
return nil
}
// updateSpaceP2PStatus updates status for specific spaceId and sends event if status changed
func (p *p2pStatus) processSpaceStatusUpdate(spaceId string) {
p.Lock()
defer p.Unlock()
var (
newStatus Status
event pb.EventP2PStatusStatus
currentStatus *spaceStatus
ok bool
)
if p.status == NotPossible && connectionCount == 0 {
return NotPossible, pb.EventP2PStatus_NotPossible
if currentStatus, ok = p.spaceIds[spaceId]; !ok {
currentStatus = &spaceStatus{
status: Unknown,
connectionsCount: 0,
}
p.spaceIds[spaceId] = currentStatus
}
connectionCount := p.countOpenConnections(spaceId)
newStatus := p.getResultStatus(p.p2pNotPossible, connectionCount)
if currentStatus.status != newStatus || currentStatus.connectionsCount != connectionCount {
p.sendEvent("", spaceId, newStatus.ToPb(), connectionCount)
currentStatus.status = newStatus
currentStatus.connectionsCount = connectionCount
}
}
func (p *p2pStatus) getResultStatus(notPossible bool, connectionCount int64) Status {
if notPossible && connectionCount == 0 {
return NotPossible
}
if connectionCount == 0 {
event = pb.EventP2PStatus_NotConnected
newStatus = NotConnected
return NotConnected
} else {
event = pb.EventP2PStatus_Connected
newStatus = Connected
return Connected
}
return newStatus, event
}
func (p *p2pStatus) countOpenConnections() int64 {
var connectionCount int64
ctx, cancelFunc := context.WithTimeout(context.Background(), time.Second*20)
defer cancelFunc()
peerIds := p.peerStore.AllLocalPeers()
for _, peerId := range peerIds {
_, err := p.peersConnectionPool.Pick(ctx, peerId)
if err != nil {
continue
}
connectionCount++
}
return connectionCount
}
func (p *p2pStatus) sendStatus(status Status) {
p.Lock()
defer p.Unlock()
pbStatus := p.mapStatusToEvent(status)
p.status = status
p.sendEvent(pbStatus, p.connectionsCount)
func (p *p2pStatus) countOpenConnections(spaceId string) int64 {
peerIds := p.peerStore.LocalPeerIds(spaceId)
return int64(len(peerIds))
}
func (p *p2pStatus) mapStatusToEvent(status Status) pb.EventP2PStatusStatus {
var pbStatus pb.EventP2PStatusStatus
switch status {
case Connected:
pbStatus = pb.EventP2PStatus_Connected
case NotConnected:
pbStatus = pb.EventP2PStatus_NotConnected
case NotPossible:
pbStatus = pb.EventP2PStatus_NotPossible
}
return pbStatus
}
func (p *p2pStatus) sendEvent(status pb.EventP2PStatusStatus, count int64) {
for spaceId := range p.spaceIds {
p.eventSender.Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: spaceId,
Status: status,
DevicesCounter: count,
},
// sendEvent sends event to session with sessionToken or broadcast to all sessions if sessionToken is empty
func (p *p2pStatus) sendEvent(sessionToken string, spaceId string, status pb.EventP2PStatusStatus, count int64) {
event := &pb.Event{
Messages: []*pb.EventMessage{
{
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: spaceId,
Status: status,
DevicesCounter: count,
},
},
},
})
},
}
}
func (p *p2pStatus) resetNotPossible() {
p.Lock()
defer p.Unlock()
if p.status == NotPossible {
p.status = NotConnected
if sessionToken != "" {
p.eventSender.SendToSession(sessionToken, event)
return
}
p.eventSender.Broadcast(event)
}

View file

@ -2,6 +2,7 @@ package peerstatus
import (
"context"
"fmt"
"testing"
"time"
@ -22,7 +23,7 @@ import (
)
type fixture struct {
PeerToPeerStatus
*p2pStatus
sender *mock_event.MockSender
service *mock_nodeconf.MockService
store peerstore.PeerStore
@ -35,9 +36,6 @@ func TestP2PStatus_Init(t *testing.T) {
// given
f := newFixture(t, "spaceId", pb.EventP2PStatus_NotConnected, 1)
// when
f.Run(nil)
// then
f.Close(nil)
})
@ -47,7 +45,6 @@ func TestP2pStatus_SendNewStatus(t *testing.T) {
t.Run("send NotPossible status", func(t *testing.T) {
// given
f := newFixture(t, "spaceId", pb.EventP2PStatus_NotConnected, 1)
f.Run(nil)
// when
f.sender.EXPECT().Broadcast(&pb.Event{
@ -57,22 +54,40 @@ func TestP2pStatus_SendNewStatus(t *testing.T) {
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_NotPossible,
DevicesCounter: 1,
DevicesCounter: 0,
},
},
},
},
})
f.SendNotPossibleStatus()
f.setNotPossibleStatus()
// then
status := f.PeerToPeerStatus.(*p2pStatus)
assert.NotNil(t, status)
err := waitForStatus(status, NotPossible)
err := waitForStatus("spaceId", f.p2pStatus, NotPossible)
assert.Nil(t, err)
f.CheckPeerStatus()
err = waitForStatus(status, NotPossible)
// when
f.sender.EXPECT().Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_NotConnected,
DevicesCounter: 0,
},
},
},
},
})
f.resetNotPossibleStatus()
err = f.refreshSpaces([]string{"spaceId"})
assert.Nil(t, err)
checkStatus(t, "spaceId", f.p2pStatus, NotConnected)
assert.Nil(t, err)
f.Close(nil)
@ -81,13 +96,10 @@ func TestP2pStatus_SendNewStatus(t *testing.T) {
// given
f := newFixture(t, "spaceId", pb.EventP2PStatus_NotConnected, 1)
// when
f.Run(nil)
// then
status := f.PeerToPeerStatus.(*p2pStatus)
status := f.p2pStatus
assert.NotNil(t, status)
err := waitForStatus(status, NotConnected)
err := waitForStatus("spaceId", status, NotConnected)
assert.Nil(t, err)
f.Close(nil)
})
@ -104,8 +116,6 @@ func TestP2pStatus_SendPeerUpdate(t *testing.T) {
err := f.pool.AddPeer(context.Background(), peer)
assert.Nil(t, err)
// when
f.Run(nil)
f.sender.EXPECT().Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
@ -113,25 +123,117 @@ func TestP2pStatus_SendPeerUpdate(t *testing.T) {
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_Connected,
DevicesCounter: 2,
DevicesCounter: 1,
},
},
},
},
})
f.CheckPeerStatus()
// then
f.Close(nil)
status := f.PeerToPeerStatus.(*p2pStatus)
assert.NotNil(t, status)
err = waitForStatus(status, Connected)
assert.Nil(t, err)
checkStatus(t, "spaceId", f.p2pStatus, Connected)
// should not create a problem, cause we already closed
f.store.RemoveLocalPeer("peerId")
})
t.Run("send NotConnected status, because we have peer were disconnected", func(t *testing.T) {
t.Run("send NotConnected status, because we have peer and then were disconnected", func(t *testing.T) {
// given
f := newFixture(t, "spaceId", pb.EventP2PStatus_Connected, 1)
ctrl := gomock.NewController(t)
peer := mock_peer.NewMockPeer(ctrl)
peer.EXPECT().Id().Return("peerId")
f.sender.EXPECT().Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_Connected,
DevicesCounter: 1,
},
},
},
},
})
err := f.pool.AddPeer(context.Background(), peer)
assert.Nil(t, err)
f.store.UpdateLocalPeer("peerId", []string{"spaceId"})
checkStatus(t, "spaceId", f.p2pStatus, Connected)
f.sender.EXPECT().Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_NotConnected,
DevicesCounter: 0,
},
},
},
},
})
f.store.RemoveLocalPeer("peerId")
checkStatus(t, "spaceId", f.p2pStatus, NotConnected)
// then
f.Close(nil)
assert.Nil(t, err)
})
t.Run("connection was not possible, but after a while starts working", func(t *testing.T) {
// given
f := newFixture(t, "spaceId", pb.EventP2PStatus_NotConnected, 1)
// when
f.sender.EXPECT().Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_NotPossible,
DevicesCounter: 0,
},
},
},
},
})
f.setNotPossibleStatus()
checkStatus(t, "spaceId", f.p2pStatus, NotPossible)
f.store.UpdateLocalPeer("peerId", []string{"spaceId"})
ctrl := gomock.NewController(t)
peer := mock_peer.NewMockPeer(ctrl)
peer.EXPECT().Id().Return("peerId")
f.sender.EXPECT().Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_Connected,
DevicesCounter: 1,
},
},
},
},
})
err := f.pool.AddPeer(context.Background(), peer)
assert.Nil(t, err)
checkStatus(t, "spaceId", f.p2pStatus, Connected)
// then
f.Close(nil)
})
t.Run("no peers were connected, but after a while one is connected", func(t *testing.T) {
// given
f := newFixture(t, "spaceId", pb.EventP2PStatus_NotConnected, 1)
// when
checkStatus(t, "spaceId", f.p2pStatus, NotConnected)
f.store.UpdateLocalPeer("peerId", []string{"spaceId"})
ctrl := gomock.NewController(t)
peer := mock_peer.NewMockPeer(ctrl)
@ -139,8 +241,6 @@ func TestP2pStatus_SendPeerUpdate(t *testing.T) {
err := f.pool.AddPeer(context.Background(), peer)
assert.Nil(t, err)
// when
f.Run(nil)
f.sender.EXPECT().Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
@ -148,145 +248,22 @@ func TestP2pStatus_SendPeerUpdate(t *testing.T) {
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_Connected,
DevicesCounter: 2,
},
},
},
},
})
err = waitForStatus(f.PeerToPeerStatus.(*p2pStatus), Connected)
assert.Nil(t, err)
f.store.RemoveLocalPeer("peerId")
f.sender.EXPECT().Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_NotConnected,
DevicesCounter: 1,
},
},
},
},
})
f.CheckPeerStatus()
err = waitForStatus(f.PeerToPeerStatus.(*p2pStatus), NotConnected)
assert.Nil(t, err)
checkStatus(t, "spaceId", f.p2pStatus, Connected)
// then
f.Close(nil)
assert.Nil(t, err)
status := f.PeerToPeerStatus.(*p2pStatus)
assert.NotNil(t, status)
err = waitForStatus(status, NotConnected)
})
t.Run("connection was not possible, but after a while starts working", func(t *testing.T) {
// given
f := newFixture(t, "spaceId", pb.EventP2PStatus_NotConnected, 1)
// when
f.Run(nil)
f.sender.EXPECT().Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_NotPossible,
DevicesCounter: 1,
},
},
},
},
})
f.SendNotPossibleStatus()
err := waitForStatus(f.PeerToPeerStatus.(*p2pStatus), NotPossible)
assert.Nil(t, err)
f.store.UpdateLocalPeer("peerId", []string{"spaceId"})
ctrl := gomock.NewController(t)
peer := mock_peer.NewMockPeer(ctrl)
peer.EXPECT().Id().Return("peerId")
err = f.pool.AddPeer(context.Background(), peer)
assert.Nil(t, err)
f.sender.EXPECT().Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_Connected,
DevicesCounter: 2,
},
},
},
},
})
f.CheckPeerStatus()
err = waitForStatus(f.PeerToPeerStatus.(*p2pStatus), Connected)
assert.Nil(t, err)
// then
f.Close(nil)
assert.Nil(t, err)
status := f.PeerToPeerStatus.(*p2pStatus)
assert.NotNil(t, status)
checkStatus(t, status, Connected)
})
t.Run("no peers were connected, but after a while one is connected", func(t *testing.T) {
// given
f := newFixture(t, "spaceId", pb.EventP2PStatus_NotConnected, 1)
// when
f.Run(nil)
err := waitForStatus(f.PeerToPeerStatus.(*p2pStatus), NotConnected)
f.store.UpdateLocalPeer("peerId", []string{"spaceId"})
ctrl := gomock.NewController(t)
peer := mock_peer.NewMockPeer(ctrl)
peer.EXPECT().Id().Return("peerId")
err = f.pool.AddPeer(context.Background(), peer)
assert.Nil(t, err)
f.sender.EXPECT().Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_Connected,
DevicesCounter: 2,
},
},
},
},
})
f.CheckPeerStatus()
err = waitForStatus(f.PeerToPeerStatus.(*p2pStatus), Connected)
assert.Nil(t, err)
// then
f.Close(nil)
assert.Nil(t, err)
status := f.PeerToPeerStatus.(*p2pStatus)
assert.NotNil(t, status)
checkStatus(t, status, Connected)
})
t.Run("reset not possible status", func(t *testing.T) {
// given
f := newFixture(t, "spaceId", pb.EventP2PStatus_NotConnected, 1)
// when
f.Run(nil)
f.sender.EXPECT().Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
@ -294,16 +271,19 @@ func TestP2pStatus_SendPeerUpdate(t *testing.T) {
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_NotPossible,
DevicesCounter: 1,
DevicesCounter: 0,
},
},
},
},
})
f.SendNotPossibleStatus()
status := f.PeerToPeerStatus.(*p2pStatus)
assert.NotNil(t, status)
err := waitForStatus(status, NotPossible)
f.setNotPossibleStatus()
checkStatus(t, "spaceId", f.p2pStatus, NotPossible)
// double set should not generate new event
f.setNotPossibleStatus()
checkStatus(t, "spaceId", f.p2pStatus, NotPossible)
f.sender.EXPECT().Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
@ -311,41 +291,78 @@ func TestP2pStatus_SendPeerUpdate(t *testing.T) {
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_NotConnected,
DevicesCounter: 1,
DevicesCounter: 0,
},
},
},
},
})
f.ResetNotPossibleStatus()
err = waitForStatus(status, NotConnected)
assert.Nil(t, err)
f.resetNotPossibleStatus()
checkStatus(t, "spaceId", f.p2pStatus, NotConnected)
// then
f.Close(nil)
assert.Nil(t, err)
checkStatus(t, status, NotConnected)
})
t.Run("don't reset not possible status, because status != NotPossible", func(t *testing.T) {
// given
f := newFixture(t, "spaceId", pb.EventP2PStatus_NotConnected, 1)
// when
f.Run(nil)
status := f.PeerToPeerStatus.(*p2pStatus)
err := waitForStatus(status, NotConnected)
f.ResetNotPossibleStatus()
err = waitForStatus(status, NotConnected)
checkStatus(t, "spaceId", f.p2pStatus, NotConnected)
f.resetNotPossibleStatus()
checkStatus(t, "spaceId", f.p2pStatus, NotConnected)
// then
f.Close(nil)
assert.Nil(t, err)
checkStatus(t, status, NotConnected)
checkStatus(t, "spaceId", f.p2pStatus, NotConnected)
})
}
func TestP2pStatus_SendToNewSession(t *testing.T) {
t.Run("send event only to new session", func(t *testing.T) {
// given
f := newFixture(t, "spaceId", pb.EventP2PStatus_Connected, 1)
ctrl := gomock.NewController(t)
peer := mock_peer.NewMockPeer(ctrl)
peer.EXPECT().Id().Return("peerId")
f.sender.EXPECT().Broadcast(&pb.Event{
Messages: []*pb.EventMessage{
{
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_Connected,
DevicesCounter: 1,
},
},
},
},
})
err := f.pool.AddPeer(context.Background(), peer)
assert.Nil(t, err)
f.store.UpdateLocalPeer("peerId", []string{"spaceId"})
checkStatus(t, "spaceId", f.p2pStatus, Connected)
f.sender.EXPECT().SendToSession("token1", &pb.Event{
Messages: []*pb.EventMessage{
{
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: "spaceId",
Status: pb.EventP2PStatus_Connected,
DevicesCounter: 1,
},
},
},
},
})
err = f.sendStatusForNewSession(session.NewContext(session.WithSession("token1")))
assert.Nil(t, err)
// then
f.Close(nil)
})
}
func TestP2pStatus_UnregisterSpace(t *testing.T) {
t.Run("success", func(t *testing.T) {
// given
@ -356,7 +373,7 @@ func TestP2pStatus_UnregisterSpace(t *testing.T) {
// then
status := f.PeerToPeerStatus.(*p2pStatus)
status := f.p2pStatus
assert.Len(t, status.spaceIds, 0)
})
t.Run("delete non existing space", func(t *testing.T) {
@ -367,7 +384,7 @@ func TestP2pStatus_UnregisterSpace(t *testing.T) {
f.UnregisterSpace("spaceId1")
// then
status := f.PeerToPeerStatus.(*p2pStatus)
status := f.p2pStatus
assert.Len(t, status.spaceIds, 1)
})
}
@ -404,7 +421,7 @@ func newFixture(t *testing.T, spaceId string, initialStatus pb.EventP2PStatusSta
Value: &pb.EventMessageValueOfP2PStatusUpdate{
P2PStatusUpdate: &pb.EventP2PStatusUpdate{
SpaceId: spaceId,
DevicesCounter: 1,
DevicesCounter: 0,
},
},
},
@ -423,40 +440,54 @@ func newFixture(t *testing.T, spaceId string, initialStatus pb.EventP2PStatusSta
},
},
}).Maybe()
status.RegisterSpace(spaceId)
err = status.Run(context.Background())
assert.Nil(t, err)
status.RegisterSpace(spaceId)
f := &fixture{
PeerToPeerStatus: status,
sender: sender,
service: service,
store: store,
pool: pool,
hookRegister: hookRegister,
p2pStatus: status.(*p2pStatus),
sender: sender,
service: service,
store: store,
pool: pool,
hookRegister: hookRegister,
}
return f
}
func waitForStatus(statusSender *p2pStatus, expectedStatus Status) error {
func waitForStatus(spaceId string, statusSender *p2pStatus, expectedStatus Status) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
case <-time.After(time.Millisecond * 10):
statusSender.Lock()
if statusSender.status == expectedStatus {
if status, ok := statusSender.spaceIds[spaceId]; !ok {
statusSender.Unlock()
return nil
return fmt.Errorf("spaceId %s not found", spaceId)
} else {
if status.status == expectedStatus {
statusSender.Unlock()
return nil
}
}
statusSender.Unlock()
}
}
}
func checkStatus(t *testing.T, statusSender *p2pStatus, expectedStatus Status) {
func checkStatus(t *testing.T, spaceId string, statusSender *p2pStatus, expectedStatus Status) {
time.Sleep(time.Millisecond * 300)
statusSender.Lock()
defer statusSender.Unlock()
assert.Equal(t, expectedStatus, statusSender.status)
if status, ok := statusSender.spaceIds[spaceId]; !ok {
assert.Fail(t, "spaceId %s not found", spaceId)
} else {
assert.Equal(t, expectedStatus, status.status)
}
}

View file

@ -153,6 +153,10 @@ func (c *collectionSub) hasDep() bool {
return c.sortedSub.hasDep()
}
func (c *collectionSub) getDep() subscription {
return c.sortedSub.depSub
}
func (c *collectionSub) close() {
c.observer.close()
c.sortedSub.close()

View file

@ -43,6 +43,8 @@ type opGroup struct {
}
type opCtx struct {
outputs map[string][]*pb.EventMessage
// subIds for remove
remove []opRemove
change []opChange
@ -60,14 +62,31 @@ type opCtx struct {
c *cache
}
func (ctx *opCtx) apply() (event *pb.Event) {
var subMsgs = make([]*pb.EventMessage, 0, 10)
const defaultOutput = "_default"
func (ctx *opCtx) apply() {
addEvent := func(subId string, ev *pb.EventMessage) {
_, ok := ctx.outputs[subId]
if ok {
ctx.outputs[subId] = append(ctx.outputs[subId], ev)
} else {
ctx.outputs[defaultOutput] = append(ctx.outputs[defaultOutput], ev)
}
}
// changes
for _, ch := range ctx.change {
ctx.collectKeys(ch.id, ch.subId, ch.keys)
}
// details events
ctx.detailsEvents()
// adds, positions
for _, pos := range ctx.position {
if pos.isAdd {
ctx.collectKeys(pos.id, pos.subId, pos.keys)
subMsgs = append(subMsgs, &pb.EventMessage{
addEvent(pos.subId, &pb.EventMessage{
Value: &pb.EventMessageValueOfSubscriptionAdd{
SubscriptionAdd: &pb.EventObjectSubscriptionAdd{
Id: pos.id,
@ -77,7 +96,7 @@ func (ctx *opCtx) apply() (event *pb.Event) {
},
})
} else {
subMsgs = append(subMsgs, &pb.EventMessage{
addEvent(pos.subId, &pb.EventMessage{
Value: &pb.EventMessageValueOfSubscriptionPosition{
SubscriptionPosition: &pb.EventObjectSubscriptionPosition{
Id: pos.id,
@ -89,17 +108,9 @@ func (ctx *opCtx) apply() (event *pb.Event) {
}
}
// changes
for _, ch := range ctx.change {
ctx.collectKeys(ch.id, ch.subId, ch.keys)
}
// details events
eventMsgs := ctx.detailsEvents()
// removes
for _, rem := range ctx.remove {
subMsgs = append(subMsgs, &pb.EventMessage{
addEvent(rem.subId, &pb.EventMessage{
Value: &pb.EventMessageValueOfSubscriptionRemove{
SubscriptionRemove: &pb.EventObjectSubscriptionRemove{
Id: rem.id,
@ -111,7 +122,7 @@ func (ctx *opCtx) apply() (event *pb.Event) {
// counters
for _, count := range ctx.counters {
subMsgs = append(subMsgs, &pb.EventMessage{
addEvent(count.subId, &pb.EventMessage{
Value: &pb.EventMessageValueOfSubscriptionCounters{
SubscriptionCounters: &pb.EventObjectSubscriptionCounters{
Total: int64(count.total),
@ -133,7 +144,7 @@ func (ctx *opCtx) apply() (event *pb.Event) {
}
for _, opGroup := range ctx.groups {
subMsgs = append(subMsgs, &pb.EventMessage{
addEvent(opGroup.subId, &pb.EventMessage{
Value: &pb.EventMessageValueOfSubscriptionGroups{
SubscriptionGroups: &pb.EventObjectSubscriptionGroups{
SubId: opGroup.subId,
@ -143,13 +154,14 @@ func (ctx *opCtx) apply() (event *pb.Event) {
},
})
}
return &pb.Event{
Messages: append(eventMsgs, subMsgs...),
}
}
func (ctx *opCtx) detailsEvents() (msgs []*pb.EventMessage) {
// detailsEvents produces following types of events:
// EventObjectDetailsAmend
// EventObjectDetailsUnset
// EventMessageValueOfObjectDetailsSet
func (ctx *opCtx) detailsEvents() {
var msgs []*pb.EventMessage
var getEntry = func(id string) *entry {
for _, e := range ctx.entries {
if e.id == id {
@ -186,7 +198,110 @@ func (ctx *opCtx) detailsEvents() (msgs []*pb.EventMessage) {
curr.SetSub(sub, true, true)
}
}
return
ctx.groupDetailsEvents(msgs)
}
func (ctx *opCtx) groupDetailsEvents(msgs []*pb.EventMessage) {
for _, msg := range msgs {
if v := msg.GetObjectDetailsAmend(); v != nil {
ctx.groupEventsDetailsAmend(v)
} else if v := msg.GetObjectDetailsUnset(); v != nil {
ctx.groupEventsDetailsUnset(v)
} else if v := msg.GetObjectDetailsSet(); v != nil {
ctx.groupEventsDetailsSet(v)
}
}
}
func (ctx *opCtx) groupEventsDetailsSet(v *pb.EventObjectDetailsSet) {
defaultSubIds := v.SubIds[:0]
for _, subId := range v.SubIds {
if _, ok := ctx.outputs[subId]; ok {
ctx.outputs[subId] = append(ctx.outputs[subId], &pb.EventMessage{
Value: &pb.EventMessageValueOfObjectDetailsSet{
ObjectDetailsSet: &pb.EventObjectDetailsSet{
Id: v.Id,
Details: v.Details,
SubIds: []string{subId},
},
},
})
} else {
defaultSubIds = append(defaultSubIds, subId)
}
}
if len(defaultSubIds) > 0 {
ctx.outputs[defaultOutput] = append(ctx.outputs[defaultOutput], &pb.EventMessage{
Value: &pb.EventMessageValueOfObjectDetailsSet{
ObjectDetailsSet: &pb.EventObjectDetailsSet{
Id: v.Id,
Details: v.Details,
SubIds: defaultSubIds,
},
},
})
}
}
func (ctx *opCtx) groupEventsDetailsUnset(v *pb.EventObjectDetailsUnset) {
defaultSubIds := v.SubIds[:0]
for _, subId := range v.SubIds {
if _, ok := ctx.outputs[subId]; ok {
ctx.outputs[subId] = append(ctx.outputs[subId], &pb.EventMessage{
Value: &pb.EventMessageValueOfObjectDetailsUnset{
ObjectDetailsUnset: &pb.EventObjectDetailsUnset{
Id: v.Id,
Keys: v.Keys,
SubIds: []string{subId},
},
},
})
} else {
defaultSubIds = append(defaultSubIds, subId)
}
}
if len(defaultSubIds) > 0 {
ctx.outputs[defaultOutput] = append(ctx.outputs[defaultOutput], &pb.EventMessage{
Value: &pb.EventMessageValueOfObjectDetailsUnset{
ObjectDetailsUnset: &pb.EventObjectDetailsUnset{
Id: v.Id,
Keys: v.Keys,
SubIds: defaultSubIds,
},
},
})
}
}
func (ctx *opCtx) groupEventsDetailsAmend(v *pb.EventObjectDetailsAmend) {
defaultSubIds := v.SubIds[:0]
for _, subId := range v.SubIds {
if _, ok := ctx.outputs[subId]; ok {
ctx.outputs[subId] = append(ctx.outputs[subId], &pb.EventMessage{
Value: &pb.EventMessageValueOfObjectDetailsAmend{
ObjectDetailsAmend: &pb.EventObjectDetailsAmend{
Id: v.Id,
Details: v.Details,
SubIds: []string{subId},
},
},
})
} else {
defaultSubIds = append(defaultSubIds, subId)
}
}
if len(defaultSubIds) > 0 {
ctx.outputs[defaultOutput] = append(ctx.outputs[defaultOutput], &pb.EventMessage{
Value: &pb.EventMessageValueOfObjectDetailsAmend{
ObjectDetailsAmend: &pb.EventObjectDetailsAmend{
Id: v.Id,
Details: v.Details,
SubIds: defaultSubIds,
},
},
})
}
}
func (ctx *opCtx) collectKeys(id string, subId string, keys []string) {
@ -233,4 +348,9 @@ func (ctx *opCtx) reset() {
ctx.keysBuf = ctx.keysBuf[:0]
ctx.entries = ctx.entries[:0]
ctx.groups = ctx.groups[:0]
if ctx.outputs == nil {
ctx.outputs = map[string][]*pb.EventMessage{
defaultOutput: nil,
}
}
}

View file

@ -109,7 +109,6 @@ var ignoredKeys = map[string]struct{}{
bundle.RelationKeyId.String(): {},
bundle.RelationKeySpaceId.String(): {}, // relation format for spaceId has mistakenly set to Object instead of shorttext
bundle.RelationKeyFeaturedRelations.String(): {}, // relation format for featuredRelations has mistakenly set to Object instead of shorttext
bundle.RelationKeyLinks.String(): {}, // skip links because it's aggregated from other relations and blocks
}
func (ds *dependencyService) isRelationObject(key string) bool {
@ -123,12 +122,12 @@ func (ds *dependencyService) isRelationObject(key string) bool {
if isObj, ok := ds.isRelationObjMap[key]; ok {
return isObj
}
rel, err := ds.s.objectStore.GetRelationByKey(key)
relFormat, err := ds.s.objectStore.GetRelationFormatByKey(key)
if err != nil {
log.Errorf("can't get relation %s: %v", key, err)
return false
}
isObj := rel.Format == model.RelationFormat_object || rel.Format == model.RelationFormat_file || rel.Format == model.RelationFormat_tag || rel.Format == model.RelationFormat_status
isObj := relFormat == model.RelationFormat_object || relFormat == model.RelationFormat_file || relFormat == model.RelationFormat_tag || relFormat == model.RelationFormat_status
ds.isRelationObjMap[key] = isObj
return isObj
}

View file

@ -0,0 +1,75 @@
package subscription
import (
"context"
"testing"
"github.com/anyproto/any-sync/app"
"github.com/stretchr/testify/require"
"github.com/anyproto/anytype-heart/core/event/mock_event"
"github.com/anyproto/anytype-heart/core/kanban"
"github.com/anyproto/anytype-heart/pkg/lib/localstore/objectstore"
"github.com/anyproto/anytype-heart/tests/testutil"
)
type InternalTestService struct {
Service
*objectstore.StoreFixture
}
func (s *InternalTestService) Init(a *app.App) error {
return s.Service.Init(a)
}
func (s *InternalTestService) Run(ctx context.Context) error {
err := s.StoreFixture.Run(ctx)
if err != nil {
return err
}
return s.Service.Run(ctx)
}
func (s *InternalTestService) Close(ctx context.Context) (err error) {
_ = s.Service.Close(ctx)
return s.StoreFixture.Close(ctx)
}
func NewInternalTestService(t *testing.T) *InternalTestService {
s := New()
ctx := context.Background()
objectStore := objectstore.NewStoreFixture(t)
a := &app.App{}
a.Register(objectStore)
a.Register(kanban.New())
a.Register(&collectionServiceMock{MockCollectionService: NewMockCollectionService(t)})
a.Register(testutil.PrepareMock(ctx, a, mock_event.NewMockSender(t)))
a.Register(s)
err := a.Start(ctx)
require.NoError(t, err)
return &InternalTestService{Service: s, StoreFixture: objectStore}
}
func RegisterSubscriptionService(t *testing.T, a *app.App) *InternalTestService {
s := New()
ctx := context.Background()
objectStore := objectstore.NewStoreFixture(t)
a.Register(objectStore).
Register(kanban.New()).
Register(&collectionServiceMock{MockCollectionService: NewMockCollectionService(t)}).
Register(testutil.PrepareMock(ctx, a, mock_event.NewMockSender(t))).
Register(s)
return &InternalTestService{Service: s, StoreFixture: objectStore}
}
type collectionServiceMock struct {
*MockCollectionService
}
func (c *collectionServiceMock) Name() string {
return "collectionService"
}
func (c *collectionServiceMock) Init(a *app.App) error { return nil }

View file

@ -23,16 +23,6 @@ import (
"github.com/anyproto/anytype-heart/util/testMock/mockKanban"
)
type collectionServiceMock struct {
*MockCollectionService
}
func (c *collectionServiceMock) Name() string {
return "collectionService"
}
func (c *collectionServiceMock) Init(a *app.App) error { return nil }
type fixture struct {
Service
a *app.App

View file

@ -123,6 +123,10 @@ func (gs *groupSub) hasDep() bool {
return false
}
func (gs *groupSub) getDep() subscription {
return nil
}
func (gs *groupSub) close() {
for id := range gs.set {
gs.cache.RemoveSubId(id, gs.id)

View file

@ -0,0 +1,403 @@
package subscription
import (
"context"
"errors"
"testing"
"time"
mb2 "github.com/cheggaaa/mb/v3"
"github.com/gogo/protobuf/types"
"github.com/stretchr/testify/require"
"github.com/anyproto/anytype-heart/pb"
"github.com/anyproto/anytype-heart/pkg/lib/bundle"
"github.com/anyproto/anytype-heart/pkg/lib/localstore/objectstore"
"github.com/anyproto/anytype-heart/pkg/lib/pb/model"
"github.com/anyproto/anytype-heart/util/pbtypes"
)
func wrapToEventMessages(vals []pb.IsEventMessageValue) []*pb.EventMessage {
msgs := make([]*pb.EventMessage, len(vals))
for i, v := range vals {
msgs[i] = &pb.EventMessage{Value: v}
}
return msgs
}
func TestInternalSubscriptionSingle(t *testing.T) {
fx := NewInternalTestService(t)
resp, err := fx.Search(SubscribeRequest{
SubId: "test",
Filters: []*model.BlockContentDataviewFilter{
{
RelationKey: bundle.RelationKeyPriority.String(),
Condition: model.BlockContentDataviewFilter_Equal,
Value: pbtypes.Int64(10),
},
},
Keys: []string{bundle.RelationKeyId.String(), bundle.RelationKeyName.String(), bundle.RelationKeyPriority.String()},
Internal: true,
})
require.NoError(t, err)
require.Empty(t, resp.Records)
t.Run("amend details not related to filter", func(t *testing.T) {
fx.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id1"),
bundle.RelationKeyName: pbtypes.String("task1"),
bundle.RelationKeyPriority: pbtypes.Int64(10),
bundle.RelationKeyLinkedProjects: pbtypes.StringList([]string{"project1", "project2"}), // Should be ignored as not listed in keys
},
})
time.Sleep(batchTime)
fx.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id1"),
bundle.RelationKeyName: pbtypes.String("task1 renamed"),
bundle.RelationKeyPriority: pbtypes.Int64(10),
},
})
time.Sleep(batchTime)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
want := givenMessagesForFirstObject("test")
msgs, err := resp.Output.NewCond().WithMin(len(want)).Wait(ctx)
require.NoError(t, err)
require.Equal(t, wrapToEventMessages(want), msgs)
})
t.Run("amend details related to filter -- remove from subscription", func(t *testing.T) {
fx.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id2"),
bundle.RelationKeyName: pbtypes.String("task2"),
bundle.RelationKeyPriority: pbtypes.Int64(10),
},
})
time.Sleep(batchTime)
fx.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id2"),
bundle.RelationKeyName: pbtypes.String("task2"),
bundle.RelationKeyPriority: pbtypes.Int64(9),
},
})
time.Sleep(batchTime)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
want := givenMessagesForSecondObject("test")
msgs, err := resp.Output.NewCond().WithMin(len(want)).Wait(ctx)
require.NoError(t, err)
require.Equal(t, wrapToEventMessages(want), msgs)
})
t.Run("close", func(t *testing.T) {
err = fx.Unsubscribe("test")
require.NoError(t, err)
err = resp.Output.Add(context.Background(), &pb.EventMessage{})
require.True(t, errors.Is(err, mb2.ErrClosed))
})
t.Run("try to add after close", func(t *testing.T) {
time.Sleep(batchTime)
fx.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id3"),
bundle.RelationKeyName: pbtypes.String("task2"),
bundle.RelationKeyPriority: pbtypes.Int64(10),
},
})
})
}
func TestInternalSubscriptionMultiple(t *testing.T) {
fx := newFixtureWithRealObjectStore(t)
resp1, err := fx.Search(SubscribeRequest{
SubId: "internal1",
Filters: []*model.BlockContentDataviewFilter{
{
RelationKey: bundle.RelationKeyPriority.String(),
Condition: model.BlockContentDataviewFilter_Equal,
Value: pbtypes.Int64(10),
},
},
Keys: []string{bundle.RelationKeyId.String(), bundle.RelationKeyName.String(), bundle.RelationKeyPriority.String()},
Internal: true,
})
_, err = fx.Search(SubscribeRequest{
SubId: "client1",
Filters: []*model.BlockContentDataviewFilter{
{
RelationKey: bundle.RelationKeyPriority.String(),
Condition: model.BlockContentDataviewFilter_Equal,
Value: pbtypes.Int64(10),
},
},
Keys: []string{bundle.RelationKeyId.String(), bundle.RelationKeyName.String(), bundle.RelationKeyPriority.String()},
})
_, err = fx.Search(SubscribeRequest{
SubId: "client2",
Filters: []*model.BlockContentDataviewFilter{
{
RelationKey: bundle.RelationKeyPriority.String(),
Condition: model.BlockContentDataviewFilter_Equal,
Value: pbtypes.Int64(10),
},
},
Keys: []string{bundle.RelationKeyId.String(), bundle.RelationKeyName.String(), bundle.RelationKeyPriority.String()},
})
resp4, err := fx.Search(SubscribeRequest{
SubId: "internal2",
Filters: []*model.BlockContentDataviewFilter{
{
RelationKey: bundle.RelationKeyName.String(),
Condition: model.BlockContentDataviewFilter_Equal,
Value: pbtypes.String("Jane Doe"),
},
},
Keys: []string{bundle.RelationKeyId.String(), bundle.RelationKeyName.String(), bundle.RelationKeyPriority.String()},
Internal: true,
})
require.NoError(t, err)
require.Empty(t, resp1.Records)
t.Run("amend details not related to filter", func(t *testing.T) {
fx.store.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id1"),
bundle.RelationKeyName: pbtypes.String("task1"),
bundle.RelationKeyPriority: pbtypes.Int64(10),
bundle.RelationKeyLinkedProjects: pbtypes.StringList([]string{"project1", "project2"}), // Should be ignored as not listed in keys
},
})
time.Sleep(batchTime)
fx.store.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id1"),
bundle.RelationKeyName: pbtypes.String("task1 renamed"),
bundle.RelationKeyPriority: pbtypes.Int64(10),
},
})
time.Sleep(batchTime)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
want := givenMessagesForFirstObject("internal1")
msgs, err := resp1.Output.NewCond().WithMin(len(want)).Wait(ctx)
require.NoError(t, err)
require.Equal(t, wrapToEventMessages(want), msgs)
want = givenMessagesForFirstObject("client1", "client2")
fx.waitEvents(t, want...)
})
t.Run("amend details related to filter -- remove from subscription", func(t *testing.T) {
fx.store.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id2"),
bundle.RelationKeyName: pbtypes.String("task2"),
bundle.RelationKeyPriority: pbtypes.Int64(10),
},
})
time.Sleep(batchTime)
fx.store.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id2"),
bundle.RelationKeyName: pbtypes.String("task2"),
bundle.RelationKeyPriority: pbtypes.Int64(9),
},
})
time.Sleep(batchTime)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
want := givenMessagesForSecondObject("internal1")
msgs, err := resp1.Output.NewCond().WithMin(len(want)).Wait(ctx)
require.NoError(t, err)
require.Equal(t, wrapToEventMessages(want), msgs)
want = givenMessagesForSecondObject("client1", "client2")
fx.waitEvents(t, want...)
})
t.Run("add item satisfying filters from all subscription", func(t *testing.T) {
fx.store.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id3"),
bundle.RelationKeyName: pbtypes.String("Jane Doe"),
bundle.RelationKeyPriority: pbtypes.Int64(10),
},
})
time.Sleep(batchTime)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
want := givenMessagesForThirdObject(2, "id1", "internal1")
msgs, err := resp1.Output.NewCond().WithMin(len(want)).Wait(ctx)
require.NoError(t, err)
require.Equal(t, wrapToEventMessages(want), msgs)
want = givenMessagesForThirdObject(1, "", "internal2")
msgs, err = resp4.Output.NewCond().WithMin(len(want)).Wait(ctx)
require.NoError(t, err)
require.Equal(t, wrapToEventMessages(want), msgs)
want = givenMessagesForThirdObject(2, "id1", "client1", "client2")
fx.waitEvents(t, want...)
})
}
func givenMessagesForFirstObject(subIds ...string) []pb.IsEventMessageValue {
var msgs []pb.IsEventMessageValue
msgs = append(msgs, &pb.EventMessageValueOfObjectDetailsSet{
ObjectDetailsSet: &pb.EventObjectDetailsSet{
Id: "id1",
SubIds: subIds,
Details: &types.Struct{
Fields: map[string]*types.Value{
bundle.RelationKeyId.String(): pbtypes.String("id1"),
bundle.RelationKeyName.String(): pbtypes.String("task1"),
bundle.RelationKeyPriority.String(): pbtypes.Int64(10),
},
},
},
})
for _, subId := range subIds {
msgs = append(msgs, &pb.EventMessageValueOfSubscriptionAdd{
SubscriptionAdd: &pb.EventObjectSubscriptionAdd{
SubId: subId,
Id: "id1",
},
})
}
for _, subId := range subIds {
msgs = append(msgs, &pb.EventMessageValueOfSubscriptionCounters{
SubscriptionCounters: &pb.EventObjectSubscriptionCounters{
SubId: subId,
Total: 1,
},
})
}
msgs = append(msgs, &pb.EventMessageValueOfObjectDetailsAmend{
ObjectDetailsAmend: &pb.EventObjectDetailsAmend{
Id: "id1",
SubIds: subIds,
Details: []*pb.EventObjectDetailsAmendKeyValue{
{
Key: bundle.RelationKeyName.String(),
Value: pbtypes.String("task1 renamed"),
},
},
},
})
return msgs
}
func givenMessagesForSecondObject(subIds ...string) []pb.IsEventMessageValue {
var msgs []pb.IsEventMessageValue
msgs = append(msgs, &pb.EventMessageValueOfObjectDetailsSet{
ObjectDetailsSet: &pb.EventObjectDetailsSet{
Id: "id2",
SubIds: subIds,
Details: &types.Struct{
Fields: map[string]*types.Value{
bundle.RelationKeyId.String(): pbtypes.String("id2"),
bundle.RelationKeyName.String(): pbtypes.String("task2"),
bundle.RelationKeyPriority.String(): pbtypes.Int64(10),
},
},
},
})
for _, subId := range subIds {
msgs = append(msgs, &pb.EventMessageValueOfSubscriptionAdd{
SubscriptionAdd: &pb.EventObjectSubscriptionAdd{
SubId: subId,
AfterId: "id1",
Id: "id2",
},
})
}
for _, subId := range subIds {
msgs = append(msgs, &pb.EventMessageValueOfSubscriptionCounters{
SubscriptionCounters: &pb.EventObjectSubscriptionCounters{
SubId: subId,
Total: 2,
},
})
}
for _, subId := range subIds {
msgs = append(msgs, &pb.EventMessageValueOfSubscriptionRemove{
SubscriptionRemove: &pb.EventObjectSubscriptionRemove{
Id: "id2",
SubId: subId,
},
})
}
for _, subId := range subIds {
msgs = append(msgs, &pb.EventMessageValueOfSubscriptionCounters{
SubscriptionCounters: &pb.EventObjectSubscriptionCounters{
SubId: subId,
Total: 1,
},
})
}
return msgs
}
func givenMessagesForThirdObject(total int, afterId string, subIds ...string) []pb.IsEventMessageValue {
var msgs []pb.IsEventMessageValue
msgs = append(msgs, &pb.EventMessageValueOfObjectDetailsSet{
ObjectDetailsSet: &pb.EventObjectDetailsSet{
Id: "id3",
SubIds: subIds,
Details: &types.Struct{
Fields: map[string]*types.Value{
bundle.RelationKeyId.String(): pbtypes.String("id3"),
bundle.RelationKeyName.String(): pbtypes.String("Jane Doe"),
bundle.RelationKeyPriority.String(): pbtypes.Int64(10),
},
},
},
})
for _, subId := range subIds {
msgs = append(msgs, &pb.EventMessageValueOfSubscriptionAdd{
SubscriptionAdd: &pb.EventObjectSubscriptionAdd{
SubId: subId,
Id: "id3",
AfterId: afterId,
},
})
}
for _, subId := range subIds {
msgs = append(msgs, &pb.EventMessageValueOfSubscriptionCounters{
SubscriptionCounters: &pb.EventObjectSubscriptionCounters{
SubId: subId,
Total: int64(total),
},
})
}
return msgs
}

View file

@ -13,6 +13,8 @@ import (
session "github.com/anyproto/anytype-heart/core/session"
subscription "github.com/anyproto/anytype-heart/core/subscription"
types "github.com/gogo/protobuf/types"
)
@ -213,27 +215,27 @@ func (_c *MockService_Run_Call) RunAndReturn(run func(context.Context) error) *M
}
// Search provides a mock function with given fields: req
func (_m *MockService) Search(req pb.RpcObjectSearchSubscribeRequest) (*pb.RpcObjectSearchSubscribeResponse, error) {
func (_m *MockService) Search(req subscription.SubscribeRequest) (*subscription.SubscribeResponse, error) {
ret := _m.Called(req)
if len(ret) == 0 {
panic("no return value specified for Search")
}
var r0 *pb.RpcObjectSearchSubscribeResponse
var r0 *subscription.SubscribeResponse
var r1 error
if rf, ok := ret.Get(0).(func(pb.RpcObjectSearchSubscribeRequest) (*pb.RpcObjectSearchSubscribeResponse, error)); ok {
if rf, ok := ret.Get(0).(func(subscription.SubscribeRequest) (*subscription.SubscribeResponse, error)); ok {
return rf(req)
}
if rf, ok := ret.Get(0).(func(pb.RpcObjectSearchSubscribeRequest) *pb.RpcObjectSearchSubscribeResponse); ok {
if rf, ok := ret.Get(0).(func(subscription.SubscribeRequest) *subscription.SubscribeResponse); ok {
r0 = rf(req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*pb.RpcObjectSearchSubscribeResponse)
r0 = ret.Get(0).(*subscription.SubscribeResponse)
}
}
if rf, ok := ret.Get(1).(func(pb.RpcObjectSearchSubscribeRequest) error); ok {
if rf, ok := ret.Get(1).(func(subscription.SubscribeRequest) error); ok {
r1 = rf(req)
} else {
r1 = ret.Error(1)
@ -248,24 +250,24 @@ type MockService_Search_Call struct {
}
// Search is a helper method to define mock.On call
// - req pb.RpcObjectSearchSubscribeRequest
// - req subscription.SubscribeRequest
func (_e *MockService_Expecter) Search(req interface{}) *MockService_Search_Call {
return &MockService_Search_Call{Call: _e.mock.On("Search", req)}
}
func (_c *MockService_Search_Call) Run(run func(req pb.RpcObjectSearchSubscribeRequest)) *MockService_Search_Call {
func (_c *MockService_Search_Call) Run(run func(req subscription.SubscribeRequest)) *MockService_Search_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(pb.RpcObjectSearchSubscribeRequest))
run(args[0].(subscription.SubscribeRequest))
})
return _c
}
func (_c *MockService_Search_Call) Return(resp *pb.RpcObjectSearchSubscribeResponse, err error) *MockService_Search_Call {
func (_c *MockService_Search_Call) Return(resp *subscription.SubscribeResponse, err error) *MockService_Search_Call {
_c.Call.Return(resp, err)
return _c
}
func (_c *MockService_Search_Call) RunAndReturn(run func(pb.RpcObjectSearchSubscribeRequest) (*pb.RpcObjectSearchSubscribeResponse, error)) *MockService_Search_Call {
func (_c *MockService_Search_Call) RunAndReturn(run func(subscription.SubscribeRequest) (*subscription.SubscribeResponse, error)) *MockService_Search_Call {
_c.Call.Return(run)
return _c
}

View file

@ -2,15 +2,17 @@ package subscription
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/anyproto/any-sync/app"
"github.com/cheggaaa/mb"
mb2 "github.com/cheggaaa/mb/v3"
"github.com/globalsign/mgo/bson"
"github.com/gogo/protobuf/types"
"github.com/samber/lo"
"golang.org/x/exp/slices"
"github.com/anyproto/anytype-heart/core/domain"
@ -38,8 +40,40 @@ func New() Service {
return &service{}
}
type SubscribeRequest struct {
SubId string
Filters []*model.BlockContentDataviewFilter
Sorts []*model.BlockContentDataviewSort
Limit int64
Offset int64
// (required) needed keys in details for return, for object fields mw will return (and subscribe) objects as dependent
Keys []string
// (optional) pagination: middleware will return results after given id
AfterId string
// (optional) pagination: middleware will return results before given id
BeforeId string
Source []string
IgnoreWorkspace string
// disable dependent subscription
NoDepSubscription bool
CollectionId string
// Internal indicates that subscription will send events into message queue instead of global client's event system
Internal bool
}
type SubscribeResponse struct {
SubId string
Records []*types.Struct
Dependencies []*types.Struct
Counters *pb.EventObjectSubscriptionCounters
// Used when Internal flag is set to true
Output *mb2.MB[*pb.EventMessage]
}
type Service interface {
Search(req pb.RpcObjectSearchSubscribeRequest) (resp *pb.RpcObjectSearchSubscribeResponse, err error)
Search(req SubscribeRequest) (resp *SubscribeResponse, err error)
SubscribeIdsReq(req pb.RpcObjectSubscribeIdsRequest) (resp *pb.RpcObjectSubscribeIdsResponse, err error)
SubscribeIds(subId string, ids []string) (records []*types.Struct, err error)
SubscribeGroups(ctx session.Context, req pb.RpcObjectGroupsSubscribeRequest) (*pb.RpcObjectGroupsSubscribeResponse, error)
@ -56,6 +90,7 @@ type subscription interface {
onChange(ctx *opCtx)
getActiveRecords() (res []*types.Struct)
hasDep() bool
getDep() subscription
close()
}
@ -65,10 +100,14 @@ type CollectionService interface {
}
type service struct {
cache *cache
ds *dependencyService
subscriptions map[string]subscription
recBatch *mb.MB
cache *cache
ds *dependencyService
subscriptionKeys []string
subscriptions map[string]subscription
customOutput map[string]*mb2.MB[*pb.EventMessage]
recBatch *mb.MB
objectStore objectstore.ObjectStore
kanban kanban.Service
@ -85,11 +124,12 @@ func (s *service) Init(a *app.App) (err error) {
s.cache = newCache()
s.ds = newDependencyService(s)
s.subscriptions = make(map[string]subscription)
s.objectStore = a.MustComponent(objectstore.CName).(objectstore.ObjectStore)
s.kanban = a.MustComponent(kanban.CName).(kanban.Service)
s.customOutput = map[string]*mb2.MB[*pb.EventMessage]{}
s.objectStore = app.MustComponent[objectstore.ObjectStore](a)
s.kanban = app.MustComponent[kanban.Service](a)
s.recBatch = mb.New(0)
s.collectionService = app.MustComponent[CollectionService](a)
s.eventSender = a.MustComponent(event.CName).(event.Sender)
s.eventSender = app.MustComponent[event.Sender](a)
s.ctxBuf = &opCtx{c: s.cache}
s.initDebugger()
return
@ -107,7 +147,33 @@ func (s *service) Run(context.Context) (err error) {
return
}
func (s *service) Search(req pb.RpcObjectSearchSubscribeRequest) (*pb.RpcObjectSearchSubscribeResponse, error) {
func (s *service) getSubscription(id string) (subscription, bool) {
sub, ok := s.subscriptions[id]
return sub, ok
}
func (s *service) setSubscription(id string, sub subscription) {
s.subscriptions[id] = sub
if !slices.Contains(s.subscriptionKeys, id) {
s.subscriptionKeys = append(s.subscriptionKeys, id)
}
}
func (s *service) deleteSubscription(id string) {
delete(s.subscriptions, id)
s.subscriptionKeys = slice.RemoveMut(s.subscriptionKeys, id)
}
func (s *service) iterateSubscriptions(proc func(sub subscription)) {
for _, subId := range s.subscriptionKeys {
sub, ok := s.getSubscription(subId)
if ok && sub != nil {
proc(sub)
}
}
}
func (s *service) Search(req SubscribeRequest) (*SubscribeResponse, error) {
if req.SubId == "" {
req.SubId = bson.NewObjectId().Hex()
}
@ -135,9 +201,9 @@ func (s *service) Search(req pb.RpcObjectSearchSubscribeRequest) (*pb.RpcObjectS
defer s.m.Unlock()
filterDepIds := s.depIdsFromFilter(req.Filters)
if exists, ok := s.subscriptions[req.SubId]; ok {
delete(s.subscriptions, req.SubId)
exists.close()
if existing, ok := s.getSubscription(req.SubId); ok {
s.deleteSubscription(req.SubId)
existing.close()
}
if req.Offset < 0 {
req.Offset = 0
@ -152,7 +218,7 @@ func (s *service) Search(req pb.RpcObjectSearchSubscribeRequest) (*pb.RpcObjectS
return s.subscribeForQuery(req, f, filterDepIds)
}
func (s *service) subscribeForQuery(req pb.RpcObjectSearchSubscribeRequest, f *database.Filters, filterDepIds []string) (*pb.RpcObjectSearchSubscribeResponse, error) {
func (s *service) subscribeForQuery(req SubscribeRequest, f *database.Filters, filterDepIds []string) (*SubscribeResponse, error) {
sub := s.newSortedSub(req.SubId, req.Keys, f.FilterObj, f.Order, int(req.Limit), int(req.Offset))
if req.NoDepSubscription {
sub.disableDep = true
@ -176,7 +242,7 @@ func (s *service) subscribeForQuery(req pb.RpcObjectSearchSubscribeRequest, f *d
sub.nested = append(sub.nested, childSub)
childSub.parent = sub
childSub.parentFilter = f
s.subscriptions[childSub.id] = childSub
s.setSubscription(childSub.id, childSub)
}
return nil
})
@ -189,7 +255,7 @@ func (s *service) subscribeForQuery(req pb.RpcObjectSearchSubscribeRequest, f *d
if err != nil {
return nil, fmt.Errorf("init sub entries: %w", err)
}
s.subscriptions[sub.id] = sub
s.setSubscription(sub.id, sub)
prev, next := sub.counters()
var depRecords, subRecords []*types.Struct
@ -198,8 +264,10 @@ func (s *service) subscribeForQuery(req pb.RpcObjectSearchSubscribeRequest, f *d
if sub.depSub != nil {
depRecords = sub.depSub.getActiveRecords()
}
return &pb.RpcObjectSearchSubscribeResponse{
if req.Internal {
s.customOutput[req.SubId] = mb2.New[*pb.EventMessage](0)
}
return &SubscribeResponse{
Records: subRecords,
Dependencies: depRecords,
SubId: sub.id,
@ -208,6 +276,7 @@ func (s *service) subscribeForQuery(req pb.RpcObjectSearchSubscribeRequest, f *d
NextCount: int64(prev),
PrevCount: int64(next),
},
Output: s.customOutput[req.SubId],
}, nil
}
@ -237,7 +306,7 @@ func queryEntries(objectStore objectstore.ObjectStore, f *database.Filters) ([]*
return entries, nil
}
func (s *service) subscribeForCollection(req pb.RpcObjectSearchSubscribeRequest, f *database.Filters, filterDepIds []string) (*pb.RpcObjectSearchSubscribeResponse, error) {
func (s *service) subscribeForCollection(req SubscribeRequest, f *database.Filters, filterDepIds []string) (*SubscribeResponse, error) {
sub, err := s.newCollectionSub(req.SubId, req.CollectionId, req.Keys, filterDepIds, f.FilterObj, f.Order, int(req.Limit), int(req.Offset), req.NoDepSubscription)
if err != nil {
return nil, err
@ -245,7 +314,7 @@ func (s *service) subscribeForCollection(req pb.RpcObjectSearchSubscribeRequest,
if err := sub.init(nil); err != nil {
return nil, fmt.Errorf("subscription init error: %w", err)
}
s.subscriptions[sub.sortedSub.id] = sub
s.setSubscription(sub.sortedSub.id, sub)
prev, next := sub.counters()
var depRecords, subRecords []*types.Struct
@ -255,7 +324,11 @@ func (s *service) subscribeForCollection(req pb.RpcObjectSearchSubscribeRequest,
depRecords = sub.sortedSub.depSub.getActiveRecords()
}
return &pb.RpcObjectSearchSubscribeResponse{
if req.Internal {
s.customOutput[req.SubId] = mb2.New[*pb.EventMessage](0)
}
return &SubscribeResponse{
Records: subRecords,
Dependencies: depRecords,
SubId: sub.sortedSub.id,
@ -264,6 +337,7 @@ func (s *service) subscribeForCollection(req pb.RpcObjectSearchSubscribeRequest,
NextCount: int64(prev),
PrevCount: int64(next),
},
Output: s.customOutput[req.SubId],
}, nil
}
@ -291,7 +365,7 @@ func (s *service) SubscribeIdsReq(req pb.RpcObjectSubscribeIdsRequest) (resp *pb
if err = sub.init(entries); err != nil {
return
}
s.subscriptions[sub.id] = sub
s.setSubscription(sub.id, sub)
var depRecords, subRecords []*types.Struct
subRecords = sub.getActiveRecords()
@ -390,7 +464,7 @@ func (s *service) SubscribeGroups(ctx session.Context, req pb.RpcObjectGroupsSub
if err := sub.init(entries); err != nil {
return nil, err
}
s.subscriptions[subId] = sub
s.setSubscription(subId, sub)
} else if colObserver != nil {
colObserver.close()
}
@ -406,16 +480,24 @@ func (s *service) SubscribeIds(subId string, ids []string) (records []*types.Str
return
}
func (s *service) Unsubscribe(subIds ...string) (err error) {
func (s *service) Unsubscribe(subIds ...string) error {
s.m.Lock()
defer s.m.Unlock()
for _, subId := range subIds {
if sub, ok := s.subscriptions[subId]; ok {
if sub, ok := s.getSubscription(subId); ok {
out := s.customOutput[subId]
if out != nil {
err := out.Close()
if err != nil {
return fmt.Errorf("close subscription %s: %w", subId, err)
}
s.customOutput[subId] = nil
}
sub.close()
delete(s.subscriptions, subId)
s.deleteSubscription(subId)
}
}
return
return nil
}
func (s *service) UnsubscribeAll() (err error) {
@ -425,13 +507,14 @@ func (s *service) UnsubscribeAll() (err error) {
sub.close()
}
s.subscriptions = make(map[string]subscription)
s.subscriptionKeys = s.subscriptionKeys[:0]
return
}
func (s *service) SubscriptionIDs() []string {
s.m.Lock()
defer s.m.Unlock()
return lo.Keys(s.subscriptions)
return s.subscriptionKeys
}
func (s *service) recordsHandler() {
@ -480,21 +563,52 @@ func (s *service) onChange(entries []*entry) time.Duration {
st := time.Now()
s.ctxBuf.reset()
s.ctxBuf.entries = entries
for _, sub := range s.subscriptions {
s.iterateSubscriptions(func(sub subscription) {
sub.onChange(s.ctxBuf)
subCount++
if sub.hasDep() {
sub.getDep().onChange(s.ctxBuf)
depCount++
}
}
})
handleTime := time.Since(st)
event := s.ctxBuf.apply()
// Reset output buffer
for subId := range s.ctxBuf.outputs {
if subId == defaultOutput {
s.ctxBuf.outputs[subId] = nil
} else if _, ok := s.customOutput[subId]; ok {
s.ctxBuf.outputs[subId] = nil
} else {
delete(s.ctxBuf.outputs, subId)
}
}
for subId := range s.customOutput {
if _, ok := s.ctxBuf.outputs[subId]; !ok {
s.ctxBuf.outputs[subId] = nil
}
}
s.ctxBuf.apply()
dur := time.Since(st)
s.debugEvents(event)
for id, msgs := range s.ctxBuf.outputs {
if len(msgs) > 0 {
s.debugEvents(&pb.Event{Messages: msgs})
if id == defaultOutput {
s.eventSender.Broadcast(&pb.Event{Messages: msgs})
} else {
err := s.customOutput[id].Add(context.TODO(), msgs...)
if err != nil && !errors.Is(err, mb2.ErrClosed) {
log.With("subId", id, "error", err).Errorf("push to output")
}
}
}
}
log.Debugf("handle %d entries; %v(handle:%v;genEvents:%v); cacheSize: %d; subCount:%d; subDepCount:%d", len(entries), dur, handleTime, dur-handleTime, len(s.cache.entries), subCount, depCount)
s.eventSender.Broadcast(event)
return dur
}
@ -564,8 +678,8 @@ func (s *service) Close(ctx context.Context) (err error) {
s.m.Lock()
defer s.m.Unlock()
s.recBatch.Close()
for _, sub := range s.subscriptions {
s.iterateSubscriptions(func(sub subscription) {
sub.close()
}
})
return
}

View file

@ -32,14 +32,8 @@ func TestService_Search(t *testing.T) {
},
nil,
)
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyName.String()).Return(&model.Relation{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyAuthor.String()).Return(&model.Relation{
Key: bundle.RelationKeyAuthor.String(),
Format: model.RelationFormat_object,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyName.String()).Return(model.RelationFormat_shorttext, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyAuthor.String()).Return(model.RelationFormat_object, nil).AnyTimes()
fx.store.EXPECT().QueryByID([]string{"author1"}).Return([]database.Record{
{Details: &types.Struct{Fields: map[string]*types.Value{
@ -48,7 +42,7 @@ func TestService_Search(t *testing.T) {
}}},
}, nil).AnyTimes()
resp, err := fx.Search(pb.RpcObjectSearchSubscribeRequest{
resp, err := fx.Search(SubscribeRequest{
SubId: subId,
Keys: []string{bundle.RelationKeyName.String(), bundle.RelationKeyAuthor.String()},
})
@ -135,14 +129,8 @@ func TestService_Search(t *testing.T) {
},
nil,
)
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyName.String()).Return(&model.Relation{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyAuthor.String()).Return(&model.Relation{
Key: bundle.RelationKeyAuthor.String(),
Format: model.RelationFormat_object,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyName.String()).Return(model.RelationFormat_shorttext, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyAuthor.String()).Return(model.RelationFormat_object, nil).AnyTimes()
fx.store.EXPECT().QueryByID([]string{"force1", "force2"}).Return([]database.Record{
{Details: &types.Struct{Fields: map[string]*types.Value{
@ -155,7 +143,7 @@ func TestService_Search(t *testing.T) {
}}},
}, nil)
var resp, err = fx.Search(pb.RpcObjectSearchSubscribeRequest{
var resp, err = fx.Search(SubscribeRequest{
SubId: "subId",
Keys: []string{bundle.RelationKeyName.String(), bundle.RelationKeyAuthor.String()},
Filters: []*model.BlockContentDataviewFilter{
@ -194,12 +182,9 @@ func TestService_Search(t *testing.T) {
},
nil,
)
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyName.String()).Return(&model.Relation{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyName.String()).Return(model.RelationFormat_shorttext, nil).AnyTimes()
resp, err := fx.Search(pb.RpcObjectSearchSubscribeRequest{
resp, err := fx.Search(SubscribeRequest{
SubId: "test",
Sorts: []*model.BlockContentDataviewSort{
{
@ -261,12 +246,9 @@ func TestService_Search(t *testing.T) {
},
nil,
)
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyName.String()).Return(&model.Relation{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyName.String()).Return(model.RelationFormat_shorttext, nil).AnyTimes()
resp, err := fx.Search(pb.RpcObjectSearchSubscribeRequest{
resp, err := fx.Search(SubscribeRequest{
SubId: "test",
Sorts: []*model.BlockContentDataviewSort{
{
@ -309,7 +291,7 @@ func TestService_Search(t *testing.T) {
collectionID := "id"
subscriptionID := "subId"
fx.collectionService.EXPECT().SubscribeForCollection(collectionID, subscriptionID).Return(nil, nil, fmt.Errorf("error"))
var resp, err = fx.Search(pb.RpcObjectSearchSubscribeRequest{
var resp, err = fx.Search(SubscribeRequest{
SubId: "subId",
CollectionId: collectionID,
})
@ -326,7 +308,7 @@ func TestService_Search(t *testing.T) {
subscriptionID := "subId"
fx.collectionService.EXPECT().SubscribeForCollection(collectionID, subscriptionID).Return(nil, nil, nil)
fx.collectionService.EXPECT().UnsubscribeFromCollection(collectionID, subscriptionID).Return()
var resp, err = fx.Search(pb.RpcObjectSearchSubscribeRequest{
var resp, err = fx.Search(SubscribeRequest{
SubId: subscriptionID,
CollectionId: collectionID,
})
@ -359,16 +341,10 @@ func TestService_Search(t *testing.T) {
}}},
}, nil)
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyName.String()).Return(&model.Relation{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyId.String()).Return(&model.Relation{
Key: bundle.RelationKeyId.String(),
Format: model.RelationFormat_shorttext,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyName.String()).Return(model.RelationFormat_shorttext, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyId.String()).Return(model.RelationFormat_shorttext, nil).AnyTimes()
var resp, err = fx.Search(pb.RpcObjectSearchSubscribeRequest{
var resp, err = fx.Search(SubscribeRequest{
SubId: subscriptionID,
Keys: []string{bundle.RelationKeyName.String(), bundle.RelationKeyId.String()},
CollectionId: collectionID,
@ -409,16 +385,10 @@ func TestService_Search(t *testing.T) {
}}},
}, nil)
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyName.String()).Return(&model.Relation{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyId.String()).Return(&model.Relation{
Key: bundle.RelationKeyId.String(),
Format: model.RelationFormat_shorttext,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyName.String()).Return(model.RelationFormat_shorttext, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyId.String()).Return(model.RelationFormat_shorttext, nil).AnyTimes()
var resp, err = fx.Search(pb.RpcObjectSearchSubscribeRequest{
var resp, err = fx.Search(SubscribeRequest{
SubId: subscriptionID,
Keys: []string{bundle.RelationKeyName.String(), bundle.RelationKeyId.String()},
CollectionId: collectionID,
@ -466,16 +436,10 @@ func TestService_Search(t *testing.T) {
}}},
}, nil)
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyName.String()).Return(&model.Relation{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyId.String()).Return(&model.Relation{
Key: bundle.RelationKeyId.String(),
Format: model.RelationFormat_shorttext,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyName.String()).Return(model.RelationFormat_shorttext, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyId.String()).Return(model.RelationFormat_shorttext, nil).AnyTimes()
var resp, err = fx.Search(pb.RpcObjectSearchSubscribeRequest{
var resp, err = fx.Search(SubscribeRequest{
SubId: subscriptionID,
Keys: []string{bundle.RelationKeyName.String(), bundle.RelationKeyId.String()},
CollectionId: collectionID,
@ -508,20 +472,14 @@ func TestService_Search(t *testing.T) {
}}},
}, nil)
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyName.String()).Return(&model.Relation{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyName.String()).Return(model.RelationFormat_shorttext, nil).AnyTimes()
fx.store.EXPECT().GetRelationByKey(testRelationKey).Return(&model.Relation{
Key: testRelationKey,
Format: model.RelationFormat_object,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(testRelationKey).Return(model.RelationFormat_object, nil).AnyTimes()
s := fx.Service.(*service)
s.ds = newDependencyService(s)
var resp, err = fx.Search(pb.RpcObjectSearchSubscribeRequest{
var resp, err = fx.Search(SubscribeRequest{
SubId: subscriptionID,
Keys: []string{bundle.RelationKeyName.String(), bundle.RelationKeyId.String(), testRelationKey},
CollectionId: collectionID,
@ -563,20 +521,14 @@ func TestService_Search(t *testing.T) {
}}},
}, nil)
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyName.String()).Return(&model.Relation{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyName.String()).Return(model.RelationFormat_shorttext, nil).AnyTimes()
fx.store.EXPECT().GetRelationByKey(testRelationKey).Return(&model.Relation{
Key: testRelationKey,
Format: model.RelationFormat_object,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(testRelationKey).Return(model.RelationFormat_object, nil).AnyTimes()
s := fx.Service.(*service)
s.ds = newDependencyService(s)
var resp, err = fx.Search(pb.RpcObjectSearchSubscribeRequest{
var resp, err = fx.Search(SubscribeRequest{
SubId: subscriptionID,
Keys: []string{bundle.RelationKeyName.String(), bundle.RelationKeyId.String(), testRelationKey},
CollectionId: collectionID,
@ -618,16 +570,10 @@ func TestService_Search(t *testing.T) {
}}},
}, nil)
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyName.String()).Return(&model.Relation{
Key: bundle.RelationKeyName.String(),
Format: model.RelationFormat_shorttext,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationByKey(bundle.RelationKeyId.String()).Return(&model.Relation{
Key: bundle.RelationKeyId.String(),
Format: model.RelationFormat_shorttext,
}, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyName.String()).Return(model.RelationFormat_shorttext, nil).AnyTimes()
fx.store.EXPECT().GetRelationFormatByKey(bundle.RelationKeyId.String()).Return(model.RelationFormat_shorttext, nil).AnyTimes()
var resp, err = fx.Search(pb.RpcObjectSearchSubscribeRequest{
var resp, err = fx.Search(SubscribeRequest{
SubId: subscriptionID,
Keys: []string{bundle.RelationKeyName.String(), bundle.RelationKeyId.String()},
CollectionId: collectionID,
@ -1066,8 +1012,8 @@ func xTestNestedSubscription(t *testing.T) {
func testCreateSubscriptionWithNestedFilter(t *testing.T) *fixtureRealStore {
fx := newFixtureWithRealObjectStore(t)
// fx.store.EXPECT().GetRelationByKey(mock.Anything).Return(&model.Relation{}, nil)
resp, err := fx.Search(pb.RpcObjectSearchSubscribeRequest{
// fx.store.EXPECT().GetRelationFormatByKey(mock.Anything).Return(&model.Relation{}, nil)
resp, err := fx.Search(SubscribeRequest{
SubId: "test",
Filters: []*model.BlockContentDataviewFilter{
{

View file

@ -133,6 +133,10 @@ func (s *simpleSub) hasDep() bool {
return s.depSub != nil
}
func (s *simpleSub) getDep() subscription {
return s.depSub
}
func (s *simpleSub) close() {
for id := range s.set {
s.cache.RemoveSubId(id, s.id)

View file

@ -373,6 +373,10 @@ func (s *sortedSub) hasDep() bool {
return s.depSub != nil
}
func (s *sortedSub) getDep() subscription {
return s.depSub
}
func (s *sortedSub) close() {
el := s.skl.Front()
for el != nil {

View file

@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/anyproto/anytype-heart/pb"
"github.com/anyproto/anytype-heart/pkg/lib/database"
"github.com/anyproto/anytype-heart/pkg/lib/pb/model"
"github.com/anyproto/anytype-heart/util/pbtypes"
@ -34,7 +35,7 @@ func TestSubscription_Add(t *testing.T) {
assert.Len(t, sub.cache.entries, 9)
ctx := &opCtx{c: sub.cache, entries: newEntries}
ctx := &opCtx{c: sub.cache, entries: newEntries, outputs: map[string][]*pb.EventMessage{}}
sub.onChange(ctx)
assertCtxAdd(t, ctx, "newActiveId1", "")
assertCtxAdd(t, ctx, "newActiveId2", "newActiveId1")

View file

@ -5,8 +5,6 @@ package mock_detailsupdater
import (
app "github.com/anyproto/any-sync/app"
domain "github.com/anyproto/anytype-heart/core/domain"
mock "github.com/stretchr/testify/mock"
)
@ -114,35 +112,69 @@ func (_c *MockSpaceStatusUpdater_Name_Call) RunAndReturn(run func() string) *Moc
return _c
}
// SendUpdate provides a mock function with given fields: status
func (_m *MockSpaceStatusUpdater) SendUpdate(status *domain.SpaceSync) {
_m.Called(status)
// Refresh provides a mock function with given fields: spaceId
func (_m *MockSpaceStatusUpdater) Refresh(spaceId string) {
_m.Called(spaceId)
}
// MockSpaceStatusUpdater_SendUpdate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendUpdate'
type MockSpaceStatusUpdater_SendUpdate_Call struct {
// MockSpaceStatusUpdater_Refresh_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Refresh'
type MockSpaceStatusUpdater_Refresh_Call struct {
*mock.Call
}
// SendUpdate is a helper method to define mock.On call
// - status *domain.SpaceSync
func (_e *MockSpaceStatusUpdater_Expecter) SendUpdate(status interface{}) *MockSpaceStatusUpdater_SendUpdate_Call {
return &MockSpaceStatusUpdater_SendUpdate_Call{Call: _e.mock.On("SendUpdate", status)}
// Refresh is a helper method to define mock.On call
// - spaceId string
func (_e *MockSpaceStatusUpdater_Expecter) Refresh(spaceId interface{}) *MockSpaceStatusUpdater_Refresh_Call {
return &MockSpaceStatusUpdater_Refresh_Call{Call: _e.mock.On("Refresh", spaceId)}
}
func (_c *MockSpaceStatusUpdater_SendUpdate_Call) Run(run func(status *domain.SpaceSync)) *MockSpaceStatusUpdater_SendUpdate_Call {
func (_c *MockSpaceStatusUpdater_Refresh_Call) Run(run func(spaceId string)) *MockSpaceStatusUpdater_Refresh_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*domain.SpaceSync))
run(args[0].(string))
})
return _c
}
func (_c *MockSpaceStatusUpdater_SendUpdate_Call) Return() *MockSpaceStatusUpdater_SendUpdate_Call {
func (_c *MockSpaceStatusUpdater_Refresh_Call) Return() *MockSpaceStatusUpdater_Refresh_Call {
_c.Call.Return()
return _c
}
func (_c *MockSpaceStatusUpdater_SendUpdate_Call) RunAndReturn(run func(*domain.SpaceSync)) *MockSpaceStatusUpdater_SendUpdate_Call {
func (_c *MockSpaceStatusUpdater_Refresh_Call) RunAndReturn(run func(string)) *MockSpaceStatusUpdater_Refresh_Call {
_c.Call.Return(run)
return _c
}
// UpdateMissingIds provides a mock function with given fields: spaceId, ids
func (_m *MockSpaceStatusUpdater) UpdateMissingIds(spaceId string, ids []string) {
_m.Called(spaceId, ids)
}
// MockSpaceStatusUpdater_UpdateMissingIds_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateMissingIds'
type MockSpaceStatusUpdater_UpdateMissingIds_Call struct {
*mock.Call
}
// UpdateMissingIds is a helper method to define mock.On call
// - spaceId string
// - ids []string
func (_e *MockSpaceStatusUpdater_Expecter) UpdateMissingIds(spaceId interface{}, ids interface{}) *MockSpaceStatusUpdater_UpdateMissingIds_Call {
return &MockSpaceStatusUpdater_UpdateMissingIds_Call{Call: _e.mock.On("UpdateMissingIds", spaceId, ids)}
}
func (_c *MockSpaceStatusUpdater_UpdateMissingIds_Call) Run(run func(spaceId string, ids []string)) *MockSpaceStatusUpdater_UpdateMissingIds_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string), args[1].([]string))
})
return _c
}
func (_c *MockSpaceStatusUpdater_UpdateMissingIds_Call) Return() *MockSpaceStatusUpdater_UpdateMissingIds_Call {
_c.Call.Return()
return _c
}
func (_c *MockSpaceStatusUpdater_UpdateMissingIds_Call) RunAndReturn(run func(string, []string)) *MockSpaceStatusUpdater_UpdateMissingIds_Call {
_c.Call.Return(run)
return _c
}

View file

@ -12,18 +12,18 @@ import (
"github.com/cheggaaa/mb/v3"
"github.com/gogo/protobuf/types"
"github.com/anyproto/anytype-heart/core/block/editor/basic"
"github.com/anyproto/anytype-heart/core/block/editor/smartblock"
"github.com/anyproto/anytype-heart/core/domain"
"github.com/anyproto/anytype-heart/core/syncstatus/detailsupdater/helper"
"github.com/anyproto/anytype-heart/core/syncstatus/filesyncstatus"
"github.com/anyproto/anytype-heart/core/syncstatus/syncsubscriptions"
"github.com/anyproto/anytype-heart/pkg/lib/bundle"
"github.com/anyproto/anytype-heart/pkg/lib/database"
"github.com/anyproto/anytype-heart/pkg/lib/localstore/objectstore"
"github.com/anyproto/anytype-heart/pkg/lib/logging"
"github.com/anyproto/anytype-heart/pkg/lib/pb/model"
"github.com/anyproto/anytype-heart/space"
"github.com/anyproto/anytype-heart/util/pbtypes"
"github.com/anyproto/anytype-heart/util/slice"
)
var log = logging.Logger(CName)
@ -31,29 +31,31 @@ var log = logging.Logger(CName)
const CName = "core.syncstatus.objectsyncstatus.updater"
type syncStatusDetails struct {
objectIds []string
status domain.ObjectSyncStatus
syncError domain.SyncError
spaceId string
objectId string
status domain.ObjectSyncStatus
spaceId string
}
type Updater interface {
app.ComponentRunnable
UpdateDetails(objectId []string, status domain.ObjectSyncStatus, syncError domain.SyncError, spaceId string)
UpdateSpaceDetails(existing, missing []string, spaceId string)
UpdateDetails(objectId string, status domain.ObjectSyncStatus, spaceId string)
}
type SpaceStatusUpdater interface {
app.Component
SendUpdate(status *domain.SpaceSync)
Refresh(spaceId string)
UpdateMissingIds(spaceId string, ids []string)
}
type syncStatusUpdater struct {
objectStore objectstore.ObjectStore
ctx context.Context
ctxCancel context.CancelFunc
batcher *mb.MB[*syncStatusDetails]
spaceService space.Service
spaceSyncStatus SpaceStatusUpdater
objectStore objectstore.ObjectStore
ctx context.Context
ctxCancel context.CancelFunc
batcher *mb.MB[string]
spaceService space.Service
spaceSyncStatus SpaceStatusUpdater
syncSubscriptions syncsubscriptions.SyncSubscriptions
entries map[string]*syncStatusDetails
mx sync.Mutex
@ -61,9 +63,9 @@ type syncStatusUpdater struct {
finish chan struct{}
}
func NewUpdater() Updater {
func New() Updater {
return &syncStatusUpdater{
batcher: mb.New[*syncStatusDetails](0),
batcher: mb.New[string](0),
finish: make(chan struct{}),
entries: make(map[string]*syncStatusDetails, 0),
}
@ -87,6 +89,7 @@ func (u *syncStatusUpdater) Init(a *app.App) (err error) {
u.objectStore = app.MustComponent[objectstore.ObjectStore](a)
u.spaceService = app.MustComponent[space.Service](a)
u.spaceSyncStatus = app.MustComponent[SpaceStatusUpdater](a)
u.syncSubscriptions = app.MustComponent[syncsubscriptions.SyncSubscriptions](a)
return nil
}
@ -94,94 +97,121 @@ func (u *syncStatusUpdater) Name() (name string) {
return CName
}
func (u *syncStatusUpdater) UpdateDetails(objectId []string, status domain.ObjectSyncStatus, syncError domain.SyncError, spaceId string) {
func (u *syncStatusUpdater) UpdateDetails(objectId string, status domain.ObjectSyncStatus, spaceId string) {
if spaceId == u.spaceService.TechSpaceId() {
return
}
for _, id := range objectId {
u.mx.Lock()
u.entries[id] = &syncStatusDetails{
status: status,
syncError: syncError,
spaceId: spaceId,
}
u.mx.Unlock()
}
err := u.batcher.TryAdd(&syncStatusDetails{
objectIds: objectId,
status: status,
syncError: syncError,
spaceId: spaceId,
err := u.addToQueue(&syncStatusDetails{
objectId: objectId,
status: status,
spaceId: spaceId,
})
if err != nil {
log.Errorf("failed to add sync details update to queue: %s", err)
}
}
func (u *syncStatusUpdater) updateDetails(syncStatusDetails *syncStatusDetails) {
details := u.extractObjectDetails(syncStatusDetails)
for _, detail := range details {
id := pbtypes.GetString(detail.Details, bundle.RelationKeyId.String())
err := u.setObjectDetails(syncStatusDetails, detail.Details, id)
func (u *syncStatusUpdater) addToQueue(details *syncStatusDetails) error {
u.mx.Lock()
u.entries[details.objectId] = details
u.mx.Unlock()
return u.batcher.TryAdd(details.objectId)
}
func (u *syncStatusUpdater) processEvents() {
defer close(u.finish)
for {
objectId, err := u.batcher.WaitOne(u.ctx)
if err != nil {
log.Errorf("failed to update object details %s", err)
return
}
u.updateSpecificObject(objectId)
}
}
func (u *syncStatusUpdater) updateSpecificObject(objectId string) {
u.mx.Lock()
objectStatus := u.entries[objectId]
delete(u.entries, objectId)
u.mx.Unlock()
if objectStatus != nil {
err := u.updateObjectDetails(objectStatus, objectId)
if err != nil {
log.Errorf("failed to update details %s", err)
}
}
}
func (u *syncStatusUpdater) extractObjectDetails(syncStatusDetails *syncStatusDetails) []database.Record {
details, err := u.objectStore.Query(database.Query{
Filters: []*model.BlockContentDataviewFilter{
{
RelationKey: bundle.RelationKeySyncStatus.String(),
Condition: model.BlockContentDataviewFilter_NotEqual,
Value: pbtypes.Int64(int64(syncStatusDetails.status)),
},
{
RelationKey: bundle.RelationKeySpaceId.String(),
Condition: model.BlockContentDataviewFilter_Equal,
Value: pbtypes.String(syncStatusDetails.spaceId),
},
},
})
if err != nil {
log.Errorf("failed to update object details %s", err)
func (u *syncStatusUpdater) UpdateSpaceDetails(existing, missing []string, spaceId string) {
if spaceId == u.spaceService.TechSpaceId() {
return
}
return details
u.spaceSyncStatus.UpdateMissingIds(spaceId, missing)
ids := u.getSyncingObjects(spaceId)
// removed contains ids that are not yet marked as syncing
// added contains ids that were syncing, but appeared as synced, because they are not in existing list
removed, added := slice.DifferenceRemovedAdded(existing, ids)
if len(removed)+len(added) == 0 {
u.spaceSyncStatus.Refresh(spaceId)
return
}
for _, id := range added {
err := u.addToQueue(&syncStatusDetails{
objectId: id,
status: domain.ObjectSyncStatusSynced,
spaceId: spaceId,
})
if err != nil {
log.Errorf("failed to add sync details update to queue: %s", err)
}
}
for _, id := range removed {
err := u.addToQueue(&syncStatusDetails{
objectId: id,
status: domain.ObjectSyncStatusSyncing,
spaceId: spaceId,
})
if err != nil {
log.Errorf("failed to add sync details update to queue: %s", err)
}
}
}
func (u *syncStatusUpdater) getSyncingObjects(spaceId string) []string {
sub, err := u.syncSubscriptions.GetSubscription(spaceId)
if err != nil {
return nil
}
ids := make([]string, 0, sub.GetObjectSubscription().Len())
sub.GetObjectSubscription().Iterate(func(id string, _ struct{}) bool {
ids = append(ids, id)
return true
})
return ids
}
func (u *syncStatusUpdater) updateObjectDetails(syncStatusDetails *syncStatusDetails, objectId string) error {
record, err := u.objectStore.GetDetails(objectId)
if err != nil {
return err
}
return u.setObjectDetails(syncStatusDetails, record.Details, objectId)
}
func (u *syncStatusUpdater) setObjectDetails(syncStatusDetails *syncStatusDetails, record *types.Struct, objectId string) error {
status := syncStatusDetails.status
syncError := syncStatusDetails.syncError
if fileStatus, ok := record.GetFields()[bundle.RelationKeyFileBackupStatus.String()]; ok {
status, syncError = mapFileStatus(filesyncstatus.Status(int(fileStatus.GetNumberValue())))
}
changed := u.hasRelationsChange(record, status, syncError)
if !changed {
return nil
}
if !u.isLayoutSuitableForSyncRelations(record) {
return nil
}
syncError := domain.SyncErrorNull
spc, err := u.spaceService.Get(u.ctx, syncStatusDetails.spaceId)
if err != nil {
return err
}
spaceStatus := mapObjectSyncToSpaceSyncStatus(status, syncError)
defer u.sendSpaceStatusUpdate(err, syncStatusDetails, spaceStatus, syncError)
defer u.spaceSyncStatus.Refresh(syncStatusDetails.spaceId)
err = spc.DoLockedIfNotExists(objectId, func() error {
return u.objectStore.ModifyObjectDetails(objectId, func(details *types.Struct) (*types.Struct, error) {
if details == nil || details.Fields == nil {
details = &types.Struct{Fields: map[string]*types.Value{}}
}
if !u.isLayoutSuitableForSyncRelations(details) {
return details, nil
}
if fileStatus, ok := details.GetFields()[bundle.RelationKeyFileBackupStatus.String()]; ok {
status, syncError = getSyncStatusForFile(status, syncError, filesyncstatus.Status(int(fileStatus.GetNumberValue())))
}
details.Fields[bundle.RelationKeySyncStatus.String()] = pbtypes.Int64(int64(status))
details.Fields[bundle.RelationKeySyncError.String()] = pbtypes.Int64(int64(syncError))
details.Fields[bundle.RelationKeySyncDate.String()] = pbtypes.Int64(time.Now().Unix())
@ -199,120 +229,71 @@ func (u *syncStatusUpdater) setObjectDetails(syncStatusDetails *syncStatusDetail
})
}
func (u *syncStatusUpdater) isLayoutSuitableForSyncRelations(details *types.Struct) bool {
layoutsWithoutSyncRelations := []float64{
float64(model.ObjectType_participant),
float64(model.ObjectType_dashboard),
float64(model.ObjectType_spaceView),
float64(model.ObjectType_space),
float64(model.ObjectType_date),
}
layout := details.Fields[bundle.RelationKeyLayout.String()].GetNumberValue()
return !slices.Contains(layoutsWithoutSyncRelations, layout)
}
func mapObjectSyncToSpaceSyncStatus(status domain.ObjectSyncStatus, syncError domain.SyncError) domain.SpaceSyncStatus {
switch status {
case domain.ObjectSynced:
return domain.Synced
case domain.ObjectSyncing, domain.ObjectQueued:
return domain.Syncing
case domain.ObjectError:
// don't send error to space if file were oversized
if syncError != domain.Oversized {
return domain.Error
}
}
return domain.Synced
}
func (u *syncStatusUpdater) sendSpaceStatusUpdate(err error, syncStatusDetails *syncStatusDetails, status domain.SpaceSyncStatus, syncError domain.SyncError) {
if err == nil {
u.spaceSyncStatus.SendUpdate(domain.MakeSyncStatus(syncStatusDetails.spaceId, status, syncError, domain.Objects))
}
}
func mapFileStatus(status filesyncstatus.Status) (domain.ObjectSyncStatus, domain.SyncError) {
var syncError domain.SyncError
switch status {
case filesyncstatus.Syncing:
return domain.ObjectSyncing, domain.Null
case filesyncstatus.Queued:
return domain.ObjectQueued, domain.Null
case filesyncstatus.Limited:
syncError = domain.Oversized
return domain.ObjectError, syncError
case filesyncstatus.Unknown:
syncError = domain.NetworkError
return domain.ObjectError, syncError
default:
return domain.ObjectSynced, domain.Null
}
}
func (u *syncStatusUpdater) setSyncDetails(sb smartblock.SmartBlock, status domain.ObjectSyncStatus, syncError domain.SyncError) error {
if !slices.Contains(helper.SyncRelationsSmartblockTypes(), sb.Type()) {
return nil
}
if d, ok := sb.(basic.DetailsSettable); ok {
syncStatusDetails := []*model.Detail{
{
Key: bundle.RelationKeySyncStatus.String(),
Value: pbtypes.Int64(int64(status)),
},
}
syncStatusDetails = append(syncStatusDetails, &model.Detail{
Key: bundle.RelationKeySyncError.String(),
Value: pbtypes.Int64(int64(syncError)),
})
syncStatusDetails = append(syncStatusDetails, &model.Detail{
Key: bundle.RelationKeySyncDate.String(),
Value: pbtypes.Int64(time.Now().Unix()),
})
return d.SetDetails(nil, syncStatusDetails, false)
if !u.isLayoutSuitableForSyncRelations(sb.Details()) {
return nil
}
return nil
st := sb.NewState()
if fileStatus, ok := st.Details().GetFields()[bundle.RelationKeyFileBackupStatus.String()]; ok {
status, syncError = getSyncStatusForFile(status, syncError, filesyncstatus.Status(int(fileStatus.GetNumberValue())))
}
st.SetDetailAndBundledRelation(bundle.RelationKeySyncStatus, pbtypes.Int64(int64(status)))
st.SetDetailAndBundledRelation(bundle.RelationKeySyncError, pbtypes.Int64(int64(syncError)))
st.SetDetailAndBundledRelation(bundle.RelationKeySyncDate, pbtypes.Int64(time.Now().Unix()))
return sb.Apply(st, smartblock.KeepInternalFlags /* do not erase flags */)
}
func (u *syncStatusUpdater) hasRelationsChange(record *types.Struct, status domain.ObjectSyncStatus, syncError domain.SyncError) bool {
var changed bool
if record == nil || len(record.GetFields()) == 0 {
changed = true
}
if pbtypes.Get(record, bundle.RelationKeySyncStatus.String()) == nil ||
pbtypes.Get(record, bundle.RelationKeySyncError.String()) == nil {
changed = true
}
if pbtypes.GetInt64(record, bundle.RelationKeySyncStatus.String()) != int64(status) {
changed = true
}
if pbtypes.GetInt64(record, bundle.RelationKeySyncError.String()) != int64(syncError) {
changed = true
}
return changed
var suitableLayouts = map[model.ObjectTypeLayout]struct{}{
model.ObjectType_basic: {},
model.ObjectType_profile: {},
model.ObjectType_todo: {},
model.ObjectType_set: {},
model.ObjectType_objectType: {},
model.ObjectType_relation: {},
model.ObjectType_file: {},
model.ObjectType_image: {},
model.ObjectType_note: {},
model.ObjectType_bookmark: {},
model.ObjectType_relationOption: {},
model.ObjectType_collection: {},
model.ObjectType_audio: {},
model.ObjectType_video: {},
model.ObjectType_pdf: {},
}
func (u *syncStatusUpdater) processEvents() {
defer close(u.finish)
for {
status, err := u.batcher.WaitOne(u.ctx)
if err != nil {
return
}
for _, id := range status.objectIds {
u.mx.Lock()
objectStatus := u.entries[id]
delete(u.entries, id)
u.mx.Unlock()
if objectStatus != nil {
err := u.updateObjectDetails(objectStatus, id)
if err != nil {
log.Errorf("failed to update details %s", err)
}
}
}
if len(status.objectIds) == 0 {
u.updateDetails(status)
}
func (u *syncStatusUpdater) isLayoutSuitableForSyncRelations(details *types.Struct) bool {
layout := model.ObjectTypeLayout(pbtypes.GetInt64(details, bundle.RelationKeyLayout.String()))
_, ok := suitableLayouts[layout]
return ok
}
func getSyncStatusForFile(objectStatus domain.ObjectSyncStatus, objectSyncError domain.SyncError, fileStatus filesyncstatus.Status) (domain.ObjectSyncStatus, domain.SyncError) {
statusFromFile, errFromFile := mapFileStatus(fileStatus)
// If file status is synced, then prioritize object's status, otherwise pick file status
if statusFromFile != domain.ObjectSyncStatusSynced {
objectStatus = statusFromFile
}
if errFromFile != domain.SyncErrorNull {
objectSyncError = errFromFile
}
return objectStatus, objectSyncError
}
func mapFileStatus(status filesyncstatus.Status) (domain.ObjectSyncStatus, domain.SyncError) {
switch status {
case filesyncstatus.Syncing:
return domain.ObjectSyncStatusSyncing, domain.SyncErrorNull
case filesyncstatus.Queued:
return domain.ObjectSyncStatusQueued, domain.SyncErrorNull
case filesyncstatus.Limited:
return domain.ObjectSyncStatusError, domain.SyncErrorOversized
case filesyncstatus.Unknown:
return domain.ObjectSyncStatusError, domain.SyncErrorNetworkError
default:
return domain.ObjectSyncStatusSynced, domain.SyncErrorNull
}
}

View file

@ -3,19 +3,24 @@ package detailsupdater
import (
"context"
"testing"
"time"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/ocache"
"github.com/cheggaaa/mb/v3"
"github.com/gogo/protobuf/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/anyproto/anytype-heart/core/block/editor"
"github.com/anyproto/anytype-heart/core/block/editor/smartblock"
"github.com/anyproto/anytype-heart/core/block/editor/smartblock/smarttest"
"github.com/anyproto/anytype-heart/core/block/editor/state"
domain "github.com/anyproto/anytype-heart/core/domain"
"github.com/anyproto/anytype-heart/core/subscription"
"github.com/anyproto/anytype-heart/core/syncstatus/detailsupdater/mock_detailsupdater"
"github.com/anyproto/anytype-heart/core/syncstatus/filesyncstatus"
"github.com/anyproto/anytype-heart/core/syncstatus/syncsubscriptions"
"github.com/anyproto/anytype-heart/pkg/lib/bundle"
coresb "github.com/anyproto/anytype-heart/pkg/lib/core/smartblock"
"github.com/anyproto/anytype-heart/pkg/lib/localstore/objectstore"
@ -26,232 +31,286 @@ import (
"github.com/anyproto/anytype-heart/util/pbtypes"
)
func TestSyncStatusUpdater_UpdateDetails(t *testing.T) {
t.Run("update sync status and date - no changes", func(t *testing.T) {
// given
fixture := newFixture(t)
fixture.storeFixture.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id"),
bundle.RelationKeySyncStatus: pbtypes.Int64(int64(domain.Synced)),
bundle.RelationKeySyncError: pbtypes.Int64(int64(domain.Null)),
},
})
// when
err := fixture.updater.updateObjectDetails(&syncStatusDetails{[]string{"id"}, domain.ObjectSynced, domain.Null, "spaceId"}, "id")
// then
assert.Nil(t, err)
fixture.service.AssertNotCalled(t, "Get")
})
t.Run("update sync status and date - details exist in store", func(t *testing.T) {
// given
fixture := newFixture(t)
space := mock_clientspace.NewMockSpace(t)
fixture.service.EXPECT().Get(fixture.updater.ctx, "spaceId").Return(space, nil)
fixture.storeFixture.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id"),
},
})
space.EXPECT().DoLockedIfNotExists("id", mock.Anything).Return(nil)
// when
fixture.statusUpdater.EXPECT().SendUpdate(domain.MakeSyncStatus("spaceId", domain.Synced, domain.Null, domain.Objects))
err := fixture.updater.updateObjectDetails(&syncStatusDetails{[]string{"id"}, domain.ObjectSynced, domain.Null, "spaceId"}, "id")
// then
assert.Nil(t, err)
})
t.Run("update sync status and date - object not exist in cache", func(t *testing.T) {
// given
fixture := newFixture(t)
space := mock_clientspace.NewMockSpace(t)
fixture.service.EXPECT().Get(fixture.updater.ctx, "spaceId").Return(space, nil)
fixture.storeFixture.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id"),
bundle.RelationKeySyncStatus: pbtypes.Int64(int64(domain.Error)),
bundle.RelationKeySyncError: pbtypes.Int64(int64(domain.NetworkError)),
},
})
space.EXPECT().DoLockedIfNotExists("id", mock.Anything).Return(nil)
// when
fixture.statusUpdater.EXPECT().SendUpdate(domain.MakeSyncStatus("spaceId", domain.Synced, domain.Null, domain.Objects))
err := fixture.updater.updateObjectDetails(&syncStatusDetails{[]string{"id"}, domain.ObjectSynced, domain.Null, "spaceId"}, "id")
// then
assert.Nil(t, err)
})
t.Run("update sync status and date - object exist in cache", func(t *testing.T) {
// given
fixture := newFixture(t)
space := mock_clientspace.NewMockSpace(t)
fixture.service.EXPECT().Get(fixture.updater.ctx, "spaceId").Return(space, nil)
space.EXPECT().DoLockedIfNotExists("id", mock.Anything).Return(ocache.ErrExists)
space.EXPECT().DoCtx(fixture.updater.ctx, "id", mock.Anything).Return(nil)
// when
fixture.statusUpdater.EXPECT().SendUpdate(domain.MakeSyncStatus("spaceId", domain.Synced, domain.Null, domain.Objects))
err := fixture.updater.updateObjectDetails(&syncStatusDetails{[]string{"id"}, domain.ObjectSynced, domain.Null, "spaceId"}, "id")
// then
assert.Nil(t, err)
})
t.Run("update sync status and date - file status", func(t *testing.T) {
// given
fixture := newFixture(t)
space := mock_clientspace.NewMockSpace(t)
fixture.service.EXPECT().Get(fixture.updater.ctx, "spaceId").Return(space, nil)
fixture.storeFixture.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id"),
bundle.RelationKeyFileBackupStatus: pbtypes.Int64(int64(filesyncstatus.Syncing)),
},
})
space.EXPECT().DoLockedIfNotExists("id", mock.Anything).Return(nil)
// when
fixture.statusUpdater.EXPECT().SendUpdate(domain.MakeSyncStatus("spaceId", domain.Syncing, domain.Null, domain.Objects))
err := fixture.updater.updateObjectDetails(&syncStatusDetails{[]string{"id"}, domain.ObjectSynced, domain.Null, "spaceId"}, "id")
// then
assert.Nil(t, err)
})
t.Run("update sync status and date - unknown file status", func(t *testing.T) {
// given
fixture := newFixture(t)
space := mock_clientspace.NewMockSpace(t)
fixture.service.EXPECT().Get(fixture.updater.ctx, "spaceId").Return(space, nil)
fixture.storeFixture.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id"),
bundle.RelationKeyFileBackupStatus: pbtypes.Int64(int64(filesyncstatus.Unknown)),
},
})
space.EXPECT().DoLockedIfNotExists("id", mock.Anything).Return(nil)
// when
fixture.statusUpdater.EXPECT().SendUpdate(domain.MakeSyncStatus("spaceId", domain.Error, domain.NetworkError, domain.Objects))
err := fixture.updater.updateObjectDetails(&syncStatusDetails{[]string{"id"}, domain.ObjectSynced, domain.Null, "spaceId"}, "id")
// then
assert.Nil(t, err)
})
t.Run("update sync status and date - queued file status", func(t *testing.T) {
// given
fixture := newFixture(t)
space := mock_clientspace.NewMockSpace(t)
fixture.service.EXPECT().Get(fixture.updater.ctx, "spaceId").Return(space, nil)
fixture.storeFixture.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id"),
bundle.RelationKeyFileBackupStatus: pbtypes.Int64(int64(filesyncstatus.Queued)),
},
})
space.EXPECT().DoLockedIfNotExists("id", mock.Anything).Return(nil)
// when
fixture.statusUpdater.EXPECT().SendUpdate(domain.MakeSyncStatus("spaceId", domain.Syncing, domain.Null, domain.Objects))
err := fixture.updater.updateObjectDetails(&syncStatusDetails{[]string{"id"}, domain.ObjectSyncing, domain.Null, "spaceId"}, "id")
// then
assert.Nil(t, err)
})
t.Run("update sync status and date - synced file status", func(t *testing.T) {
// given
fixture := newFixture(t)
space := mock_clientspace.NewMockSpace(t)
fixture.service.EXPECT().Get(fixture.updater.ctx, "spaceId").Return(space, nil)
fixture.storeFixture.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id"),
bundle.RelationKeyFileBackupStatus: pbtypes.Int64(int64(filesyncstatus.Synced)),
},
})
space.EXPECT().DoLockedIfNotExists("id", mock.Anything).Return(nil)
// when
fixture.statusUpdater.EXPECT().SendUpdate(domain.MakeSyncStatus("spaceId", domain.Synced, domain.Null, domain.Objects))
err := fixture.updater.updateObjectDetails(&syncStatusDetails{[]string{"id"}, domain.ObjectSynced, domain.Null, "spaceId"}, "id")
// then
assert.Nil(t, err)
})
type updateTester struct {
t *testing.T
waitCh chan struct{}
minEventsCount int
maxEventsCount int
}
func TestSyncStatusUpdater_Run(t *testing.T) {
t.Run("run", func(t *testing.T) {
// given
fixture := newFixture(t)
func newUpdateTester(t *testing.T, minEventsCount int, maxEventsCount int) *updateTester {
return &updateTester{
t: t,
minEventsCount: minEventsCount,
maxEventsCount: maxEventsCount,
waitCh: make(chan struct{}, maxEventsCount),
}
}
func (t *updateTester) done() {
t.waitCh <- struct{}{}
}
// wait waits for at least one event up to t.maxEventsCount events
func (t *updateTester) wait() {
timeout := time.After(1 * time.Second)
minReceivedTimer := time.After(10 * time.Millisecond)
var eventsReceived int
for i := 0; i < t.maxEventsCount; i++ {
select {
case <-minReceivedTimer:
if eventsReceived >= t.minEventsCount {
return
}
case <-t.waitCh:
eventsReceived++
case <-timeout:
t.t.Fatal("timeout")
}
}
}
func newUpdateDetailsFixture(t *testing.T) *fixture {
fx := newFixture(t)
fx.spaceService.EXPECT().TechSpaceId().Return("techSpace")
err := fx.Run(context.Background())
require.NoError(t, err)
t.Cleanup(func() {
err := fx.Close(context.Background())
require.NoError(t, err)
})
return fx
}
func TestSyncStatusUpdater_UpdateDetails(t *testing.T) {
t.Run("ignore tech space", func(t *testing.T) {
fx := newUpdateDetailsFixture(t)
fx.UpdateDetails("spaceView1", domain.ObjectSyncStatusSynced, "techSpace")
})
t.Run("updates to the same object", func(t *testing.T) {
fx := newUpdateDetailsFixture(t)
updTester := newUpdateTester(t, 1, 4)
// when
fixture.service.EXPECT().TechSpaceId().Return("techSpaceId")
space := mock_clientspace.NewMockSpace(t)
fixture.service.EXPECT().Get(mock.Anything, mock.Anything).Return(space, nil).Maybe()
space.EXPECT().DoLockedIfNotExists(mock.Anything, mock.Anything).Return(nil).Maybe()
space.EXPECT().DoCtx(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
err := fixture.updater.Run(context.Background())
fixture.statusUpdater.EXPECT().SendUpdate(mock.Anything).Return().Maybe()
assert.Nil(t, err)
fixture.updater.UpdateDetails([]string{"id"}, domain.ObjectSynced, domain.Null, "spaceId")
fx.spaceService.EXPECT().Get(mock.Anything, "space1").Return(space, nil)
space.EXPECT().DoLockedIfNotExists(mock.Anything, mock.Anything).Return(ocache.ErrExists).Times(0)
space.EXPECT().DoCtx(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, objectId string, apply func(smartblock.SmartBlock) error) {
sb := smarttest.New(objectId)
st := sb.Doc.(*state.State)
st.SetDetailAndBundledRelation(bundle.RelationKeyLayout, pbtypes.Int64(int64(model.ObjectType_basic)))
err := apply(sb)
require.NoError(t, err)
// then
err = fixture.updater.Close(context.Background())
assert.Nil(t, err)
det := sb.Doc.LocalDetails()
assert.Contains(t, det.GetFields(), bundle.RelationKeySyncStatus.String())
assert.Contains(t, det.GetFields(), bundle.RelationKeySyncDate.String())
assert.Contains(t, det.GetFields(), bundle.RelationKeySyncError.String())
fx.spaceStatusUpdater.EXPECT().Refresh("space1")
updTester.done()
}).Return(nil).Times(0)
fx.UpdateDetails("id1", domain.ObjectSyncStatusSyncing, "space1")
fx.UpdateDetails("id1", domain.ObjectSyncStatusError, "space1")
fx.UpdateDetails("id1", domain.ObjectSyncStatusSyncing, "space1")
fx.UpdateDetails("id1", domain.ObjectSyncStatusSynced, "space1")
updTester.wait()
})
t.Run("run 2 time for 1 object", func(t *testing.T) {
// given
fixture := newFixture(t)
t.Run("updates to object not in cache", func(t *testing.T) {
fx := newUpdateDetailsFixture(t)
updTester := newUpdateTester(t, 1, 1)
// when
fixture.service.EXPECT().TechSpaceId().Return("techSpaceId").Times(2)
fixture.updater.UpdateDetails([]string{"id"}, domain.ObjectSynced, domain.Null, "spaceId")
fixture.updater.UpdateDetails([]string{"id"}, domain.ObjectSyncing, domain.Null, "spaceId")
fx.subscriptionService.StoreFixture.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id1"),
bundle.RelationKeySpaceId: pbtypes.String("space1"),
bundle.RelationKeyLayout: pbtypes.Int64(int64(model.ObjectType_basic)),
},
})
// then
assert.Equal(t, &syncStatusDetails{status: domain.ObjectSyncing, syncError: domain.Null, spaceId: "spaceId"}, fixture.updater.entries["id"])
space := mock_clientspace.NewMockSpace(t)
fx.spaceService.EXPECT().Get(mock.Anything, "space1").Return(space, nil)
space.EXPECT().DoLockedIfNotExists(mock.Anything, mock.Anything).Run(func(objectId string, proc func() error) {
err := proc()
require.NoError(t, err)
details, err := fx.objectStore.GetDetails(objectId)
require.NoError(t, err)
assert.True(t, pbtypes.GetInt64(details.Details, bundle.RelationKeySyncStatus.String()) == int64(domain.ObjectSyncStatusError))
assert.True(t, pbtypes.GetInt64(details.Details, bundle.RelationKeySyncError.String()) == int64(domain.SyncErrorNull))
assert.Contains(t, details.Details.GetFields(), bundle.RelationKeySyncDate.String())
updTester.done()
}).Return(nil).Times(0)
fx.UpdateDetails("id1", domain.ObjectSyncStatusError, "space1")
fx.spaceStatusUpdater.EXPECT().Refresh("space1")
updTester.wait()
})
t.Run("updates in file object", func(t *testing.T) {
t.Run("file backup status limited", func(t *testing.T) {
fx := newUpdateDetailsFixture(t)
updTester := newUpdateTester(t, 1, 1)
space := mock_clientspace.NewMockSpace(t)
fx.spaceService.EXPECT().Get(mock.Anything, "space1").Return(space, nil)
space.EXPECT().DoLockedIfNotExists(mock.Anything, mock.Anything).Return(ocache.ErrExists)
space.EXPECT().DoCtx(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, objectId string, apply func(smartblock.SmartBlock) error) {
sb := smarttest.New(objectId)
st := sb.Doc.(*state.State)
st.SetDetailAndBundledRelation(bundle.RelationKeyLayout, pbtypes.Int64(int64(model.ObjectType_file)))
st.SetDetailAndBundledRelation(bundle.RelationKeyFileBackupStatus, pbtypes.Int64(int64(filesyncstatus.Limited)))
err := apply(sb)
require.NoError(t, err)
det := sb.Doc.LocalDetails()
assert.True(t, pbtypes.GetInt64(det, bundle.RelationKeySyncStatus.String()) == int64(domain.ObjectSyncStatusError))
assert.True(t, pbtypes.GetInt64(det, bundle.RelationKeySyncError.String()) == int64(domain.SyncErrorOversized))
assert.Contains(t, det.GetFields(), bundle.RelationKeySyncDate.String())
fx.spaceStatusUpdater.EXPECT().Refresh("space1")
updTester.done()
}).Return(nil)
fx.UpdateDetails("id2", domain.ObjectSyncStatusSynced, "space1")
updTester.wait()
})
t.Run("prioritize object status", func(t *testing.T) {
fx := newUpdateDetailsFixture(t)
updTester := newUpdateTester(t, 1, 1)
space := mock_clientspace.NewMockSpace(t)
fx.spaceService.EXPECT().Get(mock.Anything, "space1").Return(space, nil)
space.EXPECT().DoLockedIfNotExists(mock.Anything, mock.Anything).Return(ocache.ErrExists)
space.EXPECT().DoCtx(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, objectId string, apply func(smartblock.SmartBlock) error) {
sb := smarttest.New(objectId)
st := sb.Doc.(*state.State)
st.SetDetailAndBundledRelation(bundle.RelationKeyLayout, pbtypes.Int64(int64(model.ObjectType_file)))
st.SetDetailAndBundledRelation(bundle.RelationKeyFileBackupStatus, pbtypes.Int64(int64(filesyncstatus.Synced)))
err := apply(sb)
require.NoError(t, err)
det := sb.Doc.LocalDetails()
assert.True(t, pbtypes.GetInt64(det, bundle.RelationKeySyncStatus.String()) == int64(domain.ObjectSyncStatusSyncing))
assert.Contains(t, det.GetFields(), bundle.RelationKeySyncError.String())
assert.Contains(t, det.GetFields(), bundle.RelationKeySyncDate.String())
fx.spaceStatusUpdater.EXPECT().Refresh("space1")
updTester.done()
}).Return(nil)
fx.UpdateDetails("id3", domain.ObjectSyncStatusSyncing, "space1")
updTester.wait()
})
})
// TODO Test DoLockedIfNotExists
}
func TestSyncStatusUpdater_UpdateSpaceDetails(t *testing.T) {
fx := newUpdateDetailsFixture(t)
updTester := newUpdateTester(t, 3, 3)
fx.subscriptionService.StoreFixture.AddObjects(t, []objectstore.TestObject{
{
bundle.RelationKeyId: pbtypes.String("id1"),
bundle.RelationKeySpaceId: pbtypes.String("space1"),
bundle.RelationKeyLayout: pbtypes.Int64(int64(model.ObjectType_basic)),
bundle.RelationKeySyncStatus: pbtypes.Int64(int64(domain.ObjectSyncStatusSyncing)),
},
{
bundle.RelationKeyId: pbtypes.String("id4"),
bundle.RelationKeySpaceId: pbtypes.String("space1"),
bundle.RelationKeyLayout: pbtypes.Int64(int64(model.ObjectType_basic)),
bundle.RelationKeySyncStatus: pbtypes.Int64(int64(domain.ObjectSyncStatusSyncing)),
},
})
space := mock_clientspace.NewMockSpace(t)
fx.spaceService.EXPECT().Get(mock.Anything, "space1").Return(space, nil)
space.EXPECT().DoLockedIfNotExists(mock.Anything, mock.Anything).Return(ocache.ErrExists).Times(0)
assertUpdate := func(objectId string, status domain.ObjectSyncStatus) {
space.EXPECT().DoCtx(mock.Anything, objectId, mock.Anything).Run(func(ctx context.Context, objectId string, apply func(smartblock.SmartBlock) error) {
sb := smarttest.New(objectId)
st := sb.Doc.(*state.State)
st.SetDetailAndBundledRelation(bundle.RelationKeyLayout, pbtypes.Int64(int64(model.ObjectType_basic)))
err := apply(sb)
require.NoError(t, err)
det := sb.Doc.LocalDetails()
assert.True(t, pbtypes.GetInt64(det, bundle.RelationKeySyncStatus.String()) == int64(status))
assert.Contains(t, det.GetFields(), bundle.RelationKeySyncDate.String())
assert.Contains(t, det.GetFields(), bundle.RelationKeySyncError.String())
fx.spaceStatusUpdater.EXPECT().Refresh("space1")
updTester.done()
}).Return(nil).Times(0)
}
assertUpdate("id2", domain.ObjectSyncStatusSyncing)
assertUpdate("id4", domain.ObjectSyncStatusSynced)
fx.spaceStatusUpdater.EXPECT().UpdateMissingIds("space1", []string{"id3"})
fx.UpdateSpaceDetails([]string{"id1", "id2"}, []string{"id3"}, "space1")
fx.spaceStatusUpdater.EXPECT().UpdateMissingIds("space1", []string{"id3"})
fx.spaceStatusUpdater.EXPECT().Refresh("space1")
fx.UpdateSpaceDetails([]string{"id1", "id2"}, []string{"id3"}, "space1")
updTester.wait()
}
func TestSyncStatusUpdater_setSyncDetails(t *testing.T) {
t.Run("set smartblock details", func(t *testing.T) {
// given
fixture := newFixture(t)
fx := newFixture(t)
sb := smarttest.New("id")
// when
err := fixture.updater.setSyncDetails(fixture.sb, domain.ObjectError, domain.NetworkError)
err := fx.setSyncDetails(sb, domain.ObjectSyncStatusError, domain.SyncErrorNetworkError)
assert.Nil(t, err)
// then
details := fixture.sb.NewState().CombinedDetails().GetFields()
details := sb.NewState().CombinedDetails().GetFields()
assert.NotNil(t, details)
assert.Equal(t, pbtypes.Int64(int64(domain.Error)), details[bundle.RelationKeySyncStatus.String()])
assert.Equal(t, pbtypes.Int64(int64(domain.NetworkError)), details[bundle.RelationKeySyncError.String()])
assert.Equal(t, pbtypes.Int64(int64(domain.SpaceSyncStatusError)), details[bundle.RelationKeySyncStatus.String()])
assert.Equal(t, pbtypes.Int64(int64(domain.SyncErrorNetworkError)), details[bundle.RelationKeySyncError.String()])
assert.NotNil(t, details[bundle.RelationKeySyncDate.String()])
})
t.Run("not set smartblock details, because it doesn't implement interface DetailsSettable", func(t *testing.T) {
// given
fixture := newFixture(t)
fx := newFixture(t)
sb := smarttest.New("id")
// when
fixture.sb.SetType(coresb.SmartBlockTypePage)
err := fixture.updater.setSyncDetails(editor.NewMissingObject(fixture.sb), domain.ObjectError, domain.NetworkError)
sb.SetType(coresb.SmartBlockTypePage)
err := fx.setSyncDetails(editor.NewMissingObject(sb), domain.ObjectSyncStatusError, domain.SyncErrorNetworkError)
// then
assert.Nil(t, err)
})
t.Run("not set smartblock details, because it doesn't need details", func(t *testing.T) {
// given
fixture := newFixture(t)
fx := newFixture(t)
sb := smarttest.New("id")
// when
fixture.sb.SetType(coresb.SmartBlockTypeHome)
err := fixture.updater.setSyncDetails(fixture.sb, domain.ObjectError, domain.NetworkError)
sb.SetType(coresb.SmartBlockTypeHome)
err := fx.setSyncDetails(sb, domain.ObjectSyncStatusError, domain.SyncErrorNetworkError)
// then
assert.Nil(t, err)
@ -261,13 +320,13 @@ func TestSyncStatusUpdater_setSyncDetails(t *testing.T) {
func TestSyncStatusUpdater_isLayoutSuitableForSyncRelations(t *testing.T) {
t.Run("isLayoutSuitableForSyncRelations - participant details", func(t *testing.T) {
// given
fixture := newFixture(t)
fx := newFixture(t)
// when
details := &types.Struct{Fields: map[string]*types.Value{
bundle.RelationKeyLayout.String(): pbtypes.Float64(float64(model.ObjectType_participant)),
}}
isSuitable := fixture.updater.isLayoutSuitableForSyncRelations(details)
isSuitable := fx.isLayoutSuitableForSyncRelations(details)
// then
assert.False(t, isSuitable)
@ -275,13 +334,13 @@ func TestSyncStatusUpdater_isLayoutSuitableForSyncRelations(t *testing.T) {
t.Run("isLayoutSuitableForSyncRelations - basic details", func(t *testing.T) {
// given
fixture := newFixture(t)
fx := newFixture(t)
// when
details := &types.Struct{Fields: map[string]*types.Value{
bundle.RelationKeyLayout.String(): pbtypes.Float64(float64(model.ObjectType_basic)),
}}
isSuitable := fixture.updater.isLayoutSuitableForSyncRelations(details)
isSuitable := fx.isLayoutSuitableForSyncRelations(details)
// then
assert.True(t, isSuitable)
@ -289,34 +348,36 @@ func TestSyncStatusUpdater_isLayoutSuitableForSyncRelations(t *testing.T) {
}
func newFixture(t *testing.T) *fixture {
smartTest := smarttest.New("id")
storeFixture := objectstore.NewStoreFixture(t)
service := mock_space.NewMockService(t)
updater := &syncStatusUpdater{
batcher: mb.New[*syncStatusDetails](0),
finish: make(chan struct{}),
entries: map[string]*syncStatusDetails{},
}
updater := New()
statusUpdater := mock_detailsupdater.NewMockSpaceStatusUpdater(t)
syncSub := syncsubscriptions.New()
ctx := context.Background()
a := &app.App{}
a.Register(storeFixture).
Register(testutil.PrepareMock(context.Background(), a, service)).
Register(testutil.PrepareMock(context.Background(), a, statusUpdater))
subscriptionService := subscription.RegisterSubscriptionService(t, a)
a.Register(syncSub)
a.Register(testutil.PrepareMock(ctx, a, service))
a.Register(testutil.PrepareMock(ctx, a, statusUpdater))
err := updater.Init(a)
assert.Nil(t, err)
require.NoError(t, err)
err = a.Start(ctx)
require.NoError(t, err)
return &fixture{
updater: updater,
sb: smartTest,
storeFixture: storeFixture,
service: service,
statusUpdater: statusUpdater,
syncStatusUpdater: updater.(*syncStatusUpdater),
spaceService: service,
spaceStatusUpdater: statusUpdater,
subscriptionService: subscriptionService,
}
}
type fixture struct {
sb *smarttest.SmartTest
updater *syncStatusUpdater
storeFixture *objectstore.StoreFixture
service *mock_space.MockService
statusUpdater *mock_detailsupdater.MockSpaceStatusUpdater
*syncStatusUpdater
spaceService *mock_space.MockService
spaceStatusUpdater *mock_detailsupdater.MockSpaceStatusUpdater
subscriptionService *subscription.InternalTestService
}

View file

@ -3,51 +3,37 @@ package syncstatus
import (
"context"
"fmt"
"time"
"github.com/anyproto/anytype-heart/core/block/cache"
"github.com/anyproto/anytype-heart/core/block/editor/basic"
"github.com/anyproto/anytype-heart/core/block/editor/smartblock"
"github.com/anyproto/anytype-heart/core/domain"
"github.com/anyproto/anytype-heart/core/syncstatus/filesyncstatus"
"github.com/anyproto/anytype-heart/pkg/lib/bundle"
"github.com/anyproto/anytype-heart/pkg/lib/pb/model"
"github.com/anyproto/anytype-heart/util/pbtypes"
)
const limitReachErrorPercentage = 0.01
func (s *service) onFileUploadStarted(objectId string, _ domain.FullFileId) error {
return s.indexFileSyncStatus(objectId, filesyncstatus.Syncing, 0)
return s.indexFileSyncStatus(objectId, filesyncstatus.Syncing)
}
func (s *service) onFileUploaded(objectId string, _ domain.FullFileId) error {
return s.indexFileSyncStatus(objectId, filesyncstatus.Synced, 0)
return s.indexFileSyncStatus(objectId, filesyncstatus.Synced)
}
func (s *service) onFileLimited(objectId string, _ domain.FullFileId, bytesLeftPercentage float64) error {
return s.indexFileSyncStatus(objectId, filesyncstatus.Limited, bytesLeftPercentage)
return s.indexFileSyncStatus(objectId, filesyncstatus.Limited)
}
func (s *service) OnFileDelete(fileId domain.FullFileId) {
s.sendSpaceStatusUpdate(filesyncstatus.Synced, fileId.SpaceId, 0)
}
func (s *service) indexFileSyncStatus(fileObjectId string, status filesyncstatus.Status, bytesLeftPercentage float64) error {
var spaceId string
func (s *service) indexFileSyncStatus(fileObjectId string, status filesyncstatus.Status) error {
err := cache.Do(s.objectGetter, fileObjectId, func(sb smartblock.SmartBlock) (err error) {
spaceId = sb.SpaceID()
prevStatus := pbtypes.GetInt64(sb.Details(), bundle.RelationKeyFileBackupStatus.String())
newStatus := int64(status)
if prevStatus == newStatus {
return nil
}
detailsSetter, ok := sb.(basic.DetailsSettable)
if !ok {
return fmt.Errorf("setting of details is not supported for %T", sb)
}
details := provideFileStatusDetails(status, newStatus)
return detailsSetter.SetDetails(nil, details, true)
st := sb.NewState()
st.SetDetailAndBundledRelation(bundle.RelationKeyFileBackupStatus, pbtypes.Int64(newStatus))
return sb.Apply(st)
})
if err != nil {
return fmt.Errorf("get object: %w", err)
@ -56,79 +42,5 @@ func (s *service) indexFileSyncStatus(fileObjectId string, status filesyncstatus
if err != nil {
return fmt.Errorf("update tree: %w", err)
}
s.sendSpaceStatusUpdate(status, spaceId, bytesLeftPercentage)
return nil
}
func provideFileStatusDetails(status filesyncstatus.Status, newStatus int64) []*model.Detail {
syncStatus, syncError := getFileObjectStatus(status)
details := make([]*model.Detail, 0, 4)
details = append(details, &model.Detail{
Key: bundle.RelationKeySyncStatus.String(),
Value: pbtypes.Int64(int64(syncStatus)),
})
details = append(details, &model.Detail{
Key: bundle.RelationKeySyncError.String(),
Value: pbtypes.Int64(int64(syncError)),
})
details = append(details, &model.Detail{
Key: bundle.RelationKeySyncDate.String(),
Value: pbtypes.Int64(time.Now().Unix()),
})
details = append(details, &model.Detail{
Key: bundle.RelationKeyFileBackupStatus.String(),
Value: pbtypes.Int64(newStatus),
})
return details
}
func (s *service) sendSpaceStatusUpdate(status filesyncstatus.Status, spaceId string, bytesLeftPercentage float64) {
spaceStatus, spaceError := getSyncStatus(status, bytesLeftPercentage)
syncStatus := domain.MakeSyncStatus(spaceId, spaceStatus, spaceError, domain.Files)
s.spaceSyncStatus.SendUpdate(syncStatus)
}
func getFileObjectStatus(status filesyncstatus.Status) (domain.ObjectSyncStatus, domain.SyncError) {
var (
objectSyncStatus domain.ObjectSyncStatus
objectError domain.SyncError
)
switch status {
case filesyncstatus.Synced:
objectSyncStatus = domain.ObjectSynced
case filesyncstatus.Syncing:
objectSyncStatus = domain.ObjectSyncing
case filesyncstatus.Queued:
objectSyncStatus = domain.ObjectQueued
case filesyncstatus.Limited:
objectError = domain.Oversized
objectSyncStatus = domain.ObjectError
case filesyncstatus.Unknown:
objectSyncStatus = domain.ObjectError
objectError = domain.NetworkError
}
return objectSyncStatus, objectError
}
func getSyncStatus(status filesyncstatus.Status, bytesLeftPercentage float64) (domain.SpaceSyncStatus, domain.SyncError) {
var (
spaceStatus domain.SpaceSyncStatus
spaceError domain.SyncError
)
switch status {
case filesyncstatus.Synced:
spaceStatus = domain.Synced
case filesyncstatus.Syncing, filesyncstatus.Queued:
spaceStatus = domain.Syncing
case filesyncstatus.Limited:
spaceStatus = domain.Synced
if bytesLeftPercentage <= limitReachErrorPercentage {
spaceStatus = domain.Error
spaceError = domain.StorageLimitExceed
}
case filesyncstatus.Unknown:
spaceStatus = domain.Error
spaceError = domain.NetworkError
}
return spaceStatus, spaceError
}

View file

@ -1,79 +0,0 @@
package syncstatus
import (
"testing"
"github.com/anyproto/anytype-heart/core/domain"
"github.com/anyproto/anytype-heart/core/syncstatus/filesyncstatus"
"github.com/anyproto/anytype-heart/core/syncstatus/spacesyncstatus/mock_spacesyncstatus"
)
func Test_sendSpaceStatusUpdate(t *testing.T) {
t.Run("file limited", func(t *testing.T) {
// given
updater := mock_spacesyncstatus.NewMockUpdater(t)
s := &service{
spaceSyncStatus: updater,
}
// when
updater.EXPECT().SendUpdate(domain.MakeSyncStatus("spaceId", domain.Error, domain.StorageLimitExceed, domain.Files)).Return()
s.sendSpaceStatusUpdate(filesyncstatus.Limited, "spaceId", 0)
})
t.Run("file limited, but over 1% of storage is available", func(t *testing.T) {
// given
updater := mock_spacesyncstatus.NewMockUpdater(t)
s := &service{
spaceSyncStatus: updater,
}
// when
updater.EXPECT().SendUpdate(domain.MakeSyncStatus("spaceId", domain.Synced, domain.Null, domain.Files)).Return()
s.sendSpaceStatusUpdate(filesyncstatus.Limited, "spaceId", 0.9)
})
t.Run("file synced", func(t *testing.T) {
// given
updater := mock_spacesyncstatus.NewMockUpdater(t)
s := &service{
spaceSyncStatus: updater,
}
// when
updater.EXPECT().SendUpdate(domain.MakeSyncStatus("spaceId", domain.Synced, domain.Null, domain.Files)).Return()
s.sendSpaceStatusUpdate(filesyncstatus.Synced, "spaceId", 0)
})
t.Run("file queued", func(t *testing.T) {
// given
updater := mock_spacesyncstatus.NewMockUpdater(t)
s := &service{
spaceSyncStatus: updater,
}
// when
updater.EXPECT().SendUpdate(domain.MakeSyncStatus("spaceId", domain.Syncing, domain.Null, domain.Files)).Return()
s.sendSpaceStatusUpdate(filesyncstatus.Queued, "spaceId", 0)
})
t.Run("file syncing", func(t *testing.T) {
// given
updater := mock_spacesyncstatus.NewMockUpdater(t)
s := &service{
spaceSyncStatus: updater,
}
// when
updater.EXPECT().SendUpdate(domain.MakeSyncStatus("spaceId", domain.Syncing, domain.Null, domain.Files)).Return()
s.sendSpaceStatusUpdate(filesyncstatus.Syncing, "spaceId", 0)
})
t.Run("file unknown status", func(t *testing.T) {
// given
updater := mock_spacesyncstatus.NewMockUpdater(t)
s := &service{
spaceSyncStatus: updater,
}
// when
updater.EXPECT().SendUpdate(domain.MakeSyncStatus("spaceId", domain.Error, domain.NetworkError, domain.Files)).Return()
s.sendSpaceStatusUpdate(filesyncstatus.Unknown, "spaceId", 0)
})
}

View file

@ -0,0 +1,208 @@
// Code generated by mockery. DO NOT EDIT.
package mock_nodestatus
import (
app "github.com/anyproto/any-sync/app"
mock "github.com/stretchr/testify/mock"
nodestatus "github.com/anyproto/anytype-heart/core/syncstatus/nodestatus"
)
// MockNodeStatus is an autogenerated mock type for the NodeStatus type
type MockNodeStatus struct {
mock.Mock
}
type MockNodeStatus_Expecter struct {
mock *mock.Mock
}
func (_m *MockNodeStatus) EXPECT() *MockNodeStatus_Expecter {
return &MockNodeStatus_Expecter{mock: &_m.Mock}
}
// GetNodeStatus provides a mock function with given fields: spaceId
func (_m *MockNodeStatus) GetNodeStatus(spaceId string) nodestatus.ConnectionStatus {
ret := _m.Called(spaceId)
if len(ret) == 0 {
panic("no return value specified for GetNodeStatus")
}
var r0 nodestatus.ConnectionStatus
if rf, ok := ret.Get(0).(func(string) nodestatus.ConnectionStatus); ok {
r0 = rf(spaceId)
} else {
r0 = ret.Get(0).(nodestatus.ConnectionStatus)
}
return r0
}
// MockNodeStatus_GetNodeStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetNodeStatus'
type MockNodeStatus_GetNodeStatus_Call struct {
*mock.Call
}
// GetNodeStatus is a helper method to define mock.On call
// - spaceId string
func (_e *MockNodeStatus_Expecter) GetNodeStatus(spaceId interface{}) *MockNodeStatus_GetNodeStatus_Call {
return &MockNodeStatus_GetNodeStatus_Call{Call: _e.mock.On("GetNodeStatus", spaceId)}
}
func (_c *MockNodeStatus_GetNodeStatus_Call) Run(run func(spaceId string)) *MockNodeStatus_GetNodeStatus_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockNodeStatus_GetNodeStatus_Call) Return(_a0 nodestatus.ConnectionStatus) *MockNodeStatus_GetNodeStatus_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockNodeStatus_GetNodeStatus_Call) RunAndReturn(run func(string) nodestatus.ConnectionStatus) *MockNodeStatus_GetNodeStatus_Call {
_c.Call.Return(run)
return _c
}
// Init provides a mock function with given fields: a
func (_m *MockNodeStatus) Init(a *app.App) error {
ret := _m.Called(a)
if len(ret) == 0 {
panic("no return value specified for Init")
}
var r0 error
if rf, ok := ret.Get(0).(func(*app.App) error); ok {
r0 = rf(a)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockNodeStatus_Init_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Init'
type MockNodeStatus_Init_Call struct {
*mock.Call
}
// Init is a helper method to define mock.On call
// - a *app.App
func (_e *MockNodeStatus_Expecter) Init(a interface{}) *MockNodeStatus_Init_Call {
return &MockNodeStatus_Init_Call{Call: _e.mock.On("Init", a)}
}
func (_c *MockNodeStatus_Init_Call) Run(run func(a *app.App)) *MockNodeStatus_Init_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*app.App))
})
return _c
}
func (_c *MockNodeStatus_Init_Call) Return(err error) *MockNodeStatus_Init_Call {
_c.Call.Return(err)
return _c
}
func (_c *MockNodeStatus_Init_Call) RunAndReturn(run func(*app.App) error) *MockNodeStatus_Init_Call {
_c.Call.Return(run)
return _c
}
// Name provides a mock function with given fields:
func (_m *MockNodeStatus) Name() string {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Name")
}
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// MockNodeStatus_Name_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Name'
type MockNodeStatus_Name_Call struct {
*mock.Call
}
// Name is a helper method to define mock.On call
func (_e *MockNodeStatus_Expecter) Name() *MockNodeStatus_Name_Call {
return &MockNodeStatus_Name_Call{Call: _e.mock.On("Name")}
}
func (_c *MockNodeStatus_Name_Call) Run(run func()) *MockNodeStatus_Name_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockNodeStatus_Name_Call) Return(name string) *MockNodeStatus_Name_Call {
_c.Call.Return(name)
return _c
}
func (_c *MockNodeStatus_Name_Call) RunAndReturn(run func() string) *MockNodeStatus_Name_Call {
_c.Call.Return(run)
return _c
}
// SetNodesStatus provides a mock function with given fields: spaceId, status
func (_m *MockNodeStatus) SetNodesStatus(spaceId string, status nodestatus.ConnectionStatus) {
_m.Called(spaceId, status)
}
// MockNodeStatus_SetNodesStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetNodesStatus'
type MockNodeStatus_SetNodesStatus_Call struct {
*mock.Call
}
// SetNodesStatus is a helper method to define mock.On call
// - spaceId string
// - status nodestatus.ConnectionStatus
func (_e *MockNodeStatus_Expecter) SetNodesStatus(spaceId interface{}, status interface{}) *MockNodeStatus_SetNodesStatus_Call {
return &MockNodeStatus_SetNodesStatus_Call{Call: _e.mock.On("SetNodesStatus", spaceId, status)}
}
func (_c *MockNodeStatus_SetNodesStatus_Call) Run(run func(spaceId string, status nodestatus.ConnectionStatus)) *MockNodeStatus_SetNodesStatus_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string), args[1].(nodestatus.ConnectionStatus))
})
return _c
}
func (_c *MockNodeStatus_SetNodesStatus_Call) Return() *MockNodeStatus_SetNodesStatus_Call {
_c.Call.Return()
return _c
}
func (_c *MockNodeStatus_SetNodesStatus_Call) RunAndReturn(run func(string, nodestatus.ConnectionStatus)) *MockNodeStatus_SetNodesStatus_Call {
_c.Call.Return(run)
return _c
}
// NewMockNodeStatus creates a new instance of MockNodeStatus. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockNodeStatus(t interface {
mock.TestingT
Cleanup(func())
}) *MockNodeStatus {
mock := &MockNodeStatus{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View file

@ -1,19 +1,16 @@
package nodestatus
import (
"slices"
"sync"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/nodeconf"
)
const CName = "core.syncstatus.nodestatus"
type nodeStatus struct {
sync.Mutex
configuration nodeconf.NodeConf
nodeStatus map[string]ConnectionStatus
nodeStatus map[string]ConnectionStatus
}
type ConnectionStatus int
@ -26,16 +23,15 @@ const (
type NodeStatus interface {
app.Component
SetNodesStatus(spaceId string, senderId string, status ConnectionStatus)
SetNodesStatus(spaceId string, status ConnectionStatus)
GetNodeStatus(spaceId string) ConnectionStatus
}
func NewNodeStatus() NodeStatus {
return &nodeStatus{nodeStatus: make(map[string]ConnectionStatus, 0)}
return &nodeStatus{nodeStatus: make(map[string]ConnectionStatus)}
}
func (n *nodeStatus) Init(a *app.App) (err error) {
n.configuration = app.MustComponent[nodeconf.NodeConf](a)
return
}
@ -49,17 +45,8 @@ func (n *nodeStatus) GetNodeStatus(spaceId string) ConnectionStatus {
return n.nodeStatus[spaceId]
}
func (n *nodeStatus) SetNodesStatus(spaceId string, senderId string, status ConnectionStatus) {
if !n.isSenderResponsible(senderId, spaceId) {
return
}
func (n *nodeStatus) SetNodesStatus(spaceId string, status ConnectionStatus) {
n.Lock()
defer n.Unlock()
n.nodeStatus[spaceId] = status
}
func (n *nodeStatus) isSenderResponsible(senderId string, spaceId string) bool {
return slices.Contains(n.configuration.NodeIds(spaceId), senderId)
}

View file

@ -3,79 +3,13 @@ package nodestatus
import (
"testing"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/nodeconf/mock_nodeconf"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"github.com/stretchr/testify/require"
)
type fixture struct {
*nodeStatus
nodeConf *mock_nodeconf.MockService
}
func TestNodeStatus_SetNodesStatus(t *testing.T) {
t.Run("peer is responsible", func(t *testing.T) {
// given
f := newFixture(t)
f.nodeConf.EXPECT().NodeIds("spaceId").Return([]string{"peerId"})
// when
f.SetNodesStatus("spaceId", "peerId", Online)
// then
assert.Equal(t, Online, f.nodeStatus.nodeStatus["spaceId"])
})
t.Run("peer is not responsible", func(t *testing.T) {
// given
f := newFixture(t)
f.nodeConf.EXPECT().NodeIds("spaceId").Return([]string{"peerId2"})
// when
f.SetNodesStatus("spaceId", "peerId", ConnectionError)
// then
assert.NotEqual(t, ConnectionError, f.nodeStatus.nodeStatus["spaceId"])
})
}
func TestNodeStatus_GetNodeStatus(t *testing.T) {
t.Run("get default status", func(t *testing.T) {
// given
f := newFixture(t)
// when
status := f.GetNodeStatus("")
// then
assert.Equal(t, Online, status)
})
t.Run("get updated status", func(t *testing.T) {
// given
f := newFixture(t)
f.nodeConf.EXPECT().NodeIds("spaceId").Return([]string{"peerId"})
// when
f.SetNodesStatus("spaceId", "peerId", ConnectionError)
status := f.GetNodeStatus("spaceId")
// then
assert.Equal(t, ConnectionError, status)
})
}
func newFixture(t *testing.T) *fixture {
ctrl := gomock.NewController(t)
nodeConf := mock_nodeconf.NewMockService(ctrl)
nodeStatus := &nodeStatus{
nodeStatus: map[string]ConnectionStatus{},
}
a := &app.App{}
a.Register(nodeConf)
err := nodeStatus.Init(a)
assert.Nil(t, err)
return &fixture{
nodeStatus: nodeStatus,
nodeConf: nodeConf,
}
func TestNodeStatus(t *testing.T) {
st := NewNodeStatus()
st.SetNodesStatus("spaceId", Online)
require.Equal(t, Online, st.GetNodeStatus("spaceId"))
st.SetNodesStatus("spaceId", ConnectionError)
require.Equal(t, ConnectionError, st.GetNodeStatus("spaceId"))
}

View file

@ -112,9 +112,9 @@ func (_c *MockUpdater_Name_Call) RunAndReturn(run func() string) *MockUpdater_Na
return _c
}
// UpdateDetails provides a mock function with given fields: objectId, status, syncError, spaceId
func (_m *MockUpdater) UpdateDetails(objectId []string, status domain.ObjectSyncStatus, syncError domain.SyncError, spaceId string) {
_m.Called(objectId, status, syncError, spaceId)
// UpdateDetails provides a mock function with given fields: objectId, status, spaceId
func (_m *MockUpdater) UpdateDetails(objectId string, status domain.ObjectSyncStatus, spaceId string) {
_m.Called(objectId, status, spaceId)
}
// MockUpdater_UpdateDetails_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateDetails'
@ -123,17 +123,16 @@ type MockUpdater_UpdateDetails_Call struct {
}
// UpdateDetails is a helper method to define mock.On call
// - objectId []string
// - objectId string
// - status domain.ObjectSyncStatus
// - syncError domain.SyncError
// - spaceId string
func (_e *MockUpdater_Expecter) UpdateDetails(objectId interface{}, status interface{}, syncError interface{}, spaceId interface{}) *MockUpdater_UpdateDetails_Call {
return &MockUpdater_UpdateDetails_Call{Call: _e.mock.On("UpdateDetails", objectId, status, syncError, spaceId)}
func (_e *MockUpdater_Expecter) UpdateDetails(objectId interface{}, status interface{}, spaceId interface{}) *MockUpdater_UpdateDetails_Call {
return &MockUpdater_UpdateDetails_Call{Call: _e.mock.On("UpdateDetails", objectId, status, spaceId)}
}
func (_c *MockUpdater_UpdateDetails_Call) Run(run func(objectId []string, status domain.ObjectSyncStatus, syncError domain.SyncError, spaceId string)) *MockUpdater_UpdateDetails_Call {
func (_c *MockUpdater_UpdateDetails_Call) Run(run func(objectId string, status domain.ObjectSyncStatus, spaceId string)) *MockUpdater_UpdateDetails_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]string), args[1].(domain.ObjectSyncStatus), args[2].(domain.SyncError), args[3].(string))
run(args[0].(string), args[1].(domain.ObjectSyncStatus), args[2].(string))
})
return _c
}
@ -143,7 +142,7 @@ func (_c *MockUpdater_UpdateDetails_Call) Return() *MockUpdater_UpdateDetails_Ca
return _c
}
func (_c *MockUpdater_UpdateDetails_Call) RunAndReturn(run func([]string, domain.ObjectSyncStatus, domain.SyncError, string)) *MockUpdater_UpdateDetails_Call {
func (_c *MockUpdater_UpdateDetails_Call) RunAndReturn(run func(string, domain.ObjectSyncStatus, string)) *MockUpdater_UpdateDetails_Call {
_c.Call.Return(run)
return _c
}

View file

@ -2,29 +2,28 @@ package objectsyncstatus
import (
"context"
"fmt"
"sync"
"time"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/nodeconf"
"github.com/anyproto/any-sync/util/periodicsync"
"github.com/anyproto/any-sync/util/slice"
"golang.org/x/exp/slices"
"github.com/anyproto/anytype-heart/core/anytype/config"
"github.com/anyproto/anytype-heart/core/domain"
"github.com/anyproto/anytype-heart/core/syncstatus/nodestatus"
"github.com/anyproto/anytype-heart/util/slice"
)
const (
syncUpdateInterval = 5
syncUpdateInterval = 3
syncTimeout = time.Second
)
@ -43,9 +42,16 @@ const (
StatusNotSynced
)
type treeHeadsEntry struct {
heads []string
syncStatus SyncStatus
}
type StatusUpdater interface {
HeadsChange(treeId string, heads []string)
HeadsReceive(senderId, treeId string, heads []string)
HeadsApply(senderId, treeId string, heads []string, allAdded bool)
ObjectReceive(senderId, treeId string, heads []string)
RemoveAllExcept(senderId string, differentRemoteIds []string)
}
@ -61,12 +67,6 @@ type StatusService interface {
StatusWatcher
}
type treeHeadsEntry struct {
heads []string
stateCounter uint64
syncStatus SyncStatus
}
type treeStatus struct {
treeId string
status SyncStatus
@ -74,22 +74,20 @@ type treeStatus struct {
type Updater interface {
app.Component
UpdateDetails(objectId []string, status domain.ObjectSyncStatus, syncError domain.SyncError, spaceId string)
UpdateDetails(objectId string, status domain.ObjectSyncStatus, spaceId string)
}
type syncStatusService struct {
sync.Mutex
configuration nodeconf.NodeConf
periodicSync periodicsync.PeriodicSync
updateReceiver UpdateReceiver
storage spacestorage.SpaceStorage
spaceId string
treeHeads map[string]treeHeadsEntry
watchers map[string]struct{}
stateCounter uint64
treeStatusBuf []treeStatus
spaceId string
synced []string
tempSynced map[string]struct{}
treeHeads map[string]treeHeadsEntry
watchers map[string]struct{}
updateIntervalSecs int
updateTimeout time.Duration
@ -102,8 +100,9 @@ type syncStatusService struct {
func NewSyncStatusService() StatusService {
return &syncStatusService{
treeHeads: map[string]treeHeadsEntry{},
watchers: map[string]struct{}{},
tempSynced: map[string]struct{}{},
treeHeads: map[string]treeHeadsEntry{},
watchers: map[string]struct{}{},
}
}
@ -112,7 +111,6 @@ func (s *syncStatusService) Init(a *app.App) (err error) {
s.updateIntervalSecs = syncUpdateInterval
s.updateTimeout = syncTimeout
s.spaceId = sharedState.SpaceId
s.configuration = app.MustComponent[nodeconf.NodeConf](a)
s.storage = app.MustComponent[spacestorage.SpaceStorage](a)
s.periodicSync = periodicsync.NewPeriodicSync(
s.updateIntervalSecs,
@ -143,85 +141,101 @@ func (s *syncStatusService) Run(ctx context.Context) error {
}
func (s *syncStatusService) HeadsChange(treeId string, heads []string) {
s.Lock()
s.addTreeHead(treeId, heads, StatusNotSynced)
s.Unlock()
s.updateDetails(treeId, domain.ObjectSyncStatusSyncing)
}
func (s *syncStatusService) ObjectReceive(senderId, treeId string, heads []string) {
s.Lock()
defer s.Unlock()
var headsCopy []string
headsCopy = append(headsCopy, heads...)
s.treeHeads[treeId] = treeHeadsEntry{
heads: headsCopy,
stateCounter: s.stateCounter,
syncStatus: StatusNotSynced,
if len(heads) == 0 || !s.isSenderResponsible(senderId) {
s.tempSynced[treeId] = struct{}{}
return
}
s.synced = append(s.synced, treeId)
}
func (s *syncStatusService) HeadsApply(senderId, treeId string, heads []string, allAdded bool) {
s.Lock()
defer s.Unlock()
if len(heads) == 0 || !s.isSenderResponsible(senderId) {
if allAdded {
s.tempSynced[treeId] = struct{}{}
}
return
}
if !allAdded {
return
}
s.synced = append(s.synced, treeId)
if curTreeHeads, ok := s.treeHeads[treeId]; ok {
// checking if we received the head that we are interested in
for _, head := range heads {
if idx, found := slices.BinarySearch(curTreeHeads.heads, head); found {
curTreeHeads.heads = slice.RemoveIndex(curTreeHeads.heads, idx)
}
}
if len(curTreeHeads.heads) == 0 {
curTreeHeads.syncStatus = StatusSynced
}
s.treeHeads[treeId] = curTreeHeads
}
s.stateCounter++
s.updateDetails(treeId, domain.ObjectSyncing)
}
func (s *syncStatusService) update(ctx context.Context) (err error) {
s.treeStatusBuf = s.treeStatusBuf[:0]
s.Lock()
var (
updateDetailsStatuses = make([]treeStatus, 0, len(s.synced))
updateThreadStatuses = make([]treeStatus, 0, len(s.watchers))
)
if s.updateReceiver == nil {
s.Unlock()
return
}
for _, treeId := range s.synced {
updateDetailsStatuses = append(updateDetailsStatuses, treeStatus{treeId, StatusSynced})
}
for treeId := range s.watchers {
// that means that we haven't yet got the status update
treeHeads, exists := s.treeHeads[treeId]
if !exists {
err = fmt.Errorf("treeHeads should always exist for watchers")
s.Unlock()
return
continue
}
s.treeStatusBuf = append(s.treeStatusBuf, treeStatus{treeId, treeHeads.syncStatus})
updateThreadStatuses = append(updateThreadStatuses, treeStatus{treeId, treeHeads.syncStatus})
}
s.synced = s.synced[:0]
s.Unlock()
s.updateReceiver.UpdateNodeStatus()
for _, entry := range s.treeStatusBuf {
for _, entry := range updateDetailsStatuses {
s.updateDetails(entry.treeId, mapStatus(entry.status))
}
for _, entry := range updateThreadStatuses {
err = s.updateReceiver.UpdateTree(ctx, entry.treeId, entry.status)
if err != nil {
return
}
s.updateDetails(entry.treeId, mapStatus(entry.status))
}
return
}
func mapStatus(status SyncStatus) domain.ObjectSyncStatus {
if status == StatusSynced {
return domain.ObjectSynced
return domain.ObjectSyncStatusSynced
}
return domain.ObjectSyncing
return domain.ObjectSyncStatusSyncing
}
func (s *syncStatusService) HeadsReceive(senderId, treeId string, heads []string) {
s.Lock()
defer s.Unlock()
}
curTreeHeads, ok := s.treeHeads[treeId]
if !ok || curTreeHeads.syncStatus == StatusSynced {
return
func (s *syncStatusService) addTreeHead(treeId string, heads []string, status SyncStatus) {
headsCopy := slice.Copy(heads)
slices.Sort(headsCopy)
s.treeHeads[treeId] = treeHeadsEntry{
heads: headsCopy,
syncStatus: status,
}
// checking if other node is responsible
if len(heads) == 0 || !s.isSenderResponsible(senderId) {
return
}
// checking if we received the head that we are interested in
for _, head := range heads {
if idx, found := slices.BinarySearch(curTreeHeads.heads, head); found {
curTreeHeads.heads[idx] = ""
}
}
curTreeHeads.heads = slice.DiscardFromSlice(curTreeHeads.heads, func(h string) bool {
return h == ""
})
if len(curTreeHeads.heads) == 0 {
curTreeHeads.syncStatus = StatusSynced
}
s.treeHeads[treeId] = curTreeHeads
}
func (s *syncStatusService) Watch(treeId string) (err error) {
@ -241,13 +255,7 @@ func (s *syncStatusService) Watch(treeId string) (err error) {
if err != nil {
return
}
slices.Sort(heads)
s.stateCounter++
s.treeHeads[treeId] = treeHeadsEntry{
heads: heads,
stateCounter: s.stateCounter,
syncStatus: StatusUnknown,
}
s.addTreeHead(treeId, heads, StatusUnknown)
}
s.watchers[treeId] = struct{}{}
@ -271,14 +279,17 @@ func (s *syncStatusService) RemoveAllExcept(senderId string, differentRemoteIds
slices.Sort(differentRemoteIds)
for treeId, entry := range s.treeHeads {
// if the current update is outdated
if entry.stateCounter > s.stateCounter {
continue
}
// if we didn't find our treeId in heads ids which are different from us and node
if _, found := slices.BinarySearch(differentRemoteIds, treeId); !found {
entry.syncStatus = StatusSynced
s.treeHeads[treeId] = entry
if entry.syncStatus != StatusSynced {
entry.syncStatus = StatusSynced
s.treeHeads[treeId] = entry
}
}
}
for treeId := range s.tempSynced {
delete(s.tempSynced, treeId)
if _, found := slices.BinarySearch(differentRemoteIds, treeId); !found {
s.synced = append(s.synced, treeId)
}
}
}
@ -289,18 +300,9 @@ func (s *syncStatusService) Close(ctx context.Context) error {
}
func (s *syncStatusService) isSenderResponsible(senderId string) bool {
return slices.Contains(s.configuration.NodeIds(s.spaceId), senderId)
return slices.Contains(s.nodeConfService.NodeIds(s.spaceId), senderId)
}
func (s *syncStatusService) updateDetails(treeId string, status domain.ObjectSyncStatus) {
var syncErr domain.SyncError
if s.nodeStatus.GetNodeStatus(s.spaceId) != nodestatus.Online || s.config.IsLocalOnlyMode() {
syncErr = domain.NetworkError
status = domain.ObjectError
}
if s.nodeConfService.NetworkCompatibilityStatus() == nodeconf.NetworkCompatibilityStatusIncompatible {
syncErr = domain.IncompatibleVersion
status = domain.ObjectError
}
s.syncDetailsUpdater.UpdateDetails([]string{treeId}, status, syncErr, s.spaceId)
s.syncDetailsUpdater.UpdateDetails(treeId, status, s.spaceId)
}

Some files were not shown because too many files have changed in this diff Show more