代码整理
This commit is contained in:
41
hooks/auth/allow_all.go
Normal file
41
hooks/auth/allow_all.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
)
|
||||
|
||||
// AllowHook is an authentication hook which allows connection access
|
||||
// for all users and read and write access to all topics.
|
||||
type AllowHook struct {
|
||||
mqtt.HookBase
|
||||
}
|
||||
|
||||
// ID returns the ID of the hook.
|
||||
func (h *AllowHook) ID() string {
|
||||
return "allow-all-auth"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *AllowHook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnConnectAuthenticate,
|
||||
mqtt.OnACLCheck,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// OnConnectAuthenticate returns true/allowed for all requests.
|
||||
func (h *AllowHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// OnACLCheck returns true/allowed for all checks.
|
||||
func (h *AllowHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
|
||||
return true
|
||||
}
|
||||
35
hooks/auth/allow_all_test.go
Normal file
35
hooks/auth/allow_all_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
)
|
||||
|
||||
func TestAllowAllID(t *testing.T) {
|
||||
h := new(AllowHook)
|
||||
require.Equal(t, "allow-all-auth", h.ID())
|
||||
}
|
||||
|
||||
func TestAllowAllProvides(t *testing.T) {
|
||||
h := new(AllowHook)
|
||||
require.True(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
require.False(t, h.Provides(mqtt.OnPublished))
|
||||
}
|
||||
|
||||
func TestAllowAllOnConnectAuthenticate(t *testing.T) {
|
||||
h := new(AllowHook)
|
||||
require.True(t, h.OnConnectAuthenticate(new(mqtt.Client), packets.Packet{}))
|
||||
}
|
||||
|
||||
func TestAllowAllOnACLCheck(t *testing.T) {
|
||||
h := new(AllowHook)
|
||||
require.True(t, h.OnACLCheck(new(mqtt.Client), "any", true))
|
||||
}
|
||||
103
hooks/auth/auth.go
Normal file
103
hooks/auth/auth.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
)
|
||||
|
||||
// Options contains the configuration/rules data for the auth ledger.
|
||||
type Options struct {
|
||||
Data []byte
|
||||
Ledger *Ledger
|
||||
}
|
||||
|
||||
// Hook is an authentication hook which implements an auth ledger.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options
|
||||
ledger *Ledger
|
||||
}
|
||||
|
||||
// ID returns the ID of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "auth-ledger"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnConnectAuthenticate,
|
||||
mqtt.OnACLCheck,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// Init configures the hook with the auth ledger to be used for checking.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = new(Options)
|
||||
}
|
||||
|
||||
h.config = config.(*Options)
|
||||
|
||||
var err error
|
||||
if h.config.Ledger != nil {
|
||||
h.ledger = h.config.Ledger
|
||||
} else if len(h.config.Data) > 0 {
|
||||
h.ledger = new(Ledger)
|
||||
err = h.ledger.Unmarshal(h.config.Data)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.ledger == nil {
|
||||
h.ledger = &Ledger{
|
||||
Auth: AuthRules{},
|
||||
ACL: ACLRules{},
|
||||
}
|
||||
}
|
||||
|
||||
h.Log.Info("loaded auth rules",
|
||||
"authentication", len(h.ledger.Auth),
|
||||
"acl", len(h.ledger.ACL))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnConnectAuthenticate returns true if the connecting client has rules which provide access
|
||||
// in the auth ledger.
|
||||
func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
|
||||
if _, ok := h.ledger.AuthOk(cl, pk); ok {
|
||||
return true
|
||||
}
|
||||
|
||||
h.Log.Info("client failed authentication check",
|
||||
"username", string(pk.Connect.Username),
|
||||
"remote", cl.Net.Remote)
|
||||
return false
|
||||
}
|
||||
|
||||
// OnACLCheck returns true if the connecting client has matching read or write access to subscribe
|
||||
// or publish to a given topic.
|
||||
func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
|
||||
if _, ok := h.ledger.ACLOk(cl, topic, write); ok {
|
||||
return true
|
||||
}
|
||||
|
||||
h.Log.Debug("client failed allowed ACL check",
|
||||
"client", cl.ID,
|
||||
"username", string(cl.Properties.Username),
|
||||
"topic", topic)
|
||||
|
||||
return false
|
||||
}
|
||||
213
hooks/auth/auth_test.go
Normal file
213
hooks/auth/auth_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
)
|
||||
|
||||
var logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
// func teardown(t *testing.T, path string, h *Hook) {
|
||||
// h.Stop()
|
||||
// }
|
||||
|
||||
func TestBasicID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.Equal(t, "auth-ledger", h.ID())
|
||||
}
|
||||
|
||||
func TestBasicProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.True(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
require.False(t, h.Provides(mqtt.OnPublish))
|
||||
}
|
||||
|
||||
func TestBasicInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBasicInitDefaultConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestBasicInitWithLedgerPointer(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
ln := &Ledger{
|
||||
Auth: []AuthRule{
|
||||
{
|
||||
Remote: "127.0.0.1",
|
||||
Allow: true,
|
||||
},
|
||||
},
|
||||
ACL: []ACLRule{
|
||||
{
|
||||
Remote: "127.0.0.1",
|
||||
Filters: Filters{
|
||||
"#": ReadWrite,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := h.Init(&Options{
|
||||
Ledger: ln,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Same(t, ln, h.ledger)
|
||||
}
|
||||
|
||||
func TestBasicInitWithLedgerJSON(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
Data: ledgerJSON,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
|
||||
require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
|
||||
}
|
||||
|
||||
func TestBasicInitWithLedgerYAML(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
Data: ledgerYAML,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
|
||||
require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
|
||||
}
|
||||
|
||||
func TestBasicInitWithLedgerBadDAta(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
Data: []byte("fdsfdsafasd"),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOnConnectAuthenticate(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
ln := new(Ledger)
|
||||
ln.Auth = checkLedger.Auth
|
||||
ln.ACL = checkLedger.ACL
|
||||
err := h.Init(
|
||||
&Options{
|
||||
Ledger: ln,
|
||||
},
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, h.OnConnectAuthenticate(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
|
||||
))
|
||||
|
||||
require.False(t, h.OnConnectAuthenticate(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
|
||||
))
|
||||
|
||||
require.False(t, h.OnConnectAuthenticate(
|
||||
&mqtt.Client{},
|
||||
packets.Packet{},
|
||||
))
|
||||
}
|
||||
|
||||
func TestOnACL(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
ln := new(Ledger)
|
||||
ln.Auth = checkLedger.Auth
|
||||
ln.ACL = checkLedger.ACL
|
||||
err := h.Init(
|
||||
&Options{
|
||||
Ledger: ln,
|
||||
},
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, h.OnACLCheck(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
"mochi/info",
|
||||
true,
|
||||
))
|
||||
|
||||
require.False(t, h.OnACLCheck(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
"d/j/f",
|
||||
true,
|
||||
))
|
||||
|
||||
require.True(t, h.OnACLCheck(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
"readonly",
|
||||
false,
|
||||
))
|
||||
|
||||
require.False(t, h.OnACLCheck(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
"readonly",
|
||||
true,
|
||||
))
|
||||
}
|
||||
246
hooks/auth/ledger.go
Normal file
246
hooks/auth/ledger.go
Normal file
@@ -0,0 +1,246 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
)
|
||||
|
||||
const (
|
||||
Deny Access = iota // user cannot access the topic
|
||||
ReadOnly // user can only subscribe to the topic
|
||||
WriteOnly // user can only publish to the topic
|
||||
ReadWrite // user can both publish and subscribe to the topic
|
||||
)
|
||||
|
||||
// Access determines the read/write privileges for an ACL rule.
|
||||
type Access byte
|
||||
|
||||
// Users contains a map of access rules for specific users, keyed on username.
|
||||
type Users map[string]UserRule
|
||||
|
||||
// UserRule defines a set of access rules for a specific user.
|
||||
type UserRule struct {
|
||||
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
|
||||
Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
|
||||
ACL Filters `json:"acl,omitempty" yaml:"acl,omitempty"` // filters to match, if desired
|
||||
Disallow bool `json:"disallow,omitempty" yaml:"disallow,omitempty"` // allow or disallow the user
|
||||
}
|
||||
|
||||
// AuthRules defines generic access rules applicable to all users.
|
||||
type AuthRules []AuthRule
|
||||
|
||||
type AuthRule struct {
|
||||
Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
|
||||
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
|
||||
Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
|
||||
Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
|
||||
Allow bool `json:"allow,omitempty" yaml:"allow,omitempty"` // allow or disallow the users
|
||||
}
|
||||
|
||||
// ACLRules defines generic topic or filter access rules applicable to all users.
|
||||
type ACLRules []ACLRule
|
||||
|
||||
// ACLRule defines access rules for a specific topic or filter.
|
||||
type ACLRule struct {
|
||||
Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
|
||||
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
|
||||
Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
|
||||
Filters Filters `json:"filters,omitempty" yaml:"filters,omitempty"` // filters to match
|
||||
}
|
||||
|
||||
// Filters is a map of Access rules keyed on filter.
|
||||
type Filters map[RString]Access
|
||||
|
||||
// RString is a rule value string.
|
||||
type RString string
|
||||
|
||||
// Matches returns true if the rule matches a given string.
|
||||
func (r RString) Matches(a string) bool {
|
||||
rr := string(r)
|
||||
if r == "" || r == "*" || a == rr {
|
||||
return true
|
||||
}
|
||||
|
||||
i := strings.Index(rr, "*")
|
||||
if i > 0 && len(a) > i && strings.Compare(rr[:i], a[:i]) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// FilterMatches returns true if a filter matches a topic rule.
|
||||
func (r RString) FilterMatches(a string) bool {
|
||||
_, ok := MatchTopic(string(r), a)
|
||||
return ok
|
||||
}
|
||||
|
||||
// MatchTopic checks if a given topic matches a filter, accounting for filter
|
||||
// wildcards. Eg. filter /a/b/+/c == topic a/b/d/c.
|
||||
func MatchTopic(filter string, topic string) (elements []string, matched bool) {
|
||||
filterParts := strings.Split(filter, "/")
|
||||
topicParts := strings.Split(topic, "/")
|
||||
|
||||
elements = make([]string, 0)
|
||||
for i := 0; i < len(filterParts); i++ {
|
||||
if i >= len(topicParts) {
|
||||
matched = false
|
||||
return
|
||||
}
|
||||
|
||||
if filterParts[i] == "+" {
|
||||
elements = append(elements, topicParts[i])
|
||||
continue
|
||||
}
|
||||
|
||||
if filterParts[i] == "#" {
|
||||
matched = true
|
||||
elements = append(elements, strings.Join(topicParts[i:], "/"))
|
||||
return
|
||||
}
|
||||
|
||||
if filterParts[i] != topicParts[i] {
|
||||
matched = false
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return elements, true
|
||||
}
|
||||
|
||||
// Ledger is an auth ledger containing access rules for users and topics.
|
||||
type Ledger struct {
|
||||
sync.Mutex `json:"-" yaml:"-"`
|
||||
Users Users `json:"users" yaml:"users"`
|
||||
Auth AuthRules `json:"auth" yaml:"auth"`
|
||||
ACL ACLRules `json:"acl" yaml:"acl"`
|
||||
}
|
||||
|
||||
// Update updates the internal values of the ledger.
|
||||
func (l *Ledger) Update(ln *Ledger) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.Auth = ln.Auth
|
||||
l.ACL = ln.ACL
|
||||
}
|
||||
|
||||
// AuthOk returns true if the rules indicate the user is allowed to authenticate.
|
||||
func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) {
|
||||
// If the users map is set, always check for a predefined user first instead
|
||||
// of iterating through global rules.
|
||||
if l.Users != nil {
|
||||
if u, ok := l.Users[string(cl.Properties.Username)]; ok &&
|
||||
u.Password != "" &&
|
||||
u.Password == RString(pk.Connect.Password) {
|
||||
return 0, !u.Disallow
|
||||
}
|
||||
}
|
||||
|
||||
// If there's no users map, or no user was found, attempt to find a matching
|
||||
// rule (which may also contain a user).
|
||||
for n, rule := range l.Auth {
|
||||
if rule.Client.Matches(cl.ID) &&
|
||||
rule.Username.Matches(string(cl.Properties.Username)) &&
|
||||
rule.Password.Matches(string(pk.Connect.Password)) &&
|
||||
rule.Remote.Matches(cl.Net.Remote) {
|
||||
return n, rule.Allow
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// ACLOk returns true if the rules indicate the user is allowed to read or write to
|
||||
// a specific filter or topic respectively, based on the `write` bool.
|
||||
func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) {
|
||||
// If the users map is set, always check for a predefined user first instead
|
||||
// of iterating through global rules.
|
||||
if l.Users != nil {
|
||||
if u, ok := l.Users[string(cl.Properties.Username)]; ok && len(u.ACL) > 0 {
|
||||
for filter, access := range u.ACL {
|
||||
if filter.FilterMatches(topic) {
|
||||
if !write && (access == ReadOnly || access == ReadWrite) {
|
||||
return n, true
|
||||
} else if write && (access == WriteOnly || access == ReadWrite) {
|
||||
return n, true
|
||||
} else {
|
||||
return n, false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for n, rule := range l.ACL {
|
||||
if rule.Client.Matches(cl.ID) &&
|
||||
rule.Username.Matches(string(cl.Properties.Username)) &&
|
||||
rule.Remote.Matches(cl.Net.Remote) {
|
||||
if len(rule.Filters) == 0 {
|
||||
return n, true
|
||||
}
|
||||
|
||||
if write {
|
||||
for filter, access := range rule.Filters {
|
||||
if access == WriteOnly || access == ReadWrite {
|
||||
if filter.FilterMatches(topic) {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !write {
|
||||
for filter, access := range rule.Filters {
|
||||
if access == ReadOnly || access == ReadWrite {
|
||||
if filter.FilterMatches(topic) {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for filter := range rule.Filters {
|
||||
if filter.FilterMatches(topic) {
|
||||
return n, false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, true
|
||||
}
|
||||
|
||||
// ToJSON encodes the values into a JSON string.
|
||||
func (l *Ledger) ToJSON() (data []byte, err error) {
|
||||
return json.Marshal(l)
|
||||
}
|
||||
|
||||
// ToYAML encodes the values into a YAML string.
|
||||
func (l *Ledger) ToYAML() (data []byte, err error) {
|
||||
return yaml.Marshal(l)
|
||||
}
|
||||
|
||||
// Unmarshal decodes a JSON or YAML string (such as a rule config from a file) into a struct.
|
||||
func (l *Ledger) Unmarshal(data []byte) error {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if data[0] == '{' {
|
||||
return json.Unmarshal(data, l)
|
||||
}
|
||||
|
||||
return yaml.Unmarshal(data, &l)
|
||||
}
|
||||
610
hooks/auth/ledger_test.go
Normal file
610
hooks/auth/ledger_test.go
Normal file
@@ -0,0 +1,610 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
)
|
||||
|
||||
var (
|
||||
checkLedger = Ledger{
|
||||
Users: Users{ // users are allowed by default
|
||||
"mochi-co": {
|
||||
Password: "melon",
|
||||
ACL: Filters{
|
||||
"d/+/f": Deny,
|
||||
"mochi-co/#": ReadWrite,
|
||||
"readonly": ReadOnly,
|
||||
},
|
||||
},
|
||||
"suspended-username": {
|
||||
Password: "any",
|
||||
Disallow: true,
|
||||
},
|
||||
"mochi": { // ACL only, will defer to AuthRules for authentication
|
||||
ACL: Filters{
|
||||
"special/mochi": ReadOnly,
|
||||
"secret/mochi": Deny,
|
||||
"ignored": ReadWrite,
|
||||
},
|
||||
},
|
||||
},
|
||||
Auth: AuthRules{
|
||||
{Username: "banned-user"}, // never allow specific username
|
||||
{Remote: "127.0.0.1", Allow: true}, // always allow localhost
|
||||
{Remote: "123.123.123.123"}, // disallow any from specific address
|
||||
{Username: "not-mochi", Remote: "111.144.155.166"}, // disallow specific username and address
|
||||
{Remote: "111.*", Allow: true}, // allow any in wildcard (that isn't the above username)
|
||||
{Username: "mochi", Password: "melon", Allow: true}, // allow matching user/pass
|
||||
{Username: "mochi-co", Password: "melon", Allow: false}, // allow matching user/pass (should never trigger due to Users map)
|
||||
},
|
||||
ACL: ACLRules{
|
||||
{
|
||||
Username: "mochi", // allow matching user/pass
|
||||
Filters: Filters{
|
||||
"a/b/c": Deny,
|
||||
"d/+/f": Deny,
|
||||
"mochi/#": ReadWrite,
|
||||
"updates/#": WriteOnly,
|
||||
"readonly": ReadOnly,
|
||||
"ignored": Deny,
|
||||
},
|
||||
},
|
||||
{Remote: "localhost", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to localhost
|
||||
{Username: "admin", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to admin
|
||||
{Remote: "001.002.003.004"}, // Allow all with no filter
|
||||
{Filters: Filters{"$SYS/#": Deny}}, // Deny $SYS access to all others
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func TestRStringMatches(t *testing.T) {
|
||||
require.True(t, RString("*").Matches("any"))
|
||||
require.True(t, RString("*").Matches(""))
|
||||
require.True(t, RString("").Matches("any"))
|
||||
require.True(t, RString("").Matches(""))
|
||||
require.False(t, RString("no").Matches("any"))
|
||||
require.False(t, RString("no").Matches(""))
|
||||
}
|
||||
|
||||
func TestCanAuthenticate(t *testing.T) {
|
||||
tt := []struct {
|
||||
desc string
|
||||
client *mqtt.Client
|
||||
pk packets.Packet
|
||||
n int
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
desc: "allow all local 127.0.0.1",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{}},
|
||||
ok: true,
|
||||
n: 1,
|
||||
},
|
||||
{
|
||||
desc: "allow username/password",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
|
||||
ok: true,
|
||||
n: 5,
|
||||
},
|
||||
{
|
||||
desc: "deny username/password",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
|
||||
ok: false,
|
||||
n: 0,
|
||||
},
|
||||
{
|
||||
desc: "allow all local 127.0.0.1",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
|
||||
ok: true,
|
||||
n: 1,
|
||||
},
|
||||
{
|
||||
desc: "allow username/password",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
|
||||
ok: true,
|
||||
n: 5,
|
||||
},
|
||||
{
|
||||
desc: "deny username/password",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
|
||||
ok: false,
|
||||
n: 0,
|
||||
},
|
||||
{
|
||||
desc: "deny client from address",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("not-mochi"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "111.144.155.166",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{},
|
||||
ok: false,
|
||||
n: 3,
|
||||
},
|
||||
{
|
||||
desc: "allow remote wildcard",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "111.0.0.1",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{},
|
||||
ok: true,
|
||||
n: 4,
|
||||
},
|
||||
{
|
||||
desc: "never allow username",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("banned-user"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{},
|
||||
ok: false,
|
||||
n: 0,
|
||||
},
|
||||
{
|
||||
desc: "matching user in users",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi-co"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
|
||||
ok: true,
|
||||
n: 0,
|
||||
},
|
||||
{
|
||||
desc: "never user in users",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("suspended-user"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("any")}},
|
||||
ok: false,
|
||||
n: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, d := range tt {
|
||||
t.Run(d.desc, func(t *testing.T) {
|
||||
n, ok := checkLedger.AuthOk(d.client, d.pk)
|
||||
require.Equal(t, d.n, n)
|
||||
require.Equal(t, d.ok, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanACL(t *testing.T) {
|
||||
tt := []struct {
|
||||
client *mqtt.Client
|
||||
desc string
|
||||
topic string
|
||||
n int
|
||||
write bool
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
desc: "allow normal write on any other filter",
|
||||
client: &mqtt.Client{},
|
||||
topic: "default/acl/write/access",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "allow normal read on any other filter",
|
||||
client: &mqtt.Client{},
|
||||
topic: "default/acl/read/access",
|
||||
write: false,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "deny user on literal filter",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "a/b/c",
|
||||
},
|
||||
{
|
||||
desc: "deny user on partial filter",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "d/j/f",
|
||||
},
|
||||
{
|
||||
desc: "allow read/write to user path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "mochi/read/write",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "deny read on write-only path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "updates/no/reading",
|
||||
write: false,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "deny read on write-only path ext",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "updates/mochi",
|
||||
write: false,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "allow read on not-acl path (no #)",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "updates",
|
||||
write: false,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "allow write on write-only path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "updates/mochi",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "deny write on read-only path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "readonly",
|
||||
write: true,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "allow read on read-only path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "readonly",
|
||||
write: false,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "allow $sys access to localhost",
|
||||
client: &mqtt.Client{
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "localhost",
|
||||
},
|
||||
},
|
||||
topic: "$SYS/test",
|
||||
write: false,
|
||||
ok: true,
|
||||
n: 1,
|
||||
},
|
||||
{
|
||||
desc: "allow $sys access to admin",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("admin"),
|
||||
},
|
||||
},
|
||||
topic: "$SYS/test",
|
||||
write: false,
|
||||
ok: true,
|
||||
n: 2,
|
||||
},
|
||||
{
|
||||
desc: "deny $sys access to all others",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "$SYS/test",
|
||||
write: false,
|
||||
ok: false,
|
||||
n: 4,
|
||||
},
|
||||
{
|
||||
desc: "allow all with no filter",
|
||||
client: &mqtt.Client{
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "001.002.003.004",
|
||||
},
|
||||
},
|
||||
topic: "any/path",
|
||||
write: true,
|
||||
ok: true,
|
||||
n: 3,
|
||||
},
|
||||
{
|
||||
desc: "use users embedded acl deny",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "secret/mochi",
|
||||
write: true,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "use users embedded acl any",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "any/mochi",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "use users embedded acl write on read-only",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "special/mochi",
|
||||
write: true,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "use users embedded acl read on read-only",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "special/mochi",
|
||||
write: false,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "preference users embedded acl",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "ignored",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, d := range tt {
|
||||
t.Run(d.desc, func(t *testing.T) {
|
||||
n, ok := checkLedger.ACLOk(d.client, d.topic, d.write)
|
||||
require.Equal(t, d.n, n)
|
||||
require.Equal(t, d.ok, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchTopic(t *testing.T) {
|
||||
el, matched := MatchTopic("a/+/c/+", "a/b/c/d")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, []string{"b", "d"}, el)
|
||||
|
||||
el, matched = MatchTopic("a/+/+/+", "a/b/c/d")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, []string{"b", "c", "d"}, el)
|
||||
|
||||
el, matched = MatchTopic("stuff/#", "stuff/things/yeah")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, []string{"things/yeah"}, el)
|
||||
|
||||
el, matched = MatchTopic("a/+/#/+", "a/b/c/d/as/dds")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, []string{"b", "c/d/as/dds"}, el)
|
||||
|
||||
el, matched = MatchTopic("test", "test")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, make([]string, 0), el)
|
||||
|
||||
el, matched = MatchTopic("things/stuff//", "things/stuff/")
|
||||
require.False(t, matched)
|
||||
require.Equal(t, make([]string, 0), el)
|
||||
|
||||
el, matched = MatchTopic("t", "t2")
|
||||
require.False(t, matched)
|
||||
require.Equal(t, make([]string, 0), el)
|
||||
|
||||
el, matched = MatchTopic(" ", " ")
|
||||
require.False(t, matched)
|
||||
require.Equal(t, make([]string, 0), el)
|
||||
}
|
||||
|
||||
var (
|
||||
ledgerStruct = Ledger{
|
||||
Users: Users{
|
||||
"mochi": {
|
||||
Password: "peach",
|
||||
ACL: Filters{
|
||||
"readonly": ReadOnly,
|
||||
"deny": Deny,
|
||||
},
|
||||
},
|
||||
},
|
||||
Auth: AuthRules{
|
||||
{
|
||||
Client: "*",
|
||||
Username: "mochi-co",
|
||||
Password: "melon",
|
||||
Remote: "192.168.1.*",
|
||||
Allow: true,
|
||||
},
|
||||
},
|
||||
ACL: ACLRules{
|
||||
{
|
||||
Client: "*",
|
||||
Username: "mochi-co",
|
||||
Remote: "127.*",
|
||||
Filters: Filters{
|
||||
"readonly": ReadOnly,
|
||||
"writeonly": WriteOnly,
|
||||
"readwrite": ReadWrite,
|
||||
"deny": Deny,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ledgerJSON = []byte(`{"users":{"mochi":{"password":"peach","acl":{"deny":0,"readonly":1}}},"auth":[{"client":"*","username":"mochi-co","remote":"192.168.1.*","password":"melon","allow":true}],"acl":[{"client":"*","username":"mochi-co","remote":"127.*","filters":{"deny":0,"readonly":1,"readwrite":3,"writeonly":2}}]}`)
|
||||
ledgerYAML = []byte(`users:
|
||||
mochi:
|
||||
password: peach
|
||||
acl:
|
||||
deny: 0
|
||||
readonly: 1
|
||||
auth:
|
||||
- client: '*'
|
||||
username: mochi-co
|
||||
remote: 192.168.1.*
|
||||
password: melon
|
||||
allow: true
|
||||
acl:
|
||||
- client: '*'
|
||||
username: mochi-co
|
||||
remote: 127.*
|
||||
filters:
|
||||
deny: 0
|
||||
readonly: 1
|
||||
readwrite: 3
|
||||
writeonly: 2
|
||||
`)
|
||||
)
|
||||
|
||||
func TestLedgerUpdate(t *testing.T) {
|
||||
old := &Ledger{
|
||||
Auth: AuthRules{
|
||||
{Remote: "127.0.0.1", Allow: true},
|
||||
},
|
||||
}
|
||||
|
||||
n := &Ledger{
|
||||
Auth: AuthRules{
|
||||
{Remote: "127.0.0.1", Allow: true},
|
||||
{Remote: "192.168.*", Allow: true},
|
||||
},
|
||||
}
|
||||
|
||||
old.Update(n)
|
||||
require.Len(t, old.Auth, 2)
|
||||
require.Equal(t, RString("192.168.*"), old.Auth[1].Remote)
|
||||
require.NotSame(t, n, old)
|
||||
}
|
||||
|
||||
func TestLedgerToJSON(t *testing.T) {
|
||||
data, err := ledgerStruct.ToJSON()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ledgerJSON, data)
|
||||
}
|
||||
|
||||
func TestLedgerToYAML(t *testing.T) {
|
||||
data, err := ledgerStruct.ToYAML()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ledgerYAML, data)
|
||||
}
|
||||
|
||||
func TestLedgerUnmarshalFromYAML(t *testing.T) {
|
||||
l := new(Ledger)
|
||||
err := l.Unmarshal(ledgerYAML)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &ledgerStruct, l)
|
||||
require.NotSame(t, l, &ledgerStruct)
|
||||
}
|
||||
|
||||
func TestLedgerUnmarshalFromJSON(t *testing.T) {
|
||||
l := new(Ledger)
|
||||
err := l.Unmarshal(ledgerJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &ledgerStruct, l)
|
||||
require.NotSame(t, l, &ledgerStruct)
|
||||
}
|
||||
|
||||
func TestLedgerUnmarshalNil(t *testing.T) {
|
||||
l := new(Ledger)
|
||||
err := l.Unmarshal([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, new(Ledger), l)
|
||||
}
|
||||
237
hooks/debug/debug.go
Normal file
237
hooks/debug/debug.go
Normal file
@@ -0,0 +1,237 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"testmqtt/hooks/storage"
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
)
|
||||
|
||||
// Options contains configuration settings for the debug output.
|
||||
type Options struct {
|
||||
Enable bool `yaml:"enable" json:"enable"` // non-zero field for enabling hook using file-based config
|
||||
ShowPacketData bool `yaml:"show_packet_data" json:"show_packet_data"` // include decoded packet data (default false)
|
||||
ShowPings bool `yaml:"show_pings" json:"show_pings"` // show ping requests and responses (default false)
|
||||
ShowPasswords bool `yaml:"show_passwords" json:"show_passwords"` // show connecting user passwords (default false)
|
||||
}
|
||||
|
||||
// Hook is a debugging hook which logs additional low-level information from the server.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options
|
||||
Log *slog.Logger
|
||||
}
|
||||
|
||||
// ID returns the ID of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "debug"
|
||||
}
|
||||
|
||||
// Provides indicates that this hook provides all methods.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Init is called when the hook is initialized.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = new(Options)
|
||||
}
|
||||
|
||||
h.config = config.(*Options)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetOpts is called when the hook receives inheritable server parameters.
|
||||
func (h *Hook) SetOpts(l *slog.Logger, opts *mqtt.HookOptions) {
|
||||
h.Log = l
|
||||
h.Log.Debug("", "method", "SetOpts")
|
||||
}
|
||||
|
||||
// Stop is called when the hook is stopped.
|
||||
func (h *Hook) Stop() error {
|
||||
h.Log.Debug("", "method", "Stop")
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnStarted is called when the server starts.
|
||||
func (h *Hook) OnStarted() {
|
||||
h.Log.Debug("", "method", "OnStarted")
|
||||
}
|
||||
|
||||
// OnStopped is called when the server stops.
|
||||
func (h *Hook) OnStopped() {
|
||||
h.Log.Debug("", "method", "OnStopped")
|
||||
}
|
||||
|
||||
// OnPacketRead is called when a new packet is received from a client.
|
||||
func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
|
||||
if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
h.Log.Debug(fmt.Sprintf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk))
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// OnPacketSent is called when a packet is sent to a client.
|
||||
func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) {
|
||||
if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
|
||||
return
|
||||
}
|
||||
|
||||
h.Log.Debug(fmt.Sprintf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnRetainMessage is called when a published message is retained (or retain deleted/modified).
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
h.Log.Debug("retained message on topic", "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnQosPublish is called when a publish packet with Qos is issued to a subscriber.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
h.Log.Debug("inflight out", "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnQosComplete is called when the Qos flow for a message has been completed.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug("inflight complete", "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnQosDropped is called the Qos flow for a message expires.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug("inflight dropped", "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnLWTSent is called when a Will Message has been issued from a disconnecting client.
|
||||
func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug("sent lwt for client", "method", "OnLWTSent", "client", cl.ID)
|
||||
}
|
||||
|
||||
// OnRetainedExpired is called when the server clears expired retained messages.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
h.Log.Debug("retained message expired", "method", "OnRetainedExpired", "topic", filter)
|
||||
}
|
||||
|
||||
// OnClientExpired is called when the server clears an expired client.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
h.Log.Debug("client session expired", "method", "OnClientExpired", "client", cl.ID)
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores clients from a store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
h.Log.Debug("", "method", "StoredClients")
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSubscriptions is called when the server restores subscriptions from a store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
h.Log.Debug("", "method", "StoredSubscriptions")
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredRetainedMessages is called when the server restores retained messages from a store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
h.Log.Debug("", "method", "StoredRetainedMessages")
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredInflightMessages is called when the server restores inflight messages from a store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
h.Log.Debug("", "method", "StoredInflightMessages")
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSysInfo is called when the server restores system info from a store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
h.Log.Debug("", "method", "StoredSysInfo")
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// packetMeta adds additional type-specific metadata to the debug logs.
|
||||
func (h *Hook) packetMeta(pk packets.Packet) map[string]any {
|
||||
m := map[string]any{}
|
||||
switch pk.FixedHeader.Type {
|
||||
case packets.Connect:
|
||||
m["id"] = pk.Connect.ClientIdentifier
|
||||
m["clean"] = pk.Connect.Clean
|
||||
m["keepalive"] = pk.Connect.Keepalive
|
||||
m["version"] = pk.ProtocolVersion
|
||||
m["username"] = string(pk.Connect.Username)
|
||||
if h.config.ShowPasswords {
|
||||
m["password"] = string(pk.Connect.Password)
|
||||
}
|
||||
if pk.Connect.WillFlag {
|
||||
m["will_topic"] = pk.Connect.WillTopic
|
||||
m["will_payload"] = string(pk.Connect.WillPayload)
|
||||
}
|
||||
case packets.Publish:
|
||||
m["topic"] = pk.TopicName
|
||||
m["payload"] = string(pk.Payload)
|
||||
m["raw"] = pk.Payload
|
||||
m["qos"] = pk.FixedHeader.Qos
|
||||
m["id"] = pk.PacketID
|
||||
case packets.Connack:
|
||||
fallthrough
|
||||
case packets.Disconnect:
|
||||
fallthrough
|
||||
case packets.Puback:
|
||||
fallthrough
|
||||
case packets.Pubrec:
|
||||
fallthrough
|
||||
case packets.Pubrel:
|
||||
fallthrough
|
||||
case packets.Pubcomp:
|
||||
m["id"] = pk.PacketID
|
||||
m["reason"] = int(pk.ReasonCode)
|
||||
if pk.ReasonCode > packets.CodeSuccess.Code && pk.ProtocolVersion == 5 {
|
||||
m["reason_string"] = pk.Properties.ReasonString
|
||||
}
|
||||
case packets.Subscribe:
|
||||
f := map[string]int{}
|
||||
ids := map[string]int{}
|
||||
for _, v := range pk.Filters {
|
||||
f[v.Filter] = int(v.Qos)
|
||||
ids[v.Filter] = v.Identifier
|
||||
}
|
||||
m["filters"] = f
|
||||
m["subids"] = f
|
||||
|
||||
case packets.Unsubscribe:
|
||||
f := []string{}
|
||||
for _, v := range pk.Filters {
|
||||
f = append(f, v.Filter)
|
||||
}
|
||||
m["filters"] = f
|
||||
case packets.Suback:
|
||||
fallthrough
|
||||
case packets.Unsuback:
|
||||
r := []int{}
|
||||
for _, v := range pk.ReasonCodes {
|
||||
r = append(r, int(v))
|
||||
}
|
||||
m["reasons"] = r
|
||||
case packets.Auth:
|
||||
// tbd
|
||||
}
|
||||
|
||||
if h.config.ShowPacketData {
|
||||
m["packet"] = pk
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
576
hooks/storage/badger/badger.go
Normal file
576
hooks/storage/badger/badger.go
Normal file
@@ -0,0 +1,576 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co, gsagula, werbenhu
|
||||
|
||||
package badger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
badgerdb "github.com/dgraph-io/badger/v4"
|
||||
"testmqtt/hooks/storage"
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
"testmqtt/system"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultDbFile is the default file path for the badger db file.
|
||||
defaultDbFile = ".badger"
|
||||
defaultGcInterval = 5 * 60 // gc interval in seconds
|
||||
defaultGcDiscardRatio = 0.5
|
||||
)
|
||||
|
||||
// clientKey returns a primary key for a client.
|
||||
func clientKey(cl *mqtt.Client) string {
|
||||
return storage.ClientKey + "_" + cl.ID
|
||||
}
|
||||
|
||||
// subscriptionKey returns a primary key for a subscription.
|
||||
func subscriptionKey(cl *mqtt.Client, filter string) string {
|
||||
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
|
||||
}
|
||||
|
||||
// retainedKey returns a primary key for a retained message.
|
||||
func retainedKey(topic string) string {
|
||||
return storage.RetainedKey + "_" + topic
|
||||
}
|
||||
|
||||
// inflightKey returns a primary key for an inflight message.
|
||||
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
|
||||
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
|
||||
}
|
||||
|
||||
// sysInfoKey returns a primary key for system info.
|
||||
func sysInfoKey() string {
|
||||
return storage.SysInfoKey
|
||||
}
|
||||
|
||||
// Serializable is an interface for objects that can be serialized and deserialized.
|
||||
type Serializable interface {
|
||||
UnmarshalBinary([]byte) error
|
||||
MarshalBinary() (data []byte, err error)
|
||||
}
|
||||
|
||||
// Options contains configuration settings for the BadgerDB instance.
|
||||
type Options struct {
|
||||
Options *badgerdb.Options
|
||||
Path string `yaml:"path" json:"path"`
|
||||
// GcDiscardRatio specifies the ratio of log discard compared to the maximum possible log discard.
|
||||
// Setting it to a higher value would result in fewer space reclaims, while setting it to a lower value
|
||||
// would result in more space reclaims at the cost of increased activity on the LSM tree.
|
||||
// discardRatio must be in the range (0.0, 1.0), both endpoints excluded, otherwise, it will be set to the default value of 0.5.
|
||||
GcDiscardRatio float64 `yaml:"gc_discard_ratio" json:"gc_discard_ratio"`
|
||||
GcInterval int64 `yaml:"gc_interval" json:"gc_interval"`
|
||||
}
|
||||
|
||||
// Hook is a persistent storage hook based using BadgerDB file store as a backend.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options // options for configuring the BadgerDB instance.
|
||||
gcTicker *time.Ticker // Ticker for BadgerDB garbage collection.
|
||||
db *badgerdb.DB // the BadgerDB instance.
|
||||
}
|
||||
|
||||
// ID returns the id of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "badger-db"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnSessionEstablished,
|
||||
mqtt.OnDisconnect,
|
||||
mqtt.OnSubscribed,
|
||||
mqtt.OnUnsubscribed,
|
||||
mqtt.OnRetainMessage,
|
||||
mqtt.OnWillSent,
|
||||
mqtt.OnQosPublish,
|
||||
mqtt.OnQosComplete,
|
||||
mqtt.OnQosDropped,
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
mqtt.StoredSubscriptions,
|
||||
mqtt.StoredSysInfo,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// GcLoop periodically runs the garbage collection process to reclaim space in the value log files.
|
||||
// It uses a ticker to trigger the garbage collection at regular intervals specified by the configuration.
|
||||
// Refer to: https://dgraph.io/docs/badger/get-started/#garbage-collection
|
||||
func (h *Hook) gcLoop() {
|
||||
for range h.gcTicker.C {
|
||||
again:
|
||||
// Run the garbage collection process with a threshold.
|
||||
// If the process returns nil (success), repeat the process.
|
||||
err := h.db.RunValueLogGC(h.config.GcDiscardRatio)
|
||||
if err == nil {
|
||||
goto again // Retry garbage collection if successful.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes and connects to the badger instance.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
h.config = new(Options)
|
||||
} else {
|
||||
h.config = config.(*Options)
|
||||
}
|
||||
|
||||
if len(h.config.Path) == 0 {
|
||||
h.config.Path = defaultDbFile
|
||||
}
|
||||
|
||||
if h.config.GcInterval == 0 {
|
||||
h.config.GcInterval = defaultGcInterval
|
||||
}
|
||||
|
||||
if h.config.GcDiscardRatio <= 0.0 || h.config.GcDiscardRatio >= 1.0 {
|
||||
h.config.GcDiscardRatio = defaultGcDiscardRatio
|
||||
}
|
||||
|
||||
if h.config.Options == nil {
|
||||
defaultOpts := badgerdb.DefaultOptions(h.config.Path)
|
||||
h.config.Options = &defaultOpts
|
||||
}
|
||||
h.config.Options.Logger = h
|
||||
|
||||
var err error
|
||||
h.db, err = badgerdb.Open(*h.config.Options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.gcTicker = time.NewTicker(time.Duration(h.config.GcInterval) * time.Second)
|
||||
go h.gcLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop closes the badger instance.
|
||||
func (h *Hook) Stop() error {
|
||||
if h.gcTicker != nil {
|
||||
h.gcTicker.Stop()
|
||||
}
|
||||
return h.db.Close()
|
||||
}
|
||||
|
||||
// OnSessionEstablished adds a client to the store when their session is established.
|
||||
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := cl.Properties.Props.Copy(false)
|
||||
in := &storage.Client{
|
||||
ID: cl.ID,
|
||||
T: storage.ClientKey,
|
||||
Remote: cl.Net.Remote,
|
||||
Listener: cl.Net.Listener,
|
||||
Username: cl.Properties.Username,
|
||||
Clean: cl.Properties.Clean,
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
Properties: storage.ClientProperties{
|
||||
SessionExpiryInterval: props.SessionExpiryInterval,
|
||||
AuthenticationMethod: props.AuthenticationMethod,
|
||||
AuthenticationData: props.AuthenticationData,
|
||||
RequestProblemInfo: props.RequestProblemInfo,
|
||||
RequestResponseInfo: props.RequestResponseInfo,
|
||||
ReceiveMaximum: props.ReceiveMaximum,
|
||||
TopicAliasMaximum: props.TopicAliasMaximum,
|
||||
User: props.User,
|
||||
MaximumPacketSize: props.MaximumPacketSize,
|
||||
},
|
||||
Will: storage.ClientWill(cl.Properties.Will),
|
||||
}
|
||||
|
||||
err := h.setKv(clientKey(cl), in)
|
||||
if err != nil {
|
||||
h.Log.Error("failed to upsert client data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if their session has expired.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
h.updateClient(cl)
|
||||
|
||||
if !expire {
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(cl.StopCause(), packets.ErrSessionTakenOver) {
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.delKv(clientKey(cl))
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
|
||||
_ = h.setKv(in.ID, in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
_ = h.delKv(subscriptionKey(cl, pk.Filters[i].Filter))
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if r == -1 {
|
||||
_ = h.delKv(retainedKey(pk.TopicName))
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: retainedKey(pk.TopicName),
|
||||
T: storage.RetainedKey,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Created: pk.Created,
|
||||
Origin: pk.Origin,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
_ = h.setKv(in.ID, in)
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: inflightKey(cl, pk),
|
||||
T: storage.InflightKey,
|
||||
Origin: pk.Origin,
|
||||
PacketID: pk.PacketID,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Sent: sent,
|
||||
Created: pk.Created,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
_ = h.setKv(in.ID, in)
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.delKv(inflightKey(cl, pk))
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
}
|
||||
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
in := &storage.SystemInfo{
|
||||
ID: sysInfoKey(),
|
||||
T: storage.SysInfoKey,
|
||||
Info: *sys.Clone(),
|
||||
}
|
||||
|
||||
_ = h.setKv(in.ID, in)
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.delKv(retainedKey(filter))
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.delKv(clientKey(cl))
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.iterKv(storage.ClientKey, func(value []byte) error {
|
||||
obj := storage.Client{}
|
||||
err = obj.UnmarshalBinary(value)
|
||||
if err == nil {
|
||||
v = append(v, obj)
|
||||
}
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
v = make([]storage.Subscription, 0)
|
||||
err = h.iterKv(storage.SubscriptionKey, func(value []byte) error {
|
||||
obj := storage.Subscription{}
|
||||
err = obj.UnmarshalBinary(value)
|
||||
if err == nil {
|
||||
v = append(v, obj)
|
||||
}
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
v = make([]storage.Message, 0)
|
||||
err = h.iterKv(storage.RetainedKey, func(value []byte) error {
|
||||
obj := storage.Message{}
|
||||
err = obj.UnmarshalBinary(value)
|
||||
if err == nil {
|
||||
v = append(v, obj)
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil && !errors.Is(err, badgerdb.ErrKeyNotFound) {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
v = make([]storage.Message, 0)
|
||||
err = h.iterKv(storage.InflightKey, func(value []byte) error {
|
||||
obj := storage.Message{}
|
||||
err = obj.UnmarshalBinary(value)
|
||||
if err == nil {
|
||||
v = append(v, obj)
|
||||
}
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.getKv(storage.SysInfoKey, &v)
|
||||
if err != nil && !errors.Is(err, badgerdb.ErrKeyNotFound) {
|
||||
return
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// Errorf satisfies the badger interface for an error logger.
|
||||
func (h *Hook) Errorf(m string, v ...any) {
|
||||
h.Log.Error(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
|
||||
}
|
||||
|
||||
// Warningf satisfies the badger interface for a warning logger.
|
||||
func (h *Hook) Warningf(m string, v ...any) {
|
||||
h.Log.Warn(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
}
|
||||
|
||||
// Infof satisfies the badger interface for an info logger.
|
||||
func (h *Hook) Infof(m string, v ...any) {
|
||||
h.Log.Info(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
}
|
||||
|
||||
// Debugf satisfies the badger interface for a debug logger.
|
||||
func (h *Hook) Debugf(m string, v ...any) {
|
||||
h.Log.Debug(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
}
|
||||
|
||||
// setKv stores a key-value pair in the database.
|
||||
func (h *Hook) setKv(k string, v storage.Serializable) error {
|
||||
err := h.db.Update(func(txn *badgerdb.Txn) error {
|
||||
data, _ := v.MarshalBinary()
|
||||
return txn.Set([]byte(k), data)
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error("failed to upsert data", "error", err, "key", k)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// delKv deletes a key-value pair from the database.
|
||||
func (h *Hook) delKv(k string) error {
|
||||
err := h.db.Update(func(txn *badgerdb.Txn) error {
|
||||
return txn.Delete([]byte(k))
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
h.Log.Error("failed to delete data", "error", err, "key", k)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// getKv retrieves the value associated with a key from the database.
|
||||
func (h *Hook) getKv(k string, v storage.Serializable) error {
|
||||
return h.db.View(func(txn *badgerdb.Txn) error {
|
||||
item, err := txn.Get([]byte(k))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value, err := item.ValueCopy(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return v.UnmarshalBinary(value)
|
||||
})
|
||||
}
|
||||
|
||||
// iterKv iterates over key-value pairs with keys having the specified prefix in the database.
|
||||
func (h *Hook) iterKv(prefix string, visit func([]byte) error) error {
|
||||
err := h.db.View(func(txn *badgerdb.Txn) error {
|
||||
iterator := txn.NewIterator(badgerdb.DefaultIteratorOptions)
|
||||
defer iterator.Close()
|
||||
|
||||
for iterator.Seek([]byte(prefix)); iterator.ValidForPrefix([]byte(prefix)); iterator.Next() {
|
||||
item := iterator.Item()
|
||||
value, err := item.ValueCopy(nil)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := visit(value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error("failed to find data", "error", err, "prefix", prefix)
|
||||
}
|
||||
return err
|
||||
}
|
||||
809
hooks/storage/badger/badger_test.go
Normal file
809
hooks/storage/badger/badger_test.go
Normal file
@@ -0,0 +1,809 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co, werbenhu
|
||||
|
||||
package badger
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
badgerdb "github.com/dgraph-io/badger/v4"
|
||||
"testmqtt/hooks/storage"
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
"testmqtt/system"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
|
||||
)
|
||||
|
||||
func teardown(t *testing.T, path string, h *Hook) {
|
||||
_ = h.Stop()
|
||||
_ = h.db.Close()
|
||||
err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSetGetDelKv(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Init(nil)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
key := "testKey"
|
||||
value := &storage.Client{ID: "cl1"}
|
||||
err := h.setKv(key, value)
|
||||
require.NoError(t, err)
|
||||
|
||||
var client storage.Client
|
||||
err = h.getKv(key, &client)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "cl1", client.ID)
|
||||
|
||||
err = h.delKv(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.getKv(key, &client)
|
||||
require.ErrorIs(t, badgerdb.ErrKeyNotFound, err)
|
||||
}
|
||||
|
||||
func TestClientKey(t *testing.T) {
|
||||
k := clientKey(&mqtt.Client{ID: "cl1"})
|
||||
require.Equal(t, storage.ClientKey+"_cl1", k)
|
||||
}
|
||||
|
||||
func TestSubscriptionKey(t *testing.T) {
|
||||
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
|
||||
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
|
||||
}
|
||||
|
||||
func TestRetainedKey(t *testing.T) {
|
||||
k := retainedKey("a/b/c")
|
||||
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
|
||||
}
|
||||
|
||||
func TestInflightKey(t *testing.T) {
|
||||
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
|
||||
require.Equal(t, storage.InflightKey+"_cl1:1", k)
|
||||
}
|
||||
|
||||
func TestSysInfoKey(t *testing.T) {
|
||||
require.Equal(t, storage.SysInfoKey, sysInfoKey())
|
||||
}
|
||||
|
||||
func TestID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.Equal(t, "badger-db", h.ID())
|
||||
}
|
||||
|
||||
func TestProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.True(t, h.Provides(mqtt.OnSessionEstablished))
|
||||
require.True(t, h.Provides(mqtt.OnDisconnect))
|
||||
require.True(t, h.Provides(mqtt.OnSubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnUnsubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnRetainMessage))
|
||||
require.True(t, h.Provides(mqtt.OnQosPublish))
|
||||
require.True(t, h.Provides(mqtt.OnQosComplete))
|
||||
require.True(t, h.Provides(mqtt.OnQosDropped))
|
||||
require.True(t, h.Provides(mqtt.OnSysInfoTick))
|
||||
require.True(t, h.Provides(mqtt.StoredClients))
|
||||
require.True(t, h.Provides(mqtt.StoredInflightMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredSubscriptions))
|
||||
require.True(t, h.Provides(mqtt.StoredSysInfo))
|
||||
require.False(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
}
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestInitBadOption(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(&Options{
|
||||
Options: &badgerdb.Options{
|
||||
NumCompactors: 1,
|
||||
},
|
||||
})
|
||||
// Cannot have 1 compactor. Need at least 2
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestInitUseDefaults(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
require.Equal(t, defaultDbFile, h.config.Path)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.getKv(clientKey(client), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
require.Equal(t, client.Properties.Username, r.Username)
|
||||
require.Equal(t, client.Properties.Clean, r.Clean)
|
||||
require.Equal(t, client.Net.Remote, r.Remote)
|
||||
require.Equal(t, client.Net.Listener, r.Listener)
|
||||
require.NotSame(t, client, r)
|
||||
|
||||
h.OnDisconnect(client, nil, false)
|
||||
r2 := new(storage.Client)
|
||||
err = h.getKv(clientKey(client), r2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
|
||||
h.OnDisconnect(client, nil, true)
|
||||
r3 := new(storage.Client)
|
||||
err = h.getKv(clientKey(client), r3)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, badgerdb.ErrKeyNotFound, err)
|
||||
require.Empty(t, r3.ID)
|
||||
}
|
||||
|
||||
func TestOnClientExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
cl := &mqtt.Client{ID: "cl1"}
|
||||
clientKey := clientKey(cl)
|
||||
|
||||
err = h.setKv(clientKey, &storage.Client{ID: cl.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.getKv(clientKey, r)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cl.ID, r.ID)
|
||||
|
||||
h.OnClientExpired(cl)
|
||||
err = h.getKv(clientKey, r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, badgerdb.ErrKeyNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnWillSent(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
c1 := client
|
||||
c1.Properties.Will.Flag = 1
|
||||
h.OnWillSent(c1, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.getKv(clientKey(client), r)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, uint32(1), r.Will.Flag)
|
||||
require.NotSame(t, client, r)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectSessionTakenOver(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
testClient := &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
testClient.Stop(packets.ErrSessionTakenOver)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(testClient, nil, true)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
r := new(storage.Subscription)
|
||||
|
||||
err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.Client)
|
||||
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
|
||||
require.Equal(t, byte(0), r.Qos)
|
||||
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, badgerdb.ErrKeyNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnSubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnRetainMessage(client, pk, 1)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.getKv(retainedKey(pk.TopicName), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
err = h.getKv(retainedKey(pk.TopicName), r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, badgerdb.ErrKeyNotFound)
|
||||
|
||||
// coverage: delete deleted
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
err = h.getKv(retainedKey(pk.TopicName), r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, badgerdb.ErrKeyNotFound)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
m := &storage.Message{
|
||||
ID: retainedKey("a/b/c"),
|
||||
T: storage.RetainedKey,
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
err = h.setKv(m.ID, m)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.getKv(m.ID, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, m.TopicName, r.TopicName)
|
||||
|
||||
h.OnRetainedExpired(m.TopicName)
|
||||
err = h.getKv(m.ID, r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, badgerdb.ErrKeyNotFound)
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
Qos: 2,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.getKv(inflightKey(client, pk), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
// ensure dates are properly saved
|
||||
require.True(t, r.Sent > 0)
|
||||
require.True(t, time.Now().Unix()-1 < r.Sent)
|
||||
|
||||
// OnQosDropped is a passthrough to OnQosComplete here
|
||||
h.OnQosDropped(client, pk)
|
||||
err = h.getKv(inflightKey(client, pk), r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, badgerdb.ErrKeyNotFound)
|
||||
}
|
||||
|
||||
func TestOnQosPublishNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosCompleteNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
info := &system.Info{
|
||||
Version: "2.0.0",
|
||||
BytesReceived: 100,
|
||||
}
|
||||
|
||||
h.OnSysInfoTick(info)
|
||||
|
||||
r := new(storage.SystemInfo)
|
||||
err = h.getKv(storage.SysInfoKey, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, info.Version, r.Version)
|
||||
require.Equal(t, info.BytesReceived, r.BytesReceived)
|
||||
require.NotSame(t, info, r)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestStoredClients(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with clients
|
||||
err = h.setKv(storage.ClientKey+"_cl1", &storage.Client{ID: "cl1", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.ClientKey+"_cl2", &storage.Client{ID: "cl2", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.ClientKey+"_cl3", &storage.Client{ID: "cl3", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "cl1", r[0].ID)
|
||||
require.Equal(t, "cl2", r[1].ID)
|
||||
require.Equal(t, "cl3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredClientsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with subscriptions
|
||||
err = h.setKv(storage.SubscriptionKey+"_sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.SubscriptionKey+"_sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.SubscriptionKey+"_sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "sub1", r[0].ID)
|
||||
require.Equal(t, "sub2", r[1].ID)
|
||||
require.Equal(t, "sub3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.RetainedKey+"_m2", &storage.Message{ID: "m2", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.RetainedKey+"_m3", &storage.Message{ID: "m3", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "m1", r[0].ID)
|
||||
require.Equal(t, "m2", r[1].ID)
|
||||
require.Equal(t, "m3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.setKv(storage.InflightKey+"_i1", &storage.Message{ID: "i1", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.InflightKey+"_i2", &storage.Message{ID: "i2", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "i1", r[0].ID)
|
||||
require.Equal(t, "i2", r[1].ID)
|
||||
require.Equal(t, "i3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfo(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.setKv(storage.SysInfoKey, &storage.SystemInfo{
|
||||
ID: storage.SysInfoKey,
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
},
|
||||
T: storage.SysInfoKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2.0.0", r.Info.Version)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestErrorf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Errorf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestWarningf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Warningf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestInfof(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Infof("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestDebugf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Debugf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestGcLoop(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
opts := badgerdb.DefaultOptions(defaultDbFile)
|
||||
opts.ValueLogFileSize = 1 << 20
|
||||
h.Init(&Options{
|
||||
GcInterval: 2, // Set the interval for garbage collection.
|
||||
Options: &opts,
|
||||
})
|
||||
defer teardown(t, defaultDbFile, h)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
h.OnDisconnect(client, nil, true)
|
||||
time.Sleep(3 * time.Second)
|
||||
}
|
||||
|
||||
func TestGetSetDelKv(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv("testKey", &storage.Client{ID: "testId"})
|
||||
require.NoError(t, err)
|
||||
|
||||
var obj storage.Client
|
||||
err = h.getKv("testKey", &obj)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.delKv("testKey")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.getKv("testKey", &obj)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, badgerdb.ErrKeyNotFound, err)
|
||||
}
|
||||
|
||||
func TestIterKv(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(nil)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
require.NoError(t, err)
|
||||
|
||||
h.setKv("prefix_a_1", &storage.Client{ID: "1"})
|
||||
h.setKv("prefix_a_2", &storage.Client{ID: "2"})
|
||||
h.setKv("prefix_b_2", &storage.Client{ID: "3"})
|
||||
|
||||
var clients []storage.Client
|
||||
err = h.iterKv("prefix_a", func(data []byte) error {
|
||||
var item storage.Client
|
||||
item.UnmarshalBinary(data)
|
||||
clients = append(clients, item)
|
||||
return nil
|
||||
})
|
||||
require.Equal(t, 2, len(clients))
|
||||
require.NoError(t, err)
|
||||
|
||||
visitErr := errors.New("iter visit error")
|
||||
err = h.iterKv("prefix_b", func(data []byte) error {
|
||||
return visitErr
|
||||
})
|
||||
require.ErrorIs(t, visitErr, err)
|
||||
}
|
||||
525
hooks/storage/bolt/bolt.go
Normal file
525
hooks/storage/bolt/bolt.go
Normal file
@@ -0,0 +1,525 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co, werbenhu
|
||||
|
||||
// Package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
|
||||
package bolt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"testmqtt/hooks/storage"
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
"testmqtt/system"
|
||||
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBucketNotFound = errors.New("bucket not found")
|
||||
ErrKeyNotFound = errors.New("key not found")
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultDbFile is the default file path for the boltdb file.
|
||||
defaultDbFile = ".bolt"
|
||||
|
||||
// defaultTimeout is the default time to hold a connection to the file.
|
||||
defaultTimeout = 250 * time.Millisecond
|
||||
|
||||
defaultBucket = "mochi"
|
||||
)
|
||||
|
||||
// clientKey returns a primary key for a client.
|
||||
func clientKey(cl *mqtt.Client) string {
|
||||
return storage.ClientKey + "_" + cl.ID
|
||||
}
|
||||
|
||||
// subscriptionKey returns a primary key for a subscription.
|
||||
func subscriptionKey(cl *mqtt.Client, filter string) string {
|
||||
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
|
||||
}
|
||||
|
||||
// retainedKey returns a primary key for a retained message.
|
||||
func retainedKey(topic string) string {
|
||||
return storage.RetainedKey + "_" + topic
|
||||
}
|
||||
|
||||
// inflightKey returns a primary key for an inflight message.
|
||||
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
|
||||
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
|
||||
}
|
||||
|
||||
// sysInfoKey returns a primary key for system info.
|
||||
func sysInfoKey() string {
|
||||
return storage.SysInfoKey
|
||||
}
|
||||
|
||||
// Options contains configuration settings for the bolt instance.
|
||||
type Options struct {
|
||||
Options *bbolt.Options
|
||||
Bucket string `yaml:"bucket" json:"bucket"`
|
||||
Path string `yaml:"path" json:"path"`
|
||||
}
|
||||
|
||||
// Hook is a persistent storage hook based using boltdb file store as a backend.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options // options for configuring the boltdb instance.
|
||||
db *bbolt.DB // the boltdb instance.
|
||||
}
|
||||
|
||||
// ID returns the id of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "bolt-db"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnSessionEstablished,
|
||||
mqtt.OnDisconnect,
|
||||
mqtt.OnSubscribed,
|
||||
mqtt.OnUnsubscribed,
|
||||
mqtt.OnRetainMessage,
|
||||
mqtt.OnWillSent,
|
||||
mqtt.OnQosPublish,
|
||||
mqtt.OnQosComplete,
|
||||
mqtt.OnQosDropped,
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
mqtt.StoredSubscriptions,
|
||||
mqtt.StoredSysInfo,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// Init initializes and connects to the boltdb instance.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = new(Options)
|
||||
}
|
||||
|
||||
h.config = config.(*Options)
|
||||
if h.config.Options == nil {
|
||||
h.config.Options = &bbolt.Options{
|
||||
Timeout: defaultTimeout,
|
||||
}
|
||||
}
|
||||
if len(h.config.Path) == 0 {
|
||||
h.config.Path = defaultDbFile
|
||||
}
|
||||
|
||||
if len(h.config.Bucket) == 0 {
|
||||
h.config.Bucket = defaultBucket
|
||||
}
|
||||
|
||||
var err error
|
||||
h.db, err = bbolt.Open(h.config.Path, 0600, h.config.Options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = h.db.Update(func(tx *bbolt.Tx) error {
|
||||
_, err := tx.CreateBucketIfNotExists([]byte(h.config.Bucket))
|
||||
return err
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Stop closes the boltdb instance.
|
||||
func (h *Hook) Stop() error {
|
||||
err := h.db.Close()
|
||||
h.db = nil
|
||||
return err
|
||||
}
|
||||
|
||||
// OnSessionEstablished adds a client to the store when their session is established.
|
||||
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := cl.Properties.Props.Copy(false)
|
||||
in := &storage.Client{
|
||||
ID: cl.ID,
|
||||
T: storage.ClientKey,
|
||||
Remote: cl.Net.Remote,
|
||||
Listener: cl.Net.Listener,
|
||||
Username: cl.Properties.Username,
|
||||
Clean: cl.Properties.Clean,
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
Properties: storage.ClientProperties{
|
||||
SessionExpiryInterval: props.SessionExpiryInterval,
|
||||
AuthenticationMethod: props.AuthenticationMethod,
|
||||
AuthenticationData: props.AuthenticationData,
|
||||
RequestProblemInfo: props.RequestProblemInfo,
|
||||
RequestResponseInfo: props.RequestResponseInfo,
|
||||
ReceiveMaximum: props.ReceiveMaximum,
|
||||
TopicAliasMaximum: props.TopicAliasMaximum,
|
||||
User: props.User,
|
||||
MaximumPacketSize: props.MaximumPacketSize,
|
||||
},
|
||||
Will: storage.ClientWill(cl.Properties.Will),
|
||||
}
|
||||
|
||||
_ = h.setKv(clientKey(cl), in)
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if they were using a clean session.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if !expire {
|
||||
return
|
||||
}
|
||||
|
||||
if cl.StopCause() == packets.ErrSessionTakenOver {
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.delKv(clientKey(cl))
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
_ = h.setKv(in.ID, in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
_ = h.delKv(subscriptionKey(cl, pk.Filters[i].Filter))
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if r == -1 {
|
||||
_ = h.delKv(retainedKey(pk.TopicName))
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: retainedKey(pk.TopicName),
|
||||
T: storage.RetainedKey,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Created: pk.Created,
|
||||
Origin: pk.Origin,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
_ = h.setKv(in.ID, in)
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: inflightKey(cl, pk),
|
||||
T: storage.InflightKey,
|
||||
Origin: pk.Origin,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Sent: sent,
|
||||
Created: pk.Created,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
_ = h.setKv(in.ID, in)
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.delKv(inflightKey(cl, pk))
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
}
|
||||
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
in := &storage.SystemInfo{
|
||||
ID: sysInfoKey(),
|
||||
T: storage.SysInfoKey,
|
||||
Info: *sys,
|
||||
}
|
||||
|
||||
_ = h.setKv(in.ID, in)
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
_ = h.delKv(retainedKey(filter))
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
_ = h.delKv(clientKey(cl))
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return v, storage.ErrDBFileNotOpen
|
||||
}
|
||||
|
||||
err = h.iterKv(storage.ClientKey, func(value []byte) error {
|
||||
obj := storage.Client{}
|
||||
err = obj.UnmarshalBinary(value)
|
||||
if err == nil {
|
||||
v = append(v, obj)
|
||||
}
|
||||
return err
|
||||
})
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return v, storage.ErrDBFileNotOpen
|
||||
}
|
||||
|
||||
v = make([]storage.Subscription, 0)
|
||||
err = h.iterKv(storage.SubscriptionKey, func(value []byte) error {
|
||||
obj := storage.Subscription{}
|
||||
err = obj.UnmarshalBinary(value)
|
||||
if err == nil {
|
||||
v = append(v, obj)
|
||||
}
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return v, storage.ErrDBFileNotOpen
|
||||
}
|
||||
|
||||
v = make([]storage.Message, 0)
|
||||
err = h.iterKv(storage.RetainedKey, func(value []byte) error {
|
||||
obj := storage.Message{}
|
||||
err = obj.UnmarshalBinary(value)
|
||||
if err == nil {
|
||||
v = append(v, obj)
|
||||
}
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return v, storage.ErrDBFileNotOpen
|
||||
}
|
||||
|
||||
v = make([]storage.Message, 0)
|
||||
err = h.iterKv(storage.InflightKey, func(value []byte) error {
|
||||
obj := storage.Message{}
|
||||
err = obj.UnmarshalBinary(value)
|
||||
if err == nil {
|
||||
v = append(v, obj)
|
||||
}
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return v, storage.ErrDBFileNotOpen
|
||||
}
|
||||
|
||||
err = h.getKv(storage.SysInfoKey, &v)
|
||||
if err != nil && !errors.Is(err, ErrKeyNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// setKv stores a key-value pair in the database.
|
||||
func (h *Hook) setKv(k string, v storage.Serializable) error {
|
||||
err := h.db.Update(func(tx *bbolt.Tx) error {
|
||||
|
||||
bucket := tx.Bucket([]byte(h.config.Bucket))
|
||||
data, _ := v.MarshalBinary()
|
||||
err := bucket.Put([]byte(k), data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error("failed to upsert data", "error", err, "key", k)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// delKv deletes a key-value pair from the database.
|
||||
func (h *Hook) delKv(k string) error {
|
||||
err := h.db.Update(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(h.config.Bucket))
|
||||
err := bucket.Delete([]byte(k))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error("failed to delete data", "error", err, "key", k)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// getKv retrieves the value associated with a key from the database.
|
||||
func (h *Hook) getKv(k string, v storage.Serializable) error {
|
||||
err := h.db.View(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(h.config.Bucket))
|
||||
|
||||
value := bucket.Get([]byte(k))
|
||||
if value == nil {
|
||||
return ErrKeyNotFound
|
||||
}
|
||||
|
||||
return v.UnmarshalBinary(value)
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error("failed to get data", "error", err, "key", k)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// iterKv iterates over key-value pairs with keys having the specified prefix in the database.
|
||||
func (h *Hook) iterKv(prefix string, visit func([]byte) error) error {
|
||||
err := h.db.View(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(h.config.Bucket))
|
||||
|
||||
c := bucket.Cursor()
|
||||
for k, v := c.Seek([]byte(prefix)); k != nil && string(k[:len(prefix)]) == prefix; k, v = c.Next() {
|
||||
if err := visit(v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
h.Log.Error("failed to iter data", "error", err, "prefix", prefix)
|
||||
}
|
||||
return err
|
||||
}
|
||||
791
hooks/storage/bolt/bolt_test.go
Normal file
791
hooks/storage/bolt/bolt_test.go
Normal file
@@ -0,0 +1,791 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co, werbenhu
|
||||
|
||||
package bolt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"testmqtt/hooks/storage"
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
"testmqtt/system"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
|
||||
)
|
||||
|
||||
func teardown(t *testing.T, path string, h *Hook) {
|
||||
_ = h.Stop()
|
||||
err := os.Remove(path)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClientKey(t *testing.T) {
|
||||
k := clientKey(&mqtt.Client{ID: "cl1"})
|
||||
require.Equal(t, "CL_cl1", k)
|
||||
}
|
||||
|
||||
func TestSubscriptionKey(t *testing.T) {
|
||||
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
|
||||
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
|
||||
}
|
||||
|
||||
func TestRetainedKey(t *testing.T) {
|
||||
k := retainedKey("a/b/c")
|
||||
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
|
||||
}
|
||||
|
||||
func TestInflightKey(t *testing.T) {
|
||||
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
|
||||
require.Equal(t, storage.InflightKey+"_cl1:1", k)
|
||||
}
|
||||
|
||||
func TestSysInfoKey(t *testing.T) {
|
||||
require.Equal(t, storage.SysInfoKey, sysInfoKey())
|
||||
}
|
||||
|
||||
func TestID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.Equal(t, "bolt-db", h.ID())
|
||||
}
|
||||
|
||||
func TestProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.True(t, h.Provides(mqtt.OnSessionEstablished))
|
||||
require.True(t, h.Provides(mqtt.OnDisconnect))
|
||||
require.True(t, h.Provides(mqtt.OnSubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnUnsubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnRetainMessage))
|
||||
require.True(t, h.Provides(mqtt.OnQosPublish))
|
||||
require.True(t, h.Provides(mqtt.OnQosComplete))
|
||||
require.True(t, h.Provides(mqtt.OnQosDropped))
|
||||
require.True(t, h.Provides(mqtt.OnSysInfoTick))
|
||||
require.True(t, h.Provides(mqtt.StoredClients))
|
||||
require.True(t, h.Provides(mqtt.StoredInflightMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredSubscriptions))
|
||||
require.True(t, h.Provides(mqtt.StoredSysInfo))
|
||||
require.False(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
}
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestInitUseDefaults(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
require.Equal(t, defaultTimeout, h.config.Options.Timeout)
|
||||
require.Equal(t, defaultDbFile, h.config.Path)
|
||||
}
|
||||
|
||||
func TestInitBadPath(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(&Options{
|
||||
Path: "..",
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.getKv(clientKey(client), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
require.Equal(t, client.Net.Remote, r.Remote)
|
||||
require.Equal(t, client.Net.Listener, r.Listener)
|
||||
require.Equal(t, client.Properties.Username, r.Username)
|
||||
require.Equal(t, client.Properties.Clean, r.Clean)
|
||||
require.NotSame(t, client, r)
|
||||
|
||||
h.OnDisconnect(client, nil, false)
|
||||
r2 := new(storage.Client)
|
||||
err = h.getKv(clientKey(client), r2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
|
||||
h.OnDisconnect(client, nil, true)
|
||||
r3 := new(storage.Client)
|
||||
err = h.getKv(clientKey(client), r3)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, ErrKeyNotFound, err)
|
||||
require.Empty(t, r3.ID)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnWillSent(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
c1 := client
|
||||
c1.Properties.Will.Flag = 1
|
||||
h.OnWillSent(c1, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.getKv(clientKey(client), r)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, uint32(1), r.Will.Flag)
|
||||
require.NotSame(t, client, r)
|
||||
}
|
||||
|
||||
func TestOnClientExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
cl := &mqtt.Client{ID: "cl1"}
|
||||
clientKey := clientKey(cl)
|
||||
|
||||
err = h.setKv(clientKey, &storage.Client{ID: cl.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.getKv(clientKey, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cl.ID, r.ID)
|
||||
|
||||
h.OnClientExpired(cl)
|
||||
err = h.getKv(clientKey, r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, ErrKeyNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectSessionTakenOver(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
testClient := &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
testClient.Stop(packets.ErrSessionTakenOver)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(testClient, nil, true)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
r := new(storage.Subscription)
|
||||
|
||||
err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.Client)
|
||||
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
|
||||
require.Equal(t, byte(0), r.Qos)
|
||||
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, ErrKeyNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnSubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnRetainMessage(client, pk, 1)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.getKv(retainedKey(pk.TopicName), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
err = h.getKv(retainedKey(pk.TopicName), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, ErrKeyNotFound, err)
|
||||
|
||||
// coverage: delete deleted
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
err = h.getKv(retainedKey(pk.TopicName), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, ErrKeyNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
m := &storage.Message{
|
||||
ID: retainedKey("a/b/c"),
|
||||
T: storage.RetainedKey,
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
err = h.setKv(m.ID, m)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.getKv(m.ID, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, m.TopicName, r.TopicName)
|
||||
|
||||
h.OnRetainedExpired(m.TopicName)
|
||||
err = h.getKv(m.ID, r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, ErrKeyNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
Qos: 2,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.getKv(inflightKey(client, pk), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
// ensure dates are properly saved to bolt
|
||||
require.True(t, r.Sent > 0)
|
||||
require.True(t, time.Now().Unix()-1 < r.Sent)
|
||||
|
||||
// OnQosDropped is a passthrough to OnQosComplete here
|
||||
h.OnQosDropped(client, pk)
|
||||
err = h.getKv(inflightKey(client, pk), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, ErrKeyNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnQosPublishNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosCompleteNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
info := &system.Info{
|
||||
Version: "2.0.0",
|
||||
BytesReceived: 100,
|
||||
}
|
||||
|
||||
h.OnSysInfoTick(info)
|
||||
|
||||
r := new(storage.SystemInfo)
|
||||
err = h.getKv(storage.SysInfoKey, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, info.Version, r.Version)
|
||||
require.Equal(t, info.BytesReceived, r.BytesReceived)
|
||||
require.NotSame(t, info, r)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestStoredClients(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with clients
|
||||
err = h.setKv(storage.ClientKey+"_"+"cl1", &storage.Client{ID: "cl1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.ClientKey+"_"+"cl2", &storage.Client{ID: "cl2"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.ClientKey+"_"+"cl3", &storage.Client{ID: "cl3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "cl1", r[0].ID)
|
||||
require.Equal(t, "cl2", r[1].ID)
|
||||
require.Equal(t, "cl3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredClientsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
|
||||
}
|
||||
|
||||
func TestStoredClientsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with subscriptions
|
||||
err = h.setKv(storage.SubscriptionKey+"_"+"sub1", &storage.Subscription{ID: "sub1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.SubscriptionKey+"_"+"sub2", &storage.Subscription{ID: "sub2"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.SubscriptionKey+"_"+"sub3", &storage.Subscription{ID: "sub3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "sub1", r[0].ID)
|
||||
require.Equal(t, "sub2", r[1].ID)
|
||||
require.Equal(t, "sub3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.setKv(storage.RetainedKey+"_"+"m1", &storage.Message{ID: "m1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.RetainedKey+"_"+"m2", &storage.Message{ID: "m2"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.RetainedKey+"_"+"m3", &storage.Message{ID: "m3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.InflightKey+"_"+"i3", &storage.Message{ID: "i3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "m1", r[0].ID)
|
||||
require.Equal(t, "m2", r[1].ID)
|
||||
require.Equal(t, "m3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.setKv(storage.InflightKey+"_"+"i1", &storage.Message{ID: "i1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.InflightKey+"_"+"i2", &storage.Message{ID: "i2"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.InflightKey+"_"+"i3", &storage.Message{ID: "i3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.RetainedKey+"_"+"m1", &storage.Message{ID: "m1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "i1", r[0].ID)
|
||||
require.Equal(t, "i2", r[1].ID)
|
||||
require.Equal(t, "i3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfo(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with sys info
|
||||
err = h.setKv(storage.SysInfoKey, &storage.SystemInfo{
|
||||
ID: storage.SysInfoKey,
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
},
|
||||
T: storage.SysInfoKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2.0.0", r.Info.Version)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGetSetDelKv(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv("testId", &storage.Client{ID: "testId"})
|
||||
require.NoError(t, err)
|
||||
|
||||
var obj storage.Client
|
||||
err = h.getKv("testId", &obj)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.delKv("testId")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.getKv("testId", &obj)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, ErrKeyNotFound, err)
|
||||
}
|
||||
|
||||
func TestIterKv(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(nil)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
require.NoError(t, err)
|
||||
|
||||
h.setKv("prefix_a_1", &storage.Client{ID: "1"})
|
||||
h.setKv("prefix_a_2", &storage.Client{ID: "2"})
|
||||
h.setKv("prefix_b_2", &storage.Client{ID: "3"})
|
||||
|
||||
var clients []storage.Client
|
||||
err = h.iterKv("prefix_a", func(data []byte) error {
|
||||
var item storage.Client
|
||||
item.UnmarshalBinary(data)
|
||||
clients = append(clients, item)
|
||||
return nil
|
||||
})
|
||||
require.Equal(t, 2, len(clients))
|
||||
require.NoError(t, err)
|
||||
|
||||
visitErr := errors.New("iter visit error")
|
||||
err = h.iterKv("prefix_b", func(data []byte) error {
|
||||
return visitErr
|
||||
})
|
||||
require.ErrorIs(t, visitErr, err)
|
||||
}
|
||||
524
hooks/storage/pebble/pebble.go
Normal file
524
hooks/storage/pebble/pebble.go
Normal file
@@ -0,0 +1,524 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: werbenhu
|
||||
|
||||
package pebble
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
pebbledb "github.com/cockroachdb/pebble"
|
||||
"testmqtt/hooks/storage"
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
"testmqtt/system"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultDbFile is the default file path for the pebble db file.
|
||||
defaultDbFile = ".pebble"
|
||||
)
|
||||
|
||||
// clientKey returns a primary key for a client.
|
||||
func clientKey(cl *mqtt.Client) string {
|
||||
return storage.ClientKey + "_" + cl.ID
|
||||
}
|
||||
|
||||
// subscriptionKey returns a primary key for a subscription.
|
||||
func subscriptionKey(cl *mqtt.Client, filter string) string {
|
||||
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
|
||||
}
|
||||
|
||||
// retainedKey returns a primary key for a retained message.
|
||||
func retainedKey(topic string) string {
|
||||
return storage.RetainedKey + "_" + topic
|
||||
}
|
||||
|
||||
// inflightKey returns a primary key for an inflight message.
|
||||
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
|
||||
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
|
||||
}
|
||||
|
||||
// sysInfoKey returns a primary key for system info.
|
||||
func sysInfoKey() string {
|
||||
return storage.SysInfoKey
|
||||
}
|
||||
|
||||
// keyUpperBound returns the upper bound for a given byte slice by incrementing the last byte.
|
||||
// It returns nil if all bytes are incremented and equal to 0.
|
||||
func keyUpperBound(b []byte) []byte {
|
||||
end := make([]byte, len(b))
|
||||
copy(end, b)
|
||||
for i := len(end) - 1; i >= 0; i-- {
|
||||
end[i] = end[i] + 1
|
||||
if end[i] != 0 {
|
||||
return end[:i+1]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
NoSync = "NoSync" // NoSync specifies the default write options for writes which do not synchronize to disk.
|
||||
Sync = "Sync" // Sync specifies the default write options for writes which synchronize to disk.
|
||||
)
|
||||
|
||||
// Options contains configuration settings for the pebble DB instance.
|
||||
type Options struct {
|
||||
Options *pebbledb.Options
|
||||
Mode string `yaml:"mode" json:"mode"`
|
||||
Path string `yaml:"path" json:"path"`
|
||||
}
|
||||
|
||||
// Hook is a persistent storage hook based using pebble DB file store as a backend.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options // options for configuring the pebble DB instance.
|
||||
db *pebbledb.DB // the pebble DB instance
|
||||
mode *pebbledb.WriteOptions // mode holds the optional per-query parameters for Set and Delete operations
|
||||
}
|
||||
|
||||
// ID returns the id of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "pebble-db"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnSessionEstablished,
|
||||
mqtt.OnDisconnect,
|
||||
mqtt.OnSubscribed,
|
||||
mqtt.OnUnsubscribed,
|
||||
mqtt.OnRetainMessage,
|
||||
mqtt.OnWillSent,
|
||||
mqtt.OnQosPublish,
|
||||
mqtt.OnQosComplete,
|
||||
mqtt.OnQosDropped,
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
mqtt.StoredSubscriptions,
|
||||
mqtt.StoredSysInfo,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// Init initializes and connects to the pebble instance.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
h.config = new(Options)
|
||||
} else {
|
||||
h.config = config.(*Options)
|
||||
}
|
||||
|
||||
if len(h.config.Path) == 0 {
|
||||
h.config.Path = defaultDbFile
|
||||
}
|
||||
|
||||
if h.config.Options == nil {
|
||||
h.config.Options = &pebbledb.Options{}
|
||||
}
|
||||
|
||||
h.mode = pebbledb.NoSync
|
||||
if strings.EqualFold(h.config.Mode, "Sync") {
|
||||
h.mode = pebbledb.Sync
|
||||
}
|
||||
|
||||
var err error
|
||||
h.db, err = pebbledb.Open(h.config.Path, h.config.Options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop closes the pebble instance.
|
||||
func (h *Hook) Stop() error {
|
||||
err := h.db.Close()
|
||||
h.db = nil
|
||||
return err
|
||||
}
|
||||
|
||||
// OnSessionEstablished adds a client to the store when their session is established.
|
||||
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := cl.Properties.Props.Copy(false)
|
||||
in := &storage.Client{
|
||||
ID: cl.ID,
|
||||
T: storage.ClientKey,
|
||||
Remote: cl.Net.Remote,
|
||||
Listener: cl.Net.Listener,
|
||||
Username: cl.Properties.Username,
|
||||
Clean: cl.Properties.Clean,
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
Properties: storage.ClientProperties{
|
||||
SessionExpiryInterval: props.SessionExpiryInterval,
|
||||
AuthenticationMethod: props.AuthenticationMethod,
|
||||
AuthenticationData: props.AuthenticationData,
|
||||
RequestProblemInfo: props.RequestProblemInfo,
|
||||
RequestResponseInfo: props.RequestResponseInfo,
|
||||
ReceiveMaximum: props.ReceiveMaximum,
|
||||
TopicAliasMaximum: props.TopicAliasMaximum,
|
||||
User: props.User,
|
||||
MaximumPacketSize: props.MaximumPacketSize,
|
||||
},
|
||||
Will: storage.ClientWill(cl.Properties.Will),
|
||||
}
|
||||
h.setKv(clientKey(cl), in)
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if their session has expired.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
h.updateClient(cl)
|
||||
|
||||
if !expire {
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(cl.StopCause(), packets.ErrSessionTakenOver) {
|
||||
return
|
||||
}
|
||||
|
||||
h.delKv(clientKey(cl))
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
h.setKv(in.ID, in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
h.delKv(subscriptionKey(cl, pk.Filters[i].Filter))
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if r == -1 {
|
||||
h.delKv(retainedKey(pk.TopicName))
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: retainedKey(pk.TopicName),
|
||||
T: storage.RetainedKey,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Created: pk.Created,
|
||||
Origin: pk.Origin,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
h.setKv(in.ID, in)
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: inflightKey(cl, pk),
|
||||
T: storage.InflightKey,
|
||||
Origin: pk.Origin,
|
||||
PacketID: pk.PacketID,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Sent: sent,
|
||||
Created: pk.Created,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
h.setKv(in.ID, in)
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
h.delKv(inflightKey(cl, pk))
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
}
|
||||
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
in := &storage.SystemInfo{
|
||||
ID: sysInfoKey(),
|
||||
T: storage.SysInfoKey,
|
||||
Info: *sys.Clone(),
|
||||
}
|
||||
h.setKv(in.ID, in)
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
h.delKv(retainedKey(filter))
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
h.delKv(clientKey(cl))
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
iter, _ := h.db.NewIter(&pebbledb.IterOptions{
|
||||
LowerBound: []byte(storage.ClientKey),
|
||||
UpperBound: keyUpperBound([]byte(storage.ClientKey)),
|
||||
})
|
||||
|
||||
for iter.First(); iter.Valid(); iter.Next() {
|
||||
item := storage.Client{}
|
||||
if err := item.UnmarshalBinary(iter.Value()); err == nil {
|
||||
v = append(v, item)
|
||||
}
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
iter, _ := h.db.NewIter(&pebbledb.IterOptions{
|
||||
LowerBound: []byte(storage.SubscriptionKey),
|
||||
UpperBound: keyUpperBound([]byte(storage.SubscriptionKey)),
|
||||
})
|
||||
|
||||
for iter.First(); iter.Valid(); iter.Next() {
|
||||
item := storage.Subscription{}
|
||||
if err := item.UnmarshalBinary(iter.Value()); err == nil {
|
||||
v = append(v, item)
|
||||
}
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
iter, _ := h.db.NewIter(&pebbledb.IterOptions{
|
||||
LowerBound: []byte(storage.RetainedKey),
|
||||
UpperBound: keyUpperBound([]byte(storage.RetainedKey)),
|
||||
})
|
||||
|
||||
for iter.First(); iter.Valid(); iter.Next() {
|
||||
item := storage.Message{}
|
||||
if err := item.UnmarshalBinary(iter.Value()); err == nil {
|
||||
v = append(v, item)
|
||||
}
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
iter, _ := h.db.NewIter(&pebbledb.IterOptions{
|
||||
LowerBound: []byte(storage.InflightKey),
|
||||
UpperBound: keyUpperBound([]byte(storage.InflightKey)),
|
||||
})
|
||||
|
||||
for iter.First(); iter.Valid(); iter.Next() {
|
||||
item := storage.Message{}
|
||||
if err := item.UnmarshalBinary(iter.Value()); err == nil {
|
||||
v = append(v, item)
|
||||
}
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.getKv(sysInfoKey(), &v)
|
||||
if errors.Is(err, pebbledb.ErrNotFound) {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Errorf satisfies the pebble interface for an error logger.
|
||||
func (h *Hook) Errorf(m string, v ...any) {
|
||||
h.Log.Error(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
|
||||
}
|
||||
|
||||
// Warningf satisfies the pebble interface for a warning logger.
|
||||
func (h *Hook) Warningf(m string, v ...any) {
|
||||
h.Log.Warn(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
}
|
||||
|
||||
// Infof satisfies the pebble interface for an info logger.
|
||||
func (h *Hook) Infof(m string, v ...any) {
|
||||
h.Log.Info(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
}
|
||||
|
||||
// Debugf satisfies the pebble interface for a debug logger.
|
||||
func (h *Hook) Debugf(m string, v ...any) {
|
||||
h.Log.Debug(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
}
|
||||
|
||||
// delKv deletes a key-value pair from the database.
|
||||
func (h *Hook) delKv(k string) error {
|
||||
err := h.db.Delete([]byte(k), h.mode)
|
||||
if err != nil {
|
||||
h.Log.Error("failed to delete data", "error", err, "key", k)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setKv stores a key-value pair in the database.
|
||||
func (h *Hook) setKv(k string, v storage.Serializable) error {
|
||||
bs, _ := v.MarshalBinary()
|
||||
err := h.db.Set([]byte(k), bs, h.mode)
|
||||
if err != nil {
|
||||
h.Log.Error("failed to update data", "error", err, "key", k)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getKv retrieves the value associated with a key from the database.
|
||||
func (h *Hook) getKv(k string, v storage.Serializable) error {
|
||||
value, closer, err := h.db.Get([]byte(k))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if closer != nil {
|
||||
closer.Close()
|
||||
}
|
||||
}()
|
||||
return v.UnmarshalBinary(value)
|
||||
}
|
||||
812
hooks/storage/pebble/pebble_test.go
Normal file
812
hooks/storage/pebble/pebble_test.go
Normal file
@@ -0,0 +1,812 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: werbenhu
|
||||
|
||||
package pebble
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pebbledb "github.com/cockroachdb/pebble"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testmqtt/hooks/storage"
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
"testmqtt/system"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
|
||||
)
|
||||
|
||||
func teardown(t *testing.T, path string, h *Hook) {
|
||||
_ = h.Stop()
|
||||
err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestKeyUpperBound(t *testing.T) {
|
||||
// Test case 1: Non-nil case
|
||||
input1 := []byte{97, 98, 99} // "abc"
|
||||
require.NotNil(t, keyUpperBound(input1))
|
||||
|
||||
// Test case 2: All bytes are 255
|
||||
input2 := []byte{255, 255, 255}
|
||||
require.Nil(t, keyUpperBound(input2))
|
||||
|
||||
// Test case 3: Empty slice
|
||||
input3 := []byte{}
|
||||
require.Nil(t, keyUpperBound(input3))
|
||||
|
||||
// Test case 4: Nil case
|
||||
input4 := []byte{255, 255, 255}
|
||||
require.Nil(t, keyUpperBound(input4))
|
||||
}
|
||||
|
||||
func TestClientKey(t *testing.T) {
|
||||
k := clientKey(&mqtt.Client{ID: "cl1"})
|
||||
require.Equal(t, storage.ClientKey+"_cl1", k)
|
||||
}
|
||||
|
||||
func TestSubscriptionKey(t *testing.T) {
|
||||
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
|
||||
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
|
||||
}
|
||||
|
||||
func TestRetainedKey(t *testing.T) {
|
||||
k := retainedKey("a/b/c")
|
||||
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
|
||||
}
|
||||
|
||||
func TestInflightKey(t *testing.T) {
|
||||
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
|
||||
require.Equal(t, storage.InflightKey+"_cl1:1", k)
|
||||
}
|
||||
|
||||
func TestSysInfoKey(t *testing.T) {
|
||||
require.Equal(t, storage.SysInfoKey, sysInfoKey())
|
||||
}
|
||||
|
||||
func TestID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.Equal(t, "pebble-db", h.ID())
|
||||
}
|
||||
|
||||
func TestProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.True(t, h.Provides(mqtt.OnSessionEstablished))
|
||||
require.True(t, h.Provides(mqtt.OnDisconnect))
|
||||
require.True(t, h.Provides(mqtt.OnSubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnUnsubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnRetainMessage))
|
||||
require.True(t, h.Provides(mqtt.OnQosPublish))
|
||||
require.True(t, h.Provides(mqtt.OnQosComplete))
|
||||
require.True(t, h.Provides(mqtt.OnQosDropped))
|
||||
require.True(t, h.Provides(mqtt.OnSysInfoTick))
|
||||
require.True(t, h.Provides(mqtt.StoredClients))
|
||||
require.True(t, h.Provides(mqtt.StoredInflightMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredSubscriptions))
|
||||
require.True(t, h.Provides(mqtt.StoredSysInfo))
|
||||
require.False(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
}
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestInitErr(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(&Options{
|
||||
Options: &pebbledb.Options{
|
||||
ReadOnly: true,
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestInitUseDefaults(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
require.Equal(t, defaultDbFile, h.config.Path)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.getKv(clientKey(client), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
require.Equal(t, client.Properties.Username, r.Username)
|
||||
require.Equal(t, client.Properties.Clean, r.Clean)
|
||||
require.Equal(t, client.Net.Remote, r.Remote)
|
||||
require.Equal(t, client.Net.Listener, r.Listener)
|
||||
require.NotSame(t, client, r)
|
||||
|
||||
h.OnDisconnect(client, nil, false)
|
||||
r2 := new(storage.Client)
|
||||
err = h.getKv(clientKey(client), r2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
|
||||
h.OnDisconnect(client, nil, true)
|
||||
r3 := new(storage.Client)
|
||||
err = h.getKv(clientKey(client), r3)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, pebbledb.ErrNotFound, err)
|
||||
require.Empty(t, r3.ID)
|
||||
|
||||
}
|
||||
|
||||
func TestOnClientExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
cl := &mqtt.Client{ID: "cl1"}
|
||||
clientKey := clientKey(cl)
|
||||
|
||||
err = h.setKv(clientKey, &storage.Client{ID: cl.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.getKv(clientKey, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cl.ID, r.ID)
|
||||
|
||||
h.OnClientExpired(cl)
|
||||
err = h.getKv(clientKey, r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, pebbledb.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnWillSent(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
c1 := client
|
||||
c1.Properties.Will.Flag = 1
|
||||
h.OnWillSent(c1, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.getKv(clientKey(client), r)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, uint32(1), r.Will.Flag)
|
||||
require.NotSame(t, client, r)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectSessionTakenOver(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
testClient := &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
testClient.Stop(packets.ErrSessionTakenOver)
|
||||
h.OnDisconnect(testClient, nil, true)
|
||||
teardown(t, h.config.Path, h)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
r := new(storage.Subscription)
|
||||
|
||||
err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.Client)
|
||||
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
|
||||
require.Equal(t, byte(0), r.Qos)
|
||||
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, pebbledb.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnSubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnRetainMessage(client, pk, 1)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.getKv(retainedKey(pk.TopicName), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
err = h.getKv(retainedKey(pk.TopicName), r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, pebbledb.ErrNotFound)
|
||||
|
||||
// coverage: delete deleted
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
err = h.getKv(retainedKey(pk.TopicName), r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, pebbledb.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
m := &storage.Message{
|
||||
ID: retainedKey("a/b/c"),
|
||||
T: storage.RetainedKey,
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
err = h.setKv(m.ID, m)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.getKv(m.ID, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, m.TopicName, r.TopicName)
|
||||
|
||||
h.OnRetainedExpired(m.TopicName)
|
||||
err = h.getKv(m.ID, r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, pebbledb.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
Qos: 2,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.getKv(inflightKey(client, pk), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
// ensure dates are properly saved
|
||||
require.True(t, r.Sent > 0)
|
||||
require.True(t, time.Now().Unix()-1 < r.Sent)
|
||||
|
||||
// OnQosDropped is a passthrough to OnQosComplete here
|
||||
h.OnQosDropped(client, pk)
|
||||
err = h.getKv(inflightKey(client, pk), r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, pebbledb.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestOnQosPublishNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosCompleteNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
info := &system.Info{
|
||||
Version: "2.0.0",
|
||||
BytesReceived: 100,
|
||||
}
|
||||
|
||||
h.OnSysInfoTick(info)
|
||||
|
||||
r := new(storage.SystemInfo)
|
||||
err = h.getKv(storage.SysInfoKey, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, info.Version, r.Version)
|
||||
require.Equal(t, info.BytesReceived, r.BytesReceived)
|
||||
require.NotSame(t, info, r)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestStoredClients(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with clients
|
||||
err = h.setKv(storage.ClientKey+"_cl1", &storage.Client{ID: "cl1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.ClientKey+"_cl2", &storage.Client{ID: "cl2"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.ClientKey+"_cl3", &storage.Client{ID: "cl3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "cl1", r[0].ID)
|
||||
require.Equal(t, "cl2", r[1].ID)
|
||||
require.Equal(t, "cl3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredClientsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with subscriptions
|
||||
err = h.setKv(storage.SubscriptionKey+"_sub1", &storage.Subscription{ID: "sub1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.SubscriptionKey+"_sub2", &storage.Subscription{ID: "sub2"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.SubscriptionKey+"_sub3", &storage.Subscription{ID: "sub3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "sub1", r[0].ID)
|
||||
require.Equal(t, "sub2", r[1].ID)
|
||||
require.Equal(t, "sub3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.RetainedKey+"_m2", &storage.Message{ID: "m2"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.RetainedKey+"_m3", &storage.Message{ID: "m3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "m1", r[0].ID)
|
||||
require.Equal(t, "m2", r[1].ID)
|
||||
require.Equal(t, "m3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.setKv(storage.InflightKey+"_i1", &storage.Message{ID: "i1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.InflightKey+"_i2", &storage.Message{ID: "i2"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "i1", r[0].ID)
|
||||
require.Equal(t, "i2", r[1].ID)
|
||||
require.Equal(t, "i3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfo(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
r, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
|
||||
// populate with messages
|
||||
err = h.setKv(storage.SysInfoKey, &storage.SystemInfo{
|
||||
ID: storage.SysInfoKey,
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
},
|
||||
T: storage.SysInfoKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err = h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2.0.0", r.Info.Version)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestErrorf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Errorf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestWarningf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Warningf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestInfof(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Infof("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestDebugf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Debugf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestGetSetDelKv(t *testing.T) {
|
||||
opts := []struct {
|
||||
name string
|
||||
opt *Options
|
||||
}{
|
||||
{
|
||||
name: "NoSync",
|
||||
opt: &Options{
|
||||
Mode: NoSync,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Sync",
|
||||
opt: &Options{
|
||||
Mode: Sync,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range opts {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(tt.opt)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv("testKey", &storage.Client{ID: "testId"})
|
||||
require.NoError(t, err)
|
||||
|
||||
var obj storage.Client
|
||||
err = h.getKv("testKey", &obj)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.delKv("testKey")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.getKv("testKey", &obj)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, pebbledb.ErrNotFound, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSetDelKvErr(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(&Options{
|
||||
Mode: Sync,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.setKv("testKey", &storage.Client{ID: "testId"})
|
||||
require.NoError(t, err)
|
||||
h.Stop()
|
||||
|
||||
h = new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err = h.Init(&Options{
|
||||
Mode: Sync,
|
||||
Options: &pebbledb.Options{
|
||||
ReadOnly: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
err = h.setKv("testKey", &storage.Client{ID: "testId"})
|
||||
require.Error(t, err)
|
||||
|
||||
err = h.delKv("testKey")
|
||||
require.Error(t, err)
|
||||
}
|
||||
532
hooks/storage/redis/redis.go
Normal file
532
hooks/storage/redis/redis.go
Normal file
@@ -0,0 +1,532 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"testmqtt/hooks/storage"
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
"testmqtt/system"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
// defaultAddr 默认Redis地址
|
||||
// defaultAddr is the default address to the redis service.
|
||||
const defaultAddr = "localhost:6379"
|
||||
|
||||
// defaultHPrefix 默认前缀
|
||||
// defaultHPrefix is a prefix to better identify hsets created by mochi mqtt.
|
||||
const defaultHPrefix = "mochi-"
|
||||
|
||||
// clientKey returns a primary key for a client.
|
||||
func clientKey(cl *mqtt.Client) string {
|
||||
return cl.ID
|
||||
}
|
||||
|
||||
// subscriptionKey returns a primary key for a subscription.
|
||||
func subscriptionKey(cl *mqtt.Client, filter string) string {
|
||||
return cl.ID + ":" + filter
|
||||
}
|
||||
|
||||
// retainedKey returns a primary key for a retained message.
|
||||
func retainedKey(topic string) string {
|
||||
return topic
|
||||
}
|
||||
|
||||
// inflightKey returns a primary key for an inflight message.
|
||||
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
|
||||
return cl.ID + ":" + pk.FormatID()
|
||||
}
|
||||
|
||||
// sysInfoKey returns a primary key for system info.
|
||||
func sysInfoKey() string {
|
||||
return storage.SysInfoKey
|
||||
}
|
||||
|
||||
// Options contains configuration settings for the bolt instance.
|
||||
type Options struct {
|
||||
Address string `yaml:"address" json:"address"`
|
||||
Username string `yaml:"username" json:"username"`
|
||||
Password string `yaml:"password" json:"password"`
|
||||
Database int `yaml:"database" json:"database"`
|
||||
HPrefix string `yaml:"h_prefix" json:"h_prefix"`
|
||||
Options *redis.Options
|
||||
}
|
||||
|
||||
// Hook is a persistent storage hook based using Redis as a backend.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options // options for connecting to the Redis instance.
|
||||
db *redis.Client // the Redis instance
|
||||
ctx context.Context // a context for the connection
|
||||
}
|
||||
|
||||
// ID returns the id of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "redis-db"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnSessionEstablished,
|
||||
mqtt.OnDisconnect,
|
||||
mqtt.OnSubscribed,
|
||||
mqtt.OnUnsubscribed,
|
||||
mqtt.OnRetainMessage,
|
||||
mqtt.OnQosPublish,
|
||||
mqtt.OnQosComplete,
|
||||
mqtt.OnQosDropped,
|
||||
mqtt.OnWillSent,
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
mqtt.StoredSubscriptions,
|
||||
mqtt.StoredSysInfo,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// hKey returns a hash set key with a unique prefix.
|
||||
func (h *Hook) hKey(s string) string {
|
||||
return h.config.HPrefix + s
|
||||
}
|
||||
|
||||
// Init 初始化并连接到 Redis 服务
|
||||
// Init initializes and connects to the redis service.
|
||||
func (h *Hook) Init(config any) error {
|
||||
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
h.ctx = context.Background()
|
||||
|
||||
if config == nil {
|
||||
config = new(Options)
|
||||
}
|
||||
h.config = config.(*Options)
|
||||
if h.config.Options == nil {
|
||||
h.config.Options = &redis.Options{
|
||||
Addr: defaultAddr,
|
||||
}
|
||||
h.config.Options.Addr = h.config.Address
|
||||
h.config.Options.DB = h.config.Database
|
||||
//h.config.Options.Username = h.config.Username
|
||||
h.config.Options.Password = h.config.Password
|
||||
}
|
||||
|
||||
if h.config.HPrefix == "" {
|
||||
h.config.HPrefix = defaultHPrefix
|
||||
}
|
||||
|
||||
h.Log.Info(
|
||||
"connecting to redis service",
|
||||
"prefix", h.config.HPrefix,
|
||||
"address", h.config.Options.Addr,
|
||||
"username", h.config.Options.Username,
|
||||
"password-len", len(h.config.Options.Password),
|
||||
"db", h.config.Options.DB,
|
||||
)
|
||||
|
||||
h.db = redis.NewClient(h.config.Options)
|
||||
_, err := h.db.Ping(context.Background()).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to ping service: %w", err)
|
||||
}
|
||||
|
||||
h.Log.Info("connected to redis service")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop closes the redis connection.
|
||||
func (h *Hook) Stop() error {
|
||||
h.Log.Info("disconnecting from redis service")
|
||||
|
||||
return h.db.Close()
|
||||
}
|
||||
|
||||
// OnSessionEstablished adds a client to the store when their session is established.
|
||||
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := cl.Properties.Props.Copy(false)
|
||||
in := &storage.Client{
|
||||
ID: clientKey(cl),
|
||||
T: storage.ClientKey,
|
||||
Remote: cl.Net.Remote,
|
||||
Listener: cl.Net.Listener,
|
||||
Username: cl.Properties.Username,
|
||||
Clean: cl.Properties.Clean,
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
Properties: storage.ClientProperties{
|
||||
SessionExpiryInterval: props.SessionExpiryInterval,
|
||||
AuthenticationMethod: props.AuthenticationMethod,
|
||||
AuthenticationData: props.AuthenticationData,
|
||||
RequestProblemInfo: props.RequestProblemInfo,
|
||||
RequestResponseInfo: props.RequestResponseInfo,
|
||||
ReceiveMaximum: props.ReceiveMaximum,
|
||||
TopicAliasMaximum: props.TopicAliasMaximum,
|
||||
User: props.User,
|
||||
MaximumPacketSize: props.MaximumPacketSize,
|
||||
},
|
||||
Will: storage.ClientWill(cl.Properties.Will),
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to hset client data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if they were using a clean session.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if !expire {
|
||||
return
|
||||
}
|
||||
|
||||
if cl.StopCause() == packets.ErrSessionTakenOver {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to delete client", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to hset subscription data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to delete subscription data", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if r == -1 {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to delete retained message data", "error", err, "id", retainedKey(pk.TopicName))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: retainedKey(pk.TopicName),
|
||||
T: storage.RetainedKey,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Created: pk.Created,
|
||||
Origin: pk.Origin,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to hset retained message data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: inflightKey(cl, pk),
|
||||
T: storage.InflightKey,
|
||||
Origin: pk.Origin,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Sent: sent,
|
||||
Created: pk.Created,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to hset qos inflight message data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to delete qos inflight message data", "error", err, "id", inflightKey(cl, pk))
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
}
|
||||
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
in := &storage.SystemInfo{
|
||||
ID: sysInfoKey(),
|
||||
T: storage.SysInfoKey,
|
||||
Info: *sys,
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to hset server info data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to delete expired retained message", "error", err, "id", retainedKey(filter))
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error("failed to HGetAll client data", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Client
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error("failed to unmarshal client data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error("failed to HGetAll subscription data", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Subscription
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error("failed to unmarshal subscription data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error("failed to HGetAll retained message data", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error("failed to unmarshal retained message data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error("failed to HGetAll inflight message data", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error("failed to unmarshal inflight message data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = v.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error("failed to unmarshal sys info data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
834
hooks/storage/redis/redis_test.go
Normal file
834
hooks/storage/redis/redis_test.go
Normal file
@@ -0,0 +1,834 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"testmqtt/hooks/storage"
|
||||
"testmqtt/mqtt"
|
||||
"testmqtt/packets"
|
||||
"testmqtt/system"
|
||||
|
||||
miniredis "github.com/alicebob/miniredis/v2"
|
||||
redis "github.com/go-redis/redis/v8"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
|
||||
)
|
||||
|
||||
func newHook(t *testing.T, addr string) *Hook {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(&Options{
|
||||
Options: &redis.Options{
|
||||
Addr: addr,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func teardown(t *testing.T, h *Hook) {
|
||||
if h.db != nil {
|
||||
err := h.db.FlushAll(h.ctx).Err()
|
||||
require.NoError(t, err)
|
||||
h.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientKey(t *testing.T) {
|
||||
k := clientKey(&mqtt.Client{ID: "cl1"})
|
||||
require.Equal(t, "cl1", k)
|
||||
}
|
||||
|
||||
func TestSubscriptionKey(t *testing.T) {
|
||||
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
|
||||
require.Equal(t, "cl1:a/b/c", k)
|
||||
}
|
||||
|
||||
func TestRetainedKey(t *testing.T) {
|
||||
k := retainedKey("a/b/c")
|
||||
require.Equal(t, "a/b/c", k)
|
||||
}
|
||||
|
||||
func TestInflightKey(t *testing.T) {
|
||||
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
|
||||
require.Equal(t, "cl1:1", k)
|
||||
}
|
||||
|
||||
func TestSysInfoKey(t *testing.T) {
|
||||
require.Equal(t, storage.SysInfoKey, sysInfoKey())
|
||||
}
|
||||
|
||||
func TestID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
require.Equal(t, "redis-db", h.ID())
|
||||
}
|
||||
|
||||
func TestProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
require.True(t, h.Provides(mqtt.OnSessionEstablished))
|
||||
require.True(t, h.Provides(mqtt.OnDisconnect))
|
||||
require.True(t, h.Provides(mqtt.OnSubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnUnsubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnRetainMessage))
|
||||
require.True(t, h.Provides(mqtt.OnQosPublish))
|
||||
require.True(t, h.Provides(mqtt.OnQosComplete))
|
||||
require.True(t, h.Provides(mqtt.OnQosDropped))
|
||||
require.True(t, h.Provides(mqtt.OnSysInfoTick))
|
||||
require.True(t, h.Provides(mqtt.StoredClients))
|
||||
require.True(t, h.Provides(mqtt.StoredInflightMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredSubscriptions))
|
||||
require.True(t, h.Provides(mqtt.StoredSysInfo))
|
||||
require.False(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
}
|
||||
|
||||
func TestHKey(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.SetOpts(logger, nil)
|
||||
require.Equal(t, defaultHPrefix+"test", h.hKey("test"))
|
||||
}
|
||||
|
||||
func TestInitUseDefaults(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
s.StartAddr(defaultAddr)
|
||||
defer s.Close()
|
||||
|
||||
h := newHook(t, defaultAddr)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h)
|
||||
|
||||
require.Equal(t, defaultHPrefix, h.config.HPrefix)
|
||||
require.Equal(t, defaultAddr, h.config.Options.Addr)
|
||||
}
|
||||
|
||||
func TestInitUsePassConfig(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
s.StartAddr(defaultAddr)
|
||||
defer s.Close()
|
||||
|
||||
h := newHook(t, "")
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(&Options{
|
||||
Address: defaultAddr,
|
||||
Username: "username",
|
||||
Password: "password",
|
||||
Database: 2,
|
||||
})
|
||||
require.Error(t, err)
|
||||
h.db.FlushAll(h.ctx)
|
||||
|
||||
require.Equal(t, defaultAddr, h.config.Options.Addr)
|
||||
require.Equal(t, "username", h.config.Options.Username)
|
||||
require.Equal(t, "password", h.config.Options.Password)
|
||||
require.Equal(t, 2, h.config.Options.DB)
|
||||
}
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestInitBadAddr(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(&Options{
|
||||
Options: &redis.Options{
|
||||
Addr: "abc:123",
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
require.Equal(t, client.Net.Remote, r.Remote)
|
||||
require.Equal(t, client.Net.Listener, r.Listener)
|
||||
require.Equal(t, client.Properties.Username, r.Username)
|
||||
require.Equal(t, client.Properties.Clean, r.Clean)
|
||||
require.NotSame(t, client, r)
|
||||
|
||||
h.OnDisconnect(client, nil, false)
|
||||
r2 := new(storage.Client)
|
||||
row, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r2.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
|
||||
h.OnDisconnect(client, nil, true)
|
||||
r3 := new(storage.Client)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
require.Empty(t, r3.ID)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
|
||||
h.db = nil
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnWillSent(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
c1 := client
|
||||
c1.Properties.Will.Flag = 1
|
||||
h.OnWillSent(c1, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, uint32(1), r.Will.Flag)
|
||||
require.NotSame(t, client, r)
|
||||
}
|
||||
|
||||
func TestOnClientExpired(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
cl := &mqtt.Client{ID: "cl1"}
|
||||
clientKey := clientKey(cl)
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey, &storage.Client{ID: cl.ID}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Client)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, clientKey, r.ID)
|
||||
|
||||
h.OnClientExpired(cl)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, redis.Nil, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectSessionTakenOver(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
|
||||
testClient := &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
testClient.Stop(packets.ErrSessionTakenOver)
|
||||
teardown(t, h)
|
||||
h.OnDisconnect(testClient, nil, true)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
|
||||
r := new(storage.Subscription)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.Client)
|
||||
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
|
||||
require.Equal(t, byte(0), r.Qos)
|
||||
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnSubscribedNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnRetainMessage(client, pk, 1)
|
||||
|
||||
r := new(storage.Message)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
|
||||
// coverage: delete deleted
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpired(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
m := &storage.Message{
|
||||
ID: retainedKey("a/b/c"),
|
||||
T: storage.RetainedKey,
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), m.ID, m).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Message)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, m.TopicName, r.TopicName)
|
||||
|
||||
h.OnRetainedExpired(m.TopicName)
|
||||
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
Qos: 2,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
|
||||
|
||||
r := new(storage.Message)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
// ensure dates are properly saved to bolt
|
||||
require.True(t, r.Sent > 0)
|
||||
require.True(t, time.Now().Unix()-1 < r.Sent)
|
||||
|
||||
// OnQosDropped is a passthrough to OnQosComplete here
|
||||
h.OnQosDropped(client, pk)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnQosPublishNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosCompleteNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
info := &system.Info{
|
||||
Version: "2.0.0",
|
||||
BytesReceived: 100,
|
||||
}
|
||||
|
||||
h.OnSysInfoTick(info)
|
||||
|
||||
r := new(storage.SystemInfo)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, info.Version, r.Version)
|
||||
require.Equal(t, info.BytesReceived, r.BytesReceived)
|
||||
require.NotSame(t, info, r)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
func TestOnSysInfoTickNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestStoredClients(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with clients
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl1", &storage.Client{ID: "cl1", T: storage.ClientKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl2", &storage.Client{ID: "cl2", T: storage.ClientKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl3", &storage.Client{ID: "cl3", T: storage.ClientKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
|
||||
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
|
||||
require.Equal(t, "cl1", r[0].ID)
|
||||
require.Equal(t, "cl2", r[1].ID)
|
||||
require.Equal(t, "cl3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredClientsNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredClientsClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptions(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with subscriptions
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
|
||||
require.Equal(t, "sub1", r[0].ID)
|
||||
require.Equal(t, "sub2", r[1].ID)
|
||||
require.Equal(t, "sub3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessages(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with messages
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m1", &storage.Message{ID: "m1", T: storage.RetainedKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m2", &storage.Message{ID: "m2", T: storage.RetainedKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
|
||||
require.Equal(t, "m1", r[0].ID)
|
||||
require.Equal(t, "m2", r[1].ID)
|
||||
require.Equal(t, "m3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessages(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with messages
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1", &storage.Message{ID: "i1", T: storage.InflightKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2", &storage.Message{ID: "i2", T: storage.InflightKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
|
||||
require.Equal(t, "i1", r[0].ID)
|
||||
require.Equal(t, "i2", r[1].ID)
|
||||
require.Equal(t, "i3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfo(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with sys info
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey,
|
||||
&storage.SystemInfo{
|
||||
ID: storage.SysInfoKey,
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
},
|
||||
T: storage.SysInfoKey,
|
||||
}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2.0.0", r.Info.Version)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
213
hooks/storage/storage.go
Normal file
213
hooks/storage/storage.go
Normal file
@@ -0,0 +1,213 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"testmqtt/packets"
|
||||
"testmqtt/system"
|
||||
)
|
||||
|
||||
const (
|
||||
// SubscriptionKey 唯一标识订阅信息。在存储系统中,这个键用于检索和管理客户端的订阅信息
|
||||
// SubscriptionKey unique key to denote Subscriptions in a store
|
||||
SubscriptionKey = "SUB"
|
||||
// SysInfoKey 唯一标识服务器系统信息。在存储系统中,这个键用于存储和检索与服务器状态或系统配置相关的信息。
|
||||
// SysInfoKey unique key to denote server system information in a store
|
||||
SysInfoKey = "SYS"
|
||||
// RetainedKey 唯一标识保留的消息。保留消息是在订阅时立即发送的消息,即使订阅者在消息发布时不在线。这个键用于存储这些保留消息。
|
||||
// RetainedKey unique key to denote retained messages in a store
|
||||
RetainedKey = "RET"
|
||||
// InflightKey 唯一标识飞行中的消息。飞行中的消息指的是已经发布但尚未确认的消息,这个键用于跟踪这些消息的状态。
|
||||
// InflightKey unique key to denote inflight messages in a store
|
||||
InflightKey = "IFM"
|
||||
// ClientKey 唯一标识客户端信息。在存储系统中,这个键用于检索和管理有关客户端的信息,如客户端状态、连接信息等。
|
||||
// ClientKey unique key to denote clients in a store
|
||||
ClientKey = "CL"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDBFileNotOpen 数据库文件没有打开
|
||||
// ErrDBFileNotOpen indicates that the file database (e.g. bolt/badger) wasn't open for reading.
|
||||
ErrDBFileNotOpen = errors.New("数据库文件没有打开")
|
||||
)
|
||||
|
||||
// Serializable is an interface for objects that can be serialized and deserialized.
|
||||
type Serializable interface {
|
||||
UnmarshalBinary([]byte) error
|
||||
MarshalBinary() (data []byte, err error)
|
||||
}
|
||||
|
||||
// Client is a storable representation of an MQTT client.
|
||||
type Client struct {
|
||||
Will ClientWill `json:"will"` // will topic and payload data if applicable
|
||||
Properties ClientProperties `json:"properties"` // the connect properties for the client
|
||||
Username []byte `json:"username"` // the username of the client
|
||||
ID string `json:"id" storm:"id"` // the client id / storage key
|
||||
T string `json:"t"` // the data type (client)
|
||||
Remote string `json:"remote"` // the remote address of the client
|
||||
Listener string `json:"listener"` // the listener the client connected on
|
||||
ProtocolVersion byte `json:"protocolVersion"` // mqtt protocol version of the client
|
||||
Clean bool `json:"clean"` // if the client requested a clean start/session
|
||||
}
|
||||
|
||||
// ClientProperties contains a limited set of the mqtt v5 properties specific to a client connection.
|
||||
type ClientProperties struct {
|
||||
AuthenticationData []byte `json:"authenticationData,omitempty"`
|
||||
User []packets.UserProperty `json:"user,omitempty"`
|
||||
AuthenticationMethod string `json:"authenticationMethod,omitempty"`
|
||||
SessionExpiryInterval uint32 `json:"sessionExpiryInterval,omitempty"`
|
||||
MaximumPacketSize uint32 `json:"maximumPacketSize,omitempty"`
|
||||
ReceiveMaximum uint16 `json:"receiveMaximum,omitempty"`
|
||||
TopicAliasMaximum uint16 `json:"topicAliasMaximum,omitempty"`
|
||||
SessionExpiryIntervalFlag bool `json:"sessionExpiryIntervalFlag,omitempty"`
|
||||
RequestProblemInfo byte `json:"requestProblemInfo,omitempty"`
|
||||
RequestProblemInfoFlag bool `json:"requestProblemInfoFlag,omitempty"`
|
||||
RequestResponseInfo byte `json:"requestResponseInfo,omitempty"`
|
||||
}
|
||||
|
||||
// ClientWill contains a will message for a client, and limited mqtt v5 properties.
|
||||
type ClientWill struct {
|
||||
Payload []byte `json:"payload,omitempty"`
|
||||
User []packets.UserProperty `json:"user,omitempty"`
|
||||
TopicName string `json:"topicName,omitempty"`
|
||||
Flag uint32 `json:"flag,omitempty"`
|
||||
WillDelayInterval uint32 `json:"willDelayInterval,omitempty"`
|
||||
Qos byte `json:"qos,omitempty"`
|
||||
Retain bool `json:"retain,omitempty"`
|
||||
}
|
||||
|
||||
// MarshalBinary encodes the values into a json string.
|
||||
func (d Client) MarshalBinary() (data []byte, err error) {
|
||||
return json.Marshal(d)
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes a json string into a struct.
|
||||
func (d *Client) UnmarshalBinary(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(data, d)
|
||||
}
|
||||
|
||||
// Message is a storable representation of an MQTT message (specifically publish).
|
||||
type Message struct {
|
||||
Properties MessageProperties `json:"properties"` // -
|
||||
Payload []byte `json:"payload"` // the message payload (if retained)
|
||||
T string `json:"t,omitempty"` // the data type
|
||||
ID string `json:"id,omitempty" storm:"id"` // the storage key
|
||||
Origin string `json:"origin,omitempty"` // the id of the client who sent the message
|
||||
TopicName string `json:"topic_name,omitempty"` // the topic the message was sent to (if retained)
|
||||
FixedHeader packets.FixedHeader `json:"fixedheader"` // the header properties of the message
|
||||
Created int64 `json:"created,omitempty"` // the time the message was created in unixtime
|
||||
Sent int64 `json:"sent,omitempty"` // the last time the message was sent (for retries) in unixtime (if inflight)
|
||||
PacketID uint16 `json:"packet_id,omitempty"` // the unique id of the packet (if inflight)
|
||||
}
|
||||
|
||||
// MessageProperties contains a limited subset of mqtt v5 properties specific to publish messages.
|
||||
type MessageProperties struct {
|
||||
CorrelationData []byte `json:"correlationData,omitempty"`
|
||||
SubscriptionIdentifier []int `json:"subscriptionIdentifier,omitempty"`
|
||||
User []packets.UserProperty `json:"user,omitempty"`
|
||||
ContentType string `json:"contentType,omitempty"`
|
||||
ResponseTopic string `json:"responseTopic,omitempty"`
|
||||
MessageExpiryInterval uint32 `json:"messageExpiry,omitempty"`
|
||||
TopicAlias uint16 `json:"topicAlias,omitempty"`
|
||||
PayloadFormat byte `json:"payloadFormat,omitempty"`
|
||||
PayloadFormatFlag bool `json:"payloadFormatFlag,omitempty"`
|
||||
}
|
||||
|
||||
// MarshalBinary encodes the values into a json string.
|
||||
func (d Message) MarshalBinary() (data []byte, err error) {
|
||||
return json.Marshal(d)
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes a json string into a struct.
|
||||
func (d *Message) UnmarshalBinary(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(data, d)
|
||||
}
|
||||
|
||||
// ToPacket converts a storage.Message to a standard packet.
|
||||
func (d *Message) ToPacket() packets.Packet {
|
||||
pk := packets.Packet{
|
||||
FixedHeader: d.FixedHeader,
|
||||
PacketID: d.PacketID,
|
||||
TopicName: d.TopicName,
|
||||
Payload: d.Payload,
|
||||
Origin: d.Origin,
|
||||
Created: d.Created,
|
||||
Properties: packets.Properties{
|
||||
PayloadFormat: d.Properties.PayloadFormat,
|
||||
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
|
||||
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
|
||||
ContentType: d.Properties.ContentType,
|
||||
ResponseTopic: d.Properties.ResponseTopic,
|
||||
CorrelationData: d.Properties.CorrelationData,
|
||||
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
|
||||
TopicAlias: d.Properties.TopicAlias,
|
||||
User: d.Properties.User,
|
||||
},
|
||||
}
|
||||
|
||||
// Return a deep copy of the packet data otherwise the slices will
|
||||
// continue pointing at the values from the storage packet.
|
||||
pk = pk.Copy(true)
|
||||
pk.FixedHeader.Dup = d.FixedHeader.Dup
|
||||
|
||||
return pk
|
||||
}
|
||||
|
||||
// Subscription is a storable representation of an MQTT subscription.
|
||||
type Subscription struct {
|
||||
T string `json:"t,omitempty"`
|
||||
ID string `json:"id,omitempty" storm:"id"`
|
||||
Client string `json:"client,omitempty"`
|
||||
Filter string `json:"filter"`
|
||||
Identifier int `json:"identifier,omitempty"`
|
||||
RetainHandling byte `json:"retain_handling,omitempty"`
|
||||
Qos byte `json:"qos"`
|
||||
RetainAsPublished bool `json:"retain_as_pub,omitempty"`
|
||||
NoLocal bool `json:"no_local,omitempty"`
|
||||
}
|
||||
|
||||
// MarshalBinary encodes the values into a json string.
|
||||
func (d Subscription) MarshalBinary() (data []byte, err error) {
|
||||
return json.Marshal(d)
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes a json string into a struct.
|
||||
func (d *Subscription) UnmarshalBinary(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(data, d)
|
||||
}
|
||||
|
||||
// SystemInfo is a storable representation of the system information values.
|
||||
type SystemInfo struct {
|
||||
system.Info // embed the system info struct
|
||||
T string `json:"t"` // the data type
|
||||
ID string `json:"id" storm:"id"` // the storage key
|
||||
}
|
||||
|
||||
// MarshalBinary 将值编码为json字符串
|
||||
// MarshalBinary encodes the values into a json string.
|
||||
func (d SystemInfo) MarshalBinary() (data []byte, err error) {
|
||||
return json.Marshal(d)
|
||||
}
|
||||
|
||||
// UnmarshalBinary 将json字符串解码为结构体
|
||||
// UnmarshalBinary decodes a json string into a struct.
|
||||
func (d *SystemInfo) UnmarshalBinary(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(data, d)
|
||||
}
|
||||
228
hooks/storage/storage_test.go
Normal file
228
hooks/storage/storage_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"testmqtt/packets"
|
||||
"testmqtt/system"
|
||||
)
|
||||
|
||||
var (
|
||||
clientStruct = Client{
|
||||
ID: "test",
|
||||
T: "client",
|
||||
Remote: "remote",
|
||||
Listener: "listener",
|
||||
Username: []byte("mochi"),
|
||||
Clean: true,
|
||||
Properties: ClientProperties{
|
||||
SessionExpiryInterval: 2,
|
||||
SessionExpiryIntervalFlag: true,
|
||||
AuthenticationMethod: "a",
|
||||
AuthenticationData: []byte("test"),
|
||||
RequestProblemInfo: 1,
|
||||
RequestProblemInfoFlag: true,
|
||||
RequestResponseInfo: 1,
|
||||
ReceiveMaximum: 128,
|
||||
TopicAliasMaximum: 256,
|
||||
User: []packets.UserProperty{
|
||||
{Key: "k", Val: "v"},
|
||||
},
|
||||
MaximumPacketSize: 120,
|
||||
},
|
||||
Will: ClientWill{
|
||||
Qos: 1,
|
||||
Payload: []byte("abc"),
|
||||
TopicName: "a/b/c",
|
||||
Flag: 1,
|
||||
Retain: true,
|
||||
WillDelayInterval: 2,
|
||||
User: []packets.UserProperty{
|
||||
{Key: "k2", Val: "v2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
clientJSON = []byte(`{"will":{"payload":"YWJj","user":[{"k":"k2","v":"v2"}],"topicName":"a/b/c","flag":1,"willDelayInterval":2,"qos":1,"retain":true},"properties":{"authenticationData":"dGVzdA==","user":[{"k":"k","v":"v"}],"authenticationMethod":"a","sessionExpiryInterval":2,"maximumPacketSize":120,"receiveMaximum":128,"topicAliasMaximum":256,"sessionExpiryIntervalFlag":true,"requestProblemInfo":1,"requestProblemInfoFlag":true,"requestResponseInfo":1},"username":"bW9jaGk=","id":"test","t":"client","remote":"remote","listener":"listener","protocolVersion":0,"clean":true}`)
|
||||
|
||||
messageStruct = Message{
|
||||
T: "message",
|
||||
Payload: []byte("payload"),
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Remaining: 2,
|
||||
Type: 3,
|
||||
Qos: 1,
|
||||
Dup: true,
|
||||
Retain: true,
|
||||
},
|
||||
ID: "id",
|
||||
Origin: "mochi",
|
||||
TopicName: "topic",
|
||||
Properties: MessageProperties{
|
||||
PayloadFormat: 1,
|
||||
PayloadFormatFlag: true,
|
||||
MessageExpiryInterval: 20,
|
||||
ContentType: "type",
|
||||
ResponseTopic: "a/b/r",
|
||||
CorrelationData: []byte("r"),
|
||||
SubscriptionIdentifier: []int{1},
|
||||
TopicAlias: 2,
|
||||
User: []packets.UserProperty{
|
||||
{Key: "k2", Val: "v2"},
|
||||
},
|
||||
},
|
||||
Created: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
|
||||
Sent: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
|
||||
PacketID: 100,
|
||||
}
|
||||
messageJSON = []byte(`{"properties":{"correlationData":"cg==","subscriptionIdentifier":[1],"user":[{"k":"k2","v":"v2"}],"contentType":"type","responseTopic":"a/b/r","messageExpiry":20,"topicAlias":2,"payloadFormat":1,"payloadFormatFlag":true},"payload":"cGF5bG9hZA==","t":"message","id":"id","origin":"mochi","topic_name":"topic","fixedheader":{"remaining":2,"type":3,"qos":1,"dup":true,"retain":true},"created":1569027723,"sent":1569027723,"packet_id":100}`)
|
||||
|
||||
subscriptionStruct = Subscription{
|
||||
T: "subscription",
|
||||
ID: "id",
|
||||
Client: "mochi",
|
||||
Filter: "a/b/c",
|
||||
Qos: 1,
|
||||
}
|
||||
subscriptionJSON = []byte(`{"t":"subscription","id":"id","client":"mochi","filter":"a/b/c","qos":1}`)
|
||||
|
||||
sysInfoStruct = SystemInfo{
|
||||
T: "info",
|
||||
ID: "id",
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
Started: 1,
|
||||
Uptime: 2,
|
||||
BytesReceived: 3,
|
||||
BytesSent: 4,
|
||||
ClientsConnected: 5,
|
||||
ClientsMaximum: 7,
|
||||
MessagesReceived: 10,
|
||||
MessagesSent: 11,
|
||||
MessagesDropped: 20,
|
||||
PacketsReceived: 12,
|
||||
PacketsSent: 13,
|
||||
Retained: 15,
|
||||
Inflight: 16,
|
||||
InflightDropped: 17,
|
||||
},
|
||||
}
|
||||
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
|
||||
)
|
||||
|
||||
func TestClientMarshalBinary(t *testing.T) {
|
||||
data, err := clientStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, string(clientJSON), string(data))
|
||||
}
|
||||
|
||||
func TestClientUnmarshalBinary(t *testing.T) {
|
||||
d := clientStruct
|
||||
err := d.UnmarshalBinary(clientJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, clientStruct, d)
|
||||
}
|
||||
|
||||
func TestClientUnmarshalBinaryEmpty(t *testing.T) {
|
||||
d := Client{}
|
||||
err := d.UnmarshalBinary([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Client{}, d)
|
||||
}
|
||||
|
||||
func TestMessageMarshalBinary(t *testing.T) {
|
||||
data, err := messageStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, string(messageJSON), string(data))
|
||||
}
|
||||
|
||||
func TestMessageUnmarshalBinary(t *testing.T) {
|
||||
d := messageStruct
|
||||
err := d.UnmarshalBinary(messageJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, messageStruct, d)
|
||||
}
|
||||
|
||||
func TestMessageUnmarshalBinaryEmpty(t *testing.T) {
|
||||
d := Message{}
|
||||
err := d.UnmarshalBinary([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Message{}, d)
|
||||
}
|
||||
|
||||
func TestSubscriptionMarshalBinary(t *testing.T) {
|
||||
data, err := subscriptionStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, string(subscriptionJSON), string(data))
|
||||
}
|
||||
|
||||
func TestSubscriptionUnmarshalBinary(t *testing.T) {
|
||||
d := subscriptionStruct
|
||||
err := d.UnmarshalBinary(subscriptionJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, subscriptionStruct, d)
|
||||
}
|
||||
|
||||
func TestSubscriptionUnmarshalBinaryEmpty(t *testing.T) {
|
||||
d := Subscription{}
|
||||
err := d.UnmarshalBinary([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Subscription{}, d)
|
||||
}
|
||||
|
||||
func TestSysInfoMarshalBinary(t *testing.T) {
|
||||
data, err := sysInfoStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, string(sysInfoJSON), string(data))
|
||||
}
|
||||
|
||||
func TestSysInfoUnmarshalBinary(t *testing.T) {
|
||||
d := sysInfoStruct
|
||||
err := d.UnmarshalBinary(sysInfoJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, sysInfoStruct, d)
|
||||
}
|
||||
|
||||
func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) {
|
||||
d := SystemInfo{}
|
||||
err := d.UnmarshalBinary([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SystemInfo{}, d)
|
||||
}
|
||||
|
||||
func TestMessageToPacket(t *testing.T) {
|
||||
d := messageStruct
|
||||
pk := d.ToPacket()
|
||||
|
||||
require.Equal(t, packets.Packet{
|
||||
Payload: []byte("payload"),
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Remaining: d.FixedHeader.Remaining,
|
||||
Type: d.FixedHeader.Type,
|
||||
Qos: d.FixedHeader.Qos,
|
||||
Dup: d.FixedHeader.Dup,
|
||||
Retain: d.FixedHeader.Retain,
|
||||
},
|
||||
Origin: d.Origin,
|
||||
TopicName: d.TopicName,
|
||||
Properties: packets.Properties{
|
||||
PayloadFormat: d.Properties.PayloadFormat,
|
||||
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
|
||||
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
|
||||
ContentType: d.Properties.ContentType,
|
||||
ResponseTopic: d.Properties.ResponseTopic,
|
||||
CorrelationData: d.Properties.CorrelationData,
|
||||
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
|
||||
TopicAlias: d.Properties.TopicAlias,
|
||||
User: d.Properties.User,
|
||||
},
|
||||
PacketID: 100,
|
||||
Created: d.Created,
|
||||
}, pk)
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user