From 84e5e65ee7a0e1c15edef83ffac3ab6fa6f98dd5 Mon Sep 17 00:00:00 2001
From: iuu <2167162990@qq.com>
Date: Wed, 21 Aug 2024 15:32:05 +0800
Subject: [PATCH] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E6=95=B4=E7=90=86?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.idea/vcs.xml | 6 +
.idea/workspace.xml | 72 +
config-bac.yaml | 101 +
config.yaml | 25 +
config/config.go | 47 +
config/config_test.go | 238 ++
config/hook.go | 123 +
config/logger.go | 21 +
go.mod | 54 +
go.sum | 573 ++++
hooks/auth/allow_all.go | 41 +
hooks/auth/allow_all_test.go | 35 +
hooks/auth/auth.go | 103 +
hooks/auth/auth_test.go | 213 ++
hooks/auth/ledger.go | 246 ++
hooks/auth/ledger_test.go | 610 ++++
hooks/debug/debug.go | 237 ++
hooks/storage/badger/badger.go | 576 ++++
hooks/storage/badger/badger_test.go | 809 ++++++
hooks/storage/bolt/bolt.go | 525 ++++
hooks/storage/bolt/bolt_test.go | 791 ++++++
hooks/storage/pebble/pebble.go | 524 ++++
hooks/storage/pebble/pebble_test.go | 812 ++++++
hooks/storage/redis/redis.go | 532 ++++
hooks/storage/redis/redis_test.go | 834 ++++++
hooks/storage/storage.go | 213 ++
hooks/storage/storage_test.go | 228 ++
listeners/http_healthcheck.go | 99 +
listeners/http_healthcheck_test.go | 137 +
listeners/http_sysinfo.go | 122 +
listeners/http_sysinfo_test.go | 149 +
listeners/listeners.go | 135 +
listeners/listeners_test.go | 182 ++
listeners/mock.go | 105 +
listeners/mock_test.go | 99 +
listeners/net.go | 92 +
listeners/net_test.go | 105 +
listeners/tcp.go | 109 +
listeners/tcp_test.go | 124 +
listeners/unixsock.go | 102 +
listeners/unixsock_test.go | 102 +
listeners/websocket.go | 199 ++
listeners/websocket_test.go | 173 ++
main.go | 103 +
mempool/bufpool.go | 83 +
mempool/bufpool_test.go | 96 +
mqtt/clients.go | 649 +++++
mqtt/clients_test.go | 930 ++++++
mqtt/hooks.go | 864 ++++++
mqtt/hooks_test.go | 667 +++++
mqtt/inflight.go | 156 ++
mqtt/inflight_test.go | 199 ++
mqtt/server.go | 1759 ++++++++++++
mqtt/server_test.go | 3915 ++++++++++++++++++++++++++
mqtt/topics.go | 824 ++++++
mqtt/topics_test.go | 1068 +++++++
packets/codec.go | 172 ++
packets/codec_test.go | 422 +++
packets/codes.go | 149 +
packets/codes_test.go | 29 +
packets/fixedheader.go | 63 +
packets/fixedheader_test.go | 237 ++
packets/packets.go | 1173 ++++++++
packets/packets_test.go | 505 ++++
packets/properties.go | 481 ++++
packets/properties_test.go | 333 +++
packets/tpackets.go | 4031 +++++++++++++++++++++++++++
packets/tpackets_test.go | 33 +
system/config.yaml | 70 +
system/system.go | 62 +
system/system_test.go | 37 +
71 files changed, 29733 insertions(+)
create mode 100644 .idea/vcs.xml
create mode 100644 .idea/workspace.xml
create mode 100644 config-bac.yaml
create mode 100644 config.yaml
create mode 100644 config/config.go
create mode 100644 config/config_test.go
create mode 100644 config/hook.go
create mode 100644 config/logger.go
create mode 100644 go.mod
create mode 100644 go.sum
create mode 100644 hooks/auth/allow_all.go
create mode 100644 hooks/auth/allow_all_test.go
create mode 100644 hooks/auth/auth.go
create mode 100644 hooks/auth/auth_test.go
create mode 100644 hooks/auth/ledger.go
create mode 100644 hooks/auth/ledger_test.go
create mode 100644 hooks/debug/debug.go
create mode 100644 hooks/storage/badger/badger.go
create mode 100644 hooks/storage/badger/badger_test.go
create mode 100644 hooks/storage/bolt/bolt.go
create mode 100644 hooks/storage/bolt/bolt_test.go
create mode 100644 hooks/storage/pebble/pebble.go
create mode 100644 hooks/storage/pebble/pebble_test.go
create mode 100644 hooks/storage/redis/redis.go
create mode 100644 hooks/storage/redis/redis_test.go
create mode 100644 hooks/storage/storage.go
create mode 100644 hooks/storage/storage_test.go
create mode 100644 listeners/http_healthcheck.go
create mode 100644 listeners/http_healthcheck_test.go
create mode 100644 listeners/http_sysinfo.go
create mode 100644 listeners/http_sysinfo_test.go
create mode 100644 listeners/listeners.go
create mode 100644 listeners/listeners_test.go
create mode 100644 listeners/mock.go
create mode 100644 listeners/mock_test.go
create mode 100644 listeners/net.go
create mode 100644 listeners/net_test.go
create mode 100644 listeners/tcp.go
create mode 100644 listeners/tcp_test.go
create mode 100644 listeners/unixsock.go
create mode 100644 listeners/unixsock_test.go
create mode 100644 listeners/websocket.go
create mode 100644 listeners/websocket_test.go
create mode 100644 main.go
create mode 100644 mempool/bufpool.go
create mode 100644 mempool/bufpool_test.go
create mode 100644 mqtt/clients.go
create mode 100644 mqtt/clients_test.go
create mode 100644 mqtt/hooks.go
create mode 100644 mqtt/hooks_test.go
create mode 100644 mqtt/inflight.go
create mode 100644 mqtt/inflight_test.go
create mode 100644 mqtt/server.go
create mode 100644 mqtt/server_test.go
create mode 100644 mqtt/topics.go
create mode 100644 mqtt/topics_test.go
create mode 100644 packets/codec.go
create mode 100644 packets/codec_test.go
create mode 100644 packets/codes.go
create mode 100644 packets/codes_test.go
create mode 100644 packets/fixedheader.go
create mode 100644 packets/fixedheader_test.go
create mode 100644 packets/packets.go
create mode 100644 packets/packets_test.go
create mode 100644 packets/properties.go
create mode 100644 packets/properties_test.go
create mode 100644 packets/tpackets.go
create mode 100644 packets/tpackets_test.go
create mode 100644 system/config.yaml
create mode 100644 system/system.go
create mode 100644 system/system_test.go
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..dbeae59
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
new file mode 100644
index 0000000..a99ea19
--- /dev/null
+++ b/.idea/workspace.xml
@@ -0,0 +1,72 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {
+ "associatedIndex": 8
+}
+
+
+
+
+
+ {
+ "keyToString": {
+ "DefaultGoTemplateProperty": "Go File",
+ "RunOnceActivity.ShowReadmeOnStart": "true",
+ "RunOnceActivity.go.formatter.settings.were.checked": "true",
+ "RunOnceActivity.go.migrated.go.modules.settings": "true",
+ "RunOnceActivity.go.modules.automatic.dependencies.download": "true",
+ "RunOnceActivity.go.modules.go.list.on.any.changes.was.set": "true",
+ "git-widget-placeholder": "main",
+ "go.import.settings.migrated": "true",
+ "last_opened_file_path": "E:/Work/Go/src/testmqtt",
+ "node.js.detected.package.eslint": "true",
+ "node.js.selected.package.eslint": "(autodetect)",
+ "nodejs_package_manager_path": "npm"
+ }
+}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ true
+
+
\ No newline at end of file
diff --git a/config-bac.yaml b/config-bac.yaml
new file mode 100644
index 0000000..b1aeaac
--- /dev/null
+++ b/config-bac.yaml
@@ -0,0 +1,101 @@
+
+
+# 监听端口
+listeners:
+ - type: "tcp"
+ id: "file-tcp1"
+ address: ":1883"
+ - type: "ws"
+ id: "file-websocket"
+ address: ":1882"
+ - type: "healthcheck"
+ id: "file-healthcheck"
+ address: ":1880"
+
+
+hooks:
+ debug:
+ enable: true
+ storage:
+ badger:
+ path: badger.db
+ gc_interval: 3
+ gc_discard_ratio: 0.5
+ pebble:
+ path: pebble.db
+ mode: "NoSync"
+ bolt:
+ path: bolt.db
+ bucket: "mochi"
+ redis:
+ h_prefix: "mc"
+ username: "mochi"
+ password: "melon"
+ address: "localhost:6379"
+ database: 1
+ auth:
+ allow_all: false
+ ledger:
+ auth:
+ - username: peach
+ password: password1
+ allow: true
+ acl:
+ - remote: 127.0.0.1:*
+ - username: melon
+ filters:
+ melon/#: 3
+ updates/#: 2
+
+# MQTT协议相关参数配置
+options:
+ # 客户端网络写缓冲区的大小,单位为字节。此配置项设置了发送数据到客户端时的缓冲区大小,设置为 2048 字节。
+ client_net_write_buffer_size: 2048
+ # 客户端网络读缓冲区的大小,单位为字节。此配置项设置了从客户端接收数据时的缓冲区大小,设置为 2048 字节。
+ client_net_read_buffer_size: 2048
+ # 系统主题消息重新发送的间隔时间,单位为秒。系统主题通常包含监控和管理信息,这个选项设定了这些信息更新的频率。
+ sys_topic_resend_interval: 10
+ # 如果设为 true,客户端的操作可能会同步执行(而不是异步执行)。这可能影响性能,但在某些情况下能简化处理逻辑。
+ inline_client: true
+ # 这些选项定义了 MQTT Broker 支持的功能及其限制。
+ capabilities:
+ # 消息过期的最大时间间隔,单位为秒。超过这个时间的消息将被丢弃。
+ maximum_message_expiry_interval: 100
+ # 允许挂起的客户端写操作的最大数量。此选项设置了在发送消息给客户端之前,Broker 可以排队的最大消息数量。
+ maximum_client_writes_pending: 8192
+ # 客户端会话的最大过期时间,单位为秒。这个配置决定了在客户端断开连接后,Broker 会保留客户端会话状态的最大时间。
+ maximum_session_expiry_interval: 86400
+ # 接收的 MQTT 包的最大大小,单位为字节。如果设置为 0,表示不限制包的大小。
+ maximum_packet_size: 0
+ # Broker 能够同时接收的最大 QoS 1 和 QoS 2 消息数量
+ receive_maximum: 1024
+ # 允许的最大“飞行中”(尚未完成确认)的消息数量。
+ maximum_inflight: 8192
+ # 允许的最大主题别名数量。主题别名用于减少传输过程中重复主题字符串的传输,特别在长主题中有助于节省带宽。
+ topic_alias_maximum: 65535
+ # 指定是否支持共享订阅,1 表示支持。
+ shared_sub_available: 1
+ # 支持的最小 MQTT 协议版本。3 表示 MQTT 3.1.1。
+ minimum_protocol_version: 3
+ # 支持的最高服务质量(QoS)级别。2 表示 Broker 支持 QoS 0、QoS 1 和 QoS 2。
+ maximum_qos: 2
+ # 指定是否支持保留消息,1 表示支持。
+ retain_available: 1
+ # 指定是否支持通配符订阅,1 表示支持。
+ wildcard_sub_available: 1
+ # 指定是否支持订阅标识符,1 表示支持。
+ sub_id_available: 1
+ # 这些选项是为了保持与旧版本或特殊客户端的兼容性。
+ compatibilities:
+ # 是否隐藏未授权错误的具体原因。true 表示隐藏,可能只返回一个泛泛的错误信息以提高安全性。
+ obscure_not_authorized: true
+ # 是否在客户端无响应时被动断开连接。false 表示不被动断开,可能会主动断开连接。
+ passive_client_disconnect: false
+ # 是否总是返回响应信息。false 表示只在需要时返回响应信息。
+ always_return_response_info: false
+ # 是否在重启时恢复系统信息。false 表示重启时不会恢复之前的系统状态信息。
+ restore_sys_info_on_restart: false
+ # 是否在 ACK 时不继承属性。false 表示继承属性。
+ no_inherited_properties_on_ack: false
+logging:
+ level: INFO
diff --git a/config.yaml b/config.yaml
new file mode 100644
index 0000000..666ad35
--- /dev/null
+++ b/config.yaml
@@ -0,0 +1,25 @@
+listeners:
+ - type: "tcp"
+ id: "file-tcp1"
+ address: ":1883"
+ - type: "ws"
+ id: "file-websocket"
+ address: ":1882"
+ - type: "healthcheck"
+ id: "file-healthcheck"
+ address: ":1880"
+hooks:
+ debug:
+ enable: true
+ storage:
+ redis:
+ h_prefix: "m-"
+ password: "iuuiuuiuu"
+ address: "192.168.0.9:6379"
+ database: 5
+ auth:
+ allow_all: true
+
+
+logging:
+ level: DEBUG
diff --git a/config/config.go b/config/config.go
new file mode 100644
index 0000000..31c48d1
--- /dev/null
+++ b/config/config.go
@@ -0,0 +1,47 @@
+package config
+
+import (
+ "encoding/json"
+ "gopkg.in/yaml.v3"
+ "testmqtt/listeners"
+ "testmqtt/mqtt"
+)
+
+// config 定义配置结构
+// config defines the structure of configuration data to be parsed from a config source.
+type config struct {
+ Options mqtt.Options // Mqtt协议配置
+ Listeners []listeners.Config `yaml:"listeners" json:"listeners"` // 监听端口配置
+ HookConfigs HookConfigs `yaml:"hooks" json:"hooks"` // 钩子配置
+ LoggingConfig LoggingConfig `yaml:"logging" json:"logging"` // 日志记录配置
+}
+
+// FromBytes unmarshals a byte slice of JSON or YAML config data into a valid server options value.
+// Any hooks configurations are converted into Hooks using the toHooks methods in this package.
+func FromBytes(b []byte) (*mqtt.Options, error) {
+ c := new(config)
+ o := mqtt.Options{}
+
+ if len(b) == 0 {
+ return nil, nil
+ }
+
+ if b[0] == '{' {
+ err := json.Unmarshal(b, c)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ err := yaml.Unmarshal(b, c)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ o = c.Options
+ o.Hooks = c.HookConfigs.ToHooks()
+ o.Listeners = c.Listeners
+ o.Logger = c.LoggingConfig.ToLogger()
+
+ return &o, nil
+}
diff --git a/config/config_test.go b/config/config_test.go
new file mode 100644
index 0000000..a1b9698
--- /dev/null
+++ b/config/config_test.go
@@ -0,0 +1,238 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package config
+
+import (
+ "log/slog"
+ "os"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "testmqtt/hooks/auth"
+ "testmqtt/hooks/storage/badger"
+ "testmqtt/hooks/storage/bolt"
+ "testmqtt/hooks/storage/pebble"
+ "testmqtt/hooks/storage/redis"
+ "testmqtt/listeners"
+
+ "testmqtt/mqtt"
+)
+
+var (
+ yamlBytes = []byte(`
+listeners:
+ - type: "tcp"
+ id: "file-tcp1"
+ address: ":1883"
+hooks:
+ auth:
+ allow_all: true
+options:
+ client_net_write_buffer_size: 2048
+ capabilities:
+ minimum_protocol_version: 3
+ compatibilities:
+ restore_sys_info_on_restart: true
+`)
+
+ jsonBytes = []byte(`{
+ "listeners": [
+ {
+ "type": "tcp",
+ "id": "file-tcp1",
+ "address": ":1883"
+ }
+ ],
+ "hooks": {
+ "auth": {
+ "allow_all": true
+ }
+ },
+ "options": {
+ "client_net_write_buffer_size": 2048,
+ "capabilities": {
+ "minimum_protocol_version": 3,
+ "compatibilities": {
+ "restore_sys_info_on_restart": true
+ }
+ }
+ }
+}
+`)
+ parsedOptions = mqtt.Options{
+ Listeners: []listeners.Config{
+ {
+ Type: listeners.TypeTCP,
+ ID: "file-tcp1",
+ Address: ":1883",
+ },
+ },
+ Hooks: []mqtt.HookLoadConfig{
+ {
+ Hook: new(auth.AllowHook),
+ },
+ },
+ ClientNetWriteBufferSize: 2048,
+ Capabilities: &mqtt.Capabilities{
+ MinimumProtocolVersion: 3,
+ Compatibilities: mqtt.Compatibilities{
+ RestoreSysInfoOnRestart: true,
+ },
+ },
+ Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
+ Level: new(slog.LevelVar),
+ })),
+ }
+)
+
+func TestFromBytesEmptyL(t *testing.T) {
+ _, err := FromBytes([]byte{})
+ require.NoError(t, err)
+}
+
+func TestFromBytesYAML(t *testing.T) {
+ o, err := FromBytes(yamlBytes)
+ require.NoError(t, err)
+ require.Equal(t, parsedOptions, *o)
+}
+
+func TestFromBytesYAMLError(t *testing.T) {
+ _, err := FromBytes(append(yamlBytes, 'a'))
+ require.Error(t, err)
+}
+
+func TestFromBytesJSON(t *testing.T) {
+ o, err := FromBytes(jsonBytes)
+ require.NoError(t, err)
+ require.Equal(t, parsedOptions, *o)
+}
+
+func TestFromBytesJSONError(t *testing.T) {
+ _, err := FromBytes(append(jsonBytes, 'a'))
+ require.Error(t, err)
+}
+
+func TestToHooksAuthAllowAll(t *testing.T) {
+ hc := HookConfigs{
+ Auth: &HookAuthConfig{
+ AllowAll: true,
+ },
+ }
+
+ th := hc.toHooksAuth()
+ expect := []mqtt.HookLoadConfig{
+ {Hook: new(auth.AllowHook)},
+ }
+ require.Equal(t, expect, th)
+}
+
+func TestToHooksAuthAllowLedger(t *testing.T) {
+ hc := HookConfigs{
+ Auth: &HookAuthConfig{
+ Ledger: auth.Ledger{
+ Auth: auth.AuthRules{
+ {Username: "peach", Password: "password1", Allow: true},
+ },
+ },
+ },
+ }
+
+ th := hc.toHooksAuth()
+ expect := []mqtt.HookLoadConfig{
+ {
+ Hook: new(auth.Hook),
+ Config: &auth.Options{
+ Ledger: &auth.Ledger{ // avoid copying sync.Locker
+ Auth: auth.AuthRules{
+ {Username: "peach", Password: "password1", Allow: true},
+ },
+ },
+ },
+ },
+ }
+ require.Equal(t, expect, th)
+}
+
+func TestToHooksStorageBadger(t *testing.T) {
+ hc := HookConfigs{
+ Storage: &HookStorageConfig{
+ Badger: &badger.Options{
+ Path: "badger",
+ },
+ },
+ }
+
+ th := hc.toHooksStorage()
+ expect := []mqtt.HookLoadConfig{
+ {
+ Hook: new(badger.Hook),
+ Config: hc.Storage.Badger,
+ },
+ }
+
+ require.Equal(t, expect, th)
+}
+
+func TestToHooksStorageBolt(t *testing.T) {
+ hc := HookConfigs{
+ Storage: &HookStorageConfig{
+ Bolt: &bolt.Options{
+ Path: "bolt",
+ Bucket: "mochi",
+ },
+ },
+ }
+
+ th := hc.toHooksStorage()
+ expect := []mqtt.HookLoadConfig{
+ {
+ Hook: new(bolt.Hook),
+ Config: hc.Storage.Bolt,
+ },
+ }
+
+ require.Equal(t, expect, th)
+}
+
+func TestToHooksStorageRedis(t *testing.T) {
+ hc := HookConfigs{
+ Storage: &HookStorageConfig{
+ Redis: &redis.Options{
+ Username: "test",
+ },
+ },
+ }
+
+ th := hc.toHooksStorage()
+ expect := []mqtt.HookLoadConfig{
+ {
+ Hook: new(redis.Hook),
+ Config: hc.Storage.Redis,
+ },
+ }
+
+ require.Equal(t, expect, th)
+}
+
+func TestToHooksStoragePebble(t *testing.T) {
+ hc := HookConfigs{
+ Storage: &HookStorageConfig{
+ Pebble: &pebble.Options{
+ Path: "pebble",
+ },
+ },
+ }
+
+ th := hc.toHooksStorage()
+ expect := []mqtt.HookLoadConfig{
+ {
+ Hook: new(pebble.Hook),
+ Config: hc.Storage.Pebble,
+ },
+ }
+
+ require.Equal(t, expect, th)
+}
diff --git a/config/hook.go b/config/hook.go
new file mode 100644
index 0000000..ddc627a
--- /dev/null
+++ b/config/hook.go
@@ -0,0 +1,123 @@
+package config
+
+import (
+ "testmqtt/hooks/auth"
+ "testmqtt/hooks/debug"
+ "testmqtt/hooks/storage/badger"
+ "testmqtt/hooks/storage/bolt"
+ "testmqtt/hooks/storage/pebble"
+ "testmqtt/hooks/storage/redis"
+ "testmqtt/mqtt"
+)
+
+// HookConfigs 全部Hook的配置
+// HookConfigs contains configurations to enable individual hooks.
+type HookConfigs struct {
+ Auth *HookAuthConfig `yaml:"auth" json:"auth"` // Auth AuthHook配置
+ Storage *HookStorageConfig `yaml:"storage" json:"storage"` // Storage StorageHook配置
+ Debug *debug.Options `yaml:"debug" json:"debug"` // Debug DebugHook配置
+}
+
+// HookAuthConfig AuthHook的配置
+// HookAuthConfig contains configurations for the auth hook.
+type HookAuthConfig struct {
+ Ledger auth.Ledger `yaml:"ledger" json:"ledger"`
+ AllowAll bool `yaml:"allow_all" json:"allow_all"`
+}
+
+// HookStorageConfig StorageHook的配置
+// HookStorageConfig contains configurations for the different storage hooks.
+type HookStorageConfig struct {
+ Badger *badger.Options `yaml:"badger" json:"badger"`
+ Bolt *bolt.Options `yaml:"bolt" json:"bolt"`
+ Pebble *pebble.Options `yaml:"pebble" json:"pebble"`
+ Redis *redis.Options `yaml:"redis" json:"redis"`
+}
+
+// HookDebugConfig DebugHook的配置
+// HookDebugConfig contains configuration settings for the debug output.
+type HookDebugConfig struct {
+ Enable bool `yaml:"enable" json:"enable"` // non-zero field for enabling hook using file-based config
+ ShowPacketData bool `yaml:"show_packet_data" json:"show_packet_data"` // include decoded packet data (default false)
+ ShowPings bool `yaml:"show_pings" json:"show_pings"` // show ping requests and responses (default false)
+ ShowPasswords bool `yaml:"show_passwords" json:"show_passwords"` // show connecting user passwords (default false)
+}
+
+// ToHooks converts Hook file configurations into Hooks to be added to the server.
+func (hc HookConfigs) ToHooks() []mqtt.HookLoadConfig {
+ var hlc []mqtt.HookLoadConfig
+
+ if hc.Auth != nil {
+ hlc = append(hlc, hc.toHooksAuth()...)
+ }
+
+ if hc.Storage != nil {
+ hlc = append(hlc, hc.toHooksStorage()...)
+ }
+
+ if hc.Debug != nil {
+ hlc = append(hlc, mqtt.HookLoadConfig{
+ Hook: new(debug.Hook),
+ Config: hc.Debug,
+ })
+ }
+
+ return hlc
+}
+
+// toHooksAuth converts auth hook configurations into auth hooks.
+func (hc HookConfigs) toHooksAuth() []mqtt.HookLoadConfig {
+ var hlc []mqtt.HookLoadConfig
+ if hc.Auth.AllowAll {
+ hlc = append(hlc, mqtt.HookLoadConfig{
+ Hook: new(auth.AllowHook),
+ })
+ } else {
+ hlc = append(hlc, mqtt.HookLoadConfig{
+ Hook: new(auth.Hook),
+ Config: &auth.Options{
+ Ledger: &auth.Ledger{ // avoid copying sync.Locker
+ Users: hc.Auth.Ledger.Users,
+ Auth: hc.Auth.Ledger.Auth,
+ ACL: hc.Auth.Ledger.ACL,
+ },
+ },
+ })
+ }
+ return hlc
+}
+
+// toHooksAuth converts storage hook configurations into storage hooks.
+func (hc HookConfigs) toHooksStorage() []mqtt.HookLoadConfig {
+
+ var hlc []mqtt.HookLoadConfig
+
+ if hc.Storage.Badger != nil {
+ hlc = append(hlc, mqtt.HookLoadConfig{
+ Hook: new(badger.Hook),
+ Config: hc.Storage.Badger,
+ })
+ }
+
+ if hc.Storage.Bolt != nil {
+ hlc = append(hlc, mqtt.HookLoadConfig{
+ Hook: new(bolt.Hook),
+ Config: hc.Storage.Bolt,
+ })
+ }
+
+ if hc.Storage.Redis != nil {
+ hlc = append(hlc, mqtt.HookLoadConfig{
+ Hook: new(redis.Hook),
+ Config: hc.Storage.Redis,
+ })
+ }
+
+ if hc.Storage.Pebble != nil {
+ hlc = append(hlc, mqtt.HookLoadConfig{
+ Hook: new(pebble.Hook),
+ Config: hc.Storage.Pebble,
+ })
+ }
+ return hlc
+}
diff --git a/config/logger.go b/config/logger.go
new file mode 100644
index 0000000..0309d5c
--- /dev/null
+++ b/config/logger.go
@@ -0,0 +1,21 @@
+package config
+
+import "os"
+import "log/slog"
+
+type LoggingConfig struct {
+ Level string
+}
+
+func (lc LoggingConfig) ToLogger() *slog.Logger {
+ var level slog.Level
+ if err := level.UnmarshalText([]byte(lc.Level)); err != nil {
+ level = slog.LevelInfo
+ }
+
+ leveler := new(slog.LevelVar)
+ leveler.Set(level)
+ return slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
+ Level: leveler,
+ }))
+}
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..36182ea
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,54 @@
+module testmqtt
+
+go 1.22
+
+require (
+ github.com/jinzhu/copier v0.4.0
+ github.com/mochi-mqtt/server/v2 v2.6.5
+ github.com/rs/xid v1.5.0
+ github.com/stretchr/testify v1.9.0
+)
+
+require (
+ github.com/DataDog/zstd v1.4.5 // indirect
+ github.com/beorn7/perks v1.0.1 // indirect
+ github.com/cespare/xxhash/v2 v2.2.0 // indirect
+ github.com/cockroachdb/errors v1.11.1 // indirect
+ github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect
+ github.com/cockroachdb/pebble v1.1.0 // indirect
+ github.com/cockroachdb/redact v1.1.5 // indirect
+ github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06 // indirect
+ github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/dgraph-io/badger/v4 v4.2.0 // indirect
+ github.com/dgraph-io/ristretto v0.1.1 // indirect
+ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
+ github.com/dustin/go-humanize v1.0.0 // indirect
+ github.com/getsentry/sentry-go v0.18.0 // indirect
+ github.com/go-redis/redis/v8 v8.11.5 // indirect
+ github.com/gogo/protobuf v1.3.2 // indirect
+ github.com/golang/glog v1.0.0 // indirect
+ github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect
+ github.com/golang/protobuf v1.5.2 // indirect
+ github.com/golang/snappy v0.0.4 // indirect
+ github.com/google/flatbuffers v1.12.1 // indirect
+ github.com/gorilla/websocket v1.5.0 // indirect
+ github.com/klauspost/compress v1.15.15 // indirect
+ github.com/kr/pretty v0.3.1 // indirect
+ github.com/kr/text v0.2.0 // indirect
+ github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect
+ github.com/pkg/errors v0.9.1 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
+ github.com/prometheus/client_golang v1.12.0 // indirect
+ github.com/prometheus/client_model v0.2.1-0.20210607210712-147c58e9608a // indirect
+ github.com/prometheus/common v0.32.1 // indirect
+ github.com/prometheus/procfs v0.7.3 // indirect
+ github.com/rogpeppe/go-internal v1.9.0 // indirect
+ go.etcd.io/bbolt v1.3.5 // indirect
+ go.opencensus.io v0.22.5 // indirect
+ golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df // indirect
+ golang.org/x/net v0.23.0 // indirect
+ golang.org/x/sys v0.18.0 // indirect
+ golang.org/x/text v0.14.0 // indirect
+ google.golang.org/protobuf v1.33.0 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..a585726
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,573 @@
+cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
+cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
+cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU=
+cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU=
+cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY=
+cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc=
+cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0=
+cloud.google.com/go v0.50.0/go.mod h1:r9sluTvynVuxRIOHXQEHMFffphuXHOMZMycpNR5e6To=
+cloud.google.com/go v0.52.0/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4=
+cloud.google.com/go v0.53.0/go.mod h1:fp/UouUEsRkN6ryDKNW/Upv/JBKnv6WDthjR6+vze6M=
+cloud.google.com/go v0.54.0/go.mod h1:1rq2OEkV3YMf6n/9ZvGWI3GWw0VoqH/1x2nd8Is/bPc=
+cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKVk=
+cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs=
+cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc=
+cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY=
+cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o=
+cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE=
+cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc=
+cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg=
+cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc=
+cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ=
+cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
+cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
+cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I=
+cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw=
+cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA=
+cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU=
+cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw=
+cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos=
+cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk=
+cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs=
+cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
+dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
+github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
+github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
+github.com/DataDog/zstd v1.4.5 h1:EndNeuB0l9syBZhut0wns3gV1hL8zX8LIu6ZiVHWLIQ=
+github.com/DataDog/zstd v1.4.5/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
+github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
+github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
+github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
+github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
+github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
+github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
+github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
+github.com/alicebob/miniredis/v2 v2.23.0 h1:+lwAJYjvvdIVg6doFHuotFjueJ/7KY10xo/vm3X3Scw=
+github.com/alicebob/miniredis/v2 v2.23.0/go.mod h1:XNqvJdQJv5mSuVMc0ynneafpnL/zv52acZ6kqeS0t88=
+github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
+github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
+github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
+github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
+github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
+github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
+github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
+github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
+github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
+github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
+github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
+github.com/cockroachdb/datadriven v1.0.3-0.20230413201302-be42291fc80f h1:otljaYPt5hWxV3MUfO5dFPFiOXg9CyG5/kCfayTqsJ4=
+github.com/cockroachdb/datadriven v1.0.3-0.20230413201302-be42291fc80f/go.mod h1:a9RdTaap04u637JoCzcUoIcDmvwSUtcUFtT/C3kJlTU=
+github.com/cockroachdb/errors v1.11.1 h1:xSEW75zKaKCWzR3OfxXUxgrk/NtT4G1MiOv5lWZazG8=
+github.com/cockroachdb/errors v1.11.1/go.mod h1:8MUxA3Gi6b25tYlFEBGLf+D8aISL+M4MIpiWMSNRfxw=
+github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b h1:r6VH0faHjZeQy818SGhaone5OnYfxFR/+AzdY3sf5aE=
+github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b/go.mod h1:Vz9DsVWQQhf3vs21MhPMZpMGSht7O/2vFW2xusFUVOs=
+github.com/cockroachdb/pebble v1.1.0 h1:pcFh8CdCIt2kmEpK0OIatq67Ln9uGDYY3d5XnE0LJG4=
+github.com/cockroachdb/pebble v1.1.0/go.mod h1:sEHm5NOXxyiAoKWhoFxT8xMgd/f3RA6qUqQ1BXKrh2E=
+github.com/cockroachdb/redact v1.1.5 h1:u1PMllDkdFfPWaNGMyLD1+so+aq3uUItthCFqzwPJ30=
+github.com/cockroachdb/redact v1.1.5/go.mod h1:BVNblN9mBWFyMyqK1k3AAiSxhvhfK2oOZZ2lK+dpvRg=
+github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06 h1:zuQyyAKVxetITBuuhv3BI9cMrmStnpT18zmgmTxunpo=
+github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06/go.mod h1:7nc4anLGjupUW/PeY5qiNYsdNXj7zopG+eqsS7To5IQ=
+github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/dgraph-io/badger/v4 v4.2.0 h1:kJrlajbXXL9DFTNuhhu9yCx7JJa4qpYWxtE8BzuWsEs=
+github.com/dgraph-io/badger/v4 v4.2.0/go.mod h1:qfCqhPoWDFJRx1gp5QwwyGo8xk1lbHUxvK9nK0OGAak=
+github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8=
+github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA=
+github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA=
+github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
+github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
+github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
+github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
+github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
+github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
+github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
+github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
+github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
+github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
+github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
+github.com/getsentry/sentry-go v0.18.0 h1:MtBW5H9QgdcJabtZcuJG80BMOwaBpkRDZkxRkNC1sN0=
+github.com/getsentry/sentry-go v0.18.0/go.mod h1:Kgon4Mby+FJ7ZWHFUAZgVaIa8sxHtnRJRLTXZr51aKQ=
+github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
+github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
+github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
+github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
+github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
+github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
+github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
+github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
+github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
+github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
+github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
+github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
+github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
+github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
+github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
+github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
+github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
+github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
+github.com/golang/glog v1.0.0 h1:nfP3RFugxnNRyKgeWd4oI1nYvXpxrx8ck8ZrcizshdQ=
+github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4=
+github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
+github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
+github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=
+github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
+github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
+github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
+github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
+github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
+github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
+github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
+github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4=
+github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
+github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
+github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
+github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
+github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
+github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk=
+github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
+github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
+github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
+github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
+github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
+github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
+github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
+github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
+github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
+github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
+github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
+github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
+github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
+github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
+github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
+github.com/google/flatbuffers v1.12.1 h1:MVlul7pQNoDzWRLTw5imwYsl+usrS1TXG2H4jg6ImGw=
+github.com/google/flatbuffers v1.12.1/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
+github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
+github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
+github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
+github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
+github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
+github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
+github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
+github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
+github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
+github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
+github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
+github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
+github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
+github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
+github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
+github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
+github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
+github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
+github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
+github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
+github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
+github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
+github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
+github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
+github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
+github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
+github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
+github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
+github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
+github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
+github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
+github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
+github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
+github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
+github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
+github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
+github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw=
+github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4=
+github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
+github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
+github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
+github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
+github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
+github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
+github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
+github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
+github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
+github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
+github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 h1:I0XW9+e1XWDxdcEniV4rQAIOPUGDq67JSCiRCgGCZLI=
+github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
+github.com/mochi-mqtt/server/v2 v2.6.5 h1:9PiQ6EJt/Dx0ut0Fuuir4F6WinO/5Bpz9szujNwm+q8=
+github.com/mochi-mqtt/server/v2 v2.6.5/go.mod h1:TqztjKGO0/ArOjJt9x9idk0kqPT3CVN8Pb+l+PS5Gdo=
+github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
+github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
+github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
+github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
+github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
+github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
+github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
+github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
+github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
+github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
+github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
+github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
+github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
+github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
+github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
+github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
+github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
+github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
+github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
+github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
+github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
+github.com/prometheus/client_golang v1.12.0 h1:C+UIj/QWtmqY13Arb8kwMt5j34/0Z2iKamrJ+ryC0Gg=
+github.com/prometheus/client_golang v1.12.0/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
+github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
+github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
+github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
+github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
+github.com/prometheus/client_model v0.2.1-0.20210607210712-147c58e9608a h1:CmF68hwI0XsOQ5UwlBopMi2Ow4Pbg32akc4KIVCOm+Y=
+github.com/prometheus/client_model v0.2.1-0.20210607210712-147c58e9608a/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w=
+github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
+github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
+github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
+github.com/prometheus/common v0.32.1 h1:hWIdL3N2HoUx3B8j3YN9mWor0qhY/NlEKZEaXxuIRh4=
+github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls=
+github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
+github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
+github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
+github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
+github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU=
+github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
+github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
+github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
+github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
+github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
+github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
+github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
+github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
+github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
+github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
+github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
+github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
+github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
+github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
+github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 h1:k/gmLsJDWwWqbLCur2yWnJzwQEKRcAHXo6seXGuSwWw=
+github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA=
+go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
+go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
+go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
+go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
+go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
+go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
+go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
+go.opencensus.io v0.22.5 h1:dntmOdLpSpHlVqbW5Eay97DelsZHe+55D+xC6i0dDS0=
+go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
+golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
+golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
+golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
+golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
+golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek=
+golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY=
+golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
+golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
+golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
+golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
+golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
+golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df h1:UA2aFVmmsIlefxMk29Dp2juaUSth8Pyn3Tq5Y5mJGME=
+golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
+golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
+golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
+golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
+golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
+golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
+golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
+golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
+golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
+golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
+golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs=
+golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
+golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
+golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
+golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
+golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
+golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
+golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
+golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
+golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
+golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
+golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
+golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200222125558-5a598a2470a0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
+golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
+golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
+golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
+golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
+golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
+golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
+golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
+golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
+golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
+golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
+golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
+golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
+golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
+golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
+golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
+golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
+golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw=
+golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
+golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
+golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
+golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
+golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
+golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
+golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
+golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
+golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
+golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
+golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
+golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
+golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
+golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
+golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191216173652-a0e659d51361/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
+golang.org/x/tools v0.0.0-20191227053925-7b8e75db28f4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
+golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
+golang.org/x/tools v0.0.0-20200122220014-bf1340f18c4a/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
+golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
+golang.org/x/tools v0.0.0-20200204074204-1cc6d1ef6c74/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
+golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
+golang.org/x/tools v0.0.0-20200212150539-ea181f53ac56/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
+golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
+golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
+golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw=
+golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw=
+golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8=
+golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
+golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
+golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
+golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
+golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
+golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
+golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
+golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
+golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
+golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
+google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
+google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
+google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
+google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI=
+google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI=
+google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI=
+google.golang.org/api v0.17.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
+google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
+google.golang.org/api v0.19.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
+google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
+google.golang.org/api v0.22.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
+google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
+google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
+google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM=
+google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc=
+google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
+google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
+google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
+google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0=
+google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
+google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
+google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
+google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
+google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
+google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
+google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
+google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
+google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
+google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8=
+google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
+google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
+google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
+google.golang.org/genproto v0.0.0-20191230161307-f3c370f40bfb/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
+google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
+google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
+google.golang.org/genproto v0.0.0-20200204135345-fa8e72b47b90/go.mod h1:GmwEX6Z4W5gMy59cAlVYjN9JhxgbQH6Gn+gFDQe2lzA=
+google.golang.org/genproto v0.0.0-20200212174721-66ed5ce911ce/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
+google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
+google.golang.org/genproto v0.0.0-20200228133532-8c2c7df3a383/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
+google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
+google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
+google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
+google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
+google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
+google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U=
+google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
+google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA=
+google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
+google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
+google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
+google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
+google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
+google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
+google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
+google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
+google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
+google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
+google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
+google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKal+60=
+google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk=
+google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak=
+google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak=
+google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
+google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
+google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
+google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
+google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
+google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
+google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
+google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
+google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4=
+google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
+google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
+google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
+google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
+google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
+gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
+gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
+gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
+gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
+gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
+gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
+honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
+honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
+honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
+honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
+honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
+honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
+rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
+rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
+rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
diff --git a/hooks/auth/allow_all.go b/hooks/auth/allow_all.go
new file mode 100644
index 0000000..319a18f
--- /dev/null
+++ b/hooks/auth/allow_all.go
@@ -0,0 +1,41 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package auth
+
+import (
+ "bytes"
+
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+)
+
+// AllowHook is an authentication hook which allows connection access
+// for all users and read and write access to all topics.
+type AllowHook struct {
+ mqtt.HookBase
+}
+
+// ID returns the ID of the hook.
+func (h *AllowHook) ID() string {
+ return "allow-all-auth"
+}
+
+// Provides indicates which hook methods this hook provides.
+func (h *AllowHook) Provides(b byte) bool {
+ return bytes.Contains([]byte{
+ mqtt.OnConnectAuthenticate,
+ mqtt.OnACLCheck,
+ }, []byte{b})
+}
+
+// OnConnectAuthenticate returns true/allowed for all requests.
+func (h *AllowHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
+ return true
+}
+
+// OnACLCheck returns true/allowed for all checks.
+func (h *AllowHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
+ return true
+}
diff --git a/hooks/auth/allow_all_test.go b/hooks/auth/allow_all_test.go
new file mode 100644
index 0000000..6d2fe43
--- /dev/null
+++ b/hooks/auth/allow_all_test.go
@@ -0,0 +1,35 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package auth
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+)
+
+func TestAllowAllID(t *testing.T) {
+ h := new(AllowHook)
+ require.Equal(t, "allow-all-auth", h.ID())
+}
+
+func TestAllowAllProvides(t *testing.T) {
+ h := new(AllowHook)
+ require.True(t, h.Provides(mqtt.OnACLCheck))
+ require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
+ require.False(t, h.Provides(mqtt.OnPublished))
+}
+
+func TestAllowAllOnConnectAuthenticate(t *testing.T) {
+ h := new(AllowHook)
+ require.True(t, h.OnConnectAuthenticate(new(mqtt.Client), packets.Packet{}))
+}
+
+func TestAllowAllOnACLCheck(t *testing.T) {
+ h := new(AllowHook)
+ require.True(t, h.OnACLCheck(new(mqtt.Client), "any", true))
+}
diff --git a/hooks/auth/auth.go b/hooks/auth/auth.go
new file mode 100644
index 0000000..0f6233f
--- /dev/null
+++ b/hooks/auth/auth.go
@@ -0,0 +1,103 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package auth
+
+import (
+ "bytes"
+
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+)
+
+// Options contains the configuration/rules data for the auth ledger.
+type Options struct {
+ Data []byte
+ Ledger *Ledger
+}
+
+// Hook is an authentication hook which implements an auth ledger.
+type Hook struct {
+ mqtt.HookBase
+ config *Options
+ ledger *Ledger
+}
+
+// ID returns the ID of the hook.
+func (h *Hook) ID() string {
+ return "auth-ledger"
+}
+
+// Provides indicates which hook methods this hook provides.
+func (h *Hook) Provides(b byte) bool {
+ return bytes.Contains([]byte{
+ mqtt.OnConnectAuthenticate,
+ mqtt.OnACLCheck,
+ }, []byte{b})
+}
+
+// Init configures the hook with the auth ledger to be used for checking.
+func (h *Hook) Init(config any) error {
+ if _, ok := config.(*Options); !ok && config != nil {
+ return mqtt.ErrInvalidConfigType
+ }
+
+ if config == nil {
+ config = new(Options)
+ }
+
+ h.config = config.(*Options)
+
+ var err error
+ if h.config.Ledger != nil {
+ h.ledger = h.config.Ledger
+ } else if len(h.config.Data) > 0 {
+ h.ledger = new(Ledger)
+ err = h.ledger.Unmarshal(h.config.Data)
+ }
+ if err != nil {
+ return err
+ }
+
+ if h.ledger == nil {
+ h.ledger = &Ledger{
+ Auth: AuthRules{},
+ ACL: ACLRules{},
+ }
+ }
+
+ h.Log.Info("loaded auth rules",
+ "authentication", len(h.ledger.Auth),
+ "acl", len(h.ledger.ACL))
+
+ return nil
+}
+
+// OnConnectAuthenticate returns true if the connecting client has rules which provide access
+// in the auth ledger.
+func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
+ if _, ok := h.ledger.AuthOk(cl, pk); ok {
+ return true
+ }
+
+ h.Log.Info("client failed authentication check",
+ "username", string(pk.Connect.Username),
+ "remote", cl.Net.Remote)
+ return false
+}
+
+// OnACLCheck returns true if the connecting client has matching read or write access to subscribe
+// or publish to a given topic.
+func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
+ if _, ok := h.ledger.ACLOk(cl, topic, write); ok {
+ return true
+ }
+
+ h.Log.Debug("client failed allowed ACL check",
+ "client", cl.ID,
+ "username", string(cl.Properties.Username),
+ "topic", topic)
+
+ return false
+}
diff --git a/hooks/auth/auth_test.go b/hooks/auth/auth_test.go
new file mode 100644
index 0000000..1c8b41e
--- /dev/null
+++ b/hooks/auth/auth_test.go
@@ -0,0 +1,213 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package auth
+
+import (
+ "log/slog"
+ "os"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+)
+
+var logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
+
+// func teardown(t *testing.T, path string, h *Hook) {
+// h.Stop()
+// }
+
+func TestBasicID(t *testing.T) {
+ h := new(Hook)
+ require.Equal(t, "auth-ledger", h.ID())
+}
+
+func TestBasicProvides(t *testing.T) {
+ h := new(Hook)
+ require.True(t, h.Provides(mqtt.OnACLCheck))
+ require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
+ require.False(t, h.Provides(mqtt.OnPublish))
+}
+
+func TestBasicInitBadConfig(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ err := h.Init(map[string]any{})
+ require.Error(t, err)
+}
+
+func TestBasicInitDefaultConfig(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ err := h.Init(nil)
+ require.NoError(t, err)
+}
+
+func TestBasicInitWithLedgerPointer(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ ln := &Ledger{
+ Auth: []AuthRule{
+ {
+ Remote: "127.0.0.1",
+ Allow: true,
+ },
+ },
+ ACL: []ACLRule{
+ {
+ Remote: "127.0.0.1",
+ Filters: Filters{
+ "#": ReadWrite,
+ },
+ },
+ },
+ }
+
+ err := h.Init(&Options{
+ Ledger: ln,
+ })
+
+ require.NoError(t, err)
+ require.Same(t, ln, h.ledger)
+}
+
+func TestBasicInitWithLedgerJSON(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ require.Nil(t, h.ledger)
+ err := h.Init(&Options{
+ Data: ledgerJSON,
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
+ require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
+}
+
+func TestBasicInitWithLedgerYAML(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ require.Nil(t, h.ledger)
+ err := h.Init(&Options{
+ Data: ledgerYAML,
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
+ require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
+}
+
+func TestBasicInitWithLedgerBadDAta(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ require.Nil(t, h.ledger)
+ err := h.Init(&Options{
+ Data: []byte("fdsfdsafasd"),
+ })
+
+ require.Error(t, err)
+}
+
+func TestOnConnectAuthenticate(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ ln := new(Ledger)
+ ln.Auth = checkLedger.Auth
+ ln.ACL = checkLedger.ACL
+ err := h.Init(
+ &Options{
+ Ledger: ln,
+ },
+ )
+
+ require.NoError(t, err)
+
+ require.True(t, h.OnConnectAuthenticate(
+ &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
+ ))
+
+ require.False(t, h.OnConnectAuthenticate(
+ &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
+ ))
+
+ require.False(t, h.OnConnectAuthenticate(
+ &mqtt.Client{},
+ packets.Packet{},
+ ))
+}
+
+func TestOnACL(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ ln := new(Ledger)
+ ln.Auth = checkLedger.Auth
+ ln.ACL = checkLedger.ACL
+ err := h.Init(
+ &Options{
+ Ledger: ln,
+ },
+ )
+
+ require.NoError(t, err)
+
+ require.True(t, h.OnACLCheck(
+ &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ "mochi/info",
+ true,
+ ))
+
+ require.False(t, h.OnACLCheck(
+ &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ "d/j/f",
+ true,
+ ))
+
+ require.True(t, h.OnACLCheck(
+ &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ "readonly",
+ false,
+ ))
+
+ require.False(t, h.OnACLCheck(
+ &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ "readonly",
+ true,
+ ))
+}
diff --git a/hooks/auth/ledger.go b/hooks/auth/ledger.go
new file mode 100644
index 0000000..90c5fb8
--- /dev/null
+++ b/hooks/auth/ledger.go
@@ -0,0 +1,246 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package auth
+
+import (
+ "encoding/json"
+ "strings"
+ "sync"
+
+ "gopkg.in/yaml.v3"
+
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+)
+
+const (
+ Deny Access = iota // user cannot access the topic
+ ReadOnly // user can only subscribe to the topic
+ WriteOnly // user can only publish to the topic
+ ReadWrite // user can both publish and subscribe to the topic
+)
+
+// Access determines the read/write privileges for an ACL rule.
+type Access byte
+
+// Users contains a map of access rules for specific users, keyed on username.
+type Users map[string]UserRule
+
+// UserRule defines a set of access rules for a specific user.
+type UserRule struct {
+ Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
+ Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
+ ACL Filters `json:"acl,omitempty" yaml:"acl,omitempty"` // filters to match, if desired
+ Disallow bool `json:"disallow,omitempty" yaml:"disallow,omitempty"` // allow or disallow the user
+}
+
+// AuthRules defines generic access rules applicable to all users.
+type AuthRules []AuthRule
+
+type AuthRule struct {
+ Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
+ Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
+ Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
+ Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
+ Allow bool `json:"allow,omitempty" yaml:"allow,omitempty"` // allow or disallow the users
+}
+
+// ACLRules defines generic topic or filter access rules applicable to all users.
+type ACLRules []ACLRule
+
+// ACLRule defines access rules for a specific topic or filter.
+type ACLRule struct {
+ Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
+ Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
+ Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
+ Filters Filters `json:"filters,omitempty" yaml:"filters,omitempty"` // filters to match
+}
+
+// Filters is a map of Access rules keyed on filter.
+type Filters map[RString]Access
+
+// RString is a rule value string.
+type RString string
+
+// Matches returns true if the rule matches a given string.
+func (r RString) Matches(a string) bool {
+ rr := string(r)
+ if r == "" || r == "*" || a == rr {
+ return true
+ }
+
+ i := strings.Index(rr, "*")
+ if i > 0 && len(a) > i && strings.Compare(rr[:i], a[:i]) == 0 {
+ return true
+ }
+
+ return false
+}
+
+// FilterMatches returns true if a filter matches a topic rule.
+func (r RString) FilterMatches(a string) bool {
+ _, ok := MatchTopic(string(r), a)
+ return ok
+}
+
+// MatchTopic checks if a given topic matches a filter, accounting for filter
+// wildcards. Eg. filter /a/b/+/c == topic a/b/d/c.
+func MatchTopic(filter string, topic string) (elements []string, matched bool) {
+ filterParts := strings.Split(filter, "/")
+ topicParts := strings.Split(topic, "/")
+
+ elements = make([]string, 0)
+ for i := 0; i < len(filterParts); i++ {
+ if i >= len(topicParts) {
+ matched = false
+ return
+ }
+
+ if filterParts[i] == "+" {
+ elements = append(elements, topicParts[i])
+ continue
+ }
+
+ if filterParts[i] == "#" {
+ matched = true
+ elements = append(elements, strings.Join(topicParts[i:], "/"))
+ return
+ }
+
+ if filterParts[i] != topicParts[i] {
+ matched = false
+ return
+ }
+ }
+
+ return elements, true
+}
+
+// Ledger is an auth ledger containing access rules for users and topics.
+type Ledger struct {
+ sync.Mutex `json:"-" yaml:"-"`
+ Users Users `json:"users" yaml:"users"`
+ Auth AuthRules `json:"auth" yaml:"auth"`
+ ACL ACLRules `json:"acl" yaml:"acl"`
+}
+
+// Update updates the internal values of the ledger.
+func (l *Ledger) Update(ln *Ledger) {
+ l.Lock()
+ defer l.Unlock()
+ l.Auth = ln.Auth
+ l.ACL = ln.ACL
+}
+
+// AuthOk returns true if the rules indicate the user is allowed to authenticate.
+func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) {
+ // If the users map is set, always check for a predefined user first instead
+ // of iterating through global rules.
+ if l.Users != nil {
+ if u, ok := l.Users[string(cl.Properties.Username)]; ok &&
+ u.Password != "" &&
+ u.Password == RString(pk.Connect.Password) {
+ return 0, !u.Disallow
+ }
+ }
+
+ // If there's no users map, or no user was found, attempt to find a matching
+ // rule (which may also contain a user).
+ for n, rule := range l.Auth {
+ if rule.Client.Matches(cl.ID) &&
+ rule.Username.Matches(string(cl.Properties.Username)) &&
+ rule.Password.Matches(string(pk.Connect.Password)) &&
+ rule.Remote.Matches(cl.Net.Remote) {
+ return n, rule.Allow
+ }
+ }
+
+ return 0, false
+}
+
+// ACLOk returns true if the rules indicate the user is allowed to read or write to
+// a specific filter or topic respectively, based on the `write` bool.
+func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) {
+ // If the users map is set, always check for a predefined user first instead
+ // of iterating through global rules.
+ if l.Users != nil {
+ if u, ok := l.Users[string(cl.Properties.Username)]; ok && len(u.ACL) > 0 {
+ for filter, access := range u.ACL {
+ if filter.FilterMatches(topic) {
+ if !write && (access == ReadOnly || access == ReadWrite) {
+ return n, true
+ } else if write && (access == WriteOnly || access == ReadWrite) {
+ return n, true
+ } else {
+ return n, false
+ }
+ }
+ }
+ }
+ }
+
+ for n, rule := range l.ACL {
+ if rule.Client.Matches(cl.ID) &&
+ rule.Username.Matches(string(cl.Properties.Username)) &&
+ rule.Remote.Matches(cl.Net.Remote) {
+ if len(rule.Filters) == 0 {
+ return n, true
+ }
+
+ if write {
+ for filter, access := range rule.Filters {
+ if access == WriteOnly || access == ReadWrite {
+ if filter.FilterMatches(topic) {
+ return n, true
+ }
+ }
+ }
+ }
+
+ if !write {
+ for filter, access := range rule.Filters {
+ if access == ReadOnly || access == ReadWrite {
+ if filter.FilterMatches(topic) {
+ return n, true
+ }
+ }
+ }
+ }
+
+ for filter := range rule.Filters {
+ if filter.FilterMatches(topic) {
+ return n, false
+ }
+ }
+ }
+ }
+
+ return 0, true
+}
+
+// ToJSON encodes the values into a JSON string.
+func (l *Ledger) ToJSON() (data []byte, err error) {
+ return json.Marshal(l)
+}
+
+// ToYAML encodes the values into a YAML string.
+func (l *Ledger) ToYAML() (data []byte, err error) {
+ return yaml.Marshal(l)
+}
+
+// Unmarshal decodes a JSON or YAML string (such as a rule config from a file) into a struct.
+func (l *Ledger) Unmarshal(data []byte) error {
+ l.Lock()
+ defer l.Unlock()
+ if len(data) == 0 {
+ return nil
+ }
+
+ if data[0] == '{' {
+ return json.Unmarshal(data, l)
+ }
+
+ return yaml.Unmarshal(data, &l)
+}
diff --git a/hooks/auth/ledger_test.go b/hooks/auth/ledger_test.go
new file mode 100644
index 0000000..ab4fb3d
--- /dev/null
+++ b/hooks/auth/ledger_test.go
@@ -0,0 +1,610 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package auth
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+)
+
+var (
+ checkLedger = Ledger{
+ Users: Users{ // users are allowed by default
+ "mochi-co": {
+ Password: "melon",
+ ACL: Filters{
+ "d/+/f": Deny,
+ "mochi-co/#": ReadWrite,
+ "readonly": ReadOnly,
+ },
+ },
+ "suspended-username": {
+ Password: "any",
+ Disallow: true,
+ },
+ "mochi": { // ACL only, will defer to AuthRules for authentication
+ ACL: Filters{
+ "special/mochi": ReadOnly,
+ "secret/mochi": Deny,
+ "ignored": ReadWrite,
+ },
+ },
+ },
+ Auth: AuthRules{
+ {Username: "banned-user"}, // never allow specific username
+ {Remote: "127.0.0.1", Allow: true}, // always allow localhost
+ {Remote: "123.123.123.123"}, // disallow any from specific address
+ {Username: "not-mochi", Remote: "111.144.155.166"}, // disallow specific username and address
+ {Remote: "111.*", Allow: true}, // allow any in wildcard (that isn't the above username)
+ {Username: "mochi", Password: "melon", Allow: true}, // allow matching user/pass
+ {Username: "mochi-co", Password: "melon", Allow: false}, // allow matching user/pass (should never trigger due to Users map)
+ },
+ ACL: ACLRules{
+ {
+ Username: "mochi", // allow matching user/pass
+ Filters: Filters{
+ "a/b/c": Deny,
+ "d/+/f": Deny,
+ "mochi/#": ReadWrite,
+ "updates/#": WriteOnly,
+ "readonly": ReadOnly,
+ "ignored": Deny,
+ },
+ },
+ {Remote: "localhost", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to localhost
+ {Username: "admin", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to admin
+ {Remote: "001.002.003.004"}, // Allow all with no filter
+ {Filters: Filters{"$SYS/#": Deny}}, // Deny $SYS access to all others
+ },
+ }
+)
+
+func TestRStringMatches(t *testing.T) {
+ require.True(t, RString("*").Matches("any"))
+ require.True(t, RString("*").Matches(""))
+ require.True(t, RString("").Matches("any"))
+ require.True(t, RString("").Matches(""))
+ require.False(t, RString("no").Matches("any"))
+ require.False(t, RString("no").Matches(""))
+}
+
+func TestCanAuthenticate(t *testing.T) {
+ tt := []struct {
+ desc string
+ client *mqtt.Client
+ pk packets.Packet
+ n int
+ ok bool
+ }{
+ {
+ desc: "allow all local 127.0.0.1",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ Net: mqtt.ClientConnection{
+ Remote: "127.0.0.1",
+ },
+ },
+ pk: packets.Packet{Connect: packets.ConnectParams{}},
+ ok: true,
+ n: 1,
+ },
+ {
+ desc: "allow username/password",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
+ ok: true,
+ n: 5,
+ },
+ {
+ desc: "deny username/password",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
+ ok: false,
+ n: 0,
+ },
+ {
+ desc: "allow all local 127.0.0.1",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ Net: mqtt.ClientConnection{
+ Remote: "127.0.0.1",
+ },
+ },
+ pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
+ ok: true,
+ n: 1,
+ },
+ {
+ desc: "allow username/password",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
+ ok: true,
+ n: 5,
+ },
+ {
+ desc: "deny username/password",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
+ ok: false,
+ n: 0,
+ },
+ {
+ desc: "deny client from address",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("not-mochi"),
+ },
+ Net: mqtt.ClientConnection{
+ Remote: "111.144.155.166",
+ },
+ },
+ pk: packets.Packet{},
+ ok: false,
+ n: 3,
+ },
+ {
+ desc: "allow remote wildcard",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ Net: mqtt.ClientConnection{
+ Remote: "111.0.0.1",
+ },
+ },
+ pk: packets.Packet{},
+ ok: true,
+ n: 4,
+ },
+ {
+ desc: "never allow username",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("banned-user"),
+ },
+ Net: mqtt.ClientConnection{
+ Remote: "127.0.0.1",
+ },
+ },
+ pk: packets.Packet{},
+ ok: false,
+ n: 0,
+ },
+ {
+ desc: "matching user in users",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi-co"),
+ },
+ },
+ pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
+ ok: true,
+ n: 0,
+ },
+ {
+ desc: "never user in users",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("suspended-user"),
+ },
+ },
+ pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("any")}},
+ ok: false,
+ n: 0,
+ },
+ }
+
+ for _, d := range tt {
+ t.Run(d.desc, func(t *testing.T) {
+ n, ok := checkLedger.AuthOk(d.client, d.pk)
+ require.Equal(t, d.n, n)
+ require.Equal(t, d.ok, ok)
+ })
+ }
+}
+
+func TestCanACL(t *testing.T) {
+ tt := []struct {
+ client *mqtt.Client
+ desc string
+ topic string
+ n int
+ write bool
+ ok bool
+ }{
+ {
+ desc: "allow normal write on any other filter",
+ client: &mqtt.Client{},
+ topic: "default/acl/write/access",
+ write: true,
+ ok: true,
+ },
+ {
+ desc: "allow normal read on any other filter",
+ client: &mqtt.Client{},
+ topic: "default/acl/read/access",
+ write: false,
+ ok: true,
+ },
+ {
+ desc: "deny user on literal filter",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "a/b/c",
+ },
+ {
+ desc: "deny user on partial filter",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "d/j/f",
+ },
+ {
+ desc: "allow read/write to user path",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "mochi/read/write",
+ write: true,
+ ok: true,
+ },
+ {
+ desc: "deny read on write-only path",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "updates/no/reading",
+ write: false,
+ ok: false,
+ },
+ {
+ desc: "deny read on write-only path ext",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "updates/mochi",
+ write: false,
+ ok: false,
+ },
+ {
+ desc: "allow read on not-acl path (no #)",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "updates",
+ write: false,
+ ok: true,
+ },
+ {
+ desc: "allow write on write-only path",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "updates/mochi",
+ write: true,
+ ok: true,
+ },
+ {
+ desc: "deny write on read-only path",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "readonly",
+ write: true,
+ ok: false,
+ },
+ {
+ desc: "allow read on read-only path",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "readonly",
+ write: false,
+ ok: true,
+ },
+ {
+ desc: "allow $sys access to localhost",
+ client: &mqtt.Client{
+ Net: mqtt.ClientConnection{
+ Remote: "localhost",
+ },
+ },
+ topic: "$SYS/test",
+ write: false,
+ ok: true,
+ n: 1,
+ },
+ {
+ desc: "allow $sys access to admin",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("admin"),
+ },
+ },
+ topic: "$SYS/test",
+ write: false,
+ ok: true,
+ n: 2,
+ },
+ {
+ desc: "deny $sys access to all others",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "$SYS/test",
+ write: false,
+ ok: false,
+ n: 4,
+ },
+ {
+ desc: "allow all with no filter",
+ client: &mqtt.Client{
+ Net: mqtt.ClientConnection{
+ Remote: "001.002.003.004",
+ },
+ },
+ topic: "any/path",
+ write: true,
+ ok: true,
+ n: 3,
+ },
+ {
+ desc: "use users embedded acl deny",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "secret/mochi",
+ write: true,
+ ok: false,
+ },
+ {
+ desc: "use users embedded acl any",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "any/mochi",
+ write: true,
+ ok: true,
+ },
+ {
+ desc: "use users embedded acl write on read-only",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "special/mochi",
+ write: true,
+ ok: false,
+ },
+ {
+ desc: "use users embedded acl read on read-only",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "special/mochi",
+ write: false,
+ ok: true,
+ },
+ {
+ desc: "preference users embedded acl",
+ client: &mqtt.Client{
+ Properties: mqtt.ClientProperties{
+ Username: []byte("mochi"),
+ },
+ },
+ topic: "ignored",
+ write: true,
+ ok: true,
+ },
+ }
+
+ for _, d := range tt {
+ t.Run(d.desc, func(t *testing.T) {
+ n, ok := checkLedger.ACLOk(d.client, d.topic, d.write)
+ require.Equal(t, d.n, n)
+ require.Equal(t, d.ok, ok)
+ })
+ }
+}
+
+func TestMatchTopic(t *testing.T) {
+ el, matched := MatchTopic("a/+/c/+", "a/b/c/d")
+ require.True(t, matched)
+ require.Equal(t, []string{"b", "d"}, el)
+
+ el, matched = MatchTopic("a/+/+/+", "a/b/c/d")
+ require.True(t, matched)
+ require.Equal(t, []string{"b", "c", "d"}, el)
+
+ el, matched = MatchTopic("stuff/#", "stuff/things/yeah")
+ require.True(t, matched)
+ require.Equal(t, []string{"things/yeah"}, el)
+
+ el, matched = MatchTopic("a/+/#/+", "a/b/c/d/as/dds")
+ require.True(t, matched)
+ require.Equal(t, []string{"b", "c/d/as/dds"}, el)
+
+ el, matched = MatchTopic("test", "test")
+ require.True(t, matched)
+ require.Equal(t, make([]string, 0), el)
+
+ el, matched = MatchTopic("things/stuff//", "things/stuff/")
+ require.False(t, matched)
+ require.Equal(t, make([]string, 0), el)
+
+ el, matched = MatchTopic("t", "t2")
+ require.False(t, matched)
+ require.Equal(t, make([]string, 0), el)
+
+ el, matched = MatchTopic(" ", " ")
+ require.False(t, matched)
+ require.Equal(t, make([]string, 0), el)
+}
+
+var (
+ ledgerStruct = Ledger{
+ Users: Users{
+ "mochi": {
+ Password: "peach",
+ ACL: Filters{
+ "readonly": ReadOnly,
+ "deny": Deny,
+ },
+ },
+ },
+ Auth: AuthRules{
+ {
+ Client: "*",
+ Username: "mochi-co",
+ Password: "melon",
+ Remote: "192.168.1.*",
+ Allow: true,
+ },
+ },
+ ACL: ACLRules{
+ {
+ Client: "*",
+ Username: "mochi-co",
+ Remote: "127.*",
+ Filters: Filters{
+ "readonly": ReadOnly,
+ "writeonly": WriteOnly,
+ "readwrite": ReadWrite,
+ "deny": Deny,
+ },
+ },
+ },
+ }
+
+ ledgerJSON = []byte(`{"users":{"mochi":{"password":"peach","acl":{"deny":0,"readonly":1}}},"auth":[{"client":"*","username":"mochi-co","remote":"192.168.1.*","password":"melon","allow":true}],"acl":[{"client":"*","username":"mochi-co","remote":"127.*","filters":{"deny":0,"readonly":1,"readwrite":3,"writeonly":2}}]}`)
+ ledgerYAML = []byte(`users:
+ mochi:
+ password: peach
+ acl:
+ deny: 0
+ readonly: 1
+auth:
+ - client: '*'
+ username: mochi-co
+ remote: 192.168.1.*
+ password: melon
+ allow: true
+acl:
+ - client: '*'
+ username: mochi-co
+ remote: 127.*
+ filters:
+ deny: 0
+ readonly: 1
+ readwrite: 3
+ writeonly: 2
+`)
+)
+
+func TestLedgerUpdate(t *testing.T) {
+ old := &Ledger{
+ Auth: AuthRules{
+ {Remote: "127.0.0.1", Allow: true},
+ },
+ }
+
+ n := &Ledger{
+ Auth: AuthRules{
+ {Remote: "127.0.0.1", Allow: true},
+ {Remote: "192.168.*", Allow: true},
+ },
+ }
+
+ old.Update(n)
+ require.Len(t, old.Auth, 2)
+ require.Equal(t, RString("192.168.*"), old.Auth[1].Remote)
+ require.NotSame(t, n, old)
+}
+
+func TestLedgerToJSON(t *testing.T) {
+ data, err := ledgerStruct.ToJSON()
+ require.NoError(t, err)
+ require.Equal(t, ledgerJSON, data)
+}
+
+func TestLedgerToYAML(t *testing.T) {
+ data, err := ledgerStruct.ToYAML()
+ require.NoError(t, err)
+ require.Equal(t, ledgerYAML, data)
+}
+
+func TestLedgerUnmarshalFromYAML(t *testing.T) {
+ l := new(Ledger)
+ err := l.Unmarshal(ledgerYAML)
+ require.NoError(t, err)
+ require.Equal(t, &ledgerStruct, l)
+ require.NotSame(t, l, &ledgerStruct)
+}
+
+func TestLedgerUnmarshalFromJSON(t *testing.T) {
+ l := new(Ledger)
+ err := l.Unmarshal(ledgerJSON)
+ require.NoError(t, err)
+ require.Equal(t, &ledgerStruct, l)
+ require.NotSame(t, l, &ledgerStruct)
+}
+
+func TestLedgerUnmarshalNil(t *testing.T) {
+ l := new(Ledger)
+ err := l.Unmarshal([]byte{})
+ require.NoError(t, err)
+ require.Equal(t, new(Ledger), l)
+}
diff --git a/hooks/debug/debug.go b/hooks/debug/debug.go
new file mode 100644
index 0000000..5fa852b
--- /dev/null
+++ b/hooks/debug/debug.go
@@ -0,0 +1,237 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package debug
+
+import (
+ "fmt"
+ "log/slog"
+ "strings"
+
+ "testmqtt/hooks/storage"
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+)
+
+// Options contains configuration settings for the debug output.
+type Options struct {
+ Enable bool `yaml:"enable" json:"enable"` // non-zero field for enabling hook using file-based config
+ ShowPacketData bool `yaml:"show_packet_data" json:"show_packet_data"` // include decoded packet data (default false)
+ ShowPings bool `yaml:"show_pings" json:"show_pings"` // show ping requests and responses (default false)
+ ShowPasswords bool `yaml:"show_passwords" json:"show_passwords"` // show connecting user passwords (default false)
+}
+
+// Hook is a debugging hook which logs additional low-level information from the server.
+type Hook struct {
+ mqtt.HookBase
+ config *Options
+ Log *slog.Logger
+}
+
+// ID returns the ID of the hook.
+func (h *Hook) ID() string {
+ return "debug"
+}
+
+// Provides indicates that this hook provides all methods.
+func (h *Hook) Provides(b byte) bool {
+ return true
+}
+
+// Init is called when the hook is initialized.
+func (h *Hook) Init(config any) error {
+ if _, ok := config.(*Options); !ok && config != nil {
+ return mqtt.ErrInvalidConfigType
+ }
+
+ if config == nil {
+ config = new(Options)
+ }
+
+ h.config = config.(*Options)
+
+ return nil
+}
+
+// SetOpts is called when the hook receives inheritable server parameters.
+func (h *Hook) SetOpts(l *slog.Logger, opts *mqtt.HookOptions) {
+ h.Log = l
+ h.Log.Debug("", "method", "SetOpts")
+}
+
+// Stop is called when the hook is stopped.
+func (h *Hook) Stop() error {
+ h.Log.Debug("", "method", "Stop")
+ return nil
+}
+
+// OnStarted is called when the server starts.
+func (h *Hook) OnStarted() {
+ h.Log.Debug("", "method", "OnStarted")
+}
+
+// OnStopped is called when the server stops.
+func (h *Hook) OnStopped() {
+ h.Log.Debug("", "method", "OnStopped")
+}
+
+// OnPacketRead is called when a new packet is received from a client.
+func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
+ if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
+ return pk, nil
+ }
+
+ h.Log.Debug(fmt.Sprintf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk))
+ return pk, nil
+}
+
+// OnPacketSent is called when a packet is sent to a client.
+func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) {
+ if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
+ return
+ }
+
+ h.Log.Debug(fmt.Sprintf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk))
+}
+
+// OnRetainMessage is called when a published message is retained (or retain deleted/modified).
+func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
+ h.Log.Debug("retained message on topic", "m", h.packetMeta(pk))
+}
+
+// OnQosPublish is called when a publish packet with Qos is issued to a subscriber.
+func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
+ h.Log.Debug("inflight out", "m", h.packetMeta(pk))
+}
+
+// OnQosComplete is called when the Qos flow for a message has been completed.
+func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
+ h.Log.Debug("inflight complete", "m", h.packetMeta(pk))
+}
+
+// OnQosDropped is called the Qos flow for a message expires.
+func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
+ h.Log.Debug("inflight dropped", "m", h.packetMeta(pk))
+}
+
+// OnLWTSent is called when a Will Message has been issued from a disconnecting client.
+func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) {
+ h.Log.Debug("sent lwt for client", "method", "OnLWTSent", "client", cl.ID)
+}
+
+// OnRetainedExpired is called when the server clears expired retained messages.
+func (h *Hook) OnRetainedExpired(filter string) {
+ h.Log.Debug("retained message expired", "method", "OnRetainedExpired", "topic", filter)
+}
+
+// OnClientExpired is called when the server clears an expired client.
+func (h *Hook) OnClientExpired(cl *mqtt.Client) {
+ h.Log.Debug("client session expired", "method", "OnClientExpired", "client", cl.ID)
+}
+
+// StoredClients is called when the server restores clients from a store.
+func (h *Hook) StoredClients() (v []storage.Client, err error) {
+ h.Log.Debug("", "method", "StoredClients")
+
+ return v, nil
+}
+
+// StoredSubscriptions is called when the server restores subscriptions from a store.
+func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
+ h.Log.Debug("", "method", "StoredSubscriptions")
+ return v, nil
+}
+
+// StoredRetainedMessages is called when the server restores retained messages from a store.
+func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
+ h.Log.Debug("", "method", "StoredRetainedMessages")
+ return v, nil
+}
+
+// StoredInflightMessages is called when the server restores inflight messages from a store.
+func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
+ h.Log.Debug("", "method", "StoredInflightMessages")
+ return v, nil
+}
+
+// StoredSysInfo is called when the server restores system info from a store.
+func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
+ h.Log.Debug("", "method", "StoredSysInfo")
+
+ return v, nil
+}
+
+// packetMeta adds additional type-specific metadata to the debug logs.
+func (h *Hook) packetMeta(pk packets.Packet) map[string]any {
+ m := map[string]any{}
+ switch pk.FixedHeader.Type {
+ case packets.Connect:
+ m["id"] = pk.Connect.ClientIdentifier
+ m["clean"] = pk.Connect.Clean
+ m["keepalive"] = pk.Connect.Keepalive
+ m["version"] = pk.ProtocolVersion
+ m["username"] = string(pk.Connect.Username)
+ if h.config.ShowPasswords {
+ m["password"] = string(pk.Connect.Password)
+ }
+ if pk.Connect.WillFlag {
+ m["will_topic"] = pk.Connect.WillTopic
+ m["will_payload"] = string(pk.Connect.WillPayload)
+ }
+ case packets.Publish:
+ m["topic"] = pk.TopicName
+ m["payload"] = string(pk.Payload)
+ m["raw"] = pk.Payload
+ m["qos"] = pk.FixedHeader.Qos
+ m["id"] = pk.PacketID
+ case packets.Connack:
+ fallthrough
+ case packets.Disconnect:
+ fallthrough
+ case packets.Puback:
+ fallthrough
+ case packets.Pubrec:
+ fallthrough
+ case packets.Pubrel:
+ fallthrough
+ case packets.Pubcomp:
+ m["id"] = pk.PacketID
+ m["reason"] = int(pk.ReasonCode)
+ if pk.ReasonCode > packets.CodeSuccess.Code && pk.ProtocolVersion == 5 {
+ m["reason_string"] = pk.Properties.ReasonString
+ }
+ case packets.Subscribe:
+ f := map[string]int{}
+ ids := map[string]int{}
+ for _, v := range pk.Filters {
+ f[v.Filter] = int(v.Qos)
+ ids[v.Filter] = v.Identifier
+ }
+ m["filters"] = f
+ m["subids"] = f
+
+ case packets.Unsubscribe:
+ f := []string{}
+ for _, v := range pk.Filters {
+ f = append(f, v.Filter)
+ }
+ m["filters"] = f
+ case packets.Suback:
+ fallthrough
+ case packets.Unsuback:
+ r := []int{}
+ for _, v := range pk.ReasonCodes {
+ r = append(r, int(v))
+ }
+ m["reasons"] = r
+ case packets.Auth:
+ // tbd
+ }
+
+ if h.config.ShowPacketData {
+ m["packet"] = pk
+ }
+
+ return m
+}
diff --git a/hooks/storage/badger/badger.go b/hooks/storage/badger/badger.go
new file mode 100644
index 0000000..5cafa38
--- /dev/null
+++ b/hooks/storage/badger/badger.go
@@ -0,0 +1,576 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co, gsagula, werbenhu
+
+package badger
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ badgerdb "github.com/dgraph-io/badger/v4"
+ "testmqtt/hooks/storage"
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+ "testmqtt/system"
+)
+
+const (
+ // defaultDbFile is the default file path for the badger db file.
+ defaultDbFile = ".badger"
+ defaultGcInterval = 5 * 60 // gc interval in seconds
+ defaultGcDiscardRatio = 0.5
+)
+
+// clientKey returns a primary key for a client.
+func clientKey(cl *mqtt.Client) string {
+ return storage.ClientKey + "_" + cl.ID
+}
+
+// subscriptionKey returns a primary key for a subscription.
+func subscriptionKey(cl *mqtt.Client, filter string) string {
+ return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
+}
+
+// retainedKey returns a primary key for a retained message.
+func retainedKey(topic string) string {
+ return storage.RetainedKey + "_" + topic
+}
+
+// inflightKey returns a primary key for an inflight message.
+func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
+ return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
+}
+
+// sysInfoKey returns a primary key for system info.
+func sysInfoKey() string {
+ return storage.SysInfoKey
+}
+
+// Serializable is an interface for objects that can be serialized and deserialized.
+type Serializable interface {
+ UnmarshalBinary([]byte) error
+ MarshalBinary() (data []byte, err error)
+}
+
+// Options contains configuration settings for the BadgerDB instance.
+type Options struct {
+ Options *badgerdb.Options
+ Path string `yaml:"path" json:"path"`
+ // GcDiscardRatio specifies the ratio of log discard compared to the maximum possible log discard.
+ // Setting it to a higher value would result in fewer space reclaims, while setting it to a lower value
+ // would result in more space reclaims at the cost of increased activity on the LSM tree.
+ // discardRatio must be in the range (0.0, 1.0), both endpoints excluded, otherwise, it will be set to the default value of 0.5.
+ GcDiscardRatio float64 `yaml:"gc_discard_ratio" json:"gc_discard_ratio"`
+ GcInterval int64 `yaml:"gc_interval" json:"gc_interval"`
+}
+
+// Hook is a persistent storage hook based using BadgerDB file store as a backend.
+type Hook struct {
+ mqtt.HookBase
+ config *Options // options for configuring the BadgerDB instance.
+ gcTicker *time.Ticker // Ticker for BadgerDB garbage collection.
+ db *badgerdb.DB // the BadgerDB instance.
+}
+
+// ID returns the id of the hook.
+func (h *Hook) ID() string {
+ return "badger-db"
+}
+
+// Provides indicates which hook methods this hook provides.
+func (h *Hook) Provides(b byte) bool {
+ return bytes.Contains([]byte{
+ mqtt.OnSessionEstablished,
+ mqtt.OnDisconnect,
+ mqtt.OnSubscribed,
+ mqtt.OnUnsubscribed,
+ mqtt.OnRetainMessage,
+ mqtt.OnWillSent,
+ mqtt.OnQosPublish,
+ mqtt.OnQosComplete,
+ mqtt.OnQosDropped,
+ mqtt.OnSysInfoTick,
+ mqtt.OnClientExpired,
+ mqtt.OnRetainedExpired,
+ mqtt.StoredClients,
+ mqtt.StoredInflightMessages,
+ mqtt.StoredRetainedMessages,
+ mqtt.StoredSubscriptions,
+ mqtt.StoredSysInfo,
+ }, []byte{b})
+}
+
+// GcLoop periodically runs the garbage collection process to reclaim space in the value log files.
+// It uses a ticker to trigger the garbage collection at regular intervals specified by the configuration.
+// Refer to: https://dgraph.io/docs/badger/get-started/#garbage-collection
+func (h *Hook) gcLoop() {
+ for range h.gcTicker.C {
+ again:
+ // Run the garbage collection process with a threshold.
+ // If the process returns nil (success), repeat the process.
+ err := h.db.RunValueLogGC(h.config.GcDiscardRatio)
+ if err == nil {
+ goto again // Retry garbage collection if successful.
+ }
+ }
+}
+
+// Init initializes and connects to the badger instance.
+func (h *Hook) Init(config any) error {
+ if _, ok := config.(*Options); !ok && config != nil {
+ return mqtt.ErrInvalidConfigType
+ }
+
+ if config == nil {
+ h.config = new(Options)
+ } else {
+ h.config = config.(*Options)
+ }
+
+ if len(h.config.Path) == 0 {
+ h.config.Path = defaultDbFile
+ }
+
+ if h.config.GcInterval == 0 {
+ h.config.GcInterval = defaultGcInterval
+ }
+
+ if h.config.GcDiscardRatio <= 0.0 || h.config.GcDiscardRatio >= 1.0 {
+ h.config.GcDiscardRatio = defaultGcDiscardRatio
+ }
+
+ if h.config.Options == nil {
+ defaultOpts := badgerdb.DefaultOptions(h.config.Path)
+ h.config.Options = &defaultOpts
+ }
+ h.config.Options.Logger = h
+
+ var err error
+ h.db, err = badgerdb.Open(*h.config.Options)
+ if err != nil {
+ return err
+ }
+
+ h.gcTicker = time.NewTicker(time.Duration(h.config.GcInterval) * time.Second)
+ go h.gcLoop()
+
+ return nil
+}
+
+// Stop closes the badger instance.
+func (h *Hook) Stop() error {
+ if h.gcTicker != nil {
+ h.gcTicker.Stop()
+ }
+ return h.db.Close()
+}
+
+// OnSessionEstablished adds a client to the store when their session is established.
+func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
+ h.updateClient(cl)
+}
+
+// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
+func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
+ h.updateClient(cl)
+}
+
+// updateClient writes the client data to the store.
+func (h *Hook) updateClient(cl *mqtt.Client) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ props := cl.Properties.Props.Copy(false)
+ in := &storage.Client{
+ ID: cl.ID,
+ T: storage.ClientKey,
+ Remote: cl.Net.Remote,
+ Listener: cl.Net.Listener,
+ Username: cl.Properties.Username,
+ Clean: cl.Properties.Clean,
+ ProtocolVersion: cl.Properties.ProtocolVersion,
+ Properties: storage.ClientProperties{
+ SessionExpiryInterval: props.SessionExpiryInterval,
+ AuthenticationMethod: props.AuthenticationMethod,
+ AuthenticationData: props.AuthenticationData,
+ RequestProblemInfo: props.RequestProblemInfo,
+ RequestResponseInfo: props.RequestResponseInfo,
+ ReceiveMaximum: props.ReceiveMaximum,
+ TopicAliasMaximum: props.TopicAliasMaximum,
+ User: props.User,
+ MaximumPacketSize: props.MaximumPacketSize,
+ },
+ Will: storage.ClientWill(cl.Properties.Will),
+ }
+
+ err := h.setKv(clientKey(cl), in)
+ if err != nil {
+ h.Log.Error("failed to upsert client data", "error", err, "data", in)
+ }
+}
+
+// OnDisconnect removes a client from the store if their session has expired.
+func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ h.updateClient(cl)
+
+ if !expire {
+ return
+ }
+
+ if errors.Is(cl.StopCause(), packets.ErrSessionTakenOver) {
+ return
+ }
+
+ _ = h.delKv(clientKey(cl))
+}
+
+// OnSubscribed adds one or more client subscriptions to the store.
+func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ var in *storage.Subscription
+ for i := 0; i < len(pk.Filters); i++ {
+ in = &storage.Subscription{
+ ID: subscriptionKey(cl, pk.Filters[i].Filter),
+ T: storage.SubscriptionKey,
+ Client: cl.ID,
+ Qos: reasonCodes[i],
+ Filter: pk.Filters[i].Filter,
+ Identifier: pk.Filters[i].Identifier,
+ NoLocal: pk.Filters[i].NoLocal,
+ RetainHandling: pk.Filters[i].RetainHandling,
+ RetainAsPublished: pk.Filters[i].RetainAsPublished,
+ }
+
+ _ = h.setKv(in.ID, in)
+ }
+}
+
+// OnUnsubscribed removes one or more client subscriptions from the store.
+func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ for i := 0; i < len(pk.Filters); i++ {
+ _ = h.delKv(subscriptionKey(cl, pk.Filters[i].Filter))
+ }
+}
+
+// OnRetainMessage adds a retained message for a topic to the store.
+func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ if r == -1 {
+ _ = h.delKv(retainedKey(pk.TopicName))
+ return
+ }
+
+ props := pk.Properties.Copy(false)
+ in := &storage.Message{
+ ID: retainedKey(pk.TopicName),
+ T: storage.RetainedKey,
+ FixedHeader: pk.FixedHeader,
+ TopicName: pk.TopicName,
+ Payload: pk.Payload,
+ Created: pk.Created,
+ Origin: pk.Origin,
+ Properties: storage.MessageProperties{
+ PayloadFormat: props.PayloadFormat,
+ MessageExpiryInterval: props.MessageExpiryInterval,
+ ContentType: props.ContentType,
+ ResponseTopic: props.ResponseTopic,
+ CorrelationData: props.CorrelationData,
+ SubscriptionIdentifier: props.SubscriptionIdentifier,
+ TopicAlias: props.TopicAlias,
+ User: props.User,
+ },
+ }
+
+ _ = h.setKv(in.ID, in)
+}
+
+// OnQosPublish adds or updates an inflight message in the store.
+func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ props := pk.Properties.Copy(false)
+ in := &storage.Message{
+ ID: inflightKey(cl, pk),
+ T: storage.InflightKey,
+ Origin: pk.Origin,
+ PacketID: pk.PacketID,
+ FixedHeader: pk.FixedHeader,
+ TopicName: pk.TopicName,
+ Payload: pk.Payload,
+ Sent: sent,
+ Created: pk.Created,
+ Properties: storage.MessageProperties{
+ PayloadFormat: props.PayloadFormat,
+ MessageExpiryInterval: props.MessageExpiryInterval,
+ ContentType: props.ContentType,
+ ResponseTopic: props.ResponseTopic,
+ CorrelationData: props.CorrelationData,
+ SubscriptionIdentifier: props.SubscriptionIdentifier,
+ TopicAlias: props.TopicAlias,
+ User: props.User,
+ },
+ }
+
+ _ = h.setKv(in.ID, in)
+}
+
+// OnQosComplete removes a resolved inflight message from the store.
+func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ _ = h.delKv(inflightKey(cl, pk))
+}
+
+// OnQosDropped removes a dropped inflight message from the store.
+func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ }
+
+ h.OnQosComplete(cl, pk)
+}
+
+// OnSysInfoTick stores the latest system info in the store.
+func (h *Hook) OnSysInfoTick(sys *system.Info) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ in := &storage.SystemInfo{
+ ID: sysInfoKey(),
+ T: storage.SysInfoKey,
+ Info: *sys.Clone(),
+ }
+
+ _ = h.setKv(in.ID, in)
+}
+
+// OnRetainedExpired deletes expired retained messages from the store.
+func (h *Hook) OnRetainedExpired(filter string) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ _ = h.delKv(retainedKey(filter))
+}
+
+// OnClientExpired deleted expired clients from the store.
+func (h *Hook) OnClientExpired(cl *mqtt.Client) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ _ = h.delKv(clientKey(cl))
+}
+
+// StoredClients returns all stored clients from the store.
+func (h *Hook) StoredClients() (v []storage.Client, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ err = h.iterKv(storage.ClientKey, func(value []byte) error {
+ obj := storage.Client{}
+ err = obj.UnmarshalBinary(value)
+ if err == nil {
+ v = append(v, obj)
+ }
+ return err
+ })
+ return
+}
+
+// StoredSubscriptions returns all stored subscriptions from the store.
+func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ v = make([]storage.Subscription, 0)
+ err = h.iterKv(storage.SubscriptionKey, func(value []byte) error {
+ obj := storage.Subscription{}
+ err = obj.UnmarshalBinary(value)
+ if err == nil {
+ v = append(v, obj)
+ }
+ return err
+ })
+ return
+}
+
+// StoredRetainedMessages returns all stored retained messages from the store.
+func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ v = make([]storage.Message, 0)
+ err = h.iterKv(storage.RetainedKey, func(value []byte) error {
+ obj := storage.Message{}
+ err = obj.UnmarshalBinary(value)
+ if err == nil {
+ v = append(v, obj)
+ }
+ return err
+ })
+
+ if err != nil && !errors.Is(err, badgerdb.ErrKeyNotFound) {
+ return
+ }
+ return
+}
+
+// StoredInflightMessages returns all stored inflight messages from the store.
+func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ v = make([]storage.Message, 0)
+ err = h.iterKv(storage.InflightKey, func(value []byte) error {
+ obj := storage.Message{}
+ err = obj.UnmarshalBinary(value)
+ if err == nil {
+ v = append(v, obj)
+ }
+ return err
+ })
+ return
+}
+
+// StoredSysInfo returns the system info from the store.
+func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ err = h.getKv(storage.SysInfoKey, &v)
+ if err != nil && !errors.Is(err, badgerdb.ErrKeyNotFound) {
+ return
+ }
+ return v, nil
+}
+
+// Errorf satisfies the badger interface for an error logger.
+func (h *Hook) Errorf(m string, v ...any) {
+ h.Log.Error(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
+
+}
+
+// Warningf satisfies the badger interface for a warning logger.
+func (h *Hook) Warningf(m string, v ...any) {
+ h.Log.Warn(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
+}
+
+// Infof satisfies the badger interface for an info logger.
+func (h *Hook) Infof(m string, v ...any) {
+ h.Log.Info(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
+}
+
+// Debugf satisfies the badger interface for a debug logger.
+func (h *Hook) Debugf(m string, v ...any) {
+ h.Log.Debug(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
+}
+
+// setKv stores a key-value pair in the database.
+func (h *Hook) setKv(k string, v storage.Serializable) error {
+ err := h.db.Update(func(txn *badgerdb.Txn) error {
+ data, _ := v.MarshalBinary()
+ return txn.Set([]byte(k), data)
+ })
+ if err != nil {
+ h.Log.Error("failed to upsert data", "error", err, "key", k)
+ }
+ return err
+}
+
+// delKv deletes a key-value pair from the database.
+func (h *Hook) delKv(k string) error {
+ err := h.db.Update(func(txn *badgerdb.Txn) error {
+ return txn.Delete([]byte(k))
+ })
+
+ if err != nil {
+ h.Log.Error("failed to delete data", "error", err, "key", k)
+ }
+ return err
+}
+
+// getKv retrieves the value associated with a key from the database.
+func (h *Hook) getKv(k string, v storage.Serializable) error {
+ return h.db.View(func(txn *badgerdb.Txn) error {
+ item, err := txn.Get([]byte(k))
+ if err != nil {
+ return err
+ }
+ value, err := item.ValueCopy(nil)
+ if err != nil {
+ return err
+ }
+ return v.UnmarshalBinary(value)
+ })
+}
+
+// iterKv iterates over key-value pairs with keys having the specified prefix in the database.
+func (h *Hook) iterKv(prefix string, visit func([]byte) error) error {
+ err := h.db.View(func(txn *badgerdb.Txn) error {
+ iterator := txn.NewIterator(badgerdb.DefaultIteratorOptions)
+ defer iterator.Close()
+
+ for iterator.Seek([]byte(prefix)); iterator.ValidForPrefix([]byte(prefix)); iterator.Next() {
+ item := iterator.Item()
+ value, err := item.ValueCopy(nil)
+
+ if err != nil {
+ return err
+ }
+
+ if err := visit(value); err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ if err != nil {
+ h.Log.Error("failed to find data", "error", err, "prefix", prefix)
+ }
+ return err
+}
diff --git a/hooks/storage/badger/badger_test.go b/hooks/storage/badger/badger_test.go
new file mode 100644
index 0000000..5f65fd4
--- /dev/null
+++ b/hooks/storage/badger/badger_test.go
@@ -0,0 +1,809 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co, werbenhu
+
+package badger
+
+import (
+ "errors"
+ "log/slog"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ badgerdb "github.com/dgraph-io/badger/v4"
+ "testmqtt/hooks/storage"
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+ "testmqtt/system"
+
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
+
+ client = &mqtt.Client{
+ ID: "test",
+ Net: mqtt.ClientConnection{
+ Remote: "test.addr",
+ Listener: "listener",
+ },
+ Properties: mqtt.ClientProperties{
+ Username: []byte("username"),
+ Clean: false,
+ },
+ }
+
+ pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
+)
+
+func teardown(t *testing.T, path string, h *Hook) {
+ _ = h.Stop()
+ _ = h.db.Close()
+ err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
+ require.NoError(t, err)
+}
+
+func TestSetGetDelKv(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.Init(nil)
+ defer teardown(t, h.config.Path, h)
+
+ key := "testKey"
+ value := &storage.Client{ID: "cl1"}
+ err := h.setKv(key, value)
+ require.NoError(t, err)
+
+ var client storage.Client
+ err = h.getKv(key, &client)
+ require.NoError(t, err)
+ require.Equal(t, "cl1", client.ID)
+
+ err = h.delKv(key)
+ require.NoError(t, err)
+
+ err = h.getKv(key, &client)
+ require.ErrorIs(t, badgerdb.ErrKeyNotFound, err)
+}
+
+func TestClientKey(t *testing.T) {
+ k := clientKey(&mqtt.Client{ID: "cl1"})
+ require.Equal(t, storage.ClientKey+"_cl1", k)
+}
+
+func TestSubscriptionKey(t *testing.T) {
+ k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
+ require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
+}
+
+func TestRetainedKey(t *testing.T) {
+ k := retainedKey("a/b/c")
+ require.Equal(t, storage.RetainedKey+"_a/b/c", k)
+}
+
+func TestInflightKey(t *testing.T) {
+ k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
+ require.Equal(t, storage.InflightKey+"_cl1:1", k)
+}
+
+func TestSysInfoKey(t *testing.T) {
+ require.Equal(t, storage.SysInfoKey, sysInfoKey())
+}
+
+func TestID(t *testing.T) {
+ h := new(Hook)
+ require.Equal(t, "badger-db", h.ID())
+}
+
+func TestProvides(t *testing.T) {
+ h := new(Hook)
+ require.True(t, h.Provides(mqtt.OnSessionEstablished))
+ require.True(t, h.Provides(mqtt.OnDisconnect))
+ require.True(t, h.Provides(mqtt.OnSubscribed))
+ require.True(t, h.Provides(mqtt.OnUnsubscribed))
+ require.True(t, h.Provides(mqtt.OnRetainMessage))
+ require.True(t, h.Provides(mqtt.OnQosPublish))
+ require.True(t, h.Provides(mqtt.OnQosComplete))
+ require.True(t, h.Provides(mqtt.OnQosDropped))
+ require.True(t, h.Provides(mqtt.OnSysInfoTick))
+ require.True(t, h.Provides(mqtt.StoredClients))
+ require.True(t, h.Provides(mqtt.StoredInflightMessages))
+ require.True(t, h.Provides(mqtt.StoredRetainedMessages))
+ require.True(t, h.Provides(mqtt.StoredSubscriptions))
+ require.True(t, h.Provides(mqtt.StoredSysInfo))
+ require.False(t, h.Provides(mqtt.OnACLCheck))
+ require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
+}
+
+func TestInitBadConfig(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ err := h.Init(map[string]any{})
+ require.Error(t, err)
+}
+
+func TestInitBadOption(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ err := h.Init(&Options{
+ Options: &badgerdb.Options{
+ NumCompactors: 1,
+ },
+ })
+ // Cannot have 1 compactor. Need at least 2
+ require.Error(t, err)
+}
+
+func TestInitUseDefaults(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ require.Equal(t, defaultDbFile, h.config.Path)
+}
+
+func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ h.OnSessionEstablished(client, packets.Packet{})
+
+ r := new(storage.Client)
+ err = h.getKv(clientKey(client), r)
+ require.NoError(t, err)
+ require.Equal(t, client.ID, r.ID)
+ require.Equal(t, client.Properties.Username, r.Username)
+ require.Equal(t, client.Properties.Clean, r.Clean)
+ require.Equal(t, client.Net.Remote, r.Remote)
+ require.Equal(t, client.Net.Listener, r.Listener)
+ require.NotSame(t, client, r)
+
+ h.OnDisconnect(client, nil, false)
+ r2 := new(storage.Client)
+ err = h.getKv(clientKey(client), r2)
+ require.NoError(t, err)
+ require.Equal(t, client.ID, r.ID)
+
+ h.OnDisconnect(client, nil, true)
+ r3 := new(storage.Client)
+ err = h.getKv(clientKey(client), r3)
+ require.Error(t, err)
+ require.ErrorIs(t, badgerdb.ErrKeyNotFound, err)
+ require.Empty(t, r3.ID)
+}
+
+func TestOnClientExpired(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ cl := &mqtt.Client{ID: "cl1"}
+ clientKey := clientKey(cl)
+
+ err = h.setKv(clientKey, &storage.Client{ID: cl.ID})
+ require.NoError(t, err)
+
+ r := new(storage.Client)
+ err = h.getKv(clientKey, r)
+
+ require.NoError(t, err)
+ require.Equal(t, cl.ID, r.ID)
+
+ h.OnClientExpired(cl)
+ err = h.getKv(clientKey, r)
+ require.Error(t, err)
+ require.ErrorIs(t, badgerdb.ErrKeyNotFound, err)
+}
+
+func TestOnClientExpiredNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnClientExpired(client)
+}
+
+func TestOnClientExpiredClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnClientExpired(client)
+}
+
+func TestOnSessionEstablishedNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnSessionEstablished(client, packets.Packet{})
+}
+
+func TestOnSessionEstablishedClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnSessionEstablished(client, packets.Packet{})
+}
+
+func TestOnWillSent(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ c1 := client
+ c1.Properties.Will.Flag = 1
+ h.OnWillSent(c1, packets.Packet{})
+
+ r := new(storage.Client)
+ err = h.getKv(clientKey(client), r)
+ require.NoError(t, err)
+
+ require.Equal(t, uint32(1), r.Will.Flag)
+ require.NotSame(t, client, r)
+}
+
+func TestOnDisconnectNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnDisconnect(client, nil, false)
+}
+
+func TestOnDisconnectClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnDisconnect(client, nil, false)
+}
+
+func TestOnDisconnectSessionTakenOver(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+
+ testClient := &mqtt.Client{
+ ID: "test",
+ Net: mqtt.ClientConnection{
+ Remote: "test.addr",
+ Listener: "listener",
+ },
+ Properties: mqtt.ClientProperties{
+ Username: []byte("username"),
+ Clean: false,
+ },
+ }
+
+ testClient.Stop(packets.ErrSessionTakenOver)
+ teardown(t, h.config.Path, h)
+ h.OnDisconnect(testClient, nil, true)
+}
+
+func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ h.OnSubscribed(client, pkf, []byte{0})
+ r := new(storage.Subscription)
+
+ err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
+ require.NoError(t, err)
+ require.Equal(t, client.ID, r.Client)
+ require.Equal(t, pkf.Filters[0].Filter, r.Filter)
+ require.Equal(t, byte(0), r.Qos)
+
+ h.OnUnsubscribed(client, pkf)
+ err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
+ require.Error(t, err)
+ require.Equal(t, badgerdb.ErrKeyNotFound, err)
+}
+
+func TestOnSubscribedNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnSubscribed(client, pkf, []byte{0})
+}
+
+func TestOnSubscribedClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnSubscribed(client, pkf, []byte{0})
+}
+
+func TestOnUnsubscribedNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnUnsubscribed(client, pkf)
+}
+
+func TestOnUnsubscribedClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnUnsubscribed(client, pkf)
+}
+
+func TestOnRetainMessageThenUnset(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ pk := packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Retain: true,
+ },
+ Payload: []byte("hello"),
+ TopicName: "a/b/c",
+ }
+
+ h.OnRetainMessage(client, pk, 1)
+
+ r := new(storage.Message)
+ err = h.getKv(retainedKey(pk.TopicName), r)
+ require.NoError(t, err)
+ require.Equal(t, pk.TopicName, r.TopicName)
+ require.Equal(t, pk.Payload, r.Payload)
+
+ h.OnRetainMessage(client, pk, -1)
+ err = h.getKv(retainedKey(pk.TopicName), r)
+ require.Error(t, err)
+ require.ErrorIs(t, err, badgerdb.ErrKeyNotFound)
+
+ // coverage: delete deleted
+ h.OnRetainMessage(client, pk, -1)
+ err = h.getKv(retainedKey(pk.TopicName), r)
+ require.Error(t, err)
+ require.ErrorIs(t, err, badgerdb.ErrKeyNotFound)
+}
+
+func TestOnRetainedExpired(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ m := &storage.Message{
+ ID: retainedKey("a/b/c"),
+ T: storage.RetainedKey,
+ TopicName: "a/b/c",
+ }
+
+ err = h.setKv(m.ID, m)
+ require.NoError(t, err)
+
+ r := new(storage.Message)
+ err = h.getKv(m.ID, r)
+ require.NoError(t, err)
+ require.Equal(t, m.TopicName, r.TopicName)
+
+ h.OnRetainedExpired(m.TopicName)
+ err = h.getKv(m.ID, r)
+ require.Error(t, err)
+ require.ErrorIs(t, err, badgerdb.ErrKeyNotFound)
+}
+
+func TestOnRetainExpiredNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnRetainedExpired("a/b/c")
+}
+
+func TestOnRetainExpiredClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnRetainedExpired("a/b/c")
+}
+
+func TestOnRetainMessageNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnRetainMessage(client, packets.Packet{}, 0)
+}
+
+func TestOnRetainMessageClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnRetainMessage(client, packets.Packet{}, 0)
+}
+
+func TestOnQosPublishThenQOSComplete(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ pk := packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Retain: true,
+ Qos: 2,
+ },
+ Payload: []byte("hello"),
+ TopicName: "a/b/c",
+ }
+
+ h.OnQosPublish(client, pk, time.Now().Unix(), 0)
+
+ r := new(storage.Message)
+ err = h.getKv(inflightKey(client, pk), r)
+ require.NoError(t, err)
+ require.Equal(t, pk.TopicName, r.TopicName)
+ require.Equal(t, pk.Payload, r.Payload)
+
+ // ensure dates are properly saved
+ require.True(t, r.Sent > 0)
+ require.True(t, time.Now().Unix()-1 < r.Sent)
+
+ // OnQosDropped is a passthrough to OnQosComplete here
+ h.OnQosDropped(client, pk)
+ err = h.getKv(inflightKey(client, pk), r)
+ require.Error(t, err)
+ require.ErrorIs(t, err, badgerdb.ErrKeyNotFound)
+}
+
+func TestOnQosPublishNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
+}
+
+func TestOnQosPublishClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
+}
+
+func TestOnQosCompleteNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnQosComplete(client, packets.Packet{})
+}
+
+func TestOnQosCompleteClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnQosComplete(client, packets.Packet{})
+}
+
+func TestOnQosDroppedNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnQosDropped(client, packets.Packet{})
+}
+
+func TestOnSysInfoTick(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ info := &system.Info{
+ Version: "2.0.0",
+ BytesReceived: 100,
+ }
+
+ h.OnSysInfoTick(info)
+
+ r := new(storage.SystemInfo)
+ err = h.getKv(storage.SysInfoKey, r)
+ require.NoError(t, err)
+ require.Equal(t, info.Version, r.Version)
+ require.Equal(t, info.BytesReceived, r.BytesReceived)
+ require.NotSame(t, info, r)
+}
+
+func TestOnSysInfoTickNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnSysInfoTick(new(system.Info))
+}
+
+func TestOnSysInfoTickClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnSysInfoTick(new(system.Info))
+}
+
+func TestStoredClients(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ // populate with clients
+ err = h.setKv(storage.ClientKey+"_cl1", &storage.Client{ID: "cl1", T: storage.ClientKey})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.ClientKey+"_cl2", &storage.Client{ID: "cl2", T: storage.ClientKey})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.ClientKey+"_cl3", &storage.Client{ID: "cl3", T: storage.ClientKey})
+ require.NoError(t, err)
+
+ r, err := h.StoredClients()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ require.Equal(t, "cl1", r[0].ID)
+ require.Equal(t, "cl2", r[1].ID)
+ require.Equal(t, "cl3", r[2].ID)
+}
+
+func TestStoredClientsNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredClients()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestStoredSubscriptions(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ // populate with subscriptions
+ err = h.setKv(storage.SubscriptionKey+"_sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.SubscriptionKey+"_sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.SubscriptionKey+"_sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey})
+ require.NoError(t, err)
+
+ r, err := h.StoredSubscriptions()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ require.Equal(t, "sub1", r[0].ID)
+ require.Equal(t, "sub2", r[1].ID)
+ require.Equal(t, "sub3", r[2].ID)
+}
+
+func TestStoredSubscriptionsNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredSubscriptions()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestStoredRetainedMessages(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ // populate with messages
+ err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.RetainedKey+"_m2", &storage.Message{ID: "m2", T: storage.RetainedKey})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.RetainedKey+"_m3", &storage.Message{ID: "m3", T: storage.RetainedKey})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3", T: storage.InflightKey})
+ require.NoError(t, err)
+
+ r, err := h.StoredRetainedMessages()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ require.Equal(t, "m1", r[0].ID)
+ require.Equal(t, "m2", r[1].ID)
+ require.Equal(t, "m3", r[2].ID)
+}
+
+func TestStoredRetainedMessagesNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredRetainedMessages()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestStoredInflightMessages(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ // populate with messages
+ err = h.setKv(storage.InflightKey+"_i1", &storage.Message{ID: "i1", T: storage.InflightKey})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.InflightKey+"_i2", &storage.Message{ID: "i2", T: storage.InflightKey})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3", T: storage.InflightKey})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
+ require.NoError(t, err)
+
+ r, err := h.StoredInflightMessages()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ require.Equal(t, "i1", r[0].ID)
+ require.Equal(t, "i2", r[1].ID)
+ require.Equal(t, "i3", r[2].ID)
+}
+
+func TestStoredInflightMessagesNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredInflightMessages()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestStoredSysInfo(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ // populate with messages
+ err = h.setKv(storage.SysInfoKey, &storage.SystemInfo{
+ ID: storage.SysInfoKey,
+ Info: system.Info{
+ Version: "2.0.0",
+ },
+ T: storage.SysInfoKey,
+ })
+ require.NoError(t, err)
+
+ r, err := h.StoredSysInfo()
+ require.NoError(t, err)
+ require.Equal(t, "2.0.0", r.Info.Version)
+}
+
+func TestStoredSysInfoNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredSysInfo()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestErrorf(t *testing.T) {
+ // coverage: one day check log hook
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.Errorf("test", 1, 2, 3)
+}
+
+func TestWarningf(t *testing.T) {
+ // coverage: one day check log hook
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.Warningf("test", 1, 2, 3)
+}
+
+func TestInfof(t *testing.T) {
+ // coverage: one day check log hook
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.Infof("test", 1, 2, 3)
+}
+
+func TestDebugf(t *testing.T) {
+ // coverage: one day check log hook
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.Debugf("test", 1, 2, 3)
+}
+
+func TestGcLoop(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ opts := badgerdb.DefaultOptions(defaultDbFile)
+ opts.ValueLogFileSize = 1 << 20
+ h.Init(&Options{
+ GcInterval: 2, // Set the interval for garbage collection.
+ Options: &opts,
+ })
+ defer teardown(t, defaultDbFile, h)
+ h.OnSessionEstablished(client, packets.Packet{})
+ h.OnDisconnect(client, nil, true)
+ time.Sleep(3 * time.Second)
+}
+
+func TestGetSetDelKv(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ defer teardown(t, h.config.Path, h)
+ require.NoError(t, err)
+
+ err = h.setKv("testKey", &storage.Client{ID: "testId"})
+ require.NoError(t, err)
+
+ var obj storage.Client
+ err = h.getKv("testKey", &obj)
+ require.NoError(t, err)
+
+ err = h.delKv("testKey")
+ require.NoError(t, err)
+
+ err = h.getKv("testKey", &obj)
+ require.Error(t, err)
+ require.ErrorIs(t, badgerdb.ErrKeyNotFound, err)
+}
+
+func TestIterKv(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ err := h.Init(nil)
+ defer teardown(t, h.config.Path, h)
+ require.NoError(t, err)
+
+ h.setKv("prefix_a_1", &storage.Client{ID: "1"})
+ h.setKv("prefix_a_2", &storage.Client{ID: "2"})
+ h.setKv("prefix_b_2", &storage.Client{ID: "3"})
+
+ var clients []storage.Client
+ err = h.iterKv("prefix_a", func(data []byte) error {
+ var item storage.Client
+ item.UnmarshalBinary(data)
+ clients = append(clients, item)
+ return nil
+ })
+ require.Equal(t, 2, len(clients))
+ require.NoError(t, err)
+
+ visitErr := errors.New("iter visit error")
+ err = h.iterKv("prefix_b", func(data []byte) error {
+ return visitErr
+ })
+ require.ErrorIs(t, visitErr, err)
+}
diff --git a/hooks/storage/bolt/bolt.go b/hooks/storage/bolt/bolt.go
new file mode 100644
index 0000000..6d6235a
--- /dev/null
+++ b/hooks/storage/bolt/bolt.go
@@ -0,0 +1,525 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co, werbenhu
+
+// Package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
+package bolt
+
+import (
+ "bytes"
+ "errors"
+ "time"
+
+ "testmqtt/hooks/storage"
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+ "testmqtt/system"
+
+ "go.etcd.io/bbolt"
+)
+
+var (
+ ErrBucketNotFound = errors.New("bucket not found")
+ ErrKeyNotFound = errors.New("key not found")
+)
+
+const (
+ // defaultDbFile is the default file path for the boltdb file.
+ defaultDbFile = ".bolt"
+
+ // defaultTimeout is the default time to hold a connection to the file.
+ defaultTimeout = 250 * time.Millisecond
+
+ defaultBucket = "mochi"
+)
+
+// clientKey returns a primary key for a client.
+func clientKey(cl *mqtt.Client) string {
+ return storage.ClientKey + "_" + cl.ID
+}
+
+// subscriptionKey returns a primary key for a subscription.
+func subscriptionKey(cl *mqtt.Client, filter string) string {
+ return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
+}
+
+// retainedKey returns a primary key for a retained message.
+func retainedKey(topic string) string {
+ return storage.RetainedKey + "_" + topic
+}
+
+// inflightKey returns a primary key for an inflight message.
+func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
+ return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
+}
+
+// sysInfoKey returns a primary key for system info.
+func sysInfoKey() string {
+ return storage.SysInfoKey
+}
+
+// Options contains configuration settings for the bolt instance.
+type Options struct {
+ Options *bbolt.Options
+ Bucket string `yaml:"bucket" json:"bucket"`
+ Path string `yaml:"path" json:"path"`
+}
+
+// Hook is a persistent storage hook based using boltdb file store as a backend.
+type Hook struct {
+ mqtt.HookBase
+ config *Options // options for configuring the boltdb instance.
+ db *bbolt.DB // the boltdb instance.
+}
+
+// ID returns the id of the hook.
+func (h *Hook) ID() string {
+ return "bolt-db"
+}
+
+// Provides indicates which hook methods this hook provides.
+func (h *Hook) Provides(b byte) bool {
+ return bytes.Contains([]byte{
+ mqtt.OnSessionEstablished,
+ mqtt.OnDisconnect,
+ mqtt.OnSubscribed,
+ mqtt.OnUnsubscribed,
+ mqtt.OnRetainMessage,
+ mqtt.OnWillSent,
+ mqtt.OnQosPublish,
+ mqtt.OnQosComplete,
+ mqtt.OnQosDropped,
+ mqtt.OnSysInfoTick,
+ mqtt.OnClientExpired,
+ mqtt.OnRetainedExpired,
+ mqtt.StoredClients,
+ mqtt.StoredInflightMessages,
+ mqtt.StoredRetainedMessages,
+ mqtt.StoredSubscriptions,
+ mqtt.StoredSysInfo,
+ }, []byte{b})
+}
+
+// Init initializes and connects to the boltdb instance.
+func (h *Hook) Init(config any) error {
+ if _, ok := config.(*Options); !ok && config != nil {
+ return mqtt.ErrInvalidConfigType
+ }
+
+ if config == nil {
+ config = new(Options)
+ }
+
+ h.config = config.(*Options)
+ if h.config.Options == nil {
+ h.config.Options = &bbolt.Options{
+ Timeout: defaultTimeout,
+ }
+ }
+ if len(h.config.Path) == 0 {
+ h.config.Path = defaultDbFile
+ }
+
+ if len(h.config.Bucket) == 0 {
+ h.config.Bucket = defaultBucket
+ }
+
+ var err error
+ h.db, err = bbolt.Open(h.config.Path, 0600, h.config.Options)
+ if err != nil {
+ return err
+ }
+
+ err = h.db.Update(func(tx *bbolt.Tx) error {
+ _, err := tx.CreateBucketIfNotExists([]byte(h.config.Bucket))
+ return err
+ })
+ return err
+}
+
+// Stop closes the boltdb instance.
+func (h *Hook) Stop() error {
+ err := h.db.Close()
+ h.db = nil
+ return err
+}
+
+// OnSessionEstablished adds a client to the store when their session is established.
+func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
+ h.updateClient(cl)
+}
+
+// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
+func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
+ h.updateClient(cl)
+}
+
+// updateClient writes the client data to the store.
+func (h *Hook) updateClient(cl *mqtt.Client) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ props := cl.Properties.Props.Copy(false)
+ in := &storage.Client{
+ ID: cl.ID,
+ T: storage.ClientKey,
+ Remote: cl.Net.Remote,
+ Listener: cl.Net.Listener,
+ Username: cl.Properties.Username,
+ Clean: cl.Properties.Clean,
+ ProtocolVersion: cl.Properties.ProtocolVersion,
+ Properties: storage.ClientProperties{
+ SessionExpiryInterval: props.SessionExpiryInterval,
+ AuthenticationMethod: props.AuthenticationMethod,
+ AuthenticationData: props.AuthenticationData,
+ RequestProblemInfo: props.RequestProblemInfo,
+ RequestResponseInfo: props.RequestResponseInfo,
+ ReceiveMaximum: props.ReceiveMaximum,
+ TopicAliasMaximum: props.TopicAliasMaximum,
+ User: props.User,
+ MaximumPacketSize: props.MaximumPacketSize,
+ },
+ Will: storage.ClientWill(cl.Properties.Will),
+ }
+
+ _ = h.setKv(clientKey(cl), in)
+}
+
+// OnDisconnect removes a client from the store if they were using a clean session.
+func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ if !expire {
+ return
+ }
+
+ if cl.StopCause() == packets.ErrSessionTakenOver {
+ return
+ }
+
+ _ = h.delKv(clientKey(cl))
+}
+
+// OnSubscribed adds one or more client subscriptions to the store.
+func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ var in *storage.Subscription
+ for i := 0; i < len(pk.Filters); i++ {
+ in = &storage.Subscription{
+ ID: subscriptionKey(cl, pk.Filters[i].Filter),
+ T: storage.SubscriptionKey,
+ Client: cl.ID,
+ Qos: reasonCodes[i],
+ Filter: pk.Filters[i].Filter,
+ Identifier: pk.Filters[i].Identifier,
+ NoLocal: pk.Filters[i].NoLocal,
+ RetainHandling: pk.Filters[i].RetainHandling,
+ RetainAsPublished: pk.Filters[i].RetainAsPublished,
+ }
+ _ = h.setKv(in.ID, in)
+ }
+}
+
+// OnUnsubscribed removes one or more client subscriptions from the store.
+func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ for i := 0; i < len(pk.Filters); i++ {
+ _ = h.delKv(subscriptionKey(cl, pk.Filters[i].Filter))
+ }
+}
+
+// OnRetainMessage adds a retained message for a topic to the store.
+func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ if r == -1 {
+ _ = h.delKv(retainedKey(pk.TopicName))
+ return
+ }
+
+ props := pk.Properties.Copy(false)
+ in := &storage.Message{
+ ID: retainedKey(pk.TopicName),
+ T: storage.RetainedKey,
+ FixedHeader: pk.FixedHeader,
+ TopicName: pk.TopicName,
+ Payload: pk.Payload,
+ Created: pk.Created,
+ Origin: pk.Origin,
+ Properties: storage.MessageProperties{
+ PayloadFormat: props.PayloadFormat,
+ MessageExpiryInterval: props.MessageExpiryInterval,
+ ContentType: props.ContentType,
+ ResponseTopic: props.ResponseTopic,
+ CorrelationData: props.CorrelationData,
+ SubscriptionIdentifier: props.SubscriptionIdentifier,
+ TopicAlias: props.TopicAlias,
+ User: props.User,
+ },
+ }
+
+ _ = h.setKv(in.ID, in)
+}
+
+// OnQosPublish adds or updates an inflight message in the store.
+func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ props := pk.Properties.Copy(false)
+ in := &storage.Message{
+ ID: inflightKey(cl, pk),
+ T: storage.InflightKey,
+ Origin: pk.Origin,
+ FixedHeader: pk.FixedHeader,
+ TopicName: pk.TopicName,
+ Payload: pk.Payload,
+ Sent: sent,
+ Created: pk.Created,
+ Properties: storage.MessageProperties{
+ PayloadFormat: props.PayloadFormat,
+ MessageExpiryInterval: props.MessageExpiryInterval,
+ ContentType: props.ContentType,
+ ResponseTopic: props.ResponseTopic,
+ CorrelationData: props.CorrelationData,
+ SubscriptionIdentifier: props.SubscriptionIdentifier,
+ TopicAlias: props.TopicAlias,
+ User: props.User,
+ },
+ }
+
+ _ = h.setKv(in.ID, in)
+}
+
+// OnQosComplete removes a resolved inflight message from the store.
+func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ _ = h.delKv(inflightKey(cl, pk))
+}
+
+// OnQosDropped removes a dropped inflight message from the store.
+func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ }
+
+ h.OnQosComplete(cl, pk)
+}
+
+// OnSysInfoTick stores the latest system info in the store.
+func (h *Hook) OnSysInfoTick(sys *system.Info) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ in := &storage.SystemInfo{
+ ID: sysInfoKey(),
+ T: storage.SysInfoKey,
+ Info: *sys,
+ }
+
+ _ = h.setKv(in.ID, in)
+}
+
+// OnRetainedExpired deletes expired retained messages from the store.
+func (h *Hook) OnRetainedExpired(filter string) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+ _ = h.delKv(retainedKey(filter))
+}
+
+// OnClientExpired deleted expired clients from the store.
+func (h *Hook) OnClientExpired(cl *mqtt.Client) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+ _ = h.delKv(clientKey(cl))
+}
+
+// StoredClients returns all stored clients from the store.
+func (h *Hook) StoredClients() (v []storage.Client, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return v, storage.ErrDBFileNotOpen
+ }
+
+ err = h.iterKv(storage.ClientKey, func(value []byte) error {
+ obj := storage.Client{}
+ err = obj.UnmarshalBinary(value)
+ if err == nil {
+ v = append(v, obj)
+ }
+ return err
+ })
+ return v, nil
+}
+
+// StoredSubscriptions returns all stored subscriptions from the store.
+func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return v, storage.ErrDBFileNotOpen
+ }
+
+ v = make([]storage.Subscription, 0)
+ err = h.iterKv(storage.SubscriptionKey, func(value []byte) error {
+ obj := storage.Subscription{}
+ err = obj.UnmarshalBinary(value)
+ if err == nil {
+ v = append(v, obj)
+ }
+ return err
+ })
+ return
+}
+
+// StoredRetainedMessages returns all stored retained messages from the store.
+func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return v, storage.ErrDBFileNotOpen
+ }
+
+ v = make([]storage.Message, 0)
+ err = h.iterKv(storage.RetainedKey, func(value []byte) error {
+ obj := storage.Message{}
+ err = obj.UnmarshalBinary(value)
+ if err == nil {
+ v = append(v, obj)
+ }
+ return err
+ })
+ return
+}
+
+// StoredInflightMessages returns all stored inflight messages from the store.
+func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return v, storage.ErrDBFileNotOpen
+ }
+
+ v = make([]storage.Message, 0)
+ err = h.iterKv(storage.InflightKey, func(value []byte) error {
+ obj := storage.Message{}
+ err = obj.UnmarshalBinary(value)
+ if err == nil {
+ v = append(v, obj)
+ }
+ return err
+ })
+ return
+}
+
+// StoredSysInfo returns the system info from the store.
+func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return v, storage.ErrDBFileNotOpen
+ }
+
+ err = h.getKv(storage.SysInfoKey, &v)
+ if err != nil && !errors.Is(err, ErrKeyNotFound) {
+ return
+ }
+
+ return v, nil
+}
+
+// setKv stores a key-value pair in the database.
+func (h *Hook) setKv(k string, v storage.Serializable) error {
+ err := h.db.Update(func(tx *bbolt.Tx) error {
+
+ bucket := tx.Bucket([]byte(h.config.Bucket))
+ data, _ := v.MarshalBinary()
+ err := bucket.Put([]byte(k), data)
+ if err != nil {
+ return err
+ }
+ return nil
+ })
+ if err != nil {
+ h.Log.Error("failed to upsert data", "error", err, "key", k)
+ }
+ return err
+}
+
+// delKv deletes a key-value pair from the database.
+func (h *Hook) delKv(k string) error {
+ err := h.db.Update(func(tx *bbolt.Tx) error {
+ bucket := tx.Bucket([]byte(h.config.Bucket))
+ err := bucket.Delete([]byte(k))
+ if err != nil {
+ return err
+ }
+ return nil
+ })
+ if err != nil {
+ h.Log.Error("failed to delete data", "error", err, "key", k)
+ }
+ return err
+}
+
+// getKv retrieves the value associated with a key from the database.
+func (h *Hook) getKv(k string, v storage.Serializable) error {
+ err := h.db.View(func(tx *bbolt.Tx) error {
+ bucket := tx.Bucket([]byte(h.config.Bucket))
+
+ value := bucket.Get([]byte(k))
+ if value == nil {
+ return ErrKeyNotFound
+ }
+
+ return v.UnmarshalBinary(value)
+ })
+ if err != nil {
+ h.Log.Error("failed to get data", "error", err, "key", k)
+ }
+ return err
+}
+
+// iterKv iterates over key-value pairs with keys having the specified prefix in the database.
+func (h *Hook) iterKv(prefix string, visit func([]byte) error) error {
+ err := h.db.View(func(tx *bbolt.Tx) error {
+ bucket := tx.Bucket([]byte(h.config.Bucket))
+
+ c := bucket.Cursor()
+ for k, v := c.Seek([]byte(prefix)); k != nil && string(k[:len(prefix)]) == prefix; k, v = c.Next() {
+ if err := visit(v); err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+
+ if err != nil {
+ h.Log.Error("failed to iter data", "error", err, "prefix", prefix)
+ }
+ return err
+}
diff --git a/hooks/storage/bolt/bolt_test.go b/hooks/storage/bolt/bolt_test.go
new file mode 100644
index 0000000..780162f
--- /dev/null
+++ b/hooks/storage/bolt/bolt_test.go
@@ -0,0 +1,791 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co, werbenhu
+
+package bolt
+
+import (
+ "errors"
+ "log/slog"
+ "os"
+ "testing"
+ "time"
+
+ "testmqtt/hooks/storage"
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+ "testmqtt/system"
+
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
+
+ client = &mqtt.Client{
+ ID: "test",
+ Net: mqtt.ClientConnection{
+ Remote: "test.addr",
+ Listener: "listener",
+ },
+ Properties: mqtt.ClientProperties{
+ Username: []byte("username"),
+ Clean: false,
+ },
+ }
+
+ pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
+)
+
+func teardown(t *testing.T, path string, h *Hook) {
+ _ = h.Stop()
+ err := os.Remove(path)
+ require.NoError(t, err)
+}
+
+func TestClientKey(t *testing.T) {
+ k := clientKey(&mqtt.Client{ID: "cl1"})
+ require.Equal(t, "CL_cl1", k)
+}
+
+func TestSubscriptionKey(t *testing.T) {
+ k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
+ require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
+}
+
+func TestRetainedKey(t *testing.T) {
+ k := retainedKey("a/b/c")
+ require.Equal(t, storage.RetainedKey+"_a/b/c", k)
+}
+
+func TestInflightKey(t *testing.T) {
+ k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
+ require.Equal(t, storage.InflightKey+"_cl1:1", k)
+}
+
+func TestSysInfoKey(t *testing.T) {
+ require.Equal(t, storage.SysInfoKey, sysInfoKey())
+}
+
+func TestID(t *testing.T) {
+ h := new(Hook)
+ require.Equal(t, "bolt-db", h.ID())
+}
+
+func TestProvides(t *testing.T) {
+ h := new(Hook)
+ require.True(t, h.Provides(mqtt.OnSessionEstablished))
+ require.True(t, h.Provides(mqtt.OnDisconnect))
+ require.True(t, h.Provides(mqtt.OnSubscribed))
+ require.True(t, h.Provides(mqtt.OnUnsubscribed))
+ require.True(t, h.Provides(mqtt.OnRetainMessage))
+ require.True(t, h.Provides(mqtt.OnQosPublish))
+ require.True(t, h.Provides(mqtt.OnQosComplete))
+ require.True(t, h.Provides(mqtt.OnQosDropped))
+ require.True(t, h.Provides(mqtt.OnSysInfoTick))
+ require.True(t, h.Provides(mqtt.StoredClients))
+ require.True(t, h.Provides(mqtt.StoredInflightMessages))
+ require.True(t, h.Provides(mqtt.StoredRetainedMessages))
+ require.True(t, h.Provides(mqtt.StoredSubscriptions))
+ require.True(t, h.Provides(mqtt.StoredSysInfo))
+ require.False(t, h.Provides(mqtt.OnACLCheck))
+ require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
+}
+
+func TestInitBadConfig(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ err := h.Init(map[string]any{})
+ require.Error(t, err)
+}
+
+func TestInitUseDefaults(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ require.Equal(t, defaultTimeout, h.config.Options.Timeout)
+ require.Equal(t, defaultDbFile, h.config.Path)
+}
+
+func TestInitBadPath(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(&Options{
+ Path: "..",
+ })
+ require.Error(t, err)
+}
+
+func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ h.OnSessionEstablished(client, packets.Packet{})
+
+ r := new(storage.Client)
+ err = h.getKv(clientKey(client), r)
+ require.NoError(t, err)
+ require.Equal(t, client.ID, r.ID)
+ require.Equal(t, client.Net.Remote, r.Remote)
+ require.Equal(t, client.Net.Listener, r.Listener)
+ require.Equal(t, client.Properties.Username, r.Username)
+ require.Equal(t, client.Properties.Clean, r.Clean)
+ require.NotSame(t, client, r)
+
+ h.OnDisconnect(client, nil, false)
+ r2 := new(storage.Client)
+ err = h.getKv(clientKey(client), r2)
+ require.NoError(t, err)
+ require.Equal(t, client.ID, r.ID)
+
+ h.OnDisconnect(client, nil, true)
+ r3 := new(storage.Client)
+ err = h.getKv(clientKey(client), r3)
+ require.Error(t, err)
+ require.ErrorIs(t, ErrKeyNotFound, err)
+ require.Empty(t, r3.ID)
+}
+
+func TestOnSessionEstablishedNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnSessionEstablished(client, packets.Packet{})
+}
+
+func TestOnSessionEstablishedClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnSessionEstablished(client, packets.Packet{})
+}
+
+func TestOnWillSent(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ c1 := client
+ c1.Properties.Will.Flag = 1
+ h.OnWillSent(c1, packets.Packet{})
+
+ r := new(storage.Client)
+ err = h.getKv(clientKey(client), r)
+ require.NoError(t, err)
+
+ require.Equal(t, uint32(1), r.Will.Flag)
+ require.NotSame(t, client, r)
+}
+
+func TestOnClientExpired(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ cl := &mqtt.Client{ID: "cl1"}
+ clientKey := clientKey(cl)
+
+ err = h.setKv(clientKey, &storage.Client{ID: cl.ID})
+ require.NoError(t, err)
+
+ r := new(storage.Client)
+ err = h.getKv(clientKey, r)
+ require.NoError(t, err)
+ require.Equal(t, cl.ID, r.ID)
+
+ h.OnClientExpired(cl)
+ err = h.getKv(clientKey, r)
+ require.Error(t, err)
+ require.ErrorIs(t, ErrKeyNotFound, err)
+}
+
+func TestOnClientExpiredClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnClientExpired(client)
+}
+
+func TestOnClientExpiredNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnClientExpired(client)
+}
+
+func TestOnDisconnectNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnDisconnect(client, nil, false)
+}
+
+func TestOnDisconnectClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnDisconnect(client, nil, false)
+}
+
+func TestOnDisconnectSessionTakenOver(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+
+ testClient := &mqtt.Client{
+ ID: "test",
+ Net: mqtt.ClientConnection{
+ Remote: "test.addr",
+ Listener: "listener",
+ },
+ Properties: mqtt.ClientProperties{
+ Username: []byte("username"),
+ Clean: false,
+ },
+ }
+
+ testClient.Stop(packets.ErrSessionTakenOver)
+ teardown(t, h.config.Path, h)
+ h.OnDisconnect(testClient, nil, true)
+}
+
+func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ h.OnSubscribed(client, pkf, []byte{0})
+ r := new(storage.Subscription)
+
+ err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
+ require.NoError(t, err)
+ require.Equal(t, client.ID, r.Client)
+ require.Equal(t, pkf.Filters[0].Filter, r.Filter)
+ require.Equal(t, byte(0), r.Qos)
+
+ h.OnUnsubscribed(client, pkf)
+ err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
+ require.Error(t, err)
+ require.Equal(t, ErrKeyNotFound, err)
+}
+
+func TestOnSubscribedNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnSubscribed(client, pkf, []byte{0})
+}
+
+func TestOnSubscribedClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnSubscribed(client, pkf, []byte{0})
+}
+
+func TestOnUnsubscribedNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnUnsubscribed(client, pkf)
+}
+
+func TestOnUnsubscribedClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnUnsubscribed(client, pkf)
+}
+
+func TestOnRetainMessageThenUnset(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ pk := packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Retain: true,
+ },
+ Payload: []byte("hello"),
+ TopicName: "a/b/c",
+ }
+
+ h.OnRetainMessage(client, pk, 1)
+
+ r := new(storage.Message)
+ err = h.getKv(retainedKey(pk.TopicName), r)
+ require.NoError(t, err)
+ require.Equal(t, pk.TopicName, r.TopicName)
+ require.Equal(t, pk.Payload, r.Payload)
+
+ h.OnRetainMessage(client, pk, -1)
+ err = h.getKv(retainedKey(pk.TopicName), r)
+ require.Error(t, err)
+ require.Equal(t, ErrKeyNotFound, err)
+
+ // coverage: delete deleted
+ h.OnRetainMessage(client, pk, -1)
+ err = h.getKv(retainedKey(pk.TopicName), r)
+ require.Error(t, err)
+ require.Equal(t, ErrKeyNotFound, err)
+}
+
+func TestOnRetainedExpired(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ m := &storage.Message{
+ ID: retainedKey("a/b/c"),
+ T: storage.RetainedKey,
+ TopicName: "a/b/c",
+ }
+
+ err = h.setKv(m.ID, m)
+ require.NoError(t, err)
+
+ r := new(storage.Message)
+ err = h.getKv(m.ID, r)
+ require.NoError(t, err)
+ require.Equal(t, m.TopicName, r.TopicName)
+
+ h.OnRetainedExpired(m.TopicName)
+ err = h.getKv(m.ID, r)
+ require.Error(t, err)
+ require.Equal(t, ErrKeyNotFound, err)
+}
+
+func TestOnRetainedExpiredClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnRetainedExpired("a/b/c")
+}
+
+func TestOnRetainedExpiredNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnRetainedExpired("a/b/c")
+}
+
+func TestOnRetainMessageNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnRetainMessage(client, packets.Packet{}, 0)
+}
+
+func TestOnRetainMessageClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnRetainMessage(client, packets.Packet{}, 0)
+}
+
+func TestOnQosPublishThenQOSComplete(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ pk := packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Retain: true,
+ Qos: 2,
+ },
+ Payload: []byte("hello"),
+ TopicName: "a/b/c",
+ }
+
+ h.OnQosPublish(client, pk, time.Now().Unix(), 0)
+
+ r := new(storage.Message)
+ err = h.getKv(inflightKey(client, pk), r)
+ require.NoError(t, err)
+ require.Equal(t, pk.TopicName, r.TopicName)
+ require.Equal(t, pk.Payload, r.Payload)
+
+ // ensure dates are properly saved to bolt
+ require.True(t, r.Sent > 0)
+ require.True(t, time.Now().Unix()-1 < r.Sent)
+
+ // OnQosDropped is a passthrough to OnQosComplete here
+ h.OnQosDropped(client, pk)
+ err = h.getKv(inflightKey(client, pk), r)
+ require.Error(t, err)
+ require.Equal(t, ErrKeyNotFound, err)
+}
+
+func TestOnQosPublishNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
+}
+
+func TestOnQosPublishClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
+}
+
+func TestOnQosCompleteNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnQosComplete(client, packets.Packet{})
+}
+
+func TestOnQosCompleteClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnQosComplete(client, packets.Packet{})
+}
+
+func TestOnQosDroppedNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnQosDropped(client, packets.Packet{})
+}
+
+func TestOnSysInfoTick(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ info := &system.Info{
+ Version: "2.0.0",
+ BytesReceived: 100,
+ }
+
+ h.OnSysInfoTick(info)
+
+ r := new(storage.SystemInfo)
+ err = h.getKv(storage.SysInfoKey, r)
+ require.NoError(t, err)
+ require.Equal(t, info.Version, r.Version)
+ require.Equal(t, info.BytesReceived, r.BytesReceived)
+ require.NotSame(t, info, r)
+}
+
+func TestOnSysInfoTickNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnSysInfoTick(new(system.Info))
+}
+
+func TestOnSysInfoTickClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnSysInfoTick(new(system.Info))
+}
+
+func TestStoredClients(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ // populate with clients
+ err = h.setKv(storage.ClientKey+"_"+"cl1", &storage.Client{ID: "cl1"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.ClientKey+"_"+"cl2", &storage.Client{ID: "cl2"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.ClientKey+"_"+"cl3", &storage.Client{ID: "cl3"})
+ require.NoError(t, err)
+
+ r, err := h.StoredClients()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ require.Equal(t, "cl1", r[0].ID)
+ require.Equal(t, "cl2", r[1].ID)
+ require.Equal(t, "cl3", r[2].ID)
+}
+
+func TestStoredClientsNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredClients()
+ require.Empty(t, v)
+ require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
+}
+
+func TestStoredClientsClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ v, err := h.StoredClients()
+ require.Empty(t, v)
+ require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
+}
+
+func TestStoredSubscriptions(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ // populate with subscriptions
+ err = h.setKv(storage.SubscriptionKey+"_"+"sub1", &storage.Subscription{ID: "sub1"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.SubscriptionKey+"_"+"sub2", &storage.Subscription{ID: "sub2"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.SubscriptionKey+"_"+"sub3", &storage.Subscription{ID: "sub3"})
+ require.NoError(t, err)
+
+ r, err := h.StoredSubscriptions()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ require.Equal(t, "sub1", r[0].ID)
+ require.Equal(t, "sub2", r[1].ID)
+ require.Equal(t, "sub3", r[2].ID)
+}
+
+func TestStoredSubscriptionsNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredSubscriptions()
+ require.Empty(t, v)
+ require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
+}
+
+func TestStoredSubscriptionsClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ v, err := h.StoredSubscriptions()
+ require.Empty(t, v)
+ require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
+}
+
+func TestStoredRetainedMessages(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ // populate with messages
+ err = h.setKv(storage.RetainedKey+"_"+"m1", &storage.Message{ID: "m1"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.RetainedKey+"_"+"m2", &storage.Message{ID: "m2"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.RetainedKey+"_"+"m3", &storage.Message{ID: "m3"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.InflightKey+"_"+"i3", &storage.Message{ID: "i3"})
+ require.NoError(t, err)
+
+ r, err := h.StoredRetainedMessages()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ require.Equal(t, "m1", r[0].ID)
+ require.Equal(t, "m2", r[1].ID)
+ require.Equal(t, "m3", r[2].ID)
+}
+
+func TestStoredRetainedMessagesNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredRetainedMessages()
+ require.Empty(t, v)
+ require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
+}
+
+func TestStoredRetainedMessagesClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ v, err := h.StoredRetainedMessages()
+ require.Empty(t, v)
+ require.Error(t, err)
+}
+
+func TestStoredInflightMessages(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ // populate with messages
+ err = h.setKv(storage.InflightKey+"_"+"i1", &storage.Message{ID: "i1"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.InflightKey+"_"+"i2", &storage.Message{ID: "i2"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.InflightKey+"_"+"i3", &storage.Message{ID: "i3"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.RetainedKey+"_"+"m1", &storage.Message{ID: "m1"})
+ require.NoError(t, err)
+
+ r, err := h.StoredInflightMessages()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ require.Equal(t, "i1", r[0].ID)
+ require.Equal(t, "i2", r[1].ID)
+ require.Equal(t, "i3", r[2].ID)
+}
+
+func TestStoredInflightMessagesNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredInflightMessages()
+ require.Empty(t, v)
+ require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
+}
+
+func TestStoredInflightMessagesClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ v, err := h.StoredInflightMessages()
+ require.Empty(t, v)
+ require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
+}
+
+func TestStoredSysInfo(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ // populate with sys info
+ err = h.setKv(storage.SysInfoKey, &storage.SystemInfo{
+ ID: storage.SysInfoKey,
+ Info: system.Info{
+ Version: "2.0.0",
+ },
+ T: storage.SysInfoKey,
+ })
+ require.NoError(t, err)
+
+ r, err := h.StoredSysInfo()
+ require.NoError(t, err)
+ require.Equal(t, "2.0.0", r.Info.Version)
+}
+
+func TestStoredSysInfoNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredSysInfo()
+ require.Empty(t, v)
+ require.ErrorIs(t, storage.ErrDBFileNotOpen, err)
+}
+
+func TestStoredSysInfoClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ v, err := h.StoredSysInfo()
+ require.Empty(t, v)
+ require.Error(t, err)
+}
+
+func TestGetSetDelKv(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ defer teardown(t, h.config.Path, h)
+ require.NoError(t, err)
+
+ err = h.setKv("testId", &storage.Client{ID: "testId"})
+ require.NoError(t, err)
+
+ var obj storage.Client
+ err = h.getKv("testId", &obj)
+ require.NoError(t, err)
+
+ err = h.delKv("testId")
+ require.NoError(t, err)
+
+ err = h.getKv("testId", &obj)
+ require.Error(t, err)
+ require.ErrorIs(t, ErrKeyNotFound, err)
+}
+
+func TestIterKv(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ err := h.Init(nil)
+ defer teardown(t, h.config.Path, h)
+ require.NoError(t, err)
+
+ h.setKv("prefix_a_1", &storage.Client{ID: "1"})
+ h.setKv("prefix_a_2", &storage.Client{ID: "2"})
+ h.setKv("prefix_b_2", &storage.Client{ID: "3"})
+
+ var clients []storage.Client
+ err = h.iterKv("prefix_a", func(data []byte) error {
+ var item storage.Client
+ item.UnmarshalBinary(data)
+ clients = append(clients, item)
+ return nil
+ })
+ require.Equal(t, 2, len(clients))
+ require.NoError(t, err)
+
+ visitErr := errors.New("iter visit error")
+ err = h.iterKv("prefix_b", func(data []byte) error {
+ return visitErr
+ })
+ require.ErrorIs(t, visitErr, err)
+}
diff --git a/hooks/storage/pebble/pebble.go b/hooks/storage/pebble/pebble.go
new file mode 100644
index 0000000..2066338
--- /dev/null
+++ b/hooks/storage/pebble/pebble.go
@@ -0,0 +1,524 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: werbenhu
+
+package pebble
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "strings"
+
+ pebbledb "github.com/cockroachdb/pebble"
+ "testmqtt/hooks/storage"
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+ "testmqtt/system"
+)
+
+const (
+ // defaultDbFile is the default file path for the pebble db file.
+ defaultDbFile = ".pebble"
+)
+
+// clientKey returns a primary key for a client.
+func clientKey(cl *mqtt.Client) string {
+ return storage.ClientKey + "_" + cl.ID
+}
+
+// subscriptionKey returns a primary key for a subscription.
+func subscriptionKey(cl *mqtt.Client, filter string) string {
+ return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
+}
+
+// retainedKey returns a primary key for a retained message.
+func retainedKey(topic string) string {
+ return storage.RetainedKey + "_" + topic
+}
+
+// inflightKey returns a primary key for an inflight message.
+func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
+ return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
+}
+
+// sysInfoKey returns a primary key for system info.
+func sysInfoKey() string {
+ return storage.SysInfoKey
+}
+
+// keyUpperBound returns the upper bound for a given byte slice by incrementing the last byte.
+// It returns nil if all bytes are incremented and equal to 0.
+func keyUpperBound(b []byte) []byte {
+ end := make([]byte, len(b))
+ copy(end, b)
+ for i := len(end) - 1; i >= 0; i-- {
+ end[i] = end[i] + 1
+ if end[i] != 0 {
+ return end[:i+1]
+ }
+ }
+ return nil
+}
+
+const (
+ NoSync = "NoSync" // NoSync specifies the default write options for writes which do not synchronize to disk.
+ Sync = "Sync" // Sync specifies the default write options for writes which synchronize to disk.
+)
+
+// Options contains configuration settings for the pebble DB instance.
+type Options struct {
+ Options *pebbledb.Options
+ Mode string `yaml:"mode" json:"mode"`
+ Path string `yaml:"path" json:"path"`
+}
+
+// Hook is a persistent storage hook based using pebble DB file store as a backend.
+type Hook struct {
+ mqtt.HookBase
+ config *Options // options for configuring the pebble DB instance.
+ db *pebbledb.DB // the pebble DB instance
+ mode *pebbledb.WriteOptions // mode holds the optional per-query parameters for Set and Delete operations
+}
+
+// ID returns the id of the hook.
+func (h *Hook) ID() string {
+ return "pebble-db"
+}
+
+// Provides indicates which hook methods this hook provides.
+func (h *Hook) Provides(b byte) bool {
+ return bytes.Contains([]byte{
+ mqtt.OnSessionEstablished,
+ mqtt.OnDisconnect,
+ mqtt.OnSubscribed,
+ mqtt.OnUnsubscribed,
+ mqtt.OnRetainMessage,
+ mqtt.OnWillSent,
+ mqtt.OnQosPublish,
+ mqtt.OnQosComplete,
+ mqtt.OnQosDropped,
+ mqtt.OnSysInfoTick,
+ mqtt.OnClientExpired,
+ mqtt.OnRetainedExpired,
+ mqtt.StoredClients,
+ mqtt.StoredInflightMessages,
+ mqtt.StoredRetainedMessages,
+ mqtt.StoredSubscriptions,
+ mqtt.StoredSysInfo,
+ }, []byte{b})
+}
+
+// Init initializes and connects to the pebble instance.
+func (h *Hook) Init(config any) error {
+ if _, ok := config.(*Options); !ok && config != nil {
+ return mqtt.ErrInvalidConfigType
+ }
+
+ if config == nil {
+ h.config = new(Options)
+ } else {
+ h.config = config.(*Options)
+ }
+
+ if len(h.config.Path) == 0 {
+ h.config.Path = defaultDbFile
+ }
+
+ if h.config.Options == nil {
+ h.config.Options = &pebbledb.Options{}
+ }
+
+ h.mode = pebbledb.NoSync
+ if strings.EqualFold(h.config.Mode, "Sync") {
+ h.mode = pebbledb.Sync
+ }
+
+ var err error
+ h.db, err = pebbledb.Open(h.config.Path, h.config.Options)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// Stop closes the pebble instance.
+func (h *Hook) Stop() error {
+ err := h.db.Close()
+ h.db = nil
+ return err
+}
+
+// OnSessionEstablished adds a client to the store when their session is established.
+func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
+ h.updateClient(cl)
+}
+
+// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
+func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
+ h.updateClient(cl)
+}
+
+// updateClient writes the client data to the store.
+func (h *Hook) updateClient(cl *mqtt.Client) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ props := cl.Properties.Props.Copy(false)
+ in := &storage.Client{
+ ID: cl.ID,
+ T: storage.ClientKey,
+ Remote: cl.Net.Remote,
+ Listener: cl.Net.Listener,
+ Username: cl.Properties.Username,
+ Clean: cl.Properties.Clean,
+ ProtocolVersion: cl.Properties.ProtocolVersion,
+ Properties: storage.ClientProperties{
+ SessionExpiryInterval: props.SessionExpiryInterval,
+ AuthenticationMethod: props.AuthenticationMethod,
+ AuthenticationData: props.AuthenticationData,
+ RequestProblemInfo: props.RequestProblemInfo,
+ RequestResponseInfo: props.RequestResponseInfo,
+ ReceiveMaximum: props.ReceiveMaximum,
+ TopicAliasMaximum: props.TopicAliasMaximum,
+ User: props.User,
+ MaximumPacketSize: props.MaximumPacketSize,
+ },
+ Will: storage.ClientWill(cl.Properties.Will),
+ }
+ h.setKv(clientKey(cl), in)
+}
+
+// OnDisconnect removes a client from the store if their session has expired.
+func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ h.updateClient(cl)
+
+ if !expire {
+ return
+ }
+
+ if errors.Is(cl.StopCause(), packets.ErrSessionTakenOver) {
+ return
+ }
+
+ h.delKv(clientKey(cl))
+}
+
+// OnSubscribed adds one or more client subscriptions to the store.
+func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ var in *storage.Subscription
+ for i := 0; i < len(pk.Filters); i++ {
+ in = &storage.Subscription{
+ ID: subscriptionKey(cl, pk.Filters[i].Filter),
+ T: storage.SubscriptionKey,
+ Client: cl.ID,
+ Qos: reasonCodes[i],
+ Filter: pk.Filters[i].Filter,
+ Identifier: pk.Filters[i].Identifier,
+ NoLocal: pk.Filters[i].NoLocal,
+ RetainHandling: pk.Filters[i].RetainHandling,
+ RetainAsPublished: pk.Filters[i].RetainAsPublished,
+ }
+ h.setKv(in.ID, in)
+ }
+}
+
+// OnUnsubscribed removes one or more client subscriptions from the store.
+func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ for i := 0; i < len(pk.Filters); i++ {
+ h.delKv(subscriptionKey(cl, pk.Filters[i].Filter))
+ }
+}
+
+// OnRetainMessage adds a retained message for a topic to the store.
+func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ if r == -1 {
+ h.delKv(retainedKey(pk.TopicName))
+ return
+ }
+
+ props := pk.Properties.Copy(false)
+ in := &storage.Message{
+ ID: retainedKey(pk.TopicName),
+ T: storage.RetainedKey,
+ FixedHeader: pk.FixedHeader,
+ TopicName: pk.TopicName,
+ Payload: pk.Payload,
+ Created: pk.Created,
+ Origin: pk.Origin,
+ Properties: storage.MessageProperties{
+ PayloadFormat: props.PayloadFormat,
+ MessageExpiryInterval: props.MessageExpiryInterval,
+ ContentType: props.ContentType,
+ ResponseTopic: props.ResponseTopic,
+ CorrelationData: props.CorrelationData,
+ SubscriptionIdentifier: props.SubscriptionIdentifier,
+ TopicAlias: props.TopicAlias,
+ User: props.User,
+ },
+ }
+
+ h.setKv(in.ID, in)
+}
+
+// OnQosPublish adds or updates an inflight message in the store.
+func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ props := pk.Properties.Copy(false)
+ in := &storage.Message{
+ ID: inflightKey(cl, pk),
+ T: storage.InflightKey,
+ Origin: pk.Origin,
+ PacketID: pk.PacketID,
+ FixedHeader: pk.FixedHeader,
+ TopicName: pk.TopicName,
+ Payload: pk.Payload,
+ Sent: sent,
+ Created: pk.Created,
+ Properties: storage.MessageProperties{
+ PayloadFormat: props.PayloadFormat,
+ MessageExpiryInterval: props.MessageExpiryInterval,
+ ContentType: props.ContentType,
+ ResponseTopic: props.ResponseTopic,
+ CorrelationData: props.CorrelationData,
+ SubscriptionIdentifier: props.SubscriptionIdentifier,
+ TopicAlias: props.TopicAlias,
+ User: props.User,
+ },
+ }
+ h.setKv(in.ID, in)
+}
+
+// OnQosComplete removes a resolved inflight message from the store.
+func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+ h.delKv(inflightKey(cl, pk))
+}
+
+// OnQosDropped removes a dropped inflight message from the store.
+func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ }
+
+ h.OnQosComplete(cl, pk)
+}
+
+// OnSysInfoTick stores the latest system info in the store.
+func (h *Hook) OnSysInfoTick(sys *system.Info) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ in := &storage.SystemInfo{
+ ID: sysInfoKey(),
+ T: storage.SysInfoKey,
+ Info: *sys.Clone(),
+ }
+ h.setKv(in.ID, in)
+}
+
+// OnRetainedExpired deletes expired retained messages from the store.
+func (h *Hook) OnRetainedExpired(filter string) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+ h.delKv(retainedKey(filter))
+}
+
+// OnClientExpired deleted expired clients from the store.
+func (h *Hook) OnClientExpired(cl *mqtt.Client) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+ h.delKv(clientKey(cl))
+}
+
+// StoredClients returns all stored clients from the store.
+func (h *Hook) StoredClients() (v []storage.Client, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ iter, _ := h.db.NewIter(&pebbledb.IterOptions{
+ LowerBound: []byte(storage.ClientKey),
+ UpperBound: keyUpperBound([]byte(storage.ClientKey)),
+ })
+
+ for iter.First(); iter.Valid(); iter.Next() {
+ item := storage.Client{}
+ if err := item.UnmarshalBinary(iter.Value()); err == nil {
+ v = append(v, item)
+ }
+ }
+ return v, nil
+}
+
+// StoredSubscriptions returns all stored subscriptions from the store.
+func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ iter, _ := h.db.NewIter(&pebbledb.IterOptions{
+ LowerBound: []byte(storage.SubscriptionKey),
+ UpperBound: keyUpperBound([]byte(storage.SubscriptionKey)),
+ })
+
+ for iter.First(); iter.Valid(); iter.Next() {
+ item := storage.Subscription{}
+ if err := item.UnmarshalBinary(iter.Value()); err == nil {
+ v = append(v, item)
+ }
+ }
+ return v, nil
+}
+
+// StoredRetainedMessages returns all stored retained messages from the store.
+func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ iter, _ := h.db.NewIter(&pebbledb.IterOptions{
+ LowerBound: []byte(storage.RetainedKey),
+ UpperBound: keyUpperBound([]byte(storage.RetainedKey)),
+ })
+
+ for iter.First(); iter.Valid(); iter.Next() {
+ item := storage.Message{}
+ if err := item.UnmarshalBinary(iter.Value()); err == nil {
+ v = append(v, item)
+ }
+ }
+ return v, nil
+}
+
+// StoredInflightMessages returns all stored inflight messages from the store.
+func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ iter, _ := h.db.NewIter(&pebbledb.IterOptions{
+ LowerBound: []byte(storage.InflightKey),
+ UpperBound: keyUpperBound([]byte(storage.InflightKey)),
+ })
+
+ for iter.First(); iter.Valid(); iter.Next() {
+ item := storage.Message{}
+ if err := item.UnmarshalBinary(iter.Value()); err == nil {
+ v = append(v, item)
+ }
+ }
+ return v, nil
+}
+
+// StoredSysInfo returns the system info from the store.
+func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ err = h.getKv(sysInfoKey(), &v)
+ if errors.Is(err, pebbledb.ErrNotFound) {
+ return v, nil
+ }
+
+ return
+}
+
+// Errorf satisfies the pebble interface for an error logger.
+func (h *Hook) Errorf(m string, v ...any) {
+ h.Log.Error(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
+
+}
+
+// Warningf satisfies the pebble interface for a warning logger.
+func (h *Hook) Warningf(m string, v ...any) {
+ h.Log.Warn(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
+}
+
+// Infof satisfies the pebble interface for an info logger.
+func (h *Hook) Infof(m string, v ...any) {
+ h.Log.Info(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
+}
+
+// Debugf satisfies the pebble interface for a debug logger.
+func (h *Hook) Debugf(m string, v ...any) {
+ h.Log.Debug(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
+}
+
+// delKv deletes a key-value pair from the database.
+func (h *Hook) delKv(k string) error {
+ err := h.db.Delete([]byte(k), h.mode)
+ if err != nil {
+ h.Log.Error("failed to delete data", "error", err, "key", k)
+ return err
+ }
+ return nil
+}
+
+// setKv stores a key-value pair in the database.
+func (h *Hook) setKv(k string, v storage.Serializable) error {
+ bs, _ := v.MarshalBinary()
+ err := h.db.Set([]byte(k), bs, h.mode)
+ if err != nil {
+ h.Log.Error("failed to update data", "error", err, "key", k)
+ return err
+ }
+ return nil
+}
+
+// getKv retrieves the value associated with a key from the database.
+func (h *Hook) getKv(k string, v storage.Serializable) error {
+ value, closer, err := h.db.Get([]byte(k))
+ if err != nil {
+ return err
+ }
+
+ defer func() {
+ if closer != nil {
+ closer.Close()
+ }
+ }()
+ return v.UnmarshalBinary(value)
+}
diff --git a/hooks/storage/pebble/pebble_test.go b/hooks/storage/pebble/pebble_test.go
new file mode 100644
index 0000000..eafc3f4
--- /dev/null
+++ b/hooks/storage/pebble/pebble_test.go
@@ -0,0 +1,812 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: werbenhu
+
+package pebble
+
+import (
+ "log/slog"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ pebbledb "github.com/cockroachdb/pebble"
+ "github.com/stretchr/testify/require"
+ "testmqtt/hooks/storage"
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+ "testmqtt/system"
+)
+
+var (
+ logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
+
+ client = &mqtt.Client{
+ ID: "test",
+ Net: mqtt.ClientConnection{
+ Remote: "test.addr",
+ Listener: "listener",
+ },
+ Properties: mqtt.ClientProperties{
+ Username: []byte("username"),
+ Clean: false,
+ },
+ }
+
+ pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
+)
+
+func teardown(t *testing.T, path string, h *Hook) {
+ _ = h.Stop()
+ err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
+ require.NoError(t, err)
+}
+
+func TestKeyUpperBound(t *testing.T) {
+ // Test case 1: Non-nil case
+ input1 := []byte{97, 98, 99} // "abc"
+ require.NotNil(t, keyUpperBound(input1))
+
+ // Test case 2: All bytes are 255
+ input2 := []byte{255, 255, 255}
+ require.Nil(t, keyUpperBound(input2))
+
+ // Test case 3: Empty slice
+ input3 := []byte{}
+ require.Nil(t, keyUpperBound(input3))
+
+ // Test case 4: Nil case
+ input4 := []byte{255, 255, 255}
+ require.Nil(t, keyUpperBound(input4))
+}
+
+func TestClientKey(t *testing.T) {
+ k := clientKey(&mqtt.Client{ID: "cl1"})
+ require.Equal(t, storage.ClientKey+"_cl1", k)
+}
+
+func TestSubscriptionKey(t *testing.T) {
+ k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
+ require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
+}
+
+func TestRetainedKey(t *testing.T) {
+ k := retainedKey("a/b/c")
+ require.Equal(t, storage.RetainedKey+"_a/b/c", k)
+}
+
+func TestInflightKey(t *testing.T) {
+ k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
+ require.Equal(t, storage.InflightKey+"_cl1:1", k)
+}
+
+func TestSysInfoKey(t *testing.T) {
+ require.Equal(t, storage.SysInfoKey, sysInfoKey())
+}
+
+func TestID(t *testing.T) {
+ h := new(Hook)
+ require.Equal(t, "pebble-db", h.ID())
+}
+
+func TestProvides(t *testing.T) {
+ h := new(Hook)
+ require.True(t, h.Provides(mqtt.OnSessionEstablished))
+ require.True(t, h.Provides(mqtt.OnDisconnect))
+ require.True(t, h.Provides(mqtt.OnSubscribed))
+ require.True(t, h.Provides(mqtt.OnUnsubscribed))
+ require.True(t, h.Provides(mqtt.OnRetainMessage))
+ require.True(t, h.Provides(mqtt.OnQosPublish))
+ require.True(t, h.Provides(mqtt.OnQosComplete))
+ require.True(t, h.Provides(mqtt.OnQosDropped))
+ require.True(t, h.Provides(mqtt.OnSysInfoTick))
+ require.True(t, h.Provides(mqtt.StoredClients))
+ require.True(t, h.Provides(mqtt.StoredInflightMessages))
+ require.True(t, h.Provides(mqtt.StoredRetainedMessages))
+ require.True(t, h.Provides(mqtt.StoredSubscriptions))
+ require.True(t, h.Provides(mqtt.StoredSysInfo))
+ require.False(t, h.Provides(mqtt.OnACLCheck))
+ require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
+}
+
+func TestInitBadConfig(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ err := h.Init(map[string]any{})
+ require.Error(t, err)
+}
+
+func TestInitErr(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ err := h.Init(&Options{
+ Options: &pebbledb.Options{
+ ReadOnly: true,
+ },
+ })
+ require.Error(t, err)
+}
+
+func TestInitUseDefaults(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ require.Equal(t, defaultDbFile, h.config.Path)
+}
+
+func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ h.OnSessionEstablished(client, packets.Packet{})
+
+ r := new(storage.Client)
+ err = h.getKv(clientKey(client), r)
+ require.NoError(t, err)
+ require.Equal(t, client.ID, r.ID)
+ require.Equal(t, client.Properties.Username, r.Username)
+ require.Equal(t, client.Properties.Clean, r.Clean)
+ require.Equal(t, client.Net.Remote, r.Remote)
+ require.Equal(t, client.Net.Listener, r.Listener)
+ require.NotSame(t, client, r)
+
+ h.OnDisconnect(client, nil, false)
+ r2 := new(storage.Client)
+ err = h.getKv(clientKey(client), r2)
+ require.NoError(t, err)
+ require.Equal(t, client.ID, r.ID)
+
+ h.OnDisconnect(client, nil, true)
+ r3 := new(storage.Client)
+ err = h.getKv(clientKey(client), r3)
+ require.Error(t, err)
+ require.ErrorIs(t, pebbledb.ErrNotFound, err)
+ require.Empty(t, r3.ID)
+
+}
+
+func TestOnClientExpired(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ cl := &mqtt.Client{ID: "cl1"}
+ clientKey := clientKey(cl)
+
+ err = h.setKv(clientKey, &storage.Client{ID: cl.ID})
+ require.NoError(t, err)
+
+ r := new(storage.Client)
+ err = h.getKv(clientKey, r)
+ require.NoError(t, err)
+ require.Equal(t, cl.ID, r.ID)
+
+ h.OnClientExpired(cl)
+ err = h.getKv(clientKey, r)
+ require.Error(t, err)
+ require.ErrorIs(t, err, pebbledb.ErrNotFound)
+}
+
+func TestOnClientExpiredNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnClientExpired(client)
+}
+
+func TestOnClientExpiredClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnClientExpired(client)
+}
+
+func TestOnSessionEstablishedNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnSessionEstablished(client, packets.Packet{})
+}
+
+func TestOnSessionEstablishedClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnSessionEstablished(client, packets.Packet{})
+}
+
+func TestOnWillSent(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ c1 := client
+ c1.Properties.Will.Flag = 1
+ h.OnWillSent(c1, packets.Packet{})
+
+ r := new(storage.Client)
+ err = h.getKv(clientKey(client), r)
+ require.NoError(t, err)
+
+ require.Equal(t, uint32(1), r.Will.Flag)
+ require.NotSame(t, client, r)
+}
+
+func TestOnDisconnectNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnDisconnect(client, nil, false)
+}
+
+func TestOnDisconnectClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnDisconnect(client, nil, false)
+}
+
+func TestOnDisconnectSessionTakenOver(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+
+ testClient := &mqtt.Client{
+ ID: "test",
+ Net: mqtt.ClientConnection{
+ Remote: "test.addr",
+ Listener: "listener",
+ },
+ Properties: mqtt.ClientProperties{
+ Username: []byte("username"),
+ Clean: false,
+ },
+ }
+
+ testClient.Stop(packets.ErrSessionTakenOver)
+ h.OnDisconnect(testClient, nil, true)
+ teardown(t, h.config.Path, h)
+}
+
+func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ h.OnSubscribed(client, pkf, []byte{0})
+ r := new(storage.Subscription)
+
+ err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
+ require.NoError(t, err)
+ require.Equal(t, client.ID, r.Client)
+ require.Equal(t, pkf.Filters[0].Filter, r.Filter)
+ require.Equal(t, byte(0), r.Qos)
+
+ h.OnUnsubscribed(client, pkf)
+ err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r)
+ require.Error(t, err)
+ require.Equal(t, pebbledb.ErrNotFound, err)
+}
+
+func TestOnSubscribedNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnSubscribed(client, pkf, []byte{0})
+}
+
+func TestOnSubscribedClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnSubscribed(client, pkf, []byte{0})
+}
+
+func TestOnUnsubscribedNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnUnsubscribed(client, pkf)
+}
+
+func TestOnUnsubscribedClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnUnsubscribed(client, pkf)
+}
+
+func TestOnRetainMessageThenUnset(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ pk := packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Retain: true,
+ },
+ Payload: []byte("hello"),
+ TopicName: "a/b/c",
+ }
+
+ h.OnRetainMessage(client, pk, 1)
+
+ r := new(storage.Message)
+ err = h.getKv(retainedKey(pk.TopicName), r)
+ require.NoError(t, err)
+ require.Equal(t, pk.TopicName, r.TopicName)
+ require.Equal(t, pk.Payload, r.Payload)
+
+ h.OnRetainMessage(client, pk, -1)
+ err = h.getKv(retainedKey(pk.TopicName), r)
+ require.Error(t, err)
+ require.ErrorIs(t, err, pebbledb.ErrNotFound)
+
+ // coverage: delete deleted
+ h.OnRetainMessage(client, pk, -1)
+ err = h.getKv(retainedKey(pk.TopicName), r)
+ require.Error(t, err)
+ require.ErrorIs(t, err, pebbledb.ErrNotFound)
+}
+
+func TestOnRetainedExpired(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ m := &storage.Message{
+ ID: retainedKey("a/b/c"),
+ T: storage.RetainedKey,
+ TopicName: "a/b/c",
+ }
+
+ err = h.setKv(m.ID, m)
+ require.NoError(t, err)
+
+ r := new(storage.Message)
+ err = h.getKv(m.ID, r)
+ require.NoError(t, err)
+ require.Equal(t, m.TopicName, r.TopicName)
+
+ h.OnRetainedExpired(m.TopicName)
+ err = h.getKv(m.ID, r)
+ require.Error(t, err)
+ require.ErrorIs(t, err, pebbledb.ErrNotFound)
+}
+
+func TestOnRetainExpiredNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnRetainedExpired("a/b/c")
+}
+
+func TestOnRetainExpiredClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnRetainedExpired("a/b/c")
+}
+
+func TestOnRetainMessageNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnRetainMessage(client, packets.Packet{}, 0)
+}
+
+func TestOnRetainMessageClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnRetainMessage(client, packets.Packet{}, 0)
+}
+
+func TestOnQosPublishThenQOSComplete(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ pk := packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Retain: true,
+ Qos: 2,
+ },
+ Payload: []byte("hello"),
+ TopicName: "a/b/c",
+ }
+
+ h.OnQosPublish(client, pk, time.Now().Unix(), 0)
+
+ r := new(storage.Message)
+ err = h.getKv(inflightKey(client, pk), r)
+ require.NoError(t, err)
+ require.Equal(t, pk.TopicName, r.TopicName)
+ require.Equal(t, pk.Payload, r.Payload)
+
+ // ensure dates are properly saved
+ require.True(t, r.Sent > 0)
+ require.True(t, time.Now().Unix()-1 < r.Sent)
+
+ // OnQosDropped is a passthrough to OnQosComplete here
+ h.OnQosDropped(client, pk)
+ err = h.getKv(inflightKey(client, pk), r)
+ require.Error(t, err)
+ require.ErrorIs(t, err, pebbledb.ErrNotFound)
+}
+
+func TestOnQosPublishNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
+}
+
+func TestOnQosPublishClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
+}
+
+func TestOnQosCompleteNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnQosComplete(client, packets.Packet{})
+}
+
+func TestOnQosCompleteClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnQosComplete(client, packets.Packet{})
+}
+
+func TestOnQosDroppedNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnQosDropped(client, packets.Packet{})
+}
+
+func TestOnSysInfoTick(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ info := &system.Info{
+ Version: "2.0.0",
+ BytesReceived: 100,
+ }
+
+ h.OnSysInfoTick(info)
+
+ r := new(storage.SystemInfo)
+ err = h.getKv(storage.SysInfoKey, r)
+ require.NoError(t, err)
+ require.Equal(t, info.Version, r.Version)
+ require.Equal(t, info.BytesReceived, r.BytesReceived)
+ require.NotSame(t, info, r)
+}
+
+func TestOnSysInfoTickNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.OnSysInfoTick(new(system.Info))
+}
+
+func TestOnSysInfoTickClosedDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ teardown(t, h.config.Path, h)
+ h.OnSysInfoTick(new(system.Info))
+}
+
+func TestStoredClients(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ // populate with clients
+ err = h.setKv(storage.ClientKey+"_cl1", &storage.Client{ID: "cl1"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.ClientKey+"_cl2", &storage.Client{ID: "cl2"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.ClientKey+"_cl3", &storage.Client{ID: "cl3"})
+ require.NoError(t, err)
+
+ r, err := h.StoredClients()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ require.Equal(t, "cl1", r[0].ID)
+ require.Equal(t, "cl2", r[1].ID)
+ require.Equal(t, "cl3", r[2].ID)
+}
+
+func TestStoredClientsNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredClients()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestStoredSubscriptions(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ // populate with subscriptions
+ err = h.setKv(storage.SubscriptionKey+"_sub1", &storage.Subscription{ID: "sub1"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.SubscriptionKey+"_sub2", &storage.Subscription{ID: "sub2"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.SubscriptionKey+"_sub3", &storage.Subscription{ID: "sub3"})
+ require.NoError(t, err)
+
+ r, err := h.StoredSubscriptions()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ require.Equal(t, "sub1", r[0].ID)
+ require.Equal(t, "sub2", r[1].ID)
+ require.Equal(t, "sub3", r[2].ID)
+}
+
+func TestStoredSubscriptionsNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredSubscriptions()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestStoredRetainedMessages(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ // populate with messages
+ err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.RetainedKey+"_m2", &storage.Message{ID: "m2"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.RetainedKey+"_m3", &storage.Message{ID: "m3"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3"})
+ require.NoError(t, err)
+
+ r, err := h.StoredRetainedMessages()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ require.Equal(t, "m1", r[0].ID)
+ require.Equal(t, "m2", r[1].ID)
+ require.Equal(t, "m3", r[2].ID)
+}
+
+func TestStoredRetainedMessagesNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredRetainedMessages()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestStoredInflightMessages(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ // populate with messages
+ err = h.setKv(storage.InflightKey+"_i1", &storage.Message{ID: "i1"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.InflightKey+"_i2", &storage.Message{ID: "i2"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3"})
+ require.NoError(t, err)
+
+ err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1"})
+ require.NoError(t, err)
+
+ r, err := h.StoredInflightMessages()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ require.Equal(t, "i1", r[0].ID)
+ require.Equal(t, "i2", r[1].ID)
+ require.Equal(t, "i3", r[2].ID)
+}
+
+func TestStoredInflightMessagesNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredInflightMessages()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestStoredSysInfo(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ r, err := h.StoredSysInfo()
+ require.NoError(t, err)
+
+ // populate with messages
+ err = h.setKv(storage.SysInfoKey, &storage.SystemInfo{
+ ID: storage.SysInfoKey,
+ Info: system.Info{
+ Version: "2.0.0",
+ },
+ T: storage.SysInfoKey,
+ })
+ require.NoError(t, err)
+
+ r, err = h.StoredSysInfo()
+ require.NoError(t, err)
+ require.Equal(t, "2.0.0", r.Info.Version)
+}
+
+func TestStoredSysInfoNoDB(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ v, err := h.StoredSysInfo()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestErrorf(t *testing.T) {
+ // coverage: one day check log hook
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.Errorf("test", 1, 2, 3)
+}
+
+func TestWarningf(t *testing.T) {
+ // coverage: one day check log hook
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.Warningf("test", 1, 2, 3)
+}
+
+func TestInfof(t *testing.T) {
+ // coverage: one day check log hook
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.Infof("test", 1, 2, 3)
+}
+
+func TestDebugf(t *testing.T) {
+ // coverage: one day check log hook
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ h.Debugf("test", 1, 2, 3)
+}
+
+func TestGetSetDelKv(t *testing.T) {
+ opts := []struct {
+ name string
+ opt *Options
+ }{
+ {
+ name: "NoSync",
+ opt: &Options{
+ Mode: NoSync,
+ },
+ },
+ {
+ name: "Sync",
+ opt: &Options{
+ Mode: Sync,
+ },
+ },
+ }
+ for _, tt := range opts {
+ t.Run(tt.name, func(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(tt.opt)
+ defer teardown(t, h.config.Path, h)
+ require.NoError(t, err)
+
+ err = h.setKv("testKey", &storage.Client{ID: "testId"})
+ require.NoError(t, err)
+
+ var obj storage.Client
+ err = h.getKv("testKey", &obj)
+ require.NoError(t, err)
+
+ err = h.delKv("testKey")
+ require.NoError(t, err)
+
+ err = h.getKv("testKey", &obj)
+ require.Error(t, err)
+ require.ErrorIs(t, pebbledb.ErrNotFound, err)
+ })
+ }
+}
+
+func TestGetSetDelKvErr(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ err := h.Init(&Options{
+ Mode: Sync,
+ })
+ require.NoError(t, err)
+
+ err = h.setKv("testKey", &storage.Client{ID: "testId"})
+ require.NoError(t, err)
+ h.Stop()
+
+ h = new(Hook)
+ h.SetOpts(logger, nil)
+
+ err = h.Init(&Options{
+ Mode: Sync,
+ Options: &pebbledb.Options{
+ ReadOnly: true,
+ },
+ })
+ require.NoError(t, err)
+ defer teardown(t, h.config.Path, h)
+
+ err = h.setKv("testKey", &storage.Client{ID: "testId"})
+ require.Error(t, err)
+
+ err = h.delKv("testKey")
+ require.Error(t, err)
+}
diff --git a/hooks/storage/redis/redis.go b/hooks/storage/redis/redis.go
new file mode 100644
index 0000000..24ed872
--- /dev/null
+++ b/hooks/storage/redis/redis.go
@@ -0,0 +1,532 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package redis
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+
+ "testmqtt/hooks/storage"
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+ "testmqtt/system"
+
+ "github.com/go-redis/redis/v8"
+)
+
+// defaultAddr 默认Redis地址
+// defaultAddr is the default address to the redis service.
+const defaultAddr = "localhost:6379"
+
+// defaultHPrefix 默认前缀
+// defaultHPrefix is a prefix to better identify hsets created by mochi mqtt.
+const defaultHPrefix = "mochi-"
+
+// clientKey returns a primary key for a client.
+func clientKey(cl *mqtt.Client) string {
+ return cl.ID
+}
+
+// subscriptionKey returns a primary key for a subscription.
+func subscriptionKey(cl *mqtt.Client, filter string) string {
+ return cl.ID + ":" + filter
+}
+
+// retainedKey returns a primary key for a retained message.
+func retainedKey(topic string) string {
+ return topic
+}
+
+// inflightKey returns a primary key for an inflight message.
+func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
+ return cl.ID + ":" + pk.FormatID()
+}
+
+// sysInfoKey returns a primary key for system info.
+func sysInfoKey() string {
+ return storage.SysInfoKey
+}
+
+// Options contains configuration settings for the bolt instance.
+type Options struct {
+ Address string `yaml:"address" json:"address"`
+ Username string `yaml:"username" json:"username"`
+ Password string `yaml:"password" json:"password"`
+ Database int `yaml:"database" json:"database"`
+ HPrefix string `yaml:"h_prefix" json:"h_prefix"`
+ Options *redis.Options
+}
+
+// Hook is a persistent storage hook based using Redis as a backend.
+type Hook struct {
+ mqtt.HookBase
+ config *Options // options for connecting to the Redis instance.
+ db *redis.Client // the Redis instance
+ ctx context.Context // a context for the connection
+}
+
+// ID returns the id of the hook.
+func (h *Hook) ID() string {
+ return "redis-db"
+}
+
+// Provides indicates which hook methods this hook provides.
+func (h *Hook) Provides(b byte) bool {
+ return bytes.Contains([]byte{
+ mqtt.OnSessionEstablished,
+ mqtt.OnDisconnect,
+ mqtt.OnSubscribed,
+ mqtt.OnUnsubscribed,
+ mqtt.OnRetainMessage,
+ mqtt.OnQosPublish,
+ mqtt.OnQosComplete,
+ mqtt.OnQosDropped,
+ mqtt.OnWillSent,
+ mqtt.OnSysInfoTick,
+ mqtt.OnClientExpired,
+ mqtt.OnRetainedExpired,
+ mqtt.StoredClients,
+ mqtt.StoredInflightMessages,
+ mqtt.StoredRetainedMessages,
+ mqtt.StoredSubscriptions,
+ mqtt.StoredSysInfo,
+ }, []byte{b})
+}
+
+// hKey returns a hash set key with a unique prefix.
+func (h *Hook) hKey(s string) string {
+ return h.config.HPrefix + s
+}
+
+// Init 初始化并连接到 Redis 服务
+// Init initializes and connects to the redis service.
+func (h *Hook) Init(config any) error {
+
+ if _, ok := config.(*Options); !ok && config != nil {
+ return mqtt.ErrInvalidConfigType
+ }
+
+ h.ctx = context.Background()
+
+ if config == nil {
+ config = new(Options)
+ }
+ h.config = config.(*Options)
+ if h.config.Options == nil {
+ h.config.Options = &redis.Options{
+ Addr: defaultAddr,
+ }
+ h.config.Options.Addr = h.config.Address
+ h.config.Options.DB = h.config.Database
+ //h.config.Options.Username = h.config.Username
+ h.config.Options.Password = h.config.Password
+ }
+
+ if h.config.HPrefix == "" {
+ h.config.HPrefix = defaultHPrefix
+ }
+
+ h.Log.Info(
+ "connecting to redis service",
+ "prefix", h.config.HPrefix,
+ "address", h.config.Options.Addr,
+ "username", h.config.Options.Username,
+ "password-len", len(h.config.Options.Password),
+ "db", h.config.Options.DB,
+ )
+
+ h.db = redis.NewClient(h.config.Options)
+ _, err := h.db.Ping(context.Background()).Result()
+ if err != nil {
+ return fmt.Errorf("failed to ping service: %w", err)
+ }
+
+ h.Log.Info("connected to redis service")
+
+ return nil
+}
+
+// Stop closes the redis connection.
+func (h *Hook) Stop() error {
+ h.Log.Info("disconnecting from redis service")
+
+ return h.db.Close()
+}
+
+// OnSessionEstablished adds a client to the store when their session is established.
+func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
+ h.updateClient(cl)
+}
+
+// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
+func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
+ h.updateClient(cl)
+}
+
+// updateClient writes the client data to the store.
+func (h *Hook) updateClient(cl *mqtt.Client) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ props := cl.Properties.Props.Copy(false)
+ in := &storage.Client{
+ ID: clientKey(cl),
+ T: storage.ClientKey,
+ Remote: cl.Net.Remote,
+ Listener: cl.Net.Listener,
+ Username: cl.Properties.Username,
+ Clean: cl.Properties.Clean,
+ ProtocolVersion: cl.Properties.ProtocolVersion,
+ Properties: storage.ClientProperties{
+ SessionExpiryInterval: props.SessionExpiryInterval,
+ AuthenticationMethod: props.AuthenticationMethod,
+ AuthenticationData: props.AuthenticationData,
+ RequestProblemInfo: props.RequestProblemInfo,
+ RequestResponseInfo: props.RequestResponseInfo,
+ ReceiveMaximum: props.ReceiveMaximum,
+ TopicAliasMaximum: props.TopicAliasMaximum,
+ User: props.User,
+ MaximumPacketSize: props.MaximumPacketSize,
+ },
+ Will: storage.ClientWill(cl.Properties.Will),
+ }
+
+ err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err()
+ if err != nil {
+ h.Log.Error("failed to hset client data", "error", err, "data", in)
+ }
+}
+
+// OnDisconnect removes a client from the store if they were using a clean session.
+func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ if !expire {
+ return
+ }
+
+ if cl.StopCause() == packets.ErrSessionTakenOver {
+ return
+ }
+
+ err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
+ if err != nil {
+ h.Log.Error("failed to delete client", "error", err, "id", clientKey(cl))
+ }
+}
+
+// OnSubscribed adds one or more client subscriptions to the store.
+func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ var in *storage.Subscription
+ for i := 0; i < len(pk.Filters); i++ {
+ in = &storage.Subscription{
+ ID: subscriptionKey(cl, pk.Filters[i].Filter),
+ T: storage.SubscriptionKey,
+ Client: cl.ID,
+ Qos: reasonCodes[i],
+ Filter: pk.Filters[i].Filter,
+ Identifier: pk.Filters[i].Identifier,
+ NoLocal: pk.Filters[i].NoLocal,
+ RetainHandling: pk.Filters[i].RetainHandling,
+ RetainAsPublished: pk.Filters[i].RetainAsPublished,
+ }
+
+ err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
+ if err != nil {
+ h.Log.Error("failed to hset subscription data", "error", err, "data", in)
+ }
+ }
+}
+
+// OnUnsubscribed removes one or more client subscriptions from the store.
+func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ for i := 0; i < len(pk.Filters); i++ {
+ err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err()
+ if err != nil {
+ h.Log.Error("failed to delete subscription data", "error", err, "id", clientKey(cl))
+ }
+ }
+}
+
+// OnRetainMessage adds a retained message for a topic to the store.
+func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ if r == -1 {
+ err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err()
+ if err != nil {
+ h.Log.Error("failed to delete retained message data", "error", err, "id", retainedKey(pk.TopicName))
+ }
+
+ return
+ }
+
+ props := pk.Properties.Copy(false)
+ in := &storage.Message{
+ ID: retainedKey(pk.TopicName),
+ T: storage.RetainedKey,
+ FixedHeader: pk.FixedHeader,
+ TopicName: pk.TopicName,
+ Payload: pk.Payload,
+ Created: pk.Created,
+ Origin: pk.Origin,
+ Properties: storage.MessageProperties{
+ PayloadFormat: props.PayloadFormat,
+ MessageExpiryInterval: props.MessageExpiryInterval,
+ ContentType: props.ContentType,
+ ResponseTopic: props.ResponseTopic,
+ CorrelationData: props.CorrelationData,
+ SubscriptionIdentifier: props.SubscriptionIdentifier,
+ TopicAlias: props.TopicAlias,
+ User: props.User,
+ },
+ }
+
+ err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err()
+ if err != nil {
+ h.Log.Error("failed to hset retained message data", "error", err, "data", in)
+ }
+}
+
+// OnQosPublish adds or updates an inflight message in the store.
+func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ props := pk.Properties.Copy(false)
+ in := &storage.Message{
+ ID: inflightKey(cl, pk),
+ T: storage.InflightKey,
+ Origin: pk.Origin,
+ FixedHeader: pk.FixedHeader,
+ TopicName: pk.TopicName,
+ Payload: pk.Payload,
+ Sent: sent,
+ Created: pk.Created,
+ Properties: storage.MessageProperties{
+ PayloadFormat: props.PayloadFormat,
+ MessageExpiryInterval: props.MessageExpiryInterval,
+ ContentType: props.ContentType,
+ ResponseTopic: props.ResponseTopic,
+ CorrelationData: props.CorrelationData,
+ SubscriptionIdentifier: props.SubscriptionIdentifier,
+ TopicAlias: props.TopicAlias,
+ User: props.User,
+ },
+ }
+
+ err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err()
+ if err != nil {
+ h.Log.Error("failed to hset qos inflight message data", "error", err, "data", in)
+ }
+}
+
+// OnQosComplete removes a resolved inflight message from the store.
+func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err()
+ if err != nil {
+ h.Log.Error("failed to delete qos inflight message data", "error", err, "id", inflightKey(cl, pk))
+ }
+}
+
+// OnQosDropped removes a dropped inflight message from the store.
+func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ }
+
+ h.OnQosComplete(cl, pk)
+}
+
+// OnSysInfoTick stores the latest system info in the store.
+func (h *Hook) OnSysInfoTick(sys *system.Info) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ in := &storage.SystemInfo{
+ ID: sysInfoKey(),
+ T: storage.SysInfoKey,
+ Info: *sys,
+ }
+
+ err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err()
+ if err != nil {
+ h.Log.Error("failed to hset server info data", "error", err, "data", in)
+ }
+}
+
+// OnRetainedExpired deletes expired retained messages from the store.
+func (h *Hook) OnRetainedExpired(filter string) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
+ if err != nil {
+ h.Log.Error("failed to delete expired retained message", "error", err, "id", retainedKey(filter))
+ }
+}
+
+// OnClientExpired deleted expired clients from the store.
+func (h *Hook) OnClientExpired(cl *mqtt.Client) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
+ if err != nil {
+ h.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl))
+ }
+}
+
+// StoredClients returns all stored clients from the store.
+func (h *Hook) StoredClients() (v []storage.Client, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result()
+ if err != nil && !errors.Is(err, redis.Nil) {
+ h.Log.Error("failed to HGetAll client data", "error", err)
+ return
+ }
+
+ for _, row := range rows {
+ var d storage.Client
+ if err = d.UnmarshalBinary([]byte(row)); err != nil {
+ h.Log.Error("failed to unmarshal client data", "error", err, "data", row)
+ }
+
+ v = append(v, d)
+ }
+
+ return v, nil
+}
+
+// StoredSubscriptions returns all stored subscriptions from the store.
+func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result()
+ if err != nil && !errors.Is(err, redis.Nil) {
+ h.Log.Error("failed to HGetAll subscription data", "error", err)
+ return
+ }
+
+ for _, row := range rows {
+ var d storage.Subscription
+ if err = d.UnmarshalBinary([]byte(row)); err != nil {
+ h.Log.Error("failed to unmarshal subscription data", "error", err, "data", row)
+ }
+
+ v = append(v, d)
+ }
+
+ return v, nil
+}
+
+// StoredRetainedMessages returns all stored retained messages from the store.
+func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result()
+ if err != nil && !errors.Is(err, redis.Nil) {
+ h.Log.Error("failed to HGetAll retained message data", "error", err)
+ return
+ }
+
+ for _, row := range rows {
+ var d storage.Message
+ if err = d.UnmarshalBinary([]byte(row)); err != nil {
+ h.Log.Error("failed to unmarshal retained message data", "error", err, "data", row)
+ }
+
+ v = append(v, d)
+ }
+
+ return v, nil
+}
+
+// StoredInflightMessages returns all stored inflight messages from the store.
+func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
+ if err != nil && !errors.Is(err, redis.Nil) {
+ h.Log.Error("failed to HGetAll inflight message data", "error", err)
+ return
+ }
+
+ for _, row := range rows {
+ var d storage.Message
+ if err = d.UnmarshalBinary([]byte(row)); err != nil {
+ h.Log.Error("failed to unmarshal inflight message data", "error", err, "data", row)
+ }
+
+ v = append(v, d)
+ }
+
+ return v, nil
+}
+
+// StoredSysInfo returns the system info from the store.
+func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
+ if h.db == nil {
+ h.Log.Error("", "error", storage.ErrDBFileNotOpen)
+ return
+ }
+
+ row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
+ if err != nil && !errors.Is(err, redis.Nil) {
+ return
+ }
+
+ if err = v.UnmarshalBinary([]byte(row)); err != nil {
+ h.Log.Error("failed to unmarshal sys info data", "error", err, "data", row)
+ }
+
+ return v, nil
+}
diff --git a/hooks/storage/redis/redis_test.go b/hooks/storage/redis/redis_test.go
new file mode 100644
index 0000000..e3a5c90
--- /dev/null
+++ b/hooks/storage/redis/redis_test.go
@@ -0,0 +1,834 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package redis
+
+import (
+ "log/slog"
+ "os"
+ "sort"
+ "testing"
+ "time"
+
+ "testmqtt/hooks/storage"
+ "testmqtt/mqtt"
+ "testmqtt/packets"
+ "testmqtt/system"
+
+ miniredis "github.com/alicebob/miniredis/v2"
+ redis "github.com/go-redis/redis/v8"
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
+
+ client = &mqtt.Client{
+ ID: "test",
+ Net: mqtt.ClientConnection{
+ Remote: "test.addr",
+ Listener: "listener",
+ },
+ Properties: mqtt.ClientProperties{
+ Username: []byte("username"),
+ Clean: false,
+ },
+ }
+
+ pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
+)
+
+func newHook(t *testing.T, addr string) *Hook {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ err := h.Init(&Options{
+ Options: &redis.Options{
+ Addr: addr,
+ },
+ })
+ require.NoError(t, err)
+
+ return h
+}
+
+func teardown(t *testing.T, h *Hook) {
+ if h.db != nil {
+ err := h.db.FlushAll(h.ctx).Err()
+ require.NoError(t, err)
+ h.Stop()
+ }
+}
+
+func TestClientKey(t *testing.T) {
+ k := clientKey(&mqtt.Client{ID: "cl1"})
+ require.Equal(t, "cl1", k)
+}
+
+func TestSubscriptionKey(t *testing.T) {
+ k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
+ require.Equal(t, "cl1:a/b/c", k)
+}
+
+func TestRetainedKey(t *testing.T) {
+ k := retainedKey("a/b/c")
+ require.Equal(t, "a/b/c", k)
+}
+
+func TestInflightKey(t *testing.T) {
+ k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
+ require.Equal(t, "cl1:1", k)
+}
+
+func TestSysInfoKey(t *testing.T) {
+ require.Equal(t, storage.SysInfoKey, sysInfoKey())
+}
+
+func TestID(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ require.Equal(t, "redis-db", h.ID())
+}
+
+func TestProvides(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ require.True(t, h.Provides(mqtt.OnSessionEstablished))
+ require.True(t, h.Provides(mqtt.OnDisconnect))
+ require.True(t, h.Provides(mqtt.OnSubscribed))
+ require.True(t, h.Provides(mqtt.OnUnsubscribed))
+ require.True(t, h.Provides(mqtt.OnRetainMessage))
+ require.True(t, h.Provides(mqtt.OnQosPublish))
+ require.True(t, h.Provides(mqtt.OnQosComplete))
+ require.True(t, h.Provides(mqtt.OnQosDropped))
+ require.True(t, h.Provides(mqtt.OnSysInfoTick))
+ require.True(t, h.Provides(mqtt.StoredClients))
+ require.True(t, h.Provides(mqtt.StoredInflightMessages))
+ require.True(t, h.Provides(mqtt.StoredRetainedMessages))
+ require.True(t, h.Provides(mqtt.StoredSubscriptions))
+ require.True(t, h.Provides(mqtt.StoredSysInfo))
+ require.False(t, h.Provides(mqtt.OnACLCheck))
+ require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
+}
+
+func TestHKey(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.SetOpts(logger, nil)
+ require.Equal(t, defaultHPrefix+"test", h.hKey("test"))
+}
+
+func TestInitUseDefaults(t *testing.T) {
+ s := miniredis.RunT(t)
+ s.StartAddr(defaultAddr)
+ defer s.Close()
+
+ h := newHook(t, defaultAddr)
+ h.SetOpts(logger, nil)
+ err := h.Init(nil)
+ require.NoError(t, err)
+ defer teardown(t, h)
+
+ require.Equal(t, defaultHPrefix, h.config.HPrefix)
+ require.Equal(t, defaultAddr, h.config.Options.Addr)
+}
+
+func TestInitUsePassConfig(t *testing.T) {
+ s := miniredis.RunT(t)
+ s.StartAddr(defaultAddr)
+ defer s.Close()
+
+ h := newHook(t, "")
+ h.SetOpts(logger, nil)
+
+ err := h.Init(&Options{
+ Address: defaultAddr,
+ Username: "username",
+ Password: "password",
+ Database: 2,
+ })
+ require.Error(t, err)
+ h.db.FlushAll(h.ctx)
+
+ require.Equal(t, defaultAddr, h.config.Options.Addr)
+ require.Equal(t, "username", h.config.Options.Username)
+ require.Equal(t, "password", h.config.Options.Password)
+ require.Equal(t, 2, h.config.Options.DB)
+}
+
+func TestInitBadConfig(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+
+ err := h.Init(map[string]any{})
+ require.Error(t, err)
+}
+
+func TestInitBadAddr(t *testing.T) {
+ h := new(Hook)
+ h.SetOpts(logger, nil)
+ err := h.Init(&Options{
+ Options: &redis.Options{
+ Addr: "abc:123",
+ },
+ })
+ require.Error(t, err)
+}
+
+func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ defer teardown(t, h)
+
+ h.OnSessionEstablished(client, packets.Packet{})
+
+ r := new(storage.Client)
+ row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
+ require.NoError(t, err)
+ err = r.UnmarshalBinary([]byte(row))
+ require.NoError(t, err)
+
+ require.Equal(t, client.ID, r.ID)
+ require.Equal(t, client.Net.Remote, r.Remote)
+ require.Equal(t, client.Net.Listener, r.Listener)
+ require.Equal(t, client.Properties.Username, r.Username)
+ require.Equal(t, client.Properties.Clean, r.Clean)
+ require.NotSame(t, client, r)
+
+ h.OnDisconnect(client, nil, false)
+ r2 := new(storage.Client)
+ row, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
+ require.NoError(t, err)
+ err = r2.UnmarshalBinary([]byte(row))
+ require.NoError(t, err)
+ require.Equal(t, client.ID, r.ID)
+
+ h.OnDisconnect(client, nil, true)
+ r3 := new(storage.Client)
+ _, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
+ require.Error(t, err)
+ require.ErrorIs(t, err, redis.Nil)
+ require.Empty(t, r3.ID)
+}
+
+func TestOnSessionEstablishedNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+
+ h.db = nil
+ h.OnSessionEstablished(client, packets.Packet{})
+}
+
+func TestOnSessionEstablishedClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+ h.OnSessionEstablished(client, packets.Packet{})
+}
+
+func TestOnWillSent(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ defer teardown(t, h)
+
+ c1 := client
+ c1.Properties.Will.Flag = 1
+ h.OnWillSent(c1, packets.Packet{})
+
+ r := new(storage.Client)
+ row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
+ require.NoError(t, err)
+ err = r.UnmarshalBinary([]byte(row))
+ require.NoError(t, err)
+
+ require.Equal(t, uint32(1), r.Will.Flag)
+ require.NotSame(t, client, r)
+}
+
+func TestOnClientExpired(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ defer teardown(t, h)
+
+ cl := &mqtt.Client{ID: "cl1"}
+ clientKey := clientKey(cl)
+
+ err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey, &storage.Client{ID: cl.ID}).Err()
+ require.NoError(t, err)
+
+ r := new(storage.Client)
+ row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
+ require.NoError(t, err)
+ err = r.UnmarshalBinary([]byte(row))
+ require.NoError(t, err)
+ require.Equal(t, clientKey, r.ID)
+
+ h.OnClientExpired(cl)
+ _, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
+ require.Error(t, err)
+ require.ErrorIs(t, redis.Nil, err)
+}
+
+func TestOnClientExpiredClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+ h.OnClientExpired(client)
+}
+
+func TestOnClientExpiredNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ h.OnClientExpired(client)
+}
+
+func TestOnDisconnectNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ h.OnDisconnect(client, nil, false)
+}
+
+func TestOnDisconnectClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+ h.OnDisconnect(client, nil, false)
+}
+
+func TestOnDisconnectSessionTakenOver(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+
+ testClient := &mqtt.Client{
+ ID: "test",
+ Net: mqtt.ClientConnection{
+ Remote: "test.addr",
+ Listener: "listener",
+ },
+ Properties: mqtt.ClientProperties{
+ Username: []byte("username"),
+ Clean: false,
+ },
+ }
+
+ testClient.Stop(packets.ErrSessionTakenOver)
+ teardown(t, h)
+ h.OnDisconnect(testClient, nil, true)
+}
+
+func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ defer teardown(t, h)
+
+ h.OnSubscribed(client, pkf, []byte{0})
+
+ r := new(storage.Subscription)
+ row, err := h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
+ require.NoError(t, err)
+ err = r.UnmarshalBinary([]byte(row))
+ require.NoError(t, err)
+ require.Equal(t, client.ID, r.Client)
+ require.Equal(t, pkf.Filters[0].Filter, r.Filter)
+ require.Equal(t, byte(0), r.Qos)
+
+ h.OnUnsubscribed(client, pkf)
+ _, err = h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
+ require.Error(t, err)
+ require.ErrorIs(t, err, redis.Nil)
+}
+
+func TestOnSubscribedNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ h.OnSubscribed(client, pkf, []byte{0})
+}
+
+func TestOnSubscribedClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+ h.OnSubscribed(client, pkf, []byte{0})
+}
+
+func TestOnUnsubscribedNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ h.OnUnsubscribed(client, pkf)
+}
+
+func TestOnUnsubscribedClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+ h.OnUnsubscribed(client, pkf)
+}
+
+func TestOnRetainMessageThenUnset(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ defer teardown(t, h)
+
+ pk := packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Retain: true,
+ },
+ Payload: []byte("hello"),
+ TopicName: "a/b/c",
+ }
+
+ h.OnRetainMessage(client, pk, 1)
+
+ r := new(storage.Message)
+ row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
+ require.NoError(t, err)
+ err = r.UnmarshalBinary([]byte(row))
+ require.NoError(t, err)
+
+ require.NoError(t, err)
+ require.Equal(t, pk.TopicName, r.TopicName)
+ require.Equal(t, pk.Payload, r.Payload)
+
+ h.OnRetainMessage(client, pk, -1)
+ _, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
+ require.Error(t, err)
+ require.ErrorIs(t, err, redis.Nil)
+
+ // coverage: delete deleted
+ h.OnRetainMessage(client, pk, -1)
+ _, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
+ require.Error(t, err)
+ require.ErrorIs(t, err, redis.Nil)
+}
+
+func TestOnRetainedExpired(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ defer teardown(t, h)
+
+ m := &storage.Message{
+ ID: retainedKey("a/b/c"),
+ T: storage.RetainedKey,
+ TopicName: "a/b/c",
+ }
+
+ err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), m.ID, m).Err()
+ require.NoError(t, err)
+
+ r := new(storage.Message)
+ row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
+ require.NoError(t, err)
+ err = r.UnmarshalBinary([]byte(row))
+ require.NoError(t, err)
+ require.NoError(t, err)
+ require.Equal(t, m.TopicName, r.TopicName)
+
+ h.OnRetainedExpired(m.TopicName)
+
+ _, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
+ require.Error(t, err)
+ require.ErrorIs(t, err, redis.Nil)
+}
+
+func TestOnRetainedExpiredClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+ h.OnRetainedExpired("a/b/c")
+}
+
+func TestOnRetainedExpiredNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ h.OnRetainedExpired("a/b/c")
+}
+
+func TestOnRetainMessageNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ h.OnRetainMessage(client, packets.Packet{}, 0)
+}
+
+func TestOnRetainMessageClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+ h.OnRetainMessage(client, packets.Packet{}, 0)
+}
+
+func TestOnQosPublishThenQOSComplete(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ defer teardown(t, h)
+
+ pk := packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Retain: true,
+ Qos: 2,
+ },
+ Payload: []byte("hello"),
+ TopicName: "a/b/c",
+ }
+
+ h.OnQosPublish(client, pk, time.Now().Unix(), 0)
+
+ r := new(storage.Message)
+ row, err := h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
+ require.NoError(t, err)
+ err = r.UnmarshalBinary([]byte(row))
+ require.NoError(t, err)
+ require.Equal(t, pk.TopicName, r.TopicName)
+ require.Equal(t, pk.Payload, r.Payload)
+
+ // ensure dates are properly saved to bolt
+ require.True(t, r.Sent > 0)
+ require.True(t, time.Now().Unix()-1 < r.Sent)
+
+ // OnQosDropped is a passthrough to OnQosComplete here
+ h.OnQosDropped(client, pk)
+ _, err = h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
+ require.Error(t, err)
+ require.ErrorIs(t, err, redis.Nil)
+}
+
+func TestOnQosPublishNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
+}
+
+func TestOnQosPublishClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+ h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
+}
+
+func TestOnQosCompleteNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ h.OnQosComplete(client, packets.Packet{})
+}
+
+func TestOnQosCompleteClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+ h.OnQosComplete(client, packets.Packet{})
+}
+
+func TestOnQosDroppedNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ h.OnQosDropped(client, packets.Packet{})
+}
+
+func TestOnSysInfoTick(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ defer teardown(t, h)
+
+ info := &system.Info{
+ Version: "2.0.0",
+ BytesReceived: 100,
+ }
+
+ h.OnSysInfoTick(info)
+
+ r := new(storage.SystemInfo)
+ row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
+ require.NoError(t, err)
+ err = r.UnmarshalBinary([]byte(row))
+ require.NoError(t, err)
+ require.Equal(t, info.Version, r.Version)
+ require.Equal(t, info.BytesReceived, r.BytesReceived)
+ require.NotSame(t, info, r)
+}
+
+func TestOnSysInfoTickClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+ h.OnSysInfoTick(new(system.Info))
+}
+func TestOnSysInfoTickNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ h.OnSysInfoTick(new(system.Info))
+}
+
+func TestStoredClients(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ defer teardown(t, h)
+
+ // populate with clients
+ err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl1", &storage.Client{ID: "cl1", T: storage.ClientKey}).Err()
+ require.NoError(t, err)
+
+ err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl2", &storage.Client{ID: "cl2", T: storage.ClientKey}).Err()
+ require.NoError(t, err)
+
+ err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl3", &storage.Client{ID: "cl3", T: storage.ClientKey}).Err()
+ require.NoError(t, err)
+
+ r, err := h.StoredClients()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+
+ sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
+ require.Equal(t, "cl1", r[0].ID)
+ require.Equal(t, "cl2", r[1].ID)
+ require.Equal(t, "cl3", r[2].ID)
+}
+
+func TestStoredClientsNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ v, err := h.StoredClients()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestStoredClientsClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+
+ v, err := h.StoredClients()
+ require.Empty(t, v)
+ require.Error(t, err)
+}
+
+func TestStoredSubscriptions(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ defer teardown(t, h)
+
+ // populate with subscriptions
+ err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey}).Err()
+ require.NoError(t, err)
+
+ err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey}).Err()
+ require.NoError(t, err)
+
+ err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey}).Err()
+ require.NoError(t, err)
+
+ r, err := h.StoredSubscriptions()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
+ require.Equal(t, "sub1", r[0].ID)
+ require.Equal(t, "sub2", r[1].ID)
+ require.Equal(t, "sub3", r[2].ID)
+}
+
+func TestStoredSubscriptionsNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ v, err := h.StoredSubscriptions()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestStoredSubscriptionsClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+
+ v, err := h.StoredSubscriptions()
+ require.Empty(t, v)
+ require.Error(t, err)
+}
+
+func TestStoredRetainedMessages(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ defer teardown(t, h)
+
+ // populate with messages
+ err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m1", &storage.Message{ID: "m1", T: storage.RetainedKey}).Err()
+ require.NoError(t, err)
+
+ err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m2", &storage.Message{ID: "m2", T: storage.RetainedKey}).Err()
+ require.NoError(t, err)
+
+ err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
+ require.NoError(t, err)
+
+ err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
+ require.NoError(t, err)
+
+ r, err := h.StoredRetainedMessages()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
+ require.Equal(t, "m1", r[0].ID)
+ require.Equal(t, "m2", r[1].ID)
+ require.Equal(t, "m3", r[2].ID)
+}
+
+func TestStoredRetainedMessagesNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ v, err := h.StoredRetainedMessages()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestStoredRetainedMessagesClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+
+ v, err := h.StoredRetainedMessages()
+ require.Empty(t, v)
+ require.Error(t, err)
+}
+
+func TestStoredInflightMessages(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ defer teardown(t, h)
+
+ // populate with messages
+ err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1", &storage.Message{ID: "i1", T: storage.InflightKey}).Err()
+ require.NoError(t, err)
+
+ err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2", &storage.Message{ID: "i2", T: storage.InflightKey}).Err()
+ require.NoError(t, err)
+
+ err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
+ require.NoError(t, err)
+
+ err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
+ require.NoError(t, err)
+
+ r, err := h.StoredInflightMessages()
+ require.NoError(t, err)
+ require.Len(t, r, 3)
+ sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
+ require.Equal(t, "i1", r[0].ID)
+ require.Equal(t, "i2", r[1].ID)
+ require.Equal(t, "i3", r[2].ID)
+}
+
+func TestStoredInflightMessagesNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ v, err := h.StoredInflightMessages()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestStoredInflightMessagesClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+
+ v, err := h.StoredInflightMessages()
+ require.Empty(t, v)
+ require.Error(t, err)
+}
+
+func TestStoredSysInfo(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ defer teardown(t, h)
+
+ // populate with sys info
+ err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey,
+ &storage.SystemInfo{
+ ID: storage.SysInfoKey,
+ Info: system.Info{
+ Version: "2.0.0",
+ },
+ T: storage.SysInfoKey,
+ }).Err()
+ require.NoError(t, err)
+
+ r, err := h.StoredSysInfo()
+ require.NoError(t, err)
+ require.Equal(t, "2.0.0", r.Info.Version)
+}
+
+func TestStoredSysInfoNoDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ h.db = nil
+ v, err := h.StoredSysInfo()
+ require.Empty(t, v)
+ require.NoError(t, err)
+}
+
+func TestStoredSysInfoClosedDB(t *testing.T) {
+ s := miniredis.RunT(t)
+ defer s.Close()
+ h := newHook(t, s.Addr())
+ teardown(t, h)
+
+ v, err := h.StoredSysInfo()
+ require.Empty(t, v)
+ require.Error(t, err)
+}
diff --git a/hooks/storage/storage.go b/hooks/storage/storage.go
new file mode 100644
index 0000000..7fd3d13
--- /dev/null
+++ b/hooks/storage/storage.go
@@ -0,0 +1,213 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package storage
+
+import (
+ "encoding/json"
+ "errors"
+
+ "testmqtt/packets"
+ "testmqtt/system"
+)
+
+const (
+ // SubscriptionKey 唯一标识订阅信息。在存储系统中,这个键用于检索和管理客户端的订阅信息
+ // SubscriptionKey unique key to denote Subscriptions in a store
+ SubscriptionKey = "SUB"
+ // SysInfoKey 唯一标识服务器系统信息。在存储系统中,这个键用于存储和检索与服务器状态或系统配置相关的信息。
+ // SysInfoKey unique key to denote server system information in a store
+ SysInfoKey = "SYS"
+ // RetainedKey 唯一标识保留的消息。保留消息是在订阅时立即发送的消息,即使订阅者在消息发布时不在线。这个键用于存储这些保留消息。
+ // RetainedKey unique key to denote retained messages in a store
+ RetainedKey = "RET"
+ // InflightKey 唯一标识飞行中的消息。飞行中的消息指的是已经发布但尚未确认的消息,这个键用于跟踪这些消息的状态。
+ // InflightKey unique key to denote inflight messages in a store
+ InflightKey = "IFM"
+ // ClientKey 唯一标识客户端信息。在存储系统中,这个键用于检索和管理有关客户端的信息,如客户端状态、连接信息等。
+ // ClientKey unique key to denote clients in a store
+ ClientKey = "CL"
+)
+
+var (
+ // ErrDBFileNotOpen 数据库文件没有打开
+ // ErrDBFileNotOpen indicates that the file database (e.g. bolt/badger) wasn't open for reading.
+ ErrDBFileNotOpen = errors.New("数据库文件没有打开")
+)
+
+// Serializable is an interface for objects that can be serialized and deserialized.
+type Serializable interface {
+ UnmarshalBinary([]byte) error
+ MarshalBinary() (data []byte, err error)
+}
+
+// Client is a storable representation of an MQTT client.
+type Client struct {
+ Will ClientWill `json:"will"` // will topic and payload data if applicable
+ Properties ClientProperties `json:"properties"` // the connect properties for the client
+ Username []byte `json:"username"` // the username of the client
+ ID string `json:"id" storm:"id"` // the client id / storage key
+ T string `json:"t"` // the data type (client)
+ Remote string `json:"remote"` // the remote address of the client
+ Listener string `json:"listener"` // the listener the client connected on
+ ProtocolVersion byte `json:"protocolVersion"` // mqtt protocol version of the client
+ Clean bool `json:"clean"` // if the client requested a clean start/session
+}
+
+// ClientProperties contains a limited set of the mqtt v5 properties specific to a client connection.
+type ClientProperties struct {
+ AuthenticationData []byte `json:"authenticationData,omitempty"`
+ User []packets.UserProperty `json:"user,omitempty"`
+ AuthenticationMethod string `json:"authenticationMethod,omitempty"`
+ SessionExpiryInterval uint32 `json:"sessionExpiryInterval,omitempty"`
+ MaximumPacketSize uint32 `json:"maximumPacketSize,omitempty"`
+ ReceiveMaximum uint16 `json:"receiveMaximum,omitempty"`
+ TopicAliasMaximum uint16 `json:"topicAliasMaximum,omitempty"`
+ SessionExpiryIntervalFlag bool `json:"sessionExpiryIntervalFlag,omitempty"`
+ RequestProblemInfo byte `json:"requestProblemInfo,omitempty"`
+ RequestProblemInfoFlag bool `json:"requestProblemInfoFlag,omitempty"`
+ RequestResponseInfo byte `json:"requestResponseInfo,omitempty"`
+}
+
+// ClientWill contains a will message for a client, and limited mqtt v5 properties.
+type ClientWill struct {
+ Payload []byte `json:"payload,omitempty"`
+ User []packets.UserProperty `json:"user,omitempty"`
+ TopicName string `json:"topicName,omitempty"`
+ Flag uint32 `json:"flag,omitempty"`
+ WillDelayInterval uint32 `json:"willDelayInterval,omitempty"`
+ Qos byte `json:"qos,omitempty"`
+ Retain bool `json:"retain,omitempty"`
+}
+
+// MarshalBinary encodes the values into a json string.
+func (d Client) MarshalBinary() (data []byte, err error) {
+ return json.Marshal(d)
+}
+
+// UnmarshalBinary decodes a json string into a struct.
+func (d *Client) UnmarshalBinary(data []byte) error {
+ if len(data) == 0 {
+ return nil
+ }
+ return json.Unmarshal(data, d)
+}
+
+// Message is a storable representation of an MQTT message (specifically publish).
+type Message struct {
+ Properties MessageProperties `json:"properties"` // -
+ Payload []byte `json:"payload"` // the message payload (if retained)
+ T string `json:"t,omitempty"` // the data type
+ ID string `json:"id,omitempty" storm:"id"` // the storage key
+ Origin string `json:"origin,omitempty"` // the id of the client who sent the message
+ TopicName string `json:"topic_name,omitempty"` // the topic the message was sent to (if retained)
+ FixedHeader packets.FixedHeader `json:"fixedheader"` // the header properties of the message
+ Created int64 `json:"created,omitempty"` // the time the message was created in unixtime
+ Sent int64 `json:"sent,omitempty"` // the last time the message was sent (for retries) in unixtime (if inflight)
+ PacketID uint16 `json:"packet_id,omitempty"` // the unique id of the packet (if inflight)
+}
+
+// MessageProperties contains a limited subset of mqtt v5 properties specific to publish messages.
+type MessageProperties struct {
+ CorrelationData []byte `json:"correlationData,omitempty"`
+ SubscriptionIdentifier []int `json:"subscriptionIdentifier,omitempty"`
+ User []packets.UserProperty `json:"user,omitempty"`
+ ContentType string `json:"contentType,omitempty"`
+ ResponseTopic string `json:"responseTopic,omitempty"`
+ MessageExpiryInterval uint32 `json:"messageExpiry,omitempty"`
+ TopicAlias uint16 `json:"topicAlias,omitempty"`
+ PayloadFormat byte `json:"payloadFormat,omitempty"`
+ PayloadFormatFlag bool `json:"payloadFormatFlag,omitempty"`
+}
+
+// MarshalBinary encodes the values into a json string.
+func (d Message) MarshalBinary() (data []byte, err error) {
+ return json.Marshal(d)
+}
+
+// UnmarshalBinary decodes a json string into a struct.
+func (d *Message) UnmarshalBinary(data []byte) error {
+ if len(data) == 0 {
+ return nil
+ }
+ return json.Unmarshal(data, d)
+}
+
+// ToPacket converts a storage.Message to a standard packet.
+func (d *Message) ToPacket() packets.Packet {
+ pk := packets.Packet{
+ FixedHeader: d.FixedHeader,
+ PacketID: d.PacketID,
+ TopicName: d.TopicName,
+ Payload: d.Payload,
+ Origin: d.Origin,
+ Created: d.Created,
+ Properties: packets.Properties{
+ PayloadFormat: d.Properties.PayloadFormat,
+ PayloadFormatFlag: d.Properties.PayloadFormatFlag,
+ MessageExpiryInterval: d.Properties.MessageExpiryInterval,
+ ContentType: d.Properties.ContentType,
+ ResponseTopic: d.Properties.ResponseTopic,
+ CorrelationData: d.Properties.CorrelationData,
+ SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
+ TopicAlias: d.Properties.TopicAlias,
+ User: d.Properties.User,
+ },
+ }
+
+ // Return a deep copy of the packet data otherwise the slices will
+ // continue pointing at the values from the storage packet.
+ pk = pk.Copy(true)
+ pk.FixedHeader.Dup = d.FixedHeader.Dup
+
+ return pk
+}
+
+// Subscription is a storable representation of an MQTT subscription.
+type Subscription struct {
+ T string `json:"t,omitempty"`
+ ID string `json:"id,omitempty" storm:"id"`
+ Client string `json:"client,omitempty"`
+ Filter string `json:"filter"`
+ Identifier int `json:"identifier,omitempty"`
+ RetainHandling byte `json:"retain_handling,omitempty"`
+ Qos byte `json:"qos"`
+ RetainAsPublished bool `json:"retain_as_pub,omitempty"`
+ NoLocal bool `json:"no_local,omitempty"`
+}
+
+// MarshalBinary encodes the values into a json string.
+func (d Subscription) MarshalBinary() (data []byte, err error) {
+ return json.Marshal(d)
+}
+
+// UnmarshalBinary decodes a json string into a struct.
+func (d *Subscription) UnmarshalBinary(data []byte) error {
+ if len(data) == 0 {
+ return nil
+ }
+ return json.Unmarshal(data, d)
+}
+
+// SystemInfo is a storable representation of the system information values.
+type SystemInfo struct {
+ system.Info // embed the system info struct
+ T string `json:"t"` // the data type
+ ID string `json:"id" storm:"id"` // the storage key
+}
+
+// MarshalBinary 将值编码为json字符串
+// MarshalBinary encodes the values into a json string.
+func (d SystemInfo) MarshalBinary() (data []byte, err error) {
+ return json.Marshal(d)
+}
+
+// UnmarshalBinary 将json字符串解码为结构体
+// UnmarshalBinary decodes a json string into a struct.
+func (d *SystemInfo) UnmarshalBinary(data []byte) error {
+ if len(data) == 0 {
+ return nil
+ }
+ return json.Unmarshal(data, d)
+}
diff --git a/hooks/storage/storage_test.go b/hooks/storage/storage_test.go
new file mode 100644
index 0000000..835cbae
--- /dev/null
+++ b/hooks/storage/storage_test.go
@@ -0,0 +1,228 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package storage
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "testmqtt/packets"
+ "testmqtt/system"
+)
+
+var (
+ clientStruct = Client{
+ ID: "test",
+ T: "client",
+ Remote: "remote",
+ Listener: "listener",
+ Username: []byte("mochi"),
+ Clean: true,
+ Properties: ClientProperties{
+ SessionExpiryInterval: 2,
+ SessionExpiryIntervalFlag: true,
+ AuthenticationMethod: "a",
+ AuthenticationData: []byte("test"),
+ RequestProblemInfo: 1,
+ RequestProblemInfoFlag: true,
+ RequestResponseInfo: 1,
+ ReceiveMaximum: 128,
+ TopicAliasMaximum: 256,
+ User: []packets.UserProperty{
+ {Key: "k", Val: "v"},
+ },
+ MaximumPacketSize: 120,
+ },
+ Will: ClientWill{
+ Qos: 1,
+ Payload: []byte("abc"),
+ TopicName: "a/b/c",
+ Flag: 1,
+ Retain: true,
+ WillDelayInterval: 2,
+ User: []packets.UserProperty{
+ {Key: "k2", Val: "v2"},
+ },
+ },
+ }
+ clientJSON = []byte(`{"will":{"payload":"YWJj","user":[{"k":"k2","v":"v2"}],"topicName":"a/b/c","flag":1,"willDelayInterval":2,"qos":1,"retain":true},"properties":{"authenticationData":"dGVzdA==","user":[{"k":"k","v":"v"}],"authenticationMethod":"a","sessionExpiryInterval":2,"maximumPacketSize":120,"receiveMaximum":128,"topicAliasMaximum":256,"sessionExpiryIntervalFlag":true,"requestProblemInfo":1,"requestProblemInfoFlag":true,"requestResponseInfo":1},"username":"bW9jaGk=","id":"test","t":"client","remote":"remote","listener":"listener","protocolVersion":0,"clean":true}`)
+
+ messageStruct = Message{
+ T: "message",
+ Payload: []byte("payload"),
+ FixedHeader: packets.FixedHeader{
+ Remaining: 2,
+ Type: 3,
+ Qos: 1,
+ Dup: true,
+ Retain: true,
+ },
+ ID: "id",
+ Origin: "mochi",
+ TopicName: "topic",
+ Properties: MessageProperties{
+ PayloadFormat: 1,
+ PayloadFormatFlag: true,
+ MessageExpiryInterval: 20,
+ ContentType: "type",
+ ResponseTopic: "a/b/r",
+ CorrelationData: []byte("r"),
+ SubscriptionIdentifier: []int{1},
+ TopicAlias: 2,
+ User: []packets.UserProperty{
+ {Key: "k2", Val: "v2"},
+ },
+ },
+ Created: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
+ Sent: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
+ PacketID: 100,
+ }
+ messageJSON = []byte(`{"properties":{"correlationData":"cg==","subscriptionIdentifier":[1],"user":[{"k":"k2","v":"v2"}],"contentType":"type","responseTopic":"a/b/r","messageExpiry":20,"topicAlias":2,"payloadFormat":1,"payloadFormatFlag":true},"payload":"cGF5bG9hZA==","t":"message","id":"id","origin":"mochi","topic_name":"topic","fixedheader":{"remaining":2,"type":3,"qos":1,"dup":true,"retain":true},"created":1569027723,"sent":1569027723,"packet_id":100}`)
+
+ subscriptionStruct = Subscription{
+ T: "subscription",
+ ID: "id",
+ Client: "mochi",
+ Filter: "a/b/c",
+ Qos: 1,
+ }
+ subscriptionJSON = []byte(`{"t":"subscription","id":"id","client":"mochi","filter":"a/b/c","qos":1}`)
+
+ sysInfoStruct = SystemInfo{
+ T: "info",
+ ID: "id",
+ Info: system.Info{
+ Version: "2.0.0",
+ Started: 1,
+ Uptime: 2,
+ BytesReceived: 3,
+ BytesSent: 4,
+ ClientsConnected: 5,
+ ClientsMaximum: 7,
+ MessagesReceived: 10,
+ MessagesSent: 11,
+ MessagesDropped: 20,
+ PacketsReceived: 12,
+ PacketsSent: 13,
+ Retained: 15,
+ Inflight: 16,
+ InflightDropped: 17,
+ },
+ }
+ sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
+)
+
+func TestClientMarshalBinary(t *testing.T) {
+ data, err := clientStruct.MarshalBinary()
+ require.NoError(t, err)
+ require.JSONEq(t, string(clientJSON), string(data))
+}
+
+func TestClientUnmarshalBinary(t *testing.T) {
+ d := clientStruct
+ err := d.UnmarshalBinary(clientJSON)
+ require.NoError(t, err)
+ require.Equal(t, clientStruct, d)
+}
+
+func TestClientUnmarshalBinaryEmpty(t *testing.T) {
+ d := Client{}
+ err := d.UnmarshalBinary([]byte{})
+ require.NoError(t, err)
+ require.Equal(t, Client{}, d)
+}
+
+func TestMessageMarshalBinary(t *testing.T) {
+ data, err := messageStruct.MarshalBinary()
+ require.NoError(t, err)
+ require.JSONEq(t, string(messageJSON), string(data))
+}
+
+func TestMessageUnmarshalBinary(t *testing.T) {
+ d := messageStruct
+ err := d.UnmarshalBinary(messageJSON)
+ require.NoError(t, err)
+ require.Equal(t, messageStruct, d)
+}
+
+func TestMessageUnmarshalBinaryEmpty(t *testing.T) {
+ d := Message{}
+ err := d.UnmarshalBinary([]byte{})
+ require.NoError(t, err)
+ require.Equal(t, Message{}, d)
+}
+
+func TestSubscriptionMarshalBinary(t *testing.T) {
+ data, err := subscriptionStruct.MarshalBinary()
+ require.NoError(t, err)
+ require.JSONEq(t, string(subscriptionJSON), string(data))
+}
+
+func TestSubscriptionUnmarshalBinary(t *testing.T) {
+ d := subscriptionStruct
+ err := d.UnmarshalBinary(subscriptionJSON)
+ require.NoError(t, err)
+ require.Equal(t, subscriptionStruct, d)
+}
+
+func TestSubscriptionUnmarshalBinaryEmpty(t *testing.T) {
+ d := Subscription{}
+ err := d.UnmarshalBinary([]byte{})
+ require.NoError(t, err)
+ require.Equal(t, Subscription{}, d)
+}
+
+func TestSysInfoMarshalBinary(t *testing.T) {
+ data, err := sysInfoStruct.MarshalBinary()
+ require.NoError(t, err)
+ require.JSONEq(t, string(sysInfoJSON), string(data))
+}
+
+func TestSysInfoUnmarshalBinary(t *testing.T) {
+ d := sysInfoStruct
+ err := d.UnmarshalBinary(sysInfoJSON)
+ require.NoError(t, err)
+ require.Equal(t, sysInfoStruct, d)
+}
+
+func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) {
+ d := SystemInfo{}
+ err := d.UnmarshalBinary([]byte{})
+ require.NoError(t, err)
+ require.Equal(t, SystemInfo{}, d)
+}
+
+func TestMessageToPacket(t *testing.T) {
+ d := messageStruct
+ pk := d.ToPacket()
+
+ require.Equal(t, packets.Packet{
+ Payload: []byte("payload"),
+ FixedHeader: packets.FixedHeader{
+ Remaining: d.FixedHeader.Remaining,
+ Type: d.FixedHeader.Type,
+ Qos: d.FixedHeader.Qos,
+ Dup: d.FixedHeader.Dup,
+ Retain: d.FixedHeader.Retain,
+ },
+ Origin: d.Origin,
+ TopicName: d.TopicName,
+ Properties: packets.Properties{
+ PayloadFormat: d.Properties.PayloadFormat,
+ PayloadFormatFlag: d.Properties.PayloadFormatFlag,
+ MessageExpiryInterval: d.Properties.MessageExpiryInterval,
+ ContentType: d.Properties.ContentType,
+ ResponseTopic: d.Properties.ResponseTopic,
+ CorrelationData: d.Properties.CorrelationData,
+ SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
+ TopicAlias: d.Properties.TopicAlias,
+ User: d.Properties.User,
+ },
+ PacketID: 100,
+ Created: d.Created,
+ }, pk)
+
+}
diff --git a/listeners/http_healthcheck.go b/listeners/http_healthcheck.go
new file mode 100644
index 0000000..fc0f13d
--- /dev/null
+++ b/listeners/http_healthcheck.go
@@ -0,0 +1,99 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
+// SPDX-FileContributor: Derek Duncan
+
+package listeners
+
+import (
+ "context"
+ "log/slog"
+ "net/http"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+const TypeHealthCheck = "healthcheck"
+
+// HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint.
+type HTTPHealthCheck struct {
+ sync.RWMutex
+ id string // the internal id of the listener
+ address string // the network address to bind to
+ config Config // configuration values for the listener
+ listen *http.Server // the http server
+ end uint32 // ensure the close methods are only called once
+}
+
+// NewHTTPHealthCheck initializes and returns a new HTTP listener, listening on an address.
+func NewHTTPHealthCheck(config Config) *HTTPHealthCheck {
+ return &HTTPHealthCheck{
+ id: config.ID,
+ address: config.Address,
+ config: config,
+ }
+}
+
+// ID returns the id of the listener.
+func (l *HTTPHealthCheck) ID() string {
+ return l.id
+}
+
+// Address returns the address of the listener.
+func (l *HTTPHealthCheck) Address() string {
+ return l.address
+}
+
+// Protocol returns the address of the listener.
+func (l *HTTPHealthCheck) Protocol() string {
+ if l.listen != nil && l.listen.TLSConfig != nil {
+ return "https"
+ }
+
+ return "http"
+}
+
+// Init initializes the listener.
+func (l *HTTPHealthCheck) Init(_ *slog.Logger) error {
+ mux := http.NewServeMux()
+ mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ }
+ })
+ l.listen = &http.Server{
+ ReadTimeout: 5 * time.Second,
+ WriteTimeout: 5 * time.Second,
+ Addr: l.address,
+ Handler: mux,
+ }
+
+ if l.config.TLSConfig != nil {
+ l.listen.TLSConfig = l.config.TLSConfig
+ }
+
+ return nil
+}
+
+// Serve starts listening for new connections and serving responses.
+func (l *HTTPHealthCheck) Serve(establish EstablishFn) {
+ if l.listen.TLSConfig != nil {
+ _ = l.listen.ListenAndServeTLS("", "")
+ } else {
+ _ = l.listen.ListenAndServe()
+ }
+}
+
+// Close closes the listener and any client connections.
+func (l *HTTPHealthCheck) Close(closeClients CloseFn) {
+ l.Lock()
+ defer l.Unlock()
+
+ if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ _ = l.listen.Shutdown(ctx)
+ }
+
+ closeClients(l.id)
+}
diff --git a/listeners/http_healthcheck_test.go b/listeners/http_healthcheck_test.go
new file mode 100644
index 0000000..66a693c
--- /dev/null
+++ b/listeners/http_healthcheck_test.go
@@ -0,0 +1,137 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
+// SPDX-FileContributor: Derek Duncan
+
+package listeners
+
+import (
+ "io"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewHTTPHealthCheck(t *testing.T) {
+ l := NewHTTPHealthCheck(basicConfig)
+ require.Equal(t, basicConfig.ID, l.id)
+ require.Equal(t, basicConfig.Address, l.address)
+}
+
+func TestHTTPHealthCheckID(t *testing.T) {
+ l := NewHTTPHealthCheck(basicConfig)
+ require.Equal(t, basicConfig.ID, l.ID())
+}
+
+func TestHTTPHealthCheckAddress(t *testing.T) {
+ l := NewHTTPHealthCheck(basicConfig)
+ require.Equal(t, basicConfig.Address, l.Address())
+}
+
+func TestHTTPHealthCheckProtocol(t *testing.T) {
+ l := NewHTTPHealthCheck(basicConfig)
+ require.Equal(t, "http", l.Protocol())
+}
+
+func TestHTTPHealthCheckTLSProtocol(t *testing.T) {
+ l := NewHTTPHealthCheck(tlsConfig)
+ _ = l.Init(logger)
+ require.Equal(t, "https", l.Protocol())
+}
+
+func TestHTTPHealthCheckInit(t *testing.T) {
+ l := NewHTTPHealthCheck(basicConfig)
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ require.NotNil(t, l.listen)
+ require.Equal(t, basicConfig.Address, l.listen.Addr)
+}
+
+func TestHTTPHealthCheckServeAndClose(t *testing.T) {
+ // setup http stats listener
+ l := NewHTTPHealthCheck(basicConfig)
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ go func(o chan bool) {
+ l.Serve(MockEstablisher)
+ o <- true
+ }(o)
+
+ time.Sleep(time.Millisecond)
+
+ // call healthcheck
+ resp, err := http.Get("http://localhost" + testAddr + "/healthcheck")
+ require.NoError(t, err)
+ require.NotNil(t, resp)
+
+ defer resp.Body.Close()
+ _, err = io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ // ensure listening is closed
+ var closed bool
+ l.Close(func(id string) {
+ closed = true
+ })
+
+ require.Equal(t, true, closed)
+
+ _, err = http.Get("http://localhost/healthcheck" + testAddr + "/healthcheck")
+ require.Error(t, err)
+ <-o
+}
+
+func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) {
+ // setup http stats listener
+ l := NewHTTPHealthCheck(basicConfig)
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ go func(o chan bool) {
+ l.Serve(MockEstablisher)
+ o <- true
+ }(o)
+
+ time.Sleep(time.Millisecond)
+
+ // make disallowed method type http request
+ resp, err := http.Post("http://localhost"+testAddr+"/healthcheck", "application/json", http.NoBody)
+ require.NoError(t, err)
+ require.NotNil(t, resp)
+
+ defer resp.Body.Close()
+ _, err = io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ // ensure listening is closed
+ var closed bool
+ l.Close(func(id string) {
+ closed = true
+ })
+
+ require.Equal(t, true, closed)
+
+ _, err = http.Post("http://localhost/healthcheck"+testAddr+"/healthcheck", "application/json", http.NoBody)
+ require.Error(t, err)
+ <-o
+}
+
+func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) {
+ l := NewHTTPHealthCheck(tlsConfig)
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ go func(o chan bool) {
+ l.Serve(MockEstablisher)
+ o <- true
+ }(o)
+
+ time.Sleep(time.Millisecond)
+ l.Close(MockCloser)
+}
diff --git a/listeners/http_sysinfo.go b/listeners/http_sysinfo.go
new file mode 100644
index 0000000..03b2da8
--- /dev/null
+++ b/listeners/http_sysinfo.go
@@ -0,0 +1,122 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package listeners
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "log/slog"
+ "net/http"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "testmqtt/system"
+)
+
+const TypeSysInfo = "sysinfo"
+
+// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
+type HTTPStats struct {
+ sync.RWMutex
+ id string // the internal id of the listener
+ address string // the network address to bind to
+ config Config // configuration values for the listener
+ listen *http.Server // the http server
+ sysInfo *system.Info // pointers to the server data
+ log *slog.Logger // server logger
+ end uint32 // ensure the close methods are only called once
+}
+
+// NewHTTPStats initializes and returns a new HTTP listener, listening on an address.
+func NewHTTPStats(config Config, sysInfo *system.Info) *HTTPStats {
+ return &HTTPStats{
+ sysInfo: sysInfo,
+ id: config.ID,
+ address: config.Address,
+ config: config,
+ }
+}
+
+// ID returns the id of the listener.
+func (l *HTTPStats) ID() string {
+ return l.id
+}
+
+// Address returns the address of the listener.
+func (l *HTTPStats) Address() string {
+ return l.address
+}
+
+// Protocol returns the address of the listener.
+func (l *HTTPStats) Protocol() string {
+ if l.listen != nil && l.listen.TLSConfig != nil {
+ return "https"
+ }
+
+ return "http"
+}
+
+// Init initializes the listener.
+func (l *HTTPStats) Init(log *slog.Logger) error {
+ l.log = log
+ mux := http.NewServeMux()
+ mux.HandleFunc("/", l.jsonHandler)
+ l.listen = &http.Server{
+ ReadTimeout: 5 * time.Second,
+ WriteTimeout: 5 * time.Second,
+ Addr: l.address,
+ Handler: mux,
+ }
+
+ if l.config.TLSConfig != nil {
+ l.listen.TLSConfig = l.config.TLSConfig
+ }
+
+ return nil
+}
+
+// Serve starts listening for new connections and serving responses.
+func (l *HTTPStats) Serve(establish EstablishFn) {
+
+ var err error
+ if l.listen.TLSConfig != nil {
+ err = l.listen.ListenAndServeTLS("", "")
+ } else {
+ err = l.listen.ListenAndServe()
+ }
+
+ // After the listener has been shutdown, no need to print the http.ErrServerClosed error.
+ if err != nil && atomic.LoadUint32(&l.end) == 0 {
+ l.log.Error("failed to serve.", "error", err, "listener", l.id)
+ }
+}
+
+// Close closes the listener and any client connections.
+func (l *HTTPStats) Close(closeClients CloseFn) {
+ l.Lock()
+ defer l.Unlock()
+
+ if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ _ = l.listen.Shutdown(ctx)
+ }
+
+ closeClients(l.id)
+}
+
+// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
+func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
+ info := *l.sysInfo.Clone()
+
+ out, err := json.MarshalIndent(info, "", "\t")
+ if err != nil {
+ _, _ = io.WriteString(w, err.Error())
+ }
+
+ _, _ = w.Write(out)
+}
diff --git a/listeners/http_sysinfo_test.go b/listeners/http_sysinfo_test.go
new file mode 100644
index 0000000..c35c077
--- /dev/null
+++ b/listeners/http_sysinfo_test.go
@@ -0,0 +1,149 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package listeners
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "testing"
+ "time"
+
+ "testmqtt/system"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewHTTPStats(t *testing.T) {
+ l := NewHTTPStats(basicConfig, nil)
+ require.Equal(t, "t1", l.id)
+ require.Equal(t, testAddr, l.address)
+}
+
+func TestHTTPStatsID(t *testing.T) {
+ l := NewHTTPStats(basicConfig, nil)
+ require.Equal(t, "t1", l.ID())
+}
+
+func TestHTTPStatsAddress(t *testing.T) {
+ l := NewHTTPStats(basicConfig, nil)
+ require.Equal(t, testAddr, l.Address())
+}
+
+func TestHTTPStatsProtocol(t *testing.T) {
+ l := NewHTTPStats(basicConfig, nil)
+ require.Equal(t, "http", l.Protocol())
+}
+
+func TestHTTPStatsTLSProtocol(t *testing.T) {
+ l := NewHTTPStats(tlsConfig, nil)
+ _ = l.Init(logger)
+ require.Equal(t, "https", l.Protocol())
+}
+
+func TestHTTPStatsInit(t *testing.T) {
+ sysInfo := new(system.Info)
+ l := NewHTTPStats(basicConfig, sysInfo)
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ require.NotNil(t, l.sysInfo)
+ require.Equal(t, sysInfo, l.sysInfo)
+ require.NotNil(t, l.listen)
+ require.Equal(t, testAddr, l.listen.Addr)
+}
+
+func TestHTTPStatsServeAndClose(t *testing.T) {
+ sysInfo := &system.Info{
+ Version: "test",
+ }
+
+ // setup http stats listener
+ l := NewHTTPStats(basicConfig, sysInfo)
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ go func(o chan bool) {
+ l.Serve(MockEstablisher)
+ o <- true
+ }(o)
+
+ time.Sleep(time.Millisecond)
+
+ // get body from stats address
+ resp, err := http.Get("http://localhost" + testAddr)
+ require.NoError(t, err)
+ require.NotNil(t, resp)
+
+ defer resp.Body.Close()
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ // decode body from json and check data
+ v := new(system.Info)
+ err = json.Unmarshal(body, v)
+ require.NoError(t, err)
+ require.Equal(t, "test", v.Version)
+
+ // ensure listening is closed
+ var closed bool
+ l.Close(func(id string) {
+ closed = true
+ })
+
+ require.Equal(t, true, closed)
+
+ _, err = http.Get("http://localhost" + testAddr)
+ require.Error(t, err)
+ <-o
+}
+
+func TestHTTPStatsServeTLSAndClose(t *testing.T) {
+ sysInfo := &system.Info{
+ Version: "test",
+ }
+
+ l := NewHTTPStats(tlsConfig, sysInfo)
+
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ go func(o chan bool) {
+ l.Serve(MockEstablisher)
+ o <- true
+ }(o)
+
+ time.Sleep(time.Millisecond)
+ l.Close(MockCloser)
+}
+
+func TestHTTPStatsFailedToServe(t *testing.T) {
+ sysInfo := &system.Info{
+ Version: "test",
+ }
+
+ // setup http stats listener
+ config := basicConfig
+ config.Address = "wrong_addr"
+ l := NewHTTPStats(config, sysInfo)
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ go func(o chan bool) {
+ l.Serve(MockEstablisher)
+ o <- true
+ }(o)
+
+ <-o
+ // ensure listening is closed
+ var closed bool
+ l.Close(func(id string) {
+ closed = true
+ })
+ require.Equal(t, true, closed)
+}
diff --git a/listeners/listeners.go b/listeners/listeners.go
new file mode 100644
index 0000000..ded7c37
--- /dev/null
+++ b/listeners/listeners.go
@@ -0,0 +1,135 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package listeners
+
+import (
+ "crypto/tls"
+ "net"
+ "sync"
+
+ "log/slog"
+)
+
+// Config contains configuration values for a listener.
+type Config struct {
+ Type string
+ ID string
+ Address string
+ // TLSConfig is a tls.Config configuration to be used with the listener. See examples folder for basic and mutual-tls use.
+ TLSConfig *tls.Config
+}
+
+// EstablishFn is a callback function for establishing new clients.
+type EstablishFn func(id string, c net.Conn) error
+
+// CloseFn is a callback function for closing all listener clients.
+type CloseFn func(id string)
+
+// Listener is an interface for network listeners. A network listener listens
+// for incoming client connections and adds them to the server.
+type Listener interface {
+ Init(*slog.Logger) error // open the network address
+ Serve(EstablishFn) // starting actively listening for new connections
+ ID() string // return the id of the listener
+ Address() string // the address of the listener
+ Protocol() string // the protocol in use by the listener
+ Close(CloseFn) // stop and close the listener
+}
+
+// Listeners contains the network listeners for the broker.
+type Listeners struct {
+ ClientsWg sync.WaitGroup // a waitgroup that waits for all clients in all listeners to finish.
+ internal map[string]Listener // a map of active listeners.
+ sync.RWMutex
+}
+
+// New returns a new instance of Listeners.
+func New() *Listeners {
+ return &Listeners{
+ internal: map[string]Listener{},
+ }
+}
+
+// Add adds a new listener to the listeners map, keyed on id.
+func (l *Listeners) Add(val Listener) {
+ l.Lock()
+ defer l.Unlock()
+ l.internal[val.ID()] = val
+}
+
+// Get returns the value of a listener if it exists.
+func (l *Listeners) Get(id string) (Listener, bool) {
+ l.RLock()
+ defer l.RUnlock()
+ val, ok := l.internal[id]
+ return val, ok
+}
+
+// Len returns the length of the listeners map.
+func (l *Listeners) Len() int {
+ l.RLock()
+ defer l.RUnlock()
+ return len(l.internal)
+}
+
+// Delete removes a listener from the internal map.
+func (l *Listeners) Delete(id string) {
+ l.Lock()
+ defer l.Unlock()
+ delete(l.internal, id)
+}
+
+// Serve starts a listener serving from the internal map.
+func (l *Listeners) Serve(id string, establisher EstablishFn) {
+ l.RLock()
+ defer l.RUnlock()
+ listener := l.internal[id]
+
+ go func(e EstablishFn) {
+ listener.Serve(e)
+ }(establisher)
+}
+
+// ServeAll starts all listeners serving from the internal map.
+func (l *Listeners) ServeAll(establisher EstablishFn) {
+ l.RLock()
+ i := 0
+ ids := make([]string, len(l.internal))
+ for id := range l.internal {
+ ids[i] = id
+ i++
+ }
+ l.RUnlock()
+
+ for _, id := range ids {
+ l.Serve(id, establisher)
+ }
+}
+
+// Close stops a listener from the internal map.
+func (l *Listeners) Close(id string, closer CloseFn) {
+ l.RLock()
+ defer l.RUnlock()
+ if listener, ok := l.internal[id]; ok {
+ listener.Close(closer)
+ }
+}
+
+// CloseAll iterates and closes all registered listeners.
+func (l *Listeners) CloseAll(closer CloseFn) {
+ l.RLock()
+ i := 0
+ ids := make([]string, len(l.internal))
+ for id := range l.internal {
+ ids[i] = id
+ i++
+ }
+ l.RUnlock()
+
+ for _, id := range ids {
+ l.Close(id, closer)
+ }
+ l.ClientsWg.Wait()
+}
diff --git a/listeners/listeners_test.go b/listeners/listeners_test.go
new file mode 100644
index 0000000..6b4b025
--- /dev/null
+++ b/listeners/listeners_test.go
@@ -0,0 +1,182 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package listeners
+
+import (
+ "crypto/tls"
+ "log"
+ "os"
+ "testing"
+ "time"
+
+ "log/slog"
+
+ "github.com/stretchr/testify/require"
+)
+
+const testAddr = ":22222"
+
+var (
+ basicConfig = Config{ID: "t1", Address: testAddr}
+ tlsConfig = Config{ID: "t1", Address: testAddr, TLSConfig: tlsConfigBasic}
+
+ logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
+
+ testCertificate = []byte(`-----BEGIN CERTIFICATE-----
+MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
+VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
+BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
+VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
+DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
+AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
+OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
+MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
+gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
+qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
+zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
+-----END CERTIFICATE-----`)
+
+ testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
+MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
+FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
+rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
+AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
+UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
+n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
+mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
+INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
+AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
+/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
+WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
+w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
+OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
+-----END RSA PRIVATE KEY-----`)
+
+ tlsConfigBasic *tls.Config
+)
+
+func init() {
+ cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Basic TLS Config
+ tlsConfigBasic = &tls.Config{
+ MinVersion: tls.VersionTLS12,
+ Certificates: []tls.Certificate{cert},
+ }
+ tlsConfig.TLSConfig = tlsConfigBasic
+}
+
+func TestNew(t *testing.T) {
+ l := New()
+ require.NotNil(t, l.internal)
+}
+
+func TestAddListener(t *testing.T) {
+ l := New()
+ l.Add(NewMockListener("t1", testAddr))
+ require.Contains(t, l.internal, "t1")
+}
+
+func TestGetListener(t *testing.T) {
+ l := New()
+ l.Add(NewMockListener("t1", testAddr))
+ l.Add(NewMockListener("t2", testAddr))
+ require.Contains(t, l.internal, "t1")
+ require.Contains(t, l.internal, "t2")
+
+ g, ok := l.Get("t1")
+ require.True(t, ok)
+ require.Equal(t, g.ID(), "t1")
+}
+
+func TestLenListener(t *testing.T) {
+ l := New()
+ l.Add(NewMockListener("t1", testAddr))
+ l.Add(NewMockListener("t2", testAddr))
+ require.Contains(t, l.internal, "t1")
+ require.Contains(t, l.internal, "t2")
+ require.Equal(t, 2, l.Len())
+}
+
+func TestDeleteListener(t *testing.T) {
+ l := New()
+ l.Add(NewMockListener("t1", testAddr))
+ require.Contains(t, l.internal, "t1")
+ l.Delete("t1")
+ _, ok := l.Get("t1")
+ require.False(t, ok)
+ require.Nil(t, l.internal["t1"])
+}
+
+func TestServeListener(t *testing.T) {
+ l := New()
+ l.Add(NewMockListener("t1", testAddr))
+ l.Serve("t1", MockEstablisher)
+ time.Sleep(time.Millisecond)
+ require.True(t, l.internal["t1"].(*MockListener).IsServing())
+
+ l.Close("t1", MockCloser)
+ require.False(t, l.internal["t1"].(*MockListener).IsServing())
+}
+
+func TestServeAllListeners(t *testing.T) {
+ l := New()
+ l.Add(NewMockListener("t1", testAddr))
+ l.Add(NewMockListener("t2", testAddr))
+ l.Add(NewMockListener("t3", testAddr))
+ l.ServeAll(MockEstablisher)
+ time.Sleep(time.Millisecond)
+
+ require.True(t, l.internal["t1"].(*MockListener).IsServing())
+ require.True(t, l.internal["t2"].(*MockListener).IsServing())
+ require.True(t, l.internal["t3"].(*MockListener).IsServing())
+
+ l.Close("t1", MockCloser)
+ l.Close("t2", MockCloser)
+ l.Close("t3", MockCloser)
+
+ require.False(t, l.internal["t1"].(*MockListener).IsServing())
+ require.False(t, l.internal["t2"].(*MockListener).IsServing())
+ require.False(t, l.internal["t3"].(*MockListener).IsServing())
+}
+
+func TestCloseListener(t *testing.T) {
+ l := New()
+ mocked := NewMockListener("t1", testAddr)
+ l.Add(mocked)
+ l.Serve("t1", MockEstablisher)
+ time.Sleep(time.Millisecond)
+ var closed bool
+ l.Close("t1", func(id string) {
+ closed = true
+ })
+ require.True(t, closed)
+}
+
+func TestCloseAllListeners(t *testing.T) {
+ l := New()
+ l.Add(NewMockListener("t1", testAddr))
+ l.Add(NewMockListener("t2", testAddr))
+ l.Add(NewMockListener("t3", testAddr))
+ l.ServeAll(MockEstablisher)
+ time.Sleep(time.Millisecond)
+ require.True(t, l.internal["t1"].(*MockListener).IsServing())
+ require.True(t, l.internal["t2"].(*MockListener).IsServing())
+ require.True(t, l.internal["t3"].(*MockListener).IsServing())
+
+ closed := make(map[string]bool)
+ l.CloseAll(func(id string) {
+ closed[id] = true
+ })
+ require.Contains(t, closed, "t1")
+ require.Contains(t, closed, "t2")
+ require.Contains(t, closed, "t3")
+ require.True(t, closed["t1"])
+ require.True(t, closed["t2"])
+ require.True(t, closed["t3"])
+}
diff --git a/listeners/mock.go b/listeners/mock.go
new file mode 100644
index 0000000..1a67d89
--- /dev/null
+++ b/listeners/mock.go
@@ -0,0 +1,105 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package listeners
+
+import (
+ "fmt"
+ "net"
+ "sync"
+
+ "log/slog"
+)
+
+const TypeMock = "mock"
+
+// MockEstablisher is a function signature which can be used in testing.
+func MockEstablisher(id string, c net.Conn) error {
+ return nil
+}
+
+// MockCloser is a function signature which can be used in testing.
+func MockCloser(id string) {}
+
+// MockListener is a mock listener for establishing client connections.
+type MockListener struct {
+ sync.RWMutex
+ id string // the id of the listener
+ address string // the network address the listener binds to
+ Config *Config // configuration for the listener
+ done chan bool // indicate the listener is done
+ Serving bool // indicate the listener is serving
+ Listening bool // indiciate the listener is listening
+ ErrListen bool // throw an error on listen
+}
+
+// NewMockListener returns a new instance of MockListener.
+func NewMockListener(id, address string) *MockListener {
+ return &MockListener{
+ id: id,
+ address: address,
+ done: make(chan bool),
+ }
+}
+
+// Serve serves the mock listener.
+func (l *MockListener) Serve(establisher EstablishFn) {
+ l.Lock()
+ l.Serving = true
+ l.Unlock()
+
+ for range l.done {
+ return
+ }
+}
+
+// Init initializes the listener.
+func (l *MockListener) Init(log *slog.Logger) error {
+ if l.ErrListen {
+ return fmt.Errorf("listen failure")
+ }
+
+ l.Lock()
+ defer l.Unlock()
+ l.Listening = true
+ return nil
+}
+
+// ID returns the id of the mock listener.
+func (l *MockListener) ID() string {
+ return l.id
+}
+
+// Address returns the address of the listener.
+func (l *MockListener) Address() string {
+ return l.address
+}
+
+// Protocol returns the address of the listener.
+func (l *MockListener) Protocol() string {
+ return "mock"
+}
+
+// Close closes the mock listener.
+func (l *MockListener) Close(closer CloseFn) {
+ l.Lock()
+ defer l.Unlock()
+ l.Serving = false
+ closer(l.id)
+ close(l.done)
+}
+
+// IsServing indicates whether the mock listener is serving.
+func (l *MockListener) IsServing() bool {
+ l.Lock()
+ defer l.Unlock()
+ return l.Serving
+}
+
+// IsListening indicates whether the mock listener is listening.
+func (l *MockListener) IsListening() bool {
+ l.Lock()
+ defer l.Unlock()
+ return l.Listening
+}
diff --git a/listeners/mock_test.go b/listeners/mock_test.go
new file mode 100644
index 0000000..46aa922
--- /dev/null
+++ b/listeners/mock_test.go
@@ -0,0 +1,99 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package listeners
+
+import (
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestMockEstablisher(t *testing.T) {
+ _, w := net.Pipe()
+ err := MockEstablisher("t1", w)
+ require.NoError(t, err)
+ _ = w.Close()
+}
+
+func TestNewMockListener(t *testing.T) {
+ mocked := NewMockListener("t1", testAddr)
+ require.Equal(t, "t1", mocked.id)
+ require.Equal(t, testAddr, mocked.address)
+}
+func TestMockListenerID(t *testing.T) {
+ mocked := NewMockListener("t1", testAddr)
+ require.Equal(t, "t1", mocked.ID())
+}
+
+func TestMockListenerAddress(t *testing.T) {
+ mocked := NewMockListener("t1", testAddr)
+ require.Equal(t, testAddr, mocked.Address())
+}
+func TestMockListenerProtocol(t *testing.T) {
+ mocked := NewMockListener("t1", testAddr)
+ require.Equal(t, "mock", mocked.Protocol())
+}
+
+func TestNewMockListenerIsListening(t *testing.T) {
+ mocked := NewMockListener("t1", testAddr)
+ require.Equal(t, false, mocked.IsListening())
+}
+
+func TestNewMockListenerIsServing(t *testing.T) {
+ mocked := NewMockListener("t1", testAddr)
+ require.Equal(t, false, mocked.IsServing())
+}
+
+func TestNewMockListenerInit(t *testing.T) {
+ mocked := NewMockListener("t1", testAddr)
+ require.Equal(t, "t1", mocked.id)
+ require.Equal(t, testAddr, mocked.address)
+
+ require.Equal(t, false, mocked.IsListening())
+ err := mocked.Init(nil)
+ require.NoError(t, err)
+ require.Equal(t, true, mocked.IsListening())
+}
+
+func TestNewMockListenerInitFailure(t *testing.T) {
+ mocked := NewMockListener("t1", testAddr)
+ mocked.ErrListen = true
+ err := mocked.Init(nil)
+ require.Error(t, err)
+}
+
+func TestMockListenerServe(t *testing.T) {
+ mocked := NewMockListener("t1", testAddr)
+ require.Equal(t, false, mocked.IsServing())
+
+ o := make(chan bool)
+ go func(o chan bool) {
+ mocked.Serve(MockEstablisher)
+ o <- true
+ }(o)
+
+ time.Sleep(time.Millisecond) // easy non-channel wait for start of serving
+ require.Equal(t, true, mocked.IsServing())
+
+ var closed bool
+ mocked.Close(func(id string) {
+ closed = true
+ })
+ require.Equal(t, true, closed)
+ <-o
+
+ _ = mocked.Init(nil)
+}
+
+func TestMockListenerClose(t *testing.T) {
+ mocked := NewMockListener("t1", testAddr)
+ var closed bool
+ mocked.Close(func(id string) {
+ closed = true
+ })
+ require.Equal(t, true, closed)
+}
diff --git a/listeners/net.go b/listeners/net.go
new file mode 100644
index 0000000..fa4ef3d
--- /dev/null
+++ b/listeners/net.go
@@ -0,0 +1,92 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
+// SPDX-FileContributor: Jeroen Rinzema
+
+package listeners
+
+import (
+ "net"
+ "sync"
+ "sync/atomic"
+
+ "log/slog"
+)
+
+// Net is a listener for establishing client connections on basic TCP protocol.
+type Net struct { // [MQTT-4.2.0-1]
+ mu sync.Mutex
+ listener net.Listener // a net.Listener which will listen for new clients
+ id string // the internal id of the listener
+ log *slog.Logger // server logger
+ end uint32 // ensure the close methods are only called once
+}
+
+// NewNet initialises and returns a listener serving incoming connections on the given net.Listener
+func NewNet(id string, listener net.Listener) *Net {
+ return &Net{
+ id: id,
+ listener: listener,
+ }
+}
+
+// ID returns the id of the listener.
+func (l *Net) ID() string {
+ return l.id
+}
+
+// Address returns the address of the listener.
+func (l *Net) Address() string {
+ return l.listener.Addr().String()
+}
+
+// Protocol returns the network of the listener.
+func (l *Net) Protocol() string {
+ return l.listener.Addr().Network()
+}
+
+// Init initializes the listener.
+func (l *Net) Init(log *slog.Logger) error {
+ l.log = log
+ return nil
+}
+
+// Serve starts waiting for new TCP connections, and calls the establish
+// connection callback for any received.
+func (l *Net) Serve(establish EstablishFn) {
+ for {
+ if atomic.LoadUint32(&l.end) == 1 {
+ return
+ }
+
+ conn, err := l.listener.Accept()
+ if err != nil {
+ return
+ }
+
+ if atomic.LoadUint32(&l.end) == 0 {
+ go func() {
+ err = establish(l.id, conn)
+ if err != nil {
+ l.log.Warn("", "error", err)
+ }
+ }()
+ }
+ }
+}
+
+// Close closes the listener and any client connections.
+func (l *Net) Close(closeClients CloseFn) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
+ closeClients(l.id)
+ }
+
+ if l.listener != nil {
+ err := l.listener.Close()
+ if err != nil {
+ return
+ }
+ }
+}
diff --git a/listeners/net_test.go b/listeners/net_test.go
new file mode 100644
index 0000000..14a1ad6
--- /dev/null
+++ b/listeners/net_test.go
@@ -0,0 +1,105 @@
+package listeners
+
+import (
+ "errors"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewNet(t *testing.T) {
+ n, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+
+ l := NewNet("t1", n)
+ require.Equal(t, "t1", l.id)
+}
+
+func TestNetID(t *testing.T) {
+ n, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+
+ l := NewNet("t1", n)
+ require.Equal(t, "t1", l.ID())
+}
+
+func TestNetAddress(t *testing.T) {
+ n, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+
+ l := NewNet("t1", n)
+ require.Equal(t, n.Addr().String(), l.Address())
+}
+
+func TestNetProtocol(t *testing.T) {
+ n, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+
+ l := NewNet("t1", n)
+ require.Equal(t, "tcp", l.Protocol())
+}
+
+func TestNetInit(t *testing.T) {
+ n, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+
+ l := NewNet("t1", n)
+ err = l.Init(logger)
+ l.Close(MockCloser)
+ require.NoError(t, err)
+}
+
+func TestNetServeAndClose(t *testing.T) {
+ n, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+
+ l := NewNet("t1", n)
+ err = l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ go func(o chan bool) {
+ l.Serve(MockEstablisher)
+ o <- true
+ }(o)
+
+ time.Sleep(time.Millisecond)
+
+ var closed bool
+ l.Close(func(id string) {
+ closed = true
+ })
+
+ require.True(t, closed)
+ <-o
+
+ l.Close(MockCloser) // coverage: close closed
+ l.Serve(MockEstablisher) // coverage: serve closed
+}
+
+func TestNetEstablishThenEnd(t *testing.T) {
+ n, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+
+ l := NewNet("t1", n)
+ err = l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ established := make(chan bool)
+ go func() {
+ l.Serve(func(id string, c net.Conn) error {
+ established <- true
+ return errors.New("ending") // return an error to exit immediately
+ })
+ o <- true
+ }()
+
+ time.Sleep(time.Millisecond)
+ _, _ = net.Dial("tcp", n.Addr().String())
+ require.Equal(t, true, <-established)
+ l.Close(MockCloser)
+ <-o
+}
diff --git a/listeners/tcp.go b/listeners/tcp.go
new file mode 100644
index 0000000..014a182
--- /dev/null
+++ b/listeners/tcp.go
@@ -0,0 +1,109 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package listeners
+
+import (
+ "crypto/tls"
+ "net"
+ "sync"
+ "sync/atomic"
+
+ "log/slog"
+)
+
+const TypeTCP = "tcp"
+
+// TCP is a listener for establishing client connections on basic TCP protocol.
+type TCP struct { // [MQTT-4.2.0-1]
+ sync.RWMutex
+ id string // the internal id of the listener
+ address string // the network address to bind to
+ listen net.Listener // a net.Listener which will listen for new clients
+ config Config // configuration values for the listener
+ log *slog.Logger // server logger
+ end uint32 // ensure the close methods are only called once
+}
+
+// NewTCP initializes and returns a new TCP listener, listening on an address.
+func NewTCP(config Config) *TCP {
+ return &TCP{
+ id: config.ID,
+ address: config.Address,
+ config: config,
+ }
+}
+
+// ID returns the id of the listener.
+func (l *TCP) ID() string {
+ return l.id
+}
+
+// Address returns the address of the listener.
+func (l *TCP) Address() string {
+ if l.listen != nil {
+ return l.listen.Addr().String()
+ }
+ return l.address
+}
+
+// Protocol returns the address of the listener.
+func (l *TCP) Protocol() string {
+ return "tcp"
+}
+
+// Init initializes the listener.
+func (l *TCP) Init(log *slog.Logger) error {
+ l.log = log
+
+ var err error
+ if l.config.TLSConfig != nil {
+ l.listen, err = tls.Listen("tcp", l.address, l.config.TLSConfig)
+ } else {
+ l.listen, err = net.Listen("tcp", l.address)
+ }
+
+ return err
+}
+
+// Serve starts waiting for new TCP connections, and calls the establish
+// connection callback for any received.
+func (l *TCP) Serve(establish EstablishFn) {
+ for {
+ if atomic.LoadUint32(&l.end) == 1 {
+ return
+ }
+
+ conn, err := l.listen.Accept()
+ if err != nil {
+ return
+ }
+
+ if atomic.LoadUint32(&l.end) == 0 {
+ go func() {
+ err = establish(l.id, conn)
+ if err != nil {
+ l.log.Warn("", "error", err)
+ }
+ }()
+ }
+ }
+}
+
+// Close closes the listener and any client connections.
+func (l *TCP) Close(closeClients CloseFn) {
+ l.Lock()
+ defer l.Unlock()
+
+ if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
+ closeClients(l.id)
+ }
+
+ if l.listen != nil {
+ err := l.listen.Close()
+ if err != nil {
+ return
+ }
+ }
+}
diff --git a/listeners/tcp_test.go b/listeners/tcp_test.go
new file mode 100644
index 0000000..4b7496f
--- /dev/null
+++ b/listeners/tcp_test.go
@@ -0,0 +1,124 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package listeners
+
+import (
+ "errors"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewTCP(t *testing.T) {
+ l := NewTCP(basicConfig)
+ require.Equal(t, "t1", l.id)
+ require.Equal(t, testAddr, l.address)
+}
+
+func TestTCPID(t *testing.T) {
+ l := NewTCP(basicConfig)
+ require.Equal(t, "t1", l.ID())
+}
+
+func TestTCPAddress(t *testing.T) {
+ l := NewTCP(basicConfig)
+ require.Equal(t, testAddr, l.Address())
+}
+
+func TestTCPProtocol(t *testing.T) {
+ l := NewTCP(basicConfig)
+ require.Equal(t, "tcp", l.Protocol())
+}
+
+func TestTCPProtocolTLS(t *testing.T) {
+ l := NewTCP(tlsConfig)
+ _ = l.Init(logger)
+ defer l.listen.Close()
+ require.Equal(t, "tcp", l.Protocol())
+}
+
+func TestTCPInit(t *testing.T) {
+ l := NewTCP(basicConfig)
+ err := l.Init(logger)
+ l.Close(MockCloser)
+ require.NoError(t, err)
+
+ l2 := NewTCP(tlsConfig)
+ err = l2.Init(logger)
+ l2.Close(MockCloser)
+ require.NoError(t, err)
+ require.NotNil(t, l2.config.TLSConfig)
+}
+
+func TestTCPServeAndClose(t *testing.T) {
+ l := NewTCP(basicConfig)
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ go func(o chan bool) {
+ l.Serve(MockEstablisher)
+ o <- true
+ }(o)
+
+ time.Sleep(time.Millisecond)
+
+ var closed bool
+ l.Close(func(id string) {
+ closed = true
+ })
+
+ require.True(t, closed)
+ <-o
+
+ l.Close(MockCloser) // coverage: close closed
+ l.Serve(MockEstablisher) // coverage: serve closed
+}
+
+func TestTCPServeTLSAndClose(t *testing.T) {
+ l := NewTCP(tlsConfig)
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ go func(o chan bool) {
+ l.Serve(MockEstablisher)
+ o <- true
+ }(o)
+
+ time.Sleep(time.Millisecond)
+
+ var closed bool
+ l.Close(func(id string) {
+ closed = true
+ })
+
+ require.Equal(t, true, closed)
+ <-o
+}
+
+func TestTCPEstablishThenEnd(t *testing.T) {
+ l := NewTCP(basicConfig)
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ established := make(chan bool)
+ go func() {
+ l.Serve(func(id string, c net.Conn) error {
+ established <- true
+ return errors.New("ending") // return an error to exit immediately
+ })
+ o <- true
+ }()
+
+ time.Sleep(time.Millisecond)
+ _, _ = net.Dial("tcp", l.listen.Addr().String())
+ require.Equal(t, true, <-established)
+ l.Close(MockCloser)
+ <-o
+}
diff --git a/listeners/unixsock.go b/listeners/unixsock.go
new file mode 100644
index 0000000..23df130
--- /dev/null
+++ b/listeners/unixsock.go
@@ -0,0 +1,102 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: jason@zgwit.com
+
+package listeners
+
+import (
+ "net"
+ "os"
+ "sync"
+ "sync/atomic"
+
+ "log/slog"
+)
+
+const TypeUnix = "unix"
+
+// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
+type UnixSock struct {
+ sync.RWMutex
+ id string // the internal id of the listener.
+ address string // the network address to bind to.
+ config Config // configuration values for the listener
+ listen net.Listener // a net.Listener which will listen for new clients.
+ log *slog.Logger // server logger
+ end uint32 // ensure the close methods are only called once.
+}
+
+// NewUnixSock initializes and returns a new UnixSock listener, listening on an address.
+func NewUnixSock(config Config) *UnixSock {
+ return &UnixSock{
+ id: config.ID,
+ address: config.Address,
+ config: config,
+ }
+}
+
+// ID returns the id of the listener.
+func (l *UnixSock) ID() string {
+ return l.id
+}
+
+// Address returns the address of the listener.
+func (l *UnixSock) Address() string {
+ return l.address
+}
+
+// Protocol returns the address of the listener.
+func (l *UnixSock) Protocol() string {
+ return "unix"
+}
+
+// Init initializes the listener.
+func (l *UnixSock) Init(log *slog.Logger) error {
+ l.log = log
+
+ var err error
+ _ = os.Remove(l.address)
+ l.listen, err = net.Listen("unix", l.address)
+ return err
+}
+
+// Serve starts waiting for new UnixSock connections, and calls the establish
+// connection callback for any received.
+func (l *UnixSock) Serve(establish EstablishFn) {
+ for {
+ if atomic.LoadUint32(&l.end) == 1 {
+ return
+ }
+
+ conn, err := l.listen.Accept()
+ if err != nil {
+ return
+ }
+
+ if atomic.LoadUint32(&l.end) == 0 {
+ go func() {
+ err = establish(l.id, conn)
+ if err != nil {
+ l.log.Warn("", "error", err)
+ }
+ }()
+ }
+ }
+}
+
+// Close closes the listener and any client connections.
+func (l *UnixSock) Close(closeClients CloseFn) {
+ l.Lock()
+ defer l.Unlock()
+
+ if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
+ closeClients(l.id)
+ }
+
+ if l.listen != nil {
+ err := l.listen.Close()
+ if err != nil {
+ return
+ }
+ }
+}
diff --git a/listeners/unixsock_test.go b/listeners/unixsock_test.go
new file mode 100644
index 0000000..a3940e6
--- /dev/null
+++ b/listeners/unixsock_test.go
@@ -0,0 +1,102 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: jason@zgwit.com
+
+package listeners
+
+import (
+ "errors"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+const testUnixAddr = "mochi.sock"
+
+var (
+ unixConfig = Config{ID: "t1", Address: testUnixAddr}
+)
+
+func TestNewUnixSock(t *testing.T) {
+ l := NewUnixSock(unixConfig)
+ require.Equal(t, "t1", l.id)
+ require.Equal(t, testUnixAddr, l.address)
+}
+
+func TestUnixSockID(t *testing.T) {
+ l := NewUnixSock(unixConfig)
+ require.Equal(t, "t1", l.ID())
+}
+
+func TestUnixSockAddress(t *testing.T) {
+ l := NewUnixSock(unixConfig)
+ require.Equal(t, testUnixAddr, l.Address())
+}
+
+func TestUnixSockProtocol(t *testing.T) {
+ l := NewUnixSock(unixConfig)
+ require.Equal(t, "unix", l.Protocol())
+}
+
+func TestUnixSockInit(t *testing.T) {
+ l := NewUnixSock(unixConfig)
+ err := l.Init(logger)
+ l.Close(MockCloser)
+ require.NoError(t, err)
+
+ t2Config := unixConfig
+ t2Config.ID = "t2"
+ l2 := NewUnixSock(t2Config)
+ err = l2.Init(logger)
+ l2.Close(MockCloser)
+ require.NoError(t, err)
+}
+
+func TestUnixSockServeAndClose(t *testing.T) {
+ l := NewUnixSock(unixConfig)
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ go func(o chan bool) {
+ l.Serve(MockEstablisher)
+ o <- true
+ }(o)
+
+ time.Sleep(time.Millisecond)
+
+ var closed bool
+ l.Close(func(id string) {
+ closed = true
+ })
+
+ require.True(t, closed)
+ <-o
+
+ l.Close(MockCloser) // coverage: close closed
+ l.Serve(MockEstablisher) // coverage: serve closed
+}
+
+func TestUnixSockEstablishThenEnd(t *testing.T) {
+ l := NewUnixSock(unixConfig)
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ established := make(chan bool)
+ go func() {
+ l.Serve(func(id string, c net.Conn) error {
+ established <- true
+ return errors.New("ending") // return an error to exit immediately
+ })
+ o <- true
+ }()
+
+ time.Sleep(time.Millisecond)
+ _, _ = net.Dial("unix", l.listen.Addr().String())
+ require.Equal(t, true, <-established)
+ l.Close(MockCloser)
+ <-o
+}
diff --git a/listeners/websocket.go b/listeners/websocket.go
new file mode 100644
index 0000000..267daf6
--- /dev/null
+++ b/listeners/websocket.go
@@ -0,0 +1,199 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package listeners
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net"
+ "net/http"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "log/slog"
+
+ "github.com/gorilla/websocket"
+)
+
+const TypeWS = "ws"
+
+var (
+ // ErrInvalidMessage indicates that a message payload was not valid.
+ ErrInvalidMessage = errors.New("message type not binary")
+)
+
+// Websocket is a listener for establishing websocket connections.
+type Websocket struct { // [MQTT-4.2.0-1]
+ sync.RWMutex
+ id string // the internal id of the listener
+ address string // the network address to bind to
+ config Config // configuration values for the listener
+ listen *http.Server // a http server for serving websocket connections
+ log *slog.Logger // server logger
+ establish EstablishFn // the server's establish connection handler
+ upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection.
+ end uint32 // ensure the close methods are only called once
+}
+
+// NewWebsocket initializes and returns a new Websocket listener, listening on an address.
+func NewWebsocket(config Config) *Websocket {
+ return &Websocket{
+ id: config.ID,
+ address: config.Address,
+ config: config,
+ upgrader: &websocket.Upgrader{
+ Subprotocols: []string{"mqtt"},
+ CheckOrigin: func(r *http.Request) bool {
+ return true
+ },
+ },
+ }
+}
+
+// ID returns the id of the listener.
+func (l *Websocket) ID() string {
+ return l.id
+}
+
+// Address returns the address of the listener.
+func (l *Websocket) Address() string {
+ return l.address
+}
+
+// Protocol returns the address of the listener.
+func (l *Websocket) Protocol() string {
+ if l.config.TLSConfig != nil {
+ return "wss"
+ }
+
+ return "ws"
+}
+
+// Init initializes the listener.
+func (l *Websocket) Init(log *slog.Logger) error {
+ l.log = log
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("/", l.handler)
+ l.listen = &http.Server{
+ Addr: l.address,
+ Handler: mux,
+ TLSConfig: l.config.TLSConfig,
+ ReadTimeout: 60 * time.Second,
+ WriteTimeout: 60 * time.Second,
+ }
+
+ return nil
+}
+
+// handler upgrades and handles an incoming websocket connection.
+func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
+ c, err := l.upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ return
+ }
+ defer c.Close()
+
+ err = l.establish(l.id, &wsConn{Conn: c.UnderlyingConn(), c: c})
+ if err != nil {
+ l.log.Warn("", "error", err)
+ }
+}
+
+// Serve starts waiting for new Websocket connections, and calls the connection
+// establishment callback for any received.
+func (l *Websocket) Serve(establish EstablishFn) {
+ var err error
+ l.establish = establish
+
+ if l.listen.TLSConfig != nil {
+ err = l.listen.ListenAndServeTLS("", "")
+ } else {
+ err = l.listen.ListenAndServe()
+ }
+
+ // After the listener has been shutdown, no need to print the http.ErrServerClosed error.
+ if err != nil && atomic.LoadUint32(&l.end) == 0 {
+ l.log.Error("failed to serve.", "error", err, "listener", l.id)
+ }
+}
+
+// Close closes the listener and any client connections.
+func (l *Websocket) Close(closeClients CloseFn) {
+ l.Lock()
+ defer l.Unlock()
+
+ if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ _ = l.listen.Shutdown(ctx)
+ }
+
+ closeClients(l.id)
+}
+
+// wsConn is a websocket connection which satisfies the net.Conn interface.
+type wsConn struct {
+ net.Conn
+ c *websocket.Conn
+
+ // reader for the current message (can be nil)
+ r io.Reader
+}
+
+// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
+func (ws *wsConn) Read(p []byte) (int, error) {
+ if ws.r == nil {
+ op, r, err := ws.c.NextReader()
+ if err != nil {
+ return 0, err
+ }
+
+ if op != websocket.BinaryMessage {
+ err = ErrInvalidMessage
+ return 0, err
+ }
+
+ ws.r = r
+ }
+
+ var n int
+ for {
+ // buffer is full, return what we've read so far
+ if n == len(p) {
+ return n, nil
+ }
+
+ br, err := ws.r.Read(p[n:])
+ n += br
+ if err != nil {
+ // when ANY error occurs, we consider this the end of the current message (either because it really is, via
+ // io.EOF, or because something bad happened, in which case we want to drop the remainder)
+ ws.r = nil
+
+ if errors.Is(err, io.EOF) {
+ err = nil
+ }
+ return n, err
+ }
+ }
+}
+
+// Write writes bytes to the websocket connection.
+func (ws *wsConn) Write(p []byte) (int, error) {
+ err := ws.c.WriteMessage(websocket.BinaryMessage, p)
+ if err != nil {
+ return 0, err
+ }
+
+ return len(p), nil
+}
+
+// Close signals the underlying websocket conn to close.
+func (ws *wsConn) Close() error {
+ return ws.Conn.Close()
+}
diff --git a/listeners/websocket_test.go b/listeners/websocket_test.go
new file mode 100644
index 0000000..56d6bc3
--- /dev/null
+++ b/listeners/websocket_test.go
@@ -0,0 +1,173 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package listeners
+
+import (
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/gorilla/websocket"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewWebsocket(t *testing.T) {
+ l := NewWebsocket(basicConfig)
+ require.Equal(t, "t1", l.id)
+ require.Equal(t, testAddr, l.address)
+}
+
+func TestWebsocketID(t *testing.T) {
+ l := NewWebsocket(basicConfig)
+ require.Equal(t, "t1", l.ID())
+}
+
+func TestWebsocketAddress(t *testing.T) {
+ l := NewWebsocket(basicConfig)
+ require.Equal(t, testAddr, l.Address())
+}
+
+func TestWebsocketProtocol(t *testing.T) {
+ l := NewWebsocket(basicConfig)
+ require.Equal(t, "ws", l.Protocol())
+}
+
+func TestWebsocketProtocolTLS(t *testing.T) {
+ l := NewWebsocket(tlsConfig)
+ require.Equal(t, "wss", l.Protocol())
+}
+
+func TestWebsocketInit(t *testing.T) {
+ l := NewWebsocket(basicConfig)
+ require.Nil(t, l.listen)
+ err := l.Init(logger)
+ require.NoError(t, err)
+ require.NotNil(t, l.listen)
+}
+
+func TestWebsocketServeAndClose(t *testing.T) {
+ l := NewWebsocket(basicConfig)
+ _ = l.Init(logger)
+
+ o := make(chan bool)
+ go func(o chan bool) {
+ l.Serve(MockEstablisher)
+ o <- true
+ }(o)
+
+ time.Sleep(time.Millisecond)
+
+ var closed bool
+ l.Close(func(id string) {
+ closed = true
+ })
+
+ require.True(t, closed)
+ <-o
+}
+
+func TestWebsocketServeTLSAndClose(t *testing.T) {
+ l := NewWebsocket(tlsConfig)
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ go func(o chan bool) {
+ l.Serve(MockEstablisher)
+ o <- true
+ }(o)
+
+ time.Sleep(time.Millisecond)
+ var closed bool
+ l.Close(func(id string) {
+ closed = true
+ })
+ require.Equal(t, true, closed)
+ <-o
+}
+
+func TestWebsocketFailedToServe(t *testing.T) {
+ config := tlsConfig
+ config.Address = "wrong_addr"
+ l := NewWebsocket(config)
+ err := l.Init(logger)
+ require.NoError(t, err)
+
+ o := make(chan bool)
+ go func(o chan bool) {
+ l.Serve(MockEstablisher)
+ o <- true
+ }(o)
+
+ <-o
+ var closed bool
+ l.Close(func(id string) {
+ closed = true
+ })
+ require.Equal(t, true, closed)
+}
+
+func TestWebsocketUpgrade(t *testing.T) {
+ l := NewWebsocket(basicConfig)
+ _ = l.Init(logger)
+
+ e := make(chan bool)
+ l.establish = func(id string, c net.Conn) error {
+ e <- true
+ return nil
+ }
+
+ s := httptest.NewServer(http.HandlerFunc(l.handler))
+ ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil)
+ require.NoError(t, err)
+ require.Equal(t, true, <-e)
+
+ s.Close()
+ _ = ws.Close()
+}
+
+func TestWebsocketConnectionReads(t *testing.T) {
+ l := NewWebsocket(basicConfig)
+ _ = l.Init(nil)
+
+ recv := make(chan []byte)
+ l.establish = func(id string, c net.Conn) error {
+ var out []byte
+ for {
+ buf := make([]byte, 2048)
+ n, err := c.Read(buf)
+ require.NoError(t, err)
+ out = append(out, buf[:n]...)
+ if n < 2048 {
+ break
+ }
+ }
+
+ recv <- out
+ return nil
+ }
+
+ s := httptest.NewServer(http.HandlerFunc(l.handler))
+ ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil)
+ require.NoError(t, err)
+
+ pkt := make([]byte, 3000) // make sure this is >2048
+ for i := 0; i < len(pkt); i++ {
+ pkt[i] = byte(i % 100)
+ }
+
+ err = ws.WriteMessage(websocket.BinaryMessage, pkt)
+ require.NoError(t, err)
+
+ got := <-recv
+ require.Equal(t, 3000, len(got))
+ require.Equal(t, pkt, got)
+
+ s.Close()
+ _ = ws.Close()
+}
diff --git a/main.go b/main.go
new file mode 100644
index 0000000..c7c372b
--- /dev/null
+++ b/main.go
@@ -0,0 +1,103 @@
+package main
+
+import (
+ "log"
+ "os"
+ "os/signal"
+ "syscall"
+ "testmqtt/config"
+ "testmqtt/mqtt"
+)
+
+func main() {
+
+ sigs := make(chan os.Signal, 1)
+ done := make(chan bool, 1)
+ signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
+ go func() {
+ <-sigs
+ done <- true
+ }()
+
+ configBytes, err := os.ReadFile("config.yaml")
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ options, err := config.FromBytes(configBytes)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ server := mqtt.New(options)
+
+ go func() {
+ err := server.Serve()
+ if err != nil {
+ log.Fatal(err)
+ }
+ }()
+
+ <-done
+ server.Log.Warn("caught signal, stopping...")
+ _ = server.Close()
+ server.Log.Info("main.go finished")
+
+ //tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
+ //wsAddr := flag.String("ws", ":1882", "network address for Websocket listener")
+ //infoAddr := flag.String("info", ":8080", "network address for web info dashboard listener")
+ //flag.Parse()
+ //
+ //sigs := make(chan os.Signal, 1)
+ //done := make(chan bool, 1)
+ //signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
+ //go func() {
+ // <-sigs
+ // done <- true
+ //}()
+ //
+ //server := mqtt.New(nil)
+ //_ = server.AddHook(new(auth.AllowHook), nil)
+ //
+ //tcp := listeners.NewTCP(listeners.Config{
+ // ID: "t1",
+ // Address: *tcpAddr,
+ //})
+ //err := server.AddListener(tcp)
+ //if err != nil {
+ // log.Fatal(err)
+ //}
+ //
+ //ws := listeners.NewWebsocket(listeners.Config{
+ // ID: "ws1",
+ // Address: *wsAddr,
+ //})
+ //err = server.AddListener(ws)
+ //if err != nil {
+ // log.Fatal(err)
+ //}
+ //
+ //stats := listeners.NewHTTPStats(
+ // listeners.Config{
+ // ID: "info",
+ // Address: *infoAddr,
+ // },
+ // server.Info,
+ //)
+ //err = server.AddListener(stats)
+ //if err != nil {
+ // log.Fatal(err)
+ //}
+ //
+ //go func() {
+ // err := server.Serve()
+ // if err != nil {
+ // log.Fatal(err)
+ // }
+ //}()
+ //
+ //<-done
+ //server.Log.Warn("caught signal, stopping...")
+ //_ = server.Close()
+ //server.Log.Info("mochi mqtt shutdown complete")
+}
diff --git a/mempool/bufpool.go b/mempool/bufpool.go
new file mode 100644
index 0000000..2b35b0c
--- /dev/null
+++ b/mempool/bufpool.go
@@ -0,0 +1,83 @@
+package mempool
+
+import (
+ "bytes"
+ "sync"
+)
+
+var bufPool = NewBuffer(0)
+
+// GetBuffer 从默认缓冲池中获取一个缓冲区
+// GetBuffer takes a Buffer from the default buffer pool
+func GetBuffer() *bytes.Buffer { return bufPool.Get() }
+
+// PutBuffer 将 Buffer 返回到默认缓冲池
+// PutBuffer returns Buffer to the default buffer pool
+func PutBuffer(x *bytes.Buffer) { bufPool.Put(x) }
+
+// BufferPool 缓冲池接口
+type BufferPool interface {
+ Get() *bytes.Buffer
+ Put(x *bytes.Buffer)
+}
+
+// NewBuffer returns a buffer pool. The max specify the max capacity of the Buffer the pool will
+// return. If the Buffer becoomes large than max, it will no longer be returned to the pool. If
+// max <= 0, no limit will be enforced.
+func NewBuffer(max int) BufferPool {
+ if max > 0 {
+ return newBufferWithCap(max)
+ }
+ return newBuffer()
+}
+
+// Buffer is a Buffer pool.
+type Buffer struct {
+ pool *sync.Pool
+}
+
+func newBuffer() *Buffer {
+ return &Buffer{
+ pool: &sync.Pool{
+ New: func() any { return new(bytes.Buffer) },
+ },
+ }
+}
+
+// Get a Buffer from the pool.
+func (b *Buffer) Get() *bytes.Buffer {
+ return b.pool.Get().(*bytes.Buffer)
+}
+
+// Put the Buffer back into pool. It resets the Buffer for reuse.
+func (b *Buffer) Put(x *bytes.Buffer) {
+ x.Reset()
+ b.pool.Put(x)
+}
+
+// BufferWithCap is a Buffer pool that
+type BufferWithCap struct {
+ bp *Buffer
+ max int
+}
+
+func newBufferWithCap(max int) *BufferWithCap {
+ return &BufferWithCap{
+ bp: newBuffer(),
+ max: max,
+ }
+}
+
+// Get a Buffer from the pool.
+func (b *BufferWithCap) Get() *bytes.Buffer {
+ return b.bp.Get()
+}
+
+// Put the Buffer back into the pool if the capacity doesn't exceed the limit. It resets the Buffer
+// for reuse.
+func (b *BufferWithCap) Put(x *bytes.Buffer) {
+ if x.Cap() > b.max {
+ return
+ }
+ b.bp.Put(x)
+}
diff --git a/mempool/bufpool_test.go b/mempool/bufpool_test.go
new file mode 100644
index 0000000..560f2e9
--- /dev/null
+++ b/mempool/bufpool_test.go
@@ -0,0 +1,96 @@
+package mempool
+
+import (
+ "bytes"
+ "reflect"
+ "runtime/debug"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewBuffer(t *testing.T) {
+ defer debug.SetGCPercent(debug.SetGCPercent(-1))
+ bp := NewBuffer(1000)
+ require.Equal(t, "*mempool.BufferWithCap", reflect.TypeOf(bp).String())
+
+ bp = NewBuffer(0)
+ require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String())
+
+ bp = NewBuffer(-1)
+ require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String())
+}
+
+func TestBuffer(t *testing.T) {
+ defer debug.SetGCPercent(debug.SetGCPercent(-1))
+ Size := 101
+
+ bp := NewBuffer(0)
+ buf := bp.Get()
+
+ for i := 0; i < Size; i++ {
+ buf.WriteByte('a')
+ }
+
+ bp.Put(buf)
+ buf = bp.Get()
+ require.Equal(t, 0, buf.Len())
+}
+
+func TestBufferWithCap(t *testing.T) {
+ defer debug.SetGCPercent(debug.SetGCPercent(-1))
+ Size := 101
+ bp := NewBuffer(100)
+ buf := bp.Get()
+
+ for i := 0; i < Size; i++ {
+ buf.WriteByte('a')
+ }
+
+ bp.Put(buf)
+ buf = bp.Get()
+ require.Equal(t, 0, buf.Len())
+ require.Equal(t, 0, buf.Cap())
+}
+
+func BenchmarkBufferPool(b *testing.B) {
+ bp := NewBuffer(0)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ b := bp.Get()
+ b.WriteString("this is a test")
+ bp.Put(b)
+ }
+}
+
+func BenchmarkBufferPoolWithCapLarger(b *testing.B) {
+ bp := NewBuffer(64 * 1024)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ b := bp.Get()
+ b.WriteString("this is a test")
+ bp.Put(b)
+ }
+}
+
+func BenchmarkBufferPoolWithCapLesser(b *testing.B) {
+ bp := NewBuffer(10)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ b := bp.Get()
+ b.WriteString("this is a test")
+ bp.Put(b)
+ }
+}
+
+func BenchmarkBufferWithoutPool(b *testing.B) {
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ b := new(bytes.Buffer)
+ b.WriteString("this is a test")
+ _ = b
+ }
+}
diff --git a/mqtt/clients.go b/mqtt/clients.go
new file mode 100644
index 0000000..bbfc826
--- /dev/null
+++ b/mqtt/clients.go
@@ -0,0 +1,649 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package mqtt
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/rs/xid"
+
+ "testmqtt/packets"
+)
+
+const (
+ defaultKeepalive uint16 = 10 // the default connection keepalive value in seconds.
+ defaultClientProtocolVersion byte = 4 // the default mqtt protocol version of connecting clients (if somehow unspecified).
+ minimumKeepalive uint16 = 5 // the minimum recommended keepalive - values under with display a warning.
+)
+
+var (
+ ErrMinimumKeepalive = errors.New("client keepalive is below minimum recommended value and may exhibit connection instability")
+)
+
+// ReadFn is the function signature for the function used for reading and processing new packets.
+type ReadFn func(*Client, packets.Packet) error
+
+// Clients contains a map of the clients known by the broker.
+type Clients struct {
+ internal map[string]*Client // clients known by the broker, keyed on client id.
+ sync.RWMutex
+}
+
+// NewClients returns an instance of Clients.
+func NewClients() *Clients {
+ return &Clients{
+ internal: make(map[string]*Client),
+ }
+}
+
+// Add adds a new client to the clients map, keyed on client id.
+func (cl *Clients) Add(val *Client) {
+ cl.Lock()
+ defer cl.Unlock()
+ cl.internal[val.ID] = val
+}
+
+// GetAll returns all the clients.
+func (cl *Clients) GetAll() map[string]*Client {
+ cl.RLock()
+ defer cl.RUnlock()
+ m := map[string]*Client{}
+ for k, v := range cl.internal {
+ m[k] = v
+ }
+ return m
+}
+
+// Get returns the value of a client if it exists.
+func (cl *Clients) Get(id string) (*Client, bool) {
+ cl.RLock()
+ defer cl.RUnlock()
+ val, ok := cl.internal[id]
+ return val, ok
+}
+
+// Len returns the length of the clients map.
+func (cl *Clients) Len() int {
+ cl.RLock()
+ defer cl.RUnlock()
+ val := len(cl.internal)
+ return val
+}
+
+// Delete removes a client from the internal map.
+func (cl *Clients) Delete(id string) {
+ cl.Lock()
+ defer cl.Unlock()
+ delete(cl.internal, id)
+}
+
+// GetByListener returns clients matching a listener id.
+func (cl *Clients) GetByListener(id string) []*Client {
+ cl.RLock()
+ defer cl.RUnlock()
+ clients := make([]*Client, 0, cl.Len())
+ for _, client := range cl.internal {
+ if client.Net.Listener == id && !client.Closed() {
+ clients = append(clients, client)
+ }
+ }
+ return clients
+}
+
+// Client contains information about a client known by the broker.
+type Client struct {
+ Properties ClientProperties // client properties
+ State ClientState // the operational state of the client.
+ Net ClientConnection // network connection state of the client
+ ID string // the client id.
+ ops *ops // ops provides a reference to server ops.
+ sync.RWMutex // mutex
+}
+
+// ClientConnection contains the connection transport and metadata for the client.
+type ClientConnection struct {
+ Conn net.Conn // the net.Conn used to establish the connection
+ bconn *bufio.Reader // a buffered net.Conn for reading packets
+ outbuf *bytes.Buffer // a buffer for writing packets
+ Remote string // the remote address of the client
+ Listener string // listener id of the client
+ Inline bool // if true, the client is the built-in 'inline' embedded client
+}
+
+// ClientProperties contains the properties which define the client behaviour.
+type ClientProperties struct {
+ Props packets.Properties
+ Will Will
+ Username []byte
+ ProtocolVersion byte
+ Clean bool
+}
+
+// Will contains the last will and testament details for a client connection.
+type Will struct {
+ Payload []byte // -
+ User []packets.UserProperty // -
+ TopicName string // -
+ Flag uint32 // 0,1
+ WillDelayInterval uint32 // -
+ Qos byte // -
+ Retain bool // -
+}
+
+// ClientState tracks the state of the client.
+type ClientState struct {
+ TopicAliases TopicAliases // a map of topic aliases
+ stopCause atomic.Value // reason for stopping
+ Inflight *Inflight // a map of in-flight qos messages
+ Subscriptions *Subscriptions // a map of the subscription filters a client maintains
+ disconnected int64 // the time the client disconnected in unix time, for calculating expiry
+ outbound chan *packets.Packet // queue for pending outbound packets
+ endOnce sync.Once // only end once
+ isTakenOver uint32 // used to identify orphaned clients
+ packetID uint32 // the current highest packetID
+ open context.Context // indicate that the client is open for packet exchange
+ cancelOpen context.CancelFunc // cancel function for open context
+ outboundQty int32 // number of messages currently in the outbound queue
+ Keepalive uint16 // the number of seconds the connection can wait
+ ServerKeepalive bool // keepalive was set by the server
+}
+
+// newClient returns a new instance of Client. This is almost exclusively used by Server
+// for creating new clients, but it lives here because it's not dependent.
+func newClient(c net.Conn, o *ops) *Client {
+ ctx, cancel := context.WithCancel(context.Background())
+ cl := &Client{
+ State: ClientState{
+ Inflight: NewInflights(),
+ Subscriptions: NewSubscriptions(),
+ TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
+ open: ctx,
+ cancelOpen: cancel,
+ Keepalive: defaultKeepalive,
+ outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
+ },
+ Properties: ClientProperties{
+ ProtocolVersion: defaultClientProtocolVersion, // default protocol version
+ },
+ ops: o,
+ }
+
+ if c != nil {
+ cl.Net = ClientConnection{
+ Conn: c,
+ bconn: bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
+ Remote: c.RemoteAddr().String(),
+ }
+ }
+
+ return cl
+}
+
+// WriteLoop ranges over pending outbound messages and writes them to the client connection.
+func (cl *Client) WriteLoop() {
+ for {
+ select {
+ case pk := <-cl.State.outbound:
+ if err := cl.WritePacket(*pk); err != nil {
+ // TODO : Figure out what to do with error
+ cl.ops.log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk)
+ }
+ atomic.AddInt32(&cl.State.outboundQty, -1)
+ case <-cl.State.open.Done():
+ return
+ }
+ }
+}
+
+// ParseConnect parses the connect parameters and properties for a client.
+func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
+ cl.Net.Listener = lid
+
+ cl.Properties.ProtocolVersion = pk.ProtocolVersion
+ cl.Properties.Username = pk.Connect.Username
+ cl.Properties.Clean = pk.Connect.Clean
+ cl.Properties.Props = pk.Properties.Copy(false)
+
+ if cl.Properties.Props.ReceiveMaximum > cl.ops.options.Capabilities.MaximumInflight { // 3.3.4 Non-normative
+ cl.Properties.Props.ReceiveMaximum = cl.ops.options.Capabilities.MaximumInflight
+ }
+
+ if pk.Connect.Keepalive <= minimumKeepalive {
+ cl.ops.log.Warn(
+ ErrMinimumKeepalive.Error(),
+ "client", cl.ID,
+ "keepalive", pk.Connect.Keepalive,
+ "recommended", minimumKeepalive,
+ )
+ }
+
+ cl.State.Keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
+ cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
+ cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
+ cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
+
+ cl.ID = pk.Connect.ClientIdentifier
+ if cl.ID == "" {
+ cl.ID = xid.New().String() // [MQTT-3.1.3-6] [MQTT-3.1.3-7]
+ cl.Properties.Props.AssignedClientID = cl.ID
+ }
+
+ if pk.Connect.WillFlag {
+ cl.Properties.Will = Will{
+ Qos: pk.Connect.WillQos,
+ Retain: pk.Connect.WillRetain,
+ Payload: pk.Connect.WillPayload,
+ TopicName: pk.Connect.WillTopic,
+ WillDelayInterval: pk.Connect.WillProperties.WillDelayInterval,
+ User: pk.Connect.WillProperties.User,
+ }
+ if pk.Properties.SessionExpiryIntervalFlag &&
+ pk.Properties.SessionExpiryInterval < pk.Connect.WillProperties.WillDelayInterval {
+ cl.Properties.Will.WillDelayInterval = pk.Properties.SessionExpiryInterval
+ }
+ if pk.Connect.WillFlag {
+ cl.Properties.Will.Flag = 1 // atomic for checking
+ }
+ }
+}
+
+// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
+func (cl *Client) refreshDeadline(keepalive uint16) {
+ var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
+ if keepalive > 0 {
+ expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
+ }
+
+ if cl.Net.Conn != nil {
+ _ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
+ }
+}
+
+// NextPacketID returns the next available (unused) packet id for the client.
+// If no unused packet ids are available, an error is returned and the client
+// should be disconnected.
+func (cl *Client) NextPacketID() (i uint32, err error) {
+ cl.Lock()
+ defer cl.Unlock()
+
+ i = atomic.LoadUint32(&cl.State.packetID)
+ started := i
+ overflowed := false
+ for {
+ if overflowed && i == started {
+ return 0, packets.ErrQuotaExceeded
+ }
+
+ if i >= cl.ops.options.Capabilities.maximumPacketID {
+ overflowed = true
+ i = 0
+ continue
+ }
+
+ i++
+
+ if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
+ atomic.StoreUint32(&cl.State.packetID, i)
+ return i, nil
+ }
+ }
+}
+
+// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
+func (cl *Client) ResendInflightMessages(force bool) error {
+ if cl.State.Inflight.Len() == 0 {
+ return nil
+ }
+
+ for _, tk := range cl.State.Inflight.GetAll(false) {
+ if tk.FixedHeader.Type == packets.Publish {
+ tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
+ }
+
+ cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0)
+ err := cl.WritePacket(tk)
+ if err != nil {
+ return err
+ }
+
+ if tk.FixedHeader.Type == packets.Puback || tk.FixedHeader.Type == packets.Pubcomp {
+ if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
+ cl.ops.hooks.OnQosComplete(cl, tk)
+ atomic.AddInt64(&cl.ops.info.Inflight, -1)
+ }
+ }
+ }
+
+ return nil
+}
+
+// ClearInflights deletes all inflight messages for the client, e.g. for a disconnected user with a clean session.
+func (cl *Client) ClearInflights() {
+ for _, tk := range cl.State.Inflight.GetAll(false) {
+ if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
+ cl.ops.hooks.OnQosDropped(cl, tk)
+ atomic.AddInt64(&cl.ops.info.Inflight, -1)
+ }
+ }
+}
+
+// ClearExpiredInflights deletes any inflight messages which have expired.
+func (cl *Client) ClearExpiredInflights(now, maximumExpiry int64) []uint16 {
+ deleted := []uint16{}
+ for _, tk := range cl.State.Inflight.GetAll(false) {
+ expired := tk.ProtocolVersion == 5 && tk.Expiry > 0 && tk.Expiry < now // [MQTT-3.3.2-5]
+
+ // If the maximum message expiry interval is set (greater than 0), and the message
+ // retention period exceeds the maximum expiry, the message will be forcibly removed.
+ enforced := maximumExpiry > 0 && now-tk.Created > maximumExpiry
+
+ if expired || enforced {
+ if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
+ cl.ops.hooks.OnQosDropped(cl, tk)
+ atomic.AddInt64(&cl.ops.info.Inflight, -1)
+ deleted = append(deleted, tk.PacketID)
+ }
+ }
+ }
+
+ return deleted
+}
+
+// Read reads incoming packets from the connected client and transforms them into
+// packets to be handled by the packetHandler.
+func (cl *Client) Read(packetHandler ReadFn) error {
+ var err error
+
+ for {
+ if cl.Closed() {
+ return nil
+ }
+
+ cl.refreshDeadline(cl.State.Keepalive)
+ fh := new(packets.FixedHeader)
+ err = cl.ReadFixedHeader(fh)
+ if err != nil {
+ return err
+ }
+
+ pk, err := cl.ReadPacket(fh)
+ if err != nil {
+ return err
+ }
+
+ err = packetHandler(cl, pk) // Process inbound packet.
+ if err != nil {
+ return err
+ }
+ }
+}
+
+// Stop instructs the client to shut down all processing goroutines and disconnect.
+func (cl *Client) Stop(err error) {
+ cl.State.endOnce.Do(func() {
+
+ if cl.Net.Conn != nil {
+ _ = cl.Net.Conn.Close() // omit close error
+ }
+
+ if err != nil {
+ cl.State.stopCause.Store(err)
+ }
+
+ if cl.State.cancelOpen != nil {
+ cl.State.cancelOpen()
+ }
+
+ atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
+ })
+}
+
+// StopCause returns the reason the client connection was stopped, if any.
+func (cl *Client) StopCause() error {
+ if cl.State.stopCause.Load() == nil {
+ return nil
+ }
+ return cl.State.stopCause.Load().(error)
+}
+
+// StopTime returns the the time the client disconnected in unix time, else zero.
+func (cl *Client) StopTime() int64 {
+ return atomic.LoadInt64(&cl.State.disconnected)
+}
+
+// Closed returns true if client connection is closed.
+func (cl *Client) Closed() bool {
+ return cl.State.open == nil || cl.State.open.Err() != nil
+}
+
+// ReadFixedHeader reads in the values of the next packet's fixed header.
+func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
+ if cl.Net.bconn == nil {
+ return ErrConnectionClosed
+ }
+
+ b, err := cl.Net.bconn.ReadByte()
+ if err != nil {
+ return err
+ }
+
+ err = fh.Decode(b)
+ if err != nil {
+ return err
+ }
+
+ var bu int
+ fh.Remaining, bu, err = packets.DecodeLength(cl.Net.bconn)
+ if err != nil {
+ return err
+ }
+
+ if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize {
+ return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
+ }
+
+ atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
+ return nil
+}
+
+// ReadPacket reads the remaining buffer into an MQTT packet.
+func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
+ atomic.AddInt64(&cl.ops.info.PacketsReceived, 1)
+
+ pk.ProtocolVersion = cl.Properties.ProtocolVersion // inherit client protocol version for decoding
+ pk.FixedHeader = *fh
+ p := make([]byte, pk.FixedHeader.Remaining)
+ n, err := io.ReadFull(cl.Net.bconn, p)
+ if err != nil {
+ return pk, err
+ }
+
+ atomic.AddInt64(&cl.ops.info.BytesReceived, int64(n))
+
+ // Decode the remaining packet values using a fresh copy of the bytes,
+ // otherwise the next packet will change the data of this one.
+ px := append([]byte{}, p[:]...)
+ switch pk.FixedHeader.Type {
+ case packets.Connect:
+ err = pk.ConnectDecode(px)
+ case packets.Disconnect:
+ err = pk.DisconnectDecode(px)
+ case packets.Connack:
+ err = pk.ConnackDecode(px)
+ case packets.Publish:
+ err = pk.PublishDecode(px)
+ if err == nil {
+ atomic.AddInt64(&cl.ops.info.MessagesReceived, 1)
+ }
+ case packets.Puback:
+ err = pk.PubackDecode(px)
+ case packets.Pubrec:
+ err = pk.PubrecDecode(px)
+ case packets.Pubrel:
+ err = pk.PubrelDecode(px)
+ case packets.Pubcomp:
+ err = pk.PubcompDecode(px)
+ case packets.Subscribe:
+ err = pk.SubscribeDecode(px)
+ case packets.Suback:
+ err = pk.SubackDecode(px)
+ case packets.Unsubscribe:
+ err = pk.UnsubscribeDecode(px)
+ case packets.Unsuback:
+ err = pk.UnsubackDecode(px)
+ case packets.Pingreq:
+ case packets.Pingresp:
+ case packets.Auth:
+ err = pk.AuthDecode(px)
+ default:
+ err = fmt.Errorf("invalid packet type; %v", pk.FixedHeader.Type)
+ }
+
+ if err != nil {
+ return pk, err
+ }
+
+ pk, err = cl.ops.hooks.OnPacketRead(cl, pk)
+ return
+}
+
+// WritePacket encodes and writes a packet to the client.
+func (cl *Client) WritePacket(pk packets.Packet) error {
+ if cl.Closed() {
+ return ErrConnectionClosed
+ }
+
+ if cl.Net.Conn == nil {
+ return nil
+ }
+
+ if pk.Expiry > 0 {
+ pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
+ }
+
+ pk.ProtocolVersion = cl.Properties.ProtocolVersion
+ if pk.Mods.MaxSize == 0 { // NB we use this statement to embed client packet sizes in tests
+ pk.Mods.MaxSize = cl.Properties.Props.MaximumPacketSize
+ }
+
+ if cl.Properties.Props.RequestProblemInfoFlag && cl.Properties.Props.RequestProblemInfo == 0x0 {
+ pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
+ }
+
+ if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo {
+ pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
+ }
+
+ pk = cl.ops.hooks.OnPacketEncode(cl, pk)
+
+ var err error
+ buf := new(bytes.Buffer)
+ switch pk.FixedHeader.Type {
+ case packets.Connect:
+ err = pk.ConnectEncode(buf)
+ case packets.Connack:
+ err = pk.ConnackEncode(buf)
+ case packets.Publish:
+ err = pk.PublishEncode(buf)
+ case packets.Puback:
+ err = pk.PubackEncode(buf)
+ case packets.Pubrec:
+ err = pk.PubrecEncode(buf)
+ case packets.Pubrel:
+ err = pk.PubrelEncode(buf)
+ case packets.Pubcomp:
+ err = pk.PubcompEncode(buf)
+ case packets.Subscribe:
+ err = pk.SubscribeEncode(buf)
+ case packets.Suback:
+ err = pk.SubackEncode(buf)
+ case packets.Unsubscribe:
+ err = pk.UnsubscribeEncode(buf)
+ case packets.Unsuback:
+ err = pk.UnsubackEncode(buf)
+ case packets.Pingreq:
+ err = pk.PingreqEncode(buf)
+ case packets.Pingresp:
+ err = pk.PingrespEncode(buf)
+ case packets.Disconnect:
+ err = pk.DisconnectEncode(buf)
+ case packets.Auth:
+ err = pk.AuthEncode(buf)
+ default:
+ err = fmt.Errorf("%w: %v", packets.ErrNoValidPacketAvailable, pk.FixedHeader.Type)
+ }
+ if err != nil {
+ return err
+ }
+
+ if pk.Mods.MaxSize > 0 && uint32(buf.Len()) > pk.Mods.MaxSize {
+ return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25]
+ }
+
+ n, err := func() (int64, error) {
+ cl.Lock()
+ defer cl.Unlock()
+ if len(cl.State.outbound) == 0 {
+ if cl.Net.outbuf == nil {
+ return buf.WriteTo(cl.Net.Conn)
+ }
+
+ // first write to buffer, then flush buffer
+ n, _ := cl.Net.outbuf.Write(buf.Bytes()) // will always be successful
+ err = cl.flushOutbuf()
+ return int64(n), err
+ }
+
+ // there are more writes in the queue
+ if cl.Net.outbuf == nil {
+ if buf.Len() >= cl.ops.options.ClientNetWriteBufferSize {
+ return buf.WriteTo(cl.Net.Conn)
+ }
+ cl.Net.outbuf = new(bytes.Buffer)
+ }
+
+ n, _ := cl.Net.outbuf.Write(buf.Bytes()) // will always be successful
+ if cl.Net.outbuf.Len() < cl.ops.options.ClientNetWriteBufferSize {
+ return int64(n), nil
+ }
+
+ err = cl.flushOutbuf()
+ return int64(n), err
+ }()
+ if err != nil {
+ return err
+ }
+
+ atomic.AddInt64(&cl.ops.info.BytesSent, n)
+ atomic.AddInt64(&cl.ops.info.PacketsSent, 1)
+ if pk.FixedHeader.Type == packets.Publish {
+ atomic.AddInt64(&cl.ops.info.MessagesSent, 1)
+ }
+
+ cl.ops.hooks.OnPacketSent(cl, pk, buf.Bytes())
+
+ return err
+}
+
+func (cl *Client) flushOutbuf() (err error) {
+ if cl.Net.outbuf == nil {
+ return
+ }
+
+ _, err = cl.Net.outbuf.WriteTo(cl.Net.Conn)
+ if err == nil {
+ cl.Net.outbuf = nil
+ }
+ return
+}
diff --git a/mqtt/clients_test.go b/mqtt/clients_test.go
new file mode 100644
index 0000000..15d8215
--- /dev/null
+++ b/mqtt/clients_test.go
@@ -0,0 +1,930 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package mqtt
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "errors"
+ "io"
+ "log/slog"
+ "net"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "testmqtt/packets"
+ "testmqtt/system"
+
+ "github.com/stretchr/testify/require"
+)
+
+const pkInfo = "packet type %v, %s"
+
+var errClientStop = errors.New("test stop")
+
+func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
+ r, w = net.Pipe()
+
+ cl = newClient(w, &ops{
+ info: new(system.Info),
+ hooks: new(Hooks),
+ log: logger,
+ options: &Options{
+ Capabilities: &Capabilities{
+ ReceiveMaximum: 10,
+ MaximumInflight: 5,
+ TopicAliasMaximum: 10000,
+ MaximumClientWritesPending: 3,
+ maximumPacketID: 10,
+ },
+ },
+ })
+
+ cl.ID = "mochi"
+ cl.State.Inflight.maximumSendQuota = 5
+ cl.State.Inflight.sendQuota = 5
+ cl.State.Inflight.maximumReceiveQuota = 10
+ cl.State.Inflight.receiveQuota = 10
+ cl.Properties.Props.TopicAliasMaximum = 0
+ cl.Properties.Props.RequestResponseInfo = 0x1
+
+ go cl.WriteLoop()
+
+ return
+}
+
+func TestNewInflights(t *testing.T) {
+ require.NotNil(t, NewInflights().internal)
+}
+
+func TestNewClients(t *testing.T) {
+ cl := NewClients()
+ require.NotNil(t, cl.internal)
+}
+
+func TestClientsAdd(t *testing.T) {
+ cl := NewClients()
+ cl.Add(&Client{ID: "t1"})
+ require.Contains(t, cl.internal, "t1")
+}
+
+func TestClientsGet(t *testing.T) {
+ cl := NewClients()
+ cl.Add(&Client{ID: "t1"})
+ cl.Add(&Client{ID: "t2"})
+ require.Contains(t, cl.internal, "t1")
+ require.Contains(t, cl.internal, "t2")
+
+ client, ok := cl.Get("t1")
+ require.Equal(t, true, ok)
+ require.Equal(t, "t1", client.ID)
+}
+
+func TestClientsGetAll(t *testing.T) {
+ cl := NewClients()
+ cl.Add(&Client{ID: "t1"})
+ cl.Add(&Client{ID: "t2"})
+ cl.Add(&Client{ID: "t3"})
+ require.Contains(t, cl.internal, "t1")
+ require.Contains(t, cl.internal, "t2")
+ require.Contains(t, cl.internal, "t3")
+
+ clients := cl.GetAll()
+ require.Len(t, clients, 3)
+}
+
+func TestClientsLen(t *testing.T) {
+ cl := NewClients()
+ cl.Add(&Client{ID: "t1"})
+ cl.Add(&Client{ID: "t2"})
+ require.Contains(t, cl.internal, "t1")
+ require.Contains(t, cl.internal, "t2")
+ require.Equal(t, 2, cl.Len())
+}
+
+func TestClientsDelete(t *testing.T) {
+ cl := NewClients()
+ cl.Add(&Client{ID: "t1"})
+ require.Contains(t, cl.internal, "t1")
+
+ cl.Delete("t1")
+ _, ok := cl.Get("t1")
+ require.Equal(t, false, ok)
+ require.Nil(t, cl.internal["t1"])
+}
+
+func TestClientsGetByListener(t *testing.T) {
+ cl := NewClients()
+ cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}})
+ cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}})
+ require.Contains(t, cl.internal, "t1")
+ require.Contains(t, cl.internal, "t2")
+
+ clients := cl.GetByListener("tcp1")
+ require.NotEmpty(t, clients)
+ require.Equal(t, 1, len(clients))
+ require.Equal(t, "tcp1", clients[0].Net.Listener)
+}
+
+func TestNewClient(t *testing.T) {
+ cl, _, _ := newTestClient()
+
+ require.NotNil(t, cl)
+ require.NotNil(t, cl.State.Inflight.internal)
+ require.NotNil(t, cl.State.Subscriptions)
+ require.NotNil(t, cl.State.TopicAliases)
+ require.Equal(t, defaultKeepalive, cl.State.Keepalive)
+ require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
+ require.NotNil(t, cl.Net.Conn)
+ require.NotNil(t, cl.Net.bconn)
+ require.NotNil(t, cl.ops)
+ require.NotNil(t, cl.ops.options.Capabilities)
+ require.False(t, cl.Net.Inline)
+}
+
+func TestClientParseConnect(t *testing.T) {
+ cl, _, _ := newTestClient()
+
+ pk := packets.Packet{
+ ProtocolVersion: 4,
+ Connect: packets.ConnectParams{
+ ProtocolName: []byte{'M', 'Q', 'T', 'T'},
+ Clean: true,
+ Keepalive: 60,
+ ClientIdentifier: "mochi",
+ WillFlag: true,
+ WillTopic: "lwt",
+ WillPayload: []byte("lol gg"),
+ WillQos: 1,
+ WillRetain: false,
+ },
+ Properties: packets.Properties{
+ ReceiveMaximum: uint16(5),
+ },
+ }
+
+ cl.ParseConnect("tcp1", pk)
+ require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
+ require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive)
+ require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
+ require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
+ require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
+ require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload)
+ require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
+ require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
+ require.Equal(t, uint32(1), cl.Properties.Will.Flag)
+ require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
+ require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
+ require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota)
+ require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
+}
+
+func TestClientParseConnectReceiveMaxExceedMaxInflight(t *testing.T) {
+ const MaxInflight uint16 = 1
+ cl, _, _ := newTestClient()
+ cl.ops.options.Capabilities.MaximumInflight = MaxInflight
+
+ pk := packets.Packet{
+ ProtocolVersion: 4,
+ Connect: packets.ConnectParams{
+ ProtocolName: []byte{'M', 'Q', 'T', 'T'},
+ Clean: true,
+ Keepalive: 60,
+ ClientIdentifier: "mochi",
+ WillFlag: true,
+ WillTopic: "lwt",
+ WillPayload: []byte("lol gg"),
+ WillQos: 1,
+ WillRetain: false,
+ },
+ Properties: packets.Properties{
+ ReceiveMaximum: uint16(5),
+ },
+ }
+
+ cl.ParseConnect("tcp1", pk)
+ require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
+ require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive)
+ require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
+ require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
+ require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
+ require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload)
+ require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
+ require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
+ require.Equal(t, uint32(1), cl.Properties.Will.Flag)
+ require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
+ require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
+ require.Equal(t, int32(MaxInflight), cl.State.Inflight.sendQuota)
+ require.Equal(t, int32(MaxInflight), cl.State.Inflight.maximumSendQuota)
+}
+
+func TestClientParseConnectOverrideWillDelay(t *testing.T) {
+ cl, _, _ := newTestClient()
+
+ pk := packets.Packet{
+ ProtocolVersion: 4,
+ Connect: packets.ConnectParams{
+ ProtocolName: []byte{'M', 'Q', 'T', 'T'},
+ Clean: true,
+ Keepalive: 60,
+ ClientIdentifier: "mochi",
+ WillFlag: true,
+ WillProperties: packets.Properties{
+ WillDelayInterval: 200,
+ },
+ },
+ Properties: packets.Properties{
+ SessionExpiryInterval: 100,
+ SessionExpiryIntervalFlag: true,
+ },
+ }
+
+ cl.ParseConnect("tcp1", pk)
+ require.Equal(t, pk.Properties.SessionExpiryInterval, cl.Properties.Will.WillDelayInterval)
+}
+
+func TestClientParseConnectNoID(t *testing.T) {
+ cl, _, _ := newTestClient()
+ cl.ParseConnect("tcp1", packets.Packet{})
+ require.NotEmpty(t, cl.ID)
+}
+
+func TestClientParseConnectBelowMinimumKeepalive(t *testing.T) {
+ cl, _, _ := newTestClient()
+ var b bytes.Buffer
+ x := bufio.NewWriter(&b)
+ cl.ops.log = slog.New(slog.NewTextHandler(x, nil))
+
+ pk := packets.Packet{
+ ProtocolVersion: 4,
+ Connect: packets.ConnectParams{
+ ProtocolName: []byte{'M', 'Q', 'T', 'T'},
+ Keepalive: minimumKeepalive - 1,
+ ClientIdentifier: "mochi",
+ },
+ }
+ cl.ParseConnect("tcp1", pk)
+ err := x.Flush()
+ require.NoError(t, err)
+ require.True(t, strings.Contains(b.String(), ErrMinimumKeepalive.Error()))
+ require.NotEmpty(t, cl.ID)
+}
+
+func TestClientNextPacketID(t *testing.T) {
+ cl, _, _ := newTestClient()
+
+ i, err := cl.NextPacketID()
+ require.NoError(t, err)
+ require.Equal(t, uint32(1), i)
+
+ i, err = cl.NextPacketID()
+ require.NoError(t, err)
+ require.Equal(t, uint32(2), i)
+}
+
+func TestClientNextPacketIDInUse(t *testing.T) {
+ cl, _, _ := newTestClient()
+
+ // skip over 2
+ cl.State.Inflight.Set(packets.Packet{PacketID: 2})
+
+ i, err := cl.NextPacketID()
+ require.NoError(t, err)
+ require.Equal(t, uint32(1), i)
+
+ i, err = cl.NextPacketID()
+ require.NoError(t, err)
+ require.Equal(t, uint32(3), i)
+
+ // Skip over overflow
+ cl.State.Inflight.Set(packets.Packet{PacketID: 65535})
+ atomic.StoreUint32(&cl.State.packetID, 65534)
+
+ i, err = cl.NextPacketID()
+ require.NoError(t, err)
+ require.Equal(t, uint32(1), i)
+}
+
+func TestClientNextPacketIDExhausted(t *testing.T) {
+ cl, _, _ := newTestClient()
+ for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
+ cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)}
+ }
+
+ i, err := cl.NextPacketID()
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrQuotaExceeded)
+ require.Equal(t, uint32(0), i)
+}
+
+func TestClientNextPacketIDOverflow(t *testing.T) {
+ cl, _, _ := newTestClient()
+ for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ {
+ cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
+ }
+
+ cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID - 1
+ i, err := cl.NextPacketID()
+ require.NoError(t, err)
+ require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
+ cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{}
+
+ cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID
+ _, err = cl.NextPacketID()
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrQuotaExceeded)
+}
+
+func TestClientClearInflights(t *testing.T) {
+ cl, _, _ := newTestClient()
+ n := time.Now().Unix()
+
+ cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 1, Expiry: n - 1})
+ cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 2, Expiry: n - 2})
+ cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 3, Created: n - 3}) // within bounds
+ cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 5, Created: n - 5}) // over max server expiry limit
+ cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 7, Created: n})
+
+ require.Equal(t, 5, cl.State.Inflight.Len())
+ cl.ClearInflights()
+ require.Equal(t, 0, cl.State.Inflight.Len())
+}
+
+func TestClientClearExpiredInflights(t *testing.T) {
+ cl, _, _ := newTestClient()
+
+ n := time.Now().Unix()
+ cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 1, Expiry: n - 1})
+ cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 2, Expiry: n - 2})
+ cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 3, Created: n - 3}) // within bounds
+ cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 5, Created: n - 5}) // over max server expiry limit
+ cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 7, Created: n})
+ require.Equal(t, 5, cl.State.Inflight.Len())
+
+ deleted := cl.ClearExpiredInflights(n, 4)
+ require.Len(t, deleted, 3)
+ require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
+ require.Equal(t, 2, cl.State.Inflight.Len())
+
+ cl.State.Inflight.Set(packets.Packet{PacketID: 11, Expiry: n - 1})
+ cl.State.Inflight.Set(packets.Packet{PacketID: 12, Expiry: n - 2}) // expiry is ineffective for v3.
+ cl.State.Inflight.Set(packets.Packet{PacketID: 13, Created: n - 3}) // within bounds for v3
+ cl.State.Inflight.Set(packets.Packet{PacketID: 15, Created: n - 5}) // over max server expiry limit
+ require.Equal(t, 6, cl.State.Inflight.Len())
+
+ deleted = cl.ClearExpiredInflights(n, 4)
+ require.Len(t, deleted, 3)
+ require.ElementsMatch(t, []uint16{11, 12, 15}, deleted)
+ require.Equal(t, 3, cl.State.Inflight.Len())
+
+ cl.State.Inflight.Set(packets.Packet{PacketID: 17, Created: n - 1})
+ deleted = cl.ClearExpiredInflights(n, 0) // maximumExpiry = 0 do not process abandon messages
+ require.Len(t, deleted, 0)
+ require.Equal(t, 4, cl.State.Inflight.Len())
+
+ cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 18, Expiry: n - 1})
+ deleted = cl.ClearExpiredInflights(n, 0) // maximumExpiry = 0 do not abandon messages
+ require.ElementsMatch(t, []uint16{18}, deleted) // expiry is still effective for v5.
+ require.Len(t, deleted, 1)
+ require.Equal(t, 4, cl.State.Inflight.Len())
+}
+
+func TestClientResendInflightMessages(t *testing.T) {
+ pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
+ cl, r, w := newTestClient()
+
+ cl.State.Inflight.Set(*pk1.Packet)
+ require.Equal(t, 1, cl.State.Inflight.Len())
+
+ go func() {
+ err := cl.ResendInflightMessages(true)
+ require.NoError(t, err)
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, 0, cl.State.Inflight.Len())
+ require.Equal(t, pk1.RawBytes, buf)
+}
+
+func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
+ pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
+ cl, r, _ := newTestClient()
+ _ = r.Close()
+
+ cl.State.Inflight.Set(*pk1.Packet)
+ require.Equal(t, 1, cl.State.Inflight.Len())
+ err := cl.ResendInflightMessages(true)
+ require.Error(t, err)
+ require.ErrorIs(t, err, io.ErrClosedPipe)
+ require.Equal(t, 1, cl.State.Inflight.Len())
+}
+
+func TestClientResendInflightMessagesNoMessages(t *testing.T) {
+ cl, _, _ := newTestClient()
+ err := cl.ResendInflightMessages(true)
+ require.NoError(t, err)
+}
+
+func TestClientRefreshDeadline(t *testing.T) {
+ cl, _, _ := newTestClient()
+ cl.refreshDeadline(10)
+ require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
+}
+
+func TestClientReadFixedHeader(t *testing.T) {
+ cl, r, _ := newTestClient()
+
+ defer cl.Stop(errClientStop)
+ go func() {
+ _, _ = r.Write([]byte{packets.Connect << 4, 0x00})
+ _ = r.Close()
+ }()
+
+ fh := new(packets.FixedHeader)
+ err := cl.ReadFixedHeader(fh)
+ require.NoError(t, err)
+ require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.BytesReceived))
+}
+
+func TestClientReadFixedHeaderDecodeError(t *testing.T) {
+ cl, r, _ := newTestClient()
+ defer cl.Stop(errClientStop)
+
+ go func() {
+ _, _ = r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
+ _ = r.Close()
+ }()
+
+ fh := new(packets.FixedHeader)
+ err := cl.ReadFixedHeader(fh)
+ require.Error(t, err)
+}
+
+func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
+ cl, r, _ := newTestClient()
+ cl.ops.options.Capabilities.MaximumPacketSize = 2
+ defer cl.Stop(errClientStop)
+
+ go func() {
+ _, _ = r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
+ _ = r.Close()
+ }()
+
+ fh := new(packets.FixedHeader)
+ err := cl.ReadFixedHeader(fh)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrPacketTooLarge)
+}
+
+func TestClientReadFixedHeaderReadEOF(t *testing.T) {
+ cl, r, _ := newTestClient()
+ defer cl.Stop(errClientStop)
+
+ go func() {
+ _ = r.Close()
+ }()
+
+ fh := new(packets.FixedHeader)
+ err := cl.ReadFixedHeader(fh)
+ require.Error(t, err)
+ require.Equal(t, io.EOF, err)
+}
+
+func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
+ cl, r, _ := newTestClient()
+ defer cl.Stop(errClientStop)
+
+ go func() {
+ _, _ = r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
+ _ = r.Close()
+ }()
+
+ fh := new(packets.FixedHeader)
+ err := cl.ReadFixedHeader(fh)
+ require.Error(t, err)
+}
+
+func TestClientReadOK(t *testing.T) {
+ cl, r, _ := newTestClient()
+ defer cl.Stop(errClientStop)
+ go func() {
+ _, _ = r.Write([]byte{
+ packets.Publish << 4, 18, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload,
+ packets.Publish << 4, 11, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'd', '/', 'e', '/', 'f', // Topic Name
+ 'y', 'e', 'a', 'h', // Payload
+ })
+ _ = r.Close()
+ }()
+
+ var pks []packets.Packet
+ o := make(chan error)
+ go func() {
+ o <- cl.Read(func(cl *Client, pk packets.Packet) error {
+ pks = append(pks, pk)
+ return nil
+ })
+ }()
+
+ err := <-o
+ require.Error(t, err)
+ require.ErrorIs(t, err, io.EOF)
+ require.Equal(t, 2, len(pks))
+ require.Equal(t, []packets.Packet{
+ {
+ ProtocolVersion: cl.Properties.ProtocolVersion,
+ FixedHeader: packets.FixedHeader{
+ Type: packets.Publish,
+ Remaining: 18,
+ },
+ TopicName: "a/b/c",
+ Payload: []byte("hello mochi"),
+ },
+ {
+ ProtocolVersion: cl.Properties.ProtocolVersion,
+ FixedHeader: packets.FixedHeader{
+ Type: packets.Publish,
+ Remaining: 11,
+ },
+ TopicName: "d/e/f",
+ Payload: []byte("yeah"),
+ },
+ }, pks)
+
+ require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.MessagesReceived))
+}
+
+func TestClientReadDone(t *testing.T) {
+ cl, _, _ := newTestClient()
+ defer cl.Stop(errClientStop)
+ cl.State.cancelOpen()
+
+ o := make(chan error)
+ go func() {
+ o <- cl.Read(func(cl *Client, pk packets.Packet) error {
+ return nil
+ })
+ }()
+
+ require.NoError(t, <-o)
+}
+
+func TestClientStop(t *testing.T) {
+ cl, _, _ := newTestClient()
+ require.Equal(t, int64(0), cl.StopTime())
+ cl.Stop(nil)
+ require.Equal(t, nil, cl.State.stopCause.Load())
+ require.InDelta(t, time.Now().Unix(), cl.State.disconnected, 1.0)
+ require.Equal(t, cl.State.disconnected, cl.StopTime())
+ require.True(t, cl.Closed())
+ require.Equal(t, nil, cl.StopCause())
+}
+
+func TestClientClosed(t *testing.T) {
+ cl, _, _ := newTestClient()
+ require.False(t, cl.Closed())
+ cl.Stop(nil)
+ require.True(t, cl.Closed())
+}
+
+func TestClientReadFixedHeaderError(t *testing.T) {
+ cl, r, _ := newTestClient()
+ defer cl.Stop(errClientStop)
+ go func() {
+ _, _ = r.Write([]byte{
+ packets.Publish << 4, 11, // Fixed header
+ })
+ _ = r.Close()
+ }()
+
+ cl.Net.bconn = nil
+ fh := new(packets.FixedHeader)
+ err := cl.ReadFixedHeader(fh)
+ require.Error(t, err)
+ require.ErrorIs(t, ErrConnectionClosed, err)
+}
+
+func TestClientReadReadHandlerErr(t *testing.T) {
+ cl, r, _ := newTestClient()
+ defer cl.Stop(errClientStop)
+ go func() {
+ _, _ = r.Write([]byte{
+ packets.Publish << 4, 11, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'd', '/', 'e', '/', 'f', // Topic Name
+ 'y', 'e', 'a', 'h', // Payload
+ })
+ _ = r.Close()
+ }()
+
+ err := cl.Read(func(cl *Client, pk packets.Packet) error {
+ return errors.New("test")
+ })
+
+ require.Error(t, err)
+}
+
+func TestClientReadReadPacketOK(t *testing.T) {
+ cl, r, _ := newTestClient()
+ defer cl.Stop(errClientStop)
+ go func() {
+ _, _ = r.Write([]byte{
+ packets.Publish << 4, 11, // Fixed header
+ 0, 5,
+ 'd', '/', 'e', '/', 'f',
+ 'y', 'e', 'a', 'h',
+ })
+ _ = r.Close()
+ }()
+
+ fh := new(packets.FixedHeader)
+ err := cl.ReadFixedHeader(fh)
+ require.NoError(t, err)
+
+ pk, err := cl.ReadPacket(fh)
+ require.NoError(t, err)
+ require.NotNil(t, pk)
+
+ require.Equal(t, packets.Packet{
+ ProtocolVersion: cl.Properties.ProtocolVersion,
+ FixedHeader: packets.FixedHeader{
+ Type: packets.Publish,
+ Remaining: 11,
+ },
+ TopicName: "d/e/f",
+ Payload: []byte("yeah"),
+ }, pk)
+}
+
+func TestClientReadPacket(t *testing.T) {
+ cl, r, _ := newTestClient()
+ defer cl.Stop(errClientStop)
+
+ for _, tx := range pkTable {
+ tt := tx // avoid data race
+ t.Run(tt.Desc, func(t *testing.T) {
+ atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0)
+ go func() {
+ _, _ = r.Write(tt.RawBytes)
+ }()
+
+ fh := new(packets.FixedHeader)
+ err := cl.ReadFixedHeader(fh)
+ require.NoError(t, err)
+
+ if tt.Packet.ProtocolVersion == 5 {
+ cl.Properties.ProtocolVersion = 5
+ } else {
+ cl.Properties.ProtocolVersion = 0
+ }
+
+ pk, err := cl.ReadPacket(fh)
+ require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
+ require.NotNil(t, pk, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, *tt.Packet, pk, pkInfo, tt.Case, tt.Desc)
+
+ if tt.Packet.FixedHeader.Type == packets.Publish {
+ require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsReceived), pkInfo, tt.Case, tt.Desc)
+ }
+ })
+ }
+}
+
+func TestClientReadPacketInvalidTypeError(t *testing.T) {
+ cl, _, _ := newTestClient()
+ _ = cl.Net.Conn.Close()
+ _, err := cl.ReadPacket(&packets.FixedHeader{})
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "invalid packet type")
+}
+
+func TestClientWritePacket(t *testing.T) {
+ for _, tt := range pkTable {
+ cl, r, _ := newTestClient()
+ defer cl.Stop(errClientStop)
+
+ cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
+
+ o := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ o <- buf
+ }()
+
+ err := cl.WritePacket(*tt.Packet)
+ require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
+
+ time.Sleep(2 * time.Millisecond)
+ _ = cl.Net.Conn.Close()
+
+ require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
+
+ cl.Stop(errClientStop)
+ time.Sleep(time.Millisecond * 1)
+
+ // The stop cause is either the test error, EOF, or a
+ // closed pipe, depending on which goroutine runs first.
+ err = cl.StopCause()
+ require.True(t,
+ errors.Is(err, errClientStop) ||
+ errors.Is(err, io.EOF) ||
+ errors.Is(err, io.ErrClosedPipe))
+
+ require.Equal(t, int64(len(tt.RawBytes)), atomic.LoadInt64(&cl.ops.info.BytesSent))
+ require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsSent))
+ if tt.Packet.FixedHeader.Type == packets.Publish {
+ require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.MessagesSent))
+ }
+ }
+}
+
+func TestClientWritePacketBuffer(t *testing.T) {
+ r, w := net.Pipe()
+
+ cl := newClient(w, &ops{
+ info: new(system.Info),
+ hooks: new(Hooks),
+ log: logger,
+ options: &Options{
+ Capabilities: &Capabilities{
+ ReceiveMaximum: 10,
+ TopicAliasMaximum: 10000,
+ MaximumClientWritesPending: 3,
+ maximumPacketID: 10,
+ },
+ },
+ })
+
+ cl.ID = "mochi"
+ cl.State.Inflight.maximumSendQuota = 5
+ cl.State.Inflight.sendQuota = 5
+ cl.State.Inflight.maximumReceiveQuota = 10
+ cl.State.Inflight.receiveQuota = 10
+ cl.Properties.Props.TopicAliasMaximum = 0
+ cl.Properties.Props.RequestResponseInfo = 0x1
+
+ cl.ops.options.ClientNetWriteBufferSize = 10
+ defer cl.Stop(errClientStop)
+
+ small := packets.TPacketData[packets.Publish].Get(packets.TPublishNoPayload).Packet
+ large := packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
+
+ cl.State.outbound <- small
+
+ tt := []struct {
+ pks []*packets.Packet
+ size int
+ }{
+ {
+ pks: []*packets.Packet{small, small},
+ size: 18,
+ },
+ {
+ pks: []*packets.Packet{large},
+ size: 20,
+ },
+ {
+ pks: []*packets.Packet{small},
+ size: 0,
+ },
+ }
+
+ go func() {
+ for i, tx := range tt {
+ for _, pk := range tx.pks {
+ cl.Properties.ProtocolVersion = pk.ProtocolVersion
+ err := cl.WritePacket(*pk)
+ require.NoError(t, err, "index: %d", i)
+ if i == len(tt)-1 {
+ cl.Net.Conn.Close()
+ }
+ time.Sleep(100 * time.Millisecond)
+ }
+ }
+ }()
+
+ var n int
+ var err error
+ for i, tx := range tt {
+ buf := make([]byte, 100)
+ if i == len(tt)-1 {
+ buf, err = io.ReadAll(r)
+ n = len(buf)
+ } else {
+ n, err = io.ReadAtLeast(r, buf, 1)
+ }
+ require.NoError(t, err, "index: %d", i)
+ require.Equal(t, tx.size, n, "index: %d", i)
+ }
+}
+
+func TestWriteClientOversizePacket(t *testing.T) {
+ cl, _, _ := newTestClient()
+ cl.Properties.Props.MaximumPacketSize = 2
+ pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
+ err := cl.WritePacket(pk)
+ require.Error(t, err)
+ require.ErrorIs(t, packets.ErrPacketTooLarge, err)
+}
+
+func TestClientReadPacketReadingError(t *testing.T) {
+ cl, r, _ := newTestClient()
+ defer cl.Stop(errClientStop)
+ go func() {
+ _, _ = r.Write([]byte{
+ 0, 11, // Fixed header
+ 0, 5,
+ 'd', '/', 'e', '/', 'f',
+ 'y', 'e', 'a', 'h',
+ })
+ _ = r.Close()
+ }()
+
+ _, err := cl.ReadPacket(&packets.FixedHeader{
+ Type: 0,
+ Remaining: 11,
+ })
+ require.Error(t, err)
+}
+
+func TestClientReadPacketReadUnknown(t *testing.T) {
+ cl, r, _ := newTestClient()
+ defer cl.Stop(errClientStop)
+ go func() {
+ _, _ = r.Write([]byte{
+ 0, 11, // Fixed header
+ 0, 5,
+ 'd', '/', 'e', '/', 'f',
+ 'y', 'e', 'a', 'h',
+ })
+ _ = r.Close()
+ }()
+
+ _, err := cl.ReadPacket(&packets.FixedHeader{
+ Remaining: 1,
+ })
+ require.Error(t, err)
+}
+
+func TestClientWritePacketWriteNoConn(t *testing.T) {
+ cl, _, _ := newTestClient()
+ cl.Stop(errClientStop)
+
+ err := cl.WritePacket(*pkTable[1].Packet)
+ require.Error(t, err)
+ require.Equal(t, ErrConnectionClosed, err)
+}
+
+func TestClientWritePacketWriteError(t *testing.T) {
+ cl, _, _ := newTestClient()
+ _ = cl.Net.Conn.Close()
+
+ err := cl.WritePacket(*pkTable[1].Packet)
+ require.Error(t, err)
+}
+
+func TestClientWritePacketInvalidPacket(t *testing.T) {
+ cl, _, _ := newTestClient()
+ err := cl.WritePacket(packets.Packet{})
+ require.Error(t, err)
+}
+
+var (
+ pkTable = []packets.TPacketCase{
+ packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311),
+ packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedMqtt5),
+ packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession),
+ packets.TPacketData[packets.Publish].Get(packets.TPublishBasic),
+ packets.TPacketData[packets.Publish].Get(packets.TPublishMqtt5),
+ packets.TPacketData[packets.Puback].Get(packets.TPuback),
+ packets.TPacketData[packets.Pubrec].Get(packets.TPubrec),
+ packets.TPacketData[packets.Pubrel].Get(packets.TPubrel),
+ packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp),
+ packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe),
+ packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5),
+ packets.TPacketData[packets.Suback].Get(packets.TSuback),
+ packets.TPacketData[packets.Suback].Get(packets.TSubackMqtt5),
+ packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribe),
+ packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5),
+ packets.TPacketData[packets.Unsuback].Get(packets.TUnsuback),
+ packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackMqtt5),
+ packets.TPacketData[packets.Pingreq].Get(packets.TPingreq),
+ packets.TPacketData[packets.Pingresp].Get(packets.TPingresp),
+ packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect),
+ packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5),
+ packets.TPacketData[packets.Auth].Get(packets.TAuth),
+ }
+)
diff --git a/mqtt/hooks.go b/mqtt/hooks.go
new file mode 100644
index 0000000..98c8433
--- /dev/null
+++ b/mqtt/hooks.go
@@ -0,0 +1,864 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co, thedevop, dgduncan
+
+package mqtt
+
+import (
+ "errors"
+ "fmt"
+ "log/slog"
+ "sync"
+ "sync/atomic"
+
+ "testmqtt/hooks/storage"
+ "testmqtt/packets"
+ "testmqtt/system"
+)
+
+const (
+ SetOptions byte = iota
+ OnSysInfoTick
+ OnStarted
+ OnStopped
+ OnConnectAuthenticate
+ OnACLCheck
+ OnConnect
+ OnSessionEstablish
+ OnSessionEstablished
+ OnDisconnect
+ OnAuthPacket
+ OnPacketRead
+ OnPacketEncode
+ OnPacketSent
+ OnPacketProcessed
+ OnSubscribe
+ OnSubscribed
+ OnSelectSubscribers
+ OnUnsubscribe
+ OnUnsubscribed
+ OnPublish
+ OnPublished
+ OnPublishDropped
+ OnRetainMessage
+ OnRetainPublished
+ OnQosPublish
+ OnQosComplete
+ OnQosDropped
+ OnPacketIDExhausted
+ OnWill
+ OnWillSent
+ OnClientExpired
+ OnRetainedExpired
+ StoredClients
+ StoredSubscriptions
+ StoredInflightMessages
+ StoredRetainedMessages
+ StoredSysInfo
+)
+
+var (
+ // ErrInvalidConfigType indicates a different Type of config value was expected to what was received.
+ // ErrInvalidConfigType = errors.New("invalid config type provided")
+ ErrInvalidConfigType = errors.New("提供的配置类型无效")
+)
+
+// HookLoadConfig contains the hook and configuration as loaded from a configuration (usually file).
+type HookLoadConfig struct {
+ Hook Hook
+ Config any
+}
+
+// Hook provides an interface of handlers for different events which occur
+// during the lifecycle of the broker.
+type Hook interface {
+ ID() string
+ Provides(b byte) bool
+ Init(config any) error
+ Stop() error
+ SetOpts(l *slog.Logger, o *HookOptions)
+
+ OnStarted()
+ OnStopped()
+ OnConnectAuthenticate(cl *Client, pk packets.Packet) bool
+ OnACLCheck(cl *Client, topic string, write bool) bool
+ OnSysInfoTick(*system.Info)
+ OnConnect(cl *Client, pk packets.Packet) error
+ OnSessionEstablish(cl *Client, pk packets.Packet)
+ OnSessionEstablished(cl *Client, pk packets.Packet)
+ OnDisconnect(cl *Client, err error, expire bool)
+ OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error)
+ OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) // triggers when a new packet is received by a client, but before packet validation
+ OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet // modify a packet before it is byte-encoded and written to the client
+ OnPacketSent(cl *Client, pk packets.Packet, b []byte) // triggers when packet bytes have been written to the client
+ OnPacketProcessed(cl *Client, pk packets.Packet, err error) // triggers after a packet from the client been processed (handled)
+ OnSubscribe(cl *Client, pk packets.Packet) packets.Packet
+ OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
+ OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers
+ OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet
+ OnUnsubscribed(cl *Client, pk packets.Packet)
+ OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
+ OnPublished(cl *Client, pk packets.Packet)
+ OnPublishDropped(cl *Client, pk packets.Packet)
+ OnRetainMessage(cl *Client, pk packets.Packet, r int64)
+ OnRetainPublished(cl *Client, pk packets.Packet)
+ OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
+ OnQosComplete(cl *Client, pk packets.Packet)
+ OnQosDropped(cl *Client, pk packets.Packet)
+ OnPacketIDExhausted(cl *Client, pk packets.Packet)
+ OnWill(cl *Client, will Will) (Will, error)
+ OnWillSent(cl *Client, pk packets.Packet)
+ OnClientExpired(cl *Client)
+ OnRetainedExpired(filter string)
+ StoredClients() ([]storage.Client, error)
+ StoredSubscriptions() ([]storage.Subscription, error)
+ StoredInflightMessages() ([]storage.Message, error)
+ StoredRetainedMessages() ([]storage.Message, error)
+ StoredSysInfo() (storage.SystemInfo, error)
+}
+
+// HookOptions contains values which are inherited from the server on initialisation.
+type HookOptions struct {
+ Capabilities *Capabilities
+}
+
+// Hooks is a slice of Hook interfaces to be called in sequence.
+type Hooks struct {
+ Log *slog.Logger // a logger for the hook (from the server)
+ internal atomic.Value // a slice of []Hook
+ wg sync.WaitGroup // a waitgroup for syncing hook shutdown
+ qty int64 // the number of hooks in use
+ sync.Mutex // a mutex for locking when adding hooks
+}
+
+// Len returns the number of hooks added.
+func (h *Hooks) Len() int64 {
+ return atomic.LoadInt64(&h.qty)
+}
+
+// Provides returns true if any one hook provides any of the requested hook methods.
+func (h *Hooks) Provides(b ...byte) bool {
+ for _, hook := range h.GetAll() {
+ for _, hb := range b {
+ if hook.Provides(hb) {
+ return true
+ }
+ }
+ }
+
+ return false
+}
+
+// Add adds and initializes a new hook.
+func (h *Hooks) Add(hook Hook, config any) error {
+ h.Lock()
+ defer h.Unlock()
+
+ err := hook.Init(config)
+ if err != nil {
+ return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
+ }
+
+ i, ok := h.internal.Load().([]Hook)
+ if !ok {
+ i = []Hook{}
+ }
+
+ i = append(i, hook)
+ h.internal.Store(i)
+ atomic.AddInt64(&h.qty, 1)
+ h.wg.Add(1)
+
+ return nil
+}
+
+// GetAll returns a slice of all the hooks.
+func (h *Hooks) GetAll() []Hook {
+ i, ok := h.internal.Load().([]Hook)
+ if !ok {
+ return []Hook{}
+ }
+
+ return i
+}
+
+// Stop indicates all attached hooks to gracefully end.
+func (h *Hooks) Stop() {
+ go func() {
+ for _, hook := range h.GetAll() {
+ h.Log.Info("stopping hook", "hook", hook.ID())
+ if err := hook.Stop(); err != nil {
+ h.Log.Debug("problem stopping hook", "error", err, "hook", hook.ID())
+ }
+
+ h.wg.Done()
+ }
+ }()
+
+ h.wg.Wait()
+}
+
+// OnSysInfoTick is called when the $SYS topic values are published out.
+func (h *Hooks) OnSysInfoTick(sys *system.Info) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnSysInfoTick) {
+ hook.OnSysInfoTick(sys)
+ }
+ }
+}
+
+// OnStarted is called when the server has successfully started.
+func (h *Hooks) OnStarted() {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnStarted) {
+ hook.OnStarted()
+ }
+ }
+}
+
+// OnStopped is called when the server has successfully stopped.
+func (h *Hooks) OnStopped() {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnStopped) {
+ hook.OnStopped()
+ }
+ }
+}
+
+// OnConnect is called when a new client connects, and may return a packets.Code as an error to halt the connection.
+func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) error {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnConnect) {
+ err := hook.OnConnect(cl, pk)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// OnSessionEstablish is called right after a new client connects and authenticates and right before
+// the session is established and CONNACK is sent.
+func (h *Hooks) OnSessionEstablish(cl *Client, pk packets.Packet) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnSessionEstablish) {
+ hook.OnSessionEstablish(cl, pk)
+ }
+ }
+}
+
+// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
+func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnSessionEstablished) {
+ hook.OnSessionEstablished(cl, pk)
+ }
+ }
+}
+
+// OnDisconnect is called when a client is disconnected for any reason.
+func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnDisconnect) {
+ hook.OnDisconnect(cl, err, expire)
+ }
+ }
+}
+
+// OnPacketRead is called when a packet is received from a client.
+func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
+ pkx = pk
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnPacketRead) {
+ npk, err := hook.OnPacketRead(cl, pkx)
+ if err != nil && errors.Is(err, packets.ErrRejectPacket) {
+ h.Log.Debug("packet rejected", "hook", hook.ID(), "packet", pkx)
+ return pk, err
+ } else if err != nil {
+ continue
+ }
+
+ pkx = npk
+ }
+ }
+
+ return
+}
+
+// OnAuthPacket is called when an auth packet is received. It is intended to allow developers
+// to create their own auth packet handling mechanisms.
+func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
+ pkx = pk
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnAuthPacket) {
+ npk, err := hook.OnAuthPacket(cl, pkx)
+ if err != nil {
+ return pk, err
+ }
+
+ pkx = npk
+ }
+ }
+
+ return
+}
+
+// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
+func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnPacketEncode) {
+ pk = hook.OnPacketEncode(cl, pk)
+ }
+ }
+
+ return pk
+}
+
+// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
+func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnPacketProcessed) {
+ hook.OnPacketProcessed(cl, pk, err)
+ }
+ }
+}
+
+// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
+// containing the bytes sent.
+func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnPacketSent) {
+ hook.OnPacketSent(cl, pk, b)
+ }
+ }
+}
+
+// OnSubscribe is called when a client subscribes to one or more filters. This method
+// differs from OnSubscribed in that it allows you to modify the subscription values
+// before the packet is processed. The return values of the hook methods are passed-through
+// in the order the hooks were attached.
+func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnSubscribe) {
+ pk = hook.OnSubscribe(cl, pk)
+ }
+ }
+ return pk
+}
+
+// OnSubscribed is called when a client subscribes to one or more filters.
+func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnSubscribed) {
+ hook.OnSubscribed(cl, pk, reasonCodes)
+ }
+ }
+}
+
+// OnSelectSubscribers is called when subscribers have been collected for a topic, but before
+// shared subscription subscribers have been selected. This hook can be used to programmatically
+// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
+// group in a custom manner (such as based on client id, ip, etc).
+func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnSelectSubscribers) {
+ subs = hook.OnSelectSubscribers(subs, pk)
+ }
+ }
+ return subs
+}
+
+// OnUnsubscribe is called when a client unsubscribes from one or more filters. This method
+// differs from OnUnsubscribed in that it allows you to modify the unsubscription values
+// before the packet is processed. The return values of the hook methods are passed-through
+// in the order the hooks were attached.
+func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnUnsubscribe) {
+ pk = hook.OnUnsubscribe(cl, pk)
+ }
+ }
+ return pk
+}
+
+// OnUnsubscribed is called when a client unsubscribes from one or more filters.
+func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnUnsubscribed) {
+ hook.OnUnsubscribed(cl, pk)
+ }
+ }
+}
+
+// OnPublish is called when a client publishes a message. This method differs from OnPublished
+// in that it allows you to modify you to modify the incoming packet before it is processed.
+// The return values of the hook methods are passed-through in the order the hooks were attached.
+func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
+ pkx = pk
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnPublish) {
+ npk, err := hook.OnPublish(cl, pkx)
+ if err != nil {
+ if errors.Is(err, packets.ErrRejectPacket) {
+ h.Log.Debug("publish packet rejected",
+ "error", err,
+ "hook", hook.ID(),
+ "packet", pkx)
+ return pk, err
+ }
+ h.Log.Error("publish packet error",
+ "error", err,
+ "hook", hook.ID(),
+ "packet", pkx)
+ return pk, err
+ }
+ pkx = npk
+ }
+ }
+
+ return
+}
+
+// OnPublished is called when a client has published a message to subscribers.
+func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnPublished) {
+ hook.OnPublished(cl, pk)
+ }
+ }
+}
+
+// OnPublishDropped is called when a message to a client was dropped instead of delivered
+// such as when a client is too slow to respond.
+func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnPublishDropped) {
+ hook.OnPublishDropped(cl, pk)
+ }
+ }
+}
+
+// OnRetainMessage is called then a published message is retained.
+func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnRetainMessage) {
+ hook.OnRetainMessage(cl, pk, r)
+ }
+ }
+}
+
+// OnRetainPublished is called when a retained message is published.
+func (h *Hooks) OnRetainPublished(cl *Client, pk packets.Packet) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnRetainPublished) {
+ hook.OnRetainPublished(cl, pk)
+ }
+ }
+}
+
+// OnQosPublish is called when a publish packet with Qos >= 1 is issued to a subscriber.
+// In other words, this method is called when a new inflight message is created or resent.
+// It is typically used to store a new inflight message.
+func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnQosPublish) {
+ hook.OnQosPublish(cl, pk, sent, resends)
+ }
+ }
+}
+
+// OnQosComplete is called when the Qos flow for a message has been completed.
+// In other words, when an inflight message is resolved.
+// It is typically used to delete an inflight message from a store.
+func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnQosComplete) {
+ hook.OnQosComplete(cl, pk)
+ }
+ }
+}
+
+// OnQosDropped is called the Qos flow for a message expires. In other words, when
+// an inflight message expires or is abandoned. It is typically used to delete an
+// inflight message from a store.
+func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnQosDropped) {
+ hook.OnQosDropped(cl, pk)
+ }
+ }
+}
+
+// OnPacketIDExhausted is called when the client runs out of unused packet ids to
+// assign to a packet.
+func (h *Hooks) OnPacketIDExhausted(cl *Client, pk packets.Packet) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnPacketIDExhausted) {
+ hook.OnPacketIDExhausted(cl, pk)
+ }
+ }
+}
+
+// OnWill is called when a client disconnects and publishes an LWT message. This method
+// differs from OnWillSent in that it allows you to modify the LWT message before it is
+// published. The return values of the hook methods are passed-through in the order
+// the hooks were attached.
+func (h *Hooks) OnWill(cl *Client, will Will) Will {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnWill) {
+ mlwt, err := hook.OnWill(cl, will)
+ if err != nil {
+ h.Log.Error("parse will error",
+ "error", err,
+ "hook", hook.ID(),
+ "will", will)
+ continue
+ }
+ will = mlwt
+ }
+ }
+
+ return will
+}
+
+// OnWillSent is called when an LWT message has been issued from a disconnecting client.
+func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnWillSent) {
+ hook.OnWillSent(cl, pk)
+ }
+ }
+}
+
+// OnClientExpired is called when a client session has expired and should be deleted.
+func (h *Hooks) OnClientExpired(cl *Client) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnClientExpired) {
+ hook.OnClientExpired(cl)
+ }
+ }
+}
+
+// OnRetainedExpired is called when a retained message has expired and should be deleted.
+func (h *Hooks) OnRetainedExpired(filter string) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnRetainedExpired) {
+ hook.OnRetainedExpired(filter)
+ }
+ }
+}
+
+// StoredClients returns all clients, e.g. from a persistent store, is used to
+// populate the server clients list before start.
+func (h *Hooks) StoredClients() (v []storage.Client, err error) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(StoredClients) {
+ v, err := hook.StoredClients()
+ if err != nil {
+ h.Log.Error("failed to load clients", "error", err, "hook", hook.ID())
+ return v, err
+ }
+
+ if len(v) > 0 {
+ return v, nil
+ }
+ }
+ }
+
+ return
+}
+
+// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
+// used to populate the server subscriptions list before start.
+func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(StoredSubscriptions) {
+ v, err := hook.StoredSubscriptions()
+ if err != nil {
+ h.Log.Error("failed to load subscriptions", "error", err, "hook", hook.ID())
+ return v, err
+ }
+
+ if len(v) > 0 {
+ return v, nil
+ }
+ }
+ }
+
+ return
+}
+
+// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
+// and is used to populate the restored clients with inflight messages before start.
+func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(StoredInflightMessages) {
+ v, err := hook.StoredInflightMessages()
+ if err != nil {
+ h.Log.Error("failed to load inflight messages", "error", err, "hook", hook.ID())
+ return v, err
+ }
+
+ if len(v) > 0 {
+ return v, nil
+ }
+ }
+ }
+
+ return
+}
+
+// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
+// and is used to populate the server topics with retained messages before start.
+func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(StoredRetainedMessages) {
+ v, err := hook.StoredRetainedMessages()
+ if err != nil {
+ h.Log.Error("failed to load retained messages", "error", err, "hook", hook.ID())
+ return v, err
+ }
+
+ if len(v) > 0 {
+ return v, nil
+ }
+ }
+ }
+
+ return
+}
+
+// StoredSysInfo returns a set of system info values.
+func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(StoredSysInfo) {
+ v, err := hook.StoredSysInfo()
+ if err != nil {
+ h.Log.Error("failed to load $SYS info", "error", err, "hook", hook.ID())
+ return v, err
+ }
+
+ if v.Version != "" {
+ return v, nil
+ }
+ }
+ }
+
+ return
+}
+
+// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
+// An implementation of this method MUST be used to allow or deny access to the
+// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
+// check connecting users against an existing user database.
+func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnConnectAuthenticate) {
+ if ok := hook.OnConnectAuthenticate(cl, pk); ok {
+ return true
+ }
+ }
+ }
+
+ return false
+}
+
+// OnACLCheck is called when a user attempts to publish or subscribe to a topic filter.
+// An implementation of this method MUST be used to allow or deny access to the
+// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
+// check publishing and subscribing users against an existing permissions or roles database.
+func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
+ for _, hook := range h.GetAll() {
+ if hook.Provides(OnACLCheck) {
+ if ok := hook.OnACLCheck(cl, topic, write); ok {
+ return true
+ }
+ }
+ }
+
+ return false
+}
+
+// HookBase provides a set of default methods for each hook. It should be embedded in
+// all hooks.
+type HookBase struct {
+ Hook
+ Log *slog.Logger
+ Opts *HookOptions
+}
+
+// ID returns the ID of the hook.
+func (h *HookBase) ID() string {
+ return "base"
+}
+
+// Provides indicates which methods a hook provides. The default is none - this method
+// should be overridden by the embedding hook.
+func (h *HookBase) Provides(b byte) bool {
+ return false
+}
+
+// Init performs any pre-start initializations for the hook, such as connecting to databases
+// or opening files.
+func (h *HookBase) Init(config any) error {
+ return nil
+}
+
+// SetOpts is called by the server to propagate internal values and generally should
+// not be called manually.
+func (h *HookBase) SetOpts(l *slog.Logger, opts *HookOptions) {
+ h.Log = l
+ h.Opts = opts
+}
+
+// Stop is called to gracefully shut down the hook.
+func (h *HookBase) Stop() error {
+ return nil
+}
+
+// OnStarted is called when the server starts.
+func (h *HookBase) OnStarted() {}
+
+// OnStopped is called when the server stops.
+func (h *HookBase) OnStopped() {}
+
+// OnSysInfoTick is called when the server publishes system info.
+func (h *HookBase) OnSysInfoTick(*system.Info) {}
+
+// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
+func (h *HookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
+ return false
+}
+
+// OnACLCheck is called when a user attempts to subscribe or publish to a topic.
+func (h *HookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
+ return false
+}
+
+// OnConnect is called when a new client connects.
+func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) error {
+ return nil
+}
+
+// OnSessionEstablish is called right after a new client connects and authenticates and right before
+// the session is established and CONNACK is sent.
+func (h *HookBase) OnSessionEstablish(cl *Client, pk packets.Packet) {}
+
+// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
+func (h *HookBase) OnSessionEstablished(cl *Client, pk packets.Packet) {}
+
+// OnDisconnect is called when a client is disconnected for any reason.
+func (h *HookBase) OnDisconnect(cl *Client, err error, expire bool) {}
+
+// OnAuthPacket is called when an auth packet is received from the client.
+func (h *HookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
+ return pk, nil
+}
+
+// OnPacketRead is called when a packet is received.
+func (h *HookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
+ return pk, nil
+}
+
+// OnPacketEncode is called before a packet is byte-encoded and written to the client.
+func (h *HookBase) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
+ return pk
+}
+
+// OnPacketSent is called immediately after a packet is written to a client.
+func (h *HookBase) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {}
+
+// OnPacketProcessed is called immediately after a packet from a client is processed.
+func (h *HookBase) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {}
+
+// OnSubscribe is called when a client subscribes to one or more filters.
+func (h *HookBase) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
+ return pk
+}
+
+// OnSubscribed is called when a client subscribes to one or more filters.
+func (h *HookBase) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {}
+
+// OnSelectSubscribers is called when selecting subscribers to receive a message.
+func (h *HookBase) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
+ return subs
+}
+
+// OnUnsubscribe is called when a client unsubscribes from one or more filters.
+func (h *HookBase) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
+ return pk
+}
+
+// OnUnsubscribed is called when a client unsubscribes from one or more filters.
+func (h *HookBase) OnUnsubscribed(cl *Client, pk packets.Packet) {}
+
+// OnPublish is called when a client publishes a message.
+func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
+ return pk, nil
+}
+
+// OnPublished is called when a client has published a message to subscribers.
+func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
+
+// OnPublishDropped is called when a message to a client is dropped instead of being delivered.
+func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {}
+
+// OnRetainMessage is called then a published message is retained.
+func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}
+
+// OnRetainPublished is called when a retained message is published.
+func (h *HookBase) OnRetainPublished(cl *Client, pk packets.Packet) {}
+
+// OnQosPublish is called when a publish packet with Qos > 1 is issued to a subscriber.
+func (h *HookBase) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {}
+
+// OnQosComplete is called when the Qos flow for a message has been completed.
+func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {}
+
+// OnQosDropped is called the Qos flow for a message expires.
+func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {}
+
+// OnPacketIDExhausted is called when the client runs out of unused packet ids to assign to a packet.
+func (h *HookBase) OnPacketIDExhausted(cl *Client, pk packets.Packet) {}
+
+// OnWill is called when a client disconnects and publishes an LWT message.
+func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) {
+ return will, nil
+}
+
+// OnWillSent is called when an LWT message has been issued from a disconnecting client.
+func (h *HookBase) OnWillSent(cl *Client, pk packets.Packet) {}
+
+// OnClientExpired is called when a client session has expired.
+func (h *HookBase) OnClientExpired(cl *Client) {}
+
+// OnRetainedExpired is called when a retained message for a topic has expired.
+func (h *HookBase) OnRetainedExpired(topic string) {}
+
+// StoredClients returns all clients from a store.
+func (h *HookBase) StoredClients() (v []storage.Client, err error) {
+ return
+}
+
+// StoredSubscriptions returns all subcriptions from a store.
+func (h *HookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
+ return
+}
+
+// StoredInflightMessages returns all inflight messages from a store.
+func (h *HookBase) StoredInflightMessages() (v []storage.Message, err error) {
+ return
+}
+
+// StoredRetainedMessages 返回存储区中所有保留的消息
+// StoredRetainedMessages returns all retained messages from a store.
+func (h *HookBase) StoredRetainedMessages() (v []storage.Message, err error) {
+ return
+}
+
+// StoredSysInfo 返回一组系统信息值
+// StoredSysInfo returns a set of system info values.
+func (h *HookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
+ return
+}
diff --git a/mqtt/hooks_test.go b/mqtt/hooks_test.go
new file mode 100644
index 0000000..b792727
--- /dev/null
+++ b/mqtt/hooks_test.go
@@ -0,0 +1,667 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package mqtt
+
+import (
+ "errors"
+ "strconv"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "testmqtt/hooks/storage"
+ "testmqtt/packets"
+ "testmqtt/system"
+
+ "github.com/stretchr/testify/require"
+)
+
+type modifiedHookBase struct {
+ HookBase
+ err error
+ fail bool
+ failAt int
+}
+
+var errTestHook = errors.New("error")
+
+func (h *modifiedHookBase) ID() string {
+ return "modified"
+}
+
+func (h *modifiedHookBase) Init(config any) error {
+ if config != nil {
+ return errTestHook
+ }
+ return nil
+}
+
+func (h *modifiedHookBase) Provides(b byte) bool {
+ return true
+}
+
+func (h *modifiedHookBase) Stop() error {
+ if h.fail {
+ return errTestHook
+ }
+
+ return nil
+}
+
+func (h *modifiedHookBase) OnConnect(cl *Client, pk packets.Packet) error {
+ if h.fail {
+ return errTestHook
+ }
+
+ return nil
+}
+
+func (h *modifiedHookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
+ return true
+}
+
+func (h *modifiedHookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
+ return true
+}
+
+func (h *modifiedHookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
+ if h.fail {
+ if h.err != nil {
+ return pk, h.err
+ }
+
+ return pk, errTestHook
+ }
+
+ return pk, nil
+}
+
+func (h *modifiedHookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
+ if h.fail {
+ if h.err != nil {
+ return pk, h.err
+ }
+
+ return pk, errTestHook
+ }
+
+ return pk, nil
+}
+
+func (h *modifiedHookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
+ if h.fail {
+ if h.err != nil {
+ return pk, h.err
+ }
+
+ return pk, errTestHook
+ }
+
+ return pk, nil
+}
+
+func (h *modifiedHookBase) OnWill(cl *Client, will Will) (Will, error) {
+ if h.fail {
+ return will, errTestHook
+ }
+
+ return will, nil
+}
+
+func (h *modifiedHookBase) StoredClients() (v []storage.Client, err error) {
+ if h.fail || h.failAt == 1 {
+ return v, errTestHook
+ }
+
+ return []storage.Client{
+ {ID: "cl1"},
+ {ID: "cl2"},
+ {ID: "cl3"},
+ }, nil
+}
+
+func (h *modifiedHookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
+ if h.fail || h.failAt == 2 {
+ return v, errTestHook
+ }
+
+ return []storage.Subscription{
+ {ID: "sub1"},
+ {ID: "sub2"},
+ {ID: "sub3"},
+ }, nil
+}
+
+func (h *modifiedHookBase) StoredRetainedMessages() (v []storage.Message, err error) {
+ if h.fail || h.failAt == 3 {
+ return v, errTestHook
+ }
+
+ return []storage.Message{
+ {ID: "r1"},
+ {ID: "r2"},
+ {ID: "r3"},
+ }, nil
+}
+
+func (h *modifiedHookBase) StoredInflightMessages() (v []storage.Message, err error) {
+ if h.fail || h.failAt == 4 {
+ return v, errTestHook
+ }
+
+ return []storage.Message{
+ {ID: "i1"},
+ {ID: "i2"},
+ {ID: "i3"},
+ }, nil
+}
+
+func (h *modifiedHookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
+ if h.fail || h.failAt == 5 {
+ return v, errTestHook
+ }
+
+ return storage.SystemInfo{
+ Info: system.Info{
+ Version: "2.0.0",
+ },
+ }, nil
+}
+
+type providesCheckHook struct {
+ HookBase
+}
+
+func (h *providesCheckHook) Provides(b byte) bool {
+ return b == OnConnect
+}
+
+func TestHooksProvides(t *testing.T) {
+ h := new(Hooks)
+ err := h.Add(new(providesCheckHook), nil)
+ require.NoError(t, err)
+
+ err = h.Add(new(HookBase), nil)
+ require.NoError(t, err)
+
+ require.True(t, h.Provides(OnConnect, OnDisconnect))
+ require.False(t, h.Provides(OnDisconnect))
+}
+
+func TestHooksAddLenGetAll(t *testing.T) {
+ h := new(Hooks)
+ err := h.Add(new(HookBase), nil)
+ require.NoError(t, err)
+
+ err = h.Add(new(modifiedHookBase), nil)
+ require.NoError(t, err)
+
+ require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
+ require.Equal(t, int64(2), h.Len())
+
+ all := h.GetAll()
+ require.Equal(t, "base", all[0].ID())
+ require.Equal(t, "modified", all[1].ID())
+}
+
+func TestHooksAddInitFailure(t *testing.T) {
+ h := new(Hooks)
+ err := h.Add(new(modifiedHookBase), map[string]any{})
+ require.Error(t, err)
+ require.Equal(t, int64(0), atomic.LoadInt64(&h.qty))
+}
+
+func TestHooksStop(t *testing.T) {
+ h := new(Hooks)
+ h.Log = logger
+
+ err := h.Add(new(HookBase), nil)
+ require.NoError(t, err)
+ require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
+ require.Equal(t, int64(1), h.Len())
+
+ h.Stop()
+}
+
+// coverage: also cover some empty functions
+func TestHooksNonReturns(t *testing.T) {
+ h := new(Hooks)
+ cl := new(Client)
+
+ for i := 0; i < 2; i++ {
+ t.Run("step-"+strconv.Itoa(i), func(t *testing.T) {
+ // on first iteration, check without hook methods
+ h.OnStarted()
+ h.OnStopped()
+ h.OnSysInfoTick(new(system.Info))
+ h.OnSessionEstablish(cl, packets.Packet{})
+ h.OnSessionEstablished(cl, packets.Packet{})
+ h.OnDisconnect(cl, nil, false)
+ h.OnPacketSent(cl, packets.Packet{}, []byte{})
+ h.OnPacketProcessed(cl, packets.Packet{}, nil)
+ h.OnSubscribed(cl, packets.Packet{}, []byte{1})
+ h.OnUnsubscribed(cl, packets.Packet{})
+ h.OnPublished(cl, packets.Packet{})
+ h.OnPublishDropped(cl, packets.Packet{})
+ h.OnRetainMessage(cl, packets.Packet{}, 0)
+ h.OnRetainPublished(cl, packets.Packet{})
+ h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
+ h.OnQosComplete(cl, packets.Packet{})
+ h.OnQosDropped(cl, packets.Packet{})
+ h.OnPacketIDExhausted(cl, packets.Packet{})
+ h.OnWillSent(cl, packets.Packet{})
+ h.OnClientExpired(cl)
+ h.OnRetainedExpired("a/b/c")
+
+ // on second iteration, check added hook methods
+ err := h.Add(new(modifiedHookBase), nil)
+ require.NoError(t, err)
+ })
+ }
+}
+
+func TestHooksOnConnectAuthenticate(t *testing.T) {
+ h := new(Hooks)
+
+ ok := h.OnConnectAuthenticate(new(Client), packets.Packet{})
+ require.False(t, ok)
+
+ err := h.Add(new(modifiedHookBase), nil)
+ require.NoError(t, err)
+
+ ok = h.OnConnectAuthenticate(new(Client), packets.Packet{})
+ require.True(t, ok)
+}
+
+func TestHooksOnACLCheck(t *testing.T) {
+ h := new(Hooks)
+
+ ok := h.OnACLCheck(new(Client), "a/b/c", true)
+ require.False(t, ok)
+
+ err := h.Add(new(modifiedHookBase), nil)
+ require.NoError(t, err)
+
+ ok = h.OnACLCheck(new(Client), "a/b/c", true)
+ require.True(t, ok)
+}
+
+func TestHooksOnSubscribe(t *testing.T) {
+ h := new(Hooks)
+ err := h.Add(new(modifiedHookBase), nil)
+ require.NoError(t, err)
+
+ pki := packets.Packet{
+ Filters: packets.Subscriptions{
+ {Filter: "a/b/c", Qos: 1},
+ },
+ }
+ pk := h.OnSubscribe(new(Client), pki)
+ require.EqualValues(t, pk, pki)
+}
+
+func TestHooksOnSelectSubscribers(t *testing.T) {
+ h := new(Hooks)
+ err := h.Add(new(modifiedHookBase), nil)
+ require.NoError(t, err)
+
+ subs := &Subscribers{
+ Subscriptions: map[string]packets.Subscription{
+ "cl1": {Filter: "a/b/c"},
+ },
+ }
+
+ subs2 := h.OnSelectSubscribers(subs, packets.Packet{})
+ require.EqualValues(t, subs, subs2)
+}
+
+func TestHooksOnUnsubscribe(t *testing.T) {
+ h := new(Hooks)
+ err := h.Add(new(modifiedHookBase), nil)
+ require.NoError(t, err)
+
+ pki := packets.Packet{
+ Filters: packets.Subscriptions{
+ {Filter: "a/b/c", Qos: 1},
+ },
+ }
+
+ pk := h.OnUnsubscribe(new(Client), pki)
+ require.EqualValues(t, pk, pki)
+}
+
+func TestHooksOnPublish(t *testing.T) {
+ h := new(Hooks)
+ h.Log = logger
+
+ hook := new(modifiedHookBase)
+ err := h.Add(hook, nil)
+ require.NoError(t, err)
+
+ pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
+ require.NoError(t, err)
+ require.Equal(t, uint16(10), pk.PacketID)
+
+ // coverage: failure
+ hook.fail = true
+ pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
+ require.Error(t, err)
+ require.Equal(t, uint16(10), pk.PacketID)
+
+ // coverage: reject packet
+ hook.err = packets.ErrRejectPacket
+ pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrRejectPacket)
+ require.Equal(t, uint16(10), pk.PacketID)
+}
+
+func TestHooksOnPacketRead(t *testing.T) {
+ h := new(Hooks)
+ h.Log = logger
+
+ hook := new(modifiedHookBase)
+ err := h.Add(hook, nil)
+ require.NoError(t, err)
+
+ pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
+ require.NoError(t, err)
+ require.Equal(t, uint16(10), pk.PacketID)
+
+ // coverage: failure
+ hook.fail = true
+ pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
+ require.NoError(t, err)
+ require.Equal(t, uint16(10), pk.PacketID)
+
+ // coverage: reject packet
+ hook.err = packets.ErrRejectPacket
+ pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrRejectPacket)
+ require.Equal(t, uint16(10), pk.PacketID)
+}
+
+func TestHooksOnAuthPacket(t *testing.T) {
+ h := new(Hooks)
+ h.Log = logger
+
+ hook := new(modifiedHookBase)
+ err := h.Add(hook, nil)
+ require.NoError(t, err)
+
+ pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
+ require.NoError(t, err)
+ require.Equal(t, uint16(10), pk.PacketID)
+
+ hook.fail = true
+ pk, err = h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
+ require.Error(t, err)
+ require.Equal(t, uint16(10), pk.PacketID)
+}
+
+func TestHooksOnConnect(t *testing.T) {
+ h := new(Hooks)
+ h.Log = logger
+
+ hook := new(modifiedHookBase)
+ err := h.Add(hook, nil)
+ require.NoError(t, err)
+
+ err = h.OnConnect(new(Client), packets.Packet{PacketID: 10})
+ require.NoError(t, err)
+
+ hook.fail = true
+ err = h.OnConnect(new(Client), packets.Packet{PacketID: 10})
+ require.Error(t, err)
+}
+
+func TestHooksOnPacketEncode(t *testing.T) {
+ h := new(Hooks)
+ h.Log = logger
+
+ hook := new(modifiedHookBase)
+ err := h.Add(hook, nil)
+ require.NoError(t, err)
+
+ pk := h.OnPacketEncode(new(Client), packets.Packet{PacketID: 10})
+ require.Equal(t, uint16(10), pk.PacketID)
+}
+
+func TestHooksOnLWT(t *testing.T) {
+ h := new(Hooks)
+ h.Log = logger
+
+ hook := new(modifiedHookBase)
+ err := h.Add(hook, nil)
+ require.NoError(t, err)
+
+ lwt := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
+ require.Equal(t, "a/b/c", lwt.TopicName)
+
+ // coverage: fail lwt
+ hook.fail = true
+ lwt = h.OnWill(new(Client), Will{TopicName: "a/b/c"})
+ require.Equal(t, "a/b/c", lwt.TopicName)
+}
+
+func TestHooksStoredClients(t *testing.T) {
+ h := new(Hooks)
+ h.Log = logger
+
+ v, err := h.StoredClients()
+ require.NoError(t, err)
+ require.Len(t, v, 0)
+
+ hook := new(modifiedHookBase)
+ err = h.Add(hook, nil)
+ require.NoError(t, err)
+
+ v, err = h.StoredClients()
+ require.NoError(t, err)
+ require.Len(t, v, 3)
+
+ hook.fail = true
+ v, err = h.StoredClients()
+ require.Error(t, err)
+ require.Len(t, v, 0)
+}
+
+func TestHooksStoredSubscriptions(t *testing.T) {
+ h := new(Hooks)
+ h.Log = logger
+
+ v, err := h.StoredSubscriptions()
+ require.NoError(t, err)
+ require.Len(t, v, 0)
+
+ hook := new(modifiedHookBase)
+ err = h.Add(hook, nil)
+ require.NoError(t, err)
+
+ v, err = h.StoredSubscriptions()
+ require.NoError(t, err)
+ require.Len(t, v, 3)
+
+ hook.fail = true
+ v, err = h.StoredSubscriptions()
+ require.Error(t, err)
+ require.Len(t, v, 0)
+}
+
+func TestHooksStoredRetainedMessages(t *testing.T) {
+ h := new(Hooks)
+ h.Log = logger
+
+ v, err := h.StoredRetainedMessages()
+ require.NoError(t, err)
+ require.Len(t, v, 0)
+
+ hook := new(modifiedHookBase)
+ err = h.Add(hook, nil)
+ require.NoError(t, err)
+
+ v, err = h.StoredRetainedMessages()
+ require.NoError(t, err)
+ require.Len(t, v, 3)
+
+ hook.fail = true
+ v, err = h.StoredRetainedMessages()
+ require.Error(t, err)
+ require.Len(t, v, 0)
+}
+
+func TestHooksStoredInflightMessages(t *testing.T) {
+ h := new(Hooks)
+ h.Log = logger
+
+ v, err := h.StoredInflightMessages()
+ require.NoError(t, err)
+ require.Len(t, v, 0)
+
+ hook := new(modifiedHookBase)
+ err = h.Add(hook, nil)
+ require.NoError(t, err)
+
+ v, err = h.StoredInflightMessages()
+ require.NoError(t, err)
+ require.Len(t, v, 3)
+
+ hook.fail = true
+ v, err = h.StoredInflightMessages()
+ require.Error(t, err)
+ require.Len(t, v, 0)
+}
+
+func TestHooksStoredSysInfo(t *testing.T) {
+ h := new(Hooks)
+ h.Log = logger
+
+ v, err := h.StoredSysInfo()
+ require.NoError(t, err)
+ require.Equal(t, "", v.Info.Version)
+
+ hook := new(modifiedHookBase)
+ err = h.Add(hook, nil)
+ require.NoError(t, err)
+
+ v, err = h.StoredSysInfo()
+ require.NoError(t, err)
+ require.Equal(t, "2.0.0", v.Info.Version)
+
+ hook.fail = true
+ v, err = h.StoredSysInfo()
+ require.Error(t, err)
+ require.Equal(t, "", v.Info.Version)
+}
+
+func TestHookBaseID(t *testing.T) {
+ h := new(HookBase)
+ require.Equal(t, "base", h.ID())
+}
+
+func TestHookBaseProvidesNone(t *testing.T) {
+ h := new(HookBase)
+ require.False(t, h.Provides(OnConnect))
+ require.False(t, h.Provides(OnDisconnect))
+}
+
+func TestHookBaseInit(t *testing.T) {
+ h := new(HookBase)
+ require.Nil(t, h.Init(nil))
+}
+
+func TestHookBaseSetOpts(t *testing.T) {
+ h := new(HookBase)
+ h.SetOpts(logger, new(HookOptions))
+ require.NotNil(t, h.Log)
+ require.NotNil(t, h.Opts)
+}
+
+func TestHookBaseClose(t *testing.T) {
+ h := new(HookBase)
+ require.Nil(t, h.Stop())
+}
+
+func TestHookBaseOnConnectAuthenticate(t *testing.T) {
+ h := new(HookBase)
+ v := h.OnConnectAuthenticate(new(Client), packets.Packet{})
+ require.False(t, v)
+}
+
+func TestHookBaseOnACLCheck(t *testing.T) {
+ h := new(HookBase)
+ v := h.OnACLCheck(new(Client), "topic", true)
+ require.False(t, v)
+}
+
+func TestHookBaseOnConnect(t *testing.T) {
+ h := new(HookBase)
+ err := h.OnConnect(new(Client), packets.Packet{})
+ require.NoError(t, err)
+}
+
+func TestHookBaseOnPublish(t *testing.T) {
+ h := new(HookBase)
+ pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
+ require.NoError(t, err)
+ require.Equal(t, uint16(10), pk.PacketID)
+}
+
+func TestHookBaseOnPacketRead(t *testing.T) {
+ h := new(HookBase)
+ pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
+ require.NoError(t, err)
+ require.Equal(t, uint16(10), pk.PacketID)
+}
+
+func TestHookBaseOnAuthPacket(t *testing.T) {
+ h := new(HookBase)
+ pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
+ require.NoError(t, err)
+ require.Equal(t, uint16(10), pk.PacketID)
+}
+
+func TestHookBaseOnLWT(t *testing.T) {
+ h := new(HookBase)
+ lwt, err := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
+ require.NoError(t, err)
+ require.Equal(t, "a/b/c", lwt.TopicName)
+}
+
+func TestHookBaseStoredClients(t *testing.T) {
+ h := new(HookBase)
+ v, err := h.StoredClients()
+ require.NoError(t, err)
+ require.Empty(t, v)
+}
+
+func TestHookBaseStoredSubscriptions(t *testing.T) {
+ h := new(HookBase)
+ v, err := h.StoredSubscriptions()
+ require.NoError(t, err)
+ require.Empty(t, v)
+}
+
+func TestHookBaseStoredInflightMessages(t *testing.T) {
+ h := new(HookBase)
+ v, err := h.StoredInflightMessages()
+ require.NoError(t, err)
+ require.Empty(t, v)
+}
+
+func TestHookBaseStoredRetainedMessages(t *testing.T) {
+ h := new(HookBase)
+ v, err := h.StoredRetainedMessages()
+ require.NoError(t, err)
+ require.Empty(t, v)
+}
+
+func TestHookBaseStoreSysInfo(t *testing.T) {
+ h := new(HookBase)
+ v, err := h.StoredSysInfo()
+ require.NoError(t, err)
+ require.Equal(t, "", v.Version)
+}
diff --git a/mqtt/inflight.go b/mqtt/inflight.go
new file mode 100644
index 0000000..c631c7f
--- /dev/null
+++ b/mqtt/inflight.go
@@ -0,0 +1,156 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package mqtt
+
+import (
+ "sort"
+ "sync"
+ "sync/atomic"
+
+ "testmqtt/packets"
+)
+
+// Inflight is a map of InflightMessage keyed on packet id.
+type Inflight struct {
+ sync.RWMutex
+ internal map[uint16]packets.Packet // internal contains the inflight packets
+ receiveQuota int32 // remaining inbound qos quota for flow control
+ sendQuota int32 // remaining outbound qos quota for flow control
+ maximumReceiveQuota int32 // maximum allowed receive quota
+ maximumSendQuota int32 // maximum allowed send quota
+}
+
+// NewInflights returns a new instance of an Inflight packets map.
+func NewInflights() *Inflight {
+ return &Inflight{
+ internal: map[uint16]packets.Packet{},
+ }
+}
+
+// Set adds or updates an inflight packet by packet id.
+func (i *Inflight) Set(m packets.Packet) bool {
+ i.Lock()
+ defer i.Unlock()
+
+ _, ok := i.internal[m.PacketID]
+ i.internal[m.PacketID] = m
+ return !ok
+}
+
+// Get returns an inflight packet by packet id.
+func (i *Inflight) Get(id uint16) (packets.Packet, bool) {
+ i.RLock()
+ defer i.RUnlock()
+
+ if m, ok := i.internal[id]; ok {
+ return m, true
+ }
+
+ return packets.Packet{}, false
+}
+
+// Len returns the size of the inflight messages map.
+func (i *Inflight) Len() int {
+ i.RLock()
+ defer i.RUnlock()
+ return len(i.internal)
+}
+
+// Clone returns a new instance of Inflight with the same message data.
+// This is used when transferring inflights from a taken-over session.
+func (i *Inflight) Clone() *Inflight {
+ c := NewInflights()
+ i.RLock()
+ defer i.RUnlock()
+ for k, v := range i.internal {
+ c.internal[k] = v
+ }
+ return c
+}
+
+// GetAll returns all the inflight messages.
+func (i *Inflight) GetAll(immediate bool) []packets.Packet {
+ i.RLock()
+ defer i.RUnlock()
+
+ m := []packets.Packet{}
+ for _, v := range i.internal {
+ if !immediate || (immediate && v.Expiry < 0) {
+ m = append(m, v)
+ }
+ }
+
+ sort.Slice(m, func(i, j int) bool {
+ return uint16(m[i].Created) < uint16(m[j].Created)
+ })
+
+ return m
+}
+
+// NextImmediate returns the next inflight packet which is indicated to be sent immediately.
+// This typically occurs when the quota has been exhausted, and we need to wait until new quota
+// is free to continue sending.
+func (i *Inflight) NextImmediate() (packets.Packet, bool) {
+ i.RLock()
+ defer i.RUnlock()
+
+ m := i.GetAll(true)
+ if len(m) > 0 {
+ return m[0], true
+ }
+
+ return packets.Packet{}, false
+}
+
+// Delete removes an in-flight message from the map. Returns true if the message existed.
+func (i *Inflight) Delete(id uint16) bool {
+ i.Lock()
+ defer i.Unlock()
+
+ _, ok := i.internal[id]
+ delete(i.internal, id)
+
+ return ok
+}
+
+// TakeRecieveQuota reduces the receive quota by 1.
+func (i *Inflight) DecreaseReceiveQuota() {
+ if atomic.LoadInt32(&i.receiveQuota) > 0 {
+ atomic.AddInt32(&i.receiveQuota, -1)
+ }
+}
+
+// TakeRecieveQuota increases the receive quota by 1.
+func (i *Inflight) IncreaseReceiveQuota() {
+ if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) {
+ atomic.AddInt32(&i.receiveQuota, 1)
+ }
+}
+
+// ResetReceiveQuota resets the receive quota to the maximum allowed value.
+func (i *Inflight) ResetReceiveQuota(n int32) {
+ atomic.StoreInt32(&i.receiveQuota, n)
+ atomic.StoreInt32(&i.maximumReceiveQuota, n)
+}
+
+// DecreaseSendQuota reduces the send quota by 1.
+func (i *Inflight) DecreaseSendQuota() {
+ if atomic.LoadInt32(&i.sendQuota) > 0 {
+ atomic.AddInt32(&i.sendQuota, -1)
+ }
+}
+
+// IncreaseSendQuota increases the send quota by 1.
+func (i *Inflight) IncreaseSendQuota() {
+ if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) {
+ atomic.AddInt32(&i.sendQuota, 1)
+ }
+}
+
+// ResetSendQuota resets the send quota to the maximum allowed value.
+func (i *Inflight) ResetSendQuota(n int32) {
+ atomic.StoreInt32(&i.sendQuota, n)
+ atomic.StoreInt32(&i.maximumSendQuota, n)
+}
diff --git a/mqtt/inflight_test.go b/mqtt/inflight_test.go
new file mode 100644
index 0000000..d755204
--- /dev/null
+++ b/mqtt/inflight_test.go
@@ -0,0 +1,199 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package mqtt
+
+import (
+ "sync/atomic"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "testmqtt/packets"
+)
+
+func TestInflightSet(t *testing.T) {
+ cl, _, _ := newTestClient()
+
+ r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
+ require.True(t, r)
+ require.NotNil(t, cl.State.Inflight.internal[1])
+ require.NotEqual(t, 0, cl.State.Inflight.internal[1].PacketID)
+
+ r = cl.State.Inflight.Set(packets.Packet{PacketID: 1})
+ require.False(t, r)
+}
+
+func TestInflightGet(t *testing.T) {
+ cl, _, _ := newTestClient()
+ cl.State.Inflight.Set(packets.Packet{PacketID: 2})
+
+ msg, ok := cl.State.Inflight.Get(2)
+ require.True(t, ok)
+ require.NotEqual(t, 0, msg.PacketID)
+}
+
+func TestInflightGetAllAndImmediate(t *testing.T) {
+ cl, _, _ := newTestClient()
+ cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
+ cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
+ cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
+ cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
+ cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
+
+ require.Equal(t, []packets.Packet{
+ {PacketID: 1, Created: 1},
+ {PacketID: 2, Created: 2},
+ {PacketID: 3, Created: 3, Expiry: -1},
+ {PacketID: 4, Created: 4, Expiry: -1},
+ {PacketID: 5, Created: 5},
+ }, cl.State.Inflight.GetAll(false))
+
+ require.Equal(t, []packets.Packet{
+ {PacketID: 3, Created: 3, Expiry: -1},
+ {PacketID: 4, Created: 4, Expiry: -1},
+ }, cl.State.Inflight.GetAll(true))
+}
+
+func TestInflightLen(t *testing.T) {
+ cl, _, _ := newTestClient()
+ cl.State.Inflight.Set(packets.Packet{PacketID: 2})
+ require.Equal(t, 1, cl.State.Inflight.Len())
+}
+
+func TestInflightClone(t *testing.T) {
+ cl, _, _ := newTestClient()
+ cl.State.Inflight.Set(packets.Packet{PacketID: 2})
+ require.Equal(t, 1, cl.State.Inflight.Len())
+
+ cloned := cl.State.Inflight.Clone()
+ require.NotNil(t, cloned)
+ require.NotSame(t, cloned, cl.State.Inflight)
+}
+
+func TestInflightDelete(t *testing.T) {
+ cl, _, _ := newTestClient()
+
+ cl.State.Inflight.Set(packets.Packet{PacketID: 3})
+ require.NotNil(t, cl.State.Inflight.internal[3])
+
+ r := cl.State.Inflight.Delete(3)
+ require.True(t, r)
+ require.Equal(t, uint16(0), cl.State.Inflight.internal[3].PacketID)
+
+ _, ok := cl.State.Inflight.Get(3)
+ require.False(t, ok)
+
+ r = cl.State.Inflight.Delete(3)
+ require.False(t, r)
+}
+
+func TestResetReceiveQuota(t *testing.T) {
+ i := NewInflights()
+ require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumReceiveQuota))
+ require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
+ i.ResetReceiveQuota(6)
+ require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumReceiveQuota))
+ require.Equal(t, int32(6), atomic.LoadInt32(&i.receiveQuota))
+}
+
+func TestReceiveQuota(t *testing.T) {
+ i := NewInflights()
+ i.receiveQuota = 4
+ i.maximumReceiveQuota = 5
+ require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
+ require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota))
+
+ // Return 1
+ i.IncreaseReceiveQuota()
+ require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
+ require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
+
+ // Try to go over max limit
+ i.IncreaseReceiveQuota()
+ require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
+ require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
+
+ // Reset to max 1
+ i.ResetReceiveQuota(1)
+ require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
+ require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota))
+
+ // Take 1
+ i.DecreaseReceiveQuota()
+ require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
+ require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
+
+ // Try to go below zero
+ i.DecreaseReceiveQuota()
+ require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
+ require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
+}
+
+func TestResetSendQuota(t *testing.T) {
+ i := NewInflights()
+ require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumSendQuota))
+ require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
+ i.ResetSendQuota(6)
+ require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumSendQuota))
+ require.Equal(t, int32(6), atomic.LoadInt32(&i.sendQuota))
+}
+
+func TestSendQuota(t *testing.T) {
+ i := NewInflights()
+ i.sendQuota = 4
+ i.maximumSendQuota = 5
+ require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
+ require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota))
+
+ // Return 1
+ i.IncreaseSendQuota()
+ require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
+ require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
+
+ // Try to go over max limit
+ i.IncreaseSendQuota()
+ require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
+ require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
+
+ // Reset to max 1
+ i.ResetSendQuota(1)
+ require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
+ require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota))
+
+ // Take 1
+ i.DecreaseSendQuota()
+ require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
+ require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
+
+ // Try to go below zero
+ i.DecreaseSendQuota()
+ require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
+ require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
+}
+
+func TestNextImmediate(t *testing.T) {
+ cl, _, _ := newTestClient()
+ cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
+ cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
+ cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
+ cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
+ cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
+
+ pk, ok := cl.State.Inflight.NextImmediate()
+ require.True(t, ok)
+ require.Equal(t, packets.Packet{PacketID: 3, Created: 3, Expiry: -1}, pk)
+
+ r := cl.State.Inflight.Delete(3)
+ require.True(t, r)
+
+ pk, ok = cl.State.Inflight.NextImmediate()
+ require.True(t, ok)
+ require.Equal(t, packets.Packet{PacketID: 4, Created: 4, Expiry: -1}, pk)
+
+ r = cl.State.Inflight.Delete(4)
+ require.True(t, r)
+
+ _, ok = cl.State.Inflight.NextImmediate()
+ require.False(t, ok)
+}
diff --git a/mqtt/server.go b/mqtt/server.go
new file mode 100644
index 0000000..af7992a
--- /dev/null
+++ b/mqtt/server.go
@@ -0,0 +1,1759 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+// Package mqtt provides a high performance, fully compliant MQTT v5 broker server with v3.1.1 backward compatibility.
+package mqtt
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "net"
+ "os"
+ "runtime"
+ "sort"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "testmqtt/hooks/storage"
+ "testmqtt/listeners"
+ "testmqtt/packets"
+ "testmqtt/system"
+
+ "log/slog"
+)
+
+const (
+ // Version 服务器版本
+ Version = "2.6.5" // the current server version.
+ // 默认系统主题发布时间间隔
+ defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
+ LocalListener = "local"
+ InlineClientId = "inline"
+)
+
+var (
+ // Deprecated: Use NewDefaultServerCapabilities to avoid data race issue.
+ DefaultServerCapabilities = NewDefaultServerCapabilities()
+
+ ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists
+ ErrConnectionClosed = errors.New("connection not open") // connection is closed
+ ErrInlineClientNotEnabled = errors.New("please set Options.InlineClient=true to use this feature") // inline client is not enabled by default
+ ErrOptionsUnreadable = errors.New("unable to read options from bytes")
+)
+
+// Capabilities indicates the capabilities and features provided by the server.
+type Capabilities struct {
+ MaximumClients int64 `yaml:"maximum_clients" json:"maximum_clients"` // maximum number of connected clients
+ MaximumMessageExpiryInterval int64 `yaml:"maximum_message_expiry_interval" json:"maximum_message_expiry_interval"` // maximum message expiry if message expiry is 0 or over
+ MaximumClientWritesPending int32 `yaml:"maximum_client_writes_pending" json:"maximum_client_writes_pending"` // maximum number of pending message writes for a client
+ MaximumSessionExpiryInterval uint32 `yaml:"maximum_session_expiry_interval" json:"maximum_session_expiry_interval"` // maximum number of seconds to keep disconnected sessions
+ MaximumPacketSize uint32 `yaml:"maximum_packet_size" json:"maximum_packet_size"` // maximum packet size, no limit if 0
+ maximumPacketID uint32 // unexported, used for testing only
+ ReceiveMaximum uint16 `yaml:"receive_maximum" json:"receive_maximum"` // maximum number of concurrent qos messages per client
+ MaximumInflight uint16 `yaml:"maximum_inflight" json:"maximum_inflight"` // maximum number of qos > 0 messages can be stored, 0(=8192)-65535
+ TopicAliasMaximum uint16 `yaml:"topic_alias_maximum" json:"topic_alias_maximum"` // maximum topic alias value
+ SharedSubAvailable byte `yaml:"shared_sub_available" json:"shared_sub_available"` // support of shared subscriptions
+ MinimumProtocolVersion byte `yaml:"minimum_protocol_version" json:"minimum_protocol_version"` // minimum supported mqtt version
+ Compatibilities Compatibilities `yaml:"compatibilities" json:"compatibilities"` // version compatibilities the server provides
+ MaximumQos byte `yaml:"maximum_qos" json:"maximum_qos"` // maximum qos value available to clients
+ RetainAvailable byte `yaml:"retain_available" json:"retain_available"` // support of retain messages
+ WildcardSubAvailable byte `yaml:"wildcard_sub_available" json:"wildcard_sub_available"` // support of wildcard subscriptions
+ SubIDAvailable byte `yaml:"sub_id_available" json:"sub_id_available"` // support of subscription identifiers
+}
+
+// NewDefaultServerCapabilities defines the default features and capabilities provided by the server.
+func NewDefaultServerCapabilities() *Capabilities {
+ return &Capabilities{
+ MaximumClients: math.MaxInt64, // maximum number of connected clients
+ MaximumMessageExpiryInterval: 60 * 60 * 24, // maximum message expiry if message expiry is 0 or over
+ MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client
+ MaximumSessionExpiryInterval: math.MaxUint32, // maximum number of seconds to keep disconnected sessions
+ MaximumPacketSize: 0, // no maximum packet size
+ maximumPacketID: math.MaxUint16,
+ ReceiveMaximum: 1024, // maximum number of concurrent qos messages per client
+ MaximumInflight: 1024 * 8, // maximum number of qos > 0 messages can be stored
+ TopicAliasMaximum: math.MaxUint16, // maximum topic alias value
+ SharedSubAvailable: 1, // shared subscriptions are available
+ MinimumProtocolVersion: 3, // minimum supported mqtt version (3.0.0)
+ MaximumQos: 2, // maximum qos value available to clients
+ RetainAvailable: 1, // retain messages is available
+ WildcardSubAvailable: 1, // wildcard subscriptions are available
+ SubIDAvailable: 1, // subscription identifiers are available
+ }
+}
+
+// Compatibilities provides flags for using compatibility modes.
+type Compatibilities struct {
+ ObscureNotAuthorized bool `yaml:"obscure_not_authorized" json:"obscure_not_authorized"` // return unspecified errors instead of not authorized
+ PassiveClientDisconnect bool `yaml:"passive_client_disconnect" json:"passive_client_disconnect"` // don't disconnect the client forcefully after sending disconnect packet (paho - spec violation)
+ AlwaysReturnResponseInfo bool `yaml:"always_return_response_info" json:"always_return_response_info"` // always return response info (useful for testing)
+ RestoreSysInfoOnRestart bool `yaml:"restore_sys_info_on_restart" json:"restore_sys_info_on_restart"` // restore system info from store as if server never stopped
+ NoInheritedPropertiesOnAck bool `yaml:"no_inherited_properties_on_ack" json:"no_inherited_properties_on_ack"` // don't allow inherited user properties on ack (paho - spec violation)
+}
+
+// Options contains configurable options for the server.
+type Options struct {
+ // Listeners specifies any listeners which should be dynamically added on serve. Used when setting listeners by config.
+ Listeners []listeners.Config `yaml:"listeners" json:"listeners"`
+
+ // Hooks specifies any hooks which should be dynamically added on serve. Used when setting hooks by config.
+ Hooks []HookLoadConfig `yaml:"hooks" json:"hooks"`
+
+ // Capabilities defines the server features and behaviour. If you only wish to modify
+ // several of these values, set them explicitly - e.g.
+ // server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
+ Capabilities *Capabilities `yaml:"capabilities" json:"capabilities"`
+
+ // ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer.
+ ClientNetWriteBufferSize int `yaml:"client_net_write_buffer_size" json:"client_net_write_buffer_size"`
+
+ // ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer.
+ ClientNetReadBufferSize int `yaml:"client_net_read_buffer_size" json:"client_net_read_buffer_size"`
+
+ // Logger specifies a custom configured implementation of log/slog to override
+ // the servers default logger configuration. If you wish to change the log level,
+ // of the default logger, you can do so by setting:
+ // server := mqtt.New(nil)
+ // level := new(slog.LevelVar)
+ // server.Slog = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
+ // Level: level,
+ // }))
+ // level.Set(slog.LevelDebug)
+ Logger *slog.Logger `yaml:"-" json:"-"`
+
+ // SysTopicResendInterval specifies the interval between $SYS topic updates in seconds.
+ SysTopicResendInterval int64 `yaml:"sys_topic_resend_interval" json:"sys_topic_resend_interval"`
+
+ // Enable Inline client to allow direct subscribing and publishing from the parent codebase,
+ // with negligible performance difference (disabled by default to prevent confusion in statistics).
+ InlineClient bool `yaml:"inline_client" json:"inline_client"`
+}
+
+// Server is an MQTT broker server. It should be created with server.New()
+// in order to ensure all the internal fields are correctly populated.
+type Server struct {
+ Options *Options // configurable server options
+ Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections
+ Clients *Clients // clients known to the broker
+ Topics *TopicsIndex // an index of topic filter subscriptions and retained messages
+ Info *system.Info // values about the server commonly known as $SYS topics
+ loop *loop // loop contains tickers for the system event loop
+ done chan bool // indicate that the server is ending
+ Log *slog.Logger // minimal no-alloc logger
+ hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage
+ inlineClient *Client // inlineClient is a special client used for inline subscriptions and inline Publish
+}
+
+// loop contains interval tickers for the system events loop.
+type loop struct {
+ sysTopics *time.Ticker // interval ticker for sending updating $SYS topics
+ clientExpiry *time.Ticker // interval ticker for cleaning expired clients
+ inflightExpiry *time.Ticker // interval ticker for cleaning up expired inflight messages
+ retainedExpiry *time.Ticker // interval ticker for cleaning retained messages
+ willDelaySend *time.Ticker // interval ticker for sending Will Messages with a delay
+ willDelayed *packets.Packets // activate LWT packets which will be sent after a delay
+}
+
+// ops contains server values which can be propagated to other structs.
+type ops struct {
+ options *Options // a pointer to the server options and capabilities, for referencing in clients
+ info *system.Info // pointers to server system info
+ hooks *Hooks // pointer to the server hooks
+ log *slog.Logger // a structured logger for the client
+}
+
+// New returns a new instance of mochi mqtt broker. Optional parameters
+// can be specified to override some default settings (see Options).
+func New(opts *Options) *Server {
+ if opts == nil {
+ opts = new(Options)
+ }
+
+ opts.ensureDefaults()
+
+ s := &Server{
+ done: make(chan bool),
+ Clients: NewClients(),
+ Topics: NewTopicsIndex(),
+ Listeners: listeners.New(),
+ loop: &loop{
+ sysTopics: time.NewTicker(time.Second * time.Duration(opts.SysTopicResendInterval)),
+ clientExpiry: time.NewTicker(time.Second),
+ inflightExpiry: time.NewTicker(time.Second),
+ retainedExpiry: time.NewTicker(time.Second),
+ willDelaySend: time.NewTicker(time.Second),
+ willDelayed: packets.NewPackets(),
+ },
+ Options: opts,
+ Info: &system.Info{
+ Version: Version,
+ Started: time.Now().Unix(),
+ },
+ Log: opts.Logger,
+ hooks: &Hooks{
+ Log: opts.Logger,
+ },
+ }
+
+ if s.Options.InlineClient {
+ s.inlineClient = s.NewClient(nil, LocalListener, InlineClientId, true)
+ s.Clients.Add(s.inlineClient)
+ }
+
+ return s
+}
+
+// ensureDefaults ensures that the server starts with sane default values, if none are provided.
+func (o *Options) ensureDefaults() {
+ if o.Capabilities == nil {
+ o.Capabilities = NewDefaultServerCapabilities()
+ }
+
+ o.Capabilities.maximumPacketID = math.MaxUint16 // spec maximum is 65535
+
+ if o.Capabilities.MaximumInflight == 0 {
+ o.Capabilities.MaximumInflight = 1024 * 8
+ }
+
+ if o.SysTopicResendInterval == 0 {
+ o.SysTopicResendInterval = defaultSysTopicInterval
+ }
+
+ if o.ClientNetWriteBufferSize == 0 {
+ o.ClientNetWriteBufferSize = 1024 * 2
+ }
+
+ if o.ClientNetReadBufferSize == 0 {
+ o.ClientNetReadBufferSize = 1024 * 2
+ }
+
+ if o.Logger == nil {
+ log := slog.New(slog.NewTextHandler(os.Stdout, nil))
+ o.Logger = log
+ }
+}
+
+// NewClient returns a new Client instance, populated with all the required values and
+// references to be used with the server. If you are using this client to directly publish
+// messages from the embedding application, set the inline flag to true to bypass ACL and
+// topic validation checks.
+func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) *Client {
+ cl := newClient(c, &ops{ // [MQTT-3.1.2-6] implicit
+ options: s.Options,
+ info: s.Info,
+ hooks: s.hooks,
+ log: s.Log,
+ })
+
+ cl.ID = id
+ cl.Net.Listener = listener
+
+ if inline { // inline clients bypass acl and some validity checks.
+ cl.Net.Inline = true
+ // By default, we don't want to restrict developer publishes,
+ // but if you do, reset this after creating inline client.
+ cl.State.Inflight.ResetReceiveQuota(math.MaxInt32)
+ }
+
+ return cl
+}
+
+// AddHook attaches a new Hook to the server. Ideally, this should be called
+// before the server is started with s.Serve().
+func (s *Server) AddHook(hook Hook, config any) error {
+ nl := s.Log.With("hook", hook.ID())
+ hook.SetOpts(nl, &HookOptions{
+ Capabilities: s.Options.Capabilities,
+ })
+
+ s.Log.Info("added hook", "hook", hook.ID())
+ return s.hooks.Add(hook, config)
+}
+
+// AddHooksFromConfig adds hooks to the server which were specified in the hooks config (usually from a config file).
+// New built-in hooks should be added to this list.
+func (s *Server) AddHooksFromConfig(hooks []HookLoadConfig) error {
+ for _, h := range hooks {
+ if err := s.AddHook(h.Hook, h.Config); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// AddListener adds a new network listener to the server, for receiving incoming client connections.
+func (s *Server) AddListener(l listeners.Listener) error {
+ if _, ok := s.Listeners.Get(l.ID()); ok {
+ return ErrListenerIDExists
+ }
+
+ nl := s.Log.With(slog.String("listener", l.ID()))
+ err := l.Init(nl)
+ if err != nil {
+ return err
+ }
+
+ s.Listeners.Add(l)
+
+ s.Log.Info("attached listener", "id", l.ID(), "protocol", l.Protocol(), "address", l.Address())
+ return nil
+}
+
+// AddListenersFromConfig adds listeners to the server which were specified in the listeners config (usually from a config file).
+// New built-in listeners should be added to this list.
+func (s *Server) AddListenersFromConfig(configs []listeners.Config) error {
+ for _, conf := range configs {
+ var l listeners.Listener
+ switch strings.ToLower(conf.Type) {
+ case listeners.TypeTCP:
+ l = listeners.NewTCP(conf)
+ case listeners.TypeWS:
+ l = listeners.NewWebsocket(conf)
+ case listeners.TypeUnix:
+ l = listeners.NewUnixSock(conf)
+ case listeners.TypeHealthCheck:
+ l = listeners.NewHTTPHealthCheck(conf)
+ case listeners.TypeSysInfo:
+ l = listeners.NewHTTPStats(conf, s.Info)
+ case listeners.TypeMock:
+ l = listeners.NewMockListener(conf.ID, conf.Address)
+ default:
+ s.Log.Error("listener type unavailable by config", "listener", conf.Type)
+ continue
+ }
+ if err := s.AddListener(l); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Serve starts the event loops responsible for establishing client connections
+// on all attached listeners, publishing the system topics, and starting all hooks.
+func (s *Server) Serve() error {
+ s.Log.Info("mochi mqtt starting", "version", Version)
+ defer s.Log.Info("mochi mqtt server started")
+
+ if len(s.Options.Listeners) > 0 {
+ err := s.AddListenersFromConfig(s.Options.Listeners)
+ if err != nil {
+ return err
+ }
+ }
+
+ if len(s.Options.Hooks) > 0 {
+ err := s.AddHooksFromConfig(s.Options.Hooks)
+ if err != nil {
+ return err
+ }
+ }
+
+ if s.hooks.Provides(
+ StoredClients,
+ StoredInflightMessages,
+ StoredRetainedMessages,
+ StoredSubscriptions,
+ StoredSysInfo,
+ ) {
+ err := s.readStore()
+ if err != nil {
+ return err
+ }
+ }
+
+ go s.eventLoop() // spin up event loop for issuing $SYS values and closing server.
+ s.Listeners.ServeAll(s.EstablishConnection) // start listening on all listeners.
+ s.publishSysTopics() // begin publishing $SYS system values.
+ s.hooks.OnStarted()
+
+ return nil
+}
+
+// eventLoop loops forever, running various server housekeeping methods at different intervals.
+func (s *Server) eventLoop() {
+ s.Log.Debug("system event loop started")
+ defer s.Log.Debug("system event loop halted")
+
+ for {
+ select {
+ case <-s.done:
+ s.loop.sysTopics.Stop()
+ return
+ case <-s.loop.sysTopics.C:
+ s.publishSysTopics()
+ case <-s.loop.clientExpiry.C:
+ s.clearExpiredClients(time.Now().Unix())
+ case <-s.loop.retainedExpiry.C:
+ s.clearExpiredRetainedMessages(time.Now().Unix())
+ case <-s.loop.willDelaySend.C:
+ s.sendDelayedLWT(time.Now().Unix())
+ case <-s.loop.inflightExpiry.C:
+ s.clearExpiredInflights(time.Now().Unix())
+ }
+ }
+}
+
+// EstablishConnection establishes a new client when a listener accepts a new connection.
+func (s *Server) EstablishConnection(listener string, c net.Conn) error {
+ cl := s.NewClient(c, listener, "", false)
+ return s.attachClient(cl, listener)
+}
+
+// attachClient validates an incoming client connection and if viable, attaches the client
+// to the server, performs session housekeeping, and reads incoming packets.
+func (s *Server) attachClient(cl *Client, listener string) error {
+ defer s.Listeners.ClientsWg.Done()
+ s.Listeners.ClientsWg.Add(1)
+
+ go cl.WriteLoop()
+ defer cl.Stop(nil)
+
+ pk, err := s.readConnectionPacket(cl)
+ if err != nil {
+ return fmt.Errorf("read connection: %w", err)
+ }
+
+ cl.ParseConnect(listener, pk)
+ if atomic.LoadInt64(&s.Info.ClientsConnected) >= s.Options.Capabilities.MaximumClients {
+ if cl.Properties.ProtocolVersion < 5 {
+ s.SendConnack(cl, packets.ErrServerUnavailable, false, nil)
+ } else {
+ s.SendConnack(cl, packets.ErrServerBusy, false, nil)
+ }
+
+ return packets.ErrServerBusy
+ }
+
+ code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2]
+ if code != packets.CodeSuccess {
+ if err := s.SendConnack(cl, code, false, nil); err != nil {
+ return fmt.Errorf("invalid connection send ack: %w", err)
+ }
+ return code // [MQTT-3.2.2-7] [MQTT-3.1.4-6]
+ }
+
+ err = s.hooks.OnConnect(cl, pk)
+ if err != nil {
+ return err
+ }
+
+ cl.refreshDeadline(cl.State.Keepalive)
+ if !s.hooks.OnConnectAuthenticate(cl, pk) { // [MQTT-3.1.4-2]
+ err := s.SendConnack(cl, packets.ErrBadUsernameOrPassword, false, nil)
+ if err != nil {
+ return fmt.Errorf("invalid connection send ack: %w", err)
+ }
+
+ return packets.ErrBadUsernameOrPassword
+ }
+
+ atomic.AddInt64(&s.Info.ClientsConnected, 1)
+ defer atomic.AddInt64(&s.Info.ClientsConnected, -1)
+
+ s.hooks.OnSessionEstablish(cl, pk)
+
+ sessionPresent := s.inheritClientSession(pk, cl)
+ s.Clients.Add(cl) // [MQTT-4.1.0-1]
+
+ err = s.SendConnack(cl, code, sessionPresent, nil) // [MQTT-3.1.4-5] [MQTT-3.2.0-1] [MQTT-3.2.0-2] &[MQTT-3.14.0-1]
+ if err != nil {
+ return fmt.Errorf("ack connection packet: %w", err)
+ }
+
+ s.loop.willDelayed.Delete(cl.ID) // [MQTT-3.1.3-9]
+
+ if sessionPresent {
+ err = cl.ResendInflightMessages(true)
+ if err != nil {
+ return fmt.Errorf("resend inflight: %w", err)
+ }
+ }
+
+ s.hooks.OnSessionEstablished(cl, pk)
+
+ err = cl.Read(s.receivePacket)
+ if err != nil {
+ s.sendLWT(cl)
+ cl.Stop(err)
+ } else {
+ cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10]
+ }
+ s.Log.Debug("client disconnected", "error", err, "client", cl.ID, "remote", cl.Net.Remote, "listener", listener)
+
+ expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
+ s.hooks.OnDisconnect(cl, err, expire)
+
+ if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 {
+ cl.ClearInflights()
+ s.UnsubscribeClient(cl)
+ s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
+ }
+
+ return err
+}
+
+// readConnectionPacket reads the first incoming header for a connection, and if
+// acceptable, returns the valid connection packet.
+func (s *Server) readConnectionPacket(cl *Client) (pk packets.Packet, err error) {
+ fh := new(packets.FixedHeader)
+ err = cl.ReadFixedHeader(fh)
+ if err != nil {
+ return
+ }
+
+ if fh.Type != packets.Connect {
+ return pk, packets.ErrProtocolViolationRequireFirstConnect // [MQTT-3.1.0-1]
+ }
+
+ pk, err = cl.ReadPacket(fh)
+ if err != nil {
+ return
+ }
+
+ return
+}
+
+// receivePacket processes an incoming packet for a client, and issues a disconnect to the client
+// if an error has occurred (if mqtt v5).
+func (s *Server) receivePacket(cl *Client, pk packets.Packet) error {
+ err := s.processPacket(cl, pk)
+ if err != nil {
+ if code, ok := err.(packets.Code); ok &&
+ cl.Properties.ProtocolVersion == 5 &&
+ code.Code >= packets.ErrUnspecifiedError.Code {
+ _ = s.DisconnectClient(cl, code)
+ }
+
+ s.Log.Warn("error processing packet", "error", err, "client", cl.ID, "listener", cl.Net.Listener, "pk", pk)
+
+ return err
+ }
+
+ return nil
+}
+
+// validateConnect validates that a connect packet is compliant.
+func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code {
+ code := pk.ConnectValidate() // [MQTT-3.1.4-1] [MQTT-3.1.4-2]
+ if code != packets.CodeSuccess {
+ return code
+ }
+
+ if cl.Properties.ProtocolVersion < 5 && !pk.Connect.Clean && pk.Connect.ClientIdentifier == "" {
+ return packets.ErrUnspecifiedError
+ }
+
+ if cl.Properties.ProtocolVersion < s.Options.Capabilities.MinimumProtocolVersion {
+ return packets.ErrUnsupportedProtocolVersion // [MQTT-3.1.2-2]
+ } else if cl.Properties.Will.Qos > s.Options.Capabilities.MaximumQos {
+ return packets.ErrQosNotSupported // [MQTT-3.2.2-12]
+ } else if cl.Properties.Will.Retain && s.Options.Capabilities.RetainAvailable == 0x00 {
+ return packets.ErrRetainNotSupported // [MQTT-3.2.2-13]
+ }
+
+ return code
+}
+
+// inheritClientSession inherits the state of an existing client sharing the same
+// connection ID. If clean is true, the state of any previously existing client
+// session is abandoned.
+func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
+ if existing, ok := s.Clients.Get(cl.ID); ok {
+ _ = s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
+ if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
+ s.UnsubscribeClient(existing)
+ existing.ClearInflights()
+ atomic.StoreUint32(&existing.State.isTakenOver, 1) // only set isTakenOver after unsubscribe has occurred
+ return false // [MQTT-3.2.2-3]
+ }
+
+ atomic.StoreUint32(&existing.State.isTakenOver, 1)
+ if existing.State.Inflight.Len() > 0 {
+ cl.State.Inflight = existing.State.Inflight.Clone() // [MQTT-3.1.2-5]
+ if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.options.Capabilities.ReceiveMaximum != 0 {
+ cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
+ cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
+ }
+ }
+
+ for _, sub := range existing.State.Subscriptions.GetAll() {
+ existed := !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
+ if !existed {
+ atomic.AddInt64(&s.Info.Subscriptions, 1)
+ }
+ cl.State.Subscriptions.Add(sub.Filter, sub)
+ }
+
+ // Clean the state of the existing client to prevent sequential take-overs
+ // from increasing memory usage by inflights + subs * client-id.
+ s.UnsubscribeClient(existing)
+ existing.ClearInflights()
+
+ s.Log.Debug("session taken over", "client", cl.ID, "old_remote", existing.Net.Remote, "new_remote", cl.Net.Remote)
+
+ return true // [MQTT-3.2.2-3]
+ }
+
+ if atomic.LoadInt64(&s.Info.ClientsConnected) > atomic.LoadInt64(&s.Info.ClientsMaximum) {
+ atomic.AddInt64(&s.Info.ClientsMaximum, 1)
+ }
+
+ return false // [MQTT-3.2.2-2]
+}
+
+// SendConnack returns a Connack packet to a client.
+func (s *Server) SendConnack(cl *Client, reason packets.Code, present bool, properties *packets.Properties) error {
+ if properties == nil {
+ properties = &packets.Properties{
+ ReceiveMaximum: s.Options.Capabilities.ReceiveMaximum,
+ }
+ }
+
+ properties.ReceiveMaximum = s.Options.Capabilities.ReceiveMaximum // 3.2.2.3.3 Receive Maximum
+ if cl.State.ServerKeepalive { // You can set this dynamically using the OnConnect hook.
+ properties.ServerKeepAlive = cl.State.Keepalive // [MQTT-3.1.2-21]
+ properties.ServerKeepAliveFlag = true
+ }
+
+ if reason.Code >= packets.ErrUnspecifiedError.Code {
+ if cl.Properties.ProtocolVersion < 5 {
+ if v3reason, ok := packets.V5CodesToV3[reason]; ok { // NB v3 3.2.2.3 Connack return codes
+ reason = v3reason
+ }
+ }
+
+ properties.ReasonString = reason.Reason
+ ack := packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Type: packets.Connack,
+ },
+ SessionPresent: false, // [MQTT-3.2.2-6]
+ ReasonCode: reason.Code, // [MQTT-3.2.2-8]
+ Properties: *properties,
+ }
+ return cl.WritePacket(ack)
+ }
+
+ if s.Options.Capabilities.MaximumQos < 2 {
+ properties.MaximumQos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9]
+ properties.MaximumQosFlag = true
+ }
+
+ if cl.Properties.Props.AssignedClientID != "" {
+ properties.AssignedClientID = cl.Properties.Props.AssignedClientID // [MQTT-3.1.3-7] [MQTT-3.2.2-16]
+ }
+
+ if cl.Properties.Props.SessionExpiryInterval > s.Options.Capabilities.MaximumSessionExpiryInterval {
+ properties.SessionExpiryInterval = s.Options.Capabilities.MaximumSessionExpiryInterval
+ properties.SessionExpiryIntervalFlag = true
+ cl.Properties.Props.SessionExpiryInterval = properties.SessionExpiryInterval
+ cl.Properties.Props.SessionExpiryIntervalFlag = true
+ }
+
+ ack := packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Type: packets.Connack,
+ },
+ SessionPresent: present,
+ ReasonCode: reason.Code, // [MQTT-3.2.2-8]
+ Properties: *properties,
+ }
+ return cl.WritePacket(ack)
+}
+
+// processPacket processes an inbound packet for a client. Since the method is
+// typically called as a goroutine, errors are primarily for test checking purposes.
+func (s *Server) processPacket(cl *Client, pk packets.Packet) error {
+ var err error
+
+ switch pk.FixedHeader.Type {
+ case packets.Connect:
+ err = s.processConnect(cl, pk)
+ case packets.Disconnect:
+ err = s.processDisconnect(cl, pk)
+ case packets.Pingreq:
+ err = s.processPingreq(cl, pk)
+ case packets.Publish:
+ code := pk.PublishValidate(s.Options.Capabilities.TopicAliasMaximum)
+ if code != packets.CodeSuccess {
+ return code
+ }
+ err = s.processPublish(cl, pk)
+ case packets.Puback:
+ err = s.processPuback(cl, pk)
+ case packets.Pubrec:
+ err = s.processPubrec(cl, pk)
+ case packets.Pubrel:
+ err = s.processPubrel(cl, pk)
+ case packets.Pubcomp:
+ err = s.processPubcomp(cl, pk)
+ case packets.Subscribe:
+ code := pk.SubscribeValidate()
+ if code != packets.CodeSuccess {
+ return code
+ }
+ err = s.processSubscribe(cl, pk)
+ case packets.Unsubscribe:
+ code := pk.UnsubscribeValidate()
+ if code != packets.CodeSuccess {
+ return code
+ }
+ err = s.processUnsubscribe(cl, pk)
+ case packets.Auth:
+ code := pk.AuthValidate()
+ if code != packets.CodeSuccess {
+ return code
+ }
+ err = s.processAuth(cl, pk)
+ default:
+ return fmt.Errorf("no valid packet available; %v", pk.FixedHeader.Type)
+ }
+
+ s.hooks.OnPacketProcessed(cl, pk, err)
+ if err != nil {
+ return err
+ }
+
+ if cl.State.Inflight.Len() > 0 && atomic.LoadInt32(&cl.State.Inflight.sendQuota) > 0 {
+ next, ok := cl.State.Inflight.NextImmediate()
+ if ok {
+ _ = cl.WritePacket(next)
+ if ok := cl.State.Inflight.Delete(next.PacketID); ok {
+ atomic.AddInt64(&s.Info.Inflight, -1)
+ }
+ cl.State.Inflight.DecreaseSendQuota()
+ }
+ }
+
+ return nil
+}
+
+// processConnect processes a Connect packet. The packet cannot be used to establish
+// a new connection on an existing connection. See EstablishConnection instead.
+func (s *Server) processConnect(cl *Client, _ packets.Packet) error {
+ s.sendLWT(cl)
+ return packets.ErrProtocolViolationSecondConnect // [MQTT-3.1.0-2]
+}
+
+// processPingreq processes a Pingreq packet.
+func (s *Server) processPingreq(cl *Client, _ packets.Packet) error {
+ return cl.WritePacket(packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Type: packets.Pingresp, // [MQTT-3.12.4-1]
+ },
+ })
+}
+
+// Publish publishes a publish packet into the broker as if it were sent from the specified client.
+// This is a convenience function which wraps InjectPacket. As such, this method can publish packets
+// to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the
+// outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete).
+func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error {
+ if !s.Options.InlineClient {
+ return ErrInlineClientNotEnabled
+ }
+
+ return s.InjectPacket(s.inlineClient, packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Type: packets.Publish,
+ Qos: qos,
+ Retain: retain,
+ },
+ TopicName: topic,
+ Payload: payload,
+ PacketID: uint16(qos), // we never process the inbound qos, but we need a packet id for validity checks.
+ })
+}
+
+// Subscribe adds an inline subscription for the specified topic filter and subscription identifier
+// with the provided handler function.
+func (s *Server) Subscribe(filter string, subscriptionId int, handler InlineSubFn) error {
+ if !s.Options.InlineClient {
+ return ErrInlineClientNotEnabled
+ }
+
+ if handler == nil {
+ return packets.ErrInlineSubscriptionHandlerInvalid
+ }
+
+ if !IsValidFilter(filter, false) {
+ return packets.ErrTopicFilterInvalid
+ }
+
+ subscription := packets.Subscription{
+ Identifier: subscriptionId,
+ Filter: filter,
+ }
+
+ pk := s.hooks.OnSubscribe(s.inlineClient, packets.Packet{ // subscribe like a normal client.
+ Origin: s.inlineClient.ID,
+ FixedHeader: packets.FixedHeader{Type: packets.Subscribe},
+ Filters: packets.Subscriptions{subscription},
+ })
+
+ inlineSubscription := InlineSubscription{
+ Subscription: subscription,
+ Handler: handler,
+ }
+
+ s.Topics.InlineSubscribe(inlineSubscription)
+ s.hooks.OnSubscribed(s.inlineClient, pk, []byte{packets.CodeSuccess.Code})
+
+ // Handling retained messages.
+ for _, pkv := range s.Topics.Messages(filter) { // [MQTT-3.8.4-4]
+ handler(s.inlineClient, inlineSubscription.Subscription, pkv)
+ }
+ return nil
+}
+
+// Unsubscribe removes an inline subscription for the specified subscription and topic filter.
+// It allows you to unsubscribe a specific subscription from the internal subscription
+// associated with the given topic filter.
+func (s *Server) Unsubscribe(filter string, subscriptionId int) error {
+ if !s.Options.InlineClient {
+ return ErrInlineClientNotEnabled
+ }
+
+ if !IsValidFilter(filter, false) {
+ return packets.ErrTopicFilterInvalid
+ }
+
+ pk := s.hooks.OnUnsubscribe(s.inlineClient, packets.Packet{
+ Origin: s.inlineClient.ID,
+ FixedHeader: packets.FixedHeader{Type: packets.Unsubscribe},
+ Filters: packets.Subscriptions{
+ {
+ Identifier: subscriptionId,
+ Filter: filter,
+ },
+ },
+ })
+
+ s.Topics.InlineUnsubscribe(subscriptionId, filter)
+ s.hooks.OnUnsubscribed(s.inlineClient, pk)
+ return nil
+}
+
+// InjectPacket injects a packet into the broker as if it were sent from the specified client.
+// InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks.
+func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error {
+ pk.ProtocolVersion = cl.Properties.ProtocolVersion
+
+ err := s.processPacket(cl, pk)
+ if err != nil {
+ return err
+ }
+
+ atomic.AddInt64(&cl.ops.info.PacketsReceived, 1)
+ if pk.FixedHeader.Type == packets.Publish {
+ atomic.AddInt64(&cl.ops.info.MessagesReceived, 1)
+ }
+
+ return nil
+}
+
+// processPublish processes a Publish packet.
+func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
+ if !cl.Net.Inline && !IsValidFilter(pk.TopicName, true) {
+ return nil
+ }
+
+ if atomic.LoadInt32(&cl.State.Inflight.receiveQuota) == 0 {
+ return s.DisconnectClient(cl, packets.ErrReceiveMaximum) // ~[MQTT-3.3.4-7] ~[MQTT-3.3.4-8]
+ }
+
+ if !cl.Net.Inline && !s.hooks.OnACLCheck(cl, pk.TopicName, true) {
+ if pk.FixedHeader.Qos == 0 {
+ return nil
+ }
+
+ if cl.Properties.ProtocolVersion != 5 {
+ return s.DisconnectClient(cl, packets.ErrNotAuthorized)
+ }
+
+ ackType := packets.Puback
+ if pk.FixedHeader.Qos == 2 {
+ ackType = packets.Pubrec
+ }
+
+ ack := s.buildAck(pk.PacketID, ackType, 0, pk.Properties, packets.ErrNotAuthorized)
+ return cl.WritePacket(ack)
+ }
+
+ pk.Origin = cl.ID
+ pk.Created = time.Now().Unix()
+
+ if !cl.Net.Inline {
+ if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok {
+ if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10]
+ ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse)
+ return cl.WritePacket(ack)
+ }
+ if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5]
+ atomic.AddInt64(&s.Info.Inflight, -1)
+ }
+ }
+ }
+
+ if pk.Properties.TopicAliasFlag && pk.Properties.TopicAlias > 0 { // [MQTT-3.3.2-11]
+ pk.TopicName = cl.State.TopicAliases.Inbound.Set(pk.Properties.TopicAlias, pk.TopicName)
+ }
+
+ if pk.FixedHeader.Qos > s.Options.Capabilities.MaximumQos {
+ pk.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] Reduce qos based on server max qos capability
+ }
+
+ pkx, err := s.hooks.OnPublish(cl, pk)
+ if err == nil {
+ pk = pkx
+ } else if errors.Is(err, packets.ErrRejectPacket) {
+ return nil
+ } else if errors.Is(err, packets.CodeSuccessIgnore) {
+ pk.Ignore = true
+ } else if cl.Properties.ProtocolVersion == 5 && pk.FixedHeader.Qos > 0 && errors.As(err, new(packets.Code)) {
+ err = cl.WritePacket(s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, err.(packets.Code)))
+ if err != nil {
+ return err
+ }
+ return nil
+ }
+
+ if pk.FixedHeader.Retain { // [MQTT-3.3.1-5] ![MQTT-3.3.1-8]
+ s.retainMessage(cl, pk)
+ }
+
+ // If it's inlineClient, it can't handle PUBREC and PUBREL.
+ // When it publishes a package with a qos > 0, the server treats
+ // the package as qos=0, and the client receives it as qos=1 or 2.
+ if pk.FixedHeader.Qos == 0 || cl.Net.Inline {
+ s.publishToSubscribers(pk)
+ s.hooks.OnPublished(cl, pk)
+ return nil
+ }
+
+ cl.State.Inflight.DecreaseReceiveQuota()
+ ack := s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, packets.QosCodes[pk.FixedHeader.Qos]) // [MQTT-4.3.2-4]
+ if pk.FixedHeader.Qos == 2 {
+ ack = s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.CodeSuccess) // [MQTT-3.3.4-1] [MQTT-4.3.3-8]
+ }
+
+ if ok := cl.State.Inflight.Set(ack); ok {
+ atomic.AddInt64(&s.Info.Inflight, 1)
+ s.hooks.OnQosPublish(cl, ack, ack.Created, 0)
+ }
+
+ err = cl.WritePacket(ack)
+ if err != nil {
+ return err
+ }
+
+ if pk.FixedHeader.Qos == 1 {
+ if ok := cl.State.Inflight.Delete(ack.PacketID); ok {
+ atomic.AddInt64(&s.Info.Inflight, -1)
+ }
+ cl.State.Inflight.IncreaseReceiveQuota()
+ s.hooks.OnQosComplete(cl, ack)
+ }
+
+ s.publishToSubscribers(pk)
+ s.hooks.OnPublished(cl, pk)
+
+ return nil
+}
+
+// retainMessage adds a message to a topic, and if a persistent store is provided,
+// adds the message to the store to be reloaded if necessary.
+func (s *Server) retainMessage(cl *Client, pk packets.Packet) {
+ if s.Options.Capabilities.RetainAvailable == 0 || pk.Ignore {
+ return
+ }
+
+ out := pk.Copy(false)
+ r := s.Topics.RetainMessage(out)
+ s.hooks.OnRetainMessage(cl, pk, r)
+ atomic.StoreInt64(&s.Info.Retained, int64(s.Topics.Retained.Len()))
+}
+
+// publishToSubscribers publishes a publish packet to all subscribers with matching topic filters.
+func (s *Server) publishToSubscribers(pk packets.Packet) {
+ if pk.Ignore {
+ return
+ }
+
+ if pk.Created == 0 {
+ pk.Created = time.Now().Unix()
+ }
+
+ pk.Expiry = pk.Created + s.Options.Capabilities.MaximumMessageExpiryInterval
+ if pk.Properties.MessageExpiryInterval > 0 {
+ pk.Expiry = pk.Created + int64(pk.Properties.MessageExpiryInterval)
+ }
+
+ subscribers := s.Topics.Subscribers(pk.TopicName)
+ if len(subscribers.Shared) > 0 {
+ subscribers = s.hooks.OnSelectSubscribers(subscribers, pk)
+ if len(subscribers.SharedSelected) == 0 {
+ subscribers.SelectShared()
+ }
+ subscribers.MergeSharedSelected()
+ }
+
+ for _, inlineSubscription := range subscribers.InlineSubscriptions {
+ inlineSubscription.Handler(s.inlineClient, inlineSubscription.Subscription, pk)
+ }
+
+ for id, subs := range subscribers.Subscriptions {
+ if cl, ok := s.Clients.Get(id); ok {
+ _, err := s.publishToClient(cl, subs, pk)
+ if err != nil {
+ s.Log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk)
+ }
+ }
+ }
+}
+
+func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (packets.Packet, error) {
+ if sub.NoLocal && pk.Origin == cl.ID {
+ return pk, nil // [MQTT-3.8.3-3]
+ }
+
+ out := pk.Copy(false)
+ if !s.hooks.OnACLCheck(cl, pk.TopicName, false) {
+ return out, packets.ErrNotAuthorized
+ }
+ if !sub.FwdRetainedFlag && ((cl.Properties.ProtocolVersion == 5 && !sub.RetainAsPublished) || cl.Properties.ProtocolVersion < 5) { // ![MQTT-3.3.1-13] [v3 MQTT-3.3.1-9]
+ out.FixedHeader.Retain = false // [MQTT-3.3.1-12]
+ }
+
+ if len(sub.Identifiers) > 0 { // [MQTT-3.3.4-3]
+ out.Properties.SubscriptionIdentifier = []int{}
+ for _, id := range sub.Identifiers {
+ out.Properties.SubscriptionIdentifier = append(out.Properties.SubscriptionIdentifier, id) // [MQTT-3.3.4-4] ![MQTT-3.3.4-5]
+ }
+ sort.Ints(out.Properties.SubscriptionIdentifier)
+ }
+
+ if out.FixedHeader.Qos > sub.Qos {
+ out.FixedHeader.Qos = sub.Qos
+ }
+
+ if out.FixedHeader.Qos > s.Options.Capabilities.MaximumQos {
+ out.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9]
+ }
+
+ if cl.Properties.Props.TopicAliasMaximum > 0 {
+ var aliasExists bool
+ out.Properties.TopicAlias, aliasExists = cl.State.TopicAliases.Outbound.Set(pk.TopicName)
+ if out.Properties.TopicAlias > 0 {
+ out.Properties.TopicAliasFlag = true
+ if aliasExists {
+ out.TopicName = ""
+ }
+ }
+ }
+
+ if out.FixedHeader.Qos > 0 {
+ if cl.State.Inflight.Len() >= int(s.Options.Capabilities.MaximumInflight) {
+ // add hook?
+ atomic.AddInt64(&s.Info.InflightDropped, 1)
+ s.Log.Warn("client store quota reached", "client", cl.ID, "listener", cl.Net.Listener)
+ return out, packets.ErrQuotaExceeded
+ }
+
+ i, err := cl.NextPacketID() // [MQTT-4.3.2-1] [MQTT-4.3.3-1]
+ if err != nil {
+ s.hooks.OnPacketIDExhausted(cl, pk)
+ atomic.AddInt64(&s.Info.InflightDropped, 1)
+ s.Log.Warn("packet ids exhausted", "error", err, "client", cl.ID, "listener", cl.Net.Listener)
+ return out, packets.ErrQuotaExceeded
+ }
+
+ out.PacketID = uint16(i) // [MQTT-2.2.1-4]
+ sentQuota := atomic.LoadInt32(&cl.State.Inflight.sendQuota)
+
+ if ok := cl.State.Inflight.Set(out); ok { // [MQTT-4.3.2-3] [MQTT-4.3.3-3]
+ atomic.AddInt64(&s.Info.Inflight, 1)
+ s.hooks.OnQosPublish(cl, out, out.Created, 0)
+ cl.State.Inflight.DecreaseSendQuota()
+ }
+
+ if sentQuota == 0 && atomic.LoadInt32(&cl.State.Inflight.maximumSendQuota) > 0 {
+ out.Expiry = -1
+ cl.State.Inflight.Set(out)
+ return out, nil
+ }
+ }
+
+ if cl.Net.Conn == nil || cl.Closed() {
+ return out, packets.CodeDisconnect
+ }
+
+ select {
+ case cl.State.outbound <- &out:
+ atomic.AddInt32(&cl.State.outboundQty, 1)
+ default:
+ atomic.AddInt64(&s.Info.MessagesDropped, 1)
+ cl.ops.hooks.OnPublishDropped(cl, pk)
+ if out.FixedHeader.Qos > 0 {
+ cl.State.Inflight.Delete(out.PacketID) // packet was dropped due to irregular circumstances, so rollback inflight.
+ cl.State.Inflight.IncreaseSendQuota()
+ }
+ return out, packets.ErrPendingClientWritesExceeded
+ }
+
+ return out, nil
+}
+
+func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, existed bool) {
+ if IsSharedFilter(sub.Filter) {
+ return // 4.8.2 Non-normative - Shared Subscriptions - No Retained Messages are sent to the Session when it first subscribes.
+ }
+
+ if sub.RetainHandling == 1 && existed || sub.RetainHandling == 2 { // [MQTT-3.3.1-10] [MQTT-3.3.1-11]
+ return
+ }
+
+ sub.FwdRetainedFlag = true
+ for _, pkv := range s.Topics.Messages(sub.Filter) { // [MQTT-3.8.4-4]
+ _, err := s.publishToClient(cl, sub, pkv)
+ if err != nil {
+ s.Log.Debug("failed to publish retained message", "error", err, "client", cl.ID, "listener", cl.Net.Listener, "packet", pkv)
+ continue
+ }
+ s.hooks.OnRetainPublished(cl, pkv)
+ }
+}
+
+// buildAck builds a standardised ack message for Puback, Pubrec, Pubrel, Pubcomp packets.
+func (s *Server) buildAck(packetID uint16, pkt, qos byte, properties packets.Properties, reason packets.Code) packets.Packet {
+ if s.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck {
+ properties = packets.Properties{}
+ }
+ if reason.Code >= packets.ErrUnspecifiedError.Code {
+ properties.ReasonString = reason.Reason
+ }
+
+ pk := packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Type: pkt,
+ Qos: qos,
+ },
+ PacketID: packetID, // [MQTT-2.2.1-5]
+ ReasonCode: reason.Code, // [MQTT-3.4.2-1]
+ Properties: properties,
+ Created: time.Now().Unix(),
+ Expiry: time.Now().Unix() + s.Options.Capabilities.MaximumMessageExpiryInterval,
+ }
+
+ return pk
+}
+
+// processPuback processes a Puback packet, denoting completion of a QOS 1 packet sent from the server.
+func (s *Server) processPuback(cl *Client, pk packets.Packet) error {
+ if _, ok := cl.State.Inflight.Get(pk.PacketID); !ok {
+ return nil // omit, but would be packets.ErrPacketIdentifierNotFound
+ }
+
+ if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5]
+ cl.State.Inflight.IncreaseSendQuota()
+ atomic.AddInt64(&s.Info.Inflight, -1)
+ s.hooks.OnQosComplete(cl, pk)
+ }
+
+ return nil
+}
+
+// processPubrec processes a Pubrec packet, denoting receipt of a QOS 2 packet sent from the server.
+func (s *Server) processPubrec(cl *Client, pk packets.Packet) error {
+ if _, ok := cl.State.Inflight.Get(pk.PacketID); !ok { // [MQTT-4.3.3-7] [MQTT-4.3.3-13]
+ return cl.WritePacket(s.buildAck(pk.PacketID, packets.Pubrel, 1, pk.Properties, packets.ErrPacketIdentifierNotFound))
+ }
+
+ if pk.ReasonCode >= packets.ErrUnspecifiedError.Code || !pk.ReasonCodeValid() { // [MQTT-4.3.3-4]
+ if ok := cl.State.Inflight.Delete(pk.PacketID); ok {
+ atomic.AddInt64(&s.Info.Inflight, -1)
+ }
+ cl.ops.hooks.OnQosDropped(cl, pk)
+ return nil // as per MQTT5 Section 4.13.2 paragraph 2
+ }
+
+ ack := s.buildAck(pk.PacketID, packets.Pubrel, 1, pk.Properties, packets.CodeSuccess) // [MQTT-4.3.3-4] ![MQTT-4.3.3-6]
+ cl.State.Inflight.DecreaseReceiveQuota() // -1 RECV QUOTA
+ cl.State.Inflight.Set(ack) // [MQTT-4.3.3-5]
+ return cl.WritePacket(ack)
+}
+
+// processPubrel processes a Pubrel packet, denoting completion of a QOS 2 packet sent from the client.
+func (s *Server) processPubrel(cl *Client, pk packets.Packet) error {
+ if _, ok := cl.State.Inflight.Get(pk.PacketID); !ok { // [MQTT-4.3.3-7] [MQTT-4.3.3-13]
+ return cl.WritePacket(s.buildAck(pk.PacketID, packets.Pubcomp, 0, pk.Properties, packets.ErrPacketIdentifierNotFound))
+ }
+
+ if pk.ReasonCode >= packets.ErrUnspecifiedError.Code || !pk.ReasonCodeValid() { // [MQTT-4.3.3-9]
+ if ok := cl.State.Inflight.Delete(pk.PacketID); ok {
+ atomic.AddInt64(&s.Info.Inflight, -1)
+ }
+ cl.ops.hooks.OnQosDropped(cl, pk)
+ return nil
+ }
+
+ ack := s.buildAck(pk.PacketID, packets.Pubcomp, 0, pk.Properties, packets.CodeSuccess) // [MQTT-4.3.3-11]
+ cl.State.Inflight.Set(ack)
+
+ err := cl.WritePacket(ack)
+ if err != nil {
+ return err
+ }
+
+ cl.State.Inflight.IncreaseReceiveQuota() // +1 RECV QUOTA
+ cl.State.Inflight.IncreaseSendQuota() // +1 SENT QUOTA
+ if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.3-12]
+ atomic.AddInt64(&s.Info.Inflight, -1)
+ s.hooks.OnQosComplete(cl, pk)
+ }
+
+ return nil
+}
+
+// processPubcomp processes a Pubcomp packet, denoting completion of a QOS 2 packet sent from the server.
+func (s *Server) processPubcomp(cl *Client, pk packets.Packet) error {
+ // regardless of whether the pubcomp is a success or failure, we end the qos flow, delete inflight, and restore the quotas.
+ cl.State.Inflight.IncreaseReceiveQuota() // +1 RECV QUOTA
+ cl.State.Inflight.IncreaseSendQuota() // +1 SENT QUOTA
+ if ok := cl.State.Inflight.Delete(pk.PacketID); ok {
+ atomic.AddInt64(&s.Info.Inflight, -1)
+ s.hooks.OnQosComplete(cl, pk)
+ }
+
+ return nil
+}
+
+// processSubscribe processes a Subscribe packet.
+func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
+ pk = s.hooks.OnSubscribe(cl, pk)
+ code := packets.CodeSuccess
+ if _, ok := cl.State.Inflight.Get(pk.PacketID); ok {
+ code = packets.ErrPacketIdentifierInUse
+ }
+
+ filterExisted := make([]bool, len(pk.Filters))
+ reasonCodes := make([]byte, len(pk.Filters))
+ for i, sub := range pk.Filters {
+ if code != packets.CodeSuccess {
+ reasonCodes[i] = code.Code // NB 3.9.3 Non-normative 0x91
+ continue
+ } else if !IsValidFilter(sub.Filter, false) {
+ reasonCodes[i] = packets.ErrTopicFilterInvalid.Code
+ } else if sub.NoLocal && IsSharedFilter(sub.Filter) {
+ reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4]
+ } else if !s.hooks.OnACLCheck(cl, sub.Filter, false) {
+ reasonCodes[i] = packets.ErrNotAuthorized.Code
+ if s.Options.Capabilities.Compatibilities.ObscureNotAuthorized {
+ reasonCodes[i] = packets.ErrUnspecifiedError.Code
+ }
+ } else {
+ isNew := s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
+ if isNew {
+ atomic.AddInt64(&s.Info.Subscriptions, 1)
+ }
+ cl.State.Subscriptions.Add(sub.Filter, sub) // [MQTT-3.2.2-10]
+
+ if sub.Qos > s.Options.Capabilities.MaximumQos {
+ sub.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9]
+ }
+
+ filterExisted[i] = !isNew
+ reasonCodes[i] = sub.Qos // [MQTT-3.9.3-1] [MQTT-3.8.4-7]
+ }
+
+ if reasonCodes[i] > packets.CodeGrantedQos2.Code && cl.Properties.ProtocolVersion < 5 { // MQTT3
+ reasonCodes[i] = packets.ErrUnspecifiedError.Code
+ }
+ }
+
+ ack := packets.Packet{ // [MQTT-3.8.4-1] [MQTT-3.8.4-5]
+ FixedHeader: packets.FixedHeader{
+ Type: packets.Suback,
+ },
+ PacketID: pk.PacketID, // [MQTT-2.2.1-6] [MQTT-3.8.4-2]
+ ReasonCodes: reasonCodes, // [MQTT-3.8.4-6]
+ Properties: packets.Properties{
+ User: pk.Properties.User,
+ },
+ }
+
+ if code.Code >= packets.ErrUnspecifiedError.Code {
+ ack.Properties.ReasonString = code.Reason
+ }
+
+ s.hooks.OnSubscribed(cl, pk, reasonCodes)
+ err := cl.WritePacket(ack)
+ if err != nil {
+ return err
+ }
+
+ for i, sub := range pk.Filters { // [MQTT-3.3.1-9]
+ if reasonCodes[i] >= packets.ErrUnspecifiedError.Code {
+ continue
+ }
+
+ s.publishRetainedToClient(cl, sub, filterExisted[i])
+ }
+
+ return nil
+}
+
+// processUnsubscribe processes an unsubscribe packet.
+func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error {
+ code := packets.CodeSuccess
+ if _, ok := cl.State.Inflight.Get(pk.PacketID); ok {
+ code = packets.ErrPacketIdentifierInUse
+ }
+
+ pk = s.hooks.OnUnsubscribe(cl, pk)
+ reasonCodes := make([]byte, len(pk.Filters))
+ for i, sub := range pk.Filters { // [MQTT-3.10.4-6] [MQTT-3.11.3-1]
+ if code != packets.CodeSuccess {
+ reasonCodes[i] = code.Code // NB 3.11.3 Non-normative 0x91
+ continue
+ }
+
+ if q := s.Topics.Unsubscribe(sub.Filter, cl.ID); q {
+ atomic.AddInt64(&s.Info.Subscriptions, -1)
+ reasonCodes[i] = packets.CodeSuccess.Code
+ } else {
+ reasonCodes[i] = packets.CodeNoSubscriptionExisted.Code
+ }
+
+ cl.State.Subscriptions.Delete(sub.Filter) // [MQTT-3.10.4-2] [MQTT-3.10.4-2] ~[MQTT-3.10.4-3]
+ }
+
+ ack := packets.Packet{ // [MQTT-3.10.4-4]
+ FixedHeader: packets.FixedHeader{
+ Type: packets.Unsuback,
+ },
+ PacketID: pk.PacketID, // [MQTT-2.2.1-6] [MQTT-3.10.4-5]
+ ReasonCodes: reasonCodes, // [MQTT-3.11.3-2]
+ Properties: packets.Properties{
+ User: pk.Properties.User,
+ },
+ }
+
+ if code.Code >= packets.ErrUnspecifiedError.Code {
+ ack.Properties.ReasonString = code.Reason
+ }
+
+ s.hooks.OnUnsubscribed(cl, pk)
+ return cl.WritePacket(ack)
+}
+
+// UnsubscribeClient unsubscribes a client from all of their subscriptions.
+func (s *Server) UnsubscribeClient(cl *Client) {
+ i := 0
+ filterMap := cl.State.Subscriptions.GetAll()
+ filters := make([]packets.Subscription, len(filterMap))
+ for k := range filterMap {
+ cl.State.Subscriptions.Delete(k)
+ }
+
+ if atomic.LoadUint32(&cl.State.isTakenOver) == 1 {
+ return
+ }
+
+ for k, v := range filterMap {
+ if s.Topics.Unsubscribe(k, cl.ID) {
+ atomic.AddInt64(&s.Info.Subscriptions, -1)
+ }
+ filters[i] = v
+ i++
+ }
+ s.hooks.OnUnsubscribed(cl, packets.Packet{FixedHeader: packets.FixedHeader{Type: packets.Unsubscribe}, Filters: filters})
+}
+
+// processAuth processes an Auth packet.
+func (s *Server) processAuth(cl *Client, pk packets.Packet) error {
+ _, err := s.hooks.OnAuthPacket(cl, pk)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// processDisconnect processes a Disconnect packet.
+func (s *Server) processDisconnect(cl *Client, pk packets.Packet) error {
+ if pk.Properties.SessionExpiryIntervalFlag {
+ if pk.Properties.SessionExpiryInterval > 0 && cl.Properties.Props.SessionExpiryInterval == 0 {
+ return packets.ErrProtocolViolationZeroNonZeroExpiry
+ }
+
+ cl.Properties.Props.SessionExpiryInterval = pk.Properties.SessionExpiryInterval
+ cl.Properties.Props.SessionExpiryIntervalFlag = true
+ }
+
+ if pk.ReasonCode == packets.CodeDisconnectWillMessage.Code { // [MQTT-3.1.2.5] Non-normative comment
+ return packets.CodeDisconnectWillMessage
+ }
+
+ s.loop.willDelayed.Delete(cl.ID) // [MQTT-3.1.3-9] [MQTT-3.1.2-8]
+ cl.Stop(packets.CodeDisconnect) // [MQTT-3.14.4-2]
+
+ return nil
+}
+
+// DisconnectClient sends a Disconnect packet to a client and then closes the client connection.
+func (s *Server) DisconnectClient(cl *Client, code packets.Code) error {
+ out := packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Type: packets.Disconnect,
+ },
+ ReasonCode: code.Code,
+ Properties: packets.Properties{},
+ }
+
+ if code.Code >= packets.ErrUnspecifiedError.Code {
+ out.Properties.ReasonString = code.Reason // // [MQTT-3.14.2-1]
+ }
+
+ // We already have a code we are using to disconnect the client, so we are not
+ // interested if the write packet fails due to a closed connection (as we are closing it).
+ err := cl.WritePacket(out)
+ if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect {
+ cl.Stop(code)
+ if code.Code >= packets.ErrUnspecifiedError.Code {
+ return code
+ }
+ }
+
+ return err
+}
+
+// publishSysTopics publishes the current values to the server $SYS topics.
+// Due to the int to string conversions this method is not as cheap as
+// some of the others so the publishing interval should be set appropriately.
+func (s *Server) publishSysTopics() {
+ pk := packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Type: packets.Publish,
+ Retain: true,
+ },
+ Created: time.Now().Unix(),
+ }
+
+ var m runtime.MemStats
+ runtime.ReadMemStats(&m)
+ atomic.StoreInt64(&s.Info.MemoryAlloc, int64(m.HeapInuse))
+ atomic.StoreInt64(&s.Info.Threads, int64(runtime.NumGoroutine()))
+ atomic.StoreInt64(&s.Info.Time, time.Now().Unix())
+ atomic.StoreInt64(&s.Info.Uptime, time.Now().Unix()-atomic.LoadInt64(&s.Info.Started))
+ atomic.StoreInt64(&s.Info.ClientsTotal, int64(s.Clients.Len()))
+ atomic.StoreInt64(&s.Info.ClientsDisconnected, atomic.LoadInt64(&s.Info.ClientsTotal)-atomic.LoadInt64(&s.Info.ClientsConnected))
+
+ info := s.Info.Clone()
+ topics := map[string]string{
+ SysPrefix + "/broker/version": s.Info.Version,
+ SysPrefix + "/broker/time": Int64toa(info.Time),
+ SysPrefix + "/broker/uptime": Int64toa(info.Uptime),
+ SysPrefix + "/broker/started": Int64toa(info.Started),
+ SysPrefix + "/broker/load/bytes/received": Int64toa(info.BytesReceived),
+ SysPrefix + "/broker/load/bytes/sent": Int64toa(info.BytesSent),
+ SysPrefix + "/broker/clients/connected": Int64toa(info.ClientsConnected),
+ SysPrefix + "/broker/clients/disconnected": Int64toa(info.ClientsDisconnected),
+ SysPrefix + "/broker/clients/maximum": Int64toa(info.ClientsMaximum),
+ SysPrefix + "/broker/clients/total": Int64toa(info.ClientsTotal),
+ SysPrefix + "/broker/packets/received": Int64toa(info.PacketsReceived),
+ SysPrefix + "/broker/packets/sent": Int64toa(info.PacketsSent),
+ SysPrefix + "/broker/messages/received": Int64toa(info.MessagesReceived),
+ SysPrefix + "/broker/messages/sent": Int64toa(info.MessagesSent),
+ SysPrefix + "/broker/messages/dropped": Int64toa(info.MessagesDropped),
+ SysPrefix + "/broker/messages/inflight": Int64toa(info.Inflight),
+ SysPrefix + "/broker/retained": Int64toa(info.Retained),
+ SysPrefix + "/broker/subscriptions": Int64toa(info.Subscriptions),
+ SysPrefix + "/broker/system/memory": Int64toa(info.MemoryAlloc),
+ SysPrefix + "/broker/system/threads": Int64toa(info.Threads),
+ }
+
+ for topic, payload := range topics {
+ pk.TopicName = topic
+ pk.Payload = []byte(payload)
+ s.Topics.RetainMessage(pk.Copy(false))
+ s.publishToSubscribers(pk)
+ }
+
+ s.hooks.OnSysInfoTick(info)
+}
+
+// Close attempts to gracefully shut down the server, all listeners, clients, and stores.
+func (s *Server) Close() error {
+ close(s.done)
+ s.Log.Info("gracefully stopping server")
+ s.Listeners.CloseAll(s.closeListenerClients)
+ s.hooks.OnStopped()
+ s.hooks.Stop()
+
+ s.Log.Info("mochi mqtt server stopped")
+ return nil
+}
+
+// closeListenerClients closes all clients on the specified listener.
+func (s *Server) closeListenerClients(listener string) {
+ clients := s.Clients.GetByListener(listener)
+ for _, cl := range clients {
+ _ = s.DisconnectClient(cl, packets.ErrServerShuttingDown)
+ }
+}
+
+// sendLWT issues an LWT message to a topic when a client disconnects.
+func (s *Server) sendLWT(cl *Client) {
+ if atomic.LoadUint32(&cl.Properties.Will.Flag) == 0 {
+ return
+ }
+
+ modifiedLWT := s.hooks.OnWill(cl, cl.Properties.Will)
+
+ pk := packets.Packet{
+ FixedHeader: packets.FixedHeader{
+ Type: packets.Publish,
+ Retain: modifiedLWT.Retain, // [MQTT-3.1.2-14] [MQTT-3.1.2-15]
+ Qos: modifiedLWT.Qos,
+ },
+ TopicName: modifiedLWT.TopicName,
+ Payload: modifiedLWT.Payload,
+ Properties: packets.Properties{
+ User: modifiedLWT.User,
+ },
+ Origin: cl.ID,
+ Created: time.Now().Unix(),
+ }
+
+ if cl.Properties.Will.WillDelayInterval > 0 {
+ pk.Connect.WillProperties.WillDelayInterval = cl.Properties.Will.WillDelayInterval
+ pk.Expiry = time.Now().Unix() + int64(pk.Connect.WillProperties.WillDelayInterval)
+ s.loop.willDelayed.Add(cl.ID, pk)
+ return
+ }
+
+ if pk.FixedHeader.Retain {
+ s.retainMessage(cl, pk)
+ }
+
+ s.publishToSubscribers(pk) // [MQTT-3.1.2-8]
+ atomic.StoreUint32(&cl.Properties.Will.Flag, 0) // [MQTT-3.1.2-10]
+ s.hooks.OnWillSent(cl, pk)
+}
+
+// readStore reads in any data from the persistent datastore (if applicable).
+func (s *Server) readStore() error {
+ if s.hooks.Provides(StoredClients) {
+ clients, err := s.hooks.StoredClients()
+ if err != nil {
+ return fmt.Errorf("failed to load clients; %w", err)
+ }
+ s.loadClients(clients)
+ s.Log.Debug("loaded clients from store", "len", len(clients))
+ }
+
+ if s.hooks.Provides(StoredSubscriptions) {
+ subs, err := s.hooks.StoredSubscriptions()
+ if err != nil {
+ return fmt.Errorf("load subscriptions; %w", err)
+ }
+ s.loadSubscriptions(subs)
+ s.Log.Debug("loaded subscriptions from store", "len", len(subs))
+ }
+
+ if s.hooks.Provides(StoredInflightMessages) {
+ inflight, err := s.hooks.StoredInflightMessages()
+ if err != nil {
+ return fmt.Errorf("load inflight; %w", err)
+ }
+ s.loadInflight(inflight)
+ s.Log.Debug("loaded inflights from store", "len", len(inflight))
+ }
+
+ if s.hooks.Provides(StoredRetainedMessages) {
+ retained, err := s.hooks.StoredRetainedMessages()
+ if err != nil {
+ return fmt.Errorf("load retained; %w", err)
+ }
+ s.loadRetained(retained)
+ s.Log.Debug("loaded retained messages from store", "len", len(retained))
+ }
+
+ if s.hooks.Provides(StoredSysInfo) {
+ sysInfo, err := s.hooks.StoredSysInfo()
+ if err != nil {
+ return fmt.Errorf("load server info; %w", err)
+ }
+ s.loadServerInfo(sysInfo.Info)
+ s.Log.Debug("loaded $SYS info from store")
+ }
+
+ return nil
+}
+
+// loadServerInfo restores server info from the datastore.
+func (s *Server) loadServerInfo(v system.Info) {
+ if s.Options.Capabilities.Compatibilities.RestoreSysInfoOnRestart {
+ atomic.StoreInt64(&s.Info.BytesReceived, v.BytesReceived)
+ atomic.StoreInt64(&s.Info.BytesSent, v.BytesSent)
+ atomic.StoreInt64(&s.Info.ClientsMaximum, v.ClientsMaximum)
+ atomic.StoreInt64(&s.Info.ClientsTotal, v.ClientsTotal)
+ atomic.StoreInt64(&s.Info.ClientsDisconnected, v.ClientsDisconnected)
+ atomic.StoreInt64(&s.Info.MessagesReceived, v.MessagesReceived)
+ atomic.StoreInt64(&s.Info.MessagesSent, v.MessagesSent)
+ atomic.StoreInt64(&s.Info.MessagesDropped, v.MessagesDropped)
+ atomic.StoreInt64(&s.Info.PacketsReceived, v.PacketsReceived)
+ atomic.StoreInt64(&s.Info.PacketsSent, v.PacketsSent)
+ atomic.StoreInt64(&s.Info.InflightDropped, v.InflightDropped)
+ }
+ atomic.StoreInt64(&s.Info.Retained, v.Retained)
+ atomic.StoreInt64(&s.Info.Inflight, v.Inflight)
+ atomic.StoreInt64(&s.Info.Subscriptions, v.Subscriptions)
+}
+
+// loadSubscriptions restores subscriptions from the datastore.
+func (s *Server) loadSubscriptions(v []storage.Subscription) {
+ for _, sub := range v {
+ sb := packets.Subscription{
+ Filter: sub.Filter,
+ RetainHandling: sub.RetainHandling,
+ Qos: sub.Qos,
+ RetainAsPublished: sub.RetainAsPublished,
+ NoLocal: sub.NoLocal,
+ Identifier: sub.Identifier,
+ }
+ if s.Topics.Subscribe(sub.Client, sb) {
+ if cl, ok := s.Clients.Get(sub.Client); ok {
+ cl.State.Subscriptions.Add(sub.Filter, sb)
+ }
+ }
+ }
+}
+
+// loadClients restores clients from the datastore.
+func (s *Server) loadClients(v []storage.Client) {
+ for _, c := range v {
+ cl := s.NewClient(nil, c.Listener, c.ID, false)
+ cl.Properties.Username = c.Username
+ cl.Properties.Clean = c.Clean
+ cl.Properties.ProtocolVersion = c.ProtocolVersion
+ cl.Properties.Props = packets.Properties{
+ SessionExpiryInterval: c.Properties.SessionExpiryInterval,
+ SessionExpiryIntervalFlag: c.Properties.SessionExpiryIntervalFlag,
+ AuthenticationMethod: c.Properties.AuthenticationMethod,
+ AuthenticationData: c.Properties.AuthenticationData,
+ RequestProblemInfoFlag: c.Properties.RequestProblemInfoFlag,
+ RequestProblemInfo: c.Properties.RequestProblemInfo,
+ RequestResponseInfo: c.Properties.RequestResponseInfo,
+ ReceiveMaximum: c.Properties.ReceiveMaximum,
+ TopicAliasMaximum: c.Properties.TopicAliasMaximum,
+ User: c.Properties.User,
+ MaximumPacketSize: c.Properties.MaximumPacketSize,
+ }
+ cl.Properties.Will = Will(c.Will)
+
+ // cancel the context, update cl.State such as disconnected time and stopCause.
+ cl.Stop(packets.ErrServerShuttingDown)
+
+ expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
+ s.hooks.OnDisconnect(cl, packets.ErrServerShuttingDown, expire)
+ if expire {
+ cl.ClearInflights()
+ s.UnsubscribeClient(cl)
+ } else {
+ s.Clients.Add(cl)
+ }
+ }
+}
+
+// loadInflight restores inflight messages from the datastore.
+func (s *Server) loadInflight(v []storage.Message) {
+ for _, msg := range v {
+ if client, ok := s.Clients.Get(msg.Origin); ok {
+ client.State.Inflight.Set(msg.ToPacket())
+ }
+ }
+}
+
+// loadRetained restores retained messages from the datastore.
+func (s *Server) loadRetained(v []storage.Message) {
+ for _, msg := range v {
+ s.Topics.RetainMessage(msg.ToPacket())
+ }
+}
+
+// clearExpiredClients deletes all clients which have been disconnected for longer
+// than their given expiry intervals.
+func (s *Server) clearExpiredClients(dt int64) {
+ for id, client := range s.Clients.GetAll() {
+ disconnected := client.StopTime()
+ if disconnected == 0 {
+ continue
+ }
+
+ expire := s.Options.Capabilities.MaximumSessionExpiryInterval
+ if client.Properties.ProtocolVersion == 5 && client.Properties.Props.SessionExpiryIntervalFlag {
+ expire = client.Properties.Props.SessionExpiryInterval
+ }
+
+ if disconnected+int64(expire) < dt {
+ s.hooks.OnClientExpired(client)
+ s.Clients.Delete(id) // [MQTT-4.1.0-2]
+ }
+ }
+}
+
+// clearExpiredRetainedMessage deletes retained messages from topics if they have expired.
+func (s *Server) clearExpiredRetainedMessages(now int64) {
+ for filter, pk := range s.Topics.Retained.GetAll() {
+ expired := pk.ProtocolVersion == 5 && pk.Expiry > 0 && pk.Expiry < now // [MQTT-3.3.2-5]
+
+ // If the maximum message expiry interval is set (greater than 0), and the message
+ // retention period exceeds the maximum expiry, the message will be forcibly removed.
+ enforced := s.Options.Capabilities.MaximumMessageExpiryInterval > 0 &&
+ now-pk.Created > s.Options.Capabilities.MaximumMessageExpiryInterval
+
+ if expired || enforced {
+ s.Topics.Retained.Delete(filter)
+ s.hooks.OnRetainedExpired(filter)
+ }
+ }
+}
+
+// clearExpiredInflights deletes any inflight messages which have expired.
+func (s *Server) clearExpiredInflights(now int64) {
+ for _, client := range s.Clients.GetAll() {
+ if deleted := client.ClearExpiredInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 {
+ for _, id := range deleted {
+ s.hooks.OnQosDropped(client, packets.Packet{PacketID: id})
+ }
+ }
+ }
+}
+
+// sendDelayedLWT sends any LWT messages which have reached their issue time.
+func (s *Server) sendDelayedLWT(dt int64) {
+ for id, pk := range s.loop.willDelayed.GetAll() {
+ if dt > pk.Expiry {
+ s.publishToSubscribers(pk) // [MQTT-3.1.2-8]
+ if cl, ok := s.Clients.Get(id); ok {
+ if pk.FixedHeader.Retain {
+ s.retainMessage(cl, pk)
+ }
+ cl.Properties.Will = Will{} // [MQTT-3.1.2-10]
+ s.hooks.OnWillSent(cl, pk)
+ }
+ s.loop.willDelayed.Delete(id)
+ }
+ }
+}
+
+// Int64toa converts an int64 to a string.
+func Int64toa(v int64) string {
+ return strconv.FormatInt(v, 10)
+}
diff --git a/mqtt/server_test.go b/mqtt/server_test.go
new file mode 100644
index 0000000..d9b8705
--- /dev/null
+++ b/mqtt/server_test.go
@@ -0,0 +1,3915 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package mqtt
+
+import (
+ "bytes"
+ "encoding/binary"
+ "io"
+ "log/slog"
+ "net"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "testmqtt/hooks/storage"
+ "testmqtt/listeners"
+ "testmqtt/packets"
+ "testmqtt/system"
+
+ "github.com/stretchr/testify/require"
+)
+
+var logger = slog.New(slog.NewTextHandler(io.Discard, nil))
+
+type ProtocolTest []struct {
+ protocolVersion byte
+ in packets.TPacketCase
+ out packets.TPacketCase
+ data map[string]any
+}
+
+type AllowHook struct {
+ HookBase
+}
+
+func (h *AllowHook) SetOpts(l *slog.Logger, opts *HookOptions) {
+ h.Log = l
+ h.Opts = opts
+}
+
+func (h *AllowHook) ID() string {
+ return "allow-all-auth"
+}
+
+func (h *AllowHook) Provides(b byte) bool {
+ return bytes.Contains([]byte{OnConnectAuthenticate, OnACLCheck}, []byte{b})
+}
+
+func (h *AllowHook) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { return true }
+func (h *AllowHook) OnACLCheck(cl *Client, topic string, write bool) bool { return true }
+
+type DenyHook struct {
+ HookBase
+}
+
+func (h *DenyHook) SetOpts(l *slog.Logger, opts *HookOptions) {
+ h.Log = l
+ h.Opts = opts
+}
+
+func (h *DenyHook) ID() string {
+ return "deny-all-auth"
+}
+
+func (h *DenyHook) Provides(b byte) bool {
+ return bytes.Contains([]byte{OnConnectAuthenticate, OnACLCheck}, []byte{b})
+}
+
+func (h *DenyHook) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { return false }
+func (h *DenyHook) OnACLCheck(cl *Client, topic string, write bool) bool { return false }
+
+type DelayHook struct {
+ HookBase
+ DisconnectDelay time.Duration
+}
+
+func (h *DelayHook) SetOpts(l *slog.Logger, opts *HookOptions) {
+ h.Log = l
+ h.Opts = opts
+}
+
+func (h *DelayHook) ID() string {
+ return "delay-hook"
+}
+
+func (h *DelayHook) Provides(b byte) bool {
+ return bytes.Contains([]byte{OnDisconnect}, []byte{b})
+}
+
+func (h *DelayHook) OnDisconnect(cl *Client, err error, expire bool) {
+ time.Sleep(h.DisconnectDelay)
+}
+
+func newServer() *Server {
+ cc := NewDefaultServerCapabilities()
+ cc.MaximumMessageExpiryInterval = 0
+ cc.ReceiveMaximum = 0
+ s := New(&Options{
+ Logger: logger,
+ Capabilities: cc,
+ })
+ _ = s.AddHook(new(AllowHook), nil)
+ return s
+}
+
+func newServerWithInlineClient() *Server {
+ cc := NewDefaultServerCapabilities()
+ cc.MaximumMessageExpiryInterval = 0
+ cc.ReceiveMaximum = 0
+ s := New(&Options{
+ Logger: logger,
+ Capabilities: cc,
+ InlineClient: true,
+ })
+ _ = s.AddHook(new(AllowHook), nil)
+ return s
+}
+
+func TestOptionsSetDefaults(t *testing.T) {
+ opts := &Options{}
+ opts.ensureDefaults()
+
+ require.Equal(t, defaultSysTopicInterval, opts.SysTopicResendInterval)
+ require.Equal(t, NewDefaultServerCapabilities(), opts.Capabilities)
+
+ opts = new(Options)
+ opts.ensureDefaults()
+ require.Equal(t, defaultSysTopicInterval, opts.SysTopicResendInterval)
+}
+
+func TestNew(t *testing.T) {
+ s := New(nil)
+ require.NotNil(t, s)
+ require.NotNil(t, s.Clients)
+ require.NotNil(t, s.Listeners)
+ require.NotNil(t, s.Topics)
+ require.NotNil(t, s.Info)
+ require.NotNil(t, s.Log)
+ require.NotNil(t, s.Options)
+ require.NotNil(t, s.loop)
+ require.NotNil(t, s.loop.sysTopics)
+ require.NotNil(t, s.loop.inflightExpiry)
+ require.NotNil(t, s.loop.clientExpiry)
+ require.NotNil(t, s.hooks)
+ require.NotNil(t, s.hooks.Log)
+ require.NotNil(t, s.done)
+ require.Nil(t, s.inlineClient)
+ require.Equal(t, 0, s.Clients.Len())
+}
+
+func TestNewWithInlineClient(t *testing.T) {
+ s := New(&Options{
+ InlineClient: true,
+ })
+ require.NotNil(t, s.inlineClient)
+ require.Equal(t, 1, s.Clients.Len())
+}
+
+func TestNewNilOpts(t *testing.T) {
+ s := New(nil)
+ require.NotNil(t, s)
+ require.NotNil(t, s.Options)
+}
+
+func TestServerNewClient(t *testing.T) {
+ s := New(nil)
+ s.Log = logger
+ r, _ := net.Pipe()
+
+ cl := s.NewClient(r, "testing", "test", false)
+ require.NotNil(t, cl)
+ require.Equal(t, "test", cl.ID)
+ require.Equal(t, "testing", cl.Net.Listener)
+ require.False(t, cl.Net.Inline)
+ require.NotNil(t, cl.State.Inflight.internal)
+ require.NotNil(t, cl.State.Subscriptions)
+ require.NotNil(t, cl.State.TopicAliases)
+ require.Equal(t, defaultKeepalive, cl.State.Keepalive)
+ require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
+ require.NotNil(t, cl.Net.Conn)
+ require.NotNil(t, cl.Net.bconn)
+ require.NotNil(t, cl.ops)
+ require.Equal(t, s.Log, cl.ops.log)
+}
+
+func TestServerNewClientInline(t *testing.T) {
+ s := New(nil)
+ cl := s.NewClient(nil, "testing", "test", true)
+ require.True(t, cl.Net.Inline)
+}
+
+func TestServerAddHook(t *testing.T) {
+ s := New(nil)
+
+ s.Log = logger
+ require.NotNil(t, s)
+
+ require.Equal(t, int64(0), s.hooks.Len())
+ err := s.AddHook(new(HookBase), nil)
+ require.NoError(t, err)
+ require.Equal(t, int64(1), s.hooks.Len())
+}
+
+func TestServerAddListener(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ require.NotNil(t, s)
+
+ err := s.AddListener(listeners.NewMockListener("t1", ":1882"))
+ require.NoError(t, err)
+
+ // add existing listener
+ err = s.AddListener(listeners.NewMockListener("t1", ":1882"))
+ require.Error(t, err)
+ require.Equal(t, ErrListenerIDExists, err)
+}
+
+func TestServerAddHooksFromConfig(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+ require.NotNil(t, s)
+ s.Log = logger
+
+ hooks := []HookLoadConfig{
+ {Hook: new(modifiedHookBase)},
+ }
+
+ err := s.AddHooksFromConfig(hooks)
+ require.NoError(t, err)
+}
+
+func TestServerAddHooksFromConfigError(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+ require.NotNil(t, s)
+ s.Log = logger
+
+ hooks := []HookLoadConfig{
+ {Hook: new(modifiedHookBase), Config: map[string]interface{}{}},
+ }
+
+ err := s.AddHooksFromConfig(hooks)
+ require.Error(t, err)
+}
+
+func TestServerAddListenerInitFailure(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ require.NotNil(t, s)
+
+ m := listeners.NewMockListener("t1", ":1882")
+ m.ErrListen = true
+ err := s.AddListener(m)
+ require.Error(t, err)
+}
+
+func TestServerAddListenersFromConfig(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+ require.NotNil(t, s)
+ s.Log = logger
+
+ lc := []listeners.Config{
+ {Type: listeners.TypeTCP, ID: "tcp", Address: ":1883"},
+ {Type: listeners.TypeWS, ID: "ws", Address: ":1882"},
+ {Type: listeners.TypeHealthCheck, ID: "health", Address: ":1881"},
+ {Type: listeners.TypeSysInfo, ID: "info", Address: ":1880"},
+ {Type: listeners.TypeUnix, ID: "unix", Address: "mochi.sock"},
+ {Type: listeners.TypeMock, ID: "mock", Address: "0"},
+ {Type: "unknown", ID: "unknown"},
+ }
+
+ err := s.AddListenersFromConfig(lc)
+ require.NoError(t, err)
+ require.Equal(t, 6, s.Listeners.Len())
+
+ tcp, _ := s.Listeners.Get("tcp")
+ require.Equal(t, "[::]:1883", tcp.Address())
+
+ ws, _ := s.Listeners.Get("ws")
+ require.Equal(t, ":1882", ws.Address())
+
+ health, _ := s.Listeners.Get("health")
+ require.Equal(t, ":1881", health.Address())
+
+ info, _ := s.Listeners.Get("info")
+ require.Equal(t, ":1880", info.Address())
+
+ unix, _ := s.Listeners.Get("unix")
+ require.Equal(t, "mochi.sock", unix.Address())
+
+ mock, _ := s.Listeners.Get("mock")
+ require.Equal(t, "0", mock.Address())
+}
+
+func TestServerAddListenersFromConfigError(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+ require.NotNil(t, s)
+ s.Log = logger
+
+ lc := []listeners.Config{
+ {Type: listeners.TypeTCP, ID: "tcp", Address: "x"},
+ }
+
+ err := s.AddListenersFromConfig(lc)
+ require.Error(t, err)
+ require.Equal(t, 0, s.Listeners.Len())
+}
+
+func TestServerServe(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ require.NotNil(t, s)
+
+ err := s.AddListener(listeners.NewMockListener("t1", ":1882"))
+ require.NoError(t, err)
+
+ err = s.Serve()
+ require.NoError(t, err)
+
+ time.Sleep(time.Millisecond)
+
+ require.Equal(t, 1, s.Listeners.Len())
+ listener, ok := s.Listeners.Get("t1")
+
+ require.Equal(t, true, ok)
+ require.Equal(t, true, listener.(*listeners.MockListener).IsServing())
+}
+
+func TestServerServeFromConfig(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+ require.NotNil(t, s)
+
+ s.Options.Listeners = []listeners.Config{
+ {Type: listeners.TypeMock, ID: "mock", Address: "0"},
+ }
+
+ s.Options.Hooks = []HookLoadConfig{
+ {Hook: new(modifiedHookBase)},
+ }
+
+ err := s.Serve()
+ require.NoError(t, err)
+
+ time.Sleep(time.Millisecond)
+
+ require.Equal(t, 1, s.Listeners.Len())
+ listener, ok := s.Listeners.Get("mock")
+
+ require.Equal(t, true, ok)
+ require.Equal(t, true, listener.(*listeners.MockListener).IsServing())
+}
+
+func TestServerServeFromConfigListenerError(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+ require.NotNil(t, s)
+
+ s.Options.Listeners = []listeners.Config{
+ {Type: listeners.TypeTCP, ID: "tcp", Address: "x"},
+ }
+
+ err := s.Serve()
+ require.Error(t, err)
+}
+
+func TestServerServeFromConfigHookError(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+ require.NotNil(t, s)
+
+ s.Options.Hooks = []HookLoadConfig{
+ {Hook: new(modifiedHookBase), Config: map[string]interface{}{}},
+ }
+
+ err := s.Serve()
+ require.Error(t, err)
+}
+
+func TestServerServeReadStoreFailure(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ require.NotNil(t, s)
+
+ err := s.AddListener(listeners.NewMockListener("t1", ":1882"))
+ require.NoError(t, err)
+
+ hook := new(modifiedHookBase)
+ hook.failAt = 1
+ err = s.AddHook(hook, nil)
+ require.NoError(t, err)
+
+ err = s.Serve()
+ require.Error(t, err)
+}
+
+func TestServerEventLoop(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ s.loop.sysTopics = time.NewTicker(time.Millisecond)
+ s.loop.inflightExpiry = time.NewTicker(time.Millisecond)
+ s.loop.clientExpiry = time.NewTicker(time.Millisecond)
+ s.loop.retainedExpiry = time.NewTicker(time.Millisecond)
+ s.loop.willDelaySend = time.NewTicker(time.Millisecond)
+ go s.eventLoop()
+
+ time.Sleep(time.Millisecond * 3)
+}
+
+func TestServerReadConnectionPacket(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ cl, r, _ := newTestClient()
+ s.Clients.Add(cl)
+
+ o := make(chan packets.Packet)
+ go func() {
+ pk, err := s.readConnectionPacket(cl)
+ require.NoError(t, err)
+ o <- pk
+ }()
+
+ go func() {
+ _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes)
+ _ = r.Close()
+ }()
+
+ require.Equal(t, *packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet, <-o)
+}
+
+func TestServerReadConnectionPacketBadFixedHeader(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ cl, r, _ := newTestClient()
+ s.Clients.Add(cl)
+
+ o := make(chan error)
+ go func() {
+ _, err := s.readConnectionPacket(cl)
+ o <- err
+ }()
+
+ go func() {
+ _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalFixedHeader).RawBytes)
+ _ = r.Close()
+ }()
+
+ err := <-o
+ require.Error(t, err)
+ require.Equal(t, packets.ErrMalformedVariableByteInteger, err)
+}
+
+func TestServerReadConnectionPacketBadPacketType(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ cl, r, _ := newTestClient()
+ s.Clients.Add(cl)
+
+ go func() {
+ _, _ = r.Write(packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes)
+ _ = r.Close()
+ }()
+
+ _, err := s.readConnectionPacket(cl)
+ require.Error(t, err)
+ require.Equal(t, packets.ErrProtocolViolationRequireFirstConnect, err)
+}
+
+func TestServerReadConnectionPacketBadPacket(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ cl, r, _ := newTestClient()
+ s.Clients.Add(cl)
+
+ go func() {
+ _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalProtocolName).RawBytes)
+ _ = r.Close()
+ }()
+
+ _, err := s.readConnectionPacket(cl)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrMalformedProtocolName)
+}
+
+func TestEstablishConnection(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ r, w := net.Pipe()
+ o := make(chan error)
+ go func() {
+ o <- s.EstablishConnection("tcp", r)
+ }()
+
+ go func() {
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes)
+ _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
+ }()
+
+ // receive the connack
+ recv := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(w)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ err := <-o
+ require.NoError(t, err)
+
+ // Todo:
+ // s.Clients is already empty here. Is it necessary to check v.StopCause()?
+
+ // for _, v := range s.Clients.GetAll() {
+ // require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect
+ // }
+
+ require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv)
+
+ _ = w.Close()
+ _ = r.Close()
+
+ // client must be deleted on session close if Clean = true
+ _, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).Packet.Connect.ClientIdentifier)
+ require.False(t, ok)
+}
+
+func TestEstablishConnectionAckFailure(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ r, w := net.Pipe()
+ o := make(chan error)
+ go func() {
+ o <- s.EstablishConnection("tcp", r)
+ }()
+
+ go func() {
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes)
+ _ = w.Close()
+ }()
+
+ err := <-o
+ require.Error(t, err)
+ require.ErrorIs(t, err, io.ErrClosedPipe)
+
+ _ = r.Close()
+}
+
+func TestEstablishConnectionReadError(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ r, w := net.Pipe()
+ o := make(chan error)
+ go func() {
+ o <- s.EstablishConnection("tcp", r)
+ }()
+
+ go func() {
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).RawBytes)
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) // second connect error
+ }()
+
+ // receive the connack
+ recv := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(w)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ err := <-o
+ require.Error(t, err)
+
+ // Retrieve the client corresponding to the Client Identifier.
+ retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet.Connect.ClientIdentifier)
+ require.True(t, ok)
+ require.ErrorIs(t, retrievedCl.StopCause(), packets.ErrProtocolViolationSecondConnect) // true error is disconnect
+
+ ret := <-recv
+ require.Equal(t, append(
+ packets.TPacketData[packets.Connack].Get(packets.TConnackMinCleanMqtt5).RawBytes,
+ packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectSecondConnect).RawBytes...),
+ ret,
+ )
+
+ _ = w.Close()
+ _ = r.Close()
+}
+
+func TestEstablishConnectionInheritExisting(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ cl, r0, _ := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ cl.Properties.Username = []byte("mochi")
+ cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
+ cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
+ cl.State.Inflight.Set(*packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
+ s.Clients.Add(cl)
+
+ r, w := net.Pipe()
+ o := make(chan error)
+ go func() {
+ err := s.EstablishConnection("tcp", r)
+ o <- err
+ }()
+
+ go func() {
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes)
+ time.Sleep(time.Millisecond) // we want to receive the queued inflight, so we need to wait a moment before sending the disconnect.
+ _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
+ }()
+
+ // receive the disconnect session takeover
+ takeover := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r0)
+ require.NoError(t, err)
+ takeover <- buf
+ }()
+
+ // receive the connack
+ recv := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(w)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ err := <-o
+ require.NoError(t, err)
+
+ // Retrieve the client corresponding to the Client Identifier.
+ retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier)
+ require.True(t, ok)
+ require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect
+
+ connackPlusPacket := append(
+ packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes,
+ packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes...,
+ )
+ require.Equal(t, connackPlusPacket, <-recv)
+ require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectTakeover).RawBytes, <-takeover)
+
+ time.Sleep(time.Microsecond * 100)
+ _ = w.Close()
+ _ = r.Close()
+
+ clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier)
+ require.True(t, ok)
+ require.NotEmpty(t, clw.State.Subscriptions)
+
+ // Prevent sequential takeover memory-bloom.
+ require.Empty(t, cl.State.Subscriptions.GetAll())
+}
+
+// See https://github.com/mochi-mqtt/server/issues/173
+func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) {
+ s := newServer()
+ d := new(DelayHook)
+ d.DisconnectDelay = time.Millisecond * 200
+ _ = s.AddHook(d, nil)
+ defer s.Close()
+
+ // Clean session, 0 session expiry interval
+ cl1RawBytes := []byte{
+ packets.Connect << 4, 21, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 5, // Protocol Version
+ 1 << 1, // Packet Flags
+ 0, 30, // Keepalive
+ 5, // Properties length
+ 17, 0, 0, 0, 0, // Session Expiry Interval (17)
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ }
+
+ // Make first connection
+ r1, w1 := net.Pipe()
+ o1 := make(chan error)
+ go func() {
+ err := s.EstablishConnection("tcp", r1)
+ o1 <- err
+ }()
+ go func() {
+ _, _ = w1.Write(cl1RawBytes)
+ }()
+
+ // receive the first connack
+ recv := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(w1)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ // Get the first client pointer
+ time.Sleep(time.Millisecond * 50)
+ cl1, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).Packet.Connect.ClientIdentifier)
+ require.True(t, ok)
+ cl1.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
+ cl1.State.Subscriptions.Add("d/e/f", packets.Subscription{Filter: "d/e/f", Qos: 0})
+ time.Sleep(time.Millisecond * 50)
+
+ // Make the second connection
+ r2, w2 := net.Pipe()
+ o2 := make(chan error)
+ go func() {
+ err := s.EstablishConnection("tcp", r2)
+ o2 <- err
+ }()
+ go func() {
+ x := packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes[:]
+ x[19] = '.' // differentiate username bytes in debugging
+ _, _ = w2.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes)
+ }()
+
+ // receive the second connack
+ recv2 := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(w2)
+ require.NoError(t, err)
+ recv2 <- buf
+ }()
+
+ // Capture first Client pointer
+ clp1, ok := s.Clients.Get("zen")
+ require.True(t, ok)
+ require.Empty(t, clp1.Properties.Username)
+ require.NotEmpty(t, clp1.State.Subscriptions.GetAll())
+
+ err1 := <-o1
+ require.Error(t, err1)
+ require.ErrorIs(t, err1, io.ErrClosedPipe)
+
+ // Capture second Client pointer
+ clp2, ok := s.Clients.Get("zen")
+ require.True(t, ok)
+ require.Equal(t, []byte(".ochi"), clp2.Properties.Username)
+ require.NotEmpty(t, clp2.State.Subscriptions.GetAll())
+ require.Empty(t, clp1.State.Subscriptions.GetAll())
+
+ _, _ = w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
+ require.NoError(t, <-o2)
+}
+
+func TestEstablishConnectionResentPendingInflightsError(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ n := time.Now().Unix()
+ cl, r0, _ := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
+ cl.State.Inflight = NewInflights()
+ cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: n - 2}) // no packet type
+ s.Clients.Add(cl)
+
+ r, w := net.Pipe()
+ o := make(chan error)
+ go func() {
+ o <- s.EstablishConnection("tcp", r)
+ }()
+
+ go func() {
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes)
+ }()
+
+ go func() {
+ _, err := io.ReadAll(r0)
+ require.NoError(t, err)
+ }()
+
+ go func() {
+ _, err := io.ReadAll(w)
+ require.NoError(t, err)
+ }()
+
+ err := <-o
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrNoValidPacketAvailable)
+}
+
+func TestEstablishConnectionInheritExistingClean(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+
+ cl, r0, _ := newTestClient()
+ cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
+ cl.Properties.Clean = true
+ cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
+ s.Clients.Add(cl)
+
+ r, w := net.Pipe()
+ o := make(chan error)
+ go func() {
+ o <- s.EstablishConnection("tcp", r)
+ }()
+
+ go func() {
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes)
+ _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
+ }()
+
+ // receive the disconnect
+ takeover := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r0)
+ require.NoError(t, err)
+ takeover <- buf
+ }()
+
+ // receive the connack
+ recv := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(w)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ err := <-o
+ require.NoError(t, err)
+
+ // Retrieve the client corresponding to the Client Identifier.
+ retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier)
+ require.True(t, ok)
+ require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect
+
+ require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv)
+ require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, <-takeover)
+
+ _ = w.Close()
+ _ = r.Close()
+
+ clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier)
+ require.True(t, ok)
+ require.Equal(t, 0, clw.State.Subscriptions.Len())
+}
+
+func TestEstablishConnectionBadAuthentication(t *testing.T) {
+ s := New(&Options{
+ Logger: logger,
+ })
+ defer s.Close()
+
+ r, w := net.Pipe()
+ o := make(chan error)
+ go func() {
+ o <- s.EstablishConnection("tcp", r)
+ }()
+
+ go func() {
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes)
+ _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
+ }()
+
+ // receive the connack
+ recv := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(w)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ err := <-o
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrBadUsernameOrPassword)
+ require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackBadUsernamePasswordNoSession).RawBytes, <-recv)
+
+ _ = w.Close()
+ _ = r.Close()
+}
+
+func TestEstablishConnectionBadAuthenticationAckFailure(t *testing.T) {
+ s := New(&Options{
+ Logger: logger,
+ })
+ defer s.Close()
+
+ r, w := net.Pipe()
+ o := make(chan error)
+ go func() {
+ o <- s.EstablishConnection("tcp", r)
+ }()
+
+ go func() {
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes)
+ _ = w.Close()
+ }()
+
+ err := <-o
+ require.Error(t, err)
+ require.ErrorIs(t, err, io.ErrClosedPipe)
+
+ _ = r.Close()
+}
+
+func TestServerEstablishConnectionInvalidConnect(t *testing.T) {
+ s := newServer()
+
+ r, w := net.Pipe()
+ o := make(chan error)
+ go func() {
+ o <- s.EstablishConnection("tcp", r)
+ }()
+
+ go func() {
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes)
+ _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
+ }()
+
+ // receive the connack
+ recv := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(w)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ err := <-o
+ require.Error(t, err)
+ require.ErrorIs(t, packets.ErrProtocolViolationReservedBit, err)
+ require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackProtocolViolationNoSession).RawBytes, <-recv)
+
+ _ = r.Close()
+}
+
+func TestEstablishConnectionMaximumClientsReached(t *testing.T) {
+ cc := NewDefaultServerCapabilities()
+ cc.MaximumClients = 0
+ s := New(&Options{
+ Logger: logger,
+ Capabilities: cc,
+ })
+ _ = s.AddHook(new(AllowHook), nil)
+ defer s.Close()
+
+ r, w := net.Pipe()
+ o := make(chan error)
+ go func() {
+ o <- s.EstablishConnection("tcp", r)
+ }()
+
+ go func() {
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes)
+ }()
+
+ // receive the connack
+ recv := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(w)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ err := <-o
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrServerBusy)
+
+ _ = r.Close()
+}
+
+// See https://github.com/mochi-mqtt/server/issues/178
+func TestServerEstablishConnectionZeroByteUsernameIsValid(t *testing.T) {
+ s := newServer()
+
+ r, w := net.Pipe()
+ o := make(chan error)
+ go func() {
+ o <- s.EstablishConnection("tcp", r)
+ }()
+
+ go func() {
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectZeroByteUsername).RawBytes)
+ _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
+ }()
+
+ // receive the connack error
+ go func() {
+ _, err := io.ReadAll(w)
+ require.NoError(t, err)
+ }()
+
+ err := <-o
+ require.NoError(t, err)
+
+ _ = r.Close()
+}
+
+func TestServerEstablishConnectionInvalidConnectAckFailure(t *testing.T) {
+ s := newServer()
+
+ r, w := net.Pipe()
+ o := make(chan error)
+ go func() {
+ o <- s.EstablishConnection("tcp", r)
+ }()
+
+ go func() {
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes)
+ _ = w.Close()
+ }()
+
+ err := <-o
+ require.Error(t, err)
+ require.ErrorIs(t, err, io.ErrClosedPipe)
+
+ _ = r.Close()
+}
+
+func TestServerEstablishConnectionBadPacket(t *testing.T) {
+ s := newServer()
+
+ r, w := net.Pipe()
+ o := make(chan error)
+ go func() {
+ o <- s.EstablishConnection("tcp", r)
+ }()
+
+ go func() {
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnackBadProtocolVersion).RawBytes)
+ _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
+ }()
+
+ err := <-o
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrProtocolViolationRequireFirstConnect)
+
+ _ = r.Close()
+}
+
+func TestServerEstablishConnectionOnConnectError(t *testing.T) {
+ s := newServer()
+ hook := new(modifiedHookBase)
+ hook.fail = true
+ err := s.AddHook(hook, nil)
+ require.NoError(t, err)
+
+ r, w := net.Pipe()
+ o := make(chan error)
+ go func() {
+ o <- s.EstablishConnection("tcp", r)
+ }()
+
+ go func() {
+ _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes)
+ }()
+
+ err = <-o
+ require.Error(t, err)
+ require.ErrorIs(t, err, errTestHook)
+
+ _ = r.Close()
+}
+
+func TestServerSendConnack(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ s.Options.Capabilities.MaximumQos = 1
+ cl.Properties.Props = packets.Properties{
+ AssignedClientID: "mochi",
+ }
+ go func() {
+ err := s.SendConnack(cl, packets.CodeSuccess, true, nil)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackMinMqtt5).RawBytes, buf)
+}
+
+func TestServerSendConnackFailureReason(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ go func() {
+ err := s.SendConnack(cl, packets.ErrUnspecifiedError, true, nil)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackInvalidMinMqtt5).RawBytes, buf)
+}
+
+func TestServerSendConnackWithServerKeepalive(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ cl.State.Keepalive = 10
+ cl.State.ServerKeepalive = true
+ go func() {
+ err := s.SendConnack(cl, packets.CodeSuccess, true, nil)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackServerKeepalive).RawBytes, buf)
+}
+
+func TestServerValidateConnect(t *testing.T) {
+ packet := *packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet
+ invalidBitPacket := packet
+ invalidBitPacket.ReservedBit = 1
+ packetCleanIdPacket := packet
+ packetCleanIdPacket.Connect.Clean = false
+ packetCleanIdPacket.Connect.ClientIdentifier = ""
+ tt := []struct {
+ desc string
+ client *Client
+ capabilities Capabilities
+ packet packets.Packet
+ expect packets.Code
+ }{
+ {
+ desc: "unsupported protocol version",
+ client: &Client{Properties: ClientProperties{ProtocolVersion: 3}},
+ capabilities: Capabilities{MinimumProtocolVersion: 4},
+ packet: packet,
+ expect: packets.ErrUnsupportedProtocolVersion,
+ },
+ {
+ desc: "will qos not supported",
+ client: &Client{Properties: ClientProperties{Will: Will{Qos: 2}}},
+ capabilities: Capabilities{MaximumQos: 1},
+ packet: packet,
+ expect: packets.ErrQosNotSupported,
+ },
+ {
+ desc: "retain not supported",
+ client: &Client{Properties: ClientProperties{Will: Will{Retain: true}}},
+ capabilities: Capabilities{RetainAvailable: 0},
+ packet: packet,
+ expect: packets.ErrRetainNotSupported,
+ },
+ {
+ desc: "invalid packet validate",
+ client: &Client{Properties: ClientProperties{Will: Will{Retain: true}}},
+ capabilities: Capabilities{RetainAvailable: 0},
+ packet: invalidBitPacket,
+ expect: packets.ErrProtocolViolationReservedBit,
+ },
+ {
+ desc: "mqtt3 clean no client id ",
+ client: &Client{Properties: ClientProperties{ProtocolVersion: 3}},
+ capabilities: Capabilities{},
+ packet: packetCleanIdPacket,
+ expect: packets.ErrUnspecifiedError,
+ },
+ }
+
+ s := newServer()
+ for _, tx := range tt {
+ t.Run(tx.desc, func(t *testing.T) {
+ s.Options.Capabilities = &tx.capabilities
+ err := s.validateConnect(tx.client, tx.packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, tx.expect)
+ })
+ }
+}
+
+func TestServerSendConnackAdjustedExpiryInterval(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ cl.Properties.Props.SessionExpiryInterval = uint32(300)
+ s.Options.Capabilities.MaximumSessionExpiryInterval = 120
+ go func() {
+ err := s.SendConnack(cl, packets.CodeSuccess, false, nil)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedAdjustedExpiryInterval).RawBytes, buf)
+}
+
+func TestInheritClientSession(t *testing.T) {
+ s := newServer()
+
+ n := time.Now().Unix()
+
+ existing, _, _ := newTestClient()
+ existing.Net.Conn = nil
+ existing.ID = "mochi"
+ existing.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
+ existing.State.Inflight = NewInflights()
+ existing.State.Inflight.Set(packets.Packet{PacketID: 1, Created: n - 1})
+ existing.State.Inflight.Set(packets.Packet{PacketID: 2, Created: n - 2})
+
+ s.Clients.Add(existing)
+
+ cl, _, _ := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+
+ require.Equal(t, 0, cl.State.Inflight.Len())
+ require.Equal(t, 0, cl.State.Subscriptions.Len())
+
+ // Inherit existing client properties
+ b := s.inheritClientSession(packets.Packet{Connect: packets.ConnectParams{ClientIdentifier: "mochi"}}, cl)
+ require.True(t, b)
+ require.Equal(t, 2, cl.State.Inflight.Len())
+ require.Equal(t, 1, cl.State.Subscriptions.Len())
+
+ // On clean, clear existing properties
+ cl, _, _ = newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ b = s.inheritClientSession(packets.Packet{Connect: packets.ConnectParams{ClientIdentifier: "mochi", Clean: true}}, cl)
+ require.False(t, b)
+ require.Equal(t, 0, cl.State.Inflight.Len())
+ require.Equal(t, 0, cl.State.Subscriptions.Len())
+}
+
+func TestServerUnsubscribeClient(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ pk := packets.Subscription{Filter: "a/b/c", Qos: 1}
+ cl.State.Subscriptions.Add("a/b/c", pk)
+ s.Topics.Subscribe(cl.ID, pk)
+ subs := s.Topics.Subscribers("a/b/c")
+ require.Equal(t, 1, len(subs.Subscriptions))
+ s.UnsubscribeClient(cl)
+ subs = s.Topics.Subscribers("a/b/c")
+ require.Equal(t, 0, len(subs.Subscriptions))
+}
+
+func TestServerProcessPacketFailure(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ err := s.processPacket(cl, packets.Packet{})
+ require.Error(t, err)
+}
+
+func TestServerProcessPacketConnect(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+
+ err := s.processPacket(cl, *packets.TPacketData[packets.Connect].Get(packets.TConnectClean).Packet)
+ require.Error(t, err)
+}
+
+func TestServerProcessPacketPingreq(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Pingresp].Get(packets.TPingresp).RawBytes, buf)
+}
+
+func TestServerProcessPacketPingreqError(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.Stop(packets.CodeDisconnect)
+
+ err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, cl.StopCause(), packets.CodeDisconnect)
+}
+
+func TestServerProcessPacketPublishInvalid(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+
+ err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishInvalidQosMustPacketID).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID)
+}
+
+func TestInjectPacketPublishAndReceive(t *testing.T) {
+ s := newServer()
+ _ = s.Serve()
+ defer s.Close()
+
+ sender, _, w1 := newTestClient()
+ sender.Net.Inline = true
+ sender.ID = "sender"
+ s.Clients.Add(sender)
+
+ receiver, r2, w2 := newTestClient()
+ receiver.ID = "receiver"
+ s.Clients.Add(receiver)
+ s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
+
+ require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
+
+ receiverBuf := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r2)
+ require.NoError(t, err)
+ receiverBuf <- buf
+ }()
+
+ go func() {
+ err := s.InjectPacket(sender, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
+ require.NoError(t, err)
+ _ = w1.Close()
+ time.Sleep(time.Millisecond * 10)
+ _ = w2.Close()
+ }()
+
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
+}
+
+func TestServerPublishAndReceive(t *testing.T) {
+ s := newServerWithInlineClient()
+
+ _ = s.Serve()
+ defer s.Close()
+
+ sender, _, w1 := newTestClient()
+ sender.Net.Inline = true
+ sender.ID = "sender"
+ s.Clients.Add(sender)
+
+ receiver, r2, w2 := newTestClient()
+ receiver.ID = "receiver"
+ s.Clients.Add(receiver)
+ s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
+
+ require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
+
+ receiverBuf := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r2)
+ require.NoError(t, err)
+ receiverBuf <- buf
+ }()
+
+ go func() {
+ pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
+ err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos)
+ require.NoError(t, err)
+ _ = w1.Close()
+ time.Sleep(time.Millisecond * 10)
+ _ = w2.Close()
+ }()
+
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
+}
+
+func TestServerPublishNoInlineClient(t *testing.T) {
+ s := newServer()
+ pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
+ err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos)
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrInlineClientNotEnabled)
+}
+
+func TestInjectPacketError(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+ cl, _, _ := newTestClient()
+ cl.Net.Inline = true
+ pkx := *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet
+ pkx.Filters = packets.Subscriptions{}
+ err := s.InjectPacket(cl, pkx)
+ require.Error(t, err)
+}
+
+func TestInjectPacketPublishInvalidTopic(t *testing.T) {
+ s := newServer()
+ defer s.Close()
+ cl, _, _ := newTestClient()
+ cl.Net.Inline = true
+ pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
+ pkx.TopicName = "$SYS/test"
+ err := s.InjectPacket(cl, pkx)
+ require.NoError(t, err) // bypass topic validity and acl checks
+}
+
+func TestServerProcessPacketPublishAndReceive(t *testing.T) {
+ s := newServer()
+ _ = s.Serve()
+ defer s.Close()
+
+ sender, _, w1 := newTestClient()
+ sender.ID = "sender"
+ s.Clients.Add(sender)
+
+ receiver, r2, w2 := newTestClient()
+ receiver.ID = "receiver"
+ s.Clients.Add(receiver)
+ s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
+
+ require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
+ require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
+
+ receiverBuf := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r2)
+ require.NoError(t, err)
+ receiverBuf <- buf
+ }()
+
+ go func() {
+ err := s.processPacket(sender, *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
+ require.NoError(t, err)
+ time.Sleep(time.Millisecond * 10)
+ _ = w1.Close()
+ _ = w2.Close()
+ }()
+
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
+ require.Equal(t, 1, len(s.Topics.Messages("a/b/c")))
+}
+
+func TestServerBuildAck(t *testing.T) {
+ s := newServer()
+ properties := packets.Properties{
+ User: []packets.UserProperty{
+ {Key: "hello", Val: "世界"},
+ },
+ }
+ ack := s.buildAck(7, packets.Puback, 1, properties, packets.CodeGrantedQos1)
+ require.Equal(t, packets.Puback, ack.FixedHeader.Type)
+ require.Equal(t, uint8(1), ack.FixedHeader.Qos)
+ require.Equal(t, packets.CodeGrantedQos1.Code, ack.ReasonCode)
+ require.Equal(t, properties, ack.Properties)
+}
+
+func TestServerBuildAckError(t *testing.T) {
+ s := newServer()
+ properties := packets.Properties{
+ User: []packets.UserProperty{
+ {Key: "hello", Val: "世界"},
+ },
+ }
+ ack := s.buildAck(7, packets.Puback, 1, properties, packets.ErrMalformedPacket)
+ require.Equal(t, packets.Puback, ack.FixedHeader.Type)
+ require.Equal(t, uint8(1), ack.FixedHeader.Qos)
+ require.Equal(t, packets.ErrMalformedPacket.Code, ack.ReasonCode)
+ properties.ReasonString = packets.ErrMalformedPacket.Reason
+ require.Equal(t, properties, ack.Properties)
+}
+
+func TestServerBuildAckPahoCompatibility(t *testing.T) {
+ s := newServer()
+ s.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true
+ properties := packets.Properties{
+ User: []packets.UserProperty{
+ {Key: "hello", Val: "世界"},
+ },
+ }
+ ack := s.buildAck(7, packets.Puback, 1, properties, packets.CodeGrantedQos1)
+ require.Equal(t, packets.Puback, ack.FixedHeader.Type)
+ require.Equal(t, uint8(1), ack.FixedHeader.Qos)
+ require.Equal(t, packets.CodeGrantedQos1.Code, ack.ReasonCode)
+ require.Equal(t, packets.Properties{}, ack.Properties)
+}
+
+func TestServerProcessPacketAndNextImmediate(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+
+ next := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
+ next.Expiry = -1
+ cl.State.Inflight.Set(next)
+ atomic.StoreInt64(&s.Info.Inflight, 1)
+ require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Inflight))
+ require.Equal(t, int32(5), cl.State.Inflight.sendQuota)
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, buf)
+ require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight))
+ require.Equal(t, int32(4), cl.State.Inflight.sendQuota)
+}
+
+func TestServerProcessPublishAckFailure(t *testing.T) {
+ s := newServer()
+ _ = s.Serve()
+ defer s.Close()
+
+ cl, _, w := newTestClient()
+ s.Clients.Add(cl)
+
+ _ = w.Close()
+ err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, io.ErrClosedPipe)
+}
+
+func TestServerProcessPublishOnPublishAckErrorRWError(t *testing.T) {
+ s := newServer()
+ hook := new(modifiedHookBase)
+ hook.fail = true
+ hook.err = packets.ErrUnspecifiedError
+ err := s.AddHook(hook, nil)
+ require.NoError(t, err)
+
+ cl, _, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ s.Clients.Add(cl)
+ _ = w.Close()
+
+ err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, io.ErrClosedPipe)
+}
+
+func TestServerProcessPublishOnPublishAckErrorContinue(t *testing.T) {
+ s := newServer()
+ hook := new(modifiedHookBase)
+ hook.fail = true
+ hook.err = packets.ErrPayloadFormatInvalid
+ err := s.AddHook(hook, nil)
+ require.NoError(t, err)
+ _ = s.Serve()
+ defer s.Close()
+
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ s.Clients.Add(cl)
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPubackUnexpectedError).RawBytes, buf)
+}
+
+func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) {
+ s := newServer()
+ hook := new(modifiedHookBase)
+ hook.fail = true
+ hook.err = packets.CodeSuccessIgnore
+ err := s.AddHook(hook, nil)
+ require.NoError(t, err)
+ _ = s.Serve()
+ defer s.Close()
+
+ cl, r, w := newTestClient()
+ s.Clients.Add(cl)
+
+ receiver, r2, w2 := newTestClient()
+ receiver.ID = "receiver"
+ s.Clients.Add(receiver)
+ s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
+
+ require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
+ require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
+
+ receiverBuf := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r2)
+ require.NoError(t, err)
+ receiverBuf <- buf
+ }()
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ _ = w2.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf)
+ require.Equal(t, []byte{}, <-receiverBuf)
+ require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
+}
+
+func TestServerProcessPacketPublishMaximumReceive(t *testing.T) {
+ s := newServer()
+ _ = s.Serve()
+ defer s.Close()
+
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ cl.State.Inflight.ResetReceiveQuota(0)
+ s.Clients.Add(cl)
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrReceiveMaximum)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectReceiveMaximum).RawBytes, buf)
+}
+
+func TestServerProcessPublishInvalidTopic(t *testing.T) {
+ s := newServer()
+ _ = s.Serve()
+ defer s.Close()
+ cl, _, _ := newTestClient()
+ err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishSpecDenySysTopic).Packet)
+ require.NoError(t, err) // $SYS Topics should be ignored?
+}
+
+func TestServerProcessPublishACLCheckDeny(t *testing.T) {
+ tt := []struct {
+ name string
+ protocolVersion byte
+ pk packets.Packet
+ expectErr error
+ expectReponse []byte
+ expectDisconnect bool
+ }{
+ {
+ name: "v4_QOS0",
+ protocolVersion: 4,
+ pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet,
+ expectErr: nil,
+ expectReponse: nil,
+ expectDisconnect: false,
+ },
+ {
+ name: "v4_QOS1",
+ protocolVersion: 4,
+ pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet,
+ expectErr: packets.ErrNotAuthorized,
+ expectReponse: nil,
+ expectDisconnect: true,
+ },
+ {
+ name: "v4_QOS2",
+ protocolVersion: 4,
+ pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet,
+ expectErr: packets.ErrNotAuthorized,
+ expectReponse: nil,
+ expectDisconnect: true,
+ },
+ {
+ name: "v5_QOS0",
+ protocolVersion: 5,
+ pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet,
+ expectErr: nil,
+ expectReponse: nil,
+ expectDisconnect: false,
+ },
+ {
+ name: "v5_QOS1",
+ protocolVersion: 5,
+ pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Mqtt5).Packet,
+ expectErr: nil,
+ expectReponse: packets.TPacketData[packets.Puback].Get(packets.TPubrecMqtt5NotAuthorized).RawBytes,
+ expectDisconnect: false,
+ },
+ {
+ name: "v5_QOS2",
+ protocolVersion: 5,
+ pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2Mqtt5).Packet,
+ expectErr: nil,
+ expectReponse: packets.TPacketData[packets.Pubrec].Get(packets.TPubrecMqtt5NotAuthorized).RawBytes,
+ expectDisconnect: false,
+ },
+ }
+
+ for _, tx := range tt {
+ t.Run(tx.name, func(t *testing.T) {
+ cc := NewDefaultServerCapabilities()
+ s := New(&Options{
+ Logger: logger,
+ Capabilities: cc,
+ })
+ _ = s.AddHook(new(DenyHook), nil)
+ _ = s.Serve()
+ defer s.Close()
+
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = tx.protocolVersion
+ s.Clients.Add(cl)
+
+ wg := sync.WaitGroup{}
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ err := s.processPublish(cl, tx.pk)
+ require.ErrorIs(t, err, tx.expectErr)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+
+ if tx.expectReponse != nil {
+ require.Equal(t, tx.expectReponse, buf)
+ }
+
+ require.Equal(t, tx.expectDisconnect, cl.Closed())
+ wg.Wait()
+ })
+ }
+}
+
+func TestServerProcessPublishOnMessageRecvRejected(t *testing.T) {
+ s := newServer()
+ require.NotNil(t, s)
+ hook := new(modifiedHookBase)
+ hook.fail = true
+ hook.err = packets.ErrRejectPacket
+
+ err := s.AddHook(hook, nil)
+ require.NoError(t, err)
+
+ _ = s.Serve()
+ defer s.Close()
+ cl, _, _ := newTestClient()
+ err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
+ require.NoError(t, err) // packets rejected silently
+}
+
+func TestServerProcessPacketPublishQos0(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, []byte{}, buf)
+}
+
+func TestServerProcessPacketPublishQos1PacketIDInUse(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.State.Inflight.Set(packets.Packet{PacketID: 7, FixedHeader: packets.FixedHeader{Type: packets.Publish}})
+ atomic.StoreInt64(&s.Info.Inflight, 1)
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf)
+ require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight))
+}
+
+func TestServerProcessPacketPublishQos2PacketIDInUse(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ cl.State.Inflight.Set(packets.Packet{PacketID: 7, FixedHeader: packets.FixedHeader{Type: packets.Pubrec}})
+ atomic.StoreInt64(&s.Info.Inflight, 1)
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2Mqtt5).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Pubrec].Get(packets.TPubrecMqtt5IDInUse).RawBytes, buf)
+ require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Inflight))
+}
+
+func TestServerProcessPacketPublishQos1(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf)
+}
+
+func TestServerProcessPacketPublishQos2(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).RawBytes, buf)
+}
+
+func TestServerProcessPacketPublishDowngradeQos(t *testing.T) {
+ s := newServer()
+ s.Options.Capabilities.MaximumQos = 1
+ cl, r, w := newTestClient()
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf)
+}
+
+func TestPublishToSubscribersSelfNoLocal(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ s.Clients.Add(cl)
+ subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", NoLocal: true})
+ require.True(t, subbed)
+
+ go func() {
+ pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
+ pkx.Origin = cl.ID
+ s.publishToSubscribers(pkx)
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+ }()
+
+ receiverBuf := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ receiverBuf <- buf
+ }()
+
+ require.Equal(t, []byte{}, <-receiverBuf)
+}
+
+func TestPublishToSubscribers(t *testing.T) {
+ s := newServer()
+ cl, r1, w1 := newTestClient()
+ cl.ID = "cl1"
+ cl2, r2, w2 := newTestClient()
+ cl2.ID = "cl2"
+ cl3, r3, w3 := newTestClient()
+ cl3.ID = "cl3"
+ s.Clients.Add(cl)
+ s.Clients.Add(cl2)
+ s.Clients.Add(cl3)
+ require.True(t, s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c"}))
+ require.True(t, s.Topics.Subscribe(cl2.ID, packets.Subscription{Filter: SharePrefix + "/tmp/a/b/c"}))
+ require.True(t, s.Topics.Subscribe(cl3.ID, packets.Subscription{Filter: SharePrefix + "/tmp/a/b/c"}))
+
+ cl1Recv := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r1)
+ require.NoError(t, err)
+ cl1Recv <- buf
+ }()
+
+ cl2Recv := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r2)
+ require.NoError(t, err)
+ cl2Recv <- buf
+ }()
+
+ cl3Recv := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r3)
+ require.NoError(t, err)
+ cl3Recv <- buf
+ }()
+
+ go func() {
+ s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
+ time.Sleep(time.Millisecond)
+ _ = w1.Close()
+ _ = w2.Close()
+ _ = w3.Close()
+ }()
+
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-cl1Recv)
+ rcv2 := <-cl2Recv
+ rcv3 := <-cl3Recv
+
+ ok := false
+ if len(rcv2) > 0 {
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, rcv2)
+ require.Equal(t, []byte{}, rcv3)
+ ok = true
+ } else if len(rcv3) > 0 {
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, rcv3)
+ require.Equal(t, []byte{}, rcv2)
+ ok = true
+ }
+ require.True(t, ok)
+}
+
+func TestPublishToSubscribersMessageExpiryDelta(t *testing.T) {
+ s := newServer()
+ s.Options.Capabilities.MaximumMessageExpiryInterval = 86400
+ cl, r1, w1 := newTestClient()
+ cl.ID = "cl1"
+ cl.Properties.ProtocolVersion = 5
+ s.Clients.Add(cl)
+ require.True(t, s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c"}))
+
+ cl1Recv := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r1)
+ require.NoError(t, err)
+ cl1Recv <- buf
+ }()
+
+ go func() {
+ pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
+ pkx.Created = time.Now().Unix() - 30
+ s.publishToSubscribers(pkx)
+ time.Sleep(time.Millisecond)
+ _ = w1.Close()
+ }()
+
+ b := <-cl1Recv
+ pk := new(packets.Packet)
+ pk.ProtocolVersion = 5
+ require.Equal(t, uint32(s.Options.Capabilities.MaximumMessageExpiryInterval-30), binary.BigEndian.Uint32(b[11:15]))
+}
+
+func TestPublishToSubscribersIdentifiers(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ s.Clients.Add(cl)
+ subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/+", Identifier: 2})
+ require.True(t, subbed)
+ subbed = s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/#", Identifier: 3})
+ require.True(t, subbed)
+ subbed = s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "d/e/f", Identifier: 4})
+ require.True(t, subbed)
+
+ go func() {
+ s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+ }()
+
+ receiverBuf := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ receiverBuf <- buf
+ }()
+
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishSubscriberIdentifier).RawBytes, <-receiverBuf)
+}
+
+func TestPublishToSubscribersPkIgnore(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ s.Clients.Add(cl)
+ subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "#", Identifier: 1})
+ require.True(t, subbed)
+
+ go func() {
+ pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
+ pk.Ignore = true
+ s.publishToSubscribers(pk)
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+ }()
+
+ receiverBuf := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ receiverBuf <- buf
+ }()
+
+ require.Equal(t, []byte{}, <-receiverBuf)
+}
+
+func TestPublishToClientServerDowngradeQos(t *testing.T) {
+ s := newServer()
+ s.Options.Capabilities.MaximumQos = 1
+
+ cl, r, w := newTestClient()
+ s.Clients.Add(cl)
+
+ _, ok := cl.State.Inflight.Get(1)
+ require.False(t, ok)
+ cl.State.packetID = 6 // just to match the same packet id (7) in the fixtures
+
+ go func() {
+ pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
+ pkx.FixedHeader.Qos = 2
+ _, _ = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, pkx)
+ time.Sleep(time.Microsecond * 100)
+ _ = w.Close()
+ }()
+
+ receiverBuf := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ receiverBuf <- buf
+ }()
+
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, <-receiverBuf)
+}
+
+func TestPublishToClientSubscriptionDowngradeQos(t *testing.T) {
+ s := newServer()
+ s.Options.Capabilities.MaximumQos = 2
+
+ cl, r, w := newTestClient()
+ s.Clients.Add(cl)
+
+ _, ok := cl.State.Inflight.Get(1)
+ require.False(t, ok)
+ cl.State.packetID = 6 // just to match the same packet id (7) in the fixtures
+
+ go func() {
+ pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
+ pkx.FixedHeader.Qos = 2
+ _, _ = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, pkx)
+ time.Sleep(time.Microsecond * 100)
+ _ = w.Close()
+ }()
+
+ receiverBuf := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ receiverBuf <- buf
+ }()
+
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, <-receiverBuf)
+}
+
+func TestPublishToClientExceedClientWritesPending(t *testing.T) {
+ var sendQuota uint16 = 5
+ s := newServer()
+
+ _, w := net.Pipe()
+ cl := newClient(w, &ops{
+ info: new(system.Info),
+ hooks: new(Hooks),
+ log: logger,
+ options: &Options{
+ Capabilities: &Capabilities{
+ MaximumClientWritesPending: 3,
+ maximumPacketID: 10,
+ },
+ },
+ })
+ cl.Properties.Props.ReceiveMaximum = sendQuota
+ cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum))
+
+ s.Clients.Add(cl)
+
+ for i := int32(0); i < cl.ops.options.Capabilities.MaximumClientWritesPending; i++ {
+ cl.State.outbound <- new(packets.Packet)
+ atomic.AddInt32(&cl.State.outboundQty, 1)
+ }
+
+ id, _ := cl.NextPacketID()
+ cl.State.Inflight.Set(packets.Packet{PacketID: uint16(id)})
+ cl.State.Inflight.DecreaseSendQuota()
+ sendQuota--
+
+ _, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, packets.Packet{})
+ require.Error(t, err)
+ require.ErrorIs(t, packets.ErrPendingClientWritesExceeded, err)
+ require.Equal(t, int32(sendQuota), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
+
+ _, err = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, packets.Packet{FixedHeader: packets.FixedHeader{Qos: 1}})
+ require.Error(t, err)
+ require.ErrorIs(t, packets.ErrPendingClientWritesExceeded, err)
+ require.Equal(t, int32(sendQuota), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
+}
+
+func TestPublishToClientServerTopicAlias(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ cl.Properties.Props.TopicAliasMaximum = 5
+ s.Clients.Add(cl)
+
+ go func() {
+ pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet
+ _, _ = s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx)
+ _, _ = s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx)
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+ }()
+
+ receiverBuf := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ receiverBuf <- buf
+ }()
+
+ ret := <-receiverBuf
+ pk1 := make([]byte, len(packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).RawBytes))
+ pk2 := make([]byte, len(packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).RawBytes)-5)
+ copy(pk1, ret[:len(packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).RawBytes)])
+ copy(pk2, ret[len(packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).RawBytes):])
+ require.Equal(t, append(pk1, pk2...), ret)
+}
+
+func TestPublishToClientMqtt3RetainFalseLeverageNoConn(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.Net.Conn = nil
+
+ out, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", RetainAsPublished: true}, *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
+ require.False(t, out.FixedHeader.Retain)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.CodeDisconnect)
+}
+
+func TestPublishToClientMqtt5RetainAsPublishedTrueLeverageNoConn(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ cl.Net.Conn = nil
+
+ out, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", RetainAsPublished: true}, *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
+ require.True(t, out.FixedHeader.Retain)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.CodeDisconnect)
+}
+
+func TestPublishToClientExceedMaximumInflight(t *testing.T) {
+ const MaxInflight uint16 = 5
+ s := newServer()
+ cl, _, _ := newTestClient()
+ s.Options.Capabilities.MaximumInflight = MaxInflight
+ cl.ops.options.Capabilities.MaximumInflight = MaxInflight
+ for i := uint16(0); i < MaxInflight; i++ {
+ cl.State.Inflight.Set(packets.Packet{PacketID: i})
+ }
+
+ _, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrQuotaExceeded)
+ require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.InflightDropped))
+}
+
+func TestPublishToClientExhaustedPacketID(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ for i := uint32(0); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
+ cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
+ }
+
+ _, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrQuotaExceeded)
+ require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.InflightDropped))
+}
+
+func TestPublishToClientACLNotAuthorized(t *testing.T) {
+ s := New(&Options{
+ Logger: logger,
+ })
+ err := s.AddHook(new(DenyHook), nil)
+ require.NoError(t, err)
+ cl, _, _ := newTestClient()
+
+ _, err = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrNotAuthorized)
+}
+
+func TestPublishToClientNoConn(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.Net.Conn = nil
+
+ _, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.CodeDisconnect)
+}
+
+func TestProcessPublishWithTopicAlias(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ s.Clients.Add(cl)
+ subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 0})
+ require.True(t, subbed)
+
+ cl2, _, w2 := newTestClient()
+ cl2.Properties.ProtocolVersion = 5
+ cl2.State.TopicAliases.Inbound.Set(1, "a/b/c")
+
+ go func() {
+ pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishMqtt5).Packet
+ pkx.Properties.SubscriptionIdentifier = []int{} // must not contain from client to server
+ pkx.TopicName = ""
+ pkx.Properties.TopicAlias = 1
+ _ = s.processPacket(cl2, pkx)
+ time.Sleep(time.Millisecond)
+ _ = w2.Close()
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, buf)
+}
+
+func TestPublishToSubscribersExhaustedSendQuota(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ s.Clients.Add(cl)
+ cl.State.Inflight.sendQuota = 0
+
+ subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
+ require.True(t, subbed)
+
+ // coverage: subscriber publish errors are non-returnable
+ // can we hook into log/slog ?
+ _ = r.Close()
+ pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
+ pkx.PacketID = 0
+ s.publishToSubscribers(pkx)
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+}
+
+func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ s.Clients.Add(cl)
+ for i := uint32(0); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
+ cl.State.Inflight.Set(packets.Packet{PacketID: 1})
+ }
+
+ subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
+ require.True(t, subbed)
+
+ // coverage: subscriber publish errors are non-returnable
+ // can we hook into log/slog ?
+ _ = r.Close()
+ pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
+ pkx.PacketID = 0
+ s.publishToSubscribers(pkx)
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+}
+
+func TestPublishToSubscribersNoConnection(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ s.Clients.Add(cl)
+ subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
+ require.True(t, subbed)
+
+ // coverage: subscriber publish errors are non-returnable
+ // can we hook into log/slog ?
+ _ = r.Close()
+ s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+}
+
+func TestPublishRetainedToClient(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ s.Clients.Add(cl)
+
+ subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
+ require.True(t, subbed)
+
+ retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetainMqtt5).Packet)
+ require.Equal(t, int64(1), retained)
+
+ go func() {
+ s.publishRetainedToClient(cl, packets.Subscription{Filter: "a/b/c"}, false)
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes, buf)
+}
+
+func TestPublishRetainedToClientIsShared(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ s.Clients.Add(cl)
+
+ sub := packets.Subscription{Filter: SharePrefix + "/test/a/b/c"}
+ subbed := s.Topics.Subscribe(cl.ID, sub)
+ require.True(t, subbed)
+
+ go func() {
+ s.publishRetainedToClient(cl, sub, false)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, []byte{}, buf)
+}
+
+func TestPublishRetainedToClientError(t *testing.T) {
+ s := newServer()
+ cl, _, w := newTestClient()
+ s.Clients.Add(cl)
+
+ sub := packets.Subscription{Filter: "a/b/c"}
+ subbed := s.Topics.Subscribe(cl.ID, sub)
+ require.True(t, subbed)
+
+ retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
+ require.Equal(t, int64(1), retained)
+
+ _ = w.Close()
+ s.publishRetainedToClient(cl, sub, false)
+}
+
+func TestNoRetainMessageIfUnavailable(t *testing.T) {
+ s := newServer()
+ s.Options.Capabilities.RetainAvailable = 0
+ cl, _, _ := newTestClient()
+ s.Clients.Add(cl)
+
+ s.retainMessage(new(Client), *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
+ require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Retained))
+}
+
+func TestNoRetainMessageIfPkIgnore(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ s.Clients.Add(cl)
+
+ pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet
+ pk.Ignore = true
+ s.retainMessage(new(Client), pk)
+ require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Retained))
+}
+
+func TestNoRetainMessage(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ s.Clients.Add(cl)
+
+ s.retainMessage(new(Client), *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
+ require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Retained))
+}
+
+func TestServerProcessPacketPuback(t *testing.T) {
+ tt := ProtocolTest{
+ {
+ protocolVersion: 4,
+ in: packets.TPacketData[packets.Puback].Get(packets.TPuback),
+ },
+ {
+ protocolVersion: 5,
+ in: packets.TPacketData[packets.Puback].Get(packets.TPubackMqtt5),
+ },
+ }
+
+ for _, tx := range tt {
+ t.Run(strconv.Itoa(int(tx.protocolVersion)), func(t *testing.T) {
+ pID := uint16(7)
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.State.Inflight.sendQuota = 3
+ cl.State.Inflight.receiveQuota = 3
+
+ cl.State.Inflight.Set(packets.Packet{PacketID: pID})
+ atomic.AddInt64(&s.Info.Inflight, 1)
+
+ err := s.processPacket(cl, *tx.in.Packet)
+ require.NoError(t, err)
+
+ require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
+ require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
+
+ require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight))
+ _, ok := cl.State.Inflight.Get(pID)
+ require.False(t, ok)
+ })
+ }
+}
+
+func TestServerProcessPacketPubackNoPacketID(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.State.Inflight.sendQuota = 3
+ cl.State.Inflight.receiveQuota = 3
+
+ pk := *packets.TPacketData[packets.Puback].Get(packets.TPuback).Packet
+ err := s.processPacket(cl, pk)
+ require.NoError(t, err)
+
+ require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
+ require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
+}
+
+func TestServerProcessPacketPubrec(t *testing.T) {
+ pID := uint16(7)
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.State.Inflight.sendQuota = 3
+ cl.State.Inflight.receiveQuota = 3
+
+ cl.State.Inflight.Set(packets.Packet{PacketID: pID})
+ atomic.AddInt64(&s.Info.Inflight, 1)
+
+ recv := make(chan []byte)
+ go func() { // receive the ack
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+
+ require.Equal(t, packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).RawBytes, <-recv)
+
+ require.Equal(t, int32(2), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
+ require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
+ require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Inflight))
+ _, ok := cl.State.Inflight.Get(pID)
+ require.True(t, ok)
+}
+
+func TestServerProcessPacketPubrecNoPacketID(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ cl.State.Inflight.sendQuota = 3
+ cl.State.Inflight.receiveQuota = 3
+
+ recv := make(chan []byte)
+ go func() { // receive the ack
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ pk := *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet // not sending properties
+ err := s.processPacket(cl, pk)
+ require.NoError(t, err)
+ _ = w.Close()
+
+ require.Equal(t, packets.TPacketData[packets.Pubrel].Get(packets.TPubrelMqtt5AckNoPacket).RawBytes, <-recv)
+
+ require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
+ require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
+}
+
+func TestServerProcessPacketPubrecInvalidReason(t *testing.T) {
+ pID := uint16(7)
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.State.Inflight.Set(packets.Packet{PacketID: pID})
+ err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrecInvalidReason).Packet)
+ require.NoError(t, err)
+ require.Equal(t, int64(-1), atomic.LoadInt64(&s.Info.Inflight))
+ _, ok := cl.State.Inflight.Get(pID)
+ require.False(t, ok)
+}
+
+func TestServerProcessPacketPubrecFailure(t *testing.T) {
+ pID := uint16(7)
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.State.Inflight.Set(packets.Packet{PacketID: pID})
+ cl.Stop(packets.CodeDisconnect)
+ err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, cl.StopCause(), packets.CodeDisconnect)
+}
+
+func TestServerProcessPacketPubrel(t *testing.T) {
+ pID := uint16(7)
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.State.Inflight.sendQuota = 3
+ cl.State.Inflight.receiveQuota = 3
+
+ cl.State.Inflight.Set(packets.Packet{PacketID: pID})
+ atomic.AddInt64(&s.Info.Inflight, 1)
+
+ recv := make(chan []byte)
+ go func() { // receive the ack
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+
+ require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
+ require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
+
+ require.Equal(t, packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp).RawBytes, <-recv)
+
+ require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight))
+ _, ok := cl.State.Inflight.Get(pID)
+ require.False(t, ok)
+}
+
+func TestServerProcessPacketPubrelNoPacketID(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ cl.State.Inflight.sendQuota = 3
+ cl.State.Inflight.receiveQuota = 3
+
+ recv := make(chan []byte)
+ go func() { // receive the ack
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ pk := *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet // not sending properties
+ err := s.processPacket(cl, pk)
+ require.NoError(t, err)
+ _ = w.Close()
+
+ require.Equal(t, packets.TPacketData[packets.Pubcomp].Get(packets.TPubcompMqtt5AckNoPacket).RawBytes, <-recv)
+
+ require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
+ require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
+}
+
+func TestServerProcessPacketPubrelFailure(t *testing.T) {
+ pID := uint16(7)
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.State.Inflight.Set(packets.Packet{PacketID: pID})
+ cl.Stop(packets.CodeDisconnect)
+ err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, cl.StopCause(), packets.CodeDisconnect)
+}
+
+func TestServerProcessPacketPubrelBadReason(t *testing.T) {
+ pID := uint16(7)
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.State.Inflight.Set(packets.Packet{PacketID: pID})
+ err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrelInvalidReason).Packet)
+ require.NoError(t, err)
+ require.Equal(t, int64(-1), atomic.LoadInt64(&s.Info.Inflight))
+ _, ok := cl.State.Inflight.Get(pID)
+ require.False(t, ok)
+}
+
+func TestServerProcessPacketPubcomp(t *testing.T) {
+ tt := ProtocolTest{
+ {
+ protocolVersion: 4,
+ in: packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp),
+ },
+ {
+ protocolVersion: 5,
+ in: packets.TPacketData[packets.Pubcomp].Get(packets.TPubcompMqtt5),
+ },
+ }
+
+ for _, tx := range tt {
+ t.Run(strconv.Itoa(int(tx.protocolVersion)), func(t *testing.T) {
+ pID := uint16(7)
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.Properties.ProtocolVersion = tx.protocolVersion
+ cl.State.Inflight.sendQuota = 3
+ cl.State.Inflight.receiveQuota = 3
+
+ cl.State.Inflight.Set(packets.Packet{PacketID: pID})
+ atomic.AddInt64(&s.Info.Inflight, 1)
+
+ err := s.processPacket(cl, *tx.in.Packet)
+ require.NoError(t, err)
+ require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight))
+
+ require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
+ require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
+
+ _, ok := cl.State.Inflight.Get(pID)
+ require.False(t, ok)
+ })
+ }
+}
+
+func TestServerProcessInboundQos2Flow(t *testing.T) {
+ tt := ProtocolTest{
+ {
+ protocolVersion: 5,
+ in: packets.TPacketData[packets.Publish].Get(packets.TPublishQos2),
+ out: packets.TPacketData[packets.Pubrec].Get(packets.TPubrec),
+ data: map[string]any{
+ "sendquota": int32(3),
+ "recvquota": int32(2),
+ "inflight": int64(1),
+ },
+ },
+ {
+ protocolVersion: 5,
+ in: packets.TPacketData[packets.Pubrel].Get(packets.TPubrel),
+ out: packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp),
+ data: map[string]any{
+ "sendquota": int32(4),
+ "recvquota": int32(3),
+ "inflight": int64(0),
+ },
+ },
+ }
+
+ pID := uint16(7)
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.State.Inflight.sendQuota = 3
+ cl.State.Inflight.receiveQuota = 3
+
+ for i, tx := range tt {
+ t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) {
+ r, w = net.Pipe()
+ cl.Net.Conn = w
+
+ recv := make(chan []byte)
+ go func() { // receive the ack
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ err := s.processPacket(cl, *tx.in.Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+
+ require.Equal(t, tx.out.RawBytes, <-recv)
+ if i == 0 {
+ _, ok := cl.State.Inflight.Get(pID)
+ require.True(t, ok)
+ }
+
+ require.Equal(t, tx.data["inflight"].(int64), atomic.LoadInt64(&s.Info.Inflight))
+ require.Equal(t, tx.data["recvquota"].(int32), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
+ require.Equal(t, tx.data["sendquota"].(int32), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
+ })
+ }
+
+ _, ok := cl.State.Inflight.Get(pID)
+ require.False(t, ok)
+}
+
+func TestServerProcessOutboundQos2Flow(t *testing.T) {
+ tt := ProtocolTest{
+ {
+ protocolVersion: 5,
+ in: packets.TPacketData[packets.Publish].Get(packets.TPublishQos2),
+ out: packets.TPacketData[packets.Publish].Get(packets.TPublishQos2),
+ data: map[string]any{
+ "sendquota": int32(2),
+ "recvquota": int32(3),
+ "inflight": int64(1),
+ },
+ },
+ {
+ protocolVersion: 5,
+ in: packets.TPacketData[packets.Pubrec].Get(packets.TPubrec),
+ out: packets.TPacketData[packets.Pubrel].Get(packets.TPubrel),
+ data: map[string]any{
+ "sendquota": int32(2),
+ "recvquota": int32(2),
+ "inflight": int64(1),
+ },
+ },
+ {
+ protocolVersion: 5,
+ in: packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp),
+ data: map[string]any{
+ "sendquota": int32(3),
+ "recvquota": int32(3),
+ "inflight": int64(0),
+ },
+ },
+ }
+
+ pID := uint16(6)
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.State.packetID = uint32(6)
+ cl.State.Inflight.sendQuota = 3
+ cl.State.Inflight.receiveQuota = 3
+ s.Clients.Add(cl)
+ s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
+
+ for i, tx := range tt {
+ t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) {
+ r, w := net.Pipe()
+ time.Sleep(time.Millisecond)
+ cl.Net.Conn = w
+
+ recv := make(chan []byte)
+ go func() { // receive the ack
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ if i == 0 {
+ s.publishToSubscribers(*tx.in.Packet)
+ } else {
+ err := s.processPacket(cl, *tx.in.Packet)
+ require.NoError(t, err)
+ }
+
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+
+ if i != 2 {
+ require.Equal(t, tx.out.RawBytes, <-recv)
+ }
+
+ require.Equal(t, tx.data["inflight"].(int64), atomic.LoadInt64(&s.Info.Inflight))
+ require.Equal(t, tx.data["recvquota"].(int32), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
+ require.Equal(t, tx.data["sendquota"].(int32), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
+ })
+ }
+
+ _, ok := cl.State.Inflight.Get(pID)
+ require.False(t, ok)
+}
+
+func TestServerProcessPacketSubscribe(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackMqtt5).RawBytes, buf)
+}
+
+func TestServerProcessPacketSubscribePacketIDInUse(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ cl.State.Inflight.Set(packets.Packet{PacketID: 15, FixedHeader: packets.FixedHeader{Type: packets.Publish}})
+
+ pkx := *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5).Packet
+ pkx.PacketID = 15
+ go func() {
+ err := s.processPacket(cl, pkx)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackPacketIDInUse).RawBytes, buf)
+}
+
+func TestServerProcessPacketSubscribeInvalid(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+
+ err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeSpecQosMustPacketID).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID)
+}
+
+func TestServerProcessPacketSubscribeInvalidFilter(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidFilter).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackInvalidFilter).RawBytes, buf)
+}
+
+func TestServerProcessPacketSubscribeInvalidSharedNoLocal(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidSharedNoLocal).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackInvalidSharedNoLocal).RawBytes, buf)
+}
+
+func TestServerProcessSubscribeWithRetain(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+
+ retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
+ require.Equal(t, int64(1), retained)
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet)
+ require.NoError(t, err)
+
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, append(
+ packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes,
+ packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes...,
+ ), buf)
+}
+
+func TestServerProcessSubscribeDowngradeQos(t *testing.T) {
+ s := newServer()
+ s.Options.Capabilities.MaximumQos = 1
+ cl, r, w := newTestClient()
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMany).Packet)
+ require.NoError(t, err)
+
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, []byte{0, 1, 1}, buf[4:])
+}
+
+func TestServerProcessSubscribeWithRetainHandling1(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c"})
+ s.Clients.Add(cl)
+
+ retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
+ require.Equal(t, int64(1), retained)
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeRetainHandling1).Packet)
+ require.NoError(t, err)
+
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes, buf)
+}
+
+func TestServerProcessSubscribeWithRetainHandling2(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ s.Clients.Add(cl)
+
+ retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
+ require.Equal(t, int64(1), retained)
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeRetainHandling2).Packet)
+ require.NoError(t, err)
+
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes, buf)
+}
+
+func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ s.Clients.Add(cl)
+
+ retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
+ require.Equal(t, int64(1), retained)
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeRetainAsPublished).Packet)
+ require.NoError(t, err)
+
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, append(
+ packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes,
+ packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes...,
+ ), buf)
+}
+
+func TestServerProcessSubscribeNoConnection(t *testing.T) {
+ s := newServer()
+ cl, r, _ := newTestClient()
+ _ = r.Close()
+ err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, io.ErrClosedPipe)
+}
+
+func TestServerProcessSubscribeACLCheckDeny(t *testing.T) {
+ s := New(&Options{
+ Logger: logger,
+ })
+ _ = s.Serve()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+
+ go func() {
+ err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackDeny).RawBytes, buf)
+}
+
+func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) {
+ s := New(&Options{
+ Logger: logger,
+ })
+ _ = s.Serve()
+ s.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+
+ go func() {
+ err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackUnspecifiedErrorMqtt5).RawBytes, buf)
+}
+
+func TestServerProcessSubscribeErrorDowngrade(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 3
+ cl.State.packetID = 1 // just to match the same packet id (7) in the fixtures
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidSharedNoLocal).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackUnspecifiedError).RawBytes, buf)
+}
+
+func TestServerProcessPacketUnsubscribe(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b", Qos: 0})
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackMqtt5).RawBytes, buf)
+ require.Equal(t, int64(-1), atomic.LoadInt64(&s.Info.Subscriptions))
+}
+
+func TestServerProcessPacketUnsubscribePackedIDInUse(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.ProtocolVersion = 5
+ cl.State.Inflight.Set(packets.Packet{PacketID: 15, FixedHeader: packets.FixedHeader{Type: packets.Publish}})
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackPacketIDInUse).RawBytes, buf)
+ require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Subscriptions))
+}
+
+func TestServerProcessPacketUnsubscribeInvalid(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeSpecQosMustPacketID).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID)
+}
+
+func TestServerReceivePacketError(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ err := s.receivePacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeSpecQosMustPacketID).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID)
+}
+
+func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+ cl.Properties.Props.SessionExpiryInterval = 0
+ cl.Properties.ProtocolVersion = 5
+ cl.Properties.Props.RequestProblemInfo = 0
+ cl.Properties.Props.RequestProblemInfoFlag = true
+ go func() {
+ err := s.receivePacket(cl, *packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrProtocolViolationZeroNonZeroExpiry)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectZeroNonZeroExpiry).RawBytes, buf)
+}
+
+func TestServerRecievePacketDisconnectClient(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+
+ go func() {
+ err := s.DisconnectClient(cl, packets.CodeDisconnect)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, buf)
+}
+
+func TestServerProcessPacketDisconnect(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.Properties.Props.SessionExpiryInterval = 30
+ cl.Properties.ProtocolVersion = 5
+
+ s.loop.willDelayed.Add(cl.ID, packets.Packet{TopicName: "a/b/c", Payload: []byte("hello")})
+ require.Equal(t, 1, s.loop.willDelayed.Len())
+
+ err := s.processPacket(cl, *packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5).Packet)
+ require.NoError(t, err)
+
+ require.Equal(t, 0, s.loop.willDelayed.Len())
+ require.True(t, cl.Closed())
+ require.Equal(t, time.Now().Unix(), atomic.LoadInt64(&cl.State.disconnected))
+}
+
+func TestServerProcessPacketDisconnectNonZeroExpiryViolation(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.Properties.Props.SessionExpiryInterval = 0
+ cl.Properties.ProtocolVersion = 5
+ cl.Properties.Props.RequestProblemInfo = 0
+ cl.Properties.Props.RequestProblemInfoFlag = true
+
+ err := s.processPacket(cl, *packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, err, packets.ErrProtocolViolationZeroNonZeroExpiry)
+}
+
+func TestServerProcessPacketDisconnectDisconnectWithWillMessage(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ cl.Properties.Props.SessionExpiryInterval = 30
+ cl.Properties.ProtocolVersion = 5
+
+ s.loop.willDelayed.Add(cl.ID, packets.Packet{TopicName: "a/b/c", Payload: []byte("hello")})
+ require.Equal(t, 1, s.loop.willDelayed.Len())
+
+ err := s.processPacket(cl, *packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5DisconnectWithWillMessage).Packet)
+ require.Error(t, err)
+
+ require.Equal(t, 1, s.loop.willDelayed.Len())
+ require.False(t, cl.Closed())
+}
+
+func TestServerProcessPacketAuth(t *testing.T) {
+ s := newServer()
+ cl, r, w := newTestClient()
+
+ go func() {
+ err := s.processPacket(cl, *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet)
+ require.NoError(t, err)
+ _ = w.Close()
+ }()
+
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, []byte{}, buf)
+}
+
+func TestServerProcessPacketAuthInvalidReason(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+ pkx := *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet
+ pkx.ReasonCode = 99
+ err := s.processPacket(cl, pkx)
+ require.Error(t, err)
+ require.ErrorIs(t, packets.ErrProtocolViolationInvalidReason, err)
+}
+
+func TestServerProcessPacketAuthFailure(t *testing.T) {
+ s := newServer()
+ cl, _, _ := newTestClient()
+
+ hook := new(modifiedHookBase)
+ hook.fail = true
+ err := s.AddHook(hook, nil)
+ require.NoError(t, err)
+
+ err = s.processAuth(cl, *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet)
+ require.Error(t, err)
+ require.ErrorIs(t, errTestHook, err)
+}
+
+func TestServerSendLWT(t *testing.T) {
+ s := newServer()
+ _ = s.Serve()
+ defer s.Close()
+
+ sender, _, w1 := newTestClient()
+ sender.ID = "sender"
+ sender.Properties.Will = Will{
+ Flag: 1,
+ TopicName: "a/b/c",
+ Payload: []byte("hello mochi"),
+ }
+ s.Clients.Add(sender)
+
+ receiver, r2, w2 := newTestClient()
+ receiver.ID = "receiver"
+ s.Clients.Add(receiver)
+ s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c", Qos: 0})
+
+ require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
+ require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
+
+ receiverBuf := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r2)
+ require.NoError(t, err)
+ receiverBuf <- buf
+ }()
+
+ go func() {
+ s.sendLWT(sender)
+ time.Sleep(time.Millisecond * 10)
+ _ = w1.Close()
+ _ = w2.Close()
+ }()
+
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
+}
+
+func TestServerSendLWTRetain(t *testing.T) {
+ s := newServer()
+ _ = s.Serve()
+ defer s.Close()
+
+ sender, _, w1 := newTestClient()
+ sender.ID = "sender"
+ sender.Properties.Will = Will{
+ Flag: 1,
+ TopicName: "a/b/c",
+ Payload: []byte("hello mochi"),
+ Retain: true,
+ }
+ s.Clients.Add(sender)
+
+ receiver, r2, w2 := newTestClient()
+ receiver.ID = "receiver"
+ s.Clients.Add(receiver)
+ s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c", Qos: 0})
+
+ require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
+ require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
+
+ receiverBuf := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r2)
+ require.NoError(t, err)
+ receiverBuf <- buf
+ }()
+
+ go func() {
+ s.sendLWT(sender)
+ time.Sleep(time.Millisecond * 10)
+ _ = w1.Close()
+ _ = w2.Close()
+ }()
+
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
+}
+
+func TestServerSendLWTDelayed(t *testing.T) {
+ s := newServer()
+ cl1, _, _ := newTestClient()
+ cl1.ID = "cl1"
+ cl1.Properties.Will = Will{
+ Flag: 1,
+ TopicName: "a/b/c",
+ Payload: []byte("hello mochi"),
+ Retain: true,
+ WillDelayInterval: 2,
+ }
+ s.Clients.Add(cl1)
+
+ cl2, r, w := newTestClient()
+ cl2.ID = "cl2"
+ s.Clients.Add(cl2)
+ require.True(t, s.Topics.Subscribe(cl2.ID, packets.Subscription{Filter: "a/b/c"}))
+
+ go func() {
+ s.sendLWT(cl1)
+ pk, ok := s.loop.willDelayed.Get(cl1.ID)
+ require.True(t, ok)
+ pk.Expiry = time.Now().Unix() - 1 // set back expiry time
+ s.loop.willDelayed.Add(cl1.ID, pk)
+ require.Equal(t, 1, s.loop.willDelayed.Len())
+ s.sendDelayedLWT(time.Now().Unix())
+ require.Equal(t, 0, s.loop.willDelayed.Len())
+ time.Sleep(time.Millisecond)
+ _ = w.Close()
+ }()
+
+ recv := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-recv)
+}
+
+func TestServerReadStore(t *testing.T) {
+ s := newServer()
+ hook := new(modifiedHookBase)
+ _ = s.AddHook(hook, nil)
+
+ hook.failAt = 1 // clients
+ err := s.readStore()
+ require.Error(t, err)
+
+ hook.failAt = 2 // subscriptions
+ err = s.readStore()
+ require.Error(t, err)
+
+ hook.failAt = 3 // inflight
+ err = s.readStore()
+ require.Error(t, err)
+
+ hook.failAt = 4 // retained
+ err = s.readStore()
+ require.Error(t, err)
+
+ hook.failAt = 5 // sys info
+ err = s.readStore()
+ require.Error(t, err)
+}
+
+func TestServerLoadClients(t *testing.T) {
+ v := []storage.Client{
+ {ID: "mochi"},
+ {ID: "zen"},
+ {ID: "mochi-co"},
+ {ID: "v3-clean", ProtocolVersion: 4, Clean: true},
+ {ID: "v3-not-clean", ProtocolVersion: 4, Clean: false},
+ {
+ ID: "v5-clean",
+ ProtocolVersion: 5,
+ Clean: true,
+ Properties: storage.ClientProperties{
+ SessionExpiryInterval: 10,
+ },
+ },
+ {
+ ID: "v5-expire-interval-0",
+ ProtocolVersion: 5,
+ Properties: storage.ClientProperties{
+ SessionExpiryInterval: 0,
+ },
+ },
+ {
+ ID: "v5-expire-interval-not-0",
+ ProtocolVersion: 5,
+ Properties: storage.ClientProperties{
+ SessionExpiryInterval: 10,
+ },
+ },
+ }
+
+ s := newServer()
+ require.Equal(t, 0, s.Clients.Len())
+ s.loadClients(v)
+ require.Equal(t, 6, s.Clients.Len())
+ cl, ok := s.Clients.Get("mochi")
+ require.True(t, ok)
+ require.Equal(t, "mochi", cl.ID)
+
+ _, ok = s.Clients.Get("v3-clean")
+ require.False(t, ok)
+ _, ok = s.Clients.Get("v3-not-clean")
+ require.True(t, ok)
+ _, ok = s.Clients.Get("v5-clean")
+ require.True(t, ok)
+ _, ok = s.Clients.Get("v5-expire-interval-0")
+ require.False(t, ok)
+ _, ok = s.Clients.Get("v5-expire-interval-not-0")
+ require.True(t, ok)
+}
+
+func TestServerLoadSubscriptions(t *testing.T) {
+ v := []storage.Subscription{
+ {ID: "sub1", Client: "mochi", Filter: "a/b/c"},
+ {ID: "sub2", Client: "mochi", Filter: "d/e/f", Qos: 1},
+ {ID: "sub3", Client: "mochi", Filter: "h/i/j", Qos: 2},
+ }
+
+ s := newServer()
+ cl, _, _ := newTestClient()
+ s.Clients.Add(cl)
+ require.Equal(t, 0, cl.State.Subscriptions.Len())
+ s.loadSubscriptions(v)
+ require.Equal(t, 3, cl.State.Subscriptions.Len())
+}
+
+func TestServerLoadInflightMessages(t *testing.T) {
+ s := newServer()
+ s.loadClients([]storage.Client{
+ {ID: "mochi"},
+ {ID: "zen"},
+ {ID: "mochi-co"},
+ })
+
+ require.Equal(t, 3, s.Clients.Len())
+
+ v := []storage.Message{
+ {Origin: "mochi", PacketID: 1, Payload: []byte("hello world"), TopicName: "a/b/c"},
+ {Origin: "mochi", PacketID: 2, Payload: []byte("yes"), TopicName: "a/b/c"},
+ {Origin: "zen", PacketID: 3, Payload: []byte("hello world"), TopicName: "a/b/c"},
+ {Origin: "mochi-co", PacketID: 4, Payload: []byte("hello world"), TopicName: "a/b/c"},
+ }
+ s.loadInflight(v)
+
+ cl, ok := s.Clients.Get("mochi")
+ require.True(t, ok)
+ require.Equal(t, "mochi", cl.ID)
+
+ msg, ok := cl.State.Inflight.Get(2)
+ require.True(t, ok)
+ require.Equal(t, []byte{'y', 'e', 's'}, msg.Payload)
+ require.Equal(t, "a/b/c", msg.TopicName)
+
+ cl, ok = s.Clients.Get("mochi-co")
+ require.True(t, ok)
+ msg, ok = cl.State.Inflight.Get(4)
+ require.True(t, ok)
+}
+
+func TestServerLoadRetainedMessages(t *testing.T) {
+ s := newServer()
+
+ v := []storage.Message{
+ {Origin: "mochi", FixedHeader: packets.FixedHeader{Retain: true}, Payload: []byte("hello world"), TopicName: "a/b/c"},
+ {Origin: "mochi-co", FixedHeader: packets.FixedHeader{Retain: true}, Payload: []byte("yes"), TopicName: "d/e/f"},
+ {Origin: "zen", FixedHeader: packets.FixedHeader{Retain: true}, Payload: []byte("hello world"), TopicName: "h/i/j"},
+ }
+ s.loadRetained(v)
+ require.Equal(t, 1, len(s.Topics.Messages("a/b/c")))
+ require.Equal(t, 1, len(s.Topics.Messages("d/e/f")))
+ require.Equal(t, 1, len(s.Topics.Messages("h/i/j")))
+ require.Equal(t, 0, len(s.Topics.Messages("w/x/y")))
+}
+
+func TestServerClose(t *testing.T) {
+ s := newServer()
+
+ hook := new(modifiedHookBase)
+ _ = s.AddHook(hook, nil)
+
+ cl, r, _ := newTestClient()
+ cl.Net.Listener = "t1"
+ cl.Properties.ProtocolVersion = 5
+ s.Clients.Add(cl)
+
+ err := s.AddListener(listeners.NewMockListener("t1", ":1882"))
+ require.NoError(t, err)
+ _ = s.Serve()
+
+ // receive the disconnect
+ recv := make(chan []byte)
+ go func() {
+ buf, err := io.ReadAll(r)
+ require.NoError(t, err)
+ recv <- buf
+ }()
+
+ time.Sleep(time.Millisecond)
+ require.Equal(t, 1, s.Listeners.Len())
+
+ listener, ok := s.Listeners.Get("t1")
+ require.Equal(t, true, ok)
+ require.Equal(t, true, listener.(*listeners.MockListener).IsServing())
+
+ _ = s.Close()
+ time.Sleep(time.Millisecond)
+ require.Equal(t, false, listener.(*listeners.MockListener).IsServing())
+ require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectShuttingDown).RawBytes, <-recv)
+}
+
+func TestServerClearExpiredInflights(t *testing.T) {
+ s := New(nil)
+ require.NotNil(t, s)
+ s.Options.Capabilities.MaximumMessageExpiryInterval = 4
+
+ n := time.Now().Unix()
+ cl, _, _ := newTestClient()
+ cl.ops.info = s.Info
+
+ cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
+ cl.State.Inflight.Set(packets.Packet{PacketID: 2, Expiry: n - 2})
+ cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: n - 3}) // within bounds
+ cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: n - 5}) // over max server expiry limit
+ cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
+
+ s.Clients.Add(cl)
+
+ require.Len(t, cl.State.Inflight.GetAll(false), 5)
+ s.clearExpiredInflights(n)
+ require.Len(t, cl.State.Inflight.GetAll(false), 2)
+ require.Equal(t, int64(-3), s.Info.Inflight)
+
+ s.Options.Capabilities.MaximumMessageExpiryInterval = 0
+ cl.State.Inflight.Set(packets.Packet{PacketID: 8, Expiry: n - 8})
+ s.clearExpiredInflights(n)
+ require.Len(t, cl.State.Inflight.GetAll(false), 3)
+}
+
+func TestServerClearExpiredRetained(t *testing.T) {
+ s := New(nil)
+ require.NotNil(t, s)
+ s.Options.Capabilities.MaximumMessageExpiryInterval = 4
+
+ n := time.Now().Unix()
+ s.Topics.Retained.Add("a/b/c", packets.Packet{ProtocolVersion: 5, Created: n, Expiry: n - 1})
+ s.Topics.Retained.Add("d/e/f", packets.Packet{ProtocolVersion: 5, Created: n, Expiry: n - 2})
+ s.Topics.Retained.Add("g/h/i", packets.Packet{ProtocolVersion: 5, Created: n - 3}) // within bounds
+ s.Topics.Retained.Add("j/k/l", packets.Packet{ProtocolVersion: 5, Created: n - 5}) // over max server expiry limit
+ s.Topics.Retained.Add("m/n/o", packets.Packet{ProtocolVersion: 5, Created: n})
+
+ require.Len(t, s.Topics.Retained.GetAll(), 5)
+ s.clearExpiredRetainedMessages(n)
+ require.Len(t, s.Topics.Retained.GetAll(), 2)
+
+ s.Topics.Retained.Add("p/q/r", packets.Packet{Created: n, Expiry: n - 1})
+ s.Topics.Retained.Add("s/t/u", packets.Packet{Created: n, Expiry: n - 2}) // expiry is ineffective for v3.
+ s.Topics.Retained.Add("v/w/x", packets.Packet{Created: n - 3}) // within bounds for v3
+ s.Topics.Retained.Add("y/z/1", packets.Packet{Created: n - 5}) // over max server expiry limit
+ require.Len(t, s.Topics.Retained.GetAll(), 6)
+ s.clearExpiredRetainedMessages(n)
+ require.Len(t, s.Topics.Retained.GetAll(), 5)
+
+ s.Options.Capabilities.MaximumMessageExpiryInterval = 0
+ s.Topics.Retained.Add("2/3/4", packets.Packet{Created: n - 8})
+ s.clearExpiredRetainedMessages(n)
+ require.Len(t, s.Topics.Retained.GetAll(), 6)
+}
+
+func TestServerClearExpiredClients(t *testing.T) {
+ s := New(nil)
+ require.NotNil(t, s)
+
+ n := time.Now().Unix()
+
+ cl, _, _ := newTestClient()
+ cl.ID = "cl"
+ s.Clients.Add(cl)
+
+ // No Expiry
+ cl0, _, _ := newTestClient()
+ cl0.ID = "c0"
+ cl0.State.disconnected = n - 10
+ cl0.State.cancelOpen()
+ cl0.Properties.ProtocolVersion = 5
+ cl0.Properties.Props.SessionExpiryInterval = 12
+ cl0.Properties.Props.SessionExpiryIntervalFlag = true
+ s.Clients.Add(cl0)
+
+ // Normal Expiry
+ cl1, _, _ := newTestClient()
+ cl1.ID = "c1"
+ cl1.State.disconnected = n - 10
+ cl1.State.cancelOpen()
+ cl1.Properties.ProtocolVersion = 5
+ cl1.Properties.Props.SessionExpiryInterval = 8
+ cl1.Properties.Props.SessionExpiryIntervalFlag = true
+ s.Clients.Add(cl1)
+
+ // No Expiry, indefinite session
+ cl2, _, _ := newTestClient()
+ cl2.ID = "c2"
+ cl2.State.disconnected = n - 10
+ cl2.State.cancelOpen()
+ cl2.Properties.ProtocolVersion = 5
+ cl2.Properties.Props.SessionExpiryInterval = 0
+ cl2.Properties.Props.SessionExpiryIntervalFlag = true
+ s.Clients.Add(cl2)
+
+ require.Equal(t, 4, s.Clients.Len())
+
+ s.clearExpiredClients(n)
+ require.Equal(t, 2, s.Clients.Len())
+}
+
+func TestLoadServerInfoRestoreOnRestart(t *testing.T) {
+ s := New(nil)
+ s.Options.Capabilities.Compatibilities.RestoreSysInfoOnRestart = true
+ info := system.Info{
+ BytesReceived: 60,
+ }
+
+ s.loadServerInfo(info)
+ require.Equal(t, int64(60), s.Info.BytesReceived)
+}
+
+func TestItoa(t *testing.T) {
+ i := int64(22)
+ require.Equal(t, "22", Int64toa(i))
+}
+
+func TestServerSubscribe(t *testing.T) {
+ handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {}
+
+ s := newServerWithInlineClient()
+ require.NotNil(t, s)
+
+ tt := []struct {
+ desc string
+ filter string
+ identifier int
+ handler InlineSubFn
+ expect error
+ }{
+ {
+ desc: "subscribe",
+ filter: "a/b/c",
+ identifier: 1,
+ handler: handler,
+ expect: nil,
+ },
+ {
+ desc: "re-subscribe",
+ filter: "a/b/c",
+ identifier: 1,
+ handler: handler,
+ expect: nil,
+ },
+ {
+ desc: "subscribe d/e/f",
+ filter: "d/e/f",
+ identifier: 1,
+ handler: handler,
+ expect: nil,
+ },
+ {
+ desc: "re-subscribe d/e/f by different identifier",
+ filter: "d/e/f",
+ identifier: 2,
+ handler: handler,
+ expect: nil,
+ },
+ {
+ desc: "subscribe different handler",
+ filter: "a/b/c",
+ identifier: 1,
+ handler: func(cl *Client, sub packets.Subscription, pk packets.Packet) {},
+ expect: nil,
+ },
+ {
+ desc: "subscribe $SYS/info",
+ filter: "$SYS/info",
+ identifier: 1,
+ handler: handler,
+ expect: nil,
+ },
+ {
+ desc: "subscribe invalid ###",
+ filter: "###",
+ identifier: 1,
+ handler: handler,
+ expect: packets.ErrTopicFilterInvalid,
+ },
+ {
+ desc: "subscribe invalid handler",
+ filter: "a/b/c",
+ identifier: 1,
+ handler: nil,
+ expect: packets.ErrInlineSubscriptionHandlerInvalid,
+ },
+ }
+
+ for _, tx := range tt {
+ t.Run(tx.desc, func(t *testing.T) {
+ require.Equal(t, tx.expect, s.Subscribe(tx.filter, tx.identifier, tx.handler))
+ })
+ }
+}
+
+func TestServerSubscribeNoInlineClient(t *testing.T) {
+ s := newServer()
+ err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {})
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrInlineClientNotEnabled)
+}
+
+func TestServerUnsubscribe(t *testing.T) {
+ handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ // handler logic
+ }
+
+ s := newServerWithInlineClient()
+ err := s.Subscribe("a/b/c", 1, handler)
+ require.Nil(t, err)
+
+ err = s.Subscribe("d/e/f", 1, handler)
+ require.Nil(t, err)
+
+ err = s.Subscribe("d/e/f", 2, handler)
+ require.Nil(t, err)
+
+ err = s.Unsubscribe("a/b/c", 1)
+ require.Nil(t, err)
+
+ err = s.Unsubscribe("d/e/f", 1)
+ require.Nil(t, err)
+
+ err = s.Unsubscribe("d/e/f", 2)
+ require.Nil(t, err)
+
+ err = s.Unsubscribe("not/exist", 1)
+ require.Nil(t, err)
+
+ err = s.Unsubscribe("#/#/invalid", 1)
+ require.Equal(t, packets.ErrTopicFilterInvalid, err)
+}
+
+func TestServerUnsubscribeNoInlineClient(t *testing.T) {
+ s := newServer()
+ err := s.Unsubscribe("a/b/c", 1)
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrInlineClientNotEnabled)
+}
+
+func TestPublishToInlineSubscriber(t *testing.T) {
+ s := newServerWithInlineClient()
+ finishCh := make(chan bool)
+ err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ require.Equal(t, []byte("hello mochi"), pk.Payload)
+ require.Equal(t, InlineClientId, cl.ID)
+ require.Equal(t, LocalListener, cl.Net.Listener)
+ require.Equal(t, "a/b/c", sub.Filter)
+ require.Equal(t, 1, sub.Identifier)
+ finishCh <- true
+ })
+ require.Nil(t, err)
+
+ go func() {
+ pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
+ s.publishToSubscribers(pkx)
+ }()
+
+ require.Equal(t, true, <-finishCh)
+}
+
+func TestPublishToInlineSubscribersDifferentFilter(t *testing.T) {
+ s := newServerWithInlineClient()
+ subNumber := 2
+ finishCh := make(chan bool, subNumber)
+
+ err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ require.Equal(t, []byte("hello mochi"), pk.Payload)
+ require.Equal(t, InlineClientId, cl.ID)
+ require.Equal(t, LocalListener, cl.Net.Listener)
+ require.Equal(t, "a/b/c", sub.Filter)
+ require.Equal(t, 1, sub.Identifier)
+ finishCh <- true
+ })
+ require.Nil(t, err)
+
+ err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ require.Equal(t, []byte("mochi mochi"), pk.Payload)
+ require.Equal(t, InlineClientId, cl.ID)
+ require.Equal(t, LocalListener, cl.Net.Listener)
+ require.Equal(t, "z/e/n", sub.Filter)
+ require.Equal(t, 1, sub.Identifier)
+ finishCh <- true
+ })
+ require.Nil(t, err)
+
+ go func() {
+ pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
+ s.publishToSubscribers(pkx)
+
+ pkx = *packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet
+ s.publishToSubscribers(pkx)
+ }()
+
+ for i := 0; i < subNumber; i++ {
+ require.Equal(t, true, <-finishCh)
+ }
+}
+
+func TestPublishToInlineSubscribersDifferentIdentifier(t *testing.T) {
+ s := newServerWithInlineClient()
+ subNumber := 2
+ finishCh := make(chan bool, subNumber)
+
+ err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ require.Equal(t, []byte("hello mochi"), pk.Payload)
+ require.Equal(t, InlineClientId, cl.ID)
+ require.Equal(t, LocalListener, cl.Net.Listener)
+ require.Equal(t, "a/b/c", sub.Filter)
+ require.Equal(t, 1, sub.Identifier)
+ finishCh <- true
+ })
+ require.Nil(t, err)
+
+ err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ require.Equal(t, []byte("hello mochi"), pk.Payload)
+ require.Equal(t, InlineClientId, cl.ID)
+ require.Equal(t, LocalListener, cl.Net.Listener)
+ require.Equal(t, "a/b/c", sub.Filter)
+ require.Equal(t, 2, sub.Identifier)
+ finishCh <- true
+ })
+ require.Nil(t, err)
+
+ go func() {
+ pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
+ s.publishToSubscribers(pkx)
+ }()
+
+ for i := 0; i < subNumber; i++ {
+ require.Equal(t, true, <-finishCh)
+ }
+}
+
+func TestServerSubscribeWithRetain(t *testing.T) {
+ s := newServerWithInlineClient()
+ subNumber := 1
+ finishCh := make(chan bool, subNumber)
+
+ retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
+ require.Equal(t, int64(1), retained)
+
+ err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ require.Equal(t, []byte("hello mochi"), pk.Payload)
+ require.Equal(t, InlineClientId, cl.ID)
+ require.Equal(t, LocalListener, cl.Net.Listener)
+ require.Equal(t, "a/b/c", sub.Filter)
+ require.Equal(t, 1, sub.Identifier)
+ finishCh <- true
+ })
+ require.Nil(t, err)
+ require.Equal(t, true, <-finishCh)
+}
+
+func TestServerSubscribeWithRetainDifferentFilter(t *testing.T) {
+ s := newServerWithInlineClient()
+ subNumber := 2
+ finishCh := make(chan bool, subNumber)
+
+ retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
+ require.Equal(t, int64(1), retained)
+ retained = s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet)
+ require.Equal(t, int64(1), retained)
+
+ err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ require.Equal(t, []byte("hello mochi"), pk.Payload)
+ require.Equal(t, InlineClientId, cl.ID)
+ require.Equal(t, LocalListener, cl.Net.Listener)
+ require.Equal(t, "a/b/c", sub.Filter)
+ require.Equal(t, 1, sub.Identifier)
+ finishCh <- true
+ })
+ require.Nil(t, err)
+
+ err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ require.Equal(t, []byte("mochi mochi"), pk.Payload)
+ require.Equal(t, InlineClientId, cl.ID)
+ require.Equal(t, LocalListener, cl.Net.Listener)
+ require.Equal(t, "z/e/n", sub.Filter)
+ require.Equal(t, 1, sub.Identifier)
+ finishCh <- true
+ })
+ require.Nil(t, err)
+
+ for i := 0; i < subNumber; i++ {
+ require.Equal(t, true, <-finishCh)
+ }
+}
+
+func TestServerSubscribeWithRetainDifferentIdentifier(t *testing.T) {
+ s := newServerWithInlineClient()
+ subNumber := 2
+ finishCh := make(chan bool, subNumber)
+
+ retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
+ require.Equal(t, int64(1), retained)
+
+ err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ require.Equal(t, []byte("hello mochi"), pk.Payload)
+ require.Equal(t, InlineClientId, cl.ID)
+ require.Equal(t, LocalListener, cl.Net.Listener)
+ require.Equal(t, "a/b/c", sub.Filter)
+ require.Equal(t, 1, sub.Identifier)
+ finishCh <- true
+ })
+ require.Nil(t, err)
+
+ err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ require.Equal(t, []byte("hello mochi"), pk.Payload)
+ require.Equal(t, InlineClientId, cl.ID)
+ require.Equal(t, LocalListener, cl.Net.Listener)
+ require.Equal(t, "a/b/c", sub.Filter)
+ require.Equal(t, 2, sub.Identifier)
+ finishCh <- true
+ })
+ require.Nil(t, err)
+
+ for i := 0; i < subNumber; i++ {
+ require.Equal(t, true, <-finishCh)
+ }
+}
diff --git a/mqtt/topics.go b/mqtt/topics.go
new file mode 100644
index 0000000..f42e621
--- /dev/null
+++ b/mqtt/topics.go
@@ -0,0 +1,824 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package mqtt
+
+import (
+ "strings"
+ "sync"
+ "sync/atomic"
+
+ "testmqtt/packets"
+)
+
+var (
+ // SharePrefix 共享主题的前缀
+ SharePrefix = "$SHARE" // the prefix indicating a share topic
+ // SysPrefix 系统信息主题的前缀
+ SysPrefix = "$SYS" // the prefix indicating a system info topic
+)
+
+// TopicAliases contains inbound and outbound topic alias registrations.
+type TopicAliases struct {
+ Inbound *InboundTopicAliases
+ Outbound *OutboundTopicAliases
+}
+
+// NewTopicAliases returns an instance of TopicAliases.
+func NewTopicAliases(topicAliasMaximum uint16) TopicAliases {
+ return TopicAliases{
+ Inbound: NewInboundTopicAliases(topicAliasMaximum),
+ Outbound: NewOutboundTopicAliases(topicAliasMaximum),
+ }
+}
+
+// NewInboundTopicAliases returns a pointer to InboundTopicAliases.
+func NewInboundTopicAliases(topicAliasMaximum uint16) *InboundTopicAliases {
+ return &InboundTopicAliases{
+ maximum: topicAliasMaximum,
+ internal: map[uint16]string{},
+ }
+}
+
+// InboundTopicAliases contains a map of topic aliases received from the client.
+type InboundTopicAliases struct {
+ internal map[uint16]string
+ sync.RWMutex
+ maximum uint16
+}
+
+// Set sets a new alias for a specific topic.
+func (a *InboundTopicAliases) Set(id uint16, topic string) string {
+ a.Lock()
+ defer a.Unlock()
+
+ if a.maximum == 0 {
+ return topic // ?
+ }
+
+ if existing, ok := a.internal[id]; ok && topic == "" {
+ return existing
+ }
+
+ a.internal[id] = topic
+ return topic
+}
+
+// OutboundTopicAliases contains a map of topic aliases sent from the broker to the client.
+type OutboundTopicAliases struct {
+ internal map[string]uint16
+ sync.RWMutex
+ cursor uint32
+ maximum uint16
+}
+
+// NewOutboundTopicAliases returns a pointer to OutboundTopicAliases.
+func NewOutboundTopicAliases(topicAliasMaximum uint16) *OutboundTopicAliases {
+ return &OutboundTopicAliases{
+ maximum: topicAliasMaximum,
+ internal: map[string]uint16{},
+ }
+}
+
+// Set sets a new topic alias for a topic and returns the alias value, and a boolean
+// indicating if the alias already existed.
+func (a *OutboundTopicAliases) Set(topic string) (uint16, bool) {
+ a.Lock()
+ defer a.Unlock()
+
+ if a.maximum == 0 {
+ return 0, false
+ }
+
+ if i, ok := a.internal[topic]; ok {
+ return i, true
+ }
+
+ i := atomic.LoadUint32(&a.cursor)
+ if i+1 > uint32(a.maximum) {
+ // if i+1 > math.MaxUint16 {
+ return 0, false
+ }
+
+ a.internal[topic] = uint16(i) + 1
+ atomic.StoreUint32(&a.cursor, i+1)
+ return uint16(i) + 1, false
+}
+
+// SharedSubscriptions contains a map of subscriptions to a shared filter,
+// keyed on share group then client id.
+type SharedSubscriptions struct {
+ internal map[string]map[string]packets.Subscription
+ sync.RWMutex
+}
+
+// NewSharedSubscriptions returns a new instance of Subscriptions.
+func NewSharedSubscriptions() *SharedSubscriptions {
+ return &SharedSubscriptions{
+ internal: map[string]map[string]packets.Subscription{},
+ }
+}
+
+// Add creates a new shared subscription for a group and client id pair.
+func (s *SharedSubscriptions) Add(group, id string, val packets.Subscription) {
+ s.Lock()
+ defer s.Unlock()
+ if _, ok := s.internal[group]; !ok {
+ s.internal[group] = map[string]packets.Subscription{}
+ }
+ s.internal[group][id] = val
+}
+
+// Delete deletes a client id from a shared subscription group.
+func (s *SharedSubscriptions) Delete(group, id string) {
+ s.Lock()
+ defer s.Unlock()
+ delete(s.internal[group], id)
+ if len(s.internal[group]) == 0 {
+ delete(s.internal, group)
+ }
+}
+
+// Get returns the subscription properties for a client id in a share group, if one exists.
+func (s *SharedSubscriptions) Get(group, id string) (val packets.Subscription, ok bool) {
+ s.RLock()
+ defer s.RUnlock()
+ if _, ok := s.internal[group]; !ok {
+ return val, ok
+ }
+
+ val, ok = s.internal[group][id]
+ return val, ok
+}
+
+// GroupLen returns the number of groups subscribed to the filter.
+func (s *SharedSubscriptions) GroupLen() int {
+ s.RLock()
+ defer s.RUnlock()
+ val := len(s.internal)
+ return val
+}
+
+// Len returns the total number of shared subscriptions to a filter across all groups.
+func (s *SharedSubscriptions) Len() int {
+ s.RLock()
+ defer s.RUnlock()
+ n := 0
+ for _, group := range s.internal {
+ n += len(group)
+ }
+ return n
+}
+
+// GetAll returns all shared subscription groups and their subscriptions.
+func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscription {
+ s.RLock()
+ defer s.RUnlock()
+ m := map[string]map[string]packets.Subscription{}
+ for group, subs := range s.internal {
+ if _, ok := m[group]; !ok {
+ m[group] = map[string]packets.Subscription{}
+ }
+
+ for id, sub := range subs {
+ m[group][id] = sub
+ }
+ }
+ return m
+}
+
+// InlineSubFn is the signature for a callback function which will be called
+// when an inline client receives a message on a topic it is subscribed to.
+// The sub argument contains information about the subscription that was matched for any filters.
+type InlineSubFn func(cl *Client, sub packets.Subscription, pk packets.Packet)
+
+// InlineSubscriptions represents a map of internal subscriptions keyed on client.
+type InlineSubscriptions struct {
+ internal map[int]InlineSubscription
+ sync.RWMutex
+}
+
+// NewInlineSubscriptions returns a new instance of InlineSubscriptions.
+func NewInlineSubscriptions() *InlineSubscriptions {
+ return &InlineSubscriptions{
+ internal: map[int]InlineSubscription{},
+ }
+}
+
+// Add adds a new internal subscription for a client id.
+func (s *InlineSubscriptions) Add(val InlineSubscription) {
+ s.Lock()
+ defer s.Unlock()
+ s.internal[val.Identifier] = val
+}
+
+// GetAll returns all internal subscriptions.
+func (s *InlineSubscriptions) GetAll() map[int]InlineSubscription {
+ s.RLock()
+ defer s.RUnlock()
+ m := map[int]InlineSubscription{}
+ for k, v := range s.internal {
+ m[k] = v
+ }
+ return m
+}
+
+// Get returns an internal subscription for a client id.
+func (s *InlineSubscriptions) Get(id int) (val InlineSubscription, ok bool) {
+ s.RLock()
+ defer s.RUnlock()
+ val, ok = s.internal[id]
+ return val, ok
+}
+
+// Len returns the number of internal subscriptions.
+func (s *InlineSubscriptions) Len() int {
+ s.RLock()
+ defer s.RUnlock()
+ val := len(s.internal)
+ return val
+}
+
+// Delete removes an internal subscription by the client id.
+func (s *InlineSubscriptions) Delete(id int) {
+ s.Lock()
+ defer s.Unlock()
+ delete(s.internal, id)
+}
+
+// Subscriptions is a map of subscriptions keyed on client.
+type Subscriptions struct {
+ internal map[string]packets.Subscription
+ sync.RWMutex
+}
+
+// NewSubscriptions returns a new instance of Subscriptions.
+func NewSubscriptions() *Subscriptions {
+ return &Subscriptions{
+ internal: map[string]packets.Subscription{},
+ }
+}
+
+// Add adds a new subscription for a client. ID can be a filter in the
+// case this map is client state, or a client id if particle state.
+func (s *Subscriptions) Add(id string, val packets.Subscription) {
+ s.Lock()
+ defer s.Unlock()
+ s.internal[id] = val
+}
+
+// GetAll returns all subscriptions.
+func (s *Subscriptions) GetAll() map[string]packets.Subscription {
+ s.RLock()
+ defer s.RUnlock()
+ m := map[string]packets.Subscription{}
+ for k, v := range s.internal {
+ m[k] = v
+ }
+ return m
+}
+
+// Get returns a subscriptions for a specific client or filter id.
+func (s *Subscriptions) Get(id string) (val packets.Subscription, ok bool) {
+ s.RLock()
+ defer s.RUnlock()
+ val, ok = s.internal[id]
+ return val, ok
+}
+
+// Len returns the number of subscriptions.
+func (s *Subscriptions) Len() int {
+ s.RLock()
+ defer s.RUnlock()
+ val := len(s.internal)
+ return val
+}
+
+// Delete removes a subscription by client or filter id.
+func (s *Subscriptions) Delete(id string) {
+ s.Lock()
+ defer s.Unlock()
+ delete(s.internal, id)
+}
+
+// ClientSubscriptions is a map of aggregated subscriptions for a client.
+type ClientSubscriptions map[string]packets.Subscription
+
+type InlineSubscription struct {
+ packets.Subscription
+ Handler InlineSubFn
+}
+
+// Subscribers contains the shared and non-shared subscribers matching a topic.
+type Subscribers struct {
+ Shared map[string]map[string]packets.Subscription
+ SharedSelected map[string]packets.Subscription
+ Subscriptions map[string]packets.Subscription
+ InlineSubscriptions map[int]InlineSubscription
+}
+
+// SelectShared returns one subscriber for each shared subscription group.
+func (s *Subscribers) SelectShared() {
+ s.SharedSelected = map[string]packets.Subscription{}
+ for _, subs := range s.Shared {
+ for client, sub := range subs {
+ cls, ok := s.SharedSelected[client]
+ if !ok {
+ cls = sub
+ }
+
+ s.SharedSelected[client] = cls.Merge(sub)
+ break
+ }
+ }
+}
+
+// MergeSharedSelected merges the selected subscribers for a shared subscription group
+// and the non-shared subscribers, to ensure that no subscriber gets multiple messages
+// due to have both types of subscription matching the same filter.
+func (s *Subscribers) MergeSharedSelected() {
+ for client, sub := range s.SharedSelected {
+ cls, ok := s.Subscriptions[client]
+ if !ok {
+ cls = sub
+ }
+
+ s.Subscriptions[client] = cls.Merge(sub)
+ }
+}
+
+// TopicsIndex is a prefix/trie tree containing topic subscribers and retained messages.
+type TopicsIndex struct {
+ Retained *packets.Packets
+ root *particle // a leaf containing a message and more leaves.
+}
+
+// NewTopicsIndex returns a pointer to a new instance of Index.
+func NewTopicsIndex() *TopicsIndex {
+ return &TopicsIndex{
+ Retained: packets.NewPackets(),
+ root: &particle{
+ particles: newParticles(),
+ subscriptions: NewSubscriptions(),
+ },
+ }
+}
+
+// InlineSubscribe adds a new internal subscription for a topic filter, returning
+// true if the subscription was new.
+func (x *TopicsIndex) InlineSubscribe(subscription InlineSubscription) bool {
+ x.root.Lock()
+ defer x.root.Unlock()
+
+ var existed bool
+ n := x.set(subscription.Filter, 0)
+ _, existed = n.inlineSubscriptions.Get(subscription.Identifier)
+ n.inlineSubscriptions.Add(subscription)
+
+ return !existed
+}
+
+// InlineUnsubscribe removes an internal subscription for a topic filter associated with a specific client,
+// returning true if the subscription existed.
+func (x *TopicsIndex) InlineUnsubscribe(id int, filter string) bool {
+ x.root.Lock()
+ defer x.root.Unlock()
+
+ particle := x.seek(filter, 0)
+ if particle == nil {
+ return false
+ }
+
+ particle.inlineSubscriptions.Delete(id)
+
+ if particle.inlineSubscriptions.Len() == 0 {
+ x.trim(particle)
+ }
+ return true
+}
+
+// Subscribe adds a new subscription for a client to a topic filter, returning
+// true if the subscription was new.
+func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
+ x.root.Lock()
+ defer x.root.Unlock()
+
+ var existed bool
+ prefix, _ := isolateParticle(subscription.Filter, 0)
+ if strings.EqualFold(prefix, SharePrefix) {
+ group, _ := isolateParticle(subscription.Filter, 1)
+ n := x.set(subscription.Filter, 2)
+ _, existed = n.shared.Get(group, client)
+ n.shared.Add(group, client, subscription)
+ } else {
+ n := x.set(subscription.Filter, 0)
+ _, existed = n.subscriptions.Get(client)
+ n.subscriptions.Add(client, subscription)
+ }
+
+ return !existed
+}
+
+// Unsubscribe removes a subscription filter for a client, returning true if the
+// subscription existed.
+func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
+ x.root.Lock()
+ defer x.root.Unlock()
+
+ var d int
+ prefix, _ := isolateParticle(filter, 0)
+ shareSub := strings.EqualFold(prefix, SharePrefix)
+ if shareSub {
+ d = 2
+ }
+
+ particle := x.seek(filter, d)
+ if particle == nil {
+ return false
+ }
+
+ if shareSub {
+ group, _ := isolateParticle(filter, 1)
+ particle.shared.Delete(group, client)
+ } else {
+ particle.subscriptions.Delete(client)
+ }
+
+ x.trim(particle)
+ return true
+}
+
+// RetainMessage saves a message payload to the end of a topic address. Returns
+// 1 if a retained message was added, and -1 if the retained message was removed.
+// 0 is returned if sequential empty payloads are received.
+func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
+ x.root.Lock()
+ defer x.root.Unlock()
+
+ n := x.set(pk.TopicName, 0)
+ n.Lock()
+ defer n.Unlock()
+ if len(pk.Payload) > 0 {
+ n.retainPath = pk.TopicName
+ x.Retained.Add(pk.TopicName, pk)
+ return 1
+ }
+
+ var out int64
+ if pke, ok := x.Retained.Get(pk.TopicName); ok && len(pke.Payload) > 0 && pke.FixedHeader.Retain {
+ out = -1 // if a retained packet existed, return -1
+ }
+
+ n.retainPath = ""
+ x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7]
+ x.trim(n)
+
+ return out
+}
+
+// set creates a topic address in the index and returns the final particle.
+func (x *TopicsIndex) set(topic string, d int) *particle {
+ var key string
+ var hasNext = true
+ n := x.root
+ for hasNext {
+ key, hasNext = isolateParticle(topic, d)
+ d++
+
+ p := n.particles.get(key)
+ if p == nil {
+ p = newParticle(key, n)
+ n.particles.add(p)
+ }
+ n = p
+ }
+
+ return n
+}
+
+// seek finds the particle at a specific index in a topic filter.
+func (x *TopicsIndex) seek(filter string, d int) *particle {
+ var key string
+ var hasNext = true
+ n := x.root
+ for hasNext {
+ key, hasNext = isolateParticle(filter, d)
+ n = n.particles.get(key)
+ d++
+ if n == nil {
+ return nil
+ }
+ }
+
+ return n
+}
+
+// trim removes empty filter particles from the index.
+func (x *TopicsIndex) trim(n *particle) {
+ for n.parent != nil && n.retainPath == "" && n.particles.len()+n.subscriptions.Len()+n.shared.Len()+n.inlineSubscriptions.Len() == 0 {
+ key := n.key
+ n = n.parent
+ n.particles.delete(key)
+ }
+}
+
+// Messages returns a slice of any retained messages which match a filter.
+func (x *TopicsIndex) Messages(filter string) []packets.Packet {
+ return x.scanMessages(filter, 0, nil, []packets.Packet{})
+}
+
+// scanMessages returns all retained messages on topics matching a given filter.
+func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []packets.Packet) []packets.Packet {
+ if n == nil {
+ n = x.root
+ }
+
+ if len(filter) == 0 || x.Retained.Len() == 0 {
+ return pks
+ }
+
+ if !strings.ContainsRune(filter, '#') && !strings.ContainsRune(filter, '+') {
+ if pk, ok := x.Retained.Get(filter); ok {
+ pks = append(pks, pk)
+ }
+ return pks
+ }
+
+ key, hasNext := isolateParticle(filter, d)
+ if key == "+" || key == "#" || d == -1 {
+ for _, adjacent := range n.particles.getAll() {
+ if d == 0 && adjacent.key == SysPrefix {
+ continue
+ }
+
+ if !hasNext {
+ if adjacent.retainPath != "" {
+ if pk, ok := x.Retained.Get(adjacent.retainPath); ok {
+ pks = append(pks, pk)
+ }
+ }
+ }
+
+ if hasNext || (d >= 0 && key == "#") {
+ pks = x.scanMessages(filter, d+1, adjacent, pks)
+ }
+ }
+ return pks
+ }
+
+ if particle := n.particles.get(key); particle != nil {
+ if hasNext {
+ return x.scanMessages(filter, d+1, particle, pks)
+ }
+
+ if pk, ok := x.Retained.Get(particle.retainPath); ok {
+ pks = append(pks, pk)
+ }
+ }
+
+ return pks
+}
+
+// Subscribers returns a map of clients who are subscribed to matching filters,
+// their subscription ids and highest qos.
+func (x *TopicsIndex) Subscribers(topic string) *Subscribers {
+ return x.scanSubscribers(topic, 0, nil, &Subscribers{
+ Shared: map[string]map[string]packets.Subscription{},
+ SharedSelected: map[string]packets.Subscription{},
+ Subscriptions: map[string]packets.Subscription{},
+ InlineSubscriptions: map[int]InlineSubscription{},
+ })
+}
+
+// scanSubscribers returns a list of client subscriptions matching an indexed topic address.
+func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Subscribers) *Subscribers {
+ if n == nil {
+ n = x.root
+ }
+
+ if len(topic) == 0 {
+ return subs
+ }
+
+ key, hasNext := isolateParticle(topic, d)
+ for _, partKey := range []string{key, "+"} {
+ if particle := n.particles.get(partKey); particle != nil { // [MQTT-3.3.2-3]
+ if hasNext {
+ x.scanSubscribers(topic, d+1, particle, subs)
+ } else {
+ x.gatherSubscriptions(topic, particle, subs)
+ x.gatherSharedSubscriptions(particle, subs)
+ x.gatherInlineSubscriptions(particle, subs)
+
+ if wild := particle.particles.get("#"); wild != nil && partKey != "+" {
+ x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
+ x.gatherSharedSubscriptions(wild, subs)
+ x.gatherInlineSubscriptions(particle, subs)
+ }
+ }
+ }
+ }
+
+ if particle := n.particles.get("#"); particle != nil {
+ x.gatherSubscriptions(topic, particle, subs)
+ x.gatherSharedSubscriptions(particle, subs)
+ x.gatherInlineSubscriptions(particle, subs)
+ }
+
+ return subs
+}
+
+// gatherSubscriptions collects any matching subscriptions, and gathers any identifiers or highest qos values.
+func (x *TopicsIndex) gatherSubscriptions(topic string, particle *particle, subs *Subscribers) {
+ if subs.Subscriptions == nil {
+ subs.Subscriptions = map[string]packets.Subscription{}
+ }
+
+ for client, sub := range particle.subscriptions.GetAll() {
+ if len(sub.Filter) > 0 && topic[0] == '$' && (sub.Filter[0] == '+' || sub.Filter[0] == '#') { // don't match $ topics with top level wildcards [MQTT-4.7.1-1] [MQTT-4.7.1-2]
+ continue
+ }
+
+ cls, ok := subs.Subscriptions[client]
+ if !ok {
+ cls = sub
+ }
+
+ subs.Subscriptions[client] = cls.Merge(sub)
+ }
+}
+
+// gatherSharedSubscriptions gathers all shared subscriptions for a particle.
+func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscribers) {
+ if subs.Shared == nil {
+ subs.Shared = map[string]map[string]packets.Subscription{}
+ }
+
+ for _, shares := range particle.shared.GetAll() {
+ for client, sub := range shares {
+ if _, ok := subs.Shared[sub.Filter]; !ok {
+ subs.Shared[sub.Filter] = map[string]packets.Subscription{}
+ }
+
+ subs.Shared[sub.Filter][client] = sub
+ }
+ }
+}
+
+// gatherSharedSubscriptions gathers all inline subscriptions for a particle.
+func (x *TopicsIndex) gatherInlineSubscriptions(particle *particle, subs *Subscribers) {
+ if subs.InlineSubscriptions == nil {
+ subs.InlineSubscriptions = map[int]InlineSubscription{}
+ }
+
+ for id, inline := range particle.inlineSubscriptions.GetAll() {
+ subs.InlineSubscriptions[id] = inline
+ }
+}
+
+// isolateParticle extracts a particle between d / and d+1 / without allocations.
+func isolateParticle(filter string, d int) (particle string, hasNext bool) {
+ var next, end int
+ for i := 0; end > -1 && i <= d; i++ {
+ end = strings.IndexRune(filter, '/')
+
+ switch {
+ case d > -1 && i == d && end > -1:
+ hasNext = true
+ particle = filter[next:end]
+ case end > -1:
+ hasNext = false
+ filter = filter[end+1:]
+ default:
+ hasNext = false
+ particle = filter[next:]
+ }
+ }
+
+ return
+}
+
+// IsSharedFilter returns true if the filter uses the share prefix.
+func IsSharedFilter(filter string) bool {
+ prefix, _ := isolateParticle(filter, 0)
+ return strings.EqualFold(prefix, SharePrefix)
+}
+
+// IsValidFilter returns true if the filter is valid.
+func IsValidFilter(filter string, forPublish bool) bool {
+ if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publish.
+ return false // [MQTT-4.7.3-1]
+ }
+
+ if forPublish {
+ if len(filter) >= len(SysPrefix) && strings.EqualFold(filter[0:len(SysPrefix)], SysPrefix) {
+ // 4.7.2 Non-normative - The Server SHOULD prevent Clients from using such Topic Names [$SYS] to exchange messages with other Clients.
+ return false
+ }
+
+ if strings.ContainsRune(filter, '+') || strings.ContainsRune(filter, '#') {
+ return false //[MQTT-3.3.2-2]
+ }
+ }
+
+ wildhash := strings.IndexRune(filter, '#')
+ if wildhash >= 0 && wildhash != len(filter)-1 { // [MQTT-4.7.1-2]
+ return false
+ }
+
+ prefix, hasNext := isolateParticle(filter, 0)
+ if !hasNext && strings.EqualFold(prefix, SharePrefix) {
+ return false // [MQTT-4.8.2-1]
+ }
+
+ if hasNext && strings.EqualFold(prefix, SharePrefix) {
+ group, hasNext := isolateParticle(filter, 1)
+ if !hasNext {
+ return false // [MQTT-4.8.2-1]
+ }
+
+ if strings.ContainsRune(group, '+') || strings.ContainsRune(group, '#') {
+ return false // [MQTT-4.8.2-2]
+ }
+ }
+
+ return true
+}
+
+// particle is a child node on the tree.
+type particle struct {
+ key string // the key of the particle
+ parent *particle // a pointer to the parent of the particle
+ particles particles // a map of child particles
+ subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
+ shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
+ inlineSubscriptions *InlineSubscriptions // a map of inline subscriptions for this particle
+ retainPath string // path of a retained message
+ sync.Mutex // mutex for when making changes to the particle
+}
+
+// newParticle returns a pointer to a new instance of particle.
+func newParticle(key string, parent *particle) *particle {
+ return &particle{
+ key: key,
+ parent: parent,
+ particles: newParticles(),
+ subscriptions: NewSubscriptions(),
+ shared: NewSharedSubscriptions(),
+ inlineSubscriptions: NewInlineSubscriptions(),
+ }
+}
+
+// particles is a concurrency safe map of particles.
+type particles struct {
+ internal map[string]*particle
+ sync.RWMutex
+}
+
+// newParticles returns a map of particles.
+func newParticles() particles {
+ return particles{
+ internal: map[string]*particle{},
+ }
+}
+
+// add adds a new particle.
+func (p *particles) add(val *particle) {
+ p.Lock()
+ p.internal[val.key] = val
+ p.Unlock()
+}
+
+// getAll returns all particles.
+func (p *particles) getAll() map[string]*particle {
+ p.RLock()
+ defer p.RUnlock()
+ m := map[string]*particle{}
+ for k, v := range p.internal {
+ m[k] = v
+ }
+ return m
+}
+
+// get returns a particle by id (key).
+func (p *particles) get(id string) *particle {
+ p.RLock()
+ defer p.RUnlock()
+ return p.internal[id]
+}
+
+// len returns the number of particles.
+func (p *particles) len() int {
+ p.RLock()
+ defer p.RUnlock()
+ val := len(p.internal)
+ return val
+}
+
+// delete removes a particle.
+func (p *particles) delete(id string) {
+ p.Lock()
+ defer p.Unlock()
+ delete(p.internal, id)
+}
diff --git a/mqtt/topics_test.go b/mqtt/topics_test.go
new file mode 100644
index 0000000..ee63b71
--- /dev/null
+++ b/mqtt/topics_test.go
@@ -0,0 +1,1068 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package mqtt
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "testmqtt/packets"
+)
+
+const (
+ testGroup = "testgroup"
+ otherGroup = "other"
+)
+
+func TestNewSharedSubscriptions(t *testing.T) {
+ s := NewSharedSubscriptions()
+ require.NotNil(t, s.internal)
+}
+
+func TestSharedSubscriptionsAdd(t *testing.T) {
+ s := NewSharedSubscriptions()
+ s.Add(testGroup, "cl1", packets.Subscription{Filter: "a/b/c"})
+ require.Contains(t, s.internal, testGroup)
+ require.Contains(t, s.internal[testGroup], "cl1")
+}
+
+func TestSharedSubscriptionsGet(t *testing.T) {
+ s := NewSharedSubscriptions()
+ s.Add(testGroup, "cl1", packets.Subscription{Qos: 2})
+ s.Add(testGroup, "cl2", packets.Subscription{Qos: 2})
+ require.Contains(t, s.internal, testGroup)
+ require.Contains(t, s.internal[testGroup], "cl1")
+ require.Contains(t, s.internal, testGroup)
+ require.Contains(t, s.internal[testGroup], "cl2")
+
+ sub, ok := s.Get(testGroup, "cl2")
+ require.Equal(t, true, ok)
+ require.Equal(t, byte(2), sub.Qos)
+}
+
+func TestSharedSubscriptionsGetAll(t *testing.T) {
+ s := NewSharedSubscriptions()
+ s.Add(testGroup, "cl1", packets.Subscription{Qos: 0})
+ s.Add(testGroup, "cl2", packets.Subscription{Qos: 1})
+ s.Add(otherGroup, "cl3", packets.Subscription{Qos: 2})
+ require.Contains(t, s.internal, testGroup)
+ require.Contains(t, s.internal[testGroup], "cl1")
+ require.Contains(t, s.internal, testGroup)
+ require.Contains(t, s.internal[testGroup], "cl2")
+ require.Contains(t, s.internal, otherGroup)
+ require.Contains(t, s.internal[otherGroup], "cl3")
+
+ subs := s.GetAll()
+ require.Len(t, subs, 2)
+ require.Len(t, subs[testGroup], 2)
+ require.Len(t, subs[otherGroup], 1)
+}
+
+func TestSharedSubscriptionsLen(t *testing.T) {
+ s := NewSharedSubscriptions()
+ s.Add(testGroup, "cl1", packets.Subscription{Qos: 0})
+ s.Add(testGroup, "cl2", packets.Subscription{Qos: 1})
+ s.Add(otherGroup, "cl2", packets.Subscription{Qos: 1})
+ require.Contains(t, s.internal, testGroup)
+ require.Contains(t, s.internal[testGroup], "cl1")
+ require.Contains(t, s.internal, testGroup)
+ require.Contains(t, s.internal[testGroup], "cl2")
+ require.Contains(t, s.internal, otherGroup)
+ require.Contains(t, s.internal[otherGroup], "cl2")
+ require.Equal(t, 3, s.Len())
+ require.Equal(t, 2, s.GroupLen())
+}
+
+func TestSharedSubscriptionsDelete(t *testing.T) {
+ s := NewSharedSubscriptions()
+ s.Add(testGroup, "cl1", packets.Subscription{Qos: 1})
+ s.Add(testGroup, "cl2", packets.Subscription{Qos: 2})
+ require.Contains(t, s.internal, testGroup)
+ require.Contains(t, s.internal[testGroup], "cl1")
+ require.Contains(t, s.internal, testGroup)
+ require.Contains(t, s.internal[testGroup], "cl2")
+
+ require.Equal(t, 2, s.Len())
+
+ s.Delete(testGroup, "cl1")
+ _, ok := s.Get(testGroup, "cl1")
+ require.False(t, ok)
+ require.Equal(t, 1, s.GroupLen())
+ require.Equal(t, 1, s.Len())
+
+ s.Delete(testGroup, "cl2")
+ _, ok = s.Get(testGroup, "cl2")
+ require.False(t, ok)
+ require.Equal(t, 0, s.GroupLen())
+ require.Equal(t, 0, s.Len())
+}
+
+func TestNewSubscriptions(t *testing.T) {
+ s := NewSubscriptions()
+ require.NotNil(t, s.internal)
+}
+
+func TestSubscriptionsAdd(t *testing.T) {
+ s := NewSubscriptions()
+ s.Add("cl1", packets.Subscription{})
+ require.Contains(t, s.internal, "cl1")
+}
+
+func TestSubscriptionsGet(t *testing.T) {
+ s := NewSubscriptions()
+ s.Add("cl1", packets.Subscription{Qos: 2})
+ s.Add("cl2", packets.Subscription{Qos: 2})
+ require.Contains(t, s.internal, "cl1")
+ require.Contains(t, s.internal, "cl2")
+
+ sub, ok := s.Get("cl1")
+ require.True(t, ok)
+ require.Equal(t, byte(2), sub.Qos)
+}
+
+func TestSubscriptionsGetAll(t *testing.T) {
+ s := NewSubscriptions()
+ s.Add("cl1", packets.Subscription{Qos: 0})
+ s.Add("cl2", packets.Subscription{Qos: 1})
+ s.Add("cl3", packets.Subscription{Qos: 2})
+ require.Contains(t, s.internal, "cl1")
+ require.Contains(t, s.internal, "cl2")
+ require.Contains(t, s.internal, "cl3")
+
+ subs := s.GetAll()
+ require.Len(t, subs, 3)
+}
+
+func TestSubscriptionsLen(t *testing.T) {
+ s := NewSubscriptions()
+ s.Add("cl1", packets.Subscription{Qos: 0})
+ s.Add("cl2", packets.Subscription{Qos: 1})
+ require.Contains(t, s.internal, "cl1")
+ require.Contains(t, s.internal, "cl2")
+ require.Equal(t, 2, s.Len())
+}
+
+func TestSubscriptionsDelete(t *testing.T) {
+ s := NewSubscriptions()
+ s.Add("cl1", packets.Subscription{Qos: 1})
+ require.Contains(t, s.internal, "cl1")
+
+ s.Delete("cl1")
+ _, ok := s.Get("cl1")
+ require.False(t, ok)
+}
+
+func TestNewTopicsIndex(t *testing.T) {
+ index := NewTopicsIndex()
+ require.NotNil(t, index)
+ require.NotNil(t, index.root)
+}
+
+func BenchmarkNewTopicsIndex(b *testing.B) {
+ for n := 0; n < b.N; n++ {
+ NewTopicsIndex()
+ }
+}
+
+func TestSubscribe(t *testing.T) {
+ tt := []struct {
+ desc string
+ client string
+ filter string
+ subscription packets.Subscription
+ wasNew bool
+ }{
+ {
+ desc: "subscribe",
+ client: "cl1",
+
+ subscription: packets.Subscription{Filter: "a/b/c", Qos: 2},
+ wasNew: true,
+ },
+ {
+ desc: "subscribe existed",
+ client: "cl1",
+
+ subscription: packets.Subscription{Filter: "a/b/c", Qos: 1},
+ wasNew: false,
+ },
+ {
+ desc: "subscribe case sensitive didnt exist",
+ client: "cl1",
+
+ subscription: packets.Subscription{Filter: "A/B/c", Qos: 1},
+ wasNew: true,
+ },
+ {
+ desc: "wildcard+ sub",
+ client: "cl1",
+
+ subscription: packets.Subscription{Filter: "d/+"},
+ wasNew: true,
+ },
+ {
+ desc: "wildcard# sub",
+ client: "cl1",
+ subscription: packets.Subscription{Filter: "d/e/#"},
+ wasNew: true,
+ },
+ }
+
+ index := NewTopicsIndex()
+ for _, tx := range tt {
+ t.Run(tx.desc, func(t *testing.T) {
+ require.Equal(t, tx.wasNew, index.Subscribe(tx.client, tx.subscription))
+ })
+ }
+
+ final := index.root.particles.get("a").particles.get("b").particles.get("c")
+ require.NotNil(t, final)
+ client, exists := final.subscriptions.Get("cl1")
+ require.True(t, exists)
+ require.Equal(t, byte(1), client.Qos)
+}
+
+func TestSubscribeShared(t *testing.T) {
+ index := NewTopicsIndex()
+ index.Subscribe("cl1", packets.Subscription{Filter: SharePrefix + "/tmp/a/b/c", Qos: 2})
+ final := index.root.particles.get("a").particles.get("b").particles.get("c")
+ require.NotNil(t, final)
+ client, exists := final.shared.Get("tmp", "cl1")
+ require.True(t, exists)
+ require.Equal(t, byte(2), client.Qos)
+ require.Equal(t, 0, final.subscriptions.Len())
+ require.Equal(t, 1, final.shared.Len())
+}
+
+func BenchmarkSubscribe(b *testing.B) {
+ index := NewTopicsIndex()
+ for n := 0; n < b.N; n++ {
+ index.Subscribe("client-1", packets.Subscription{Filter: "a/b/c"})
+ }
+}
+
+func BenchmarkSubscribeShared(b *testing.B) {
+ index := NewTopicsIndex()
+ for n := 0; n < b.N; n++ {
+ index.Subscribe("client-1", packets.Subscription{Filter: "$SHARE/tmp/a/b/c"})
+ }
+}
+
+func TestUnsubscribe(t *testing.T) {
+ index := NewTopicsIndex()
+ index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/d", Qos: 1})
+ client, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").subscriptions.Get("cl1")
+ require.NotNil(t, client)
+ require.True(t, exists)
+
+ index.Subscribe("cl1", packets.Subscription{Filter: "a/b/+/d", Qos: 1})
+ client, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").subscriptions.Get("cl1")
+ require.NotNil(t, client)
+ require.True(t, exists)
+
+ index.Subscribe("cl1", packets.Subscription{Filter: "d/e/f", Qos: 1})
+ client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl1")
+ require.NotNil(t, client)
+ require.True(t, exists)
+
+ index.Subscribe("cl2", packets.Subscription{Filter: "d/e/f", Qos: 1})
+ client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl2")
+ require.NotNil(t, client)
+ require.True(t, exists)
+
+ index.Subscribe("cl3", packets.Subscription{Filter: "#", Qos: 2})
+ client, exists = index.root.particles.get("#").subscriptions.Get("cl3")
+ require.NotNil(t, client)
+ require.True(t, exists)
+
+ ok := index.Unsubscribe("a/b/c/d", "cl1")
+ require.True(t, ok)
+ require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
+ client, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").subscriptions.Get("cl1")
+ require.NotNil(t, client)
+ require.True(t, exists)
+
+ ok = index.Unsubscribe("d/e/f", "cl1")
+ require.True(t, ok)
+
+ require.Equal(t, 1, index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Len())
+ client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl2")
+ require.NotNil(t, client)
+ require.True(t, exists)
+
+ ok = index.Unsubscribe("fdasfdas/dfsfads/sa", "nobody")
+ require.False(t, ok)
+}
+
+func TestUnsubscribeNoCascade(t *testing.T) {
+ index := NewTopicsIndex()
+ index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
+ index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/e/e"})
+
+ ok := index.Unsubscribe("a/b/c/e/e", "cl1")
+ require.True(t, ok)
+ require.Equal(t, 1, index.root.particles.len())
+
+ client, exists := index.root.particles.get("a").particles.get("b").particles.get("c").subscriptions.Get("cl1")
+ require.NotNil(t, client)
+ require.True(t, exists)
+}
+
+func TestUnsubscribeShared(t *testing.T) {
+ index := NewTopicsIndex()
+ index.Subscribe("cl1", packets.Subscription{Filter: "$SHARE/tmp/a/b/c", Qos: 2})
+ final := index.root.particles.get("a").particles.get("b").particles.get("c")
+ require.NotNil(t, final)
+ client, exists := final.shared.Get("tmp", "cl1")
+ require.True(t, exists)
+ require.Equal(t, byte(2), client.Qos)
+
+ require.True(t, index.Unsubscribe("$share/tmp/a/b/c", "cl1"))
+ _, exists = final.shared.Get("tmp", "cl1")
+ require.False(t, exists)
+}
+
+func BenchmarkUnsubscribe(b *testing.B) {
+ index := NewTopicsIndex()
+
+ for n := 0; n < b.N; n++ {
+ b.StopTimer()
+ index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
+ b.StartTimer()
+ index.Unsubscribe("a/b/c", "cl1")
+ }
+}
+
+func TestIndexSeek(t *testing.T) {
+ filter := "a/b/c/d/e/f"
+ index := NewTopicsIndex()
+ k1 := index.set(filter, 0)
+ require.Equal(t, "f", k1.key)
+ k1.subscriptions.Add("cl1", packets.Subscription{})
+
+ require.Equal(t, k1, index.seek(filter, 0))
+ require.Nil(t, index.seek("d/e/f", 0))
+}
+
+func TestIndexTrim(t *testing.T) {
+ index := NewTopicsIndex()
+ k1 := index.set("a/b/c", 0)
+ require.Equal(t, "c", k1.key)
+ k1.subscriptions.Add("cl1", packets.Subscription{})
+
+ k2 := index.set("a/b/c/d/e/f", 0)
+ require.Equal(t, "f", k2.key)
+ k2.subscriptions.Add("cl1", packets.Subscription{})
+
+ k3 := index.set("a/b", 0)
+ require.Equal(t, "b", k3.key)
+ k3.subscriptions.Add("cl1", packets.Subscription{})
+
+ index.trim(k2)
+ require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
+ require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").particles.get("e").particles.get("f"))
+ require.NotNil(t, index.root.particles.get("a").particles.get("b"))
+
+ k2.subscriptions.Delete("cl1")
+ index.trim(k2)
+
+ require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d"))
+ require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
+
+ k1.subscriptions.Delete("cl1")
+ k3.subscriptions.Delete("cl1")
+ index.trim(k2)
+ require.Nil(t, index.root.particles.get("a"))
+}
+
+func TestIndexSet(t *testing.T) {
+ index := NewTopicsIndex()
+ child := index.set("a/b/c", 0)
+ require.Equal(t, "c", child.key)
+ require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
+
+ child = index.set("a/b/c/d/e", 0)
+ require.Equal(t, "e", child.key)
+
+ child = index.set("a/b/c/c/a", 0)
+ require.Equal(t, "a", child.key)
+}
+
+func TestIndexSetPrefixed(t *testing.T) {
+ index := NewTopicsIndex()
+ child := index.set("/c", 0)
+ require.Equal(t, "c", child.key)
+ require.NotNil(t, index.root.particles.get("").particles.get("c"))
+}
+
+func BenchmarkIndexSet(b *testing.B) {
+ index := NewTopicsIndex()
+ for n := 0; n < b.N; n++ {
+ index.set("a/b/c", 0)
+ }
+}
+
+func TestRetainMessage(t *testing.T) {
+ pk := packets.Packet{
+ FixedHeader: packets.FixedHeader{Retain: true},
+ TopicName: "a/b/c",
+ Payload: []byte("hello"),
+ }
+
+ index := NewTopicsIndex()
+ r := index.RetainMessage(pk)
+ require.Equal(t, int64(1), r)
+ pke, ok := index.Retained.Get(pk.TopicName)
+ require.True(t, ok)
+ require.Equal(t, pk, pke)
+
+ pk2 := packets.Packet{
+ FixedHeader: packets.FixedHeader{Retain: true},
+ TopicName: "a/b/d/f",
+ Payload: []byte("hello"),
+ }
+ r = index.RetainMessage(pk2)
+ require.Equal(t, int64(1), r)
+ // The same message already exists, but we're not doing a deep-copy check, so it's considered to be a new message.
+ r = index.RetainMessage(pk2)
+ require.Equal(t, int64(1), r)
+
+ // Clear existing retained
+ pk3 := packets.Packet{TopicName: "a/b/c", Payload: []byte{}}
+ r = index.RetainMessage(pk3)
+ require.Equal(t, int64(-1), r)
+ _, ok = index.Retained.Get(pk.TopicName)
+ require.False(t, ok)
+
+ // Clear no retained
+ r = index.RetainMessage(pk3)
+ require.Equal(t, int64(0), r)
+}
+
+func BenchmarkRetainMessage(b *testing.B) {
+ index := NewTopicsIndex()
+ for n := 0; n < b.N; n++ {
+ index.RetainMessage(packets.Packet{TopicName: "a/b/c/d"})
+ }
+}
+
+func TestIsolateParticle(t *testing.T) {
+ particle, hasNext := isolateParticle("path/to/my/mqtt", 0)
+ require.Equal(t, "path", particle)
+ require.Equal(t, true, hasNext)
+ particle, hasNext = isolateParticle("path/to/my/mqtt", 1)
+ require.Equal(t, "to", particle)
+ require.Equal(t, true, hasNext)
+ particle, hasNext = isolateParticle("path/to/my/mqtt", 2)
+ require.Equal(t, "my", particle)
+ require.Equal(t, true, hasNext)
+ particle, hasNext = isolateParticle("path/to/my/mqtt", 3)
+ require.Equal(t, "mqtt", particle)
+ require.Equal(t, false, hasNext)
+
+ particle, hasNext = isolateParticle("/path/", 0)
+ require.Equal(t, "", particle)
+ require.Equal(t, true, hasNext)
+ particle, hasNext = isolateParticle("/path/", 1)
+ require.Equal(t, "path", particle)
+ require.Equal(t, true, hasNext)
+ particle, hasNext = isolateParticle("/path/", 2)
+ require.Equal(t, "", particle)
+ require.Equal(t, false, hasNext)
+
+ particle, hasNext = isolateParticle("a/b/c/+/+", 3)
+ require.Equal(t, "+", particle)
+ require.Equal(t, true, hasNext)
+ particle, hasNext = isolateParticle("a/b/c/+/+", 4)
+ require.Equal(t, "+", particle)
+ require.Equal(t, false, hasNext)
+}
+
+func BenchmarkIsolateParticle(b *testing.B) {
+ for n := 0; n < b.N; n++ {
+ isolateParticle("path/to/my/mqtt", 3)
+ }
+}
+
+func TestScanSubscribers(t *testing.T) {
+ index := NewTopicsIndex()
+ index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: "a/b/c", Identifier: 22})
+ index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: "a/b/c/d/e/f"})
+ index.Subscribe("cl1", packets.Subscription{Qos: 2, Filter: "a/b/c/d/+/f"})
+ index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "a/#"})
+ index.Subscribe("cl2", packets.Subscription{Qos: 1, Filter: "a/b/c"})
+ index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "a/b/+", Identifier: 77})
+ index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "d/e/f", Identifier: 7237})
+ index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "$SYS/uptime", Identifier: 3})
+ index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: "+/b", Identifier: 234})
+ index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: "#", Identifier: 5})
+ index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "$SYS/test", Identifier: 2})
+
+ subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
+ require.Equal(t, 3, len(subs.Subscriptions))
+ require.Contains(t, subs.Subscriptions, "cl1")
+ require.Contains(t, subs.Subscriptions, "cl2")
+ require.Contains(t, subs.Subscriptions, "cl4")
+
+ require.Equal(t, byte(1), subs.Subscriptions["cl1"].Qos)
+ require.Equal(t, byte(2), subs.Subscriptions["cl2"].Qos)
+ require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos)
+
+ require.Equal(t, 22, subs.Subscriptions["cl1"].Identifiers["a/b/c"])
+ require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/#"])
+ require.Equal(t, 77, subs.Subscriptions["cl2"].Identifiers["a/b/+"])
+ require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/b/c"])
+ require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"])
+
+ subs = index.scanSubscribers("d/e/f/g", 0, nil, new(Subscribers))
+ require.Equal(t, 1, len(subs.Subscriptions))
+ require.Contains(t, subs.Subscriptions, "cl4")
+ require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos)
+ require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"])
+
+ subs = index.scanSubscribers("", 0, nil, new(Subscribers))
+ require.Equal(t, 0, len(subs.Subscriptions))
+}
+
+func TestScanSubscribersTopicInheritanceBug(t *testing.T) {
+ index := NewTopicsIndex()
+ index.Subscribe("cl1", packets.Subscription{Qos: 0, Filter: "a/b/c"})
+ index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "a/b"})
+
+ subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
+ require.Equal(t, 1, len(subs.Subscriptions))
+}
+
+func TestScanSubscribersShared(t *testing.T) {
+ index := NewTopicsIndex()
+ index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
+ index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 112})
+ index.Subscribe("cl3", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 113})
+ index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 10})
+ index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 200})
+ index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 201})
+ index.Subscribe("cl5", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c/#"})
+ subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
+ require.Equal(t, 4, len(subs.Shared))
+}
+
+func TestSelectSharedSubscriber(t *testing.T) {
+ index := NewTopicsIndex()
+ index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 110})
+ index.Subscribe("cl1b", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
+ index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 112})
+ index.Subscribe("cl3", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 113})
+ subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
+ require.Equal(t, 2, len(subs.Shared))
+ require.Contains(t, subs.Shared, SharePrefix+"/tmp/a/b/c")
+ require.Contains(t, subs.Shared, SharePrefix+"/tmp2/a/b/c")
+ require.Len(t, subs.Shared[SharePrefix+"/tmp/a/b/c"], 3)
+ require.Len(t, subs.Shared[SharePrefix+"/tmp2/a/b/c"], 1)
+ subs.SelectShared()
+ require.Len(t, subs.SharedSelected, 2)
+}
+
+func TestMergeSharedSelected(t *testing.T) {
+ s := &Subscribers{
+ SharedSelected: map[string]packets.Subscription{
+ "cl1": {Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 110},
+ "cl2": {Qos: 1, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 111},
+ },
+ Subscriptions: map[string]packets.Subscription{
+ "cl2": {Qos: 1, Filter: "a/b/c", Identifier: 112},
+ },
+ }
+
+ s.MergeSharedSelected()
+
+ require.Equal(t, 2, len(s.Subscriptions))
+ require.Contains(t, s.Subscriptions, "cl1")
+ require.Contains(t, s.Subscriptions, "cl2")
+ require.EqualValues(t, map[string]int{
+ SharePrefix + "/tmp2/a/b/c": 111,
+ "a/b/c": 112,
+ }, s.Subscriptions["cl2"].Identifiers)
+}
+
+func TestSubscribersFind(t *testing.T) {
+ tt := []struct {
+ filter string
+ topic string
+ matched bool
+ }{
+ {filter: "a", topic: "a", matched: true},
+ {filter: "a/", topic: "a", matched: false},
+ {filter: "a/", topic: "a/", matched: true},
+ {filter: "/a", topic: "/a", matched: true},
+ {filter: "path/to/my/mqtt", topic: "path/to/my/mqtt", matched: true},
+ {filter: "path/to/+/mqtt", topic: "path/to/my/mqtt", matched: true},
+ {filter: "+/to/+/mqtt", topic: "path/to/my/mqtt", matched: true},
+ {filter: "#", topic: "path/to/my/mqtt", matched: true},
+ {filter: "+/+/+/+", topic: "path/to/my/mqtt", matched: true},
+ {filter: "+/+/+/#", topic: "path/to/my/mqtt", matched: true},
+ {filter: "zen/#", topic: "zen", matched: true}, // as per 4.7.1.2
+ {filter: "trailing-end/#", topic: "trailing-end/", matched: true},
+ {filter: "+/prefixed", topic: "/prefixed", matched: true},
+ {filter: "+/+/#", topic: "path/to/my/mqtt", matched: true},
+ {filter: "path/to/", topic: "path/to/my/mqtt", matched: false},
+ {filter: "#/stuff", topic: "path/to/my/mqtt", matched: false},
+ {filter: "#", topic: "$SYS/info", matched: false},
+ {filter: "$SYS/#", topic: "$SYS/info", matched: true},
+ {filter: "+/info", topic: "$SYS/info", matched: false},
+ }
+
+ for _, tx := range tt {
+ t.Run("filter:'"+tx.filter+"' vs topic:'"+tx.topic+"'", func(t *testing.T) {
+ index := NewTopicsIndex()
+ index.Subscribe("cl1", packets.Subscription{Filter: tx.filter})
+ subs := index.Subscribers(tx.topic)
+ require.Equal(t, tx.matched, len(subs.Subscriptions) == 1)
+ })
+ }
+}
+
+func BenchmarkSubscribers(b *testing.B) {
+ index := NewTopicsIndex()
+ index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
+ index.Subscribe("cl1", packets.Subscription{Filter: "a/+/c"})
+ index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/+"})
+ index.Subscribe("cl2", packets.Subscription{Filter: "a/b/c/d"})
+ index.Subscribe("cl3", packets.Subscription{Filter: "#"})
+
+ for n := 0; n < b.N; n++ {
+ index.Subscribers("a/b/c")
+ }
+}
+
+func TestMessagesPattern(t *testing.T) {
+ payload := []byte("hello")
+ fh := packets.FixedHeader{Type: packets.Publish, Retain: true}
+
+ pks := []packets.Packet{
+ {TopicName: "$SYS/uptime", Payload: payload, FixedHeader: fh},
+ {TopicName: "$SYS/info", Payload: payload, FixedHeader: fh},
+ {TopicName: "a/b/c/d", Payload: payload, FixedHeader: fh},
+ {TopicName: "a/b/c/e", Payload: payload, FixedHeader: fh},
+ {TopicName: "a/b/d/f", Payload: payload, FixedHeader: fh},
+ {TopicName: "q/w/e/r/t/y", Payload: payload, FixedHeader: fh},
+ {TopicName: "q/x/e/r/t/o", Payload: payload, FixedHeader: fh},
+ {TopicName: "asdf", Payload: payload, FixedHeader: fh},
+ }
+
+ tt := []struct {
+ filter string
+ len int
+ }{
+ {"a/b/c/d", 1},
+ {"$SYS/+", 2},
+ {"$SYS/#", 2},
+ {"#", len(pks) - 2},
+ {"a/b/c/+", 2},
+ {"a/+/c/+", 2},
+ {"+/+/+/d", 1},
+ {"q/w/e/#", 1},
+ {"+/+/+/+", 3},
+ {"q/#", 2},
+ {"asdf", 1},
+ {"", 0},
+ {"#", 6},
+ }
+
+ index := NewTopicsIndex()
+ for _, pk := range pks {
+ index.RetainMessage(pk)
+ }
+
+ for _, tx := range tt {
+ t.Run("filter:'"+tx.filter, func(t *testing.T) {
+ messages := index.Messages(tx.filter)
+ require.Equal(t, tx.len, len(messages))
+ })
+ }
+}
+
+func BenchmarkMessages(b *testing.B) {
+ index := NewTopicsIndex()
+ index.RetainMessage(packets.Packet{TopicName: "a/b/c/d"})
+ index.RetainMessage(packets.Packet{TopicName: "a/b/d/e/f"})
+ index.RetainMessage(packets.Packet{TopicName: "d/e/f/g"})
+ index.RetainMessage(packets.Packet{TopicName: "$SYS/info"})
+ index.RetainMessage(packets.Packet{TopicName: "q/w/e/r/t/y"})
+
+ for n := 0; n < b.N; n++ {
+ index.Messages("+/b/c/+")
+ }
+}
+
+func TestNewParticles(t *testing.T) {
+ cl := newParticles()
+ require.NotNil(t, cl.internal)
+}
+
+func TestParticlesAdd(t *testing.T) {
+ p := newParticles()
+ p.add(&particle{key: "a"})
+ require.Contains(t, p.internal, "a")
+}
+
+func TestParticlesGet(t *testing.T) {
+ p := newParticles()
+ p.add(&particle{key: "a"})
+ p.add(&particle{key: "b"})
+ require.Contains(t, p.internal, "a")
+ require.Contains(t, p.internal, "b")
+
+ particle := p.get("a")
+ require.NotNil(t, particle)
+ require.Equal(t, "a", particle.key)
+}
+
+func TestParticlesGetAll(t *testing.T) {
+ p := newParticles()
+ p.add(&particle{key: "a"})
+ p.add(&particle{key: "b"})
+ p.add(&particle{key: "c"})
+ require.Contains(t, p.internal, "a")
+ require.Contains(t, p.internal, "b")
+ require.Contains(t, p.internal, "c")
+
+ particles := p.getAll()
+ require.Len(t, particles, 3)
+}
+
+func TestParticlesLen(t *testing.T) {
+ p := newParticles()
+ p.add(&particle{key: "a"})
+ p.add(&particle{key: "b"})
+ require.Contains(t, p.internal, "a")
+ require.Contains(t, p.internal, "b")
+ require.Equal(t, 2, p.len())
+}
+
+func TestParticlesDelete(t *testing.T) {
+ p := newParticles()
+ p.add(&particle{key: "a"})
+ require.Contains(t, p.internal, "a")
+
+ p.delete("a")
+ particle := p.get("a")
+ require.Nil(t, particle)
+}
+
+func TestIsValid(t *testing.T) {
+ require.True(t, IsValidFilter("a/b/c", false))
+ require.True(t, IsValidFilter("a/b//c", false))
+ require.True(t, IsValidFilter("$SYS", false))
+ require.True(t, IsValidFilter("$SYS/info", false))
+ require.True(t, IsValidFilter("$sys/info", false))
+ require.True(t, IsValidFilter("abc/#", false))
+ require.False(t, IsValidFilter("", false))
+ require.False(t, IsValidFilter(SharePrefix, false))
+ require.False(t, IsValidFilter(SharePrefix+"/", false))
+ require.False(t, IsValidFilter(SharePrefix+"/b+/", false))
+ require.False(t, IsValidFilter(SharePrefix+"/+", false))
+ require.False(t, IsValidFilter(SharePrefix+"/#", false))
+ require.False(t, IsValidFilter(SharePrefix+"/#/", false))
+ require.False(t, IsValidFilter("a/#/c", false))
+}
+
+func TestIsValidForPublish(t *testing.T) {
+ require.True(t, IsValidFilter("", true))
+ require.True(t, IsValidFilter("a/b/c", true))
+ require.False(t, IsValidFilter("a/b/+/d", true))
+ require.False(t, IsValidFilter("a/b/#", true))
+ require.False(t, IsValidFilter("$SYS/info", true))
+}
+
+func TestIsSharedFilter(t *testing.T) {
+ require.True(t, IsSharedFilter(SharePrefix+"/tmp/a/b/c"))
+ require.False(t, IsSharedFilter("a/b/c"))
+}
+
+func TestNewInboundAliases(t *testing.T) {
+ a := NewInboundTopicAliases(5)
+ require.NotNil(t, a)
+ require.NotNil(t, a.internal)
+ require.Equal(t, uint16(5), a.maximum)
+}
+
+func TestInboundAliasesSet(t *testing.T) {
+ topic := "test"
+ id := uint16(1)
+ a := NewInboundTopicAliases(5)
+ require.Equal(t, topic, a.Set(id, topic))
+ require.Contains(t, a.internal, id)
+ require.Equal(t, a.internal[id], topic)
+
+ require.Equal(t, topic, a.Set(id, ""))
+}
+
+func TestInboundAliasesSetMaxZero(t *testing.T) {
+ topic := "test"
+ id := uint16(1)
+ a := NewInboundTopicAliases(0)
+ require.Equal(t, topic, a.Set(id, topic))
+ require.NotContains(t, a.internal, id)
+}
+
+func TestNewOutboundAliases(t *testing.T) {
+ a := NewOutboundTopicAliases(5)
+ require.NotNil(t, a)
+ require.NotNil(t, a.internal)
+ require.Equal(t, uint16(5), a.maximum)
+ require.Equal(t, uint32(0), a.cursor)
+}
+
+func TestOutboundAliasesSet(t *testing.T) {
+ a := NewOutboundTopicAliases(3)
+ n, ok := a.Set("t1")
+ require.False(t, ok)
+ require.Equal(t, uint16(1), n)
+
+ n, ok = a.Set("t2")
+ require.False(t, ok)
+ require.Equal(t, uint16(2), n)
+
+ n, ok = a.Set("t3")
+ require.False(t, ok)
+ require.Equal(t, uint16(3), n)
+
+ n, ok = a.Set("t4")
+ require.False(t, ok)
+ require.Equal(t, uint16(0), n)
+
+ n, ok = a.Set("t2")
+ require.True(t, ok)
+ require.Equal(t, uint16(2), n)
+}
+
+func TestOutboundAliasesSetMaxZero(t *testing.T) {
+ topic := "test"
+ a := NewOutboundTopicAliases(0)
+ n, ok := a.Set(topic)
+ require.False(t, ok)
+ require.Equal(t, uint16(0), n)
+}
+
+func TestNewTopicAliases(t *testing.T) {
+ a := NewTopicAliases(5)
+ require.NotNil(t, a.Inbound)
+ require.Equal(t, uint16(5), a.Inbound.maximum)
+ require.NotNil(t, a.Outbound)
+ require.Equal(t, uint16(5), a.Outbound.maximum)
+}
+
+func TestNewInlineSubscriptions(t *testing.T) {
+ subscriptions := NewInlineSubscriptions()
+ require.NotNil(t, subscriptions)
+ require.NotNil(t, subscriptions.internal)
+ require.Equal(t, 0, subscriptions.Len())
+}
+
+func TestInlineSubscriptionAdd(t *testing.T) {
+ subscriptions := NewInlineSubscriptions()
+ handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ // handler logic
+ }
+
+ subscription := InlineSubscription{
+ Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
+ Handler: handler,
+ }
+ subscriptions.Add(subscription)
+
+ sub, ok := subscriptions.Get(1)
+ require.True(t, ok)
+ require.Equal(t, "a/b/c", sub.Filter)
+ require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler))
+}
+
+func TestInlineSubscriptionGet(t *testing.T) {
+ subscriptions := NewInlineSubscriptions()
+ handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ // handler logic
+ }
+
+ subscription := InlineSubscription{
+ Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
+ Handler: handler,
+ }
+ subscriptions.Add(subscription)
+
+ sub, ok := subscriptions.Get(1)
+ require.True(t, ok)
+ require.Equal(t, "a/b/c", sub.Filter)
+ require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler))
+
+ _, ok = subscriptions.Get(999)
+ require.False(t, ok)
+}
+
+func TestInlineSubscriptionsGetAll(t *testing.T) {
+ subscriptions := NewInlineSubscriptions()
+
+ subscriptions.Add(InlineSubscription{
+ Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
+ })
+ subscriptions.Add(InlineSubscription{
+ Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
+ })
+ subscriptions.Add(InlineSubscription{
+ Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2},
+ })
+ subscriptions.Add(InlineSubscription{
+ Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 3},
+ })
+
+ allSubs := subscriptions.GetAll()
+ require.Len(t, allSubs, 3)
+ require.Contains(t, allSubs, 1)
+ require.Contains(t, allSubs, 2)
+ require.Contains(t, allSubs, 3)
+}
+
+func TestInlineSubscriptionDelete(t *testing.T) {
+ subscriptions := NewInlineSubscriptions()
+ handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ // handler logic
+ }
+
+ subscription := InlineSubscription{
+ Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
+ Handler: handler,
+ }
+ subscriptions.Add(subscription)
+
+ subscriptions.Delete(1)
+ _, ok := subscriptions.Get(1)
+ require.False(t, ok)
+ require.Empty(t, subscriptions.GetAll())
+ require.Zero(t, subscriptions.Len())
+}
+
+func TestInlineSubscribe(t *testing.T) {
+
+ handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ // handler logic
+ }
+
+ tt := []struct {
+ desc string
+ filter string
+ subscription InlineSubscription
+ wasNew bool
+ }{
+ {
+ desc: "subscribe",
+ filter: "a/b/c",
+ subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}},
+ wasNew: true,
+ },
+ {
+ desc: "subscribe existed",
+ filter: "a/b/c",
+ subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}},
+ wasNew: false,
+ },
+ {
+ desc: "subscribe different identifier",
+ filter: "a/b/c",
+ subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2}},
+ wasNew: true,
+ },
+ {
+ desc: "subscribe case sensitive didnt exist",
+ filter: "A/B/c",
+ subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "A/B/c", Identifier: 1}},
+ wasNew: true,
+ },
+ {
+ desc: "wildcard+ sub",
+ filter: "d/+",
+ subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/+", Identifier: 1}},
+ wasNew: true,
+ },
+ {
+ desc: "wildcard# sub",
+ filter: "d/e/#",
+ subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/#", Identifier: 1}},
+ wasNew: true,
+ },
+ }
+
+ index := NewTopicsIndex()
+ for _, tx := range tt {
+ t.Run(tx.desc, func(t *testing.T) {
+ require.Equal(t, tx.wasNew, index.InlineSubscribe(tx.subscription))
+ })
+ }
+
+ final := index.root.particles.get("a").particles.get("b").particles.get("c")
+ require.NotNil(t, final)
+}
+
+func TestInlineUnsubscribe(t *testing.T) {
+ handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
+ // handler logic
+ }
+
+ index := NewTopicsIndex()
+ index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}})
+ sub, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1)
+ require.NotNil(t, sub)
+ require.True(t, exists)
+
+ index = NewTopicsIndex()
+ index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}})
+ sub, exists = index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1)
+ require.NotNil(t, sub)
+ require.True(t, exists)
+
+ index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
+ sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
+ require.NotNil(t, sub)
+ require.True(t, exists)
+
+ index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 2}})
+ sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(2)
+ require.NotNil(t, sub)
+ require.True(t, exists)
+
+ index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/+/d", Identifier: 1}})
+ sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1)
+ require.NotNil(t, sub)
+ require.True(t, exists)
+
+ index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
+ sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
+ require.NotNil(t, sub)
+ require.True(t, exists)
+
+ index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
+ sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
+ require.NotNil(t, sub)
+ require.True(t, exists)
+
+ index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "#", Identifier: 1}})
+ sub, exists = index.root.particles.get("#").inlineSubscriptions.Get(1)
+ require.NotNil(t, sub)
+ require.True(t, exists)
+
+ ok := index.InlineUnsubscribe(1, "a/b/c/d")
+ require.True(t, ok)
+ require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
+
+ sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1)
+ require.NotNil(t, sub)
+ require.True(t, exists)
+
+ ok = index.InlineUnsubscribe(1, "d/e/f")
+ require.True(t, ok)
+ require.NotNil(t, index.root.particles.get("d").particles.get("e").particles.get("f"))
+
+ ok = index.InlineUnsubscribe(1, "not/exist")
+ require.False(t, ok)
+}
diff --git a/packets/codec.go b/packets/codec.go
new file mode 100644
index 0000000..152d777
--- /dev/null
+++ b/packets/codec.go
@@ -0,0 +1,172 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package packets
+
+import (
+ "bytes"
+ "encoding/binary"
+ "io"
+ "unicode/utf8"
+ "unsafe"
+)
+
+// bytesToString provides a zero-alloc no-copy byte to string conversion.
+// via https://github.com/golang/go/issues/25484#issuecomment-391415660
+func bytesToString(bs []byte) string {
+ return *(*string)(unsafe.Pointer(&bs))
+}
+
+// decodeUint16 extracts the value of two bytes from a byte array.
+func decodeUint16(buf []byte, offset int) (uint16, int, error) {
+ if len(buf) < offset+2 {
+ return 0, 0, ErrMalformedOffsetUintOutOfRange
+ }
+
+ return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil
+}
+
+// decodeUint32 extracts the value of four bytes from a byte array.
+func decodeUint32(buf []byte, offset int) (uint32, int, error) {
+ if len(buf) < offset+4 {
+ return 0, 0, ErrMalformedOffsetUintOutOfRange
+ }
+
+ return binary.BigEndian.Uint32(buf[offset : offset+4]), offset + 4, nil
+}
+
+// decodeString extracts a string from a byte array, beginning at an offset.
+func decodeString(buf []byte, offset int) (string, int, error) {
+ b, n, err := decodeBytes(buf, offset)
+ if err != nil {
+ return "", 0, err
+ }
+
+ if !validUTF8(b) { // [MQTT-1.5.4-1] [MQTT-3.1.3-5]
+ return "", 0, ErrMalformedInvalidUTF8
+ }
+
+ return bytesToString(b), n, nil
+}
+
+// validUTF8 checks if the byte array contains valid UTF-8 characters.
+func validUTF8(b []byte) bool {
+ return utf8.Valid(b) && bytes.IndexByte(b, 0x00) == -1 // [MQTT-1.5.4-1] [MQTT-1.5.4-2]
+}
+
+// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads.
+func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
+ length, next, err := decodeUint16(buf, offset)
+ if err != nil {
+ return make([]byte, 0), 0, err
+ }
+
+ if next+int(length) > len(buf) {
+ return make([]byte, 0), 0, ErrMalformedOffsetBytesOutOfRange
+ }
+
+ return buf[next : next+int(length)], next + int(length), nil
+}
+
+// decodeByte extracts the value of a byte from a byte array.
+func decodeByte(buf []byte, offset int) (byte, int, error) {
+ if len(buf) <= offset {
+ return 0, 0, ErrMalformedOffsetByteOutOfRange
+ }
+ return buf[offset], offset + 1, nil
+}
+
+// decodeByteBool extracts the value of a byte from a byte array and returns a bool.
+func decodeByteBool(buf []byte, offset int) (bool, int, error) {
+ if len(buf) <= offset {
+ return false, 0, ErrMalformedOffsetBoolOutOfRange
+ }
+ return 1&buf[offset] > 0, offset + 1, nil
+}
+
+// encodeBool returns a byte instead of a bool.
+func encodeBool(b bool) byte {
+ if b {
+ return 1
+ }
+ return 0
+}
+
+// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads.
+func encodeBytes(val []byte) []byte {
+ // In most circumstances the number of bytes being encoded is small.
+ // Setting the cap to a low amount allows us to account for those without
+ // triggering allocation growth on append unless we need to.
+ buf := make([]byte, 2, 32)
+ binary.BigEndian.PutUint16(buf, uint16(len(val)))
+ return append(buf, val...)
+}
+
+// encodeUint16 encodes a uint16 value to a byte array.
+func encodeUint16(val uint16) []byte {
+ buf := make([]byte, 2)
+ binary.BigEndian.PutUint16(buf, val)
+ return buf
+}
+
+// encodeUint32 encodes a uint16 value to a byte array.
+func encodeUint32(val uint32) []byte {
+ buf := make([]byte, 4)
+ binary.BigEndian.PutUint32(buf, val)
+ return buf
+}
+
+// encodeString encodes a string to a byte array.
+func encodeString(val string) []byte {
+ // Like encodeBytes, we set the cap to a small number to avoid
+ // triggering allocation growth on append unless we absolutely need to.
+ buf := make([]byte, 2, 32)
+ binary.BigEndian.PutUint16(buf, uint16(len(val)))
+ return append(buf, []byte(val)...)
+}
+
+// encodeLength writes length bits for the header.
+func encodeLength(b *bytes.Buffer, length int64) {
+ // 1.5.5 Variable Byte Integer encode non-normative
+ // https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
+ for {
+ eb := byte(length % 128)
+ length /= 128
+ if length > 0 {
+ eb |= 0x80
+ }
+ b.WriteByte(eb)
+ if length == 0 {
+ break // [MQTT-1.5.5-1]
+ }
+ }
+}
+
+func DecodeLength(b io.ByteReader) (n, bu int, err error) {
+ // see 1.5.5 Variable Byte Integer decode non-normative
+ // https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
+ var multiplier uint32
+ var value uint32
+ bu = 1
+ for {
+ eb, err := b.ReadByte()
+ if err != nil {
+ return 0, bu, err
+ }
+
+ value |= uint32(eb&127) << multiplier
+ if value > 268435455 {
+ return 0, bu, ErrMalformedVariableByteInteger
+ }
+
+ if (eb & 128) == 0 {
+ break
+ }
+
+ multiplier += 7
+ bu++
+ }
+
+ return int(value), bu, nil
+}
diff --git a/packets/codec_test.go b/packets/codec_test.go
new file mode 100644
index 0000000..9129721
--- /dev/null
+++ b/packets/codec_test.go
@@ -0,0 +1,422 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package packets
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "math"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestBytesToString(t *testing.T) {
+ b := []byte{'a', 'b', 'c'}
+ require.Equal(t, "abc", bytesToString(b))
+}
+
+func TestDecodeString(t *testing.T) {
+ expect := []struct {
+ name string
+ rawBytes []byte
+ result string
+ offset int
+ shouldFail error
+ }{
+ {
+ offset: 0,
+ rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
+ result: "a/b/c/d",
+ },
+ {
+ offset: 14,
+ rawBytes: []byte{
+ Connect << 4, 17, // Fixed header
+ 0, 6, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name
+ 3, // Protocol Version
+ 0, // Packet Flags
+ 0, 30, // Keepalive
+ 0, 3, // Client ID - MSB+LSB
+ 'h', 'e', 'y', // Client ID "zen"},
+ },
+ result: "hey",
+ },
+ {
+ offset: 2,
+ rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
+ result: "1/2/3/4/a/b/c/d/e/^/@/!",
+ },
+ {
+ offset: 0,
+ rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
+ result: "x/y/z",
+ },
+ {
+ offset: 0,
+ rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
+ shouldFail: ErrMalformedOffsetBytesOutOfRange,
+ },
+ {
+ offset: 5,
+ rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
+ shouldFail: ErrMalformedOffsetBytesOutOfRange,
+ },
+ {
+ offset: 9,
+ rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
+ shouldFail: ErrMalformedOffsetUintOutOfRange,
+ },
+ {
+ offset: 17,
+ rawBytes: []byte{
+ Connect << 4, 0, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 0, // Flags
+ 0, 20, // Keepalive
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ 0, 6, // Will Topic - MSB+LSB
+ 'l',
+ },
+ shouldFail: ErrMalformedOffsetBytesOutOfRange,
+ },
+ {
+ offset: 0,
+ rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
+ shouldFail: ErrMalformedInvalidUTF8,
+ },
+ }
+
+ for i, wanted := range expect {
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ result, _, err := decodeString(wanted.rawBytes, wanted.offset)
+ if wanted.shouldFail != nil {
+ require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
+ return
+ }
+
+ require.NoError(t, err)
+ require.Equal(t, wanted.result, result)
+ })
+ }
+}
+
+func TestDecodeStringZeroWidthNoBreak(t *testing.T) { // [MQTT-1.5.4-3]
+ result, _, err := decodeString([]byte{0, 3, 0xEF, 0xBB, 0xBF}, 0)
+ require.NoError(t, err)
+ require.Equal(t, "\ufeff", result)
+}
+
+func TestDecodeBytes(t *testing.T) {
+ expect := []struct {
+ rawBytes []byte
+ result []uint8
+ next int
+ offset int
+ shouldFail error
+ }{
+ {
+ rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // truncated connect packet (clean session)
+ result: []byte{0x4d, 0x51, 0x54, 0x54},
+ next: 6,
+ offset: 0,
+ },
+ {
+ rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // truncated connect packet, only checking start
+ result: []byte{0x4d, 0x51, 0x54, 0x54},
+ next: 6,
+ offset: 0,
+ },
+ {
+ rawBytes: []byte{0, 4, 77, 81},
+ offset: 0,
+ shouldFail: ErrMalformedOffsetBytesOutOfRange,
+ },
+ {
+ rawBytes: []byte{0, 4, 77, 81},
+ offset: 8,
+ shouldFail: ErrMalformedOffsetUintOutOfRange,
+ },
+ }
+
+ for i, wanted := range expect {
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
+ if wanted.shouldFail != nil {
+ require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
+ return
+ }
+
+ require.NoError(t, err)
+ require.Equal(t, wanted.result, result)
+ })
+ }
+}
+
+func TestDecodeByte(t *testing.T) {
+ expect := []struct {
+ rawBytes []byte
+ result uint8
+ offset int
+ shouldFail error
+ }{
+ {
+ rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
+ result: uint8(0x00),
+ offset: 0,
+ },
+ {
+ rawBytes: []byte{0, 4, 77, 81, 84, 84},
+ result: uint8(0x04),
+ offset: 1,
+ },
+ {
+ rawBytes: []byte{0, 4, 77, 81, 84, 84},
+ result: uint8(0x4d),
+ offset: 2,
+ },
+ {
+ rawBytes: []byte{0, 4, 77, 81, 84, 84},
+ result: uint8(0x51),
+ offset: 3,
+ },
+ {
+ rawBytes: []byte{0, 4, 77, 80, 82, 84},
+ offset: 8,
+ shouldFail: ErrMalformedOffsetByteOutOfRange,
+ },
+ }
+
+ for i, wanted := range expect {
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
+ if wanted.shouldFail != nil {
+ require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
+ return
+ }
+
+ require.NoError(t, err)
+ require.Equal(t, wanted.result, result)
+ require.Equal(t, i+1, offset)
+ })
+ }
+}
+
+func TestDecodeUint16(t *testing.T) {
+ expect := []struct {
+ rawBytes []byte
+ result uint16
+ offset int
+ shouldFail error
+ }{
+ {
+ rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
+ result: uint16(0x07),
+ offset: 0,
+ },
+ {
+ rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
+ result: uint16(0x761),
+ offset: 1,
+ },
+ {
+ rawBytes: []byte{0, 7, 255, 47},
+ offset: 8,
+ shouldFail: ErrMalformedOffsetUintOutOfRange,
+ },
+ }
+
+ for i, wanted := range expect {
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
+ if wanted.shouldFail != nil {
+ require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
+ return
+ }
+
+ require.NoError(t, err)
+ require.Equal(t, wanted.result, result)
+ require.Equal(t, i+2, offset)
+ })
+ }
+}
+
+func TestDecodeUint32(t *testing.T) {
+ expect := []struct {
+ rawBytes []byte
+ result uint32
+ offset int
+ shouldFail error
+ }{
+ {
+ rawBytes: []byte{0, 0, 0, 7, 8},
+ result: uint32(7),
+ offset: 0,
+ },
+ {
+ rawBytes: []byte{0, 0, 1, 226, 64, 8},
+ result: uint32(123456),
+ offset: 1,
+ },
+ {
+ rawBytes: []byte{0, 7, 255, 47},
+ offset: 8,
+ shouldFail: ErrMalformedOffsetUintOutOfRange,
+ },
+ }
+
+ for i, wanted := range expect {
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ result, offset, err := decodeUint32(wanted.rawBytes, wanted.offset)
+ if wanted.shouldFail != nil {
+ require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
+ return
+ }
+
+ require.NoError(t, err)
+ require.Equal(t, wanted.result, result)
+ require.Equal(t, i+4, offset)
+ })
+ }
+}
+
+func TestDecodeByteBool(t *testing.T) {
+ expect := []struct {
+ rawBytes []byte
+ result bool
+ offset int
+ shouldFail error
+ }{
+ {
+ rawBytes: []byte{0x00, 0x00},
+ result: false,
+ },
+ {
+ rawBytes: []byte{0x01, 0x00},
+ result: true,
+ },
+ {
+ rawBytes: []byte{0x01, 0x00},
+ offset: 5,
+ shouldFail: ErrMalformedOffsetBoolOutOfRange,
+ },
+ }
+
+ for i, wanted := range expect {
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
+ if wanted.shouldFail != nil {
+ require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
+ return
+ }
+
+ require.NoError(t, err)
+ require.Equal(t, wanted.result, result)
+ require.Equal(t, 1, offset)
+ })
+ }
+}
+
+func TestDecodeLength(t *testing.T) {
+ b := bytes.NewBuffer([]byte{0x78})
+ n, bu, err := DecodeLength(b)
+ require.NoError(t, err)
+ require.Equal(t, 120, n)
+ require.Equal(t, 1, bu)
+
+ b = bytes.NewBuffer([]byte{255, 255, 255, 127})
+ n, bu, err = DecodeLength(b)
+ require.NoError(t, err)
+ require.Equal(t, 268435455, n)
+ require.Equal(t, 4, bu)
+}
+
+func TestDecodeLengthErrors(t *testing.T) {
+ b := bytes.NewBuffer([]byte{})
+ _, _, err := DecodeLength(b)
+ require.Error(t, err)
+
+ b = bytes.NewBuffer([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f})
+ _, _, err = DecodeLength(b)
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrMalformedVariableByteInteger)
+}
+
+func TestEncodeBool(t *testing.T) {
+ result := encodeBool(true)
+ require.Equal(t, byte(1), result)
+
+ result = encodeBool(false)
+ require.Equal(t, byte(0), result)
+
+ // Check failure.
+ result = encodeBool(false)
+ require.NotEqual(t, byte(1), result)
+}
+
+func TestEncodeBytes(t *testing.T) {
+ result := encodeBytes([]byte("testing"))
+ require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result)
+
+ result = encodeBytes([]byte("testing"))
+ require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result)
+}
+
+func TestEncodeUint16(t *testing.T) {
+ result := encodeUint16(0)
+ require.Equal(t, []byte{0x00, 0x00}, result)
+
+ result = encodeUint16(32767)
+ require.Equal(t, []byte{0x7f, 0xff}, result)
+
+ result = encodeUint16(math.MaxUint16)
+ require.Equal(t, []byte{0xff, 0xff}, result)
+}
+
+func TestEncodeUint32(t *testing.T) {
+ result := encodeUint32(7)
+ require.Equal(t, []byte{0x00, 0x00, 0x00, 0x07}, result)
+
+ result = encodeUint32(32767)
+ require.Equal(t, []byte{0, 0, 127, 255}, result)
+
+ result = encodeUint32(math.MaxUint32)
+ require.Equal(t, []byte{255, 255, 255, 255}, result)
+}
+
+func TestEncodeString(t *testing.T) {
+ result := encodeString("testing")
+ require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result)
+
+ result = encodeString("")
+ require.Equal(t, []uint8{0x00, 0x00}, result)
+
+ result = encodeString("a")
+ require.Equal(t, []uint8{0x00, 0x01, 0x61}, result)
+
+ result = encodeString("b")
+ require.NotEqual(t, []uint8{0x00, 0x00}, result)
+}
+
+func TestEncodeLength(t *testing.T) {
+ b := new(bytes.Buffer)
+ encodeLength(b, 120)
+ require.Equal(t, []byte{0x78}, b.Bytes())
+
+ b = new(bytes.Buffer)
+ encodeLength(b, math.MaxInt64)
+ require.Equal(t, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, b.Bytes())
+}
+
+func TestValidUTF8(t *testing.T) {
+ require.True(t, validUTF8([]byte{0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}))
+ require.False(t, validUTF8([]byte{0xff, 0xff}))
+ require.False(t, validUTF8([]byte{0x74, 0x00, 0x73, 0x74}))
+}
diff --git a/packets/codes.go b/packets/codes.go
new file mode 100644
index 0000000..5af1b74
--- /dev/null
+++ b/packets/codes.go
@@ -0,0 +1,149 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package packets
+
+// Code contains a reason code and reason string for a response.
+type Code struct {
+ Reason string
+ Code byte
+}
+
+// String returns the readable reason for a code.
+func (c Code) String() string {
+ return c.Reason
+}
+
+// Error returns the readable reason for a code.
+func (c Code) Error() string {
+ return c.Reason
+}
+
+var (
+ // QosCodes indicates the reason codes for each Qos byte.
+ QosCodes = map[byte]Code{
+ 0: CodeGrantedQos0,
+ 1: CodeGrantedQos1,
+ 2: CodeGrantedQos2,
+ }
+
+ CodeSuccessIgnore = Code{Code: 0x00, Reason: "ignore packet"}
+ CodeSuccess = Code{Code: 0x00, Reason: "success"}
+ CodeDisconnect = Code{Code: 0x00, Reason: "disconnected"}
+ CodeGrantedQos0 = Code{Code: 0x00, Reason: "granted qos 0"}
+ CodeGrantedQos1 = Code{Code: 0x01, Reason: "granted qos 1"}
+ CodeGrantedQos2 = Code{Code: 0x02, Reason: "granted qos 2"}
+ CodeDisconnectWillMessage = Code{Code: 0x04, Reason: "disconnect with will message"}
+ CodeNoMatchingSubscribers = Code{Code: 0x10, Reason: "no matching subscribers"}
+ CodeNoSubscriptionExisted = Code{Code: 0x11, Reason: "no subscription existed"}
+ CodeContinueAuthentication = Code{Code: 0x18, Reason: "continue authentication"}
+ CodeReAuthenticate = Code{Code: 0x19, Reason: "re-authenticate"}
+ ErrUnspecifiedError = Code{Code: 0x80, Reason: "unspecified error"}
+ ErrMalformedPacket = Code{Code: 0x81, Reason: "malformed packet"}
+ ErrMalformedProtocolName = Code{Code: 0x81, Reason: "malformed packet: protocol name"}
+ ErrMalformedProtocolVersion = Code{Code: 0x81, Reason: "malformed packet: protocol version"}
+ ErrMalformedFlags = Code{Code: 0x81, Reason: "malformed packet: flags"}
+ ErrMalformedKeepalive = Code{Code: 0x81, Reason: "malformed packet: keepalive"}
+ ErrMalformedPacketID = Code{Code: 0x81, Reason: "malformed packet: packet identifier"}
+ ErrMalformedTopic = Code{Code: 0x81, Reason: "malformed packet: topic"}
+ ErrMalformedWillTopic = Code{Code: 0x81, Reason: "malformed packet: will topic"}
+ ErrMalformedWillPayload = Code{Code: 0x81, Reason: "malformed packet: will message"}
+ ErrMalformedUsername = Code{Code: 0x81, Reason: "malformed packet: username"}
+ ErrMalformedPassword = Code{Code: 0x81, Reason: "malformed packet: password"}
+ ErrMalformedQos = Code{Code: 0x81, Reason: "malformed packet: qos"}
+ ErrMalformedOffsetUintOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset uint out of range"}
+ ErrMalformedOffsetBytesOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset bytes out of range"}
+ ErrMalformedOffsetByteOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset byte out of range"}
+ ErrMalformedOffsetBoolOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset boolean out of range"}
+ ErrMalformedInvalidUTF8 = Code{Code: 0x81, Reason: "malformed packet: invalid utf-8 string"}
+ ErrMalformedVariableByteInteger = Code{Code: 0x81, Reason: "malformed packet: variable byte integer out of range"}
+ ErrMalformedBadProperty = Code{Code: 0x81, Reason: "malformed packet: unknown property"}
+ ErrMalformedProperties = Code{Code: 0x81, Reason: "malformed packet: properties"}
+ ErrMalformedWillProperties = Code{Code: 0x81, Reason: "malformed packet: will properties"}
+ ErrMalformedSessionPresent = Code{Code: 0x81, Reason: "malformed packet: session present"}
+ ErrMalformedReasonCode = Code{Code: 0x81, Reason: "malformed packet: reason code"}
+ ErrProtocolViolation = Code{Code: 0x82, Reason: "protocol violation"}
+ ErrProtocolViolationProtocolName = Code{Code: 0x82, Reason: "protocol violation: protocol name"}
+ ErrProtocolViolationProtocolVersion = Code{Code: 0x82, Reason: "protocol violation: protocol version"}
+ ErrProtocolViolationReservedBit = Code{Code: 0x82, Reason: "protocol violation: reserved bit not 0"}
+ ErrProtocolViolationFlagNoUsername = Code{Code: 0x82, Reason: "protocol violation: username flag set but no value"}
+ ErrProtocolViolationFlagNoPassword = Code{Code: 0x82, Reason: "protocol violation: password flag set but no value"}
+ ErrProtocolViolationUsernameNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
+ ErrProtocolViolationPasswordNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
+ ErrProtocolViolationPasswordTooLong = Code{Code: 0x82, Reason: "protocol violation: password too long"}
+ ErrProtocolViolationUsernameTooLong = Code{Code: 0x82, Reason: "protocol violation: username too long"}
+ ErrProtocolViolationNoPacketID = Code{Code: 0x82, Reason: "protocol violation: missing packet id"}
+ ErrProtocolViolationSurplusPacketID = Code{Code: 0x82, Reason: "protocol violation: surplus packet id"}
+ ErrProtocolViolationQosOutOfRange = Code{Code: 0x82, Reason: "protocol violation: qos out of range"}
+ ErrProtocolViolationSecondConnect = Code{Code: 0x82, Reason: "protocol violation: second connect packet"}
+ ErrProtocolViolationZeroNonZeroExpiry = Code{Code: 0x82, Reason: "protocol violation: non-zero expiry"}
+ ErrProtocolViolationRequireFirstConnect = Code{Code: 0x82, Reason: "protocol violation: first packet must be connect"}
+ ErrProtocolViolationWillFlagNoPayload = Code{Code: 0x82, Reason: "protocol violation: will flag no payload"}
+ ErrProtocolViolationWillFlagSurplusRetain = Code{Code: 0x82, Reason: "protocol violation: will flag surplus retain"}
+ ErrProtocolViolationSurplusWildcard = Code{Code: 0x82, Reason: "protocol violation: topic contains wildcards"}
+ ErrProtocolViolationSurplusSubID = Code{Code: 0x82, Reason: "protocol violation: contained subscription identifier"}
+ ErrProtocolViolationInvalidTopic = Code{Code: 0x82, Reason: "protocol violation: invalid topic"}
+ ErrProtocolViolationInvalidSharedNoLocal = Code{Code: 0x82, Reason: "protocol violation: invalid shared no local"}
+ ErrProtocolViolationNoFilters = Code{Code: 0x82, Reason: "protocol violation: must contain at least one filter"}
+ ErrProtocolViolationInvalidReason = Code{Code: 0x82, Reason: "protocol violation: invalid reason"}
+ ErrProtocolViolationOversizeSubID = Code{Code: 0x82, Reason: "protocol violation: oversize subscription id"}
+ ErrProtocolViolationDupNoQos = Code{Code: 0x82, Reason: "protocol violation: dup true with no qos"}
+ ErrProtocolViolationUnsupportedProperty = Code{Code: 0x82, Reason: "protocol violation: unsupported property"}
+ ErrProtocolViolationNoTopic = Code{Code: 0x82, Reason: "protocol violation: no topic or alias"}
+ ErrImplementationSpecificError = Code{Code: 0x83, Reason: "implementation specific error"}
+ ErrRejectPacket = Code{Code: 0x83, Reason: "packet rejected"}
+ ErrUnsupportedProtocolVersion = Code{Code: 0x84, Reason: "unsupported protocol version"}
+ ErrClientIdentifierNotValid = Code{Code: 0x85, Reason: "client identifier not valid"}
+ ErrClientIdentifierTooLong = Code{Code: 0x85, Reason: "client identifier too long"}
+ ErrBadUsernameOrPassword = Code{Code: 0x86, Reason: "bad username or password"}
+ ErrNotAuthorized = Code{Code: 0x87, Reason: "not authorized"}
+ ErrServerUnavailable = Code{Code: 0x88, Reason: "server unavailable"}
+ ErrServerBusy = Code{Code: 0x89, Reason: "server busy"}
+ ErrBanned = Code{Code: 0x8A, Reason: "banned"}
+ ErrServerShuttingDown = Code{Code: 0x8B, Reason: "server shutting down"}
+ ErrBadAuthenticationMethod = Code{Code: 0x8C, Reason: "bad authentication method"}
+ ErrKeepAliveTimeout = Code{Code: 0x8D, Reason: "keep alive timeout"}
+ ErrSessionTakenOver = Code{Code: 0x8E, Reason: "session takeover"}
+ ErrTopicFilterInvalid = Code{Code: 0x8F, Reason: "topic filter invalid"}
+ ErrTopicNameInvalid = Code{Code: 0x90, Reason: "topic name invalid"}
+ ErrPacketIdentifierInUse = Code{Code: 0x91, Reason: "packet identifier in use"}
+ ErrPacketIdentifierNotFound = Code{Code: 0x92, Reason: "packet identifier not found"}
+ ErrReceiveMaximum = Code{Code: 0x93, Reason: "receive maximum exceeded"}
+ ErrTopicAliasInvalid = Code{Code: 0x94, Reason: "topic alias invalid"}
+ ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
+ ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
+ ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
+ ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"}
+ ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
+ ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
+ ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}
+ ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
+ ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
+ ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
+ ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptions not supported"}
+ ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
+ ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
+ ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
+ ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
+ ErrInlineSubscriptionHandlerInvalid = Code{Code: 0xA3, Reason: "inline subscription handler not valid."}
+
+ // MQTTv3 specific bytes.
+ Err3UnsupportedProtocolVersion = Code{Code: 0x01}
+ Err3ClientIdentifierNotValid = Code{Code: 0x02}
+ Err3ServerUnavailable = Code{Code: 0x03}
+ ErrMalformedUsernameOrPassword = Code{Code: 0x04}
+ Err3NotAuthorized = Code{Code: 0x05}
+
+ // V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
+ // This is required because MQTTv3 has different return byte specification.
+ // See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
+ V5CodesToV3 = map[Code]Code{
+ ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
+ ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
+ ErrServerUnavailable: Err3ServerUnavailable,
+ ErrMalformedUsername: ErrMalformedUsernameOrPassword,
+ ErrMalformedPassword: ErrMalformedUsernameOrPassword,
+ ErrBadUsernameOrPassword: Err3NotAuthorized,
+ }
+)
diff --git a/packets/codes_test.go b/packets/codes_test.go
new file mode 100644
index 0000000..aed8e57
--- /dev/null
+++ b/packets/codes_test.go
@@ -0,0 +1,29 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package packets
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestCodesString(t *testing.T) {
+ c := Code{
+ Reason: "test",
+ Code: 0x1,
+ }
+
+ require.Equal(t, "test", c.String())
+}
+
+func TestCodesError(t *testing.T) {
+ c := Code{
+ Reason: "error",
+ Code: 0x1,
+ }
+
+ require.Equal(t, "error", error(c).Error())
+}
diff --git a/packets/fixedheader.go b/packets/fixedheader.go
new file mode 100644
index 0000000..eb20451
--- /dev/null
+++ b/packets/fixedheader.go
@@ -0,0 +1,63 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package packets
+
+import (
+ "bytes"
+)
+
+// FixedHeader contains the values of the fixed header portion of the MQTT packet.
+type FixedHeader struct {
+ Remaining int `json:"remaining"` // the number of remaining bytes in the payload.
+ Type byte `json:"type"` // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
+ Qos byte `json:"qos"` // indicates the quality of service expected.
+ Dup bool `json:"dup"` // indicates if the packet was already sent at an earlier time.
+ Retain bool `json:"retain"` // whether the message should be retained.
+}
+
+// Encode encodes the FixedHeader and returns a bytes buffer.
+func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
+ buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
+ encodeLength(buf, int64(fh.Remaining))
+}
+
+// Decode extracts the specification bits from the header byte.
+func (fh *FixedHeader) Decode(hb byte) error {
+ fh.Type = hb >> 4 // Get the message type from the first 4 bytes.
+
+ switch fh.Type {
+ case Publish:
+ if (hb>>1)&0x01 > 0 && (hb>>1)&0x02 > 0 {
+ return ErrProtocolViolationQosOutOfRange // [MQTT-3.3.1-4]
+ }
+
+ fh.Dup = (hb>>3)&0x01 > 0 // is duplicate
+ fh.Qos = (hb >> 1) & 0x03 // qos flag
+ fh.Retain = hb&0x01 > 0 // is retain flag
+ case Pubrel:
+ fallthrough
+ case Subscribe:
+ fallthrough
+ case Unsubscribe:
+ if (hb>>0)&0x01 != 0 || (hb>>1)&0x01 != 1 || (hb>>2)&0x01 != 0 || (hb>>3)&0x01 != 0 { // [MQTT-3.8.1-1] [MQTT-3.10.1-1]
+ return ErrMalformedFlags
+ }
+
+ fh.Qos = (hb >> 1) & 0x03
+ default:
+ if (hb>>0)&0x01 != 0 ||
+ (hb>>1)&0x01 != 0 ||
+ (hb>>2)&0x01 != 0 ||
+ (hb>>3)&0x01 != 0 { // [MQTT-3.8.3-5] [MQTT-3.14.1-1] [MQTT-3.15.1-1]
+ return ErrMalformedFlags
+ }
+ }
+
+ if fh.Qos == 0 && fh.Dup {
+ return ErrProtocolViolationDupNoQos // [MQTT-3.3.1-2]
+ }
+
+ return nil
+}
diff --git a/packets/fixedheader_test.go b/packets/fixedheader_test.go
new file mode 100644
index 0000000..fe8c497
--- /dev/null
+++ b/packets/fixedheader_test.go
@@ -0,0 +1,237 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package packets
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+type fixedHeaderTable struct {
+ desc string
+ rawBytes []byte
+ header FixedHeader
+ packetError bool
+ expect error
+}
+
+var fixedHeaderExpected = []fixedHeaderTable{
+ {
+ desc: "connect",
+ rawBytes: []byte{Connect << 4, 0x00},
+ header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "connack",
+ rawBytes: []byte{Connack << 4, 0x00},
+ header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "publish",
+ rawBytes: []byte{Publish << 4, 0x00},
+ header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "publish qos 1",
+ rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
+ header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "publish qos 1 retain",
+ rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
+ header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
+ },
+ {
+ desc: "publish qos 2",
+ rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
+ header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "publish qos 2 retain",
+ rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
+ header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
+ },
+ {
+ desc: "publish dup qos 0",
+ rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
+ header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
+ expect: ErrProtocolViolationDupNoQos,
+ },
+ {
+ desc: "publish dup qos 0 retain",
+ rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
+ header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
+ expect: ErrProtocolViolationDupNoQos,
+ },
+ {
+ desc: "publish dup qos 1 retain",
+ rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
+ header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
+ },
+ {
+ desc: "publish dup qos 2 retain",
+ rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
+ header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
+ },
+ {
+ desc: "puback",
+ rawBytes: []byte{Puback << 4, 0x00},
+ header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "pubrec",
+ rawBytes: []byte{Pubrec << 4, 0x00},
+ header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "pubrel",
+ rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
+ header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "pubcomp",
+ rawBytes: []byte{Pubcomp << 4, 0x00},
+ header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "subscribe",
+ rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
+ header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "suback",
+ rawBytes: []byte{Suback << 4, 0x00},
+ header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "unsubscribe",
+ rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
+ header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "unsuback",
+ rawBytes: []byte{Unsuback << 4, 0x00},
+ header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "pingreq",
+ rawBytes: []byte{Pingreq << 4, 0x00},
+ header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "pingresp",
+ rawBytes: []byte{Pingresp << 4, 0x00},
+ header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "disconnect",
+ rawBytes: []byte{Disconnect << 4, 0x00},
+ header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
+ },
+ {
+ desc: "auth",
+ rawBytes: []byte{Auth << 4, 0x00},
+ header: FixedHeader{Type: Auth, Dup: false, Qos: 0, Retain: false, Remaining: 0},
+ },
+
+ // remaining length
+ {
+ desc: "remaining length 10",
+ rawBytes: []byte{Publish << 4, 0x0a},
+ header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
+ },
+ {
+ desc: "remaining length 512",
+ rawBytes: []byte{Publish << 4, 0x80, 0x04},
+ header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
+ },
+ {
+ desc: "remaining length 978",
+ rawBytes: []byte{Publish << 4, 0xd2, 0x07},
+ header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
+ },
+ {
+ desc: "remaining length 20202",
+ rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
+ header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
+ },
+ {
+ desc: "remaining length oversize",
+ rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
+ header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
+ packetError: true,
+ },
+
+ // Invalid flags for packet
+ {
+ desc: "invalid type dup is true",
+ rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
+ header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
+ expect: ErrMalformedFlags,
+ },
+ {
+ desc: "invalid type qos is 1",
+ rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
+ header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
+ expect: ErrMalformedFlags,
+ },
+ {
+ desc: "invalid type retain is true",
+ rawBytes: []byte{Connect<<4 | 1, 0x00},
+ header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
+ expect: ErrMalformedFlags,
+ },
+ {
+ desc: "invalid publish qos bits 1 + 2 set",
+ rawBytes: []byte{Publish<<4 | 1<<1 | 1<<2, 0x00},
+ header: FixedHeader{Type: Publish},
+ expect: ErrProtocolViolationQosOutOfRange,
+ },
+ {
+ desc: "invalid pubrel bits 3,2,1,0 should be 0,0,1,0",
+ rawBytes: []byte{Pubrel<<4 | 1<<2 | 1<<0, 0x00},
+ header: FixedHeader{Type: Pubrel, Qos: 1},
+ expect: ErrMalformedFlags,
+ },
+ {
+ desc: "invalid subscribe bits 3,2,1,0 should be 0,0,1,0",
+ rawBytes: []byte{Subscribe<<4 | 1<<2, 0x00},
+ header: FixedHeader{Type: Subscribe, Qos: 1},
+ expect: ErrMalformedFlags,
+ },
+}
+
+func TestFixedHeaderEncode(t *testing.T) {
+ for _, wanted := range fixedHeaderExpected {
+ t.Run(wanted.desc, func(t *testing.T) {
+ buf := new(bytes.Buffer)
+ wanted.header.Encode(buf)
+ if wanted.expect == nil {
+ require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()))
+ require.EqualValues(t, wanted.rawBytes, buf.Bytes())
+ }
+ })
+ }
+}
+
+func TestFixedHeaderDecode(t *testing.T) {
+ for _, wanted := range fixedHeaderExpected {
+ t.Run(wanted.desc, func(t *testing.T) {
+ fh := new(FixedHeader)
+ err := fh.Decode(wanted.rawBytes[0])
+ if wanted.expect != nil {
+ require.Equal(t, wanted.expect, err)
+ } else {
+ require.NoError(t, err)
+ require.Equal(t, wanted.header.Type, fh.Type)
+ require.Equal(t, wanted.header.Dup, fh.Dup)
+ require.Equal(t, wanted.header.Qos, fh.Qos)
+ require.Equal(t, wanted.header.Retain, fh.Retain)
+ }
+ })
+ }
+}
diff --git a/packets/packets.go b/packets/packets.go
new file mode 100644
index 0000000..d52d5af
--- /dev/null
+++ b/packets/packets.go
@@ -0,0 +1,1173 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package packets
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "math"
+ "strconv"
+ "strings"
+ "sync"
+
+ "testmqtt/mempool"
+)
+
+// 所有有效的数据包类型及其数据包标识符。
+// All valid packet types and their packet identifiers.
+const (
+ Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
+ Connect // 1
+ Connack // 2
+ Publish // 3
+ Puback // 4
+ Pubrec // 5
+ Pubrel // 6
+ Pubcomp // 7
+ Subscribe // 8
+ Suback // 9
+ Unsubscribe // 10
+ Unsuback // 11
+ Pingreq // 12
+ Pingresp // 13
+ Disconnect // 14
+ Auth // 15
+ WillProperties byte = 99 // Special byte for validating Will Properties.
+)
+
+var (
+ // ErrNoValidPacketAvailable indicates the packet type byte provided does not exist in the mqtt specification.
+ ErrNoValidPacketAvailable = errors.New("no valid packet available")
+
+ // PacketNames is a map of packet bytes to human-readable names, for easier debugging.
+ PacketNames = map[byte]string{
+ 0: "Reserved",
+ 1: "Connect",
+ 2: "Connack",
+ 3: "Publish",
+ 4: "Puback",
+ 5: "Pubrec",
+ 6: "Pubrel",
+ 7: "Pubcomp",
+ 8: "Subscribe",
+ 9: "Suback",
+ 10: "Unsubscribe",
+ 11: "Unsuback",
+ 12: "Pingreq",
+ 13: "Pingresp",
+ 14: "Disconnect",
+ 15: "Auth",
+ }
+)
+
+// Packets is a concurrency safe map of packets.
+type Packets struct {
+ internal map[string]Packet
+ sync.RWMutex
+}
+
+// NewPackets returns a new instance of Packets.
+func NewPackets() *Packets {
+ return &Packets{
+ internal: map[string]Packet{},
+ }
+}
+
+// Add adds a new packet to the map.
+func (p *Packets) Add(id string, val Packet) {
+ p.Lock()
+ defer p.Unlock()
+ p.internal[id] = val
+}
+
+// GetAll returns all packets in the map.
+func (p *Packets) GetAll() map[string]Packet {
+ p.RLock()
+ defer p.RUnlock()
+ m := map[string]Packet{}
+ for k, v := range p.internal {
+ m[k] = v
+ }
+ return m
+}
+
+// Get returns a specific packet in the map by packet id.
+func (p *Packets) Get(id string) (val Packet, ok bool) {
+ p.RLock()
+ defer p.RUnlock()
+ val, ok = p.internal[id]
+ return val, ok
+}
+
+// Len returns the number of packets in the map.
+func (p *Packets) Len() int {
+ p.RLock()
+ defer p.RUnlock()
+ val := len(p.internal)
+ return val
+}
+
+// Delete removes a packet from the map by packet id.
+func (p *Packets) Delete(id string) {
+ p.Lock()
+ defer p.Unlock()
+ delete(p.internal, id)
+}
+
+// Packet represents an MQTT packet. Instead of providing a packet interface
+// variant packet structs, this is a single concrete packet type to cover all packet
+// types, which allows us to take advantage of various compiler optimizations. It
+// contains a combination of mqtt spec values and internal broker control codes.
+type Packet struct {
+ Connect ConnectParams // parameters for connect packets (just for organisation)
+ Properties Properties // all mqtt v5 packet properties
+ Payload []byte // a message/payload for publish packets
+ ReasonCodes []byte // one or more reason codes for multi-reason responses (suback, etc)
+ Filters Subscriptions // a list of subscription filters and their properties (subscribe, unsubscribe)
+ TopicName string // the topic a payload is being published to
+ Origin string // client id of the client who is issuing the packet (mostly internal use)
+ FixedHeader FixedHeader // -
+ Created int64 // unix timestamp indicating time packet was created/received on the server
+ Expiry int64 // unix timestamp indicating when the packet will expire and should be deleted
+ Mods Mods // internal broker control values for controlling certain mqtt v5 compliance
+ PacketID uint16 // packet id for the packet (publish, qos, etc)
+ ProtocolVersion byte // protocol version of the client the packet belongs to
+ SessionPresent bool // session existed for connack
+ ReasonCode byte // reason code for a packet response (acks, etc)
+ ReservedBit byte // reserved, do not use (except in testing)
+ Ignore bool // if true, do not perform any message forwarding operations
+}
+
+// Mods specifies certain values required for certain mqtt v5 compliance within packet encoding/decoding.
+type Mods struct {
+ MaxSize uint32 // the maximum packet size specified by the client / server
+ DisallowProblemInfo bool // if problem info is disallowed
+ AllowResponseInfo bool // if response info is disallowed
+}
+
+// ConnectParams contains packet values which are specifically related to connect packets.
+type ConnectParams struct {
+ WillProperties Properties `json:"willProperties"` // -
+ Password []byte `json:"password"` // -
+ Username []byte `json:"username"` // -
+ ProtocolName []byte `json:"protocolName"` // -
+ WillPayload []byte `json:"willPayload"` // -
+ ClientIdentifier string `json:"clientId"` // -
+ WillTopic string `json:"willTopic"` // -
+ Keepalive uint16 `json:"keepalive"` // -
+ PasswordFlag bool `json:"passwordFlag"` // -
+ UsernameFlag bool `json:"usernameFlag"` // -
+ WillQos byte `json:"willQos"` // -
+ WillFlag bool `json:"willFlag"` // -
+ WillRetain bool `json:"willRetain"` // -
+ Clean bool `json:"clean"` // CleanSession in v3.1.1, CleanStart in v5
+}
+
+// Subscriptions is a slice of Subscription.
+type Subscriptions []Subscription // must be a slice to retain order.
+
+// Subscription contains details about a client subscription to a topic filter.
+type Subscription struct {
+ ShareName []string
+ Filter string
+ Identifier int
+ Identifiers map[string]int
+ RetainHandling byte
+ Qos byte
+ RetainAsPublished bool
+ NoLocal bool
+ FwdRetainedFlag bool // true if the subscription forms part of a publish response to a client subscription and packet is retained.
+}
+
+// Copy creates a new instance of a packet, but with an empty header for inheriting new QoS flags, etc.
+func (pk *Packet) Copy(allowTransfer bool) Packet {
+ p := Packet{
+ FixedHeader: FixedHeader{
+ Remaining: pk.FixedHeader.Remaining,
+ Type: pk.FixedHeader.Type,
+ Retain: pk.FixedHeader.Retain,
+ Dup: false, // [MQTT-4.3.1-1] [MQTT-4.3.2-2]
+ Qos: pk.FixedHeader.Qos,
+ },
+ Mods: Mods{
+ MaxSize: pk.Mods.MaxSize,
+ },
+ ReservedBit: pk.ReservedBit,
+ ProtocolVersion: pk.ProtocolVersion,
+ Connect: ConnectParams{
+ ClientIdentifier: pk.Connect.ClientIdentifier,
+ Keepalive: pk.Connect.Keepalive,
+ WillQos: pk.Connect.WillQos,
+ WillTopic: pk.Connect.WillTopic,
+ WillFlag: pk.Connect.WillFlag,
+ WillRetain: pk.Connect.WillRetain,
+ WillProperties: pk.Connect.WillProperties.Copy(allowTransfer),
+ Clean: pk.Connect.Clean,
+ },
+ TopicName: pk.TopicName,
+ Properties: pk.Properties.Copy(allowTransfer),
+ SessionPresent: pk.SessionPresent,
+ ReasonCode: pk.ReasonCode,
+ Filters: pk.Filters,
+ Created: pk.Created,
+ Expiry: pk.Expiry,
+ Origin: pk.Origin,
+ }
+
+ if allowTransfer {
+ p.PacketID = pk.PacketID
+ }
+
+ if len(pk.Connect.ProtocolName) > 0 {
+ p.Connect.ProtocolName = append([]byte{}, pk.Connect.ProtocolName...)
+ }
+
+ if len(pk.Connect.Password) > 0 {
+ p.Connect.PasswordFlag = true
+ p.Connect.Password = append([]byte{}, pk.Connect.Password...)
+ }
+
+ if len(pk.Connect.Username) > 0 {
+ p.Connect.UsernameFlag = true
+ p.Connect.Username = append([]byte{}, pk.Connect.Username...)
+ }
+
+ if len(pk.Connect.WillPayload) > 0 {
+ p.Connect.WillPayload = append([]byte{}, pk.Connect.WillPayload...)
+ }
+
+ if len(pk.Payload) > 0 {
+ p.Payload = append([]byte{}, pk.Payload...)
+ }
+
+ if len(pk.ReasonCodes) > 0 {
+ p.ReasonCodes = append([]byte{}, pk.ReasonCodes...)
+ }
+
+ return p
+}
+
+// Merge merges a new subscription with a base subscription, preserving the highest
+// qos value, matched identifiers and any special properties.
+func (s Subscription) Merge(n Subscription) Subscription {
+ if s.Identifiers == nil {
+ s.Identifiers = map[string]int{
+ s.Filter: s.Identifier,
+ }
+ }
+
+ if n.Identifier > 0 {
+ s.Identifiers[n.Filter] = n.Identifier
+ }
+
+ if n.Qos > s.Qos {
+ s.Qos = n.Qos // [MQTT-3.3.4-2]
+ }
+
+ if n.NoLocal {
+ s.NoLocal = true // [MQTT-3.8.3-3]
+ }
+
+ return s
+}
+
+// encode encodes a subscription and properties into bytes.
+func (s Subscription) encode() byte {
+ var flag byte
+ flag |= s.Qos
+
+ if s.NoLocal {
+ flag |= 1 << 2
+ }
+
+ if s.RetainAsPublished {
+ flag |= 1 << 3
+ }
+
+ flag |= s.RetainHandling << 4
+ return flag
+}
+
+// decode decodes subscription bytes into a subscription struct.
+func (s *Subscription) decode(b byte) {
+ s.Qos = b & 3 // byte
+ s.NoLocal = 1&(b>>2) > 0 // bool
+ s.RetainAsPublished = 1&(b>>3) > 0 // bool
+ s.RetainHandling = 3 & (b >> 4) // byte
+}
+
+// ConnectEncode encodes a connect packet.
+func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
+ nb := mempool.GetBuffer()
+ defer mempool.PutBuffer(nb)
+ nb.Write(encodeBytes(pk.Connect.ProtocolName))
+ nb.WriteByte(pk.ProtocolVersion)
+
+ nb.WriteByte(
+ encodeBool(pk.Connect.Clean)<<1 |
+ encodeBool(pk.Connect.WillFlag)<<2 |
+ pk.Connect.WillQos<<3 |
+ encodeBool(pk.Connect.WillRetain)<<5 |
+ encodeBool(pk.Connect.PasswordFlag)<<6 |
+ encodeBool(pk.Connect.UsernameFlag)<<7 |
+ 0, // [MQTT-2.1.3-1]
+ )
+
+ nb.Write(encodeUint16(pk.Connect.Keepalive))
+
+ if pk.ProtocolVersion == 5 {
+ pb := mempool.GetBuffer()
+ defer mempool.PutBuffer(pb)
+ (&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0)
+ nb.Write(pb.Bytes())
+ }
+
+ nb.Write(encodeString(pk.Connect.ClientIdentifier))
+
+ if pk.Connect.WillFlag {
+ if pk.ProtocolVersion == 5 {
+ pb := mempool.GetBuffer()
+ defer mempool.PutBuffer(pb)
+ (&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0)
+ nb.Write(pb.Bytes())
+ }
+
+ nb.Write(encodeString(pk.Connect.WillTopic))
+ nb.Write(encodeBytes(pk.Connect.WillPayload))
+ }
+
+ if pk.Connect.UsernameFlag {
+ nb.Write(encodeBytes(pk.Connect.Username))
+ }
+
+ if pk.Connect.PasswordFlag {
+ nb.Write(encodeBytes(pk.Connect.Password))
+ }
+
+ pk.FixedHeader.Remaining = nb.Len()
+ pk.FixedHeader.Encode(buf)
+ buf.Write(nb.Bytes())
+
+ return nil
+}
+
+// ConnectDecode decodes a connect packet.
+func (pk *Packet) ConnectDecode(buf []byte) error {
+ var offset int
+ var err error
+
+ pk.Connect.ProtocolName, offset, err = decodeBytes(buf, 0)
+ if err != nil {
+ return ErrMalformedProtocolName
+ }
+
+ pk.ProtocolVersion, offset, err = decodeByte(buf, offset)
+ if err != nil {
+ return ErrMalformedProtocolVersion
+ }
+
+ flags, offset, err := decodeByte(buf, offset)
+ if err != nil {
+ return ErrMalformedFlags
+ }
+
+ pk.ReservedBit = 1 & flags
+ pk.Connect.Clean = 1&(flags>>1) > 0
+ pk.Connect.WillFlag = 1&(flags>>2) > 0
+ pk.Connect.WillQos = 3 & (flags >> 3) // this one is not a bool
+ pk.Connect.WillRetain = 1&(flags>>5) > 0
+ pk.Connect.PasswordFlag = 1&(flags>>6) > 0
+ pk.Connect.UsernameFlag = 1&(flags>>7) > 0
+
+ pk.Connect.Keepalive, offset, err = decodeUint16(buf, offset)
+ if err != nil {
+ return ErrMalformedKeepalive
+ }
+
+ if pk.ProtocolVersion == 5 {
+ n, err := pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
+ }
+ offset += n
+ }
+
+ pk.Connect.ClientIdentifier, offset, err = decodeString(buf, offset) // [MQTT-3.1.3-1] [MQTT-3.1.3-2] [MQTT-3.1.3-3] [MQTT-3.1.3-4]
+ if err != nil {
+ return ErrClientIdentifierNotValid // [MQTT-3.1.3-8]
+ }
+
+ if pk.Connect.WillFlag { // [MQTT-3.1.2-7]
+ if pk.ProtocolVersion == 5 {
+ n, err := pk.Connect.WillProperties.Decode(WillProperties, bytes.NewBuffer(buf[offset:]))
+ if err != nil {
+ return ErrMalformedWillProperties
+ }
+ offset += n
+ }
+
+ pk.Connect.WillTopic, offset, err = decodeString(buf, offset)
+ if err != nil {
+ return ErrMalformedWillTopic
+ }
+
+ pk.Connect.WillPayload, offset, err = decodeBytes(buf, offset)
+ if err != nil {
+ return ErrMalformedWillPayload
+ }
+ }
+
+ if pk.Connect.UsernameFlag { // [MQTT-3.1.3-12]
+ if offset >= len(buf) { // we are at the end of the packet
+ return ErrProtocolViolationFlagNoUsername // [MQTT-3.1.2-17]
+ }
+
+ pk.Connect.Username, offset, err = decodeBytes(buf, offset)
+ if err != nil {
+ return ErrMalformedUsername
+ }
+ }
+
+ if pk.Connect.PasswordFlag {
+ pk.Connect.Password, _, err = decodeBytes(buf, offset)
+ if err != nil {
+ return ErrMalformedPassword
+ }
+ }
+
+ return nil
+}
+
+// ConnectValidate ensures the connect packet is compliant.
+func (pk *Packet) ConnectValidate() Code {
+ if !bytes.Equal(pk.Connect.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) && !bytes.Equal(pk.Connect.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) {
+ return ErrProtocolViolationProtocolName // [MQTT-3.1.2-1]
+ }
+
+ if (bytes.Equal(pk.Connect.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) && pk.ProtocolVersion != 3) ||
+ (bytes.Equal(pk.Connect.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) && pk.ProtocolVersion != 4 && pk.ProtocolVersion != 5) {
+ return ErrProtocolViolationProtocolVersion // [MQTT-3.1.2-2]
+ }
+
+ if pk.ReservedBit != 0 {
+ return ErrProtocolViolationReservedBit // [MQTT-3.1.2-3]
+ }
+
+ if len(pk.Connect.Password) > math.MaxUint16 {
+ return ErrProtocolViolationPasswordTooLong
+ }
+
+ if len(pk.Connect.Username) > math.MaxUint16 {
+ return ErrProtocolViolationUsernameTooLong
+ }
+
+ if !pk.Connect.UsernameFlag && len(pk.Connect.Username) > 0 {
+ return ErrProtocolViolationUsernameNoFlag // [MQTT-3.1.2-16]
+ }
+
+ if pk.Connect.PasswordFlag && len(pk.Connect.Password) == 0 {
+ return ErrProtocolViolationFlagNoPassword // [MQTT-3.1.2-19]
+ }
+
+ if !pk.Connect.PasswordFlag && len(pk.Connect.Password) > 0 {
+ return ErrProtocolViolationPasswordNoFlag // [MQTT-3.1.2-18]
+ }
+
+ if len(pk.Connect.ClientIdentifier) > math.MaxUint16 {
+ return ErrClientIdentifierNotValid
+ }
+
+ if pk.Connect.WillFlag {
+ if len(pk.Connect.WillPayload) == 0 || pk.Connect.WillTopic == "" {
+ return ErrProtocolViolationWillFlagNoPayload // [MQTT-3.1.2-9]
+ }
+
+ if pk.Connect.WillQos > 2 {
+ return ErrProtocolViolationQosOutOfRange // [MQTT-3.1.2-12]
+ }
+ }
+
+ if !pk.Connect.WillFlag && pk.Connect.WillRetain {
+ return ErrProtocolViolationWillFlagSurplusRetain // [MQTT-3.1.2-13]
+ }
+
+ return CodeSuccess
+}
+
+// ConnackEncode encodes a Connack packet.
+func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
+ nb := mempool.GetBuffer()
+ defer mempool.PutBuffer(nb)
+ nb.WriteByte(encodeBool(pk.SessionPresent))
+ nb.WriteByte(pk.ReasonCode)
+
+ if pk.ProtocolVersion == 5 {
+ pb := mempool.GetBuffer()
+ defer mempool.PutBuffer(pb)
+ pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode
+ nb.Write(pb.Bytes())
+ }
+
+ pk.FixedHeader.Remaining = nb.Len()
+ pk.FixedHeader.Encode(buf)
+ buf.Write(nb.Bytes())
+
+ return nil
+}
+
+// ConnackDecode decodes a Connack packet.
+func (pk *Packet) ConnackDecode(buf []byte) error {
+ var offset int
+ var err error
+
+ pk.SessionPresent, offset, err = decodeByteBool(buf, 0)
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent)
+ }
+
+ pk.ReasonCode, offset, err = decodeByte(buf, offset)
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedReasonCode)
+ }
+
+ if pk.ProtocolVersion == 5 {
+ _, err := pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
+ }
+ }
+
+ return nil
+}
+
+// DisconnectEncode encodes a Disconnect packet.
+func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
+ nb := mempool.GetBuffer()
+ defer mempool.PutBuffer(nb)
+
+ if pk.ProtocolVersion == 5 {
+ nb.WriteByte(pk.ReasonCode)
+
+ pb := mempool.GetBuffer()
+ defer mempool.PutBuffer(pb)
+ pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
+ nb.Write(pb.Bytes())
+ }
+
+ pk.FixedHeader.Remaining = nb.Len()
+ pk.FixedHeader.Encode(buf)
+ buf.Write(nb.Bytes())
+
+ return nil
+}
+
+// DisconnectDecode decodes a Disconnect packet.
+func (pk *Packet) DisconnectDecode(buf []byte) error {
+ if pk.ProtocolVersion == 5 && pk.FixedHeader.Remaining > 1 {
+ var err error
+ var offset int
+ pk.ReasonCode, offset, err = decodeByte(buf, offset)
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedReasonCode)
+ }
+
+ if pk.FixedHeader.Remaining > 2 {
+ _, err = pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
+ }
+ }
+ }
+
+ return nil
+}
+
+// PingreqEncode encodes a Pingreq packet.
+func (pk *Packet) PingreqEncode(buf *bytes.Buffer) error {
+ pk.FixedHeader.Encode(buf)
+ return nil
+}
+
+// PingreqDecode decodes a Pingreq packet.
+func (pk *Packet) PingreqDecode(buf []byte) error {
+ return nil
+}
+
+// PingrespEncode encodes a Pingresp packet.
+func (pk *Packet) PingrespEncode(buf *bytes.Buffer) error {
+ pk.FixedHeader.Encode(buf)
+ return nil
+}
+
+// PingrespDecode decodes a Pingres packet.
+func (pk *Packet) PingrespDecode(buf []byte) error {
+ return nil
+}
+
+// PublishEncode encodes a Publish packet.
+func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
+ nb := mempool.GetBuffer()
+ defer mempool.PutBuffer(nb)
+
+ nb.Write(encodeString(pk.TopicName)) // [MQTT-3.3.2-1]
+
+ if pk.FixedHeader.Qos > 0 {
+ if pk.PacketID == 0 {
+ return ErrProtocolViolationNoPacketID // [MQTT-2.2.1-2]
+ }
+ nb.Write(encodeUint16(pk.PacketID))
+ }
+
+ if pk.ProtocolVersion == 5 {
+ pb := mempool.GetBuffer()
+ defer mempool.PutBuffer(pb)
+ pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload))
+ nb.Write(pb.Bytes())
+ }
+
+ pk.FixedHeader.Remaining = nb.Len() + len(pk.Payload)
+ pk.FixedHeader.Encode(buf)
+ buf.Write(nb.Bytes())
+ buf.Write(pk.Payload)
+
+ return nil
+}
+
+// PublishDecode extracts the data values from the packet.
+func (pk *Packet) PublishDecode(buf []byte) error {
+ var offset int
+ var err error
+
+ pk.TopicName, offset, err = decodeString(buf, 0) // [MQTT-3.3.2-1]
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
+ }
+
+ if pk.FixedHeader.Qos > 0 {
+ pk.PacketID, offset, err = decodeUint16(buf, offset)
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
+ }
+ }
+
+ if pk.ProtocolVersion == 5 {
+ n, err := pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
+ }
+
+ offset += n
+ }
+
+ pk.Payload = buf[offset:]
+
+ return nil
+}
+
+// PublishValidate validates a publish packet.
+func (pk *Packet) PublishValidate(topicAliasMaximum uint16) Code {
+ if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
+ return ErrProtocolViolationNoPacketID // [MQTT-2.2.1-3] [MQTT-2.2.1-4]
+ }
+
+ if pk.FixedHeader.Qos == 0 && pk.PacketID > 0 {
+ return ErrProtocolViolationSurplusPacketID // [MQTT-2.2.1-2]
+ }
+
+ if strings.ContainsAny(pk.TopicName, "+#") {
+ return ErrProtocolViolationSurplusWildcard // [MQTT-3.3.2-2]
+ }
+
+ if pk.Properties.TopicAlias > topicAliasMaximum {
+ return ErrTopicAliasInvalid // [MQTT-3.2.2-17] [MQTT-3.3.2-9] ~[MQTT-3.3.2-10] [MQTT-3.3.2-12]
+ }
+
+ if pk.TopicName == "" && pk.Properties.TopicAlias == 0 {
+ return ErrProtocolViolationNoTopic // ~[MQTT-3.3.2-8]
+ }
+
+ if pk.Properties.TopicAliasFlag && pk.Properties.TopicAlias == 0 {
+ return ErrTopicAliasInvalid // [MQTT-3.3.2-8]
+ }
+
+ if len(pk.Properties.SubscriptionIdentifier) > 0 {
+ return ErrProtocolViolationSurplusSubID // [MQTT-3.3.4-6]
+ }
+
+ return CodeSuccess
+}
+
+// encodePubAckRelRecComp encodes a Puback, Pubrel, Pubrec, or Pubcomp packet.
+func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
+ nb := mempool.GetBuffer()
+ defer mempool.PutBuffer(nb)
+ nb.Write(encodeUint16(pk.PacketID))
+
+ if pk.ProtocolVersion == 5 {
+ pb := mempool.GetBuffer()
+ defer mempool.PutBuffer(pb)
+ pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
+ if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 {
+ nb.WriteByte(pk.ReasonCode)
+ }
+
+ if pb.Len() > 1 {
+ nb.Write(pb.Bytes())
+ }
+ }
+
+ pk.FixedHeader.Remaining = nb.Len()
+ pk.FixedHeader.Encode(buf)
+ buf.Write(nb.Bytes())
+ return nil
+}
+
+// decode extracts the data values from a Puback, Pubrel, Pubrec, or Pubcomp packet.
+func (pk *Packet) decodePubAckRelRecComp(buf []byte) error {
+ var offset int
+ var err error
+ pk.PacketID, offset, err = decodeUint16(buf, offset)
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
+ }
+
+ if pk.ProtocolVersion == 5 && pk.FixedHeader.Remaining > 2 {
+ pk.ReasonCode, offset, err = decodeByte(buf, offset)
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedReasonCode)
+ }
+
+ if pk.FixedHeader.Remaining > 3 {
+ _, err = pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
+ }
+ }
+ }
+
+ return nil
+}
+
+// PubackEncode encodes a Puback packet.
+func (pk *Packet) PubackEncode(buf *bytes.Buffer) error {
+ return pk.encodePubAckRelRecComp(buf)
+}
+
+// PubackDecode decodes a Puback packet.
+func (pk *Packet) PubackDecode(buf []byte) error {
+ return pk.decodePubAckRelRecComp(buf)
+}
+
+// PubcompEncode encodes a Pubcomp packet.
+func (pk *Packet) PubcompEncode(buf *bytes.Buffer) error {
+ return pk.encodePubAckRelRecComp(buf)
+}
+
+// PubcompDecode decodes a Pubcomp packet.
+func (pk *Packet) PubcompDecode(buf []byte) error {
+ return pk.decodePubAckRelRecComp(buf)
+}
+
+// PubrecEncode encodes a Pubrec packet.
+func (pk *Packet) PubrecEncode(buf *bytes.Buffer) error {
+ return pk.encodePubAckRelRecComp(buf)
+}
+
+// PubrecDecode decodes a Pubrec packet.
+func (pk *Packet) PubrecDecode(buf []byte) error {
+ return pk.decodePubAckRelRecComp(buf)
+}
+
+// PubrelEncode encodes a Pubrel packet.
+func (pk *Packet) PubrelEncode(buf *bytes.Buffer) error {
+ return pk.encodePubAckRelRecComp(buf)
+}
+
+// PubrelDecode decodes a Pubrel packet.
+func (pk *Packet) PubrelDecode(buf []byte) error {
+ return pk.decodePubAckRelRecComp(buf)
+}
+
+// ReasonCodeValid returns true if the provided reason code is valid for the packet type.
+func (pk *Packet) ReasonCodeValid() bool {
+ switch pk.FixedHeader.Type {
+ case Pubrec:
+ return bytes.Contains([]byte{
+ CodeSuccess.Code,
+ CodeNoMatchingSubscribers.Code,
+ ErrUnspecifiedError.Code,
+ ErrImplementationSpecificError.Code,
+ ErrNotAuthorized.Code,
+ ErrTopicNameInvalid.Code,
+ ErrPacketIdentifierInUse.Code,
+ ErrQuotaExceeded.Code,
+ ErrPayloadFormatInvalid.Code,
+ }, []byte{pk.ReasonCode})
+ case Pubrel:
+ fallthrough
+ case Pubcomp:
+ return bytes.Contains([]byte{
+ CodeSuccess.Code,
+ ErrPacketIdentifierNotFound.Code,
+ }, []byte{pk.ReasonCode})
+ case Suback:
+ return bytes.Contains([]byte{
+ CodeGrantedQos0.Code,
+ CodeGrantedQos1.Code,
+ CodeGrantedQos2.Code,
+ ErrUnspecifiedError.Code,
+ ErrImplementationSpecificError.Code,
+ ErrNotAuthorized.Code,
+ ErrTopicFilterInvalid.Code,
+ ErrPacketIdentifierInUse.Code,
+ ErrQuotaExceeded.Code,
+ ErrSharedSubscriptionsNotSupported.Code,
+ ErrSubscriptionIdentifiersNotSupported.Code,
+ ErrWildcardSubscriptionsNotSupported.Code,
+ }, []byte{pk.ReasonCode})
+ case Unsuback:
+ return bytes.Contains([]byte{
+ CodeSuccess.Code,
+ CodeNoSubscriptionExisted.Code,
+ ErrUnspecifiedError.Code,
+ ErrImplementationSpecificError.Code,
+ ErrNotAuthorized.Code,
+ ErrTopicFilterInvalid.Code,
+ ErrPacketIdentifierInUse.Code,
+ }, []byte{pk.ReasonCode})
+ }
+
+ return true
+}
+
+// SubackEncode encodes a Suback packet.
+func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
+ nb := mempool.GetBuffer()
+ defer mempool.PutBuffer(nb)
+ nb.Write(encodeUint16(pk.PacketID))
+
+ if pk.ProtocolVersion == 5 {
+ pb := mempool.GetBuffer()
+ defer mempool.PutBuffer(pb)
+ pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes))
+ nb.Write(pb.Bytes())
+ }
+
+ nb.Write(pk.ReasonCodes)
+
+ pk.FixedHeader.Remaining = nb.Len()
+ pk.FixedHeader.Encode(buf)
+ buf.Write(nb.Bytes())
+
+ return nil
+}
+
+// SubackDecode decodes a Suback packet.
+func (pk *Packet) SubackDecode(buf []byte) error {
+ var offset int
+ var err error
+
+ pk.PacketID, offset, err = decodeUint16(buf, offset)
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
+ }
+
+ if pk.ProtocolVersion == 5 {
+ n, err := pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
+ }
+ offset += n
+ }
+
+ pk.ReasonCodes = buf[offset:]
+
+ return nil
+}
+
+// SubscribeEncode encodes a Subscribe packet.
+func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
+ if pk.PacketID == 0 {
+ return ErrProtocolViolationNoPacketID
+ }
+
+ nb := mempool.GetBuffer()
+ defer mempool.PutBuffer(nb)
+ nb.Write(encodeUint16(pk.PacketID))
+
+ xb := mempool.GetBuffer() // capture and write filters after length checks
+ defer mempool.PutBuffer(xb)
+ for _, opts := range pk.Filters {
+ xb.Write(encodeString(opts.Filter)) // [MQTT-3.8.3-1]
+ if pk.ProtocolVersion == 5 {
+ xb.WriteByte(opts.encode())
+ } else {
+ xb.WriteByte(opts.Qos)
+ }
+ }
+
+ if pk.ProtocolVersion == 5 {
+ pb := mempool.GetBuffer()
+ defer mempool.PutBuffer(pb)
+ pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
+ nb.Write(pb.Bytes())
+ }
+
+ nb.Write(xb.Bytes())
+
+ pk.FixedHeader.Remaining = nb.Len()
+ pk.FixedHeader.Encode(buf)
+ buf.Write(nb.Bytes())
+
+ return nil
+}
+
+// SubscribeDecode decodes a Subscribe packet.
+func (pk *Packet) SubscribeDecode(buf []byte) error {
+ var offset int
+ var err error
+
+ pk.PacketID, offset, err = decodeUint16(buf, offset)
+ if err != nil {
+ return ErrMalformedPacketID
+ }
+
+ if pk.ProtocolVersion == 5 {
+ n, err := pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
+ }
+ offset += n
+ }
+
+ var filter string
+ pk.Filters = Subscriptions{}
+ for offset < len(buf) {
+ filter, offset, err = decodeString(buf, offset) // [MQTT-3.8.3-1]
+ if err != nil {
+ return ErrMalformedTopic
+ }
+
+ var option byte
+ sub := &Subscription{
+ Filter: filter,
+ }
+
+ if pk.ProtocolVersion == 5 {
+ sub.decode(buf[offset])
+ offset += 1
+ } else {
+ option, offset, err = decodeByte(buf, offset)
+ if err != nil {
+ return ErrMalformedQos
+ }
+ sub.Qos = option
+ }
+
+ if len(pk.Properties.SubscriptionIdentifier) > 0 {
+ sub.Identifier = pk.Properties.SubscriptionIdentifier[0]
+ }
+
+ if sub.Qos > 2 {
+ return ErrProtocolViolationQosOutOfRange
+ }
+
+ pk.Filters = append(pk.Filters, *sub)
+ }
+
+ return nil
+}
+
+// SubscribeValidate ensures the packet is compliant.
+func (pk *Packet) SubscribeValidate() Code {
+ if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
+ return ErrProtocolViolationNoPacketID // [MQTT-2.2.1-3] [MQTT-2.2.1-4]
+ }
+
+ if len(pk.Filters) == 0 {
+ return ErrProtocolViolationNoFilters // [MQTT-3.10.3-2]
+ }
+
+ for _, v := range pk.Filters {
+ if v.Identifier > 268435455 { // 3.3.2.3.8 The Subscription Identifier can have the value of 1 to 268,435,455.
+ return ErrProtocolViolationOversizeSubID //
+ }
+ }
+
+ return CodeSuccess
+}
+
+// UnsubackEncode encodes an Unsuback packet.
+func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
+ nb := mempool.GetBuffer()
+ defer mempool.PutBuffer(nb)
+ nb.Write(encodeUint16(pk.PacketID))
+
+ if pk.ProtocolVersion == 5 {
+ pb := mempool.GetBuffer()
+ defer mempool.PutBuffer(pb)
+ pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
+ nb.Write(pb.Bytes())
+ nb.Write(pk.ReasonCodes)
+ }
+
+ pk.FixedHeader.Remaining = nb.Len()
+ pk.FixedHeader.Encode(buf)
+ buf.Write(nb.Bytes())
+
+ return nil
+}
+
+// UnsubackDecode decodes an Unsuback packet.
+func (pk *Packet) UnsubackDecode(buf []byte) error {
+ var offset int
+ var err error
+
+ pk.PacketID, offset, err = decodeUint16(buf, offset)
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
+ }
+
+ if pk.ProtocolVersion == 5 {
+ n, err := pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
+ }
+
+ offset += n
+
+ pk.ReasonCodes = buf[offset:]
+ }
+
+ return nil
+}
+
+// UnsubscribeEncode encodes an Unsubscribe packet.
+func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
+ if pk.PacketID == 0 {
+ return ErrProtocolViolationNoPacketID
+ }
+
+ nb := mempool.GetBuffer()
+ defer mempool.PutBuffer(nb)
+ nb.Write(encodeUint16(pk.PacketID))
+
+ xb := mempool.GetBuffer() // capture filters and write after length checks
+ defer mempool.PutBuffer(xb)
+ for _, sub := range pk.Filters {
+ xb.Write(encodeString(sub.Filter)) // [MQTT-3.10.3-1]
+ }
+
+ if pk.ProtocolVersion == 5 {
+ pb := mempool.GetBuffer()
+ defer mempool.PutBuffer(pb)
+ pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
+ nb.Write(pb.Bytes())
+ }
+
+ nb.Write(xb.Bytes())
+
+ pk.FixedHeader.Remaining = nb.Len()
+ pk.FixedHeader.Encode(buf)
+ buf.Write(nb.Bytes())
+
+ return nil
+}
+
+// UnsubscribeDecode decodes an Unsubscribe packet.
+func (pk *Packet) UnsubscribeDecode(buf []byte) error {
+ var offset int
+ var err error
+
+ pk.PacketID, offset, err = decodeUint16(buf, offset)
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
+ }
+
+ if pk.ProtocolVersion == 5 {
+ n, err := pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
+ }
+ offset += n
+ }
+
+ var filter string
+ pk.Filters = Subscriptions{}
+ for offset < len(buf) {
+ filter, offset, err = decodeString(buf, offset) // [MQTT-3.10.3-1]
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
+ }
+ pk.Filters = append(pk.Filters, Subscription{Filter: filter})
+ }
+
+ return nil
+}
+
+// UnsubscribeValidate validates an Unsubscribe packet.
+func (pk *Packet) UnsubscribeValidate() Code {
+ if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
+ return ErrProtocolViolationNoPacketID // [MQTT-2.2.1-3] [MQTT-2.2.1-4]
+ }
+
+ if len(pk.Filters) == 0 {
+ return ErrProtocolViolationNoFilters // [MQTT-3.10.3-2]
+ }
+
+ return CodeSuccess
+}
+
+// AuthEncode encodes an Auth packet.
+func (pk *Packet) AuthEncode(buf *bytes.Buffer) error {
+ nb := mempool.GetBuffer()
+ defer mempool.PutBuffer(nb)
+ nb.WriteByte(pk.ReasonCode)
+
+ pb := mempool.GetBuffer()
+ defer mempool.PutBuffer(pb)
+ pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
+ nb.Write(pb.Bytes())
+
+ pk.FixedHeader.Remaining = nb.Len()
+ pk.FixedHeader.Encode(buf)
+ buf.Write(nb.Bytes())
+ return nil
+}
+
+// AuthDecode decodes an Auth packet.
+func (pk *Packet) AuthDecode(buf []byte) error {
+ var offset int
+ var err error
+
+ pk.ReasonCode, offset, err = decodeByte(buf, offset)
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedReasonCode)
+ }
+
+ _, err = pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
+ if err != nil {
+ return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
+ }
+
+ return nil
+}
+
+// AuthValidate returns success if the auth packet is valid.
+func (pk *Packet) AuthValidate() Code {
+ if pk.ReasonCode != CodeSuccess.Code &&
+ pk.ReasonCode != CodeContinueAuthentication.Code &&
+ pk.ReasonCode != CodeReAuthenticate.Code {
+ return ErrProtocolViolationInvalidReason // [MQTT-3.15.2-1]
+ }
+
+ return CodeSuccess
+}
+
+// FormatID returns the PacketID field as a decimal integer.
+func (pk *Packet) FormatID() string {
+ return strconv.FormatUint(uint64(pk.PacketID), 10)
+}
diff --git a/packets/packets_test.go b/packets/packets_test.go
new file mode 100644
index 0000000..1e18f1f
--- /dev/null
+++ b/packets/packets_test.go
@@ -0,0 +1,505 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package packets
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+
+ "github.com/jinzhu/copier"
+ "github.com/stretchr/testify/require"
+)
+
+const pkInfo = "packet type %v, %s"
+
+var packetList = []byte{
+ Connect,
+ Connack,
+ Publish,
+ Puback,
+ Pubrec,
+ Pubrel,
+ Pubcomp,
+ Subscribe,
+ Suback,
+ Unsubscribe,
+ Unsuback,
+ Pingreq,
+ Pingresp,
+ Disconnect,
+ Auth,
+}
+
+var pkTable = []TPacketCase{
+ TPacketData[Connect].Get(TConnectMqtt311),
+ TPacketData[Connect].Get(TConnectMqtt5),
+ TPacketData[Connect].Get(TConnectUserPassLWT),
+ TPacketData[Connack].Get(TConnackAcceptedMqtt5),
+ TPacketData[Connack].Get(TConnackAcceptedNoSession),
+ TPacketData[Publish].Get(TPublishBasic),
+ TPacketData[Publish].Get(TPublishMqtt5),
+ TPacketData[Puback].Get(TPuback),
+ TPacketData[Pubrec].Get(TPubrec),
+ TPacketData[Pubrel].Get(TPubrel),
+ TPacketData[Pubcomp].Get(TPubcomp),
+ TPacketData[Subscribe].Get(TSubscribe),
+ TPacketData[Subscribe].Get(TSubscribeMqtt5),
+ TPacketData[Suback].Get(TSuback),
+ TPacketData[Unsubscribe].Get(TUnsubscribe),
+ TPacketData[Unsubscribe].Get(TUnsubscribeMqtt5),
+ TPacketData[Pingreq].Get(TPingreq),
+ TPacketData[Pingresp].Get(TPingresp),
+ TPacketData[Disconnect].Get(TDisconnect),
+ TPacketData[Disconnect].Get(TDisconnectMqtt5),
+}
+
+func TestNewPackets(t *testing.T) {
+ s := NewPackets()
+ require.NotNil(t, s.internal)
+}
+
+func TestPacketsAdd(t *testing.T) {
+ s := NewPackets()
+ s.Add("cl1", Packet{})
+ require.Contains(t, s.internal, "cl1")
+}
+
+func TestPacketsGet(t *testing.T) {
+ s := NewPackets()
+ s.Add("cl1", Packet{TopicName: "a1"})
+ s.Add("cl2", Packet{TopicName: "a2"})
+ require.Contains(t, s.internal, "cl1")
+ require.Contains(t, s.internal, "cl2")
+
+ pk, ok := s.Get("cl1")
+ require.True(t, ok)
+ require.Equal(t, "a1", pk.TopicName)
+}
+
+func TestPacketsGetAll(t *testing.T) {
+ s := NewPackets()
+ s.Add("cl1", Packet{TopicName: "a1"})
+ s.Add("cl2", Packet{TopicName: "a2"})
+ s.Add("cl3", Packet{TopicName: "a3"})
+ require.Contains(t, s.internal, "cl1")
+ require.Contains(t, s.internal, "cl2")
+ require.Contains(t, s.internal, "cl3")
+
+ subs := s.GetAll()
+ require.Len(t, subs, 3)
+}
+
+func TestPacketsLen(t *testing.T) {
+ s := NewPackets()
+ s.Add("cl1", Packet{TopicName: "a1"})
+ s.Add("cl2", Packet{TopicName: "a2"})
+ require.Contains(t, s.internal, "cl1")
+ require.Contains(t, s.internal, "cl2")
+ require.Equal(t, 2, s.Len())
+}
+
+func TestSPacketsDelete(t *testing.T) {
+ s := NewPackets()
+ s.Add("cl1", Packet{TopicName: "a1"})
+ require.Contains(t, s.internal, "cl1")
+
+ s.Delete("cl1")
+ _, ok := s.Get("cl1")
+ require.False(t, ok)
+}
+
+func TestFormatPacketID(t *testing.T) {
+ for _, id := range []uint16{0, 7, 0x100, 0xffff} {
+ packet := &Packet{PacketID: id}
+ require.Equal(t, fmt.Sprint(id), packet.FormatID())
+ }
+}
+
+func TestSubscriptionOptionsEncodeDecode(t *testing.T) {
+ p := &Subscription{
+ Qos: 2,
+ NoLocal: true,
+ RetainAsPublished: true,
+ RetainHandling: 2,
+ }
+ x := new(Subscription)
+ x.decode(p.encode())
+ require.Equal(t, *p, *x)
+
+ p = &Subscription{
+ Qos: 1,
+ NoLocal: false,
+ RetainAsPublished: false,
+ RetainHandling: 1,
+ }
+ x = new(Subscription)
+ x.decode(p.encode())
+ require.Equal(t, *p, *x)
+}
+
+func TestPacketEncode(t *testing.T) {
+ for _, pkt := range packetList {
+ require.Contains(t, TPacketData, pkt)
+ for _, wanted := range TPacketData[pkt] {
+ t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
+ if !encodeTestOK(wanted) {
+ return
+ }
+
+ pk := new(Packet)
+ _ = copier.Copy(pk, wanted.Packet)
+ require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
+
+ pk.Mods.AllowResponseInfo = true
+
+ buf := new(bytes.Buffer)
+ var err error
+ switch pkt {
+ case Connect:
+ err = pk.ConnectEncode(buf)
+ case Connack:
+ err = pk.ConnackEncode(buf)
+ case Publish:
+ err = pk.PublishEncode(buf)
+ case Puback:
+ err = pk.PubackEncode(buf)
+ case Pubrec:
+ err = pk.PubrecEncode(buf)
+ case Pubrel:
+ err = pk.PubrelEncode(buf)
+ case Pubcomp:
+ err = pk.PubcompEncode(buf)
+ case Subscribe:
+ err = pk.SubscribeEncode(buf)
+ case Suback:
+ err = pk.SubackEncode(buf)
+ case Unsubscribe:
+ err = pk.UnsubscribeEncode(buf)
+ case Unsuback:
+ err = pk.UnsubackEncode(buf)
+ case Pingreq:
+ err = pk.PingreqEncode(buf)
+ case Pingresp:
+ err = pk.PingrespEncode(buf)
+ case Disconnect:
+ err = pk.DisconnectEncode(buf)
+ case Auth:
+ err = pk.AuthEncode(buf)
+ }
+ if wanted.Expect != nil {
+ require.Error(t, err, pkInfo, pkt, wanted.Desc)
+ return
+ }
+
+ require.NoError(t, err, pkInfo, pkt, wanted.Desc)
+ encoded := buf.Bytes()
+
+ // If ActualBytes is set, compare mutated version of byte string instead (to avoid length mismatches, etc).
+ if len(wanted.ActualBytes) > 0 {
+ wanted.RawBytes = wanted.ActualBytes
+ }
+ require.EqualValues(t, wanted.RawBytes, encoded, pkInfo, pkt, wanted.Desc)
+ })
+ }
+ }
+}
+
+func TestPacketDecode(t *testing.T) {
+ for _, pkt := range packetList {
+ require.Contains(t, TPacketData, pkt)
+ for _, wanted := range TPacketData[pkt] {
+ t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
+ if !decodeTestOK(wanted) {
+ return
+ }
+
+ pk := &Packet{FixedHeader: FixedHeader{Type: pkt}}
+ pk.Mods.AllowResponseInfo = true
+ _ = pk.FixedHeader.Decode(wanted.RawBytes[0])
+ if len(wanted.RawBytes) > 0 {
+ pk.FixedHeader.Remaining = int(wanted.RawBytes[1])
+ }
+
+ if wanted.Packet != nil && wanted.Packet.ProtocolVersion != 0 {
+ pk.ProtocolVersion = wanted.Packet.ProtocolVersion
+ }
+
+ buf := wanted.RawBytes[2:]
+ var err error
+ switch pkt {
+ case Connect:
+ err = pk.ConnectDecode(buf)
+ case Connack:
+ err = pk.ConnackDecode(buf)
+ case Publish:
+ err = pk.PublishDecode(buf)
+ case Puback:
+ err = pk.PubackDecode(buf)
+ case Pubrec:
+ err = pk.PubrecDecode(buf)
+ case Pubrel:
+ err = pk.PubrelDecode(buf)
+ case Pubcomp:
+ err = pk.PubcompDecode(buf)
+ case Subscribe:
+ err = pk.SubscribeDecode(buf)
+ case Suback:
+ err = pk.SubackDecode(buf)
+ case Unsubscribe:
+ err = pk.UnsubscribeDecode(buf)
+ case Unsuback:
+ err = pk.UnsubackDecode(buf)
+ case Pingreq:
+ err = pk.PingreqDecode(buf)
+ case Pingresp:
+ err = pk.PingrespDecode(buf)
+ case Disconnect:
+ err = pk.DisconnectDecode(buf)
+ case Auth:
+ err = pk.AuthDecode(buf)
+ }
+
+ if wanted.FailFirst != nil {
+ require.Error(t, err, pkInfo, pkt, wanted.Desc)
+ require.ErrorIs(t, err, wanted.FailFirst, pkInfo, pkt, wanted.Desc)
+ return
+ }
+
+ require.NoError(t, err, pkInfo, pkt, wanted.Desc)
+
+ require.EqualValues(t, wanted.Packet.Filters, pk.Filters, pkInfo, pkt, wanted.Desc)
+
+ require.Equal(t, wanted.Packet.FixedHeader.Type, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.FixedHeader.Dup, pk.FixedHeader.Dup, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.FixedHeader.Qos, pk.FixedHeader.Qos, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.FixedHeader.Retain, pk.FixedHeader.Retain, pkInfo, pkt, wanted.Desc)
+
+ if pkt == Connect {
+ // we use ProtocolVersion for controlling packet encoding, but we don't need to test
+ // against it unless it's a connect packet.
+ require.Equal(t, wanted.Packet.ProtocolVersion, pk.ProtocolVersion, pkInfo, pkt, wanted.Desc)
+ }
+ require.Equal(t, wanted.Packet.Connect.ProtocolName, pk.Connect.ProtocolName, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.Connect.Clean, pk.Connect.Clean, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.Connect.ClientIdentifier, pk.Connect.ClientIdentifier, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.Connect.Keepalive, pk.Connect.Keepalive, pkInfo, pkt, wanted.Desc)
+
+ require.Equal(t, wanted.Packet.Connect.UsernameFlag, pk.Connect.UsernameFlag, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.Connect.Username, pk.Connect.Username, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.Connect.PasswordFlag, pk.Connect.PasswordFlag, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.Connect.Password, pk.Connect.Password, pkInfo, pkt, wanted.Desc)
+
+ require.Equal(t, wanted.Packet.Connect.WillFlag, pk.Connect.WillFlag, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.Connect.WillTopic, pk.Connect.WillTopic, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.Connect.WillPayload, pk.Connect.WillPayload, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.Connect.WillQos, pk.Connect.WillQos, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.Connect.WillRetain, pk.Connect.WillRetain, pkInfo, pkt, wanted.Desc)
+
+ require.Equal(t, wanted.Packet.ReasonCodes, pk.ReasonCodes, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.ReasonCode, pk.ReasonCode, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.SessionPresent, pk.SessionPresent, pkInfo, pkt, wanted.Desc)
+ require.Equal(t, wanted.Packet.PacketID, pk.PacketID, pkInfo, pkt, wanted.Desc)
+
+ require.EqualValues(t, wanted.Packet.Properties, pk.Properties)
+ require.EqualValues(t, wanted.Packet.Connect.WillProperties, pk.Connect.WillProperties)
+ })
+ }
+ }
+}
+
+func TestValidate(t *testing.T) {
+ for _, pkt := range packetList {
+ require.Contains(t, TPacketData, pkt)
+ for _, wanted := range TPacketData[pkt] {
+ t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
+ if wanted.Group == "validate" || wanted.Primary {
+ pk := wanted.Packet
+ var err error
+ switch pkt {
+ case Connect:
+ err = pk.ConnectValidate()
+ case Publish:
+ err = pk.PublishValidate(1024)
+ case Subscribe:
+ err = pk.SubscribeValidate()
+ case Unsubscribe:
+ err = pk.UnsubscribeValidate()
+ case Auth:
+ err = pk.AuthValidate()
+ }
+
+ if wanted.Expect != nil {
+ require.Error(t, err, pkInfo, pkt, wanted.Desc)
+ require.ErrorIs(t, wanted.Expect, err, pkInfo, pkt, wanted.Desc)
+ }
+ }
+ })
+ }
+ }
+}
+
+func TestAckValidatePubrec(t *testing.T) {
+ for _, b := range []byte{
+ CodeSuccess.Code,
+ CodeNoMatchingSubscribers.Code,
+ ErrUnspecifiedError.Code,
+ ErrImplementationSpecificError.Code,
+ ErrNotAuthorized.Code,
+ ErrTopicNameInvalid.Code,
+ ErrPacketIdentifierInUse.Code,
+ ErrQuotaExceeded.Code,
+ ErrPayloadFormatInvalid.Code,
+ } {
+ pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: b}
+ require.True(t, pk.ReasonCodeValid())
+ }
+ pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: ErrClientIdentifierTooLong.Code}
+ require.False(t, pk.ReasonCodeValid())
+}
+
+func TestAckValidatePubrel(t *testing.T) {
+ for _, b := range []byte{
+ CodeSuccess.Code,
+ ErrPacketIdentifierNotFound.Code,
+ } {
+ pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: b}
+ require.True(t, pk.ReasonCodeValid())
+ }
+ pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
+ require.False(t, pk.ReasonCodeValid())
+}
+
+func TestAckValidatePubcomp(t *testing.T) {
+ for _, b := range []byte{
+ CodeSuccess.Code,
+ ErrPacketIdentifierNotFound.Code,
+ } {
+ pk := Packet{FixedHeader: FixedHeader{Type: Pubcomp}, ReasonCode: b}
+ require.True(t, pk.ReasonCodeValid())
+ }
+ pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
+ require.False(t, pk.ReasonCodeValid())
+}
+
+func TestAckValidateSuback(t *testing.T) {
+ for _, b := range []byte{
+ CodeGrantedQos0.Code,
+ CodeGrantedQos1.Code,
+ CodeGrantedQos2.Code,
+ ErrUnspecifiedError.Code,
+ ErrImplementationSpecificError.Code,
+ ErrNotAuthorized.Code,
+ ErrTopicFilterInvalid.Code,
+ ErrPacketIdentifierInUse.Code,
+ ErrQuotaExceeded.Code,
+ ErrSharedSubscriptionsNotSupported.Code,
+ ErrSubscriptionIdentifiersNotSupported.Code,
+ ErrWildcardSubscriptionsNotSupported.Code,
+ } {
+ pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: b}
+ require.True(t, pk.ReasonCodeValid())
+ }
+
+ pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: ErrClientIdentifierTooLong.Code}
+ require.False(t, pk.ReasonCodeValid())
+}
+
+func TestAckValidateUnsuback(t *testing.T) {
+ for _, b := range []byte{
+ CodeSuccess.Code,
+ CodeNoSubscriptionExisted.Code,
+ ErrUnspecifiedError.Code,
+ ErrImplementationSpecificError.Code,
+ ErrNotAuthorized.Code,
+ ErrTopicFilterInvalid.Code,
+ ErrPacketIdentifierInUse.Code,
+ } {
+ pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: b}
+ require.True(t, pk.ReasonCodeValid())
+ }
+
+ pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: ErrClientIdentifierTooLong.Code}
+ require.False(t, pk.ReasonCodeValid())
+}
+
+func TestReasonCodeValidMisc(t *testing.T) {
+ pk := Packet{FixedHeader: FixedHeader{Type: Connack}, ReasonCode: CodeSuccess.Code}
+ require.True(t, pk.ReasonCodeValid())
+}
+
+func TestCopy(t *testing.T) {
+ for _, tt := range pkTable {
+ pkc := tt.Packet.Copy(true)
+
+ require.Equal(t, tt.Packet.FixedHeader.Qos, pkc.FixedHeader.Qos, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, false, pkc.FixedHeader.Dup, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, false, pkc.FixedHeader.Retain, pkInfo, tt.Case, tt.Desc)
+
+ require.Equal(t, tt.Packet.TopicName, pkc.TopicName, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Connect.ClientIdentifier, pkc.Connect.ClientIdentifier, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Connect.Keepalive, pkc.Connect.Keepalive, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.ProtocolVersion, pkc.ProtocolVersion, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Connect.PasswordFlag, pkc.Connect.PasswordFlag, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Connect.UsernameFlag, pkc.Connect.UsernameFlag, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Connect.WillQos, pkc.Connect.WillQos, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Connect.WillTopic, pkc.Connect.WillTopic, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Connect.WillFlag, pkc.Connect.WillFlag, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Connect.WillRetain, pkc.Connect.WillRetain, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Connect.WillProperties, pkc.Connect.WillProperties, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Properties, pkc.Properties, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Connect.Clean, pkc.Connect.Clean, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.SessionPresent, pkc.SessionPresent, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.ReasonCode, pkc.ReasonCode, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.PacketID, pkc.PacketID, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Filters, pkc.Filters, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Payload, pkc.Payload, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Connect.Password, pkc.Connect.Password, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Connect.Username, pkc.Connect.Username, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Connect.ProtocolName, pkc.Connect.ProtocolName, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Connect.WillPayload, pkc.Connect.WillPayload, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.ReasonCodes, pkc.ReasonCodes, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
+ require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
+ require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
+
+ pkcc := tt.Packet.Copy(false)
+ require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc)
+ }
+}
+
+func TestMergeSubscription(t *testing.T) {
+ sub := Subscription{
+ Filter: "a/b/c",
+ RetainHandling: 0,
+ Qos: 0,
+ RetainAsPublished: false,
+ NoLocal: false,
+ Identifier: 1,
+ }
+
+ sub2 := Subscription{
+ Filter: "a/b/d",
+ RetainHandling: 0,
+ Qos: 2,
+ RetainAsPublished: false,
+ NoLocal: true,
+ Identifier: 2,
+ }
+
+ expect := Subscription{
+ Filter: "a/b/c",
+ RetainHandling: 0,
+ Qos: 2,
+ RetainAsPublished: false,
+ NoLocal: true,
+ Identifier: 1,
+ Identifiers: map[string]int{
+ "a/b/c": 1,
+ "a/b/d": 2,
+ },
+ }
+ require.Equal(t, expect, sub.Merge(sub2))
+}
diff --git a/packets/properties.go b/packets/properties.go
new file mode 100644
index 0000000..1ad6294
--- /dev/null
+++ b/packets/properties.go
@@ -0,0 +1,481 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package packets
+
+import (
+ "bytes"
+ "fmt"
+ "strings"
+
+ "testmqtt/mempool"
+)
+
+const (
+ PropPayloadFormat byte = 1
+ PropMessageExpiryInterval byte = 2
+ PropContentType byte = 3
+ PropResponseTopic byte = 8
+ PropCorrelationData byte = 9
+ PropSubscriptionIdentifier byte = 11
+ PropSessionExpiryInterval byte = 17
+ PropAssignedClientID byte = 18
+ PropServerKeepAlive byte = 19
+ PropAuthenticationMethod byte = 21
+ PropAuthenticationData byte = 22
+ PropRequestProblemInfo byte = 23
+ PropWillDelayInterval byte = 24
+ PropRequestResponseInfo byte = 25
+ PropResponseInfo byte = 26
+ PropServerReference byte = 28
+ PropReasonString byte = 31
+ PropReceiveMaximum byte = 33
+ PropTopicAliasMaximum byte = 34
+ PropTopicAlias byte = 35
+ PropMaximumQos byte = 36
+ PropRetainAvailable byte = 37
+ PropUser byte = 38
+ PropMaximumPacketSize byte = 39
+ PropWildcardSubAvailable byte = 40
+ PropSubIDAvailable byte = 41
+ PropSharedSubAvailable byte = 42
+)
+
+// validPacketProperties indicates which properties are valid for which packet types.
+var validPacketProperties = map[byte]map[byte]byte{
+ PropPayloadFormat: {Publish: 1, WillProperties: 1},
+ PropMessageExpiryInterval: {Publish: 1, WillProperties: 1},
+ PropContentType: {Publish: 1, WillProperties: 1},
+ PropResponseTopic: {Publish: 1, WillProperties: 1},
+ PropCorrelationData: {Publish: 1, WillProperties: 1},
+ PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
+ PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
+ PropAssignedClientID: {Connack: 1},
+ PropServerKeepAlive: {Connack: 1},
+ PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
+ PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
+ PropRequestProblemInfo: {Connect: 1},
+ PropWillDelayInterval: {WillProperties: 1},
+ PropRequestResponseInfo: {Connect: 1},
+ PropResponseInfo: {Connack: 1},
+ PropServerReference: {Connack: 1, Disconnect: 1},
+ PropReasonString: {Connack: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Suback: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
+ PropReceiveMaximum: {Connect: 1, Connack: 1},
+ PropTopicAliasMaximum: {Connect: 1, Connack: 1},
+ PropTopicAlias: {Publish: 1},
+ PropMaximumQos: {Connack: 1},
+ PropRetainAvailable: {Connack: 1},
+ PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1, WillProperties: 1},
+ PropMaximumPacketSize: {Connect: 1, Connack: 1},
+ PropWildcardSubAvailable: {Connack: 1},
+ PropSubIDAvailable: {Connack: 1},
+ PropSharedSubAvailable: {Connack: 1},
+}
+
+// UserProperty is an arbitrary key-value pair for a packet user properties array.
+type UserProperty struct { // [MQTT-1.5.7-1]
+ Key string `json:"k"`
+ Val string `json:"v"`
+}
+
+// Properties contains all mqtt v5 properties available for a packet.
+// Some properties have valid values of 0 or not-present. In this case, we opt for
+// property flags to indicate the usage of property.
+// Refer to mqtt v5 2.2.2.2 Property spec for more information.
+type Properties struct {
+ CorrelationData []byte `json:"cd"`
+ SubscriptionIdentifier []int `json:"si"`
+ AuthenticationData []byte `json:"ad"`
+ User []UserProperty `json:"user"`
+ ContentType string `json:"ct"`
+ ResponseTopic string `json:"rt"`
+ AssignedClientID string `json:"aci"`
+ AuthenticationMethod string `json:"am"`
+ ResponseInfo string `json:"ri"`
+ ServerReference string `json:"sr"`
+ ReasonString string `json:"rs"`
+ MessageExpiryInterval uint32 `json:"me"`
+ SessionExpiryInterval uint32 `json:"sei"`
+ WillDelayInterval uint32 `json:"wdi"`
+ MaximumPacketSize uint32 `json:"mps"`
+ ServerKeepAlive uint16 `json:"ska"`
+ ReceiveMaximum uint16 `json:"rm"`
+ TopicAliasMaximum uint16 `json:"tam"`
+ TopicAlias uint16 `json:"ta"`
+ PayloadFormat byte `json:"pf"`
+ PayloadFormatFlag bool `json:"fpf"`
+ SessionExpiryIntervalFlag bool `json:"fsei"`
+ ServerKeepAliveFlag bool `json:"fska"`
+ RequestProblemInfo byte `json:"rpi"`
+ RequestProblemInfoFlag bool `json:"frpi"`
+ RequestResponseInfo byte `json:"rri"`
+ TopicAliasFlag bool `json:"fta"`
+ MaximumQos byte `json:"mqos"`
+ MaximumQosFlag bool `json:"fmqos"`
+ RetainAvailable byte `json:"ra"`
+ RetainAvailableFlag bool `json:"fra"`
+ WildcardSubAvailable byte `json:"wsa"`
+ WildcardSubAvailableFlag bool `json:"fwsa"`
+ SubIDAvailable byte `json:"sida"`
+ SubIDAvailableFlag bool `json:"fsida"`
+ SharedSubAvailable byte `json:"ssa"`
+ SharedSubAvailableFlag bool `json:"fssa"`
+}
+
+// Copy creates a new Properties struct with copies of the values.
+func (p *Properties) Copy(allowTransfer bool) Properties {
+ pr := Properties{
+ PayloadFormat: p.PayloadFormat, // [MQTT-3.3.2-4]
+ PayloadFormatFlag: p.PayloadFormatFlag,
+ MessageExpiryInterval: p.MessageExpiryInterval,
+ ContentType: p.ContentType, // [MQTT-3.3.2-20]
+ ResponseTopic: p.ResponseTopic, // [MQTT-3.3.2-15]
+ SessionExpiryInterval: p.SessionExpiryInterval,
+ SessionExpiryIntervalFlag: p.SessionExpiryIntervalFlag,
+ AssignedClientID: p.AssignedClientID,
+ ServerKeepAlive: p.ServerKeepAlive,
+ ServerKeepAliveFlag: p.ServerKeepAliveFlag,
+ AuthenticationMethod: p.AuthenticationMethod,
+ RequestProblemInfo: p.RequestProblemInfo,
+ RequestProblemInfoFlag: p.RequestProblemInfoFlag,
+ WillDelayInterval: p.WillDelayInterval,
+ RequestResponseInfo: p.RequestResponseInfo,
+ ResponseInfo: p.ResponseInfo,
+ ServerReference: p.ServerReference,
+ ReasonString: p.ReasonString,
+ ReceiveMaximum: p.ReceiveMaximum,
+ TopicAliasMaximum: p.TopicAliasMaximum,
+ TopicAlias: 0, // NB; do not copy topic alias [MQTT-3.3.2-7] + we do not send to clients (currently) [MQTT-3.1.2-26] [MQTT-3.1.2-27]
+ MaximumQos: p.MaximumQos,
+ MaximumQosFlag: p.MaximumQosFlag,
+ RetainAvailable: p.RetainAvailable,
+ RetainAvailableFlag: p.RetainAvailableFlag,
+ MaximumPacketSize: p.MaximumPacketSize,
+ WildcardSubAvailable: p.WildcardSubAvailable,
+ WildcardSubAvailableFlag: p.WildcardSubAvailableFlag,
+ SubIDAvailable: p.SubIDAvailable,
+ SubIDAvailableFlag: p.SubIDAvailableFlag,
+ SharedSubAvailable: p.SharedSubAvailable,
+ SharedSubAvailableFlag: p.SharedSubAvailableFlag,
+ }
+
+ if allowTransfer {
+ pr.TopicAlias = p.TopicAlias
+ pr.TopicAliasFlag = p.TopicAliasFlag
+ }
+
+ if len(p.CorrelationData) > 0 {
+ pr.CorrelationData = append([]byte{}, p.CorrelationData...) // [MQTT-3.3.2-16]
+ }
+
+ if len(p.SubscriptionIdentifier) > 0 {
+ pr.SubscriptionIdentifier = append([]int{}, p.SubscriptionIdentifier...)
+ }
+
+ if len(p.AuthenticationData) > 0 {
+ pr.AuthenticationData = append([]byte{}, p.AuthenticationData...)
+ }
+
+ if len(p.User) > 0 {
+ pr.User = []UserProperty{}
+ for _, v := range p.User {
+ pr.User = append(pr.User, UserProperty{ // [MQTT-3.3.2-17]
+ Key: v.Key,
+ Val: v.Val,
+ })
+ }
+ }
+
+ return pr
+}
+
+// canEncode returns true if the property type is valid for the packet type.
+func (p *Properties) canEncode(pkt byte, k byte) bool {
+ return validPacketProperties[k][pkt] == 1
+}
+
+// Encode encodes properties into a bytes buffer.
+func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
+ if p == nil {
+ return
+ }
+
+ buf := mempool.GetBuffer()
+ defer mempool.PutBuffer(buf)
+ if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
+ buf.WriteByte(PropPayloadFormat)
+ buf.WriteByte(p.PayloadFormat)
+ }
+
+ if p.canEncode(pkt, PropMessageExpiryInterval) && p.MessageExpiryInterval > 0 {
+ buf.WriteByte(PropMessageExpiryInterval)
+ buf.Write(encodeUint32(p.MessageExpiryInterval))
+ }
+
+ if p.canEncode(pkt, PropContentType) && p.ContentType != "" {
+ buf.WriteByte(PropContentType)
+ buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
+ }
+
+ if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
+ p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
+ buf.WriteByte(PropResponseTopic)
+ buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
+ }
+
+ if mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
+ buf.WriteByte(PropCorrelationData)
+ buf.Write(encodeBytes(p.CorrelationData))
+ }
+
+ if p.canEncode(pkt, PropSubscriptionIdentifier) && len(p.SubscriptionIdentifier) > 0 {
+ for _, v := range p.SubscriptionIdentifier {
+ if v > 0 {
+ buf.WriteByte(PropSubscriptionIdentifier)
+ encodeLength(buf, int64(v))
+ }
+ }
+ }
+
+ if p.canEncode(pkt, PropSessionExpiryInterval) && p.SessionExpiryIntervalFlag { // [MQTT-3.14.2-2]
+ buf.WriteByte(PropSessionExpiryInterval)
+ buf.Write(encodeUint32(p.SessionExpiryInterval))
+ }
+
+ if p.canEncode(pkt, PropAssignedClientID) && p.AssignedClientID != "" {
+ buf.WriteByte(PropAssignedClientID)
+ buf.Write(encodeString(p.AssignedClientID))
+ }
+
+ if p.canEncode(pkt, PropServerKeepAlive) && p.ServerKeepAliveFlag {
+ buf.WriteByte(PropServerKeepAlive)
+ buf.Write(encodeUint16(p.ServerKeepAlive))
+ }
+
+ if p.canEncode(pkt, PropAuthenticationMethod) && p.AuthenticationMethod != "" {
+ buf.WriteByte(PropAuthenticationMethod)
+ buf.Write(encodeString(p.AuthenticationMethod))
+ }
+
+ if p.canEncode(pkt, PropAuthenticationData) && len(p.AuthenticationData) > 0 {
+ buf.WriteByte(PropAuthenticationData)
+ buf.Write(encodeBytes(p.AuthenticationData))
+ }
+
+ if p.canEncode(pkt, PropRequestProblemInfo) && p.RequestProblemInfoFlag {
+ buf.WriteByte(PropRequestProblemInfo)
+ buf.WriteByte(p.RequestProblemInfo)
+ }
+
+ if p.canEncode(pkt, PropWillDelayInterval) && p.WillDelayInterval > 0 {
+ buf.WriteByte(PropWillDelayInterval)
+ buf.Write(encodeUint32(p.WillDelayInterval))
+ }
+
+ if p.canEncode(pkt, PropRequestResponseInfo) && p.RequestResponseInfo > 0 {
+ buf.WriteByte(PropRequestResponseInfo)
+ buf.WriteByte(p.RequestResponseInfo)
+ }
+
+ if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
+ buf.WriteByte(PropResponseInfo)
+ buf.Write(encodeString(p.ResponseInfo))
+ }
+
+ if p.canEncode(pkt, PropServerReference) && len(p.ServerReference) > 0 {
+ buf.WriteByte(PropServerReference)
+ buf.Write(encodeString(p.ServerReference))
+ }
+
+ // [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
+ // [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
+ if !mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
+ b := encodeString(p.ReasonString)
+ if mods.MaxSize == 0 || uint32(n+len(b)+1) < mods.MaxSize {
+ buf.WriteByte(PropReasonString)
+ buf.Write(b)
+ }
+ }
+
+ if p.canEncode(pkt, PropReceiveMaximum) && p.ReceiveMaximum > 0 {
+ buf.WriteByte(PropReceiveMaximum)
+ buf.Write(encodeUint16(p.ReceiveMaximum))
+ }
+
+ if p.canEncode(pkt, PropTopicAliasMaximum) && p.TopicAliasMaximum > 0 {
+ buf.WriteByte(PropTopicAliasMaximum)
+ buf.Write(encodeUint16(p.TopicAliasMaximum))
+ }
+
+ if p.canEncode(pkt, PropTopicAlias) && p.TopicAliasFlag && p.TopicAlias > 0 { // [MQTT-3.3.2-8]
+ buf.WriteByte(PropTopicAlias)
+ buf.Write(encodeUint16(p.TopicAlias))
+ }
+
+ if p.canEncode(pkt, PropMaximumQos) && p.MaximumQosFlag && p.MaximumQos < 2 {
+ buf.WriteByte(PropMaximumQos)
+ buf.WriteByte(p.MaximumQos)
+ }
+
+ if p.canEncode(pkt, PropRetainAvailable) && p.RetainAvailableFlag {
+ buf.WriteByte(PropRetainAvailable)
+ buf.WriteByte(p.RetainAvailable)
+ }
+
+ if !mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
+ pb := mempool.GetBuffer()
+ defer mempool.PutBuffer(pb)
+ for _, v := range p.User {
+ pb.WriteByte(PropUser)
+ pb.Write(encodeString(v.Key))
+ pb.Write(encodeString(v.Val))
+ }
+ // [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
+ // [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
+ if mods.MaxSize == 0 || uint32(n+pb.Len()+1) < mods.MaxSize {
+ buf.Write(pb.Bytes())
+ }
+ }
+
+ if p.canEncode(pkt, PropMaximumPacketSize) && p.MaximumPacketSize > 0 {
+ buf.WriteByte(PropMaximumPacketSize)
+ buf.Write(encodeUint32(p.MaximumPacketSize))
+ }
+
+ if p.canEncode(pkt, PropWildcardSubAvailable) && p.WildcardSubAvailableFlag {
+ buf.WriteByte(PropWildcardSubAvailable)
+ buf.WriteByte(p.WildcardSubAvailable)
+ }
+
+ if p.canEncode(pkt, PropSubIDAvailable) && p.SubIDAvailableFlag {
+ buf.WriteByte(PropSubIDAvailable)
+ buf.WriteByte(p.SubIDAvailable)
+ }
+
+ if p.canEncode(pkt, PropSharedSubAvailable) && p.SharedSubAvailableFlag {
+ buf.WriteByte(PropSharedSubAvailable)
+ buf.WriteByte(p.SharedSubAvailable)
+ }
+
+ encodeLength(b, int64(buf.Len()))
+ b.Write(buf.Bytes()) // [MQTT-3.1.3-10]
+}
+
+// Decode decodes property bytes into a properties struct.
+func (p *Properties) Decode(pkt byte, b *bytes.Buffer) (n int, err error) {
+ if p == nil {
+ return 0, nil
+ }
+
+ var bu int
+ n, bu, err = DecodeLength(b)
+ if err != nil {
+ return n + bu, err
+ }
+
+ if n == 0 {
+ return n + bu, nil
+ }
+
+ bt := b.Bytes()
+ var k byte
+ for offset := 0; offset < n; {
+ k, offset, err = decodeByte(bt, offset)
+ if err != nil {
+ return n + bu, err
+ }
+
+ if _, ok := validPacketProperties[k][pkt]; !ok {
+ return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pkt, ErrProtocolViolationUnsupportedProperty)
+ }
+
+ switch k {
+ case PropPayloadFormat:
+ p.PayloadFormat, offset, err = decodeByte(bt, offset)
+ p.PayloadFormatFlag = true
+ case PropMessageExpiryInterval:
+ p.MessageExpiryInterval, offset, err = decodeUint32(bt, offset)
+ case PropContentType:
+ p.ContentType, offset, err = decodeString(bt, offset)
+ case PropResponseTopic:
+ p.ResponseTopic, offset, err = decodeString(bt, offset)
+ case PropCorrelationData:
+ p.CorrelationData, offset, err = decodeBytes(bt, offset)
+ case PropSubscriptionIdentifier:
+ if p.SubscriptionIdentifier == nil {
+ p.SubscriptionIdentifier = []int{}
+ }
+
+ n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:]))
+ if err != nil {
+ return n + bu, err
+ }
+ p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n)
+ offset += bu
+ case PropSessionExpiryInterval:
+ p.SessionExpiryInterval, offset, err = decodeUint32(bt, offset)
+ p.SessionExpiryIntervalFlag = true
+ case PropAssignedClientID:
+ p.AssignedClientID, offset, err = decodeString(bt, offset)
+ case PropServerKeepAlive:
+ p.ServerKeepAlive, offset, err = decodeUint16(bt, offset)
+ p.ServerKeepAliveFlag = true
+ case PropAuthenticationMethod:
+ p.AuthenticationMethod, offset, err = decodeString(bt, offset)
+ case PropAuthenticationData:
+ p.AuthenticationData, offset, err = decodeBytes(bt, offset)
+ case PropRequestProblemInfo:
+ p.RequestProblemInfo, offset, err = decodeByte(bt, offset)
+ p.RequestProblemInfoFlag = true
+ case PropWillDelayInterval:
+ p.WillDelayInterval, offset, err = decodeUint32(bt, offset)
+ case PropRequestResponseInfo:
+ p.RequestResponseInfo, offset, err = decodeByte(bt, offset)
+ case PropResponseInfo:
+ p.ResponseInfo, offset, err = decodeString(bt, offset)
+ case PropServerReference:
+ p.ServerReference, offset, err = decodeString(bt, offset)
+ case PropReasonString:
+ p.ReasonString, offset, err = decodeString(bt, offset)
+ case PropReceiveMaximum:
+ p.ReceiveMaximum, offset, err = decodeUint16(bt, offset)
+ case PropTopicAliasMaximum:
+ p.TopicAliasMaximum, offset, err = decodeUint16(bt, offset)
+ case PropTopicAlias:
+ p.TopicAlias, offset, err = decodeUint16(bt, offset)
+ p.TopicAliasFlag = true
+ case PropMaximumQos:
+ p.MaximumQos, offset, err = decodeByte(bt, offset)
+ p.MaximumQosFlag = true
+ case PropRetainAvailable:
+ p.RetainAvailable, offset, err = decodeByte(bt, offset)
+ p.RetainAvailableFlag = true
+ case PropUser:
+ var k, v string
+ k, offset, err = decodeString(bt, offset)
+ if err != nil {
+ return n + bu, err
+ }
+ v, offset, err = decodeString(bt, offset)
+ p.User = append(p.User, UserProperty{Key: k, Val: v})
+ case PropMaximumPacketSize:
+ p.MaximumPacketSize, offset, err = decodeUint32(bt, offset)
+ case PropWildcardSubAvailable:
+ p.WildcardSubAvailable, offset, err = decodeByte(bt, offset)
+ p.WildcardSubAvailableFlag = true
+ case PropSubIDAvailable:
+ p.SubIDAvailable, offset, err = decodeByte(bt, offset)
+ p.SubIDAvailableFlag = true
+ case PropSharedSubAvailable:
+ p.SharedSubAvailable, offset, err = decodeByte(bt, offset)
+ p.SharedSubAvailableFlag = true
+ }
+
+ if err != nil {
+ return n + bu, err
+ }
+ }
+
+ return n + bu, nil
+}
diff --git a/packets/properties_test.go b/packets/properties_test.go
new file mode 100644
index 0000000..b0a2f10
--- /dev/null
+++ b/packets/properties_test.go
@@ -0,0 +1,333 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package packets
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ propertiesStruct = Properties{
+ PayloadFormat: byte(1), // UTF-8 Format
+ PayloadFormatFlag: true,
+ MessageExpiryInterval: uint32(2),
+ ContentType: "text/plain",
+ ResponseTopic: "a/b/c",
+ CorrelationData: []byte("data"),
+ SubscriptionIdentifier: []int{322122},
+ SessionExpiryInterval: uint32(120),
+ SessionExpiryIntervalFlag: true,
+ AssignedClientID: "mochi-v5",
+ ServerKeepAlive: uint16(20),
+ ServerKeepAliveFlag: true,
+ AuthenticationMethod: "SHA-1",
+ AuthenticationData: []byte("auth-data"),
+ RequestProblemInfo: byte(1),
+ RequestProblemInfoFlag: true,
+ WillDelayInterval: uint32(600),
+ RequestResponseInfo: byte(1),
+ ResponseInfo: "response",
+ ServerReference: "mochi-2",
+ ReasonString: "reason",
+ ReceiveMaximum: uint16(500),
+ TopicAliasMaximum: uint16(999),
+ TopicAlias: uint16(3),
+ TopicAliasFlag: true,
+ MaximumQos: byte(1),
+ MaximumQosFlag: true,
+ RetainAvailable: byte(1),
+ RetainAvailableFlag: true,
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ {
+ Key: "key2",
+ Val: "value2",
+ },
+ },
+ MaximumPacketSize: uint32(32000),
+ WildcardSubAvailable: byte(1),
+ WildcardSubAvailableFlag: true,
+ SubIDAvailable: byte(1),
+ SubIDAvailableFlag: true,
+ SharedSubAvailable: byte(1),
+ SharedSubAvailableFlag: true,
+ }
+
+ propertiesBytes = []byte{
+ 172, 1, // VBI
+
+ // Payload Format (1) (vbi:2)
+ 1, 1,
+
+ // Message Expiry (2) (vbi:7)
+ 2, 0, 0, 0, 2,
+
+ // Content Type (3) (vbi:20)
+ 3,
+ 0, 10, 't', 'e', 'x', 't', '/', 'p', 'l', 'a', 'i', 'n',
+
+ // Response Topic (8) (vbi:28)
+ 8,
+ 0, 5, 'a', '/', 'b', '/', 'c',
+
+ // Correlations Data (9) (vbi:35)
+ 9,
+ 0, 4, 'd', 'a', 't', 'a',
+
+ // Subscription Identifier (11) (vbi:39)
+ 11,
+ 202, 212, 19,
+
+ // Session Expiry Interval (17) (vbi:43)
+ 17,
+ 0, 0, 0, 120,
+
+ // Assigned Client ID (18) (vbi:55)
+ 18,
+ 0, 8, 'm', 'o', 'c', 'h', 'i', '-', 'v', '5',
+
+ // Server Keep Alive (19) (vbi:58)
+ 19,
+ 0, 20,
+
+ // Authentication Method (21) (vbi:66)
+ 21,
+ 0, 5, 'S', 'H', 'A', '-', '1',
+
+ // Authentication Data (22) (vbi:78)
+ 22,
+ 0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a',
+
+ // Request Problem Info (23) (vbi:80)
+ 23, 1,
+
+ // Will Delay Interval (24) (vbi:85)
+ 24,
+ 0, 0, 2, 88,
+
+ // Request Response Info (25) (vbi:87)
+ 25, 1,
+
+ // Response Info (26) (vbi:98)
+ 26,
+ 0, 8, 'r', 'e', 's', 'p', 'o', 'n', 's', 'e',
+
+ // Server Reference (28) (vbi:108)
+ 28,
+ 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2',
+
+ // Reason String (31) (vbi:117)
+ 31,
+ 0, 6, 'r', 'e', 'a', 's', 'o', 'n',
+
+ // Receive Maximum (33) (vbi:120)
+ 33,
+ 1, 244,
+
+ // Topic Alias Maximum (34) (vbi:123)
+ 34,
+ 3, 231,
+
+ // Topic Alias (35) (vbi:126)
+ 35,
+ 0, 3,
+
+ // Maximum Qos (36) (vbi:128)
+ 36, 1,
+
+ // Retain Available (37) (vbi: 130)
+ 37, 1,
+
+ // User Properties (38) (vbi:161)
+ 38,
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ 38,
+ 0, 4, 'k', 'e', 'y', '2',
+ 0, 6, 'v', 'a', 'l', 'u', 'e', '2',
+
+ // Maximum Packet Size (39) (vbi:166)
+ 39,
+ 0, 0, 125, 0,
+
+ // Wildcard Subscriptions Available (40) (vbi:168)
+ 40, 1,
+
+ // Subscription ID Available (41) (vbi:170)
+ 41, 1,
+
+ // Shared Subscriptions Available (42) (vbi:172)
+ 42, 1,
+ }
+)
+
+func init() {
+ validPacketProperties[PropPayloadFormat][Reserved] = 1
+ validPacketProperties[PropMessageExpiryInterval][Reserved] = 1
+ validPacketProperties[PropContentType][Reserved] = 1
+ validPacketProperties[PropResponseTopic][Reserved] = 1
+ validPacketProperties[PropCorrelationData][Reserved] = 1
+ validPacketProperties[PropSubscriptionIdentifier][Reserved] = 1
+ validPacketProperties[PropSessionExpiryInterval][Reserved] = 1
+ validPacketProperties[PropAssignedClientID][Reserved] = 1
+ validPacketProperties[PropServerKeepAlive][Reserved] = 1
+ validPacketProperties[PropAuthenticationMethod][Reserved] = 1
+ validPacketProperties[PropAuthenticationData][Reserved] = 1
+ validPacketProperties[PropRequestProblemInfo][Reserved] = 1
+ validPacketProperties[PropWillDelayInterval][Reserved] = 1
+ validPacketProperties[PropRequestResponseInfo][Reserved] = 1
+ validPacketProperties[PropResponseInfo][Reserved] = 1
+ validPacketProperties[PropServerReference][Reserved] = 1
+ validPacketProperties[PropReasonString][Reserved] = 1
+ validPacketProperties[PropReceiveMaximum][Reserved] = 1
+ validPacketProperties[PropTopicAliasMaximum][Reserved] = 1
+ validPacketProperties[PropTopicAlias][Reserved] = 1
+ validPacketProperties[PropMaximumQos][Reserved] = 1
+ validPacketProperties[PropRetainAvailable][Reserved] = 1
+ validPacketProperties[PropUser][Reserved] = 1
+ validPacketProperties[PropMaximumPacketSize][Reserved] = 1
+ validPacketProperties[PropWildcardSubAvailable][Reserved] = 1
+ validPacketProperties[PropSubIDAvailable][Reserved] = 1
+ validPacketProperties[PropSharedSubAvailable][Reserved] = 1
+}
+
+func TestEncodeProperties(t *testing.T) {
+ props := propertiesStruct
+ b := bytes.NewBuffer([]byte{})
+ props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
+ require.Equal(t, propertiesBytes, b.Bytes())
+}
+
+func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
+ props := propertiesStruct
+ b := bytes.NewBuffer([]byte{})
+ props.Encode(Reserved, Mods{DisallowProblemInfo: true}, b, 0)
+ require.NotEqual(t, propertiesBytes, b.Bytes())
+ require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
+ require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
+ require.False(t, bytes.Contains(b.Bytes(), []byte{26, 0, 8}))
+}
+
+func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
+ props := propertiesStruct
+ b := bytes.NewBuffer([]byte{})
+ props.Encode(Reserved, Mods{AllowResponseInfo: false}, b, 0)
+ require.NotEqual(t, propertiesBytes, b.Bytes())
+ require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
+ require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
+}
+
+func TestEncodePropertiesNil(t *testing.T) {
+ type tmp struct {
+ p *Properties
+ }
+
+ pr := tmp{}
+ b := bytes.NewBuffer([]byte{})
+ pr.p.Encode(Reserved, Mods{}, b, 0)
+ require.Equal(t, []byte{}, b.Bytes())
+}
+
+func TestEncodeZeroProperties(t *testing.T) {
+ // [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
+ props := new(Properties)
+ b := bytes.NewBuffer([]byte{})
+ props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
+ require.Equal(t, []byte{0x00}, b.Bytes())
+}
+
+func TestDecodeProperties(t *testing.T) {
+ b := bytes.NewBuffer(propertiesBytes)
+
+ props := new(Properties)
+ n, err := props.Decode(Reserved, b)
+ require.NoError(t, err)
+ require.Equal(t, 172+2, n)
+ require.EqualValues(t, propertiesStruct, *props)
+}
+
+func TestDecodePropertiesNil(t *testing.T) {
+ b := bytes.NewBuffer(propertiesBytes)
+
+ type tmp struct {
+ p *Properties
+ }
+
+ pr := tmp{}
+ n, err := pr.p.Decode(Reserved, b)
+ require.NoError(t, err)
+ require.Equal(t, 0, n)
+}
+
+func TestDecodePropertiesBadInitialVBI(t *testing.T) {
+ b := bytes.NewBuffer([]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255})
+ props := new(Properties)
+ _, err := props.Decode(Reserved, b)
+ require.Error(t, err)
+ require.ErrorIs(t, ErrMalformedVariableByteInteger, err)
+}
+
+func TestDecodePropertiesZeroLengthVBI(t *testing.T) {
+ b := bytes.NewBuffer([]byte{0})
+ props := new(Properties)
+ _, err := props.Decode(Reserved, b)
+ require.NoError(t, err)
+ require.Equal(t, props, new(Properties))
+}
+
+func TestDecodePropertiesBadKeyByte(t *testing.T) {
+ b := bytes.NewBuffer([]byte{64, 1})
+ props := new(Properties)
+ _, err := props.Decode(Reserved, b)
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrMalformedOffsetByteOutOfRange)
+}
+
+func TestDecodePropertiesInvalidForPacket(t *testing.T) {
+ b := bytes.NewBuffer([]byte{1, 99})
+ props := new(Properties)
+ _, err := props.Decode(Reserved, b)
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrProtocolViolationUnsupportedProperty)
+}
+
+func TestDecodePropertiesGeneralFailure(t *testing.T) {
+ b := bytes.NewBuffer([]byte{10, 11, 202, 212, 19})
+ props := new(Properties)
+ _, err := props.Decode(Reserved, b)
+ require.Error(t, err)
+}
+
+func TestDecodePropertiesBadSubscriptionID(t *testing.T) {
+ b := bytes.NewBuffer([]byte{10, 11, 255, 255, 255, 255, 255, 255, 255, 255})
+ props := new(Properties)
+ _, err := props.Decode(Reserved, b)
+ require.Error(t, err)
+}
+
+func TestDecodePropertiesBadUserProps(t *testing.T) {
+ b := bytes.NewBuffer([]byte{10, 38, 255, 255, 255, 255, 255, 255, 255, 255})
+ props := new(Properties)
+ _, err := props.Decode(Reserved, b)
+ require.Error(t, err)
+}
+
+func TestCopyProperties(t *testing.T) {
+ require.EqualValues(t, propertiesStruct, propertiesStruct.Copy(true))
+}
+
+func TestCopyPropertiesNoTransfer(t *testing.T) {
+ pkA := propertiesStruct
+ pkB := pkA.Copy(false)
+
+ // Properties which should never be transferred from one connection to another
+ require.Equal(t, uint16(0), pkB.TopicAlias)
+}
diff --git a/packets/tpackets.go b/packets/tpackets.go
new file mode 100644
index 0000000..79a6f58
--- /dev/null
+++ b/packets/tpackets.go
@@ -0,0 +1,4031 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package packets
+
+// TPacketCase contains data for cross-checking the encoding and decoding
+// of packets and expected scenarios.
+type TPacketCase struct {
+ RawBytes []byte // the bytes that make the packet
+ ActualBytes []byte // the actual byte array that is created in the event of a byte mutation
+ Group string // a group that should run the test, blank for all
+ Desc string // a description of the test
+ FailFirst error // expected fail result to be run immediately after the method is called
+ Packet *Packet // the packet that is Expected
+ ActualPacket *Packet // the actual packet after mutations
+ Expect error // generic Expected fail result to be checked
+ Isolate bool // isolate can be used to isolate a test
+ Primary bool // primary is a test that should be run using readPackets
+ Case byte // the identifying byte of the case
+}
+
+// TPacketCases is a slice of TPacketCase.
+type TPacketCases []TPacketCase
+
+// Get returns a case matching a given T byte.
+func (f TPacketCases) Get(b byte) TPacketCase {
+ for _, v := range f {
+ if v.Case == b {
+ return v
+ }
+ }
+
+ return TPacketCase{}
+}
+
+const (
+ TConnectMqtt31 byte = iota
+ TConnectMqtt311
+ TConnectMqtt5
+ TConnectMqtt5LWT
+ TConnectClean
+ TConnectUserPass
+ TConnectUserPassLWT
+ TConnectMalProtocolName
+ TConnectMalProtocolVersion
+ TConnectMalFlags
+ TConnectMalKeepalive
+ TConnectMalClientID
+ TConnectMalWillTopic
+ TConnectMalWillFlag
+ TConnectMalUsername
+ TConnectMalPassword
+ TConnectMalFixedHeader
+ TConnectMalReservedBit
+ TConnectMalProperties
+ TConnectMalWillProperties
+ TConnectInvalidProtocolName
+ TConnectInvalidProtocolVersion
+ TConnectInvalidProtocolVersion2
+ TConnectInvalidReservedBit
+ TConnectInvalidClientIDTooLong
+ TConnectInvalidFlagNoUsername
+ TConnectInvalidFlagNoPassword
+ TConnectInvalidUsernameNoFlag
+ TConnectInvalidPasswordNoFlag
+ TConnectInvalidUsernameTooLong
+ TConnectInvalidPasswordTooLong
+ TConnectInvalidWillFlagNoPayload
+ TConnectInvalidWillFlagQosOutOfRange
+ TConnectInvalidWillSurplusRetain
+ TConnectZeroByteUsername
+ TConnectSpecInvalidUTF8D800
+ TConnectSpecInvalidUTF8DFFF
+ TConnectSpecInvalidUTF80000
+ TConnectSpecInvalidUTF8NoSkip
+ TConnackAcceptedNoSession
+ TConnackAcceptedSessionExists
+ TConnackAcceptedMqtt5
+ TConnackAcceptedAdjustedExpiryInterval
+ TConnackMinMqtt5
+ TConnackMinCleanMqtt5
+ TConnackServerKeepalive
+ TConnackInvalidMinMqtt5
+ TConnackBadProtocolVersion
+ TConnackProtocolViolationNoSession
+ TConnackBadClientID
+ TConnackServerUnavailable
+ TConnackBadUsernamePassword
+ TConnackBadUsernamePasswordNoSession
+ TConnackMqtt5BadUsernamePasswordNoSession
+ TConnackNotAuthorised
+ TConnackMalSessionPresent
+ TConnackMalReturnCode
+ TConnackMalProperties
+ TConnackDropProperties
+ TConnackDropPropertiesPartial
+ TPublishNoPayload
+ TPublishBasic
+ TPublishBasicTopicAliasOnly
+ TPublishBasicMqtt5
+ TPublishMqtt5
+ TPublishQos1
+ TPublishQos1Mqtt5
+ TPublishQos1NoPayload
+ TPublishQos1Dup
+ TPublishQos2
+ TPublishQos2Mqtt5
+ TPublishQos2Upgraded
+ TPublishSubscriberIdentifier
+ TPublishRetain
+ TPublishRetainMqtt5
+ TPublishDup
+ TPublishMalTopicName
+ TPublishMalPacketID
+ TPublishMalProperties
+ TPublishCopyBasic
+ TPublishSpecQos0NoPacketID
+ TPublishSpecQosMustPacketID
+ TPublishDropOversize
+ TPublishInvalidQos0NoPacketID
+ TPublishInvalidQosMustPacketID
+ TPublishInvalidSurplusSubID
+ TPublishInvalidSurplusWildcard
+ TPublishInvalidSurplusWildcard2
+ TPublishInvalidNoTopic
+ TPublishInvalidTopicAlias
+ TPublishInvalidExcessTopicAlias
+ TPublishSpecDenySysTopic
+ TPuback
+ TPubackMqtt5
+ TPubackMqtt5NotAuthorized
+ TPubackMalPacketID
+ TPubackMalProperties
+ TPubackUnexpectedError
+ TPubrec
+ TPubrecMqtt5
+ TPubrecMqtt5IDInUse
+ TPubrecMqtt5NotAuthorized
+ TPubrecMalPacketID
+ TPubrecMalProperties
+ TPubrecMalReasonCode
+ TPubrecInvalidReason
+ TPubrel
+ TPubrelMqtt5
+ TPubrelMqtt5AckNoPacket
+ TPubrelMalPacketID
+ TPubrelMalProperties
+ TPubrelInvalidReason
+ TPubcomp
+ TPubcompMqtt5
+ TPubcompMqtt5AckNoPacket
+ TPubcompMalPacketID
+ TPubcompMalProperties
+ TPubcompInvalidReason
+ TSubscribe
+ TSubscribeMany
+ TSubscribeMqtt5
+ TSubscribeRetainHandling1
+ TSubscribeRetainHandling2
+ TSubscribeRetainAsPublished
+ TSubscribeMalPacketID
+ TSubscribeMalTopic
+ TSubscribeMalQos
+ TSubscribeMalQosRange
+ TSubscribeMalProperties
+ TSubscribeInvalidQosMustPacketID
+ TSubscribeSpecQosMustPacketID
+ TSubscribeInvalidNoFilters
+ TSubscribeInvalidSharedNoLocal
+ TSubscribeInvalidFilter
+ TSubscribeInvalidIdentifierOversize
+ TSuback
+ TSubackMany
+ TSubackDeny
+ TSubackUnspecifiedError
+ TSubackUnspecifiedErrorMqtt5
+ TSubackMqtt5
+ TSubackPacketIDInUse
+ TSubackInvalidFilter
+ TSubackInvalidSharedNoLocal
+ TSubackMalPacketID
+ TSubackMalProperties
+ TUnsubscribe
+ TUnsubscribeMany
+ TUnsubscribeMqtt5
+ TUnsubscribeMalPacketID
+ TUnsubscribeMalTopicName
+ TUnsubscribeMalProperties
+ TUnsubscribeInvalidQosMustPacketID
+ TUnsubscribeSpecQosMustPacketID
+ TUnsubscribeInvalidNoFilters
+ TUnsuback
+ TUnsubackMany
+ TUnsubackMqtt5
+ TUnsubackPacketIDInUse
+ TUnsubackMalPacketID
+ TUnsubackMalProperties
+ TPingreq
+ TPingresp
+ TDisconnect
+ TDisconnectTakeover
+ TDisconnectMqtt5
+ TDisconnectMqtt5DisconnectWithWillMessage
+ TDisconnectSecondConnect
+ TDisconnectReceiveMaximum
+ TDisconnectDropProperties
+ TDisconnectShuttingDown
+ TDisconnectMalProperties
+ TDisconnectMalReasonCode
+ TDisconnectZeroNonZeroExpiry
+ TAuth
+ TAuthMalReasonCode
+ TAuthMalProperties
+ TAuthInvalidReason
+ TAuthInvalidReason2
+)
+
+// TPacketData contains individual encoding and decoding scenarios for each packet type.
+var TPacketData = map[byte]TPacketCases{
+ Connect: {
+ {
+ Case: TConnectMqtt31,
+ Desc: "mqtt v3.1",
+ Primary: true,
+ RawBytes: []byte{
+ Connect << 4, 17, // Fixed header
+ 0, 6, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name
+ 3, // Protocol Version
+ 0, // Packet Flags
+ 0, 30, // Keepalive
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connect,
+ Remaining: 17,
+ },
+ ProtocolVersion: 3,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQIsdp"),
+ Clean: false,
+ Keepalive: 30,
+ ClientIdentifier: "zen",
+ },
+ },
+ },
+ {
+ Case: TConnectMqtt311,
+ Desc: "mqtt v3.1.1",
+ Primary: true,
+ RawBytes: []byte{
+ Connect << 4, 15, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 0, // Packet Flags
+ 0, 60, // Keepalive
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connect,
+ Remaining: 15,
+ },
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ Clean: false,
+ Keepalive: 60,
+ ClientIdentifier: "zen",
+ },
+ },
+ },
+ {
+ Case: TConnectMqtt5,
+ Desc: "mqtt v5",
+ Primary: true,
+ RawBytes: []byte{
+ Connect << 4, 87, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 5, // Protocol Version
+ 0, // Packet Flags
+ 0, 30, // Keepalive
+
+ // Properties
+ 71, // length
+ 17, 0, 0, 0, 120, // Session Expiry Interval (17)
+ 21, 0, 5, 'S', 'H', 'A', '-', '1', // Authentication Method (21)
+ 22, 0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a', // Authentication Data (22)
+ 23, 1, // Request Problem Info (23)
+ 25, 1, // Request Response Info (25)
+ 33, 1, 244, // Receive Maximum (33)
+ 34, 3, 231, // Topic Alias Maximum (34)
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ 38, // User Properties (38)
+ 0, 4, 'k', 'e', 'y', '2',
+ 0, 6, 'v', 'a', 'l', 'u', 'e', '2',
+ 39, 0, 0, 125, 0, // Maximum Packet Size (39)
+
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connect,
+ Remaining: 87,
+ },
+ ProtocolVersion: 5,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ Clean: false,
+ Keepalive: 30,
+ ClientIdentifier: "zen",
+ },
+ Properties: Properties{
+ SessionExpiryInterval: uint32(120),
+ SessionExpiryIntervalFlag: true,
+ AuthenticationMethod: "SHA-1",
+ AuthenticationData: []byte("auth-data"),
+ RequestProblemInfo: byte(1),
+ RequestProblemInfoFlag: true,
+ RequestResponseInfo: byte(1),
+ ReceiveMaximum: uint16(500),
+ TopicAliasMaximum: uint16(999),
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ {
+ Key: "key2",
+ Val: "value2",
+ },
+ },
+ MaximumPacketSize: uint32(32000),
+ },
+ },
+ },
+ {
+ Case: TConnectClean,
+ Desc: "mqtt 3.1.1, clean session",
+ RawBytes: []byte{
+ Connect << 4, 15, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 2, // Packet Flags
+ 0, 45, // Keepalive
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connect,
+ Remaining: 15,
+ },
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ Clean: true,
+ Keepalive: 45,
+ ClientIdentifier: "zen",
+ },
+ },
+ },
+ {
+ Case: TConnectMqtt5LWT,
+ Desc: "mqtt 5 clean session, lwt",
+ RawBytes: []byte{
+ Connect << 4, 47, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 5, // Protocol Version
+ 14, // Packet Flags
+ 0, 30, // Keepalive
+
+ // Properties
+ 10, // length
+ 17, 0, 0, 0, 120, // Session Expiry Interval (17)
+ 39, 0, 0, 125, 0, // Maximum Packet Size (39)
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ 5, // will properties length
+ 24, 0, 0, 2, 88, // will delay interval (24)
+
+ 0, 3, // Will Topic - MSB+LSB
+ 'l', 'w', 't',
+ 0, 8, // Will Message MSB+LSB
+ 'n', 'o', 't', 'a', 'g', 'a', 'i', 'n',
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Connect,
+ Remaining: 42,
+ },
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ Clean: true,
+ Keepalive: 30,
+ ClientIdentifier: "zen",
+ WillFlag: true,
+ WillTopic: "lwt",
+ WillPayload: []byte("notagain"),
+ WillQos: 1,
+ WillProperties: Properties{
+ WillDelayInterval: uint32(600),
+ },
+ },
+ Properties: Properties{
+ SessionExpiryInterval: uint32(120),
+ SessionExpiryIntervalFlag: true,
+ MaximumPacketSize: uint32(32000),
+ },
+ },
+ },
+ {
+ Case: TConnectUserPass,
+ Desc: "mqtt 3.1.1, username, password",
+ RawBytes: []byte{
+ Connect << 4, 28, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 0 | 1<<6 | 1<<7, // Packet Flags
+ 0, 20, // Keepalive
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ 0, 5, // Username MSB+LSB
+ 'm', 'o', 'c', 'h', 'i',
+ 0, 4, // Password MSB+LSB
+ ',', '.', '/', ';',
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connect,
+ Remaining: 28,
+ },
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ Clean: false,
+ Keepalive: 20,
+ ClientIdentifier: "zen",
+ UsernameFlag: true,
+ PasswordFlag: true,
+ Username: []byte("mochi"),
+ Password: []byte(",./;"),
+ },
+ },
+ },
+ {
+ Case: TConnectUserPassLWT,
+ Desc: "mqtt 3.1.1, username, password, lwt",
+ Primary: true,
+ RawBytes: []byte{
+ Connect << 4, 44, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 206, // Packet Flags
+ 0, 120, // Keepalive
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ 0, 3, // Will Topic - MSB+LSB
+ 'l', 'w', 't',
+ 0, 9, // Will Message MSB+LSB
+ 'n', 'o', 't', ' ', 'a', 'g', 'a', 'i', 'n',
+ 0, 5, // Username MSB+LSB
+ 'm', 'o', 'c', 'h', 'i',
+ 0, 4, // Password MSB+LSB
+ ',', '.', '/', ';',
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connect,
+ Remaining: 44,
+ },
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ Clean: true,
+ Keepalive: 120,
+ ClientIdentifier: "zen",
+ UsernameFlag: true,
+ PasswordFlag: true,
+ Username: []byte("mochi"),
+ Password: []byte(",./;"),
+ WillFlag: true,
+ WillTopic: "lwt",
+ WillPayload: []byte("not again"),
+ WillQos: 1,
+ },
+ },
+ },
+ {
+ Case: TConnectZeroByteUsername,
+ Desc: "username flag but 0 byte username",
+ Group: "decode",
+ RawBytes: []byte{
+ Connect << 4, 23, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 5, // Protocol Version
+ 130, // Packet Flags
+ 0, 30, // Keepalive
+ 5, // length
+ 17, 0, 0, 0, 120, // Session Expiry Interval (17)
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ 0, 0, // Username MSB+LSB
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connect,
+ Remaining: 23,
+ },
+ ProtocolVersion: 5,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ Clean: true,
+ Keepalive: 30,
+ ClientIdentifier: "zen",
+ Username: []byte{},
+ UsernameFlag: true,
+ },
+ Properties: Properties{
+ SessionExpiryInterval: uint32(120),
+ SessionExpiryIntervalFlag: true,
+ },
+ },
+ },
+
+ // Fail States
+ {
+ Case: TConnectMalProtocolName,
+ Desc: "malformed protocol name",
+ Group: "decode",
+ FailFirst: ErrMalformedProtocolName,
+ RawBytes: []byte{
+ Connect << 4, 0, // Fixed header
+ 0, 7, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'I', 's', 'd', // Protocol Name
+ },
+ },
+ {
+ Case: TConnectMalProtocolVersion,
+ Desc: "malformed protocol version",
+ Group: "decode",
+ FailFirst: ErrMalformedProtocolVersion,
+ RawBytes: []byte{
+ Connect << 4, 0, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ },
+ },
+ {
+ Case: TConnectMalFlags,
+ Desc: "malformed flags",
+ Group: "decode",
+ FailFirst: ErrMalformedFlags,
+ RawBytes: []byte{
+ Connect << 4, 0, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+
+ },
+ },
+ {
+ Case: TConnectMalKeepalive,
+ Desc: "malformed keepalive",
+ Group: "decode",
+ FailFirst: ErrMalformedKeepalive,
+ RawBytes: []byte{
+ Connect << 4, 0, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 0, // Flags
+ },
+ },
+ {
+ Case: TConnectMalClientID,
+ Desc: "malformed client id",
+ Group: "decode",
+ FailFirst: ErrClientIdentifierNotValid,
+ RawBytes: []byte{
+ Connect << 4, 0, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 0, // Flags
+ 0, 20, // Keepalive
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', // Client ID "zen"
+ },
+ },
+ {
+ Case: TConnectMalWillTopic,
+ Desc: "malformed will topic",
+ Group: "decode",
+ FailFirst: ErrMalformedWillTopic,
+ RawBytes: []byte{
+ Connect << 4, 0, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 14, // Flags
+ 0, 20, // Keepalive
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ 0, 6, // Will Topic - MSB+LSB
+ 'l',
+ },
+ },
+ {
+ Case: TConnectMalWillFlag,
+ Desc: "malformed will flag",
+ Group: "decode",
+ FailFirst: ErrMalformedWillPayload,
+ RawBytes: []byte{
+ Connect << 4, 0, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 14, // Flags
+ 0, 20, // Keepalive
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ 0, 3, // Will Topic - MSB+LSB
+ 'l', 'w', 't',
+ 0, 9, // Will Message MSB+LSB
+ 'n', 'o', 't', ' ', 'a',
+ },
+ },
+ {
+ Case: TConnectMalUsername,
+ Desc: "malformed username",
+ Group: "decode",
+ FailFirst: ErrMalformedUsername,
+ RawBytes: []byte{
+ Connect << 4, 0, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 206, // Flags
+ 0, 20, // Keepalive
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ 0, 3, // Will Topic - MSB+LSB
+ 'l', 'w', 't',
+ 0, 9, // Will Message MSB+LSB
+ 'n', 'o', 't', ' ', 'a', 'g', 'a', 'i', 'n',
+ 0, 5, // Username MSB+LSB
+ 'm', 'o', 'c',
+ },
+ },
+
+ {
+ Case: TConnectInvalidFlagNoUsername,
+ Desc: "username flag with no username bytes",
+ Group: "decode",
+ FailFirst: ErrProtocolViolationFlagNoUsername,
+ RawBytes: []byte{
+ Connect << 4, 17, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 5, // Protocol Version
+ 130, // Flags
+ 0, 20, // Keepalive
+ 0,
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ },
+ },
+ {
+ Case: TConnectMalPassword,
+ Desc: "malformed password",
+ Group: "decode",
+ FailFirst: ErrMalformedPassword,
+ RawBytes: []byte{
+ Connect << 4, 0, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 206, // Flags
+ 0, 20, // Keepalive
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ 0, 3, // Will Topic - MSB+LSB
+ 'l', 'w', 't',
+ 0, 9, // Will Message MSB+LSB
+ 'n', 'o', 't', ' ', 'a', 'g', 'a', 'i', 'n',
+ 0, 5, // Username MSB+LSB
+ 'm', 'o', 'c', 'h', 'i',
+ 0, 4, // Password MSB+LSB
+ ',', '.',
+ },
+ },
+ {
+ Case: TConnectMalFixedHeader,
+ Desc: "malformed fixedheader oversize",
+ Group: "decode",
+ FailFirst: ErrMalformedProtocolName, // packet test doesn't test fixedheader oversize
+ RawBytes: []byte{
+ Connect << 4, 255, 255, 255, 255, 255, // Fixed header
+ },
+ },
+ {
+ Case: TConnectMalReservedBit,
+ Desc: "reserved bit not 0",
+ Group: "nodecode",
+ FailFirst: ErrProtocolViolation,
+ RawBytes: []byte{
+ Connect << 4, 15, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 1, // Packet Flags
+ 0, 45, // Keepalive
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n',
+ },
+ },
+ {
+ Case: TConnectMalProperties,
+ Desc: "malformed properties",
+ Group: "decode",
+ FailFirst: ErrMalformedProperties,
+ RawBytes: []byte{
+ Connect << 4, 47, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 5, // Protocol Version
+ 14, // Packet Flags
+ 0, 30, // Keepalive
+ 10,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+ {
+ Case: TConnectMalWillProperties,
+ Desc: "malformed will properties",
+ Group: "decode",
+ FailFirst: ErrMalformedWillProperties,
+ RawBytes: []byte{
+ Connect << 4, 47, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 5, // Protocol Version
+ 14, // Packet Flags
+ 0, 30, // Keepalive
+ 10, // length
+ 17, 0, 0, 0, 120, // Session Expiry Interval (17)
+ 39, 0, 0, 125, 0, // Maximum Packet Size (39)
+ 0, 3, // Client ID - MSB+LSB
+ 'z', 'e', 'n', // Client ID "zen"
+ 5, // will properties length
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+
+ // Validation Tests
+ {
+ Case: TConnectInvalidProtocolName,
+ Desc: "invalid protocol name",
+ Group: "validate",
+ Expect: ErrProtocolViolationProtocolName,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{Type: Connect},
+ Connect: ConnectParams{
+ ProtocolName: []byte("stuff"),
+ },
+ },
+ },
+ {
+ Case: TConnectInvalidProtocolVersion,
+ Desc: "invalid protocol version",
+ Group: "validate",
+ Expect: ErrProtocolViolationProtocolVersion,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{Type: Connect},
+ ProtocolVersion: 2,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ },
+ },
+ },
+ {
+ Case: TConnectInvalidProtocolVersion2,
+ Desc: "invalid protocol version",
+ Group: "validate",
+ Expect: ErrProtocolViolationProtocolVersion,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{Type: Connect},
+ ProtocolVersion: 2,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQIsdp"),
+ },
+ },
+ },
+ {
+ Case: TConnectInvalidReservedBit,
+ Desc: "reserved bit not 0",
+ Group: "validate",
+ Expect: ErrProtocolViolationReservedBit,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{Type: Connect},
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ },
+ ReservedBit: 1,
+ },
+ },
+ {
+ Case: TConnectInvalidClientIDTooLong,
+ Desc: "client id too long",
+ Group: "validate",
+ Expect: ErrClientIdentifierNotValid,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{Type: Connect},
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ ClientIdentifier: func() string {
+ return string(make([]byte, 65536))
+ }(),
+ },
+ },
+ },
+ {
+ Case: TConnectInvalidUsernameNoFlag,
+ Desc: "has username but no flag",
+ Group: "validate",
+ Expect: ErrProtocolViolationUsernameNoFlag,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{Type: Connect},
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ Username: []byte("username"),
+ },
+ },
+ },
+ {
+ Case: TConnectInvalidFlagNoPassword,
+ Desc: "has password flag but no password",
+ Group: "validate",
+ Expect: ErrProtocolViolationFlagNoPassword,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{Type: Connect},
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ PasswordFlag: true,
+ },
+ },
+ },
+ {
+ Case: TConnectInvalidPasswordNoFlag,
+ Desc: "has password flag but no password",
+ Group: "validate",
+ Expect: ErrProtocolViolationPasswordNoFlag,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{Type: Connect},
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ Password: []byte("password"),
+ },
+ },
+ },
+ {
+ Case: TConnectInvalidUsernameTooLong,
+ Desc: "username too long",
+ Group: "validate",
+ Expect: ErrProtocolViolationUsernameTooLong,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{Type: Connect},
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ UsernameFlag: true,
+ Username: func() []byte {
+ return make([]byte, 65536)
+ }(),
+ },
+ },
+ },
+ {
+ Case: TConnectInvalidPasswordTooLong,
+ Desc: "password too long",
+ Group: "validate",
+ Expect: ErrProtocolViolationPasswordTooLong,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{Type: Connect},
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ UsernameFlag: true,
+ Username: []byte{},
+ PasswordFlag: true,
+ Password: func() []byte {
+ return make([]byte, 65536)
+ }(),
+ },
+ },
+ },
+ {
+ Case: TConnectInvalidWillFlagNoPayload,
+ Desc: "will flag no payload",
+ Group: "validate",
+ Expect: ErrProtocolViolationWillFlagNoPayload,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{Type: Connect},
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ WillFlag: true,
+ },
+ },
+ },
+ {
+ Case: TConnectInvalidWillFlagQosOutOfRange,
+ Desc: "will flag no payload",
+ Group: "validate",
+ Expect: ErrProtocolViolationQosOutOfRange,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{Type: Connect},
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ WillFlag: true,
+ WillTopic: "a/b/c",
+ WillPayload: []byte{'b'},
+ WillQos: 4,
+ },
+ },
+ },
+ {
+ Case: TConnectInvalidWillSurplusRetain,
+ Desc: "no will flag surplus retain",
+ Group: "validate",
+ Expect: ErrProtocolViolationWillFlagSurplusRetain,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{Type: Connect},
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ WillRetain: true,
+ },
+ },
+ },
+
+ // Spec Tests
+ {
+ Case: TConnectSpecInvalidUTF8D800,
+ Desc: "invalid utf8 string (a) - code point U+D800",
+ Group: "decode",
+ FailFirst: ErrClientIdentifierNotValid,
+ RawBytes: []byte{
+ Connect << 4, 0, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 0, // Flags
+ 0, 20, // Keepalive
+ 0, 4, // Client ID - MSB+LSB
+ 'e', 0xed, 0xa0, 0x80, // Client id bearing U+D800
+ },
+ },
+ {
+ Case: TConnectSpecInvalidUTF8DFFF,
+ Desc: "invalid utf8 string (b) - code point U+DFFF",
+ Group: "decode",
+ FailFirst: ErrClientIdentifierNotValid,
+ RawBytes: []byte{
+ Connect << 4, 0, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 0, // Flags
+ 0, 20, // Keepalive
+ 0, 4, // Client ID - MSB+LSB
+ 'e', 0xed, 0xa3, 0xbf, // Client id bearing U+D8FF
+ },
+ },
+
+ {
+ Case: TConnectSpecInvalidUTF80000,
+ Desc: "invalid utf8 string (c) - code point U+0000",
+ Group: "decode",
+ FailFirst: ErrClientIdentifierNotValid,
+ RawBytes: []byte{
+ Connect << 4, 0, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 0, // Flags
+ 0, 20, // Keepalive
+ 0, 3, // Client ID - MSB+LSB
+ 'e', 0xc0, 0x80, // Client id bearing U+0000
+ },
+ },
+
+ {
+ Case: TConnectSpecInvalidUTF8NoSkip,
+ Desc: "utf8 string must not skip or strip code point U+FEFF",
+ //Group: "decode",
+ //FailFirst: ErrMalformedClientID,
+ RawBytes: []byte{
+ Connect << 4, 18, // Fixed header
+ 0, 4, // Protocol Name - MSB+LSB
+ 'M', 'Q', 'T', 'T', // Protocol Name
+ 4, // Protocol Version
+ 0, // Flags
+ 0, 20, // Keepalive
+ 0, 6, // Client ID - MSB+LSB
+ 'e', 'b', 0xEF, 0xBB, 0xBF, 'd', // Client id bearing U+FEFF
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connect,
+ Remaining: 16,
+ },
+ ProtocolVersion: 4,
+ Connect: ConnectParams{
+ ProtocolName: []byte("MQTT"),
+ Keepalive: 20,
+ ClientIdentifier: string([]byte{'e', 'b', 0xEF, 0xBB, 0xBF, 'd'}),
+ },
+ },
+ },
+ },
+ Connack: {
+ {
+ Case: TConnackAcceptedNoSession,
+ Desc: "accepted, no session",
+ Primary: true,
+ RawBytes: []byte{
+ Connack << 4, 2, // fixed header
+ 0, // No existing session
+ CodeSuccess.Code,
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 2,
+ },
+ SessionPresent: false,
+ ReasonCode: CodeSuccess.Code,
+ },
+ },
+ {
+ Case: TConnackAcceptedSessionExists,
+ Desc: "accepted, session exists",
+ Primary: true,
+ RawBytes: []byte{
+ Connack << 4, 2, // fixed header
+ 1, // Session present
+ CodeSuccess.Code,
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 2,
+ },
+ SessionPresent: true,
+ ReasonCode: CodeSuccess.Code,
+ },
+ },
+ {
+ Case: TConnackAcceptedAdjustedExpiryInterval,
+ Desc: "accepted, no session, adjusted expiry interval mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Connack << 4, 8, // fixed header
+ 0, // Session present
+ CodeSuccess.Code,
+ 5, // length
+ 17, 0, 0, 0, 120, // Session Expiry Interval (17)
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 8,
+ },
+ ReasonCode: CodeSuccess.Code,
+ Properties: Properties{
+ SessionExpiryInterval: uint32(120),
+ SessionExpiryIntervalFlag: true,
+ },
+ },
+ },
+ {
+ Case: TConnackAcceptedMqtt5,
+ Desc: "accepted no session mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Connack << 4, 124, // fixed header
+ 0, // No existing session
+ CodeSuccess.Code,
+ // Properties
+ 121, // length
+ 17, 0, 0, 0, 120, // Session Expiry Interval (17)
+ 18, 0, 8, 'm', 'o', 'c', 'h', 'i', '-', 'v', '5', // Assigned Client ID (18)
+ 19, 0, 20, // Server Keep Alive (19)
+ 21, 0, 5, 'S', 'H', 'A', '-', '1', // Authentication Method (21)
+ 22, 0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a', // Authentication Data (22)
+ 26, 0, 8, 'r', 'e', 's', 'p', 'o', 'n', 's', 'e', // Response Info (26)
+ 28, 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', // Server Reference (28)
+ 31, 0, 6, 'r', 'e', 'a', 's', 'o', 'n', // Reason String (31)
+ 33, 1, 244, // Receive Maximum (33)
+ 34, 3, 231, // Topic Alias Maximum (34)
+ 36, 1, // Maximum Qos (36)
+ 37, 1, // Retain Available (37)
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ 38, // User Properties (38)
+ 0, 4, 'k', 'e', 'y', '2',
+ 0, 6, 'v', 'a', 'l', 'u', 'e', '2',
+ 39, 0, 0, 125, 0, // Maximum Packet Size (39)
+ 40, 1, // Wildcard Subscriptions Available (40)
+ 41, 1, // Subscription ID Available (41)
+ 42, 1, // Shared Subscriptions Available (42)
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 124,
+ },
+ SessionPresent: false,
+ ReasonCode: CodeSuccess.Code,
+ Properties: Properties{
+ SessionExpiryInterval: uint32(120),
+ SessionExpiryIntervalFlag: true,
+ AssignedClientID: "mochi-v5",
+ ServerKeepAlive: uint16(20),
+ ServerKeepAliveFlag: true,
+ AuthenticationMethod: "SHA-1",
+ AuthenticationData: []byte("auth-data"),
+ ResponseInfo: "response",
+ ServerReference: "mochi-2",
+ ReasonString: "reason",
+ ReceiveMaximum: uint16(500),
+ TopicAliasMaximum: uint16(999),
+ MaximumQos: byte(1),
+ MaximumQosFlag: true,
+ RetainAvailable: byte(1),
+ RetainAvailableFlag: true,
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ {
+ Key: "key2",
+ Val: "value2",
+ },
+ },
+ MaximumPacketSize: uint32(32000),
+ WildcardSubAvailable: byte(1),
+ WildcardSubAvailableFlag: true,
+ SubIDAvailable: byte(1),
+ SubIDAvailableFlag: true,
+ SharedSubAvailable: byte(1),
+ SharedSubAvailableFlag: true,
+ },
+ },
+ },
+ {
+ Case: TConnackMinMqtt5,
+ Desc: "accepted min properties mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Connack << 4, 13, // fixed header
+ 1, // existing session
+ CodeSuccess.Code,
+ 10, // Properties length
+ 18, 0, 5, 'm', 'o', 'c', 'h', 'i', // Assigned Client ID (18)
+ 36, 1, // Maximum Qos (36)
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 13,
+ },
+ SessionPresent: true,
+ ReasonCode: CodeSuccess.Code,
+ Properties: Properties{
+ AssignedClientID: "mochi",
+ MaximumQos: byte(1),
+ MaximumQosFlag: true,
+ },
+ },
+ },
+ {
+ Case: TConnackMinCleanMqtt5,
+ Desc: "accepted min properties mqtt5b",
+ Primary: true,
+ RawBytes: []byte{
+ Connack << 4, 3, // fixed header
+ 0, // existing session
+ CodeSuccess.Code,
+ 0, // Properties length
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 16,
+ },
+ SessionPresent: false,
+ ReasonCode: CodeSuccess.Code,
+ },
+ },
+ {
+ Case: TConnackServerKeepalive,
+ Desc: "server set keepalive",
+ Primary: true,
+ RawBytes: []byte{
+ Connack << 4, 6, // fixed header
+ 1, // existing session
+ CodeSuccess.Code,
+ 3, // Properties length
+ 19, 0, 10, // server keepalive
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 6,
+ },
+ SessionPresent: true,
+ ReasonCode: CodeSuccess.Code,
+ Properties: Properties{
+ ServerKeepAlive: uint16(10),
+ ServerKeepAliveFlag: true,
+ },
+ },
+ },
+ {
+ Case: TConnackInvalidMinMqtt5,
+ Desc: "failure min properties mqtt5",
+ Primary: true,
+ RawBytes: append([]byte{
+ Connack << 4, 23, // fixed header
+ 0, // No existing session
+ ErrUnspecifiedError.Code,
+ // Properties
+ 20, // length
+ 31, 0, 17, // Reason String (31)
+ }, []byte(ErrUnspecifiedError.Reason)...),
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 23,
+ },
+ SessionPresent: false,
+ ReasonCode: ErrUnspecifiedError.Code,
+ Properties: Properties{
+ ReasonString: ErrUnspecifiedError.Reason,
+ },
+ },
+ },
+
+ {
+ Case: TConnackProtocolViolationNoSession,
+ Desc: "miscellaneous protocol violation",
+ RawBytes: []byte{
+ Connack << 4, 2, // fixed header
+ 0, // Session present
+ ErrProtocolViolation.Code,
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 2,
+ },
+ ReasonCode: ErrProtocolViolation.Code,
+ },
+ },
+ {
+ Case: TConnackBadProtocolVersion,
+ Desc: "bad protocol version",
+ RawBytes: []byte{
+ Connack << 4, 2, // fixed header
+ 1, // Session present
+ ErrProtocolViolationProtocolVersion.Code,
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 2,
+ },
+ SessionPresent: true,
+ ReasonCode: ErrProtocolViolationProtocolVersion.Code,
+ },
+ },
+ {
+ Case: TConnackBadClientID,
+ Desc: "bad client id",
+ RawBytes: []byte{
+ Connack << 4, 2, // fixed header
+ 1, // Session present
+ ErrClientIdentifierNotValid.Code,
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 2,
+ },
+ SessionPresent: true,
+ ReasonCode: ErrClientIdentifierNotValid.Code,
+ },
+ },
+ {
+ Case: TConnackServerUnavailable,
+ Desc: "server unavailable",
+ RawBytes: []byte{
+ Connack << 4, 2, // fixed header
+ 1, // Session present
+ ErrServerUnavailable.Code,
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 2,
+ },
+ SessionPresent: true,
+ ReasonCode: ErrServerUnavailable.Code,
+ },
+ },
+ {
+ Case: TConnackBadUsernamePassword,
+ Desc: "bad username or password",
+ RawBytes: []byte{
+ Connack << 4, 2, // fixed header
+ 1, // Session present
+ ErrBadUsernameOrPassword.Code,
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 2,
+ },
+ SessionPresent: true,
+ ReasonCode: ErrBadUsernameOrPassword.Code,
+ },
+ },
+ {
+ Case: TConnackBadUsernamePasswordNoSession,
+ Desc: "bad username or password no session",
+ RawBytes: []byte{
+ Connack << 4, 2, // fixed header
+ 0, // No session present
+ Err3NotAuthorized.Code, // use v3 remapping
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 2,
+ },
+ ReasonCode: Err3NotAuthorized.Code,
+ },
+ },
+ {
+ Case: TConnackMqtt5BadUsernamePasswordNoSession,
+ Desc: "mqtt5 bad username or password no session",
+ RawBytes: []byte{
+ Connack << 4, 3, // fixed header
+ 0, // No session present
+ ErrBadUsernameOrPassword.Code,
+ 0,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 2,
+ },
+ ReasonCode: ErrBadUsernameOrPassword.Code,
+ },
+ },
+
+ {
+ Case: TConnackNotAuthorised,
+ Desc: "not authorised",
+ RawBytes: []byte{
+ Connack << 4, 2, // fixed header
+ 1, // Session present
+ ErrNotAuthorized.Code,
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 2,
+ },
+ SessionPresent: true,
+ ReasonCode: ErrNotAuthorized.Code,
+ },
+ },
+ {
+ Case: TConnackDropProperties,
+ Desc: "drop oversize properties",
+ Group: "encode",
+ RawBytes: []byte{
+ Connack << 4, 40, // fixed header
+ 0, // No existing session
+ CodeSuccess.Code,
+ 19, // length
+ 28, 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', // Server Reference (28)
+ 31, 0, 6, 'r', 'e', 'a', 's', 'o', 'n', // Reason String (31)
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ },
+ ActualBytes: []byte{
+ Connack << 4, 13, // fixed header
+ 0, // No existing session
+ CodeSuccess.Code,
+ 10, // length
+ 28, 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', // Server Reference (28)
+ },
+ Packet: &Packet{
+ Mods: Mods{
+ MaxSize: 5,
+ },
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 40,
+ },
+ ReasonCode: CodeSuccess.Code,
+ Properties: Properties{
+ ReasonString: "reason",
+ ServerReference: "mochi-2",
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+ {
+ Case: TConnackDropPropertiesPartial,
+ Desc: "drop oversize properties partial",
+ Group: "encode",
+ RawBytes: []byte{
+ Connack << 4, 40, // fixed header
+ 0, // No existing session
+ CodeSuccess.Code,
+ 19, // length
+ 28, 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', // Server Reference (28)
+ 31, 0, 6, 'r', 'e', 'a', 's', 'o', 'n', // Reason String (31)
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ },
+ ActualBytes: []byte{
+ Connack << 4, 22, // fixed header
+ 0, // No existing session
+ CodeSuccess.Code,
+ 19, // length
+ 28, 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', // Server Reference (28)
+ 31, 0, 6, 'r', 'e', 'a', 's', 'o', 'n', // Reason String (31)
+ },
+ Packet: &Packet{
+ Mods: Mods{
+ MaxSize: 18,
+ },
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Connack,
+ Remaining: 40,
+ },
+ ReasonCode: CodeSuccess.Code,
+ Properties: Properties{
+ ReasonString: "reason",
+ ServerReference: "mochi-2",
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+ // Fail States
+ {
+ Case: TConnackMalSessionPresent,
+ Desc: "malformed session present",
+ Group: "decode",
+ FailFirst: ErrMalformedSessionPresent,
+ RawBytes: []byte{
+ Connect << 4, 2, // Fixed header
+ },
+ },
+ {
+ Case: TConnackMalReturnCode,
+ Desc: "malformed bad return Code",
+ Group: "decode",
+ //Primary: true,
+ FailFirst: ErrMalformedReasonCode,
+ RawBytes: []byte{
+ Connect << 4, 2, // Fixed header
+ 0,
+ },
+ },
+ {
+ Case: TConnackMalProperties,
+ Desc: "malformed properties",
+ Group: "decode",
+ FailFirst: ErrMalformedProperties,
+ RawBytes: []byte{
+ Connack << 4, 40, // fixed header
+ 0, // No existing session
+ CodeSuccess.Code,
+ 19, // length
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+ },
+
+ Publish: {
+ {
+ Case: TPublishNoPayload,
+ Desc: "no payload",
+ Primary: true,
+ RawBytes: []byte{
+ Publish << 4, 7, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Remaining: 7,
+ },
+ TopicName: "a/b/c",
+ Payload: []byte{},
+ },
+ },
+ {
+ Case: TPublishBasic,
+ Desc: "basic",
+ Primary: true,
+ RawBytes: []byte{
+ Publish << 4, 18, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Remaining: 18,
+ },
+ TopicName: "a/b/c",
+ Payload: []byte("hello mochi"),
+ },
+ },
+
+ {
+ Case: TPublishMqtt5,
+ Desc: "mqtt v5",
+ Primary: true,
+ RawBytes: []byte{
+ Publish << 4, 77, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 58, // length
+ 1, 1, // Payload Format (1)
+ 2, 0, 0, 0, 2, // Message Expiry (2)
+ 3, 0, 10, 't', 'e', 'x', 't', '/', 'p', 'l', 'a', 'i', 'n', // Content Type (3)
+ 8, 0, 5, 'a', '/', 'b', '/', 'c', // Response Topic (8)
+ 9, 0, 4, 'd', 'a', 't', 'a', // Correlations Data (9)
+ 11, 202, 212, 19, // Subscription Identifier (11)
+ 35, 0, 3, // Topic Alias (35)
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Remaining: 77,
+ },
+ TopicName: "a/b/c",
+ Properties: Properties{
+ PayloadFormat: byte(1), // UTF-8 Format
+ PayloadFormatFlag: true,
+ MessageExpiryInterval: uint32(2),
+ ContentType: "text/plain",
+ ResponseTopic: "a/b/c",
+ CorrelationData: []byte("data"),
+ SubscriptionIdentifier: []int{322122},
+ TopicAlias: uint16(3),
+ TopicAliasFlag: true,
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ Payload: []byte("hello mochi"),
+ },
+ },
+ {
+ Case: TPublishBasicTopicAliasOnly,
+ Desc: "mqtt v5 topic alias only",
+ Primary: true,
+ RawBytes: []byte{
+ Publish << 4, 17, // Fixed header
+ 0, 0, // Topic Name - LSB+MSB
+ 3, // length
+ 35, 0, 1, // Topic Alias (35)
+ 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Remaining: 17,
+ },
+ Properties: Properties{
+ TopicAlias: 1,
+ TopicAliasFlag: true,
+ },
+ Payload: []byte("hello mochi"),
+ },
+ },
+ {
+ Case: TPublishBasicMqtt5,
+ Desc: "mqtt basic v5",
+ Primary: true,
+ RawBytes: []byte{
+ Publish << 4, 22, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 3, // length
+ 35, 0, 1, // Topic Alias (35)
+ 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Remaining: 22,
+ },
+ TopicName: "a/b/c",
+ Properties: Properties{
+ TopicAlias: uint16(1),
+ TopicAliasFlag: true,
+ },
+ Payload: []byte("hello mochi"),
+ },
+ },
+
+ {
+ Case: TPublishQos1,
+ Desc: "qos:1, packet id",
+ Primary: true,
+ RawBytes: []byte{
+ Publish<<4 | 1<<1, 20, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 0, 7, // Packet ID - LSB+MSB
+ 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Qos: 1,
+ Remaining: 20,
+ },
+ TopicName: "a/b/c",
+ Payload: []byte("hello mochi"),
+ PacketID: 7,
+ },
+ },
+ {
+ Case: TPublishQos1Mqtt5,
+ Desc: "mqtt v5",
+ Primary: true,
+ RawBytes: []byte{
+ Publish<<4 | 1<<1, 37, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 0, 7, // Packet ID - LSB+MSB
+ // Properties
+ 16, // length
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Remaining: 37,
+ Qos: 1,
+ },
+ PacketID: 7,
+ TopicName: "a/b/c",
+ Properties: Properties{
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ Payload: []byte("hello mochi"),
+ },
+ },
+
+ {
+ Case: TPublishQos1Dup,
+ Desc: "qos:1, dup:true, packet id",
+ Primary: true,
+ RawBytes: []byte{
+ Publish<<4 | 2 | 8, 20, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 0, 7, // Packet ID - LSB+MSB
+ 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Qos: 1,
+ Remaining: 20,
+ Dup: true,
+ },
+ TopicName: "a/b/c",
+ Payload: []byte("hello mochi"),
+ PacketID: 7,
+ },
+ },
+ {
+ Case: TPublishQos1NoPayload,
+ Desc: "qos:1, packet id, no payload",
+ Primary: true,
+ RawBytes: []byte{
+ Publish<<4 | 2, 9, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'y', '/', 'u', '/', 'i', // Topic Name
+ 0, 7, // Packet ID - LSB+MSB
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Qos: 1,
+ Remaining: 9,
+ },
+ TopicName: "y/u/i",
+ PacketID: 7,
+ Payload: []byte{},
+ },
+ },
+ {
+ Case: TPublishQos2,
+ Desc: "qos:2, packet id",
+ Primary: true,
+ RawBytes: []byte{
+ Publish<<4 | 2<<1, 14, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 0, 7, // Packet ID - LSB+MSB
+ 'h', 'e', 'l', 'l', 'o', // Payload
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Qos: 2,
+ Remaining: 14,
+ },
+ TopicName: "a/b/c",
+ Payload: []byte("hello"),
+ PacketID: 7,
+ },
+ },
+ {
+ Case: TPublishQos2Mqtt5,
+ Desc: "mqtt v5",
+ Primary: true,
+ RawBytes: []byte{
+ Publish<<4 | 2<<1, 37, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 0, 7, // Packet ID - LSB+MSB
+ // Properties
+ 16, // length
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Remaining: 37,
+ Qos: 2,
+ },
+ PacketID: 7,
+ TopicName: "a/b/c",
+ Properties: Properties{
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ Payload: []byte("hello mochi"),
+ },
+ },
+ {
+ Case: TPublishSubscriberIdentifier,
+ Desc: "subscription identifiers",
+ Primary: true,
+ RawBytes: []byte{
+ Publish << 4, 23, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 4, // properties length
+ 11, 2, // Subscription Identifier (11)
+ 11, 3, // Subscription Identifier (11)
+ 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Remaining: 23,
+ },
+ TopicName: "a/b/c",
+ Properties: Properties{
+ SubscriptionIdentifier: []int{2, 3},
+ },
+ Payload: []byte("hello mochi"),
+ },
+ },
+
+ {
+ Case: TPublishQos2Upgraded,
+ Desc: "qos:2, upgraded from publish to qos2 sub",
+ Primary: true,
+ RawBytes: []byte{
+ Publish<<4 | 2<<1, 20, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 0, 1, // Packet ID - LSB+MSB
+ 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Qos: 2,
+ Remaining: 18,
+ },
+ TopicName: "a/b/c",
+ Payload: []byte("hello mochi"),
+ PacketID: 1,
+ },
+ },
+ {
+ Case: TPublishRetain,
+ Desc: "retain",
+ RawBytes: []byte{
+ Publish<<4 | 1<<0, 18, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Retain: true,
+ },
+ TopicName: "a/b/c",
+ Payload: []byte("hello mochi"),
+ },
+ },
+ {
+ Case: TPublishRetainMqtt5,
+ Desc: "retain mqtt5",
+ RawBytes: []byte{
+ Publish<<4 | 1<<0, 19, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 0, // properties length
+ 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Retain: true,
+ Remaining: 19,
+ },
+ TopicName: "a/b/c",
+ Properties: Properties{},
+ Payload: []byte("hello mochi"),
+ },
+ },
+ {
+ Case: TPublishDup,
+ Desc: "dup",
+ RawBytes: []byte{
+ Publish<<4 | 8, 10, // Fixed header
+ 0, 3, // Topic Name - LSB+MSB
+ 'a', '/', 'b', // Topic Name
+ 'h', 'e', 'l', 'l', 'o', // Payload
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Dup: true,
+ },
+ TopicName: "a/b",
+ Payload: []byte("hello"),
+ },
+ },
+
+ // Fail States
+ {
+ Case: TPublishMalTopicName,
+ Desc: "malformed topic name",
+ Group: "decode",
+ FailFirst: ErrMalformedTopic,
+ RawBytes: []byte{
+ Publish << 4, 7, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/',
+ 0, 11, // Packet ID - LSB+MSB
+ },
+ },
+ {
+ Case: TPublishMalPacketID,
+ Desc: "malformed packet id",
+ Group: "decode",
+ FailFirst: ErrMalformedPacketID,
+ RawBytes: []byte{
+ Publish<<4 | 2, 7, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'x', '/', 'y', '/', 'z', // Topic Name
+ 0, // Packet ID - LSB+MSB
+ },
+ },
+ {
+ Case: TPublishMalProperties,
+ Desc: "malformed properties",
+ Group: "decode",
+ FailFirst: ErrMalformedProperties,
+ RawBytes: []byte{
+ Publish << 4, 35, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 16,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+
+ // Copy tests
+ {
+ Case: TPublishCopyBasic,
+ Desc: "basic copyable",
+ Group: "copy",
+ RawBytes: []byte{
+ Publish << 4, 18, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'z', '/', 'e', '/', 'n', // Topic Name
+ 'm', 'o', 'c', 'h', 'i', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Dup: true,
+ Retain: true,
+ Qos: 1,
+ },
+ TopicName: "z/e/n",
+ Payload: []byte("mochi mochi"),
+ },
+ },
+
+ // Spec tests
+ {
+ Case: TPublishSpecQos0NoPacketID,
+ Desc: "packet id must be 0 if qos is 0 (a)",
+ Group: "encode",
+ // this version tests for correct byte array mutuation.
+ // this does not check if -incoming- Packets are parsed as correct,
+ // it is impossible for the parser to determine if the payload start is incorrect.
+ RawBytes: []byte{
+ Publish << 4, 12, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 0, 3, // Packet ID - LSB+MSB
+ 'h', 'e', 'l', 'l', 'o', // Payload
+ },
+ ActualBytes: []byte{
+ Publish << 4, 12, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ // Packet ID is removed.
+ 'h', 'e', 'l', 'l', 'o', // Payload
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Remaining: 12,
+ },
+ TopicName: "a/b/c",
+ Payload: []byte("hello"),
+ },
+ },
+ {
+ Case: TPublishSpecQosMustPacketID,
+ Desc: "no packet id with qos > 0",
+ Group: "encode",
+ Expect: ErrProtocolViolationNoPacketID,
+ RawBytes: []byte{
+ Publish<<4 | 2, 14, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 0, 0, // Packet ID - LSB+MSB
+ 'h', 'e', 'l', 'l', 'o', // Payload
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Qos: 1,
+ },
+ TopicName: "a/b/c",
+ Payload: []byte("hello"),
+ PacketID: 0,
+ },
+ },
+ {
+ Case: TPublishDropOversize,
+ Desc: "drop oversized publish packet",
+ Group: "encode",
+ FailFirst: ErrPacketTooLarge,
+ RawBytes: []byte{
+ Publish << 4, 10, // Fixed header
+ 0, 3, // Topic Name - LSB+MSB
+ 'a', '/', 'b', // Topic Name
+ 'h', 'e', 'l', 'l', 'o', // Payload
+ },
+ Packet: &Packet{
+ Mods: Mods{
+ MaxSize: 2,
+ },
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ },
+ TopicName: "a/b",
+ Payload: []byte("hello"),
+ },
+ },
+
+ // Validation Tests
+ {
+ Case: TPublishInvalidQos0NoPacketID,
+ Desc: "packet id must be 0 if qos is 0 (b)",
+ Group: "validate",
+ Expect: ErrProtocolViolationSurplusPacketID,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Remaining: 12,
+ Qos: 0,
+ },
+ TopicName: "a/b/c",
+ Payload: []byte("hello"),
+ PacketID: 3,
+ },
+ },
+ {
+ Case: TPublishInvalidQosMustPacketID,
+ Desc: "no packet id with qos > 0",
+ Group: "validate",
+ Expect: ErrProtocolViolationNoPacketID,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ Qos: 1,
+ },
+ PacketID: 0,
+ },
+ },
+ {
+ Case: TPublishInvalidSurplusSubID,
+ Desc: "surplus subscription identifier",
+ Group: "validate",
+ Expect: ErrProtocolViolationSurplusSubID,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ },
+ Properties: Properties{
+ SubscriptionIdentifier: []int{1},
+ },
+ TopicName: "a/b",
+ },
+ },
+ {
+ Case: TPublishInvalidSurplusWildcard,
+ Desc: "topic contains wildcards",
+ Group: "validate",
+ Expect: ErrProtocolViolationSurplusWildcard,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ },
+ TopicName: "a/+",
+ },
+ },
+ {
+ Case: TPublishInvalidSurplusWildcard2,
+ Desc: "topic contains wildcards 2",
+ Group: "validate",
+ Expect: ErrProtocolViolationSurplusWildcard,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ },
+ TopicName: "a/#",
+ },
+ },
+ {
+ Case: TPublishInvalidNoTopic,
+ Desc: "no topic or alias specified",
+ Group: "validate",
+ Expect: ErrProtocolViolationNoTopic,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ },
+ },
+ },
+ {
+ Case: TPublishInvalidExcessTopicAlias,
+ Desc: "topic alias over maximum",
+ Group: "validate",
+ Expect: ErrTopicAliasInvalid,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ },
+ Properties: Properties{
+ TopicAlias: 1025,
+ },
+ TopicName: "a/b",
+ },
+ },
+ {
+ Case: TPublishInvalidTopicAlias,
+ Desc: "topic alias flag and no alias",
+ Group: "validate",
+ Expect: ErrTopicAliasInvalid,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ },
+ Properties: Properties{
+ TopicAliasFlag: true,
+ TopicAlias: 0,
+ },
+ TopicName: "a/b/",
+ },
+ },
+ {
+ Case: TPublishSpecDenySysTopic,
+ Desc: "deny publishing to $SYS topics",
+ Group: "validate",
+ Expect: CodeSuccess,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Publish,
+ },
+ TopicName: "$SYS/any",
+ Payload: []byte("y"),
+ },
+ RawBytes: []byte{
+ Publish << 4, 11, // Fixed header
+ 0, 5, // Topic Name - LSB+MSB
+ '$', 'S', 'Y', 'S', '/', 'a', 'n', 'y', // Topic Name
+ 'y', // Payload
+ },
+ },
+ },
+
+ Puback: {
+ {
+ Case: TPuback,
+ Desc: "puback",
+ Primary: true,
+ RawBytes: []byte{
+ Puback << 4, 2, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Puback,
+ Remaining: 2,
+ },
+ PacketID: 7,
+ },
+ },
+ {
+ Case: TPubackMqtt5,
+ Desc: "mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Puback << 4, 20, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ CodeGrantedQos0.Code, // Reason Code
+ 16, // Properties Length
+ // 31, 0, 6, 'r', 'e', 'a', 's', 'o', 'n', // Reason String (31)
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Puback,
+ Remaining: 20,
+ },
+ PacketID: 7,
+ ReasonCode: CodeGrantedQos0.Code,
+ Properties: Properties{
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+ {
+ Case: TPubackMqtt5NotAuthorized,
+ Desc: "QOS 1 publish not authorized mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Puback << 4, 37, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ ErrNotAuthorized.Code, // Reason Code
+ 33, // Properties Length
+ 31, 0, 14, 'n', 'o', 't', ' ', 'a', 'u',
+ 't', 'h', 'o', 'r', 'i', 'z', 'e', 'd', // Reason String (31)
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Puback,
+ Remaining: 31,
+ },
+ PacketID: 7,
+ ReasonCode: ErrNotAuthorized.Code,
+ Properties: Properties{
+ ReasonString: ErrNotAuthorized.Reason,
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+ {
+ Case: TPubackUnexpectedError,
+ Desc: "unexpected error",
+ Group: "decode",
+ RawBytes: []byte{
+ Puback << 4, 29, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ ErrPayloadFormatInvalid.Code, // Reason Code
+ 25, // Properties Length
+ 31, 0, 22, 'p', 'a', 'y', 'l', 'o', 'a', 'd',
+ ' ', 'f', 'o', 'r', 'm', 'a', 't',
+ ' ', 'i', 'n', 'v', 'a', 'l', 'i', 'd', // Reason String (31)
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Puback,
+ Remaining: 28,
+ },
+ PacketID: 7,
+ ReasonCode: ErrPayloadFormatInvalid.Code,
+ Properties: Properties{
+ ReasonString: ErrPayloadFormatInvalid.Reason,
+ },
+ },
+ },
+
+ // Fail states
+ {
+ Case: TPubackMalPacketID,
+ Desc: "malformed packet id",
+ Group: "decode",
+ FailFirst: ErrMalformedPacketID,
+ RawBytes: []byte{
+ Puback << 4, 2, // Fixed header
+ 0, // Packet ID - LSB+MSB
+ },
+ },
+ {
+ Case: TPubackMalProperties,
+ Desc: "malformed properties",
+ Group: "decode",
+ FailFirst: ErrMalformedProperties,
+ RawBytes: []byte{
+ Puback << 4, 20, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ CodeGrantedQos0.Code, // Reason Code
+ 16,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+ },
+ Pubrec: {
+ {
+ Case: TPubrec,
+ Desc: "pubrec",
+ Primary: true,
+ RawBytes: []byte{
+ Pubrec << 4, 2, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Pubrec,
+ Remaining: 2,
+ },
+ PacketID: 7,
+ },
+ },
+ {
+ Case: TPubrecMqtt5,
+ Desc: "pubrec mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Pubrec << 4, 20, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ CodeSuccess.Code, // Reason Code
+ 16, // Properties Length
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Pubrec,
+ Remaining: 20,
+ },
+ PacketID: 7,
+ ReasonCode: CodeSuccess.Code,
+ Properties: Properties{
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+ {
+ Case: TPubrecMqtt5IDInUse,
+ Desc: "packet id in use mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Pubrec << 4, 47, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ ErrPacketIdentifierInUse.Code, // Reason Code
+ 43, // Properties Length
+ 31, 0, 24, 'p', 'a', 'c', 'k', 'e', 't',
+ ' ', 'i', 'd', 'e', 'n', 't', 'i', 'f', 'i', 'e', 'r',
+ ' ', 'i', 'n',
+ ' ', 'u', 's', 'e', // Reason String (31)
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Pubrec,
+ Remaining: 31,
+ },
+ PacketID: 7,
+ ReasonCode: ErrPacketIdentifierInUse.Code,
+ Properties: Properties{
+ ReasonString: ErrPacketIdentifierInUse.Reason,
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+ {
+ Case: TPubrecMqtt5NotAuthorized,
+ Desc: "QOS 2 publish not authorized mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Pubrec << 4, 37, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ ErrNotAuthorized.Code, // Reason Code
+ 33, // Properties Length
+ 31, 0, 14, 'n', 'o', 't', ' ', 'a', 'u',
+ 't', 'h', 'o', 'r', 'i', 'z', 'e', 'd', // Reason String (31)
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Pubrec,
+ Remaining: 31,
+ },
+ PacketID: 7,
+ ReasonCode: ErrNotAuthorized.Code,
+ Properties: Properties{
+ ReasonString: ErrNotAuthorized.Reason,
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+ {
+ Case: TPubrecMalReasonCode,
+ Desc: "malformed reason code",
+ Group: "decode",
+ FailFirst: ErrMalformedReasonCode,
+ RawBytes: []byte{
+ Pubrec << 4, 31, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+ // Validation
+ {
+ Case: TPubrecInvalidReason,
+ Desc: "invalid reason code",
+ Group: "validate",
+ FailFirst: ErrProtocolViolationInvalidReason,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Pubrec,
+ },
+ PacketID: 7,
+ ReasonCode: ErrConnectionRateExceeded.Code,
+ },
+ },
+ // Fail states
+ {
+ Case: TPubrecMalPacketID,
+ Desc: "malformed packet id",
+ Group: "decode",
+ FailFirst: ErrMalformedPacketID,
+ RawBytes: []byte{
+ Pubrec << 4, 2, // Fixed header
+ 0, // Packet ID - LSB+MSB
+ },
+ },
+ {
+ Case: TPubrecMalProperties,
+ Desc: "malformed properties",
+ Group: "decode",
+ FailFirst: ErrMalformedProperties,
+ RawBytes: []byte{
+ Pubrec << 4, 31, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ ErrPacketIdentifierInUse.Code, // Reason Code
+ 27,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+ },
+ Pubrel: {
+ {
+ Case: TPubrel,
+ Desc: "pubrel",
+ Primary: true,
+ RawBytes: []byte{
+ Pubrel<<4 | 1<<1, 2, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Pubrel,
+ Remaining: 2,
+ Qos: 1,
+ },
+ PacketID: 7,
+ },
+ },
+ {
+ Case: TPubrelMqtt5,
+ Desc: "pubrel mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Pubrel<<4 | 1<<1, 20, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ CodeSuccess.Code, // Reason Code
+ 16, // Properties Length
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Pubrel,
+ Remaining: 20,
+ Qos: 1,
+ },
+ PacketID: 7,
+ ReasonCode: CodeSuccess.Code,
+ Properties: Properties{
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+ {
+ Case: TPubrelMqtt5AckNoPacket,
+ Desc: "mqtt5 no packet id ack",
+ Primary: true,
+ RawBytes: append([]byte{
+ Pubrel<<4 | 1<<1, 34, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ ErrPacketIdentifierNotFound.Code, // Reason Code
+ 30, // Properties Length
+ 31, 0, byte(len(ErrPacketIdentifierNotFound.Reason)),
+ }, []byte(ErrPacketIdentifierNotFound.Reason)...),
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Pubrel,
+ Remaining: 34,
+ Qos: 1,
+ },
+ PacketID: 7,
+ ReasonCode: ErrPacketIdentifierNotFound.Code,
+ Properties: Properties{
+ ReasonString: ErrPacketIdentifierNotFound.Reason,
+ },
+ },
+ },
+ // Validation
+ {
+ Case: TPubrelInvalidReason,
+ Desc: "invalid reason code",
+ Group: "validate",
+ FailFirst: ErrProtocolViolationInvalidReason,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Pubrel,
+ },
+ PacketID: 7,
+ ReasonCode: ErrConnectionRateExceeded.Code,
+ },
+ },
+ // Fail states
+ {
+ Case: TPubrelMalPacketID,
+ Desc: "malformed packet id",
+ Group: "decode",
+ FailFirst: ErrMalformedPacketID,
+ RawBytes: []byte{
+ Pubrel << 4, 2, // Fixed header
+ 0, // Packet ID - LSB+MSB
+ },
+ },
+ {
+ Case: TPubrelMalProperties,
+ Desc: "malformed properties",
+ Group: "decode",
+ FailFirst: ErrMalformedProperties,
+ RawBytes: []byte{
+ Pubrel<<4 | 1<<1, 20, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ CodeSuccess.Code, // Reason Code
+ 16,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+ },
+ Pubcomp: {
+ {
+ Case: TPubcomp,
+ Desc: "pubcomp",
+ Primary: true,
+ RawBytes: []byte{
+ Pubcomp << 4, 2, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Pubcomp,
+ Remaining: 2,
+ },
+ PacketID: 7,
+ },
+ },
+ {
+ Case: TPubcompMqtt5,
+ Desc: "mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Pubcomp << 4, 20, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ CodeSuccess.Code, // Reason Code
+ 16, // Properties Length
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Pubcomp,
+ Remaining: 20,
+ },
+ PacketID: 7,
+ ReasonCode: CodeSuccess.Code,
+ Properties: Properties{
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+ {
+ Case: TPubcompMqtt5AckNoPacket,
+ Desc: "mqtt5 no packet id ack",
+ Primary: true,
+ RawBytes: append([]byte{
+ Pubcomp << 4, 34, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ ErrPacketIdentifierNotFound.Code, // Reason Code
+ 30, // Properties Length
+ 31, 0, byte(len(ErrPacketIdentifierNotFound.Reason)),
+ }, []byte(ErrPacketIdentifierNotFound.Reason)...),
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Pubcomp,
+ Remaining: 34,
+ },
+ PacketID: 7,
+ ReasonCode: ErrPacketIdentifierNotFound.Code,
+ Properties: Properties{
+ ReasonString: ErrPacketIdentifierNotFound.Reason,
+ },
+ },
+ },
+ // Validation
+ {
+ Case: TPubcompInvalidReason,
+ Desc: "invalid reason code",
+ Group: "validate",
+ FailFirst: ErrProtocolViolationInvalidReason,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Pubcomp,
+ },
+ ReasonCode: ErrConnectionRateExceeded.Code,
+ },
+ },
+ // Fail states
+ {
+ Case: TPubcompMalPacketID,
+ Desc: "malformed packet id",
+ Group: "decode",
+ FailFirst: ErrMalformedPacketID,
+ RawBytes: []byte{
+ Pubcomp << 4, 2, // Fixed header
+ 0, // Packet ID - LSB+MSB
+ },
+ },
+ {
+ Case: TPubcompMalProperties,
+ Desc: "malformed properties",
+ Group: "decode",
+ FailFirst: ErrMalformedProperties,
+ RawBytes: []byte{
+ Pubcomp << 4, 34, // Fixed header
+ 0, 7, // Packet ID - LSB+MSB
+ ErrPacketIdentifierNotFound.Code, // Reason Code
+ 22,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+ },
+ Subscribe: {
+ {
+ Case: TSubscribe,
+ Desc: "subscribe",
+ Primary: true,
+ RawBytes: []byte{
+ Subscribe<<4 | 1<<1, 10, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 0, // QoS
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Subscribe,
+ Remaining: 10,
+ Qos: 1,
+ },
+ PacketID: 15,
+ Filters: Subscriptions{
+ {Filter: "a/b/c"},
+ },
+ },
+ },
+ {
+ Case: TSubscribeMany,
+ Desc: "many",
+ Primary: true,
+ RawBytes: []byte{
+ Subscribe<<4 | 1<<1, 30, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+
+ 0, 3, // Topic Name - LSB+MSB
+ 'a', '/', 'b', // Topic Name
+ 0, // QoS
+
+ 0, 11, // Topic Name - LSB+MSB
+ 'd', '/', 'e', '/', 'f', '/', 'g', '/', 'h', '/', 'i', // Topic Name
+ 1, // QoS
+
+ 0, 5, // Topic Name - LSB+MSB
+ 'x', '/', 'y', '/', 'z', // Topic Name
+ 2, // QoS
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Subscribe,
+ Remaining: 30,
+ Qos: 1,
+ },
+ PacketID: 15,
+ Filters: Subscriptions{
+ {Filter: "a/b", Qos: 0},
+ {Filter: "d/e/f/g/h/i", Qos: 1},
+ {Filter: "x/y/z", Qos: 2},
+ },
+ },
+ },
+ {
+ Case: TSubscribeMqtt5,
+ Desc: "mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Subscribe<<4 | 1<<1, 31, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 20,
+ 11, 202, 212, 19, // Subscription Identifier (11)
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+
+ 0, 5, 'a', '/', 'b', '/', 'c', // Topic Name
+ 46, // subscription options
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Subscribe,
+ Remaining: 31,
+ Qos: 1,
+ },
+ PacketID: 15,
+ Filters: Subscriptions{
+ {
+ Filter: "a/b/c",
+ Qos: 2,
+ NoLocal: true,
+ RetainAsPublished: true,
+ RetainHandling: 2,
+ Identifier: 322122,
+ },
+ },
+ Properties: Properties{
+ SubscriptionIdentifier: []int{322122},
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+ {
+ Case: TSubscribeRetainHandling1,
+ Desc: "retain handling 1",
+ Primary: true,
+ RawBytes: []byte{
+ Subscribe<<4 | 1<<1, 11, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 0, // no properties
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 0 | 1<<4, // subscription options
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Subscribe,
+ Remaining: 11,
+ Qos: 1,
+ },
+ PacketID: 15,
+ Filters: Subscriptions{
+ {
+ Filter: "a/b/c",
+ RetainHandling: 1,
+ },
+ },
+ },
+ },
+ {
+ Case: TSubscribeRetainHandling2,
+ Desc: "retain handling 2",
+ Primary: true,
+ RawBytes: []byte{
+ Subscribe<<4 | 1<<1, 11, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 0, // no properties
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 0 | 2<<4, // subscription options
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Subscribe,
+ Remaining: 11,
+ Qos: 1,
+ },
+ PacketID: 15,
+ Filters: Subscriptions{
+ {
+ Filter: "a/b/c",
+ RetainHandling: 2,
+ },
+ },
+ },
+ },
+ {
+ Case: TSubscribeRetainAsPublished,
+ Desc: "retain as published",
+ Primary: true,
+ RawBytes: []byte{
+ Subscribe<<4 | 1<<1, 11, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 0, // no properties
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 0 | 1<<3, // subscription options
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Subscribe,
+ Remaining: 11,
+ Qos: 1,
+ },
+ PacketID: 15,
+ Filters: Subscriptions{
+ {
+ Filter: "a/b/c",
+ RetainAsPublished: true,
+ },
+ },
+ },
+ },
+ {
+ Case: TSubscribeInvalidFilter,
+ Desc: "invalid filter",
+ Group: "reference",
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Subscribe,
+ Qos: 1,
+ },
+ PacketID: 15,
+ Filters: Subscriptions{
+ {Filter: "$SHARE/#", Identifier: 5},
+ },
+ },
+ },
+ {
+ Case: TSubscribeInvalidSharedNoLocal,
+ Desc: "shared and no local",
+ Group: "reference",
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Subscribe,
+ Qos: 1,
+ },
+ PacketID: 15,
+ Filters: Subscriptions{
+ {Filter: "$SHARE/tmp/a/b/c", Identifier: 5, NoLocal: true},
+ },
+ },
+ },
+
+ // Fail states
+ {
+ Case: TSubscribeMalPacketID,
+ Desc: "malformed packet id",
+ Group: "decode",
+ FailFirst: ErrMalformedPacketID,
+ RawBytes: []byte{
+ Subscribe<<4 | 1<<1, 2, // Fixed header
+ 0, // Packet ID - LSB+MSB
+ },
+ },
+ {
+ Case: TSubscribeMalTopic,
+ Desc: "malformed topic",
+ Group: "decode",
+ FailFirst: ErrMalformedTopic,
+ RawBytes: []byte{
+ Subscribe<<4 | 1<<1, 2, // Fixed header
+ 0, 21, // Packet ID - LSB+MSB
+
+ 0, 3, // Topic Name - LSB+MSB
+ 'a', '/',
+ },
+ },
+ {
+ Case: TSubscribeMalQos,
+ Desc: "malformed subscribe - qos",
+ Group: "decode",
+ FailFirst: ErrMalformedQos,
+ RawBytes: []byte{
+ Subscribe<<4 | 1<<1, 2, // Fixed header
+ 0, 22, // Packet ID - LSB+MSB
+
+ 0, 3, // Topic Name - LSB+MSB
+ 'j', '/', 'b', // Topic Name
+
+ },
+ },
+ {
+ Case: TSubscribeMalQosRange,
+ Desc: "malformed qos out of range",
+ Group: "decode",
+ FailFirst: ErrProtocolViolationQosOutOfRange,
+ RawBytes: []byte{
+ Subscribe<<4 | 1<<1, 2, // Fixed header
+ 0, 22, // Packet ID - LSB+MSB
+
+ 0, 3, // Topic Name - LSB+MSB
+ 'c', '/', 'd', // Topic Name
+ 5, // QoS
+
+ },
+ },
+ {
+ Case: TSubscribeMalProperties,
+ Desc: "malformed properties",
+ Group: "decode",
+ FailFirst: ErrMalformedProperties,
+ RawBytes: []byte{
+ Subscribe << 4, 11, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 4,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+
+ // Validation
+ {
+ Case: TSubscribeInvalidQosMustPacketID,
+ Desc: "no packet id with qos > 0",
+ Group: "validate",
+ Expect: ErrProtocolViolationNoPacketID,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Subscribe,
+ Qos: 1,
+ },
+ PacketID: 0,
+ Filters: Subscriptions{
+ {Filter: "a/b"},
+ },
+ },
+ },
+ {
+ Case: TSubscribeInvalidNoFilters,
+ Desc: "no filters",
+ Group: "validate",
+ Expect: ErrProtocolViolationNoFilters,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Subscribe,
+ Qos: 1,
+ },
+ PacketID: 2,
+ },
+ },
+
+ {
+ Case: TSubscribeInvalidIdentifierOversize,
+ Desc: "oversize identifier",
+ Group: "validate",
+ Expect: ErrProtocolViolationOversizeSubID,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Subscribe,
+ Qos: 1,
+ },
+ PacketID: 2,
+ Filters: Subscriptions{
+ {Filter: "a/b", Identifier: 5},
+ {Filter: "d/f", Identifier: 268435456},
+ },
+ },
+ },
+
+ // Spec tests
+ {
+ Case: TSubscribeSpecQosMustPacketID,
+ Desc: "no packet id with qos > 0",
+ Group: "encode",
+ Expect: ErrProtocolViolationNoPacketID,
+ RawBytes: []byte{
+ Subscribe<<4 | 1<<1, 10, // Fixed header
+ 0, 0, // Packet ID - LSB+MSB
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ 1, // QoS
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Subscribe,
+ Qos: 1,
+ Remaining: 10,
+ },
+ Filters: Subscriptions{
+ {Filter: "a/b/c", Qos: 1},
+ },
+ PacketID: 0,
+ },
+ },
+ },
+ Suback: {
+ {
+ Case: TSuback,
+ Desc: "suback",
+ Primary: true,
+ RawBytes: []byte{
+ Suback << 4, 3, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 0, // Return Code QoS 0
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Suback,
+ Remaining: 3,
+ },
+ PacketID: 15,
+ ReasonCodes: []byte{0},
+ },
+ },
+ {
+ Case: TSubackMany,
+ Desc: "many",
+ Primary: true,
+ RawBytes: []byte{
+ Suback << 4, 6, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 0, // Return Code QoS 0
+ 1, // Return Code QoS 1
+ 2, // Return Code QoS 2
+ 0x80, // Return Code fail
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Suback,
+ Remaining: 6,
+ },
+ PacketID: 15,
+ ReasonCodes: []byte{0, 1, 2, 0x80},
+ },
+ },
+ {
+ Case: TSubackDeny,
+ Desc: "deny mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Suback << 4, 4, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 0, // no properties
+ ErrNotAuthorized.Code, // Return Code QoS 0
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Suback,
+ Remaining: 4,
+ },
+ PacketID: 15,
+ ReasonCodes: []byte{ErrNotAuthorized.Code},
+ },
+ },
+ {
+ Case: TSubackUnspecifiedError,
+ Desc: "unspecified error",
+ Primary: true,
+ RawBytes: []byte{
+ Suback << 4, 3, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ ErrUnspecifiedError.Code, // Return Code QoS 0
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Suback,
+ Remaining: 3,
+ },
+ PacketID: 15,
+ ReasonCodes: []byte{ErrUnspecifiedError.Code},
+ },
+ },
+ {
+ Case: TSubackUnspecifiedErrorMqtt5,
+ Desc: "unspecified error mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Suback << 4, 4, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 0, // no properties
+ ErrUnspecifiedError.Code, // Return Code QoS 0
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Suback,
+ Remaining: 4,
+ },
+ PacketID: 15,
+ ReasonCodes: []byte{ErrUnspecifiedError.Code},
+ },
+ },
+ {
+ Case: TSubackMqtt5,
+ Desc: "mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Suback << 4, 20, // Fixed header
+ 0, 15, // Packet ID
+ 16, // Properties Length
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ CodeGrantedQos2.Code, // Return Code QoS 0
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Suback,
+ Remaining: 20,
+ },
+ PacketID: 15,
+ ReasonCodes: []byte{CodeGrantedQos2.Code},
+ Properties: Properties{
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+ {
+ Case: TSubackPacketIDInUse,
+ Desc: "packet id in use",
+ Primary: true,
+ RawBytes: []byte{
+ Suback << 4, 47, // Fixed header
+ 0, 15, // Packet ID
+ 43, // Properties Length
+ 31, 0, 24, 'p', 'a', 'c', 'k', 'e', 't',
+ ' ', 'i', 'd', 'e', 'n', 't', 'i', 'f', 'i', 'e', 'r',
+ ' ', 'i', 'n',
+ ' ', 'u', 's', 'e', // Reason String (31)
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ ErrPacketIdentifierInUse.Code,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Suback,
+ Remaining: 47,
+ },
+ PacketID: 15,
+ ReasonCodes: []byte{ErrPacketIdentifierInUse.Code},
+ Properties: Properties{
+ ReasonString: ErrPacketIdentifierInUse.Reason,
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+
+ // Fail states
+ {
+ Case: TSubackMalPacketID,
+ Desc: "malformed packet id",
+ Group: "decode",
+ FailFirst: ErrMalformedPacketID,
+ RawBytes: []byte{
+ Suback << 4, 2, // Fixed header
+ 0, // Packet ID - LSB+MSB
+ },
+ },
+ {
+ Case: TSubackMalProperties,
+ Desc: "malformed properties",
+ Group: "decode",
+ FailFirst: ErrMalformedProperties,
+ RawBytes: []byte{
+ Suback << 4, 47,
+ 0, 15,
+ 43,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+
+ {
+ Case: TSubackInvalidFilter,
+ Desc: "malformed packet id",
+ Group: "reference",
+ RawBytes: []byte{
+ Suback << 4, 4,
+ 0, 15,
+ 0, // no properties
+ ErrTopicFilterInvalid.Code,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+ {
+ Case: TSubackInvalidSharedNoLocal,
+ Desc: "invalid shared no local",
+ Group: "reference",
+ RawBytes: []byte{
+ Suback << 4, 4,
+ 0, 15,
+ 0, // no properties
+ ErrProtocolViolationInvalidSharedNoLocal.Code,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+ },
+
+ Unsubscribe: {
+ {
+ Case: TUnsubscribe,
+ Desc: "unsubscribe",
+ Primary: true,
+ RawBytes: []byte{
+ Unsubscribe<<4 | 1<<1, 9, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Unsubscribe,
+ Remaining: 9,
+ Qos: 1,
+ },
+ PacketID: 15,
+ Filters: Subscriptions{
+ {Filter: "a/b/c"},
+ },
+ },
+ },
+ {
+ Case: TUnsubscribeMany,
+ Desc: "unsubscribe many",
+ Primary: true,
+ RawBytes: []byte{
+ Unsubscribe<<4 | 1<<1, 27, // Fixed header
+ 0, 35, // Packet ID - LSB+MSB
+
+ 0, 3, // Topic Name - LSB+MSB
+ 'a', '/', 'b', // Topic Name
+
+ 0, 11, // Topic Name - LSB+MSB
+ 'd', '/', 'e', '/', 'f', '/', 'g', '/', 'h', '/', 'i', // Topic Name
+
+ 0, 5, // Topic Name - LSB+MSB
+ 'x', '/', 'y', '/', 'z', // Topic Name
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Unsubscribe,
+ Remaining: 27,
+ Qos: 1,
+ },
+ PacketID: 35,
+ Filters: Subscriptions{
+ {Filter: "a/b"},
+ {Filter: "d/e/f/g/h/i"},
+ {Filter: "x/y/z"},
+ },
+ },
+ },
+ {
+ Case: TUnsubscribeMqtt5,
+ Desc: "mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Unsubscribe<<4 | 1<<1, 31, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+
+ 16,
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+
+ 0, 3, // Topic Name - LSB+MSB
+ 'a', '/', 'b',
+
+ 0, 5, // Topic Name - LSB+MSB
+ 'x', '/', 'y', '/', 'w',
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Unsubscribe,
+ Remaining: 31,
+ Qos: 1,
+ },
+ PacketID: 15,
+ Properties: Properties{
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ Filters: Subscriptions{
+ {Filter: "a/b"},
+ {Filter: "x/y/w"},
+ },
+ },
+ },
+
+ // Fail states
+ {
+ Case: TUnsubscribeMalPacketID,
+ Desc: "malformed packet id",
+ Group: "decode",
+ FailFirst: ErrMalformedPacketID,
+ RawBytes: []byte{
+ Unsubscribe << 4, 2, // Fixed header
+ 0, // Packet ID - LSB+MSB
+ },
+ },
+ {
+ Case: TUnsubscribeMalTopicName,
+ Desc: "malformed topic",
+ Group: "decode",
+ FailFirst: ErrMalformedTopic,
+ RawBytes: []byte{
+ Unsubscribe << 4, 2, // Fixed header
+ 0, 21, // Packet ID - LSB+MSB
+ 0, 3, // Topic Name - LSB+MSB
+ 'a', '/',
+ },
+ },
+ {
+ Case: TUnsubscribeMalProperties,
+ Desc: "malformed properties",
+ Group: "decode",
+ FailFirst: ErrMalformedProperties,
+ RawBytes: []byte{
+ Unsubscribe<<4 | 1<<1, 31, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 16,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+
+ {
+ Case: TUnsubscribeInvalidQosMustPacketID,
+ Desc: "no packet id with qos > 0",
+ Group: "validate",
+ Expect: ErrProtocolViolationNoPacketID,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Unsubscribe,
+ Qos: 1,
+ },
+ PacketID: 0,
+ Filters: Subscriptions{
+ Subscription{Filter: "a/b"},
+ },
+ },
+ },
+ {
+ Case: TUnsubscribeInvalidNoFilters,
+ Desc: "no filters",
+ Group: "validate",
+ Expect: ErrProtocolViolationNoFilters,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Unsubscribe,
+ Qos: 1,
+ },
+ PacketID: 2,
+ },
+ },
+
+ {
+ Case: TUnsubscribeSpecQosMustPacketID,
+ Desc: "no packet id with qos > 0",
+ Group: "encode",
+ Expect: ErrProtocolViolationNoPacketID,
+ RawBytes: []byte{
+ Unsubscribe<<4 | 1<<1, 9, // Fixed header
+ 0, 0, // Packet ID - LSB+MSB
+ 0, 5, // Topic Name - LSB+MSB
+ 'a', '/', 'b', '/', 'c', // Topic Name
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Unsubscribe,
+ Qos: 1,
+ Remaining: 9,
+ },
+ PacketID: 0,
+ Filters: Subscriptions{
+ {Filter: "a/b/c"},
+ },
+ },
+ },
+ },
+ Unsuback: {
+ {
+ Case: TUnsuback,
+ Desc: "unsuback",
+ Primary: true,
+ RawBytes: []byte{
+ Unsuback << 4, 2, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Unsuback,
+ Remaining: 2,
+ },
+ PacketID: 15,
+ },
+ },
+ {
+ Case: TUnsubackMany,
+ Desc: "unsuback many",
+ Primary: true,
+ RawBytes: []byte{
+ Unsuback << 4, 5, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 0,
+ CodeSuccess.Code, CodeSuccess.Code,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Unsuback,
+ Remaining: 5,
+ },
+ PacketID: 15,
+ ReasonCodes: []byte{CodeSuccess.Code, CodeSuccess.Code},
+ },
+ },
+ {
+ Case: TUnsubackMqtt5,
+ Desc: "mqtt5",
+ Primary: true,
+ RawBytes: []byte{
+ Unsuback << 4, 21, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 16, // Properties Length
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ CodeSuccess.Code, CodeNoSubscriptionExisted.Code,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Unsuback,
+ Remaining: 21,
+ },
+ PacketID: 15,
+ ReasonCodes: []byte{CodeSuccess.Code, CodeNoSubscriptionExisted.Code},
+ Properties: Properties{
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+ {
+ Case: TUnsubackPacketIDInUse,
+ Desc: "packet id in use",
+ Primary: true,
+ RawBytes: []byte{
+ Unsuback << 4, 48, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 43, // Properties Length
+ 31, 0, 24, 'p', 'a', 'c', 'k', 'e', 't',
+ ' ', 'i', 'd', 'e', 'n', 't', 'i', 'f', 'i', 'e', 'r',
+ ' ', 'i', 'n',
+ ' ', 'u', 's', 'e', // Reason String (31)
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ ErrPacketIdentifierInUse.Code, ErrPacketIdentifierInUse.Code,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Unsuback,
+ Remaining: 48,
+ },
+ PacketID: 15,
+ ReasonCodes: []byte{ErrPacketIdentifierInUse.Code, ErrPacketIdentifierInUse.Code},
+ Properties: Properties{
+ ReasonString: ErrPacketIdentifierInUse.Reason,
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+
+ // Fail states
+ {
+ Case: TUnsubackMalPacketID,
+ Desc: "malformed packet id",
+ Group: "decode",
+ FailFirst: ErrMalformedPacketID,
+ RawBytes: []byte{
+ Unsuback << 4, 2, // Fixed header
+ 0, // Packet ID - LSB+MSB
+ },
+ },
+ {
+ Case: TUnsubackMalProperties,
+ Desc: "malformed properties",
+ Group: "decode",
+ FailFirst: ErrMalformedProperties,
+ RawBytes: []byte{
+ Unsuback << 4, 48, // Fixed header
+ 0, 15, // Packet ID - LSB+MSB
+ 43,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+ },
+
+ Pingreq: {
+ {
+ Case: TPingreq,
+ Desc: "ping request",
+ Primary: true,
+ RawBytes: []byte{
+ Pingreq << 4, 0, // fixed header
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Pingreq,
+ Remaining: 0,
+ },
+ },
+ },
+ },
+ Pingresp: {
+ {
+ Case: TPingresp,
+ Desc: "ping response",
+ Primary: true,
+ RawBytes: []byte{
+ Pingresp << 4, 0, // fixed header
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Pingresp,
+ Remaining: 0,
+ },
+ },
+ },
+ },
+
+ Disconnect: {
+ {
+ Case: TDisconnect,
+ Desc: "disconnect",
+ Primary: true,
+ RawBytes: []byte{
+ Disconnect << 4, 0, // fixed header
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Disconnect,
+ Remaining: 0,
+ },
+ },
+ },
+ {
+ Case: TDisconnectTakeover,
+ Desc: "takeover",
+ Primary: true,
+ RawBytes: append([]byte{
+ Disconnect << 4, 21, // fixed header
+ ErrSessionTakenOver.Code, // Reason Code
+ 19, // Properties Length
+ 31, 0, 16, // Reason String (31)
+ }, []byte(ErrSessionTakenOver.Reason)...),
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Disconnect,
+ Remaining: 0,
+ },
+ ReasonCode: ErrSessionTakenOver.Code,
+ Properties: Properties{
+ ReasonString: ErrSessionTakenOver.Reason,
+ },
+ },
+ },
+ {
+ Case: TDisconnectShuttingDown,
+ Desc: "shutting down",
+ Primary: true,
+ RawBytes: append([]byte{
+ Disconnect << 4, 25, // fixed header
+ ErrServerShuttingDown.Code, // Reason Code
+ 23, // Properties Length
+ 31, 0, 20, // Reason String (31)
+ }, []byte(ErrServerShuttingDown.Reason)...),
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Disconnect,
+ Remaining: 0,
+ },
+ ReasonCode: ErrServerShuttingDown.Code,
+ Properties: Properties{
+ ReasonString: ErrServerShuttingDown.Reason,
+ },
+ },
+ },
+ {
+ Case: TDisconnectMqtt5,
+ Desc: "mqtt5",
+ Primary: true,
+ RawBytes: append([]byte{
+ Disconnect << 4, 22, // fixed header
+ CodeDisconnect.Code, // Reason Code
+ 20, // Properties Length
+ 17, 0, 0, 0, 120, // Session Expiry Interval (17)
+ 31, 0, 12, // Reason String (31)
+ }, []byte(CodeDisconnect.Reason)...),
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Disconnect,
+ Remaining: 22,
+ },
+ ReasonCode: CodeDisconnect.Code,
+ Properties: Properties{
+ ReasonString: CodeDisconnect.Reason,
+ SessionExpiryInterval: 120,
+ SessionExpiryIntervalFlag: true,
+ },
+ },
+ },
+ {
+ Case: TDisconnectMqtt5DisconnectWithWillMessage,
+ Desc: "mqtt5 disconnect with will message",
+ Primary: true,
+ RawBytes: append([]byte{
+ Disconnect << 4, 38, // fixed header
+ CodeDisconnectWillMessage.Code, // Reason Code
+ 36, // Properties Length
+ 17, 0, 0, 0, 120, // Session Expiry Interval (17)
+ 31, 0, 28, // Reason String (31)
+ }, []byte(CodeDisconnectWillMessage.Reason)...),
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Disconnect,
+ Remaining: 22,
+ },
+ ReasonCode: CodeDisconnectWillMessage.Code,
+ Properties: Properties{
+ ReasonString: CodeDisconnectWillMessage.Reason,
+ SessionExpiryInterval: 120,
+ SessionExpiryIntervalFlag: true,
+ },
+ },
+ },
+ {
+ Case: TDisconnectSecondConnect,
+ Desc: "second connect packet mqtt5",
+ RawBytes: append([]byte{
+ Disconnect << 4, 46, // fixed header
+ ErrProtocolViolationSecondConnect.Code,
+ 44,
+ 31, 0, 41, // Reason String (31)
+ }, []byte(ErrProtocolViolationSecondConnect.Reason)...),
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Disconnect,
+ Remaining: 45,
+ },
+ ReasonCode: ErrProtocolViolationSecondConnect.Code,
+ Properties: Properties{
+ ReasonString: ErrProtocolViolationSecondConnect.Reason,
+ },
+ },
+ },
+ {
+ Case: TDisconnectZeroNonZeroExpiry,
+ Desc: "zero non zero expiry",
+ RawBytes: []byte{
+ Disconnect << 4, 2, // fixed header
+ ErrProtocolViolationZeroNonZeroExpiry.Code,
+ 0,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Disconnect,
+ Remaining: 2,
+ },
+ ReasonCode: ErrProtocolViolationZeroNonZeroExpiry.Code,
+ },
+ },
+ {
+ Case: TDisconnectReceiveMaximum,
+ Desc: "receive maximum mqtt5",
+ RawBytes: append([]byte{
+ Disconnect << 4, 29, // fixed header
+ ErrReceiveMaximum.Code,
+ 27, // Properties Length
+ 31, 0, 24, // Reason String (31)
+ }, []byte(ErrReceiveMaximum.Reason)...),
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Disconnect,
+ Remaining: 29,
+ },
+ ReasonCode: ErrReceiveMaximum.Code,
+ Properties: Properties{
+ ReasonString: ErrReceiveMaximum.Reason,
+ },
+ },
+ },
+ {
+ Case: TDisconnectDropProperties,
+ Desc: "drop oversize properties partial",
+ Group: "encode",
+ RawBytes: []byte{
+ Disconnect << 4, 39, // fixed header
+ CodeDisconnect.Code,
+ 19, // length
+ 28, 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', // Server Reference (28)
+ 31, 0, 6, 'r', 'e', 'a', 's', 'o', 'n', // Reason String (31)
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ },
+ ActualBytes: []byte{
+ Disconnect << 4, 12, // fixed header
+ CodeDisconnect.Code,
+ 10, // length
+ 28, 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', // Server Reference (28)
+ },
+ Packet: &Packet{
+ Mods: Mods{
+ MaxSize: 3,
+ },
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Disconnect,
+ Remaining: 40,
+ },
+ ReasonCode: CodeSuccess.Code,
+ Properties: Properties{
+ ReasonString: "reason",
+ ServerReference: "mochi-2",
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+ // fail states
+ {
+ Case: TDisconnectMalProperties,
+ Desc: "malformed properties",
+ Group: "decode",
+ FailFirst: ErrMalformedProperties,
+ RawBytes: []byte{
+ Disconnect << 4, 48, // fixed header
+ CodeDisconnect.Code, // Reason Code
+ 46,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+ {
+ Case: TDisconnectMalReasonCode,
+ Desc: "malformed reason code",
+ Group: "decode",
+ FailFirst: ErrMalformedReasonCode,
+ RawBytes: []byte{
+ Disconnect << 4, 48, // fixed header
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+ },
+ Auth: {
+ {
+ Case: TAuth,
+ Desc: "auth",
+ Primary: true,
+ RawBytes: []byte{
+ Auth << 4, 47,
+ CodeSuccess.Code, // reason code
+ 45,
+ 21, 0, 5, 'S', 'H', 'A', '-', '1', // Authentication Method (21)
+ 22, 0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a', // Authentication Data (22)
+ 31, 0, 6, 'r', 'e', 'a', 's', 'o', 'n', // Reason String (31)
+ 38, // User Properties (38)
+ 0, 5, 'h', 'e', 'l', 'l', 'o',
+ 0, 6, 228, 184, 150, 231, 149, 140,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ FixedHeader: FixedHeader{
+ Type: Auth,
+ Remaining: 47,
+ },
+ ReasonCode: CodeSuccess.Code,
+ Properties: Properties{
+ AuthenticationMethod: "SHA-1",
+ AuthenticationData: []byte("auth-data"),
+ ReasonString: "reason",
+ User: []UserProperty{
+ {
+ Key: "hello",
+ Val: "世界",
+ },
+ },
+ },
+ },
+ },
+ {
+ Case: TAuthMalReasonCode,
+ Desc: "malformed reason code",
+ Group: "decode",
+ FailFirst: ErrMalformedReasonCode,
+ RawBytes: []byte{
+ Auth << 4, 47,
+ },
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Auth,
+ },
+ ReasonCode: CodeNoMatchingSubscribers.Code,
+ },
+ },
+ // fail states
+ {
+ Case: TAuthMalProperties,
+ Desc: "malformed properties",
+ Group: "decode",
+ FailFirst: ErrMalformedProperties,
+ RawBytes: []byte{
+ Auth << 4, 3,
+ CodeSuccess.Code,
+ 12,
+ },
+ Packet: &Packet{
+ ProtocolVersion: 5,
+ },
+ },
+ // Validation
+ {
+ Case: TAuthInvalidReason,
+ Desc: "invalid reason code",
+ Group: "validate",
+ Expect: ErrProtocolViolationInvalidReason,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Auth,
+ },
+ ReasonCode: CodeNoMatchingSubscribers.Code,
+ },
+ },
+ {
+ Case: TAuthInvalidReason2,
+ Desc: "invalid reason code",
+ Group: "validate",
+ Expect: ErrProtocolViolationInvalidReason,
+ Packet: &Packet{
+ FixedHeader: FixedHeader{
+ Type: Auth,
+ },
+ ReasonCode: CodeNoMatchingSubscribers.Code,
+ },
+ },
+ },
+}
diff --git a/packets/tpackets_test.go b/packets/tpackets_test.go
new file mode 100644
index 0000000..8114207
--- /dev/null
+++ b/packets/tpackets_test.go
@@ -0,0 +1,33 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package packets
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func encodeTestOK(wanted TPacketCase) bool {
+ if wanted.RawBytes == nil {
+ return false
+ }
+ if wanted.Group != "" && wanted.Group != "encode" {
+ return false
+ }
+ return true
+}
+
+func decodeTestOK(wanted TPacketCase) bool {
+ if wanted.Group != "" && wanted.Group != "decode" {
+ return false
+ }
+ return true
+}
+
+func TestTPacketCaseGet(t *testing.T) {
+ require.Equal(t, TPacketData[Connect][1], TPacketData[Connect].Get(TConnectMqtt311))
+ require.Equal(t, TPacketCase{}, TPacketData[Connect].Get(byte(128)))
+}
diff --git a/system/config.yaml b/system/config.yaml
new file mode 100644
index 0000000..6691526
--- /dev/null
+++ b/system/config.yaml
@@ -0,0 +1,70 @@
+listeners:
+ - type: "tcp"
+ id: "file-tcp1"
+ address: ":1883"
+ - type: "ws"
+ id: "file-websocket"
+ address: ":1882"
+ - type: "healthcheck"
+ id: "file-healthcheck"
+ address: ":1880"
+hooks:
+ debug:
+ enable: true
+ storage:
+ badger:
+ path: badger.db
+ gc_interval: 3
+ gc_discard_ratio: 0.5
+ pebble:
+ path: pebble.db
+ mode: "NoSync"
+ bolt:
+ path: bolt.db
+ bucket: "mochi"
+ redis:
+ h_prefix: "mc"
+ username: "mochi"
+ password: "melon"
+ address: "localhost:6379"
+ database: 1
+ auth:
+ allow_all: false
+ ledger:
+ auth:
+ - username: peach
+ password: password1
+ allow: true
+ acl:
+ - remote: 127.0.0.1:*
+ - username: melon
+ filters:
+ melon/#: 3
+ updates/#: 2
+options:
+ client_net_write_buffer_size: 2048
+ client_net_read_buffer_size: 2048
+ sys_topic_resend_interval: 10
+ inline_client: true
+ capabilities:
+ maximum_message_expiry_interval: 100
+ maximum_client_writes_pending: 8192
+ maximum_session_expiry_interval: 86400
+ maximum_packet_size: 0
+ receive_maximum: 1024
+ maximum_inflight: 8192
+ topic_alias_maximum: 65535
+ shared_sub_available: 1
+ minimum_protocol_version: 3
+ maximum_qos: 2
+ retain_available: 1
+ wildcard_sub_available: 1
+ sub_id_available: 1
+ compatibilities:
+ obscure_not_authorized: true
+ passive_client_disconnect: false
+ always_return_response_info: false
+ restore_sys_info_on_restart: false
+ no_inherited_properties_on_ack: false
+logging:
+ level: INFO
diff --git a/system/system.go b/system/system.go
new file mode 100644
index 0000000..1ceda94
--- /dev/null
+++ b/system/system.go
@@ -0,0 +1,62 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
+// SPDX-FileContributor: mochi-co
+
+package system
+
+import "sync/atomic"
+
+// Info 系统信息 包含各种服务器统计信息的原子计数器和值
+// Info contains atomic counters and values for various server statistics
+// commonly found in $SYS topics (and others).
+// based on https://github.com/mqtt/mqtt.org/wiki/SYS-Topics
+type Info struct {
+ Version string `json:"version"` // the current version of the server
+ Started int64 `json:"started"` // the time the server started in unix seconds
+ Time int64 `json:"time"` // current time on the server
+ Uptime int64 `json:"uptime"` // the number of seconds the server has been online
+ BytesReceived int64 `json:"bytes_received"` // total number of bytes received since the broker started
+ BytesSent int64 `json:"bytes_sent"` // total number of bytes sent since the broker started
+ ClientsConnected int64 `json:"clients_connected"` // number of currently connected clients
+ ClientsDisconnected int64 `json:"clients_disconnected"` // total number of persistent clients (with clean session disabled) that are registered at the broker but are currently disconnected
+ ClientsMaximum int64 `json:"clients_maximum"` // maximum number of active clients that have been connected
+ ClientsTotal int64 `json:"clients_total"` // total number of connected and disconnected clients with a persistent session currently connected and registered
+ MessagesReceived int64 `json:"messages_received"` // total number of publish messages received
+ MessagesSent int64 `json:"messages_sent"` // total number of publish messages sent
+ MessagesDropped int64 `json:"messages_dropped"` // total number of publish messages dropped to slow subscriber
+ Retained int64 `json:"retained"` // total number of retained messages active on the broker
+ Inflight int64 `json:"inflight"` // the number of messages currently in-flight
+ InflightDropped int64 `json:"inflight_dropped"` // the number of inflight messages which were dropped
+ Subscriptions int64 `json:"subscriptions"` // total number of subscriptions active on the broker
+ PacketsReceived int64 `json:"packets_received"` // the total number of publish messages received
+ PacketsSent int64 `json:"packets_sent"` // total number of messages of any type sent since the broker started
+ MemoryAlloc int64 `json:"memory_alloc"` // memory currently allocated
+ Threads int64 `json:"threads"` // number of active goroutines, named as threads for platform ambiguity
+}
+
+// Clone makes a copy of Info using atomic operation
+func (i *Info) Clone() *Info {
+ return &Info{
+ Version: i.Version,
+ Started: atomic.LoadInt64(&i.Started),
+ Time: atomic.LoadInt64(&i.Time),
+ Uptime: atomic.LoadInt64(&i.Uptime),
+ BytesReceived: atomic.LoadInt64(&i.BytesReceived),
+ BytesSent: atomic.LoadInt64(&i.BytesSent),
+ ClientsConnected: atomic.LoadInt64(&i.ClientsConnected),
+ ClientsMaximum: atomic.LoadInt64(&i.ClientsMaximum),
+ ClientsTotal: atomic.LoadInt64(&i.ClientsTotal),
+ ClientsDisconnected: atomic.LoadInt64(&i.ClientsDisconnected),
+ MessagesReceived: atomic.LoadInt64(&i.MessagesReceived),
+ MessagesSent: atomic.LoadInt64(&i.MessagesSent),
+ MessagesDropped: atomic.LoadInt64(&i.MessagesDropped),
+ Retained: atomic.LoadInt64(&i.Retained),
+ Inflight: atomic.LoadInt64(&i.Inflight),
+ InflightDropped: atomic.LoadInt64(&i.InflightDropped),
+ Subscriptions: atomic.LoadInt64(&i.Subscriptions),
+ PacketsReceived: atomic.LoadInt64(&i.PacketsReceived),
+ PacketsSent: atomic.LoadInt64(&i.PacketsSent),
+ MemoryAlloc: atomic.LoadInt64(&i.MemoryAlloc),
+ Threads: atomic.LoadInt64(&i.Threads),
+ }
+}
diff --git a/system/system_test.go b/system/system_test.go
new file mode 100644
index 0000000..b76df21
--- /dev/null
+++ b/system/system_test.go
@@ -0,0 +1,37 @@
+package system
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestClone(t *testing.T) {
+ o := &Info{
+ Version: "version",
+ Started: 1,
+ Time: 2,
+ Uptime: 3,
+ BytesReceived: 4,
+ BytesSent: 5,
+ ClientsConnected: 6,
+ ClientsMaximum: 7,
+ ClientsTotal: 8,
+ ClientsDisconnected: 9,
+ MessagesReceived: 10,
+ MessagesSent: 11,
+ MessagesDropped: 20,
+ Retained: 12,
+ Inflight: 13,
+ InflightDropped: 14,
+ Subscriptions: 15,
+ PacketsReceived: 16,
+ PacketsSent: 17,
+ MemoryAlloc: 18,
+ Threads: 19,
+ }
+
+ n := o.Clone()
+
+ require.Equal(t, o, n)
+}