1
0
Fork 0
mirror of https://github.com/anyproto/any-sync.git synced 2025-06-08 05:57:03 +09:00
any-sync/util/syncqueues/limit.go
2024-08-14 00:32:08 +02:00

110 lines
2.3 KiB
Go

package syncqueues
import (
"fmt"
"sync"
"golang.org/x/exp/slices"
)
type Limit struct {
peerStep []int
totalStep []int
excludeIds []string
excludedLimit int
excludedTotal int
counter int
total int
tokens map[string]int
mx sync.Mutex
}
func NewLimit(peerStep, totalStep []int, excludeIds []string, excludedLimit int) *Limit {
if len(peerStep) == 0 || len(totalStep) == 0 || len(peerStep) != len(totalStep)+1 {
panic("incorrect limit configuration")
}
slices.SortFunc(peerStep, func(a, b int) int {
if a < b {
return 1
} else if a > b {
return -1
} else {
return 0
}
})
slices.Sort(totalStep)
// so here we would have something like
// peerStep = [3, 2, 1]
// totalStep = [3, 6], where everything more than 6 in total will get 1 token for each id
totalStep = append(totalStep, totalStep[len(totalStep)-1])
return &Limit{
excludeIds: excludeIds,
excludedLimit: excludedLimit,
peerStep: peerStep,
totalStep: totalStep,
tokens: make(map[string]int),
}
}
func (l *Limit) Take(id string) bool {
l.mx.Lock()
defer l.mx.Unlock()
if l.isExcluded(id) {
if l.tokens[id] >= l.excludedLimit {
return false
}
l.tokens[id]++
l.excludedTotal++
return true
}
if l.tokens[id] >= l.peerStep[l.counter] {
return false
}
l.tokens[id]++
l.total++
if l.total >= l.totalStep[l.counter] && l.counter < len(l.totalStep)-1 {
l.counter++
}
return true
}
func (l *Limit) Release(id string) {
l.mx.Lock()
defer l.mx.Unlock()
if l.tokens[id] > 0 {
l.tokens[id]--
} else {
return
}
if l.isExcluded(id) {
l.excludedTotal--
return
}
l.total--
if l.total < l.totalStep[l.counter] {
if l.counter == len(l.totalStep)-1 {
l.counter--
}
if l.counter > 0 {
l.counter--
}
}
}
func (l *Limit) isExcluded(id string) bool {
for _, excludeId := range l.excludeIds {
if id == excludeId {
return true
}
}
return false
}
func (l *Limit) Stats(id string) string {
l.mx.Lock()
defer l.mx.Unlock()
if l.isExcluded(id) {
return fmt.Sprintf("excluded peer: %d/%d, total: %d/%d/%d", l.tokens[id], l.excludedLimit, l.excludedTotal, l.total, l.totalStep[l.counter])
}
return fmt.Sprintf("peer: %d/%d, total: %d/%d/%d", l.tokens[id], l.peerStep[l.counter], l.excludedTotal, l.total, l.totalStep[l.counter])
}