代码整理

This commit is contained in:
2024-08-21 15:32:05 +08:00
commit 84e5e65ee7
71 changed files with 29733 additions and 0 deletions

41
hooks/auth/allow_all.go Normal file
View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"bytes"
"testmqtt/mqtt"
"testmqtt/packets"
)
// AllowHook is an authentication hook which allows connection access
// for all users and read and write access to all topics.
type AllowHook struct {
mqtt.HookBase
}
// ID returns the ID of the hook.
func (h *AllowHook) ID() string {
return "allow-all-auth"
}
// Provides indicates which hook methods this hook provides.
func (h *AllowHook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnACLCheck,
}, []byte{b})
}
// OnConnectAuthenticate returns true/allowed for all requests.
func (h *AllowHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
return true
}
// OnACLCheck returns true/allowed for all checks.
func (h *AllowHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
return true
}

View File

@@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"testing"
"github.com/stretchr/testify/require"
"testmqtt/mqtt"
"testmqtt/packets"
)
func TestAllowAllID(t *testing.T) {
h := new(AllowHook)
require.Equal(t, "allow-all-auth", h.ID())
}
func TestAllowAllProvides(t *testing.T) {
h := new(AllowHook)
require.True(t, h.Provides(mqtt.OnACLCheck))
require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
require.False(t, h.Provides(mqtt.OnPublished))
}
func TestAllowAllOnConnectAuthenticate(t *testing.T) {
h := new(AllowHook)
require.True(t, h.OnConnectAuthenticate(new(mqtt.Client), packets.Packet{}))
}
func TestAllowAllOnACLCheck(t *testing.T) {
h := new(AllowHook)
require.True(t, h.OnACLCheck(new(mqtt.Client), "any", true))
}

103
hooks/auth/auth.go Normal file
View File

@@ -0,0 +1,103 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"bytes"
"testmqtt/mqtt"
"testmqtt/packets"
)
// Options contains the configuration/rules data for the auth ledger.
type Options struct {
Data []byte
Ledger *Ledger
}
// Hook is an authentication hook which implements an auth ledger.
type Hook struct {
mqtt.HookBase
config *Options
ledger *Ledger
}
// ID returns the ID of the hook.
func (h *Hook) ID() string {
return "auth-ledger"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnACLCheck,
}, []byte{b})
}
// Init configures the hook with the auth ledger to be used for checking.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
var err error
if h.config.Ledger != nil {
h.ledger = h.config.Ledger
} else if len(h.config.Data) > 0 {
h.ledger = new(Ledger)
err = h.ledger.Unmarshal(h.config.Data)
}
if err != nil {
return err
}
if h.ledger == nil {
h.ledger = &Ledger{
Auth: AuthRules{},
ACL: ACLRules{},
}
}
h.Log.Info("loaded auth rules",
"authentication", len(h.ledger.Auth),
"acl", len(h.ledger.ACL))
return nil
}
// OnConnectAuthenticate returns true if the connecting client has rules which provide access
// in the auth ledger.
func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
if _, ok := h.ledger.AuthOk(cl, pk); ok {
return true
}
h.Log.Info("client failed authentication check",
"username", string(pk.Connect.Username),
"remote", cl.Net.Remote)
return false
}
// OnACLCheck returns true if the connecting client has matching read or write access to subscribe
// or publish to a given topic.
func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
if _, ok := h.ledger.ACLOk(cl, topic, write); ok {
return true
}
h.Log.Debug("client failed allowed ACL check",
"client", cl.ID,
"username", string(cl.Properties.Username),
"topic", topic)
return false
}

213
hooks/auth/auth_test.go Normal file
View File

@@ -0,0 +1,213 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"log/slog"
"os"
"testing"
"github.com/stretchr/testify/require"
"testmqtt/mqtt"
"testmqtt/packets"
)
var logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
// func teardown(t *testing.T, path string, h *Hook) {
// h.Stop()
// }
func TestBasicID(t *testing.T) {
h := new(Hook)
require.Equal(t, "auth-ledger", h.ID())
}
func TestBasicProvides(t *testing.T) {
h := new(Hook)
require.True(t, h.Provides(mqtt.OnACLCheck))
require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
require.False(t, h.Provides(mqtt.OnPublish))
}
func TestBasicInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestBasicInitDefaultConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
}
func TestBasicInitWithLedgerPointer(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
ln := &Ledger{
Auth: []AuthRule{
{
Remote: "127.0.0.1",
Allow: true,
},
},
ACL: []ACLRule{
{
Remote: "127.0.0.1",
Filters: Filters{
"#": ReadWrite,
},
},
},
}
err := h.Init(&Options{
Ledger: ln,
})
require.NoError(t, err)
require.Same(t, ln, h.ledger)
}
func TestBasicInitWithLedgerJSON(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
Data: ledgerJSON,
})
require.NoError(t, err)
require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
}
func TestBasicInitWithLedgerYAML(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
Data: ledgerYAML,
})
require.NoError(t, err)
require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
}
func TestBasicInitWithLedgerBadDAta(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
Data: []byte("fdsfdsafasd"),
})
require.Error(t, err)
}
func TestOnConnectAuthenticate(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
ln := new(Ledger)
ln.Auth = checkLedger.Auth
ln.ACL = checkLedger.ACL
err := h.Init(
&Options{
Ledger: ln,
},
)
require.NoError(t, err)
require.True(t, h.OnConnectAuthenticate(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
))
require.False(t, h.OnConnectAuthenticate(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
))
require.False(t, h.OnConnectAuthenticate(
&mqtt.Client{},
packets.Packet{},
))
}
func TestOnACL(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
ln := new(Ledger)
ln.Auth = checkLedger.Auth
ln.ACL = checkLedger.ACL
err := h.Init(
&Options{
Ledger: ln,
},
)
require.NoError(t, err)
require.True(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"mochi/info",
true,
))
require.False(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"d/j/f",
true,
))
require.True(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"readonly",
false,
))
require.False(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"readonly",
true,
))
}

246
hooks/auth/ledger.go Normal file
View File

@@ -0,0 +1,246 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"encoding/json"
"strings"
"sync"
"gopkg.in/yaml.v3"
"testmqtt/mqtt"
"testmqtt/packets"
)
const (
Deny Access = iota // user cannot access the topic
ReadOnly // user can only subscribe to the topic
WriteOnly // user can only publish to the topic
ReadWrite // user can both publish and subscribe to the topic
)
// Access determines the read/write privileges for an ACL rule.
type Access byte
// Users contains a map of access rules for specific users, keyed on username.
type Users map[string]UserRule
// UserRule defines a set of access rules for a specific user.
type UserRule struct {
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
ACL Filters `json:"acl,omitempty" yaml:"acl,omitempty"` // filters to match, if desired
Disallow bool `json:"disallow,omitempty" yaml:"disallow,omitempty"` // allow or disallow the user
}
// AuthRules defines generic access rules applicable to all users.
type AuthRules []AuthRule
type AuthRule struct {
Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
Allow bool `json:"allow,omitempty" yaml:"allow,omitempty"` // allow or disallow the users
}
// ACLRules defines generic topic or filter access rules applicable to all users.
type ACLRules []ACLRule
// ACLRule defines access rules for a specific topic or filter.
type ACLRule struct {
Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
Filters Filters `json:"filters,omitempty" yaml:"filters,omitempty"` // filters to match
}
// Filters is a map of Access rules keyed on filter.
type Filters map[RString]Access
// RString is a rule value string.
type RString string
// Matches returns true if the rule matches a given string.
func (r RString) Matches(a string) bool {
rr := string(r)
if r == "" || r == "*" || a == rr {
return true
}
i := strings.Index(rr, "*")
if i > 0 && len(a) > i && strings.Compare(rr[:i], a[:i]) == 0 {
return true
}
return false
}
// FilterMatches returns true if a filter matches a topic rule.
func (r RString) FilterMatches(a string) bool {
_, ok := MatchTopic(string(r), a)
return ok
}
// MatchTopic checks if a given topic matches a filter, accounting for filter
// wildcards. Eg. filter /a/b/+/c == topic a/b/d/c.
func MatchTopic(filter string, topic string) (elements []string, matched bool) {
filterParts := strings.Split(filter, "/")
topicParts := strings.Split(topic, "/")
elements = make([]string, 0)
for i := 0; i < len(filterParts); i++ {
if i >= len(topicParts) {
matched = false
return
}
if filterParts[i] == "+" {
elements = append(elements, topicParts[i])
continue
}
if filterParts[i] == "#" {
matched = true
elements = append(elements, strings.Join(topicParts[i:], "/"))
return
}
if filterParts[i] != topicParts[i] {
matched = false
return
}
}
return elements, true
}
// Ledger is an auth ledger containing access rules for users and topics.
type Ledger struct {
sync.Mutex `json:"-" yaml:"-"`
Users Users `json:"users" yaml:"users"`
Auth AuthRules `json:"auth" yaml:"auth"`
ACL ACLRules `json:"acl" yaml:"acl"`
}
// Update updates the internal values of the ledger.
func (l *Ledger) Update(ln *Ledger) {
l.Lock()
defer l.Unlock()
l.Auth = ln.Auth
l.ACL = ln.ACL
}
// AuthOk returns true if the rules indicate the user is allowed to authenticate.
func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) {
// If the users map is set, always check for a predefined user first instead
// of iterating through global rules.
if l.Users != nil {
if u, ok := l.Users[string(cl.Properties.Username)]; ok &&
u.Password != "" &&
u.Password == RString(pk.Connect.Password) {
return 0, !u.Disallow
}
}
// If there's no users map, or no user was found, attempt to find a matching
// rule (which may also contain a user).
for n, rule := range l.Auth {
if rule.Client.Matches(cl.ID) &&
rule.Username.Matches(string(cl.Properties.Username)) &&
rule.Password.Matches(string(pk.Connect.Password)) &&
rule.Remote.Matches(cl.Net.Remote) {
return n, rule.Allow
}
}
return 0, false
}
// ACLOk returns true if the rules indicate the user is allowed to read or write to
// a specific filter or topic respectively, based on the `write` bool.
func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) {
// If the users map is set, always check for a predefined user first instead
// of iterating through global rules.
if l.Users != nil {
if u, ok := l.Users[string(cl.Properties.Username)]; ok && len(u.ACL) > 0 {
for filter, access := range u.ACL {
if filter.FilterMatches(topic) {
if !write && (access == ReadOnly || access == ReadWrite) {
return n, true
} else if write && (access == WriteOnly || access == ReadWrite) {
return n, true
} else {
return n, false
}
}
}
}
}
for n, rule := range l.ACL {
if rule.Client.Matches(cl.ID) &&
rule.Username.Matches(string(cl.Properties.Username)) &&
rule.Remote.Matches(cl.Net.Remote) {
if len(rule.Filters) == 0 {
return n, true
}
if write {
for filter, access := range rule.Filters {
if access == WriteOnly || access == ReadWrite {
if filter.FilterMatches(topic) {
return n, true
}
}
}
}
if !write {
for filter, access := range rule.Filters {
if access == ReadOnly || access == ReadWrite {
if filter.FilterMatches(topic) {
return n, true
}
}
}
}
for filter := range rule.Filters {
if filter.FilterMatches(topic) {
return n, false
}
}
}
}
return 0, true
}
// ToJSON encodes the values into a JSON string.
func (l *Ledger) ToJSON() (data []byte, err error) {
return json.Marshal(l)
}
// ToYAML encodes the values into a YAML string.
func (l *Ledger) ToYAML() (data []byte, err error) {
return yaml.Marshal(l)
}
// Unmarshal decodes a JSON or YAML string (such as a rule config from a file) into a struct.
func (l *Ledger) Unmarshal(data []byte) error {
l.Lock()
defer l.Unlock()
if len(data) == 0 {
return nil
}
if data[0] == '{' {
return json.Unmarshal(data, l)
}
return yaml.Unmarshal(data, &l)
}

610
hooks/auth/ledger_test.go Normal file
View File

@@ -0,0 +1,610 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"testing"
"github.com/stretchr/testify/require"
"testmqtt/mqtt"
"testmqtt/packets"
)
var (
checkLedger = Ledger{
Users: Users{ // users are allowed by default
"mochi-co": {
Password: "melon",
ACL: Filters{
"d/+/f": Deny,
"mochi-co/#": ReadWrite,
"readonly": ReadOnly,
},
},
"suspended-username": {
Password: "any",
Disallow: true,
},
"mochi": { // ACL only, will defer to AuthRules for authentication
ACL: Filters{
"special/mochi": ReadOnly,
"secret/mochi": Deny,
"ignored": ReadWrite,
},
},
},
Auth: AuthRules{
{Username: "banned-user"}, // never allow specific username
{Remote: "127.0.0.1", Allow: true}, // always allow localhost
{Remote: "123.123.123.123"}, // disallow any from specific address
{Username: "not-mochi", Remote: "111.144.155.166"}, // disallow specific username and address
{Remote: "111.*", Allow: true}, // allow any in wildcard (that isn't the above username)
{Username: "mochi", Password: "melon", Allow: true}, // allow matching user/pass
{Username: "mochi-co", Password: "melon", Allow: false}, // allow matching user/pass (should never trigger due to Users map)
},
ACL: ACLRules{
{
Username: "mochi", // allow matching user/pass
Filters: Filters{
"a/b/c": Deny,
"d/+/f": Deny,
"mochi/#": ReadWrite,
"updates/#": WriteOnly,
"readonly": ReadOnly,
"ignored": Deny,
},
},
{Remote: "localhost", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to localhost
{Username: "admin", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to admin
{Remote: "001.002.003.004"}, // Allow all with no filter
{Filters: Filters{"$SYS/#": Deny}}, // Deny $SYS access to all others
},
}
)
func TestRStringMatches(t *testing.T) {
require.True(t, RString("*").Matches("any"))
require.True(t, RString("*").Matches(""))
require.True(t, RString("").Matches("any"))
require.True(t, RString("").Matches(""))
require.False(t, RString("no").Matches("any"))
require.False(t, RString("no").Matches(""))
}
func TestCanAuthenticate(t *testing.T) {
tt := []struct {
desc string
client *mqtt.Client
pk packets.Packet
n int
ok bool
}{
{
desc: "allow all local 127.0.0.1",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
Net: mqtt.ClientConnection{
Remote: "127.0.0.1",
},
},
pk: packets.Packet{Connect: packets.ConnectParams{}},
ok: true,
n: 1,
},
{
desc: "allow username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
ok: true,
n: 5,
},
{
desc: "deny username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
ok: false,
n: 0,
},
{
desc: "allow all local 127.0.0.1",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
Net: mqtt.ClientConnection{
Remote: "127.0.0.1",
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
ok: true,
n: 1,
},
{
desc: "allow username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
ok: true,
n: 5,
},
{
desc: "deny username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
ok: false,
n: 0,
},
{
desc: "deny client from address",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("not-mochi"),
},
Net: mqtt.ClientConnection{
Remote: "111.144.155.166",
},
},
pk: packets.Packet{},
ok: false,
n: 3,
},
{
desc: "allow remote wildcard",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
Net: mqtt.ClientConnection{
Remote: "111.0.0.1",
},
},
pk: packets.Packet{},
ok: true,
n: 4,
},
{
desc: "never allow username",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("banned-user"),
},
Net: mqtt.ClientConnection{
Remote: "127.0.0.1",
},
},
pk: packets.Packet{},
ok: false,
n: 0,
},
{
desc: "matching user in users",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi-co"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
ok: true,
n: 0,
},
{
desc: "never user in users",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("suspended-user"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("any")}},
ok: false,
n: 0,
},
}
for _, d := range tt {
t.Run(d.desc, func(t *testing.T) {
n, ok := checkLedger.AuthOk(d.client, d.pk)
require.Equal(t, d.n, n)
require.Equal(t, d.ok, ok)
})
}
}
func TestCanACL(t *testing.T) {
tt := []struct {
client *mqtt.Client
desc string
topic string
n int
write bool
ok bool
}{
{
desc: "allow normal write on any other filter",
client: &mqtt.Client{},
topic: "default/acl/write/access",
write: true,
ok: true,
},
{
desc: "allow normal read on any other filter",
client: &mqtt.Client{},
topic: "default/acl/read/access",
write: false,
ok: true,
},
{
desc: "deny user on literal filter",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "a/b/c",
},
{
desc: "deny user on partial filter",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "d/j/f",
},
{
desc: "allow read/write to user path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "mochi/read/write",
write: true,
ok: true,
},
{
desc: "deny read on write-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates/no/reading",
write: false,
ok: false,
},
{
desc: "deny read on write-only path ext",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates/mochi",
write: false,
ok: false,
},
{
desc: "allow read on not-acl path (no #)",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates",
write: false,
ok: true,
},
{
desc: "allow write on write-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates/mochi",
write: true,
ok: true,
},
{
desc: "deny write on read-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "readonly",
write: true,
ok: false,
},
{
desc: "allow read on read-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "readonly",
write: false,
ok: true,
},
{
desc: "allow $sys access to localhost",
client: &mqtt.Client{
Net: mqtt.ClientConnection{
Remote: "localhost",
},
},
topic: "$SYS/test",
write: false,
ok: true,
n: 1,
},
{
desc: "allow $sys access to admin",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("admin"),
},
},
topic: "$SYS/test",
write: false,
ok: true,
n: 2,
},
{
desc: "deny $sys access to all others",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "$SYS/test",
write: false,
ok: false,
n: 4,
},
{
desc: "allow all with no filter",
client: &mqtt.Client{
Net: mqtt.ClientConnection{
Remote: "001.002.003.004",
},
},
topic: "any/path",
write: true,
ok: true,
n: 3,
},
{
desc: "use users embedded acl deny",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "secret/mochi",
write: true,
ok: false,
},
{
desc: "use users embedded acl any",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "any/mochi",
write: true,
ok: true,
},
{
desc: "use users embedded acl write on read-only",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "special/mochi",
write: true,
ok: false,
},
{
desc: "use users embedded acl read on read-only",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "special/mochi",
write: false,
ok: true,
},
{
desc: "preference users embedded acl",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "ignored",
write: true,
ok: true,
},
}
for _, d := range tt {
t.Run(d.desc, func(t *testing.T) {
n, ok := checkLedger.ACLOk(d.client, d.topic, d.write)
require.Equal(t, d.n, n)
require.Equal(t, d.ok, ok)
})
}
}
func TestMatchTopic(t *testing.T) {
el, matched := MatchTopic("a/+/c/+", "a/b/c/d")
require.True(t, matched)
require.Equal(t, []string{"b", "d"}, el)
el, matched = MatchTopic("a/+/+/+", "a/b/c/d")
require.True(t, matched)
require.Equal(t, []string{"b", "c", "d"}, el)
el, matched = MatchTopic("stuff/#", "stuff/things/yeah")
require.True(t, matched)
require.Equal(t, []string{"things/yeah"}, el)
el, matched = MatchTopic("a/+/#/+", "a/b/c/d/as/dds")
require.True(t, matched)
require.Equal(t, []string{"b", "c/d/as/dds"}, el)
el, matched = MatchTopic("test", "test")
require.True(t, matched)
require.Equal(t, make([]string, 0), el)
el, matched = MatchTopic("things/stuff//", "things/stuff/")
require.False(t, matched)
require.Equal(t, make([]string, 0), el)
el, matched = MatchTopic("t", "t2")
require.False(t, matched)
require.Equal(t, make([]string, 0), el)
el, matched = MatchTopic(" ", " ")
require.False(t, matched)
require.Equal(t, make([]string, 0), el)
}
var (
ledgerStruct = Ledger{
Users: Users{
"mochi": {
Password: "peach",
ACL: Filters{
"readonly": ReadOnly,
"deny": Deny,
},
},
},
Auth: AuthRules{
{
Client: "*",
Username: "mochi-co",
Password: "melon",
Remote: "192.168.1.*",
Allow: true,
},
},
ACL: ACLRules{
{
Client: "*",
Username: "mochi-co",
Remote: "127.*",
Filters: Filters{
"readonly": ReadOnly,
"writeonly": WriteOnly,
"readwrite": ReadWrite,
"deny": Deny,
},
},
},
}
ledgerJSON = []byte(`{"users":{"mochi":{"password":"peach","acl":{"deny":0,"readonly":1}}},"auth":[{"client":"*","username":"mochi-co","remote":"192.168.1.*","password":"melon","allow":true}],"acl":[{"client":"*","username":"mochi-co","remote":"127.*","filters":{"deny":0,"readonly":1,"readwrite":3,"writeonly":2}}]}`)
ledgerYAML = []byte(`users:
mochi:
password: peach
acl:
deny: 0
readonly: 1
auth:
- client: '*'
username: mochi-co
remote: 192.168.1.*
password: melon
allow: true
acl:
- client: '*'
username: mochi-co
remote: 127.*
filters:
deny: 0
readonly: 1
readwrite: 3
writeonly: 2
`)
)
func TestLedgerUpdate(t *testing.T) {
old := &Ledger{
Auth: AuthRules{
{Remote: "127.0.0.1", Allow: true},
},
}
n := &Ledger{
Auth: AuthRules{
{Remote: "127.0.0.1", Allow: true},
{Remote: "192.168.*", Allow: true},
},
}
old.Update(n)
require.Len(t, old.Auth, 2)
require.Equal(t, RString("192.168.*"), old.Auth[1].Remote)
require.NotSame(t, n, old)
}
func TestLedgerToJSON(t *testing.T) {
data, err := ledgerStruct.ToJSON()
require.NoError(t, err)
require.Equal(t, ledgerJSON, data)
}
func TestLedgerToYAML(t *testing.T) {
data, err := ledgerStruct.ToYAML()
require.NoError(t, err)
require.Equal(t, ledgerYAML, data)
}
func TestLedgerUnmarshalFromYAML(t *testing.T) {
l := new(Ledger)
err := l.Unmarshal(ledgerYAML)
require.NoError(t, err)
require.Equal(t, &ledgerStruct, l)
require.NotSame(t, l, &ledgerStruct)
}
func TestLedgerUnmarshalFromJSON(t *testing.T) {
l := new(Ledger)
err := l.Unmarshal(ledgerJSON)
require.NoError(t, err)
require.Equal(t, &ledgerStruct, l)
require.NotSame(t, l, &ledgerStruct)
}
func TestLedgerUnmarshalNil(t *testing.T) {
l := new(Ledger)
err := l.Unmarshal([]byte{})
require.NoError(t, err)
require.Equal(t, new(Ledger), l)
}

237
hooks/debug/debug.go Normal file
View File

@@ -0,0 +1,237 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package debug
import (
"fmt"
"log/slog"
"strings"
"testmqtt/hooks/storage"
"testmqtt/mqtt"
"testmqtt/packets"
)
// Options contains configuration settings for the debug output.
type Options struct {
Enable bool `yaml:"enable" json:"enable"` // non-zero field for enabling hook using file-based config
ShowPacketData bool `yaml:"show_packet_data" json:"show_packet_data"` // include decoded packet data (default false)
ShowPings bool `yaml:"show_pings" json:"show_pings"` // show ping requests and responses (default false)
ShowPasswords bool `yaml:"show_passwords" json:"show_passwords"` // show connecting user passwords (default false)
}
// Hook is a debugging hook which logs additional low-level information from the server.
type Hook struct {
mqtt.HookBase
config *Options
Log *slog.Logger
}
// ID returns the ID of the hook.
func (h *Hook) ID() string {
return "debug"
}
// Provides indicates that this hook provides all methods.
func (h *Hook) Provides(b byte) bool {
return true
}
// Init is called when the hook is initialized.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
return nil
}
// SetOpts is called when the hook receives inheritable server parameters.
func (h *Hook) SetOpts(l *slog.Logger, opts *mqtt.HookOptions) {
h.Log = l
h.Log.Debug("", "method", "SetOpts")
}
// Stop is called when the hook is stopped.
func (h *Hook) Stop() error {
h.Log.Debug("", "method", "Stop")
return nil
}
// OnStarted is called when the server starts.
func (h *Hook) OnStarted() {
h.Log.Debug("", "method", "OnStarted")
}
// OnStopped is called when the server stops.
func (h *Hook) OnStopped() {
h.Log.Debug("", "method", "OnStopped")
}
// OnPacketRead is called when a new packet is received from a client.
func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
return pk, nil
}
h.Log.Debug(fmt.Sprintf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk))
return pk, nil
}
// OnPacketSent is called when a packet is sent to a client.
func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) {
if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
return
}
h.Log.Debug(fmt.Sprintf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk))
}
// OnRetainMessage is called when a published message is retained (or retain deleted/modified).
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
h.Log.Debug("retained message on topic", "m", h.packetMeta(pk))
}
// OnQosPublish is called when a publish packet with Qos is issued to a subscriber.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
h.Log.Debug("inflight out", "m", h.packetMeta(pk))
}
// OnQosComplete is called when the Qos flow for a message has been completed.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug("inflight complete", "m", h.packetMeta(pk))
}
// OnQosDropped is called the Qos flow for a message expires.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug("inflight dropped", "m", h.packetMeta(pk))
}
// OnLWTSent is called when a Will Message has been issued from a disconnecting client.
func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug("sent lwt for client", "method", "OnLWTSent", "client", cl.ID)
}
// OnRetainedExpired is called when the server clears expired retained messages.
func (h *Hook) OnRetainedExpired(filter string) {
h.Log.Debug("retained message expired", "method", "OnRetainedExpired", "topic", filter)
}
// OnClientExpired is called when the server clears an expired client.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
h.Log.Debug("client session expired", "method", "OnClientExpired", "client", cl.ID)
}
// StoredClients is called when the server restores clients from a store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
h.Log.Debug("", "method", "StoredClients")
return v, nil
}
// StoredSubscriptions is called when the server restores subscriptions from a store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
h.Log.Debug("", "method", "StoredSubscriptions")
return v, nil
}
// StoredRetainedMessages is called when the server restores retained messages from a store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
h.Log.Debug("", "method", "StoredRetainedMessages")
return v, nil
}
// StoredInflightMessages is called when the server restores inflight messages from a store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
h.Log.Debug("", "method", "StoredInflightMessages")
return v, nil
}
// StoredSysInfo is called when the server restores system info from a store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
h.Log.Debug("", "method", "StoredSysInfo")
return v, nil
}
// packetMeta adds additional type-specific metadata to the debug logs.
func (h *Hook) packetMeta(pk packets.Packet) map[string]any {
m := map[string]any{}
switch pk.FixedHeader.Type {
case packets.Connect:
m["id"] = pk.Connect.ClientIdentifier
m["clean"] = pk.Connect.Clean
m["keepalive"] = pk.Connect.Keepalive
m["version"] = pk.ProtocolVersion
m["username"] = string(pk.Connect.Username)
if h.config.ShowPasswords {
m["password"] = string(pk.Connect.Password)
}
if pk.Connect.WillFlag {
m["will_topic"] = pk.Connect.WillTopic
m["will_payload"] = string(pk.Connect.WillPayload)
}
case packets.Publish:
m["topic"] = pk.TopicName
m["payload"] = string(pk.Payload)
m["raw"] = pk.Payload
m["qos"] = pk.FixedHeader.Qos
m["id"] = pk.PacketID
case packets.Connack:
fallthrough
case packets.Disconnect:
fallthrough
case packets.Puback:
fallthrough
case packets.Pubrec:
fallthrough
case packets.Pubrel:
fallthrough
case packets.Pubcomp:
m["id"] = pk.PacketID
m["reason"] = int(pk.ReasonCode)
if pk.ReasonCode > packets.CodeSuccess.Code && pk.ProtocolVersion == 5 {
m["reason_string"] = pk.Properties.ReasonString
}
case packets.Subscribe:
f := map[string]int{}
ids := map[string]int{}
for _, v := range pk.Filters {
f[v.Filter] = int(v.Qos)
ids[v.Filter] = v.Identifier
}
m["filters"] = f
m["subids"] = f
case packets.Unsubscribe:
f := []string{}
for _, v := range pk.Filters {
f = append(f, v.Filter)
}
m["filters"] = f
case packets.Suback:
fallthrough
case packets.Unsuback:
r := []int{}
for _, v := range pk.ReasonCodes {
r = append(r, int(v))
}
m["reasons"] = r
case packets.Auth:
// tbd
}
if h.config.ShowPacketData {
m["packet"] = pk
}
return m
}

View File

@@ -0,0 +1,576 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co, gsagula, werbenhu
package badger
import (
"bytes"
"errors"
"fmt"
"strings"
"time"
badgerdb "github.com/dgraph-io/badger/v4"
"testmqtt/hooks/storage"
"testmqtt/mqtt"
"testmqtt/packets"
"testmqtt/system"
)
const (
// defaultDbFile is the default file path for the badger db file.
defaultDbFile = ".badger"
defaultGcInterval = 5 * 60 // gc interval in seconds
defaultGcDiscardRatio = 0.5
)
// clientKey returns a primary key for a client.
func clientKey(cl *mqtt.Client) string {
return storage.ClientKey + "_" + cl.ID
}
// subscriptionKey returns a primary key for a subscription.
func subscriptionKey(cl *mqtt.Client, filter string) string {
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
}
// retainedKey returns a primary key for a retained message.
func retainedKey(topic string) string {
return storage.RetainedKey + "_" + topic
}
// inflightKey returns a primary key for an inflight message.
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
}
// sysInfoKey returns a primary key for system info.
func sysInfoKey() string {
return storage.SysInfoKey
}
// Serializable is an interface for objects that can be serialized and deserialized.
type Serializable interface {
UnmarshalBinary([]byte) error
MarshalBinary() (data []byte, err error)
}
// Options contains configuration settings for the BadgerDB instance.
type Options struct {
Options *badgerdb.Options
Path string `yaml:"path" json:"path"`
// GcDiscardRatio specifies the ratio of log discard compared to the maximum possible log discard.
// Setting it to a higher value would result in fewer space reclaims, while setting it to a lower value
// would result in more space reclaims at the cost of increased activity on the LSM tree.
// discardRatio must be in the range (0.0, 1.0), both endpoints excluded, otherwise, it will be set to the default value of 0.5.
GcDiscardRatio float64 `yaml:"gc_discard_ratio" json:"gc_discard_ratio"`
GcInterval int64 `yaml:"gc_interval" json:"gc_interval"`
}
// Hook is a persistent storage hook based using BadgerDB file store as a backend.
type Hook struct {
mqtt.HookBase
config *Options // options for configuring the BadgerDB instance.
gcTicker *time.Ticker // Ticker for BadgerDB garbage collection.
db *badgerdb.DB // the BadgerDB instance.
}
// ID returns the id of the hook.
func (h *Hook) ID() string {
return "badger-db"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnSessionEstablished,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnRetainMessage,
mqtt.OnWillSent,
mqtt.OnQosPublish,
mqtt.OnQosComplete,
mqtt.OnQosDropped,
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
mqtt.StoredSubscriptions,
mqtt.StoredSysInfo,
}, []byte{b})
}
// GcLoop periodically runs the garbage collection process to reclaim space in the value log files.
// It uses a ticker to trigger the garbage collection at regular intervals specified by the configuration.
// Refer to: https://dgraph.io/docs/badger/get-started/#garbage-collection
func (h *Hook) gcLoop() {
for range h.gcTicker.C {
again:
// Run the garbage collection process with a threshold.
// If the process returns nil (success), repeat the process.
err := h.db.RunValueLogGC(h.config.GcDiscardRatio)
if err == nil {
goto again // Retry garbage collection if successful.
}
}
}
// Init initializes and connects to the badger instance.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
h.config = new(Options)
} else {
h.config = config.(*Options)
}
if len(h.config.Path) == 0 {
h.config.Path = defaultDbFile
}
if h.config.GcInterval == 0 {
h.config.GcInterval = defaultGcInterval
}
if h.config.GcDiscardRatio <= 0.0 || h.config.GcDiscardRatio >= 1.0 {
h.config.GcDiscardRatio = defaultGcDiscardRatio
}
if h.config.Options == nil {
defaultOpts := badgerdb.DefaultOptions(h.config.Path)
h.config.Options = &defaultOpts
}
h.config.Options.Logger = h
var err error
h.db, err = badgerdb.Open(*h.config.Options)
if err != nil {
return err
}
h.gcTicker = time.NewTicker(time.Duration(h.config.GcInterval) * time.Second)
go h.gcLoop()
return nil
}
// Stop closes the badger instance.
func (h *Hook) Stop() error {
if h.gcTicker != nil {
h.gcTicker.Stop()
}
return h.db.Close()
}
// OnSessionEstablished adds a client to the store when their session is established.
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
props := cl.Properties.Props.Copy(false)
in := &storage.Client{
ID: cl.ID,
T: storage.ClientKey,
Remote: cl.Net.Remote,
Listener: cl.Net.Listener,
Username: cl.Properties.Username,
Clean: cl.Properties.Clean,
ProtocolVersion: cl.Properties.ProtocolVersion,
Properties: storage.ClientProperties{
SessionExpiryInterval: props.SessionExpiryInterval,
AuthenticationMethod: props.AuthenticationMethod,
AuthenticationData: props.AuthenticationData,
RequestProblemInfo: props.RequestProblemInfo,
RequestResponseInfo: props.RequestResponseInfo,
ReceiveMaximum: props.ReceiveMaximum,
TopicAliasMaximum: props.TopicAliasMaximum,
User: props.User,
MaximumPacketSize: props.MaximumPacketSize,
},
Will: storage.ClientWill(cl.Properties.Will),
}
err := h.setKv(clientKey(cl), in)
if err != nil {
h.Log.Error("failed to upsert client data", "error", err, "data", in)
}
}
// OnDisconnect removes a client from the store if their session has expired.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
h.updateClient(cl)
if !expire {
return
}
if errors.Is(cl.StopCause(), packets.ErrSessionTakenOver) {
return
}
_ = h.delKv(clientKey(cl))
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
_ = h.setKv(in.ID, in)
}
}
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
_ = h.delKv(subscriptionKey(cl, pk.Filters[i].Filter))
}
}
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
if r == -1 {
_ = h.delKv(retainedKey(pk.TopicName))
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: retainedKey(pk.TopicName),
T: storage.RetainedKey,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Created: pk.Created,
Origin: pk.Origin,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
_ = h.setKv(in.ID, in)
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: inflightKey(cl, pk),
T: storage.InflightKey,
Origin: pk.Origin,
PacketID: pk.PacketID,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Sent: sent,
Created: pk.Created,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
_ = h.setKv(in.ID, in)
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
_ = h.delKv(inflightKey(cl, pk))
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
}
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
in := &storage.SystemInfo{
ID: sysInfoKey(),
T: storage.SysInfoKey,
Info: *sys.Clone(),
}
_ = h.setKv(in.ID, in)
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
_ = h.delKv(retainedKey(filter))
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
_ = h.delKv(clientKey(cl))
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
err = h.iterKv(storage.ClientKey, func(value []byte) error {
obj := storage.Client{}
err = obj.UnmarshalBinary(value)
if err == nil {
v = append(v, obj)
}
return err
})
return
}
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
v = make([]storage.Subscription, 0)
err = h.iterKv(storage.SubscriptionKey, func(value []byte) error {
obj := storage.Subscription{}
err = obj.UnmarshalBinary(value)
if err == nil {
v = append(v, obj)
}
return err
})
return
}
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
v = make([]storage.Message, 0)
err = h.iterKv(storage.RetainedKey, func(value []byte) error {
obj := storage.Message{}
err = obj.UnmarshalBinary(value)
if err == nil {
v = append(v, obj)
}
return err
})
if err != nil && !errors.Is(err, badgerdb.ErrKeyNotFound) {
return
}
return
}
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
v = make([]storage.Message, 0)
err = h.iterKv(storage.InflightKey, func(value []byte) error {
obj := storage.Message{}
err = obj.UnmarshalBinary(value)
if err == nil {
v = append(v, obj)
}
return err
})
return
}
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
err = h.getKv(storage.SysInfoKey, &v)
if err != nil && !errors.Is(err, badgerdb.ErrKeyNotFound) {
return
}
return v, nil
}
// Errorf satisfies the badger interface for an error logger.
func (h *Hook) Errorf(m string, v ...any) {
h.Log.Error(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
}
// Warningf satisfies the badger interface for a warning logger.
func (h *Hook) Warningf(m string, v ...any) {
h.Log.Warn(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
}
// Infof satisfies the badger interface for an info logger.
func (h *Hook) Infof(m string, v ...any) {
h.Log.Info(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
}
// Debugf satisfies the badger interface for a debug logger.
func (h *Hook) Debugf(m string, v ...any) {
h.Log.Debug(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
}
// setKv stores a key-value pair in the database.
func (h *Hook) setKv(k string, v storage.Serializable) error {
err := h.db.Update(func(txn *badgerdb.Txn) error {
data, _ := v.MarshalBinary()
return txn.Set([]byte(k), data)
})
if err != nil {
h.Log.Error("failed to upsert data", "error", err, "key", k)
}
return err
}
// delKv deletes a key-value pair from the database.
func (h *Hook) delKv(k string) error {
err := h.db.Update(func(txn *badgerdb.Txn) error {
return txn.Delete([]byte(k))
})
if err != nil {
h.Log.Error("failed to delete data", "error", err, "key", k)
}
return err
}
// getKv retrieves the value associated with a key from the database.
func (h *Hook) getKv(k string, v storage.Serializable) error {
return h.db.View(func(txn *badgerdb.Txn) error {
item, err := txn.Get([]byte(k))
if err != nil {
return err
}
value, err := item.ValueCopy(nil)
if err != nil {
return err
}
return v.UnmarshalBinary(value)
})
}
// iterKv iterates over key-value pairs with keys having the specified prefix in the database.
func (h *Hook) iterKv(prefix string, visit func([]byte) error) error {
err := h.db.View(func(txn *badgerdb.Txn) error {
iterator := txn.NewIterator(badgerdb.DefaultIteratorOptions)
defer iterator.Close()
for iterator.Seek([]byte(prefix)); iterator.ValidForPrefix([]byte(prefix)); iterator.Next() {
item := iterator.Item()
value, err := item.ValueCopy(nil)
if err != nil {
return err
}
if err := visit(value); err != nil {
return err
}
}
return nil
})
if err != nil {
h.Log.Error("failed to find data", "error", err, "prefix", prefix)
}
return err
}

View File

@@ -0,0 +1,809 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co, werbenhu
package badger
import (
"errors"
"log/slog"
"os"
"strings"
"testing"
"time"
badgerdb "github.com/dgraph-io/badger/v4"
"testmqtt/hooks/storage"
"testmqtt/mqtt"
"testmqtt/packets"
"testmqtt/system"
"github.com/stretchr/testify/require"
)
var (
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
client = &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
)
func teardown(t *testing.T, path string, h *Hook) {
_ = h.Stop()
_ = h.db.Close()
err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
require.NoError(t, err)
}
func TestSetGetDelKv(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.Init(nil)
defer teardown(t, h.config.Path, h)
key := "testKey"
value := &storage.Client{ID: "cl1"}
err := h.setKv(key, value)
require.NoError(t, err)
var client storage.Client
err = h.getKv(key, &client)
require.NoError(t, err)
require.Equal(t, "cl1", client.ID)
err = h.delKv(key)
require.NoError(t, err)
err = h.getKv(key, &client)
require.ErrorIs(t, badgerdb.ErrKeyNotFound, err)
}
func TestClientKey(t *testing.T) {
k := clientKey(&mqtt.Client{ID: "cl1"})
require.Equal(t, storage.ClientKey+"_cl1", k)
}
func TestSubscriptionKey(t *testing.T) {
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
}
func TestRetainedKey(t *testing.T) {
k := retainedKey("a/b/c")
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
}
func TestInflightKey(t *testing.T) {
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
require.Equal(t, storage.InflightKey+"_cl1:1", k)
}
func TestSysInfoKey(t *testing.T) {
require.Equal(t, storage.SysInfoKey, sysInfoKey())
}
func TestID(t *testing.T) {
h := new(Hook)
require.Equal(t, "badger-db", h.ID())
}
func TestProvides(t *testing.T) {
h := new(Hook)
require.True(t, h.Provides(mqtt.OnSessionEstablished))
require.True(t, h.Provides(mqtt.OnDisconnect))
require.True(t, h.Provides(mqtt.OnSubscribed))
require.True(t, h.Provides(mqtt.OnUnsubscribed))
require.True(t, h.Provides(mqtt.OnRetainMessage))
require.True(t, h.Provides(mqtt.OnQosPublish))
require.True(t, h.Provides(mqtt.OnQosComplete))
require.True(t, h.Provides(mqtt.OnQosDropped))
require.True(t, h.Provides(mqtt.OnSysInfoTick))
require.True(t, h.Provides(mqtt.StoredClients))
require.True(t, h.Provides(mqtt.StoredInflightMessages))
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
require.True(t, h.Provides(mqtt.StoredSubscriptions))
require.True(t, h.Provides(mqtt.StoredSysInfo))
require.False(t, h.Provides(mqtt.OnACLCheck))
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
}
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestInitBadOption(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(&Options{
Options: &badgerdb.Options{
NumCompactors: 1,
},
})
// Cannot have 1 compactor. Need at least 2
require.Error(t, err)
}
func TestInitUseDefaults(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
require.Equal(t, defaultDbFile, h.config.Path)
}
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
r := new(storage.Client)
err = h.getKv(clientKey(client), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
require.Equal(t, client.Properties.Username, r.Username)
require.Equal(t, client.Properties.Clean, r.Clean)
require.Equal(t, client.Net.Remote, r.Remote)
require.Equal(t, client.Net.Listener, r.Listener)
require.NotSame(t, client, r)
h.OnDisconnect(client, nil, false)
r2 := new(storage.Client)
err = h.getKv(clientKey(client), r2)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
h.OnDisconnect(client, nil, true)
r3 := new(storage.Client)
err = h.getKv(clientKey(client), r3)
require.Error(t, err)
require.ErrorIs(t, badgerdb.ErrKeyNotFound, err)
require.Empty(t, r3.ID)
}
func TestOnClientExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
cl := &mqtt.Client{ID: "cl1"}
clientKey := clientKey(cl)
err = h.setKv(clientKey, &storage.Client{ID: cl.ID})
require.NoError(t, err)
r := new(storage.Client)
err = h.getKv(clientKey, r)
require.NoError(t, err)
require.Equal(t, cl.ID, r.ID)
h.OnClientExpired(cl)
err = h.getKv(clientKey, r)
require.Error(t, err)
require.ErrorIs(t, badgerdb.ErrKeyNotFound, err)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnClientExpired(client)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnWillSent(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
c1 := client
c1.Properties.Will.Flag = 1
h.OnWillSent(c1, packets.Packet{})
r := new(storage.Client)
err = h.getKv(clientKey(client), r)
require.NoError(t, err)
require.Equal(t, uint32(1), r.Will.Flag)
require.NotSame(t, client, r)
}
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h.config.Path, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
r := new(storage.Subscription)
err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.Client)
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
require.Equal(t, byte(0), r.Qos)
h.OnUnsubscribed(client, pkf)
err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
require.Error(t, err)
require.Equal(t, badgerdb.ErrKeyNotFound, err)
}
func TestOnSubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnUnsubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnUnsubscribed(client, pkf)
}
func TestOnRetainMessageThenUnset(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnRetainMessage(client, pk, 1)
r := new(storage.Message)
err = h.getKv(retainedKey(pk.TopicName), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
h.OnRetainMessage(client, pk, -1)
err = h.getKv(retainedKey(pk.TopicName), r)
require.Error(t, err)
require.ErrorIs(t, err, badgerdb.ErrKeyNotFound)
// coverage: delete deleted
h.OnRetainMessage(client, pk, -1)
err = h.getKv(retainedKey(pk.TopicName), r)
require.Error(t, err)
require.ErrorIs(t, err, badgerdb.ErrKeyNotFound)
}
func TestOnRetainedExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
m := &storage.Message{
ID: retainedKey("a/b/c"),
T: storage.RetainedKey,
TopicName: "a/b/c",
}
err = h.setKv(m.ID, m)
require.NoError(t, err)
r := new(storage.Message)
err = h.getKv(m.ID, r)
require.NoError(t, err)
require.Equal(t, m.TopicName, r.TopicName)
h.OnRetainedExpired(m.TopicName)
err = h.getKv(m.ID, r)
require.Error(t, err)
require.ErrorIs(t, err, badgerdb.ErrKeyNotFound)
}
func TestOnRetainExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnQosPublishThenQOSComplete(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
Qos: 2,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
r := new(storage.Message)
err = h.getKv(inflightKey(client, pk), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
// ensure dates are properly saved
require.True(t, r.Sent > 0)
require.True(t, time.Now().Unix()-1 < r.Sent)
// OnQosDropped is a passthrough to OnQosComplete here
h.OnQosDropped(client, pk)
err = h.getKv(inflightKey(client, pk), r)
require.Error(t, err)
require.ErrorIs(t, err, badgerdb.ErrKeyNotFound)
}
func TestOnQosPublishNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosCompleteNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosDroppedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnQosDropped(client, packets.Packet{})
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
info := &system.Info{
Version: "2.0.0",
BytesReceived: 100,
}
h.OnSysInfoTick(info)
r := new(storage.SystemInfo)
err = h.getKv(storage.SysInfoKey, r)
require.NoError(t, err)
require.Equal(t, info.Version, r.Version)
require.Equal(t, info.BytesReceived, r.BytesReceived)
require.NotSame(t, info, r)
}
func TestOnSysInfoTickNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSysInfoTick(new(system.Info))
}
func TestStoredClients(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with clients
err = h.setKv(storage.ClientKey+"_cl1", &storage.Client{ID: "cl1", T: storage.ClientKey})
require.NoError(t, err)
err = h.setKv(storage.ClientKey+"_cl2", &storage.Client{ID: "cl2", T: storage.ClientKey})
require.NoError(t, err)
err = h.setKv(storage.ClientKey+"_cl3", &storage.Client{ID: "cl3", T: storage.ClientKey})
require.NoError(t, err)
r, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "cl1", r[0].ID)
require.Equal(t, "cl2", r[1].ID)
require.Equal(t, "cl3", r[2].ID)
}
func TestStoredClientsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredClients()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSubscriptions(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with subscriptions
err = h.setKv(storage.SubscriptionKey+"_sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey})
require.NoError(t, err)
err = h.setKv(storage.SubscriptionKey+"_sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey})
require.NoError(t, err)
err = h.setKv(storage.SubscriptionKey+"_sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey})
require.NoError(t, err)
r, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "sub1", r[0].ID)
require.Equal(t, "sub2", r[1].ID)
require.Equal(t, "sub3", r[2].ID)
}
func TestStoredSubscriptionsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredRetainedMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
require.NoError(t, err)
err = h.setKv(storage.RetainedKey+"_m2", &storage.Message{ID: "m2", T: storage.RetainedKey})
require.NoError(t, err)
err = h.setKv(storage.RetainedKey+"_m3", &storage.Message{ID: "m3", T: storage.RetainedKey})
require.NoError(t, err)
err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
r, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "m1", r[0].ID)
require.Equal(t, "m2", r[1].ID)
require.Equal(t, "m3", r[2].ID)
}
func TestStoredRetainedMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredInflightMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.setKv(storage.InflightKey+"_i1", &storage.Message{ID: "i1", T: storage.InflightKey})
require.NoError(t, err)
err = h.setKv(storage.InflightKey+"_i2", &storage.Message{ID: "i2", T: storage.InflightKey})
require.NoError(t, err)
err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
require.NoError(t, err)
r, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "i1", r[0].ID)
require.Equal(t, "i2", r[1].ID)
require.Equal(t, "i3", r[2].ID)
}
func TestStoredInflightMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSysInfo(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.setKv(storage.SysInfoKey, &storage.SystemInfo{
ID: storage.SysInfoKey,
Info: system.Info{
Version: "2.0.0",
},
T: storage.SysInfoKey,
})
require.NoError(t, err)
r, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", r.Info.Version)
}
func TestStoredSysInfoNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.NoError(t, err)
}
func TestErrorf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(logger, nil)
h.Errorf("test", 1, 2, 3)
}
func TestWarningf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(logger, nil)
h.Warningf("test", 1, 2, 3)
}
func TestInfof(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(logger, nil)
h.Infof("test", 1, 2, 3)
}
func TestDebugf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(logger, nil)
h.Debugf("test", 1, 2, 3)
}
func TestGcLoop(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
opts := badgerdb.DefaultOptions(defaultDbFile)
opts.ValueLogFileSize = 1 << 20
h.Init(&Options{
GcInterval: 2, // Set the interval for garbage collection.
Options: &opts,
})
defer teardown(t, defaultDbFile, h)
h.OnSessionEstablished(client, packets.Packet{})
h.OnDisconnect(client, nil, true)
time.Sleep(3 * time.Second)
}
func TestGetSetDelKv(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
defer teardown(t, h.config.Path, h)
require.NoError(t, err)
err = h.setKv("testKey", &storage.Client{ID: "testId"})
require.NoError(t, err)
var obj storage.Client
err = h.getKv("testKey", &obj)
require.NoError(t, err)
err = h.delKv("testKey")
require.NoError(t, err)
err = h.getKv("testKey", &obj)
require.Error(t, err)
require.ErrorIs(t, badgerdb.ErrKeyNotFound, err)
}
func TestIterKv(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
defer teardown(t, h.config.Path, h)
require.NoError(t, err)
h.setKv("prefix_a_1", &storage.Client{ID: "1"})
h.setKv("prefix_a_2", &storage.Client{ID: "2"})
h.setKv("prefix_b_2", &storage.Client{ID: "3"})
var clients []storage.Client
err = h.iterKv("prefix_a", func(data []byte) error {
var item storage.Client
item.UnmarshalBinary(data)
clients = append(clients, item)
return nil
})
require.Equal(t, 2, len(clients))
require.NoError(t, err)
visitErr := errors.New("iter visit error")
err = h.iterKv("prefix_b", func(data []byte) error {
return visitErr
})
require.ErrorIs(t, visitErr, err)
}

525
hooks/storage/bolt/bolt.go Normal file
View File

@@ -0,0 +1,525 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co, werbenhu
// Package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
package bolt
import (
"bytes"
"errors"
"time"
"testmqtt/hooks/storage"
"testmqtt/mqtt"
"testmqtt/packets"
"testmqtt/system"
"go.etcd.io/bbolt"
)
var (
ErrBucketNotFound = errors.New("bucket not found")
ErrKeyNotFound = errors.New("key not found")
)
const (
// defaultDbFile is the default file path for the boltdb file.
defaultDbFile = ".bolt"
// defaultTimeout is the default time to hold a connection to the file.
defaultTimeout = 250 * time.Millisecond
defaultBucket = "mochi"
)
// clientKey returns a primary key for a client.
func clientKey(cl *mqtt.Client) string {
return storage.ClientKey + "_" + cl.ID
}
// subscriptionKey returns a primary key for a subscription.
func subscriptionKey(cl *mqtt.Client, filter string) string {
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
}
// retainedKey returns a primary key for a retained message.
func retainedKey(topic string) string {
return storage.RetainedKey + "_" + topic
}
// inflightKey returns a primary key for an inflight message.
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
}
// sysInfoKey returns a primary key for system info.
func sysInfoKey() string {
return storage.SysInfoKey
}
// Options contains configuration settings for the bolt instance.
type Options struct {
Options *bbolt.Options
Bucket string `yaml:"bucket" json:"bucket"`
Path string `yaml:"path" json:"path"`
}
// Hook is a persistent storage hook based using boltdb file store as a backend.
type Hook struct {
mqtt.HookBase
config *Options // options for configuring the boltdb instance.
db *bbolt.DB // the boltdb instance.
}
// ID returns the id of the hook.
func (h *Hook) ID() string {
return "bolt-db"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnSessionEstablished,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnRetainMessage,
mqtt.OnWillSent,
mqtt.OnQosPublish,
mqtt.OnQosComplete,
mqtt.OnQosDropped,
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
mqtt.StoredSubscriptions,
mqtt.StoredSysInfo,
}, []byte{b})
}
// Init initializes and connects to the boltdb instance.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
if h.config.Options == nil {
h.config.Options = &bbolt.Options{
Timeout: defaultTimeout,
}
}
if len(h.config.Path) == 0 {
h.config.Path = defaultDbFile
}
if len(h.config.Bucket) == 0 {
h.config.Bucket = defaultBucket
}
var err error
h.db, err = bbolt.Open(h.config.Path, 0600, h.config.Options)
if err != nil {
return err
}
err = h.db.Update(func(tx *bbolt.Tx) error {
_, err := tx.CreateBucketIfNotExists([]byte(h.config.Bucket))
return err
})
return err
}
// Stop closes the boltdb instance.
func (h *Hook) Stop() error {
err := h.db.Close()
h.db = nil
return err
}
// OnSessionEstablished adds a client to the store when their session is established.
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
props := cl.Properties.Props.Copy(false)
in := &storage.Client{
ID: cl.ID,
T: storage.ClientKey,
Remote: cl.Net.Remote,
Listener: cl.Net.Listener,
Username: cl.Properties.Username,
Clean: cl.Properties.Clean,
ProtocolVersion: cl.Properties.ProtocolVersion,
Properties: storage.ClientProperties{
SessionExpiryInterval: props.SessionExpiryInterval,
AuthenticationMethod: props.AuthenticationMethod,
AuthenticationData: props.AuthenticationData,
RequestProblemInfo: props.RequestProblemInfo,
RequestResponseInfo: props.RequestResponseInfo,
ReceiveMaximum: props.ReceiveMaximum,
TopicAliasMaximum: props.TopicAliasMaximum,
User: props.User,
MaximumPacketSize: props.MaximumPacketSize,
},
Will: storage.ClientWill(cl.Properties.Will),
}
_ = h.setKv(clientKey(cl), in)
}
// OnDisconnect removes a client from the store if they were using a clean session.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
if !expire {
return
}
if cl.StopCause() == packets.ErrSessionTakenOver {
return
}
_ = h.delKv(clientKey(cl))
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
_ = h.setKv(in.ID, in)
}
}
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
_ = h.delKv(subscriptionKey(cl, pk.Filters[i].Filter))
}
}
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
if r == -1 {
_ = h.delKv(retainedKey(pk.TopicName))
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: retainedKey(pk.TopicName),
T: storage.RetainedKey,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Created: pk.Created,
Origin: pk.Origin,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
_ = h.setKv(in.ID, in)
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: inflightKey(cl, pk),
T: storage.InflightKey,
Origin: pk.Origin,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Sent: sent,
Created: pk.Created,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
_ = h.setKv(in.ID, in)
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
_ = h.delKv(inflightKey(cl, pk))
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
}
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
in := &storage.SystemInfo{
ID: sysInfoKey(),
T: storage.SysInfoKey,
Info: *sys,
}
_ = h.setKv(in.ID, in)
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
_ = h.delKv(retainedKey(filter))
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
_ = h.delKv(clientKey(cl))
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return v, storage.ErrDBFileNotOpen
}
err = h.iterKv(storage.ClientKey, func(value []byte) error {
obj := storage.Client{}
err = obj.UnmarshalBinary(value)
if err == nil {
v = append(v, obj)
}
return err
})
return v, nil
}
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return v, storage.ErrDBFileNotOpen
}
v = make([]storage.Subscription, 0)
err = h.iterKv(storage.SubscriptionKey, func(value []byte) error {
obj := storage.Subscription{}
err = obj.UnmarshalBinary(value)
if err == nil {
v = append(v, obj)
}
return err
})
return
}
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return v, storage.ErrDBFileNotOpen
}
v = make([]storage.Message, 0)
err = h.iterKv(storage.RetainedKey, func(value []byte) error {
obj := storage.Message{}
err = obj.UnmarshalBinary(value)
if err == nil {
v = append(v, obj)
}
return err
})
return
}
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return v, storage.ErrDBFileNotOpen
}
v = make([]storage.Message, 0)
err = h.iterKv(storage.InflightKey, func(value []byte) error {
obj := storage.Message{}
err = obj.UnmarshalBinary(value)
if err == nil {
v = append(v, obj)
}
return err
})
return
}
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return v, storage.ErrDBFileNotOpen
}
err = h.getKv(storage.SysInfoKey, &v)
if err != nil && !errors.Is(err, ErrKeyNotFound) {
return
}
return v, nil
}
// setKv stores a key-value pair in the database.
func (h *Hook) setKv(k string, v storage.Serializable) error {
err := h.db.Update(func(tx *bbolt.Tx) error {
bucket := tx.Bucket([]byte(h.config.Bucket))
data, _ := v.MarshalBinary()
err := bucket.Put([]byte(k), data)
if err != nil {
return err
}
return nil
})
if err != nil {
h.Log.Error("failed to upsert data", "error", err, "key", k)
}
return err
}
// delKv deletes a key-value pair from the database.
func (h *Hook) delKv(k string) error {
err := h.db.Update(func(tx *bbolt.Tx) error {
bucket := tx.Bucket([]byte(h.config.Bucket))
err := bucket.Delete([]byte(k))
if err != nil {
return err
}
return nil
})
if err != nil {
h.Log.Error("failed to delete data", "error", err, "key", k)
}
return err
}
// getKv retrieves the value associated with a key from the database.
func (h *Hook) getKv(k string, v storage.Serializable) error {
err := h.db.View(func(tx *bbolt.Tx) error {
bucket := tx.Bucket([]byte(h.config.Bucket))
value := bucket.Get([]byte(k))
if value == nil {
return ErrKeyNotFound
}
return v.UnmarshalBinary(value)
})
if err != nil {
h.Log.Error("failed to get data", "error", err, "key", k)
}
return err
}
// iterKv iterates over key-value pairs with keys having the specified prefix in the database.
func (h *Hook) iterKv(prefix string, visit func([]byte) error) error {
err := h.db.View(func(tx *bbolt.Tx) error {
bucket := tx.Bucket([]byte(h.config.Bucket))
c := bucket.Cursor()
for k, v := c.Seek([]byte(prefix)); k != nil && string(k[:len(prefix)]) == prefix; k, v = c.Next() {
if err := visit(v); err != nil {
return err
}
}
return nil
})
if err != nil {
h.Log.Error("failed to iter data", "error", err, "prefix", prefix)
}
return err
}

View File

@@ -0,0 +1,791 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co, werbenhu
package bolt
import (
"errors"
"log/slog"
"os"
"testing"
"time"
"testmqtt/hooks/storage"
"testmqtt/mqtt"
"testmqtt/packets"
"testmqtt/system"
"github.com/stretchr/testify/require"
)
var (
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
client = &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
)
func teardown(t *testing.T, path string, h *Hook) {
_ = h.Stop()
err := os.Remove(path)
require.NoError(t, err)
}
func TestClientKey(t *testing.T) {
k := clientKey(&mqtt.Client{ID: "cl1"})
require.Equal(t, "CL_cl1", k)
}
func TestSubscriptionKey(t *testing.T) {
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
}
func TestRetainedKey(t *testing.T) {
k := retainedKey("a/b/c")
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
}
func TestInflightKey(t *testing.T) {
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
require.Equal(t, storage.InflightKey+"_cl1:1", k)
}
func TestSysInfoKey(t *testing.T) {
require.Equal(t, storage.SysInfoKey, sysInfoKey())
}
func TestID(t *testing.T) {
h := new(Hook)
require.Equal(t, "bolt-db", h.ID())
}
func TestProvides(t *testing.T) {
h := new(Hook)
require.True(t, h.Provides(mqtt.OnSessionEstablished))
require.True(t, h.Provides(mqtt.OnDisconnect))
require.True(t, h.Provides(mqtt.OnSubscribed))
require.True(t, h.Provides(mqtt.OnUnsubscribed))
require.True(t, h.Provides(mqtt.OnRetainMessage))
require.True(t, h.Provides(mqtt.OnQosPublish))
require.True(t, h.Provides(mqtt.OnQosComplete))
require.True(t, h.Provides(mqtt.OnQosDropped))
require.True(t, h.Provides(mqtt.OnSysInfoTick))
require.True(t, h.Provides(mqtt.StoredClients))
require.True(t, h.Provides(mqtt.StoredInflightMessages))
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
require.True(t, h.Provides(mqtt.StoredSubscriptions))
require.True(t, h.Provides(mqtt.StoredSysInfo))
require.False(t, h.Provides(mqtt.OnACLCheck))
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
}
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestInitUseDefaults(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
require.Equal(t, defaultTimeout, h.config.Options.Timeout)
require.Equal(t, defaultDbFile, h.config.Path)
}
func TestInitBadPath(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(&Options{
Path: "..",
})
require.Error(t, err)
}
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
r := new(storage.Client)
err = h.getKv(clientKey(client), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
require.Equal(t, client.Net.Remote, r.Remote)
require.Equal(t, client.Net.Listener, r.Listener)
require.Equal(t, client.Properties.Username, r.Username)
require.Equal(t, client.Properties.Clean, r.Clean)
require.NotSame(t, client, r)
h.OnDisconnect(client, nil, false)
r2 := new(storage.Client)
err = h.getKv(clientKey(client), r2)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
h.OnDisconnect(client, nil, true)
r3 := new(storage.Client)
err = h.getKv(clientKey(client), r3)
require.Error(t, err)
require.ErrorIs(t, ErrKeyNotFound, err)
require.Empty(t, r3.ID)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnWillSent(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
c1 := client
c1.Properties.Will.Flag = 1
h.OnWillSent(c1, packets.Packet{})
r := new(storage.Client)
err = h.getKv(clientKey(client), r)
require.NoError(t, err)
require.Equal(t, uint32(1), r.Will.Flag)
require.NotSame(t, client, r)
}
func TestOnClientExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
cl := &mqtt.Client{ID: "cl1"}
clientKey := clientKey(cl)
err = h.setKv(clientKey, &storage.Client{ID: cl.ID})
require.NoError(t, err)
r := new(storage.Client)
err = h.getKv(clientKey, r)
require.NoError(t, err)
require.Equal(t, cl.ID, r.ID)
h.OnClientExpired(cl)
err = h.getKv(clientKey, r)
require.Error(t, err)
require.ErrorIs(t, ErrKeyNotFound, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h.config.Path, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
r := new(storage.Subscription)
err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.Client)
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
require.Equal(t, byte(0), r.Qos)
h.OnUnsubscribed(client, pkf)
err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
require.Error(t, err)
require.Equal(t, ErrKeyNotFound, err)
}
func TestOnSubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnUnsubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnUnsubscribed(client, pkf)
}
func TestOnRetainMessageThenUnset(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnRetainMessage(client, pk, 1)
r := new(storage.Message)
err = h.getKv(retainedKey(pk.TopicName), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
h.OnRetainMessage(client, pk, -1)
err = h.getKv(retainedKey(pk.TopicName), r)
require.Error(t, err)
require.Equal(t, ErrKeyNotFound, err)
// coverage: delete deleted
h.OnRetainMessage(client, pk, -1)
err = h.getKv(retainedKey(pk.TopicName), r)
require.Error(t, err)
require.Equal(t, ErrKeyNotFound, err)
}
func TestOnRetainedExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
m := &storage.Message{
ID: retainedKey("a/b/c"),
T: storage.RetainedKey,
TopicName: "a/b/c",
}
err = h.setKv(m.ID, m)
require.NoError(t, err)
r := new(storage.Message)
err = h.getKv(m.ID, r)
require.NoError(t, err)
require.Equal(t, m.TopicName, r.TopicName)
h.OnRetainedExpired(m.TopicName)
err = h.getKv(m.ID, r)
require.Error(t, err)
require.Equal(t, ErrKeyNotFound, err)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnQosPublishThenQOSComplete(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
Qos: 2,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
r := new(storage.Message)
err = h.getKv(inflightKey(client, pk), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
// ensure dates are properly saved to bolt
require.True(t, r.Sent > 0)
require.True(t, time.Now().Unix()-1 < r.Sent)
// OnQosDropped is a passthrough to OnQosComplete here
h.OnQosDropped(client, pk)
err = h.getKv(inflightKey(client, pk), r)
require.Error(t, err)
require.Equal(t, ErrKeyNotFound, err)
}
func TestOnQosPublishNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosCompleteNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosDroppedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnQosDropped(client, packets.Packet{})
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
info := &system.Info{
Version: "2.0.0",
BytesReceived: 100,
}
h.OnSysInfoTick(info)
r := new(storage.SystemInfo)
err = h.getKv(storage.SysInfoKey, r)
require.NoError(t, err)
require.Equal(t, info.Version, r.Version)
require.Equal(t, info.BytesReceived, r.BytesReceived)
require.NotSame(t, info, r)
}
func TestOnSysInfoTickNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSysInfoTick(new(system.Info))
}
func TestStoredClients(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with clients
err = h.setKv(storage.ClientKey+"_"+"cl1", &storage.Client{ID: "cl1"})
require.NoError(t, err)
err = h.setKv(storage.ClientKey+"_"+"cl2", &storage.Client{ID: "cl2"})
require.NoError(t, err)
err = h.setKv(storage.ClientKey+"_"+"cl3", &storage.Client{ID: "cl3"})
require.NoError(t, err)
r, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "cl1", r[0].ID)
require.Equal(t, "cl2", r[1].ID)
require.Equal(t, "cl3", r[2].ID)
}
func TestStoredClientsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredClients()
require.Empty(t, v)
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
}
func TestStoredClientsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredClients()
require.Empty(t, v)
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
}
func TestStoredSubscriptions(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with subscriptions
err = h.setKv(storage.SubscriptionKey+"_"+"sub1", &storage.Subscription{ID: "sub1"})
require.NoError(t, err)
err = h.setKv(storage.SubscriptionKey+"_"+"sub2", &storage.Subscription{ID: "sub2"})
require.NoError(t, err)
err = h.setKv(storage.SubscriptionKey+"_"+"sub3", &storage.Subscription{ID: "sub3"})
require.NoError(t, err)
r, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "sub1", r[0].ID)
require.Equal(t, "sub2", r[1].ID)
require.Equal(t, "sub3", r[2].ID)
}
func TestStoredSubscriptionsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
}
func TestStoredSubscriptionsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
}
func TestStoredRetainedMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.setKv(storage.RetainedKey+"_"+"m1", &storage.Message{ID: "m1"})
require.NoError(t, err)
err = h.setKv(storage.RetainedKey+"_"+"m2", &storage.Message{ID: "m2"})
require.NoError(t, err)
err = h.setKv(storage.RetainedKey+"_"+"m3", &storage.Message{ID: "m3"})
require.NoError(t, err)
err = h.setKv(storage.InflightKey+"_"+"i3", &storage.Message{ID: "i3"})
require.NoError(t, err)
r, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "m1", r[0].ID)
require.Equal(t, "m2", r[1].ID)
require.Equal(t, "m3", r[2].ID)
}
func TestStoredRetainedMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
}
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredInflightMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.setKv(storage.InflightKey+"_"+"i1", &storage.Message{ID: "i1"})
require.NoError(t, err)
err = h.setKv(storage.InflightKey+"_"+"i2", &storage.Message{ID: "i2"})
require.NoError(t, err)
err = h.setKv(storage.InflightKey+"_"+"i3", &storage.Message{ID: "i3"})
require.NoError(t, err)
err = h.setKv(storage.RetainedKey+"_"+"m1", &storage.Message{ID: "m1"})
require.NoError(t, err)
r, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "i1", r[0].ID)
require.Equal(t, "i2", r[1].ID)
require.Equal(t, "i3", r[2].ID)
}
func TestStoredInflightMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
}
func TestStoredInflightMessagesClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
}
func TestStoredSysInfo(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with sys info
err = h.setKv(storage.SysInfoKey, &storage.SystemInfo{
ID: storage.SysInfoKey,
Info: system.Info{
Version: "2.0.0",
},
T: storage.SysInfoKey,
})
require.NoError(t, err)
r, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", r.Info.Version)
}
func TestStoredSysInfoNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
}
func TestStoredSysInfoClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.Error(t, err)
}
func TestGetSetDelKv(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
defer teardown(t, h.config.Path, h)
require.NoError(t, err)
err = h.setKv("testId", &storage.Client{ID: "testId"})
require.NoError(t, err)
var obj storage.Client
err = h.getKv("testId", &obj)
require.NoError(t, err)
err = h.delKv("testId")
require.NoError(t, err)
err = h.getKv("testId", &obj)
require.Error(t, err)
require.ErrorIs(t, ErrKeyNotFound, err)
}
func TestIterKv(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
defer teardown(t, h.config.Path, h)
require.NoError(t, err)
h.setKv("prefix_a_1", &storage.Client{ID: "1"})
h.setKv("prefix_a_2", &storage.Client{ID: "2"})
h.setKv("prefix_b_2", &storage.Client{ID: "3"})
var clients []storage.Client
err = h.iterKv("prefix_a", func(data []byte) error {
var item storage.Client
item.UnmarshalBinary(data)
clients = append(clients, item)
return nil
})
require.Equal(t, 2, len(clients))
require.NoError(t, err)
visitErr := errors.New("iter visit error")
err = h.iterKv("prefix_b", func(data []byte) error {
return visitErr
})
require.ErrorIs(t, visitErr, err)
}

View File

@@ -0,0 +1,524 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: werbenhu
package pebble
import (
"bytes"
"errors"
"fmt"
"strings"
pebbledb "github.com/cockroachdb/pebble"
"testmqtt/hooks/storage"
"testmqtt/mqtt"
"testmqtt/packets"
"testmqtt/system"
)
const (
// defaultDbFile is the default file path for the pebble db file.
defaultDbFile = ".pebble"
)
// clientKey returns a primary key for a client.
func clientKey(cl *mqtt.Client) string {
return storage.ClientKey + "_" + cl.ID
}
// subscriptionKey returns a primary key for a subscription.
func subscriptionKey(cl *mqtt.Client, filter string) string {
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
}
// retainedKey returns a primary key for a retained message.
func retainedKey(topic string) string {
return storage.RetainedKey + "_" + topic
}
// inflightKey returns a primary key for an inflight message.
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
}
// sysInfoKey returns a primary key for system info.
func sysInfoKey() string {
return storage.SysInfoKey
}
// keyUpperBound returns the upper bound for a given byte slice by incrementing the last byte.
// It returns nil if all bytes are incremented and equal to 0.
func keyUpperBound(b []byte) []byte {
end := make([]byte, len(b))
copy(end, b)
for i := len(end) - 1; i >= 0; i-- {
end[i] = end[i] + 1
if end[i] != 0 {
return end[:i+1]
}
}
return nil
}
const (
NoSync = "NoSync" // NoSync specifies the default write options for writes which do not synchronize to disk.
Sync = "Sync" // Sync specifies the default write options for writes which synchronize to disk.
)
// Options contains configuration settings for the pebble DB instance.
type Options struct {
Options *pebbledb.Options
Mode string `yaml:"mode" json:"mode"`
Path string `yaml:"path" json:"path"`
}
// Hook is a persistent storage hook based using pebble DB file store as a backend.
type Hook struct {
mqtt.HookBase
config *Options // options for configuring the pebble DB instance.
db *pebbledb.DB // the pebble DB instance
mode *pebbledb.WriteOptions // mode holds the optional per-query parameters for Set and Delete operations
}
// ID returns the id of the hook.
func (h *Hook) ID() string {
return "pebble-db"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnSessionEstablished,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnRetainMessage,
mqtt.OnWillSent,
mqtt.OnQosPublish,
mqtt.OnQosComplete,
mqtt.OnQosDropped,
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
mqtt.StoredSubscriptions,
mqtt.StoredSysInfo,
}, []byte{b})
}
// Init initializes and connects to the pebble instance.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
h.config = new(Options)
} else {
h.config = config.(*Options)
}
if len(h.config.Path) == 0 {
h.config.Path = defaultDbFile
}
if h.config.Options == nil {
h.config.Options = &pebbledb.Options{}
}
h.mode = pebbledb.NoSync
if strings.EqualFold(h.config.Mode, "Sync") {
h.mode = pebbledb.Sync
}
var err error
h.db, err = pebbledb.Open(h.config.Path, h.config.Options)
if err != nil {
return err
}
return nil
}
// Stop closes the pebble instance.
func (h *Hook) Stop() error {
err := h.db.Close()
h.db = nil
return err
}
// OnSessionEstablished adds a client to the store when their session is established.
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
props := cl.Properties.Props.Copy(false)
in := &storage.Client{
ID: cl.ID,
T: storage.ClientKey,
Remote: cl.Net.Remote,
Listener: cl.Net.Listener,
Username: cl.Properties.Username,
Clean: cl.Properties.Clean,
ProtocolVersion: cl.Properties.ProtocolVersion,
Properties: storage.ClientProperties{
SessionExpiryInterval: props.SessionExpiryInterval,
AuthenticationMethod: props.AuthenticationMethod,
AuthenticationData: props.AuthenticationData,
RequestProblemInfo: props.RequestProblemInfo,
RequestResponseInfo: props.RequestResponseInfo,
ReceiveMaximum: props.ReceiveMaximum,
TopicAliasMaximum: props.TopicAliasMaximum,
User: props.User,
MaximumPacketSize: props.MaximumPacketSize,
},
Will: storage.ClientWill(cl.Properties.Will),
}
h.setKv(clientKey(cl), in)
}
// OnDisconnect removes a client from the store if their session has expired.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
h.updateClient(cl)
if !expire {
return
}
if errors.Is(cl.StopCause(), packets.ErrSessionTakenOver) {
return
}
h.delKv(clientKey(cl))
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
h.setKv(in.ID, in)
}
}
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
h.delKv(subscriptionKey(cl, pk.Filters[i].Filter))
}
}
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
if r == -1 {
h.delKv(retainedKey(pk.TopicName))
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: retainedKey(pk.TopicName),
T: storage.RetainedKey,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Created: pk.Created,
Origin: pk.Origin,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
h.setKv(in.ID, in)
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: inflightKey(cl, pk),
T: storage.InflightKey,
Origin: pk.Origin,
PacketID: pk.PacketID,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Sent: sent,
Created: pk.Created,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
h.setKv(in.ID, in)
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
h.delKv(inflightKey(cl, pk))
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
}
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
in := &storage.SystemInfo{
ID: sysInfoKey(),
T: storage.SysInfoKey,
Info: *sys.Clone(),
}
h.setKv(in.ID, in)
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
h.delKv(retainedKey(filter))
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
h.delKv(clientKey(cl))
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
iter, _ := h.db.NewIter(&pebbledb.IterOptions{
LowerBound: []byte(storage.ClientKey),
UpperBound: keyUpperBound([]byte(storage.ClientKey)),
})
for iter.First(); iter.Valid(); iter.Next() {
item := storage.Client{}
if err := item.UnmarshalBinary(iter.Value()); err == nil {
v = append(v, item)
}
}
return v, nil
}
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
iter, _ := h.db.NewIter(&pebbledb.IterOptions{
LowerBound: []byte(storage.SubscriptionKey),
UpperBound: keyUpperBound([]byte(storage.SubscriptionKey)),
})
for iter.First(); iter.Valid(); iter.Next() {
item := storage.Subscription{}
if err := item.UnmarshalBinary(iter.Value()); err == nil {
v = append(v, item)
}
}
return v, nil
}
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
iter, _ := h.db.NewIter(&pebbledb.IterOptions{
LowerBound: []byte(storage.RetainedKey),
UpperBound: keyUpperBound([]byte(storage.RetainedKey)),
})
for iter.First(); iter.Valid(); iter.Next() {
item := storage.Message{}
if err := item.UnmarshalBinary(iter.Value()); err == nil {
v = append(v, item)
}
}
return v, nil
}
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
iter, _ := h.db.NewIter(&pebbledb.IterOptions{
LowerBound: []byte(storage.InflightKey),
UpperBound: keyUpperBound([]byte(storage.InflightKey)),
})
for iter.First(); iter.Valid(); iter.Next() {
item := storage.Message{}
if err := item.UnmarshalBinary(iter.Value()); err == nil {
v = append(v, item)
}
}
return v, nil
}
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
err = h.getKv(sysInfoKey(), &v)
if errors.Is(err, pebbledb.ErrNotFound) {
return v, nil
}
return
}
// Errorf satisfies the pebble interface for an error logger.
func (h *Hook) Errorf(m string, v ...any) {
h.Log.Error(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
}
// Warningf satisfies the pebble interface for a warning logger.
func (h *Hook) Warningf(m string, v ...any) {
h.Log.Warn(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
}
// Infof satisfies the pebble interface for an info logger.
func (h *Hook) Infof(m string, v ...any) {
h.Log.Info(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
}
// Debugf satisfies the pebble interface for a debug logger.
func (h *Hook) Debugf(m string, v ...any) {
h.Log.Debug(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
}
// delKv deletes a key-value pair from the database.
func (h *Hook) delKv(k string) error {
err := h.db.Delete([]byte(k), h.mode)
if err != nil {
h.Log.Error("failed to delete data", "error", err, "key", k)
return err
}
return nil
}
// setKv stores a key-value pair in the database.
func (h *Hook) setKv(k string, v storage.Serializable) error {
bs, _ := v.MarshalBinary()
err := h.db.Set([]byte(k), bs, h.mode)
if err != nil {
h.Log.Error("failed to update data", "error", err, "key", k)
return err
}
return nil
}
// getKv retrieves the value associated with a key from the database.
func (h *Hook) getKv(k string, v storage.Serializable) error {
value, closer, err := h.db.Get([]byte(k))
if err != nil {
return err
}
defer func() {
if closer != nil {
closer.Close()
}
}()
return v.UnmarshalBinary(value)
}

View File

@@ -0,0 +1,812 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: werbenhu
package pebble
import (
"log/slog"
"os"
"strings"
"testing"
"time"
pebbledb "github.com/cockroachdb/pebble"
"github.com/stretchr/testify/require"
"testmqtt/hooks/storage"
"testmqtt/mqtt"
"testmqtt/packets"
"testmqtt/system"
)
var (
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
client = &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
)
func teardown(t *testing.T, path string, h *Hook) {
_ = h.Stop()
err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
require.NoError(t, err)
}
func TestKeyUpperBound(t *testing.T) {
// Test case 1: Non-nil case
input1 := []byte{97, 98, 99} // "abc"
require.NotNil(t, keyUpperBound(input1))
// Test case 2: All bytes are 255
input2 := []byte{255, 255, 255}
require.Nil(t, keyUpperBound(input2))
// Test case 3: Empty slice
input3 := []byte{}
require.Nil(t, keyUpperBound(input3))
// Test case 4: Nil case
input4 := []byte{255, 255, 255}
require.Nil(t, keyUpperBound(input4))
}
func TestClientKey(t *testing.T) {
k := clientKey(&mqtt.Client{ID: "cl1"})
require.Equal(t, storage.ClientKey+"_cl1", k)
}
func TestSubscriptionKey(t *testing.T) {
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
}
func TestRetainedKey(t *testing.T) {
k := retainedKey("a/b/c")
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
}
func TestInflightKey(t *testing.T) {
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
require.Equal(t, storage.InflightKey+"_cl1:1", k)
}
func TestSysInfoKey(t *testing.T) {
require.Equal(t, storage.SysInfoKey, sysInfoKey())
}
func TestID(t *testing.T) {
h := new(Hook)
require.Equal(t, "pebble-db", h.ID())
}
func TestProvides(t *testing.T) {
h := new(Hook)
require.True(t, h.Provides(mqtt.OnSessionEstablished))
require.True(t, h.Provides(mqtt.OnDisconnect))
require.True(t, h.Provides(mqtt.OnSubscribed))
require.True(t, h.Provides(mqtt.OnUnsubscribed))
require.True(t, h.Provides(mqtt.OnRetainMessage))
require.True(t, h.Provides(mqtt.OnQosPublish))
require.True(t, h.Provides(mqtt.OnQosComplete))
require.True(t, h.Provides(mqtt.OnQosDropped))
require.True(t, h.Provides(mqtt.OnSysInfoTick))
require.True(t, h.Provides(mqtt.StoredClients))
require.True(t, h.Provides(mqtt.StoredInflightMessages))
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
require.True(t, h.Provides(mqtt.StoredSubscriptions))
require.True(t, h.Provides(mqtt.StoredSysInfo))
require.False(t, h.Provides(mqtt.OnACLCheck))
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
}
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestInitErr(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(&Options{
Options: &pebbledb.Options{
ReadOnly: true,
},
})
require.Error(t, err)
}
func TestInitUseDefaults(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
require.Equal(t, defaultDbFile, h.config.Path)
}
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
r := new(storage.Client)
err = h.getKv(clientKey(client), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
require.Equal(t, client.Properties.Username, r.Username)
require.Equal(t, client.Properties.Clean, r.Clean)
require.Equal(t, client.Net.Remote, r.Remote)
require.Equal(t, client.Net.Listener, r.Listener)
require.NotSame(t, client, r)
h.OnDisconnect(client, nil, false)
r2 := new(storage.Client)
err = h.getKv(clientKey(client), r2)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
h.OnDisconnect(client, nil, true)
r3 := new(storage.Client)
err = h.getKv(clientKey(client), r3)
require.Error(t, err)
require.ErrorIs(t, pebbledb.ErrNotFound, err)
require.Empty(t, r3.ID)
}
func TestOnClientExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
cl := &mqtt.Client{ID: "cl1"}
clientKey := clientKey(cl)
err = h.setKv(clientKey, &storage.Client{ID: cl.ID})
require.NoError(t, err)
r := new(storage.Client)
err = h.getKv(clientKey, r)
require.NoError(t, err)
require.Equal(t, cl.ID, r.ID)
h.OnClientExpired(cl)
err = h.getKv(clientKey, r)
require.Error(t, err)
require.ErrorIs(t, err, pebbledb.ErrNotFound)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnClientExpired(client)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnWillSent(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
c1 := client
c1.Properties.Will.Flag = 1
h.OnWillSent(c1, packets.Packet{})
r := new(storage.Client)
err = h.getKv(clientKey(client), r)
require.NoError(t, err)
require.Equal(t, uint32(1), r.Will.Flag)
require.NotSame(t, client, r)
}
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
h.OnDisconnect(testClient, nil, true)
teardown(t, h.config.Path, h)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
r := new(storage.Subscription)
err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.Client)
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
require.Equal(t, byte(0), r.Qos)
h.OnUnsubscribed(client, pkf)
err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
require.Error(t, err)
require.Equal(t, pebbledb.ErrNotFound, err)
}
func TestOnSubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnUnsubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnUnsubscribed(client, pkf)
}
func TestOnRetainMessageThenUnset(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnRetainMessage(client, pk, 1)
r := new(storage.Message)
err = h.getKv(retainedKey(pk.TopicName), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
h.OnRetainMessage(client, pk, -1)
err = h.getKv(retainedKey(pk.TopicName), r)
require.Error(t, err)
require.ErrorIs(t, err, pebbledb.ErrNotFound)
// coverage: delete deleted
h.OnRetainMessage(client, pk, -1)
err = h.getKv(retainedKey(pk.TopicName), r)
require.Error(t, err)
require.ErrorIs(t, err, pebbledb.ErrNotFound)
}
func TestOnRetainedExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
m := &storage.Message{
ID: retainedKey("a/b/c"),
T: storage.RetainedKey,
TopicName: "a/b/c",
}
err = h.setKv(m.ID, m)
require.NoError(t, err)
r := new(storage.Message)
err = h.getKv(m.ID, r)
require.NoError(t, err)
require.Equal(t, m.TopicName, r.TopicName)
h.OnRetainedExpired(m.TopicName)
err = h.getKv(m.ID, r)
require.Error(t, err)
require.ErrorIs(t, err, pebbledb.ErrNotFound)
}
func TestOnRetainExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnQosPublishThenQOSComplete(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
Qos: 2,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
r := new(storage.Message)
err = h.getKv(inflightKey(client, pk), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
// ensure dates are properly saved
require.True(t, r.Sent > 0)
require.True(t, time.Now().Unix()-1 < r.Sent)
// OnQosDropped is a passthrough to OnQosComplete here
h.OnQosDropped(client, pk)
err = h.getKv(inflightKey(client, pk), r)
require.Error(t, err)
require.ErrorIs(t, err, pebbledb.ErrNotFound)
}
func TestOnQosPublishNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosCompleteNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosDroppedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnQosDropped(client, packets.Packet{})
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
info := &system.Info{
Version: "2.0.0",
BytesReceived: 100,
}
h.OnSysInfoTick(info)
r := new(storage.SystemInfo)
err = h.getKv(storage.SysInfoKey, r)
require.NoError(t, err)
require.Equal(t, info.Version, r.Version)
require.Equal(t, info.BytesReceived, r.BytesReceived)
require.NotSame(t, info, r)
}
func TestOnSysInfoTickNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSysInfoTick(new(system.Info))
}
func TestStoredClients(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with clients
err = h.setKv(storage.ClientKey+"_cl1", &storage.Client{ID: "cl1"})
require.NoError(t, err)
err = h.setKv(storage.ClientKey+"_cl2", &storage.Client{ID: "cl2"})
require.NoError(t, err)
err = h.setKv(storage.ClientKey+"_cl3", &storage.Client{ID: "cl3"})
require.NoError(t, err)
r, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "cl1", r[0].ID)
require.Equal(t, "cl2", r[1].ID)
require.Equal(t, "cl3", r[2].ID)
}
func TestStoredClientsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredClients()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSubscriptions(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with subscriptions
err = h.setKv(storage.SubscriptionKey+"_sub1", &storage.Subscription{ID: "sub1"})
require.NoError(t, err)
err = h.setKv(storage.SubscriptionKey+"_sub2", &storage.Subscription{ID: "sub2"})
require.NoError(t, err)
err = h.setKv(storage.SubscriptionKey+"_sub3", &storage.Subscription{ID: "sub3"})
require.NoError(t, err)
r, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "sub1", r[0].ID)
require.Equal(t, "sub2", r[1].ID)
require.Equal(t, "sub3", r[2].ID)
}
func TestStoredSubscriptionsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredRetainedMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1"})
require.NoError(t, err)
err = h.setKv(storage.RetainedKey+"_m2", &storage.Message{ID: "m2"})
require.NoError(t, err)
err = h.setKv(storage.RetainedKey+"_m3", &storage.Message{ID: "m3"})
require.NoError(t, err)
err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3"})
require.NoError(t, err)
r, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "m1", r[0].ID)
require.Equal(t, "m2", r[1].ID)
require.Equal(t, "m3", r[2].ID)
}
func TestStoredRetainedMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredInflightMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.setKv(storage.InflightKey+"_i1", &storage.Message{ID: "i1"})
require.NoError(t, err)
err = h.setKv(storage.InflightKey+"_i2", &storage.Message{ID: "i2"})
require.NoError(t, err)
err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3"})
require.NoError(t, err)
err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1"})
require.NoError(t, err)
r, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "i1", r[0].ID)
require.Equal(t, "i2", r[1].ID)
require.Equal(t, "i3", r[2].ID)
}
func TestStoredInflightMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSysInfo(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
r, err := h.StoredSysInfo()
require.NoError(t, err)
// populate with messages
err = h.setKv(storage.SysInfoKey, &storage.SystemInfo{
ID: storage.SysInfoKey,
Info: system.Info{
Version: "2.0.0",
},
T: storage.SysInfoKey,
})
require.NoError(t, err)
r, err = h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", r.Info.Version)
}
func TestStoredSysInfoNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.NoError(t, err)
}
func TestErrorf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(logger, nil)
h.Errorf("test", 1, 2, 3)
}
func TestWarningf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(logger, nil)
h.Warningf("test", 1, 2, 3)
}
func TestInfof(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(logger, nil)
h.Infof("test", 1, 2, 3)
}
func TestDebugf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(logger, nil)
h.Debugf("test", 1, 2, 3)
}
func TestGetSetDelKv(t *testing.T) {
opts := []struct {
name string
opt *Options
}{
{
name: "NoSync",
opt: &Options{
Mode: NoSync,
},
},
{
name: "Sync",
opt: &Options{
Mode: Sync,
},
},
}
for _, tt := range opts {
t.Run(tt.name, func(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(tt.opt)
defer teardown(t, h.config.Path, h)
require.NoError(t, err)
err = h.setKv("testKey", &storage.Client{ID: "testId"})
require.NoError(t, err)
var obj storage.Client
err = h.getKv("testKey", &obj)
require.NoError(t, err)
err = h.delKv("testKey")
require.NoError(t, err)
err = h.getKv("testKey", &obj)
require.Error(t, err)
require.ErrorIs(t, pebbledb.ErrNotFound, err)
})
}
}
func TestGetSetDelKvErr(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(&Options{
Mode: Sync,
})
require.NoError(t, err)
err = h.setKv("testKey", &storage.Client{ID: "testId"})
require.NoError(t, err)
h.Stop()
h = new(Hook)
h.SetOpts(logger, nil)
err = h.Init(&Options{
Mode: Sync,
Options: &pebbledb.Options{
ReadOnly: true,
},
})
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
err = h.setKv("testKey", &storage.Client{ID: "testId"})
require.Error(t, err)
err = h.delKv("testKey")
require.Error(t, err)
}

View File

@@ -0,0 +1,532 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package redis
import (
"bytes"
"context"
"errors"
"fmt"
"testmqtt/hooks/storage"
"testmqtt/mqtt"
"testmqtt/packets"
"testmqtt/system"
"github.com/go-redis/redis/v8"
)
// defaultAddr 默认Redis地址
// defaultAddr is the default address to the redis service.
const defaultAddr = "localhost:6379"
// defaultHPrefix 默认前缀
// defaultHPrefix is a prefix to better identify hsets created by mochi mqtt.
const defaultHPrefix = "mochi-"
// clientKey returns a primary key for a client.
func clientKey(cl *mqtt.Client) string {
return cl.ID
}
// subscriptionKey returns a primary key for a subscription.
func subscriptionKey(cl *mqtt.Client, filter string) string {
return cl.ID + ":" + filter
}
// retainedKey returns a primary key for a retained message.
func retainedKey(topic string) string {
return topic
}
// inflightKey returns a primary key for an inflight message.
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
return cl.ID + ":" + pk.FormatID()
}
// sysInfoKey returns a primary key for system info.
func sysInfoKey() string {
return storage.SysInfoKey
}
// Options contains configuration settings for the bolt instance.
type Options struct {
Address string `yaml:"address" json:"address"`
Username string `yaml:"username" json:"username"`
Password string `yaml:"password" json:"password"`
Database int `yaml:"database" json:"database"`
HPrefix string `yaml:"h_prefix" json:"h_prefix"`
Options *redis.Options
}
// Hook is a persistent storage hook based using Redis as a backend.
type Hook struct {
mqtt.HookBase
config *Options // options for connecting to the Redis instance.
db *redis.Client // the Redis instance
ctx context.Context // a context for the connection
}
// ID returns the id of the hook.
func (h *Hook) ID() string {
return "redis-db"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnSessionEstablished,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnRetainMessage,
mqtt.OnQosPublish,
mqtt.OnQosComplete,
mqtt.OnQosDropped,
mqtt.OnWillSent,
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
mqtt.StoredSubscriptions,
mqtt.StoredSysInfo,
}, []byte{b})
}
// hKey returns a hash set key with a unique prefix.
func (h *Hook) hKey(s string) string {
return h.config.HPrefix + s
}
// Init 初始化并连接到 Redis 服务
// Init initializes and connects to the redis service.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
h.ctx = context.Background()
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
if h.config.Options == nil {
h.config.Options = &redis.Options{
Addr: defaultAddr,
}
h.config.Options.Addr = h.config.Address
h.config.Options.DB = h.config.Database
//h.config.Options.Username = h.config.Username
h.config.Options.Password = h.config.Password
}
if h.config.HPrefix == "" {
h.config.HPrefix = defaultHPrefix
}
h.Log.Info(
"connecting to redis service",
"prefix", h.config.HPrefix,
"address", h.config.Options.Addr,
"username", h.config.Options.Username,
"password-len", len(h.config.Options.Password),
"db", h.config.Options.DB,
)
h.db = redis.NewClient(h.config.Options)
_, err := h.db.Ping(context.Background()).Result()
if err != nil {
return fmt.Errorf("failed to ping service: %w", err)
}
h.Log.Info("connected to redis service")
return nil
}
// Stop closes the redis connection.
func (h *Hook) Stop() error {
h.Log.Info("disconnecting from redis service")
return h.db.Close()
}
// OnSessionEstablished adds a client to the store when their session is established.
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
props := cl.Properties.Props.Copy(false)
in := &storage.Client{
ID: clientKey(cl),
T: storage.ClientKey,
Remote: cl.Net.Remote,
Listener: cl.Net.Listener,
Username: cl.Properties.Username,
Clean: cl.Properties.Clean,
ProtocolVersion: cl.Properties.ProtocolVersion,
Properties: storage.ClientProperties{
SessionExpiryInterval: props.SessionExpiryInterval,
AuthenticationMethod: props.AuthenticationMethod,
AuthenticationData: props.AuthenticationData,
RequestProblemInfo: props.RequestProblemInfo,
RequestResponseInfo: props.RequestResponseInfo,
ReceiveMaximum: props.ReceiveMaximum,
TopicAliasMaximum: props.TopicAliasMaximum,
User: props.User,
MaximumPacketSize: props.MaximumPacketSize,
},
Will: storage.ClientWill(cl.Properties.Will),
}
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err()
if err != nil {
h.Log.Error("failed to hset client data", "error", err, "data", in)
}
}
// OnDisconnect removes a client from the store if they were using a clean session.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
if !expire {
return
}
if cl.StopCause() == packets.ErrSessionTakenOver {
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error("failed to delete client", "error", err, "id", clientKey(cl))
}
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
if err != nil {
h.Log.Error("failed to hset subscription data", "error", err, "data", in)
}
}
}
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err()
if err != nil {
h.Log.Error("failed to delete subscription data", "error", err, "id", clientKey(cl))
}
}
}
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
if r == -1 {
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err()
if err != nil {
h.Log.Error("failed to delete retained message data", "error", err, "id", retainedKey(pk.TopicName))
}
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: retainedKey(pk.TopicName),
T: storage.RetainedKey,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Created: pk.Created,
Origin: pk.Origin,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err()
if err != nil {
h.Log.Error("failed to hset retained message data", "error", err, "data", in)
}
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: inflightKey(cl, pk),
T: storage.InflightKey,
Origin: pk.Origin,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Sent: sent,
Created: pk.Created,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err()
if err != nil {
h.Log.Error("failed to hset qos inflight message data", "error", err, "data", in)
}
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err()
if err != nil {
h.Log.Error("failed to delete qos inflight message data", "error", err, "id", inflightKey(cl, pk))
}
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
}
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
in := &storage.SystemInfo{
ID: sysInfoKey(),
T: storage.SysInfoKey,
Info: *sys,
}
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err()
if err != nil {
h.Log.Error("failed to hset server info data", "error", err, "data", in)
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
if err != nil {
h.Log.Error("failed to delete expired retained message", "error", err, "id", retainedKey(filter))
}
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl))
}
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error("failed to HGetAll client data", "error", err)
return
}
for _, row := range rows {
var d storage.Client
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error("failed to unmarshal client data", "error", err, "data", row)
}
v = append(v, d)
}
return v, nil
}
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error("failed to HGetAll subscription data", "error", err)
return
}
for _, row := range rows {
var d storage.Subscription
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error("failed to unmarshal subscription data", "error", err, "data", row)
}
v = append(v, d)
}
return v, nil
}
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error("failed to HGetAll retained message data", "error", err)
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error("failed to unmarshal retained message data", "error", err, "data", row)
}
v = append(v, d)
}
return v, nil
}
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error("failed to HGetAll inflight message data", "error", err)
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error("failed to unmarshal inflight message data", "error", err, "data", row)
}
v = append(v, d)
}
return v, nil
}
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
if err != nil && !errors.Is(err, redis.Nil) {
return
}
if err = v.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error("failed to unmarshal sys info data", "error", err, "data", row)
}
return v, nil
}

View File

@@ -0,0 +1,834 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package redis
import (
"log/slog"
"os"
"sort"
"testing"
"time"
"testmqtt/hooks/storage"
"testmqtt/mqtt"
"testmqtt/packets"
"testmqtt/system"
miniredis "github.com/alicebob/miniredis/v2"
redis "github.com/go-redis/redis/v8"
"github.com/stretchr/testify/require"
)
var (
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
client = &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
)
func newHook(t *testing.T, addr string) *Hook {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(&Options{
Options: &redis.Options{
Addr: addr,
},
})
require.NoError(t, err)
return h
}
func teardown(t *testing.T, h *Hook) {
if h.db != nil {
err := h.db.FlushAll(h.ctx).Err()
require.NoError(t, err)
h.Stop()
}
}
func TestClientKey(t *testing.T) {
k := clientKey(&mqtt.Client{ID: "cl1"})
require.Equal(t, "cl1", k)
}
func TestSubscriptionKey(t *testing.T) {
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
require.Equal(t, "cl1:a/b/c", k)
}
func TestRetainedKey(t *testing.T) {
k := retainedKey("a/b/c")
require.Equal(t, "a/b/c", k)
}
func TestInflightKey(t *testing.T) {
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
require.Equal(t, "cl1:1", k)
}
func TestSysInfoKey(t *testing.T) {
require.Equal(t, storage.SysInfoKey, sysInfoKey())
}
func TestID(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
require.Equal(t, "redis-db", h.ID())
}
func TestProvides(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
require.True(t, h.Provides(mqtt.OnSessionEstablished))
require.True(t, h.Provides(mqtt.OnDisconnect))
require.True(t, h.Provides(mqtt.OnSubscribed))
require.True(t, h.Provides(mqtt.OnUnsubscribed))
require.True(t, h.Provides(mqtt.OnRetainMessage))
require.True(t, h.Provides(mqtt.OnQosPublish))
require.True(t, h.Provides(mqtt.OnQosComplete))
require.True(t, h.Provides(mqtt.OnQosDropped))
require.True(t, h.Provides(mqtt.OnSysInfoTick))
require.True(t, h.Provides(mqtt.StoredClients))
require.True(t, h.Provides(mqtt.StoredInflightMessages))
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
require.True(t, h.Provides(mqtt.StoredSubscriptions))
require.True(t, h.Provides(mqtt.StoredSysInfo))
require.False(t, h.Provides(mqtt.OnACLCheck))
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
}
func TestHKey(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.SetOpts(logger, nil)
require.Equal(t, defaultHPrefix+"test", h.hKey("test"))
}
func TestInitUseDefaults(t *testing.T) {
s := miniredis.RunT(t)
s.StartAddr(defaultAddr)
defer s.Close()
h := newHook(t, defaultAddr)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h)
require.Equal(t, defaultHPrefix, h.config.HPrefix)
require.Equal(t, defaultAddr, h.config.Options.Addr)
}
func TestInitUsePassConfig(t *testing.T) {
s := miniredis.RunT(t)
s.StartAddr(defaultAddr)
defer s.Close()
h := newHook(t, "")
h.SetOpts(logger, nil)
err := h.Init(&Options{
Address: defaultAddr,
Username: "username",
Password: "password",
Database: 2,
})
require.Error(t, err)
h.db.FlushAll(h.ctx)
require.Equal(t, defaultAddr, h.config.Options.Addr)
require.Equal(t, "username", h.config.Options.Username)
require.Equal(t, "password", h.config.Options.Password)
require.Equal(t, 2, h.config.Options.DB)
}
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestInitBadAddr(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(&Options{
Options: &redis.Options{
Addr: "abc:123",
},
})
require.Error(t, err)
}
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
h.OnSessionEstablished(client, packets.Packet{})
r := new(storage.Client)
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
require.Equal(t, client.Net.Remote, r.Remote)
require.Equal(t, client.Net.Listener, r.Listener)
require.Equal(t, client.Properties.Username, r.Username)
require.Equal(t, client.Properties.Clean, r.Clean)
require.NotSame(t, client, r)
h.OnDisconnect(client, nil, false)
r2 := new(storage.Client)
row, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.NoError(t, err)
err = r2.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
h.OnDisconnect(client, nil, true)
r3 := new(storage.Client)
_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
require.Empty(t, r3.ID)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnWillSent(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
c1 := client
c1.Properties.Will.Flag = 1
h.OnWillSent(c1, packets.Packet{})
r := new(storage.Client)
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, uint32(1), r.Will.Flag)
require.NotSame(t, client, r)
}
func TestOnClientExpired(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
cl := &mqtt.Client{ID: "cl1"}
clientKey := clientKey(cl)
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey, &storage.Client{ID: cl.ID}).Err()
require.NoError(t, err)
r := new(storage.Client)
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, clientKey, r.ID)
h.OnClientExpired(cl)
_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
require.Error(t, err)
require.ErrorIs(t, redis.Nil, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
h.OnSubscribed(client, pkf, []byte{0})
r := new(storage.Subscription)
row, err := h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, client.ID, r.Client)
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
require.Equal(t, byte(0), r.Qos)
h.OnUnsubscribed(client, pkf)
_, err = h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnSubscribedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnUnsubscribedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnUnsubscribed(client, pkf)
}
func TestOnRetainMessageThenUnset(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnRetainMessage(client, pk, 1)
r := new(storage.Message)
row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
h.OnRetainMessage(client, pk, -1)
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
// coverage: delete deleted
h.OnRetainMessage(client, pk, -1)
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnRetainedExpired(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
m := &storage.Message{
ID: retainedKey("a/b/c"),
T: storage.RetainedKey,
TopicName: "a/b/c",
}
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), m.ID, m).Err()
require.NoError(t, err)
r := new(storage.Message)
row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.NoError(t, err)
require.Equal(t, m.TopicName, r.TopicName)
h.OnRetainedExpired(m.TopicName)
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnQosPublishThenQOSComplete(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
Qos: 2,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
r := new(storage.Message)
row, err := h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
// ensure dates are properly saved to bolt
require.True(t, r.Sent > 0)
require.True(t, time.Now().Unix()-1 < r.Sent)
// OnQosDropped is a passthrough to OnQosComplete here
h.OnQosDropped(client, pk)
_, err = h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnQosPublishNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosCompleteNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosDroppedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnQosDropped(client, packets.Packet{})
}
func TestOnSysInfoTick(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
info := &system.Info{
Version: "2.0.0",
BytesReceived: 100,
}
h.OnSysInfoTick(info)
r := new(storage.SystemInfo)
row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, info.Version, r.Version)
require.Equal(t, info.BytesReceived, r.BytesReceived)
require.NotSame(t, info, r)
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnSysInfoTick(new(system.Info))
}
func TestStoredClients(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with clients
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl1", &storage.Client{ID: "cl1", T: storage.ClientKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl2", &storage.Client{ID: "cl2", T: storage.ClientKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl3", &storage.Client{ID: "cl3", T: storage.ClientKey}).Err()
require.NoError(t, err)
r, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "cl1", r[0].ID)
require.Equal(t, "cl2", r[1].ID)
require.Equal(t, "cl3", r[2].ID)
}
func TestStoredClientsNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredClients()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredClientsClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredClients()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredSubscriptions(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with subscriptions
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey}).Err()
require.NoError(t, err)
r, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "sub1", r[0].ID)
require.Equal(t, "sub2", r[1].ID)
require.Equal(t, "sub3", r[2].ID)
}
func TestStoredSubscriptionsNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSubscriptionsClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredRetainedMessages(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with messages
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m1", &storage.Message{ID: "m1", T: storage.RetainedKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m2", &storage.Message{ID: "m2", T: storage.RetainedKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
require.NoError(t, err)
r, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "m1", r[0].ID)
require.Equal(t, "m2", r[1].ID)
require.Equal(t, "m3", r[2].ID)
}
func TestStoredRetainedMessagesNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredInflightMessages(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with messages
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1", &storage.Message{ID: "i1", T: storage.InflightKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2", &storage.Message{ID: "i2", T: storage.InflightKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
require.NoError(t, err)
r, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "i1", r[0].ID)
require.Equal(t, "i2", r[1].ID)
require.Equal(t, "i3", r[2].ID)
}
func TestStoredInflightMessagesNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredInflightMessagesClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredSysInfo(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with sys info
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey,
&storage.SystemInfo{
ID: storage.SysInfoKey,
Info: system.Info{
Version: "2.0.0",
},
T: storage.SysInfoKey,
}).Err()
require.NoError(t, err)
r, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", r.Info.Version)
}
func TestStoredSysInfoNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSysInfoClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.Error(t, err)
}

213
hooks/storage/storage.go Normal file
View File

@@ -0,0 +1,213 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package storage
import (
"encoding/json"
"errors"
"testmqtt/packets"
"testmqtt/system"
)
const (
// SubscriptionKey 唯一标识订阅信息。在存储系统中,这个键用于检索和管理客户端的订阅信息
// SubscriptionKey unique key to denote Subscriptions in a store
SubscriptionKey = "SUB"
// SysInfoKey 唯一标识服务器系统信息。在存储系统中,这个键用于存储和检索与服务器状态或系统配置相关的信息。
// SysInfoKey unique key to denote server system information in a store
SysInfoKey = "SYS"
// RetainedKey 唯一标识保留的消息。保留消息是在订阅时立即发送的消息,即使订阅者在消息发布时不在线。这个键用于存储这些保留消息。
// RetainedKey unique key to denote retained messages in a store
RetainedKey = "RET"
// InflightKey 唯一标识飞行中的消息。飞行中的消息指的是已经发布但尚未确认的消息,这个键用于跟踪这些消息的状态。
// InflightKey unique key to denote inflight messages in a store
InflightKey = "IFM"
// ClientKey 唯一标识客户端信息。在存储系统中,这个键用于检索和管理有关客户端的信息,如客户端状态、连接信息等。
// ClientKey unique key to denote clients in a store
ClientKey = "CL"
)
var (
// ErrDBFileNotOpen 数据库文件没有打开
// ErrDBFileNotOpen indicates that the file database (e.g. bolt/badger) wasn't open for reading.
ErrDBFileNotOpen = errors.New("数据库文件没有打开")
)
// Serializable is an interface for objects that can be serialized and deserialized.
type Serializable interface {
UnmarshalBinary([]byte) error
MarshalBinary() (data []byte, err error)
}
// Client is a storable representation of an MQTT client.
type Client struct {
Will ClientWill `json:"will"` // will topic and payload data if applicable
Properties ClientProperties `json:"properties"` // the connect properties for the client
Username []byte `json:"username"` // the username of the client
ID string `json:"id" storm:"id"` // the client id / storage key
T string `json:"t"` // the data type (client)
Remote string `json:"remote"` // the remote address of the client
Listener string `json:"listener"` // the listener the client connected on
ProtocolVersion byte `json:"protocolVersion"` // mqtt protocol version of the client
Clean bool `json:"clean"` // if the client requested a clean start/session
}
// ClientProperties contains a limited set of the mqtt v5 properties specific to a client connection.
type ClientProperties struct {
AuthenticationData []byte `json:"authenticationData,omitempty"`
User []packets.UserProperty `json:"user,omitempty"`
AuthenticationMethod string `json:"authenticationMethod,omitempty"`
SessionExpiryInterval uint32 `json:"sessionExpiryInterval,omitempty"`
MaximumPacketSize uint32 `json:"maximumPacketSize,omitempty"`
ReceiveMaximum uint16 `json:"receiveMaximum,omitempty"`
TopicAliasMaximum uint16 `json:"topicAliasMaximum,omitempty"`
SessionExpiryIntervalFlag bool `json:"sessionExpiryIntervalFlag,omitempty"`
RequestProblemInfo byte `json:"requestProblemInfo,omitempty"`
RequestProblemInfoFlag bool `json:"requestProblemInfoFlag,omitempty"`
RequestResponseInfo byte `json:"requestResponseInfo,omitempty"`
}
// ClientWill contains a will message for a client, and limited mqtt v5 properties.
type ClientWill struct {
Payload []byte `json:"payload,omitempty"`
User []packets.UserProperty `json:"user,omitempty"`
TopicName string `json:"topicName,omitempty"`
Flag uint32 `json:"flag,omitempty"`
WillDelayInterval uint32 `json:"willDelayInterval,omitempty"`
Qos byte `json:"qos,omitempty"`
Retain bool `json:"retain,omitempty"`
}
// MarshalBinary encodes the values into a json string.
func (d Client) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *Client) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}
// Message is a storable representation of an MQTT message (specifically publish).
type Message struct {
Properties MessageProperties `json:"properties"` // -
Payload []byte `json:"payload"` // the message payload (if retained)
T string `json:"t,omitempty"` // the data type
ID string `json:"id,omitempty" storm:"id"` // the storage key
Origin string `json:"origin,omitempty"` // the id of the client who sent the message
TopicName string `json:"topic_name,omitempty"` // the topic the message was sent to (if retained)
FixedHeader packets.FixedHeader `json:"fixedheader"` // the header properties of the message
Created int64 `json:"created,omitempty"` // the time the message was created in unixtime
Sent int64 `json:"sent,omitempty"` // the last time the message was sent (for retries) in unixtime (if inflight)
PacketID uint16 `json:"packet_id,omitempty"` // the unique id of the packet (if inflight)
}
// MessageProperties contains a limited subset of mqtt v5 properties specific to publish messages.
type MessageProperties struct {
CorrelationData []byte `json:"correlationData,omitempty"`
SubscriptionIdentifier []int `json:"subscriptionIdentifier,omitempty"`
User []packets.UserProperty `json:"user,omitempty"`
ContentType string `json:"contentType,omitempty"`
ResponseTopic string `json:"responseTopic,omitempty"`
MessageExpiryInterval uint32 `json:"messageExpiry,omitempty"`
TopicAlias uint16 `json:"topicAlias,omitempty"`
PayloadFormat byte `json:"payloadFormat,omitempty"`
PayloadFormatFlag bool `json:"payloadFormatFlag,omitempty"`
}
// MarshalBinary encodes the values into a json string.
func (d Message) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *Message) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}
// ToPacket converts a storage.Message to a standard packet.
func (d *Message) ToPacket() packets.Packet {
pk := packets.Packet{
FixedHeader: d.FixedHeader,
PacketID: d.PacketID,
TopicName: d.TopicName,
Payload: d.Payload,
Origin: d.Origin,
Created: d.Created,
Properties: packets.Properties{
PayloadFormat: d.Properties.PayloadFormat,
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
ContentType: d.Properties.ContentType,
ResponseTopic: d.Properties.ResponseTopic,
CorrelationData: d.Properties.CorrelationData,
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
TopicAlias: d.Properties.TopicAlias,
User: d.Properties.User,
},
}
// Return a deep copy of the packet data otherwise the slices will
// continue pointing at the values from the storage packet.
pk = pk.Copy(true)
pk.FixedHeader.Dup = d.FixedHeader.Dup
return pk
}
// Subscription is a storable representation of an MQTT subscription.
type Subscription struct {
T string `json:"t,omitempty"`
ID string `json:"id,omitempty" storm:"id"`
Client string `json:"client,omitempty"`
Filter string `json:"filter"`
Identifier int `json:"identifier,omitempty"`
RetainHandling byte `json:"retain_handling,omitempty"`
Qos byte `json:"qos"`
RetainAsPublished bool `json:"retain_as_pub,omitempty"`
NoLocal bool `json:"no_local,omitempty"`
}
// MarshalBinary encodes the values into a json string.
func (d Subscription) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *Subscription) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}
// SystemInfo is a storable representation of the system information values.
type SystemInfo struct {
system.Info // embed the system info struct
T string `json:"t"` // the data type
ID string `json:"id" storm:"id"` // the storage key
}
// MarshalBinary 将值编码为json字符串
// MarshalBinary encodes the values into a json string.
func (d SystemInfo) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary 将json字符串解码为结构体
// UnmarshalBinary decodes a json string into a struct.
func (d *SystemInfo) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}

View File

@@ -0,0 +1,228 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package storage
import (
"testing"
"time"
"github.com/stretchr/testify/require"
"testmqtt/packets"
"testmqtt/system"
)
var (
clientStruct = Client{
ID: "test",
T: "client",
Remote: "remote",
Listener: "listener",
Username: []byte("mochi"),
Clean: true,
Properties: ClientProperties{
SessionExpiryInterval: 2,
SessionExpiryIntervalFlag: true,
AuthenticationMethod: "a",
AuthenticationData: []byte("test"),
RequestProblemInfo: 1,
RequestProblemInfoFlag: true,
RequestResponseInfo: 1,
ReceiveMaximum: 128,
TopicAliasMaximum: 256,
User: []packets.UserProperty{
{Key: "k", Val: "v"},
},
MaximumPacketSize: 120,
},
Will: ClientWill{
Qos: 1,
Payload: []byte("abc"),
TopicName: "a/b/c",
Flag: 1,
Retain: true,
WillDelayInterval: 2,
User: []packets.UserProperty{
{Key: "k2", Val: "v2"},
},
},
}
clientJSON = []byte(`{"will":{"payload":"YWJj","user":[{"k":"k2","v":"v2"}],"topicName":"a/b/c","flag":1,"willDelayInterval":2,"qos":1,"retain":true},"properties":{"authenticationData":"dGVzdA==","user":[{"k":"k","v":"v"}],"authenticationMethod":"a","sessionExpiryInterval":2,"maximumPacketSize":120,"receiveMaximum":128,"topicAliasMaximum":256,"sessionExpiryIntervalFlag":true,"requestProblemInfo":1,"requestProblemInfoFlag":true,"requestResponseInfo":1},"username":"bW9jaGk=","id":"test","t":"client","remote":"remote","listener":"listener","protocolVersion":0,"clean":true}`)
messageStruct = Message{
T: "message",
Payload: []byte("payload"),
FixedHeader: packets.FixedHeader{
Remaining: 2,
Type: 3,
Qos: 1,
Dup: true,
Retain: true,
},
ID: "id",
Origin: "mochi",
TopicName: "topic",
Properties: MessageProperties{
PayloadFormat: 1,
PayloadFormatFlag: true,
MessageExpiryInterval: 20,
ContentType: "type",
ResponseTopic: "a/b/r",
CorrelationData: []byte("r"),
SubscriptionIdentifier: []int{1},
TopicAlias: 2,
User: []packets.UserProperty{
{Key: "k2", Val: "v2"},
},
},
Created: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
Sent: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
PacketID: 100,
}
messageJSON = []byte(`{"properties":{"correlationData":"cg==","subscriptionIdentifier":[1],"user":[{"k":"k2","v":"v2"}],"contentType":"type","responseTopic":"a/b/r","messageExpiry":20,"topicAlias":2,"payloadFormat":1,"payloadFormatFlag":true},"payload":"cGF5bG9hZA==","t":"message","id":"id","origin":"mochi","topic_name":"topic","fixedheader":{"remaining":2,"type":3,"qos":1,"dup":true,"retain":true},"created":1569027723,"sent":1569027723,"packet_id":100}`)
subscriptionStruct = Subscription{
T: "subscription",
ID: "id",
Client: "mochi",
Filter: "a/b/c",
Qos: 1,
}
subscriptionJSON = []byte(`{"t":"subscription","id":"id","client":"mochi","filter":"a/b/c","qos":1}`)
sysInfoStruct = SystemInfo{
T: "info",
ID: "id",
Info: system.Info{
Version: "2.0.0",
Started: 1,
Uptime: 2,
BytesReceived: 3,
BytesSent: 4,
ClientsConnected: 5,
ClientsMaximum: 7,
MessagesReceived: 10,
MessagesSent: 11,
MessagesDropped: 20,
PacketsReceived: 12,
PacketsSent: 13,
Retained: 15,
Inflight: 16,
InflightDropped: 17,
},
}
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
)
func TestClientMarshalBinary(t *testing.T) {
data, err := clientStruct.MarshalBinary()
require.NoError(t, err)
require.JSONEq(t, string(clientJSON), string(data))
}
func TestClientUnmarshalBinary(t *testing.T) {
d := clientStruct
err := d.UnmarshalBinary(clientJSON)
require.NoError(t, err)
require.Equal(t, clientStruct, d)
}
func TestClientUnmarshalBinaryEmpty(t *testing.T) {
d := Client{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, Client{}, d)
}
func TestMessageMarshalBinary(t *testing.T) {
data, err := messageStruct.MarshalBinary()
require.NoError(t, err)
require.JSONEq(t, string(messageJSON), string(data))
}
func TestMessageUnmarshalBinary(t *testing.T) {
d := messageStruct
err := d.UnmarshalBinary(messageJSON)
require.NoError(t, err)
require.Equal(t, messageStruct, d)
}
func TestMessageUnmarshalBinaryEmpty(t *testing.T) {
d := Message{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, Message{}, d)
}
func TestSubscriptionMarshalBinary(t *testing.T) {
data, err := subscriptionStruct.MarshalBinary()
require.NoError(t, err)
require.JSONEq(t, string(subscriptionJSON), string(data))
}
func TestSubscriptionUnmarshalBinary(t *testing.T) {
d := subscriptionStruct
err := d.UnmarshalBinary(subscriptionJSON)
require.NoError(t, err)
require.Equal(t, subscriptionStruct, d)
}
func TestSubscriptionUnmarshalBinaryEmpty(t *testing.T) {
d := Subscription{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, Subscription{}, d)
}
func TestSysInfoMarshalBinary(t *testing.T) {
data, err := sysInfoStruct.MarshalBinary()
require.NoError(t, err)
require.JSONEq(t, string(sysInfoJSON), string(data))
}
func TestSysInfoUnmarshalBinary(t *testing.T) {
d := sysInfoStruct
err := d.UnmarshalBinary(sysInfoJSON)
require.NoError(t, err)
require.Equal(t, sysInfoStruct, d)
}
func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) {
d := SystemInfo{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, SystemInfo{}, d)
}
func TestMessageToPacket(t *testing.T) {
d := messageStruct
pk := d.ToPacket()
require.Equal(t, packets.Packet{
Payload: []byte("payload"),
FixedHeader: packets.FixedHeader{
Remaining: d.FixedHeader.Remaining,
Type: d.FixedHeader.Type,
Qos: d.FixedHeader.Qos,
Dup: d.FixedHeader.Dup,
Retain: d.FixedHeader.Retain,
},
Origin: d.Origin,
TopicName: d.TopicName,
Properties: packets.Properties{
PayloadFormat: d.Properties.PayloadFormat,
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
ContentType: d.Properties.ContentType,
ResponseTopic: d.Properties.ResponseTopic,
CorrelationData: d.Properties.CorrelationData,
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
TopicAlias: d.Properties.TopicAlias,
User: d.Properties.User,
},
PacketID: 100,
Created: d.Created,
}, pk)
}