commit 84e5e65ee7a0e1c15edef83ffac3ab6fa6f98dd5 Author: iuu <2167162990@qq.com> Date: Wed Aug 21 15:32:05 2024 +0800 代码整理 diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..dbeae59 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml new file mode 100644 index 0000000..a99ea19 --- /dev/null +++ b/.idea/workspace.xml @@ -0,0 +1,72 @@ + + + + + + + + + + + + + + { + "associatedIndex": 8 +} + + + + { + "keyToString": { + "DefaultGoTemplateProperty": "Go File", + "RunOnceActivity.ShowReadmeOnStart": "true", + "RunOnceActivity.go.formatter.settings.were.checked": "true", + "RunOnceActivity.go.migrated.go.modules.settings": "true", + "RunOnceActivity.go.modules.automatic.dependencies.download": "true", + "RunOnceActivity.go.modules.go.list.on.any.changes.was.set": "true", + "git-widget-placeholder": "main", + "go.import.settings.migrated": "true", + "last_opened_file_path": "E:/Work/Go/src/testmqtt", + "node.js.detected.package.eslint": "true", + "node.js.selected.package.eslint": "(autodetect)", + "nodejs_package_manager_path": "npm" + } +} + + + + + + + + + + + + + + + + + + + + true + + \ No newline at end of file diff --git a/config-bac.yaml b/config-bac.yaml new file mode 100644 index 0000000..b1aeaac --- /dev/null +++ b/config-bac.yaml @@ -0,0 +1,101 @@ + + +# 监听端口 +listeners: + - type: "tcp" + id: "file-tcp1" + address: ":1883" + - type: "ws" + id: "file-websocket" + address: ":1882" + - type: "healthcheck" + id: "file-healthcheck" + address: ":1880" + + +hooks: + debug: + enable: true + storage: + badger: + path: badger.db + gc_interval: 3 + gc_discard_ratio: 0.5 + pebble: + path: pebble.db + mode: "NoSync" + bolt: + path: bolt.db + bucket: "mochi" + redis: + h_prefix: "mc" + username: "mochi" + password: "melon" + address: "localhost:6379" + database: 1 + auth: + allow_all: false + ledger: + auth: + - username: peach + password: password1 + allow: true + acl: + - remote: 127.0.0.1:* + - username: melon + filters: + melon/#: 3 + updates/#: 2 + +# MQTT协议相关参数配置 +options: + # 客户端网络写缓冲区的大小,单位为字节。此配置项设置了发送数据到客户端时的缓冲区大小,设置为 2048 字节。 + client_net_write_buffer_size: 2048 + # 客户端网络读缓冲区的大小,单位为字节。此配置项设置了从客户端接收数据时的缓冲区大小,设置为 2048 字节。 + client_net_read_buffer_size: 2048 + # 系统主题消息重新发送的间隔时间,单位为秒。系统主题通常包含监控和管理信息,这个选项设定了这些信息更新的频率。 + sys_topic_resend_interval: 10 + # 如果设为 true,客户端的操作可能会同步执行(而不是异步执行)。这可能影响性能,但在某些情况下能简化处理逻辑。 + inline_client: true + # 这些选项定义了 MQTT Broker 支持的功能及其限制。 + capabilities: + # 消息过期的最大时间间隔,单位为秒。超过这个时间的消息将被丢弃。 + maximum_message_expiry_interval: 100 + # 允许挂起的客户端写操作的最大数量。此选项设置了在发送消息给客户端之前,Broker 可以排队的最大消息数量。 + maximum_client_writes_pending: 8192 + # 客户端会话的最大过期时间,单位为秒。这个配置决定了在客户端断开连接后,Broker 会保留客户端会话状态的最大时间。 + maximum_session_expiry_interval: 86400 + # 接收的 MQTT 包的最大大小,单位为字节。如果设置为 0,表示不限制包的大小。 + maximum_packet_size: 0 + # Broker 能够同时接收的最大 QoS 1 和 QoS 2 消息数量 + receive_maximum: 1024 + # 允许的最大“飞行中”(尚未完成确认)的消息数量。 + maximum_inflight: 8192 + # 允许的最大主题别名数量。主题别名用于减少传输过程中重复主题字符串的传输,特别在长主题中有助于节省带宽。 + topic_alias_maximum: 65535 + # 指定是否支持共享订阅,1 表示支持。 + shared_sub_available: 1 + # 支持的最小 MQTT 协议版本。3 表示 MQTT 3.1.1。 + minimum_protocol_version: 3 + # 支持的最高服务质量(QoS)级别。2 表示 Broker 支持 QoS 0、QoS 1 和 QoS 2。 + maximum_qos: 2 + # 指定是否支持保留消息,1 表示支持。 + retain_available: 1 + # 指定是否支持通配符订阅,1 表示支持。 + wildcard_sub_available: 1 + # 指定是否支持订阅标识符,1 表示支持。 + sub_id_available: 1 + # 这些选项是为了保持与旧版本或特殊客户端的兼容性。 + compatibilities: + # 是否隐藏未授权错误的具体原因。true 表示隐藏,可能只返回一个泛泛的错误信息以提高安全性。 + obscure_not_authorized: true + # 是否在客户端无响应时被动断开连接。false 表示不被动断开,可能会主动断开连接。 + passive_client_disconnect: false + # 是否总是返回响应信息。false 表示只在需要时返回响应信息。 + always_return_response_info: false + # 是否在重启时恢复系统信息。false 表示重启时不会恢复之前的系统状态信息。 + restore_sys_info_on_restart: false + # 是否在 ACK 时不继承属性。false 表示继承属性。 + no_inherited_properties_on_ack: false +logging: + level: INFO diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..666ad35 --- /dev/null +++ b/config.yaml @@ -0,0 +1,25 @@ +listeners: + - type: "tcp" + id: "file-tcp1" + address: ":1883" + - type: "ws" + id: "file-websocket" + address: ":1882" + - type: "healthcheck" + id: "file-healthcheck" + address: ":1880" +hooks: + debug: + enable: true + storage: + redis: + h_prefix: "m-" + password: "iuuiuuiuu" + address: "192.168.0.9:6379" + database: 5 + auth: + allow_all: true + + +logging: + level: DEBUG diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..31c48d1 --- /dev/null +++ b/config/config.go @@ -0,0 +1,47 @@ +package config + +import ( + "encoding/json" + "gopkg.in/yaml.v3" + "testmqtt/listeners" + "testmqtt/mqtt" +) + +// config 定义配置结构 +// config defines the structure of configuration data to be parsed from a config source. +type config struct { + Options mqtt.Options // Mqtt协议配置 + Listeners []listeners.Config `yaml:"listeners" json:"listeners"` // 监听端口配置 + HookConfigs HookConfigs `yaml:"hooks" json:"hooks"` // 钩子配置 + LoggingConfig LoggingConfig `yaml:"logging" json:"logging"` // 日志记录配置 +} + +// FromBytes unmarshals a byte slice of JSON or YAML config data into a valid server options value. +// Any hooks configurations are converted into Hooks using the toHooks methods in this package. +func FromBytes(b []byte) (*mqtt.Options, error) { + c := new(config) + o := mqtt.Options{} + + if len(b) == 0 { + return nil, nil + } + + if b[0] == '{' { + err := json.Unmarshal(b, c) + if err != nil { + return nil, err + } + } else { + err := yaml.Unmarshal(b, c) + if err != nil { + return nil, err + } + } + + o = c.Options + o.Hooks = c.HookConfigs.ToHooks() + o.Listeners = c.Listeners + o.Logger = c.LoggingConfig.ToLogger() + + return &o, nil +} diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 0000000..a1b9698 --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package config + +import ( + "log/slog" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "testmqtt/hooks/auth" + "testmqtt/hooks/storage/badger" + "testmqtt/hooks/storage/bolt" + "testmqtt/hooks/storage/pebble" + "testmqtt/hooks/storage/redis" + "testmqtt/listeners" + + "testmqtt/mqtt" +) + +var ( + yamlBytes = []byte(` +listeners: + - type: "tcp" + id: "file-tcp1" + address: ":1883" +hooks: + auth: + allow_all: true +options: + client_net_write_buffer_size: 2048 + capabilities: + minimum_protocol_version: 3 + compatibilities: + restore_sys_info_on_restart: true +`) + + jsonBytes = []byte(`{ + "listeners": [ + { + "type": "tcp", + "id": "file-tcp1", + "address": ":1883" + } + ], + "hooks": { + "auth": { + "allow_all": true + } + }, + "options": { + "client_net_write_buffer_size": 2048, + "capabilities": { + "minimum_protocol_version": 3, + "compatibilities": { + "restore_sys_info_on_restart": true + } + } + } +} +`) + parsedOptions = mqtt.Options{ + Listeners: []listeners.Config{ + { + Type: listeners.TypeTCP, + ID: "file-tcp1", + Address: ":1883", + }, + }, + Hooks: []mqtt.HookLoadConfig{ + { + Hook: new(auth.AllowHook), + }, + }, + ClientNetWriteBufferSize: 2048, + Capabilities: &mqtt.Capabilities{ + MinimumProtocolVersion: 3, + Compatibilities: mqtt.Compatibilities{ + RestoreSysInfoOnRestart: true, + }, + }, + Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: new(slog.LevelVar), + })), + } +) + +func TestFromBytesEmptyL(t *testing.T) { + _, err := FromBytes([]byte{}) + require.NoError(t, err) +} + +func TestFromBytesYAML(t *testing.T) { + o, err := FromBytes(yamlBytes) + require.NoError(t, err) + require.Equal(t, parsedOptions, *o) +} + +func TestFromBytesYAMLError(t *testing.T) { + _, err := FromBytes(append(yamlBytes, 'a')) + require.Error(t, err) +} + +func TestFromBytesJSON(t *testing.T) { + o, err := FromBytes(jsonBytes) + require.NoError(t, err) + require.Equal(t, parsedOptions, *o) +} + +func TestFromBytesJSONError(t *testing.T) { + _, err := FromBytes(append(jsonBytes, 'a')) + require.Error(t, err) +} + +func TestToHooksAuthAllowAll(t *testing.T) { + hc := HookConfigs{ + Auth: &HookAuthConfig{ + AllowAll: true, + }, + } + + th := hc.toHooksAuth() + expect := []mqtt.HookLoadConfig{ + {Hook: new(auth.AllowHook)}, + } + require.Equal(t, expect, th) +} + +func TestToHooksAuthAllowLedger(t *testing.T) { + hc := HookConfigs{ + Auth: &HookAuthConfig{ + Ledger: auth.Ledger{ + Auth: auth.AuthRules{ + {Username: "peach", Password: "password1", Allow: true}, + }, + }, + }, + } + + th := hc.toHooksAuth() + expect := []mqtt.HookLoadConfig{ + { + Hook: new(auth.Hook), + Config: &auth.Options{ + Ledger: &auth.Ledger{ // avoid copying sync.Locker + Auth: auth.AuthRules{ + {Username: "peach", Password: "password1", Allow: true}, + }, + }, + }, + }, + } + require.Equal(t, expect, th) +} + +func TestToHooksStorageBadger(t *testing.T) { + hc := HookConfigs{ + Storage: &HookStorageConfig{ + Badger: &badger.Options{ + Path: "badger", + }, + }, + } + + th := hc.toHooksStorage() + expect := []mqtt.HookLoadConfig{ + { + Hook: new(badger.Hook), + Config: hc.Storage.Badger, + }, + } + + require.Equal(t, expect, th) +} + +func TestToHooksStorageBolt(t *testing.T) { + hc := HookConfigs{ + Storage: &HookStorageConfig{ + Bolt: &bolt.Options{ + Path: "bolt", + Bucket: "mochi", + }, + }, + } + + th := hc.toHooksStorage() + expect := []mqtt.HookLoadConfig{ + { + Hook: new(bolt.Hook), + Config: hc.Storage.Bolt, + }, + } + + require.Equal(t, expect, th) +} + +func TestToHooksStorageRedis(t *testing.T) { + hc := HookConfigs{ + Storage: &HookStorageConfig{ + Redis: &redis.Options{ + Username: "test", + }, + }, + } + + th := hc.toHooksStorage() + expect := []mqtt.HookLoadConfig{ + { + Hook: new(redis.Hook), + Config: hc.Storage.Redis, + }, + } + + require.Equal(t, expect, th) +} + +func TestToHooksStoragePebble(t *testing.T) { + hc := HookConfigs{ + Storage: &HookStorageConfig{ + Pebble: &pebble.Options{ + Path: "pebble", + }, + }, + } + + th := hc.toHooksStorage() + expect := []mqtt.HookLoadConfig{ + { + Hook: new(pebble.Hook), + Config: hc.Storage.Pebble, + }, + } + + require.Equal(t, expect, th) +} diff --git a/config/hook.go b/config/hook.go new file mode 100644 index 0000000..ddc627a --- /dev/null +++ b/config/hook.go @@ -0,0 +1,123 @@ +package config + +import ( + "testmqtt/hooks/auth" + "testmqtt/hooks/debug" + "testmqtt/hooks/storage/badger" + "testmqtt/hooks/storage/bolt" + "testmqtt/hooks/storage/pebble" + "testmqtt/hooks/storage/redis" + "testmqtt/mqtt" +) + +// HookConfigs 全部Hook的配置 +// HookConfigs contains configurations to enable individual hooks. +type HookConfigs struct { + Auth *HookAuthConfig `yaml:"auth" json:"auth"` // Auth AuthHook配置 + Storage *HookStorageConfig `yaml:"storage" json:"storage"` // Storage StorageHook配置 + Debug *debug.Options `yaml:"debug" json:"debug"` // Debug DebugHook配置 +} + +// HookAuthConfig AuthHook的配置 +// HookAuthConfig contains configurations for the auth hook. +type HookAuthConfig struct { + Ledger auth.Ledger `yaml:"ledger" json:"ledger"` + AllowAll bool `yaml:"allow_all" json:"allow_all"` +} + +// HookStorageConfig StorageHook的配置 +// HookStorageConfig contains configurations for the different storage hooks. +type HookStorageConfig struct { + Badger *badger.Options `yaml:"badger" json:"badger"` + Bolt *bolt.Options `yaml:"bolt" json:"bolt"` + Pebble *pebble.Options `yaml:"pebble" json:"pebble"` + Redis *redis.Options `yaml:"redis" json:"redis"` +} + +// HookDebugConfig DebugHook的配置 +// HookDebugConfig contains configuration settings for the debug output. +type HookDebugConfig struct { + Enable bool `yaml:"enable" json:"enable"` // non-zero field for enabling hook using file-based config + ShowPacketData bool `yaml:"show_packet_data" json:"show_packet_data"` // include decoded packet data (default false) + ShowPings bool `yaml:"show_pings" json:"show_pings"` // show ping requests and responses (default false) + ShowPasswords bool `yaml:"show_passwords" json:"show_passwords"` // show connecting user passwords (default false) +} + +// ToHooks converts Hook file configurations into Hooks to be added to the server. +func (hc HookConfigs) ToHooks() []mqtt.HookLoadConfig { + var hlc []mqtt.HookLoadConfig + + if hc.Auth != nil { + hlc = append(hlc, hc.toHooksAuth()...) + } + + if hc.Storage != nil { + hlc = append(hlc, hc.toHooksStorage()...) + } + + if hc.Debug != nil { + hlc = append(hlc, mqtt.HookLoadConfig{ + Hook: new(debug.Hook), + Config: hc.Debug, + }) + } + + return hlc +} + +// toHooksAuth converts auth hook configurations into auth hooks. +func (hc HookConfigs) toHooksAuth() []mqtt.HookLoadConfig { + var hlc []mqtt.HookLoadConfig + if hc.Auth.AllowAll { + hlc = append(hlc, mqtt.HookLoadConfig{ + Hook: new(auth.AllowHook), + }) + } else { + hlc = append(hlc, mqtt.HookLoadConfig{ + Hook: new(auth.Hook), + Config: &auth.Options{ + Ledger: &auth.Ledger{ // avoid copying sync.Locker + Users: hc.Auth.Ledger.Users, + Auth: hc.Auth.Ledger.Auth, + ACL: hc.Auth.Ledger.ACL, + }, + }, + }) + } + return hlc +} + +// toHooksAuth converts storage hook configurations into storage hooks. +func (hc HookConfigs) toHooksStorage() []mqtt.HookLoadConfig { + + var hlc []mqtt.HookLoadConfig + + if hc.Storage.Badger != nil { + hlc = append(hlc, mqtt.HookLoadConfig{ + Hook: new(badger.Hook), + Config: hc.Storage.Badger, + }) + } + + if hc.Storage.Bolt != nil { + hlc = append(hlc, mqtt.HookLoadConfig{ + Hook: new(bolt.Hook), + Config: hc.Storage.Bolt, + }) + } + + if hc.Storage.Redis != nil { + hlc = append(hlc, mqtt.HookLoadConfig{ + Hook: new(redis.Hook), + Config: hc.Storage.Redis, + }) + } + + if hc.Storage.Pebble != nil { + hlc = append(hlc, mqtt.HookLoadConfig{ + Hook: new(pebble.Hook), + Config: hc.Storage.Pebble, + }) + } + return hlc +} diff --git a/config/logger.go b/config/logger.go new file mode 100644 index 0000000..0309d5c --- /dev/null +++ b/config/logger.go @@ -0,0 +1,21 @@ +package config + +import "os" +import "log/slog" + +type LoggingConfig struct { + Level string +} + +func (lc LoggingConfig) ToLogger() *slog.Logger { + var level slog.Level + if err := level.UnmarshalText([]byte(lc.Level)); err != nil { + level = slog.LevelInfo + } + + leveler := new(slog.LevelVar) + leveler.Set(level) + return slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: leveler, + })) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..36182ea --- /dev/null +++ b/go.mod @@ -0,0 +1,54 @@ +module testmqtt + +go 1.22 + +require ( + github.com/jinzhu/copier v0.4.0 + github.com/mochi-mqtt/server/v2 v2.6.5 + github.com/rs/xid v1.5.0 + github.com/stretchr/testify v1.9.0 +) + +require ( + github.com/DataDog/zstd v1.4.5 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/cockroachdb/errors v1.11.1 // indirect + github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect + github.com/cockroachdb/pebble v1.1.0 // indirect + github.com/cockroachdb/redact v1.1.5 // indirect + github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgraph-io/badger/v4 v4.2.0 // indirect + github.com/dgraph-io/ristretto v0.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dustin/go-humanize v1.0.0 // indirect + github.com/getsentry/sentry-go v0.18.0 // indirect + github.com/go-redis/redis/v8 v8.11.5 // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/glog v1.0.0 // indirect + github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect + github.com/golang/protobuf v1.5.2 // indirect + github.com/golang/snappy v0.0.4 // indirect + github.com/google/flatbuffers v1.12.1 // indirect + github.com/gorilla/websocket v1.5.0 // indirect + github.com/klauspost/compress v1.15.15 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_golang v1.12.0 // indirect + github.com/prometheus/client_model v0.2.1-0.20210607210712-147c58e9608a // indirect + github.com/prometheus/common v0.32.1 // indirect + github.com/prometheus/procfs v0.7.3 // indirect + github.com/rogpeppe/go-internal v1.9.0 // indirect + go.etcd.io/bbolt v1.3.5 // indirect + go.opencensus.io v0.22.5 // indirect + golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df // indirect + golang.org/x/net v0.23.0 // indirect + golang.org/x/sys v0.18.0 // indirect + golang.org/x/text v0.14.0 // indirect + google.golang.org/protobuf v1.33.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..a585726 --- /dev/null +++ b/go.sum @@ -0,0 +1,573 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= +cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= +cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= +cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= +cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= +cloud.google.com/go v0.50.0/go.mod h1:r9sluTvynVuxRIOHXQEHMFffphuXHOMZMycpNR5e6To= +cloud.google.com/go v0.52.0/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4= +cloud.google.com/go v0.53.0/go.mod h1:fp/UouUEsRkN6ryDKNW/Upv/JBKnv6WDthjR6+vze6M= +cloud.google.com/go v0.54.0/go.mod h1:1rq2OEkV3YMf6n/9ZvGWI3GWw0VoqH/1x2nd8Is/bPc= +cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKVk= +cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= +cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= +cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= +cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= +cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= +cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= +cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= +cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= +cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= +cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= +cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= +cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= +cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= +cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= +cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU= +cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= +cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos= +cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= +cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= +cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= +dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/DataDog/zstd v1.4.5 h1:EndNeuB0l9syBZhut0wns3gV1hL8zX8LIu6ZiVHWLIQ= +github.com/DataDog/zstd v1.4.5/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.23.0 h1:+lwAJYjvvdIVg6doFHuotFjueJ/7KY10xo/vm3X3Scw= +github.com/alicebob/miniredis/v2 v2.23.0/go.mod h1:XNqvJdQJv5mSuVMc0ynneafpnL/zv52acZ6kqeS0t88= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/cockroachdb/datadriven v1.0.3-0.20230413201302-be42291fc80f h1:otljaYPt5hWxV3MUfO5dFPFiOXg9CyG5/kCfayTqsJ4= +github.com/cockroachdb/datadriven v1.0.3-0.20230413201302-be42291fc80f/go.mod h1:a9RdTaap04u637JoCzcUoIcDmvwSUtcUFtT/C3kJlTU= +github.com/cockroachdb/errors v1.11.1 h1:xSEW75zKaKCWzR3OfxXUxgrk/NtT4G1MiOv5lWZazG8= +github.com/cockroachdb/errors v1.11.1/go.mod h1:8MUxA3Gi6b25tYlFEBGLf+D8aISL+M4MIpiWMSNRfxw= +github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b h1:r6VH0faHjZeQy818SGhaone5OnYfxFR/+AzdY3sf5aE= +github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b/go.mod h1:Vz9DsVWQQhf3vs21MhPMZpMGSht7O/2vFW2xusFUVOs= +github.com/cockroachdb/pebble v1.1.0 h1:pcFh8CdCIt2kmEpK0OIatq67Ln9uGDYY3d5XnE0LJG4= +github.com/cockroachdb/pebble v1.1.0/go.mod h1:sEHm5NOXxyiAoKWhoFxT8xMgd/f3RA6qUqQ1BXKrh2E= +github.com/cockroachdb/redact v1.1.5 h1:u1PMllDkdFfPWaNGMyLD1+so+aq3uUItthCFqzwPJ30= +github.com/cockroachdb/redact v1.1.5/go.mod h1:BVNblN9mBWFyMyqK1k3AAiSxhvhfK2oOZZ2lK+dpvRg= +github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06 h1:zuQyyAKVxetITBuuhv3BI9cMrmStnpT18zmgmTxunpo= +github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06/go.mod h1:7nc4anLGjupUW/PeY5qiNYsdNXj7zopG+eqsS7To5IQ= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/badger/v4 v4.2.0 h1:kJrlajbXXL9DFTNuhhu9yCx7JJa4qpYWxtE8BzuWsEs= +github.com/dgraph-io/badger/v4 v4.2.0/go.mod h1:qfCqhPoWDFJRx1gp5QwwyGo8xk1lbHUxvK9nK0OGAak= +github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8= +github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= +github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= +github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/getsentry/sentry-go v0.18.0 h1:MtBW5H9QgdcJabtZcuJG80BMOwaBpkRDZkxRkNC1sN0= +github.com/getsentry/sentry-go v0.18.0/go.mod h1:Kgon4Mby+FJ7ZWHFUAZgVaIa8sxHtnRJRLTXZr51aKQ= +github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= +github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= +github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/glog v1.0.0 h1:nfP3RFugxnNRyKgeWd4oI1nYvXpxrx8ck8ZrcizshdQ= +github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= +github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= +github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/flatbuffers v1.12.1 h1:MVlul7pQNoDzWRLTw5imwYsl+usrS1TXG2H4jg6ImGw= +github.com/google/flatbuffers v1.12.1/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= +github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= +github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= +github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= +github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= +github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw= +github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 h1:I0XW9+e1XWDxdcEniV4rQAIOPUGDq67JSCiRCgGCZLI= +github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/mochi-mqtt/server/v2 v2.6.5 h1:9PiQ6EJt/Dx0ut0Fuuir4F6WinO/5Bpz9szujNwm+q8= +github.com/mochi-mqtt/server/v2 v2.6.5/go.mod h1:TqztjKGO0/ArOjJt9x9idk0kqPT3CVN8Pb+l+PS5Gdo= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= +github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= +github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= +github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= +github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= +github.com/prometheus/client_golang v1.12.0 h1:C+UIj/QWtmqY13Arb8kwMt5j34/0Z2iKamrJ+ryC0Gg= +github.com/prometheus/client_golang v1.12.0/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.2.1-0.20210607210712-147c58e9608a h1:CmF68hwI0XsOQ5UwlBopMi2Ow4Pbg32akc4KIVCOm+Y= +github.com/prometheus/client_model v0.2.1-0.20210607210712-147c58e9608a/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w= +github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= +github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= +github.com/prometheus/common v0.32.1 h1:hWIdL3N2HoUx3B8j3YN9mWor0qhY/NlEKZEaXxuIRh4= +github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= +github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= +github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU= +github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 h1:k/gmLsJDWwWqbLCur2yWnJzwQEKRcAHXo6seXGuSwWw= +github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA= +go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= +go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= +go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= +go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= +go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.opencensus.io v0.22.5 h1:dntmOdLpSpHlVqbW5Eay97DelsZHe+55D+xC6i0dDS0= +go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= +golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= +golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= +golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= +golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= +golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= +golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= +golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df h1:UA2aFVmmsIlefxMk29Dp2juaUSth8Pyn3Tq5Y5mJGME= +golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= +golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= +golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= +golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200222125558-5a598a2470a0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191216173652-a0e659d51361/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20191227053925-7b8e75db28f4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200122220014-bf1340f18c4a/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200204074204-1cc6d1ef6c74/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200212150539-ea181f53ac56/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= +golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= +golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8= +golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= +google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= +google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= +google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= +google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= +google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= +google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= +google.golang.org/api v0.17.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.19.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.22.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= +google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= +google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= +google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= +google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= +google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20191230161307-f3c370f40bfb/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20200204135345-fa8e72b47b90/go.mod h1:GmwEX6Z4W5gMy59cAlVYjN9JhxgbQH6Gn+gFDQe2lzA= +google.golang.org/genproto v0.0.0-20200212174721-66ed5ce911ce/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200228133532-8c2c7df3a383/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA= +google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= +google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKal+60= +google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= +google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/hooks/auth/allow_all.go b/hooks/auth/allow_all.go new file mode 100644 index 0000000..319a18f --- /dev/null +++ b/hooks/auth/allow_all.go @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package auth + +import ( + "bytes" + + "testmqtt/mqtt" + "testmqtt/packets" +) + +// AllowHook is an authentication hook which allows connection access +// for all users and read and write access to all topics. +type AllowHook struct { + mqtt.HookBase +} + +// ID returns the ID of the hook. +func (h *AllowHook) ID() string { + return "allow-all-auth" +} + +// Provides indicates which hook methods this hook provides. +func (h *AllowHook) Provides(b byte) bool { + return bytes.Contains([]byte{ + mqtt.OnConnectAuthenticate, + mqtt.OnACLCheck, + }, []byte{b}) +} + +// OnConnectAuthenticate returns true/allowed for all requests. +func (h *AllowHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool { + return true +} + +// OnACLCheck returns true/allowed for all checks. +func (h *AllowHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool { + return true +} diff --git a/hooks/auth/allow_all_test.go b/hooks/auth/allow_all_test.go new file mode 100644 index 0000000..6d2fe43 --- /dev/null +++ b/hooks/auth/allow_all_test.go @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package auth + +import ( + "testing" + + "github.com/stretchr/testify/require" + "testmqtt/mqtt" + "testmqtt/packets" +) + +func TestAllowAllID(t *testing.T) { + h := new(AllowHook) + require.Equal(t, "allow-all-auth", h.ID()) +} + +func TestAllowAllProvides(t *testing.T) { + h := new(AllowHook) + require.True(t, h.Provides(mqtt.OnACLCheck)) + require.True(t, h.Provides(mqtt.OnConnectAuthenticate)) + require.False(t, h.Provides(mqtt.OnPublished)) +} + +func TestAllowAllOnConnectAuthenticate(t *testing.T) { + h := new(AllowHook) + require.True(t, h.OnConnectAuthenticate(new(mqtt.Client), packets.Packet{})) +} + +func TestAllowAllOnACLCheck(t *testing.T) { + h := new(AllowHook) + require.True(t, h.OnACLCheck(new(mqtt.Client), "any", true)) +} diff --git a/hooks/auth/auth.go b/hooks/auth/auth.go new file mode 100644 index 0000000..0f6233f --- /dev/null +++ b/hooks/auth/auth.go @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package auth + +import ( + "bytes" + + "testmqtt/mqtt" + "testmqtt/packets" +) + +// Options contains the configuration/rules data for the auth ledger. +type Options struct { + Data []byte + Ledger *Ledger +} + +// Hook is an authentication hook which implements an auth ledger. +type Hook struct { + mqtt.HookBase + config *Options + ledger *Ledger +} + +// ID returns the ID of the hook. +func (h *Hook) ID() string { + return "auth-ledger" +} + +// Provides indicates which hook methods this hook provides. +func (h *Hook) Provides(b byte) bool { + return bytes.Contains([]byte{ + mqtt.OnConnectAuthenticate, + mqtt.OnACLCheck, + }, []byte{b}) +} + +// Init configures the hook with the auth ledger to be used for checking. +func (h *Hook) Init(config any) error { + if _, ok := config.(*Options); !ok && config != nil { + return mqtt.ErrInvalidConfigType + } + + if config == nil { + config = new(Options) + } + + h.config = config.(*Options) + + var err error + if h.config.Ledger != nil { + h.ledger = h.config.Ledger + } else if len(h.config.Data) > 0 { + h.ledger = new(Ledger) + err = h.ledger.Unmarshal(h.config.Data) + } + if err != nil { + return err + } + + if h.ledger == nil { + h.ledger = &Ledger{ + Auth: AuthRules{}, + ACL: ACLRules{}, + } + } + + h.Log.Info("loaded auth rules", + "authentication", len(h.ledger.Auth), + "acl", len(h.ledger.ACL)) + + return nil +} + +// OnConnectAuthenticate returns true if the connecting client has rules which provide access +// in the auth ledger. +func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool { + if _, ok := h.ledger.AuthOk(cl, pk); ok { + return true + } + + h.Log.Info("client failed authentication check", + "username", string(pk.Connect.Username), + "remote", cl.Net.Remote) + return false +} + +// OnACLCheck returns true if the connecting client has matching read or write access to subscribe +// or publish to a given topic. +func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool { + if _, ok := h.ledger.ACLOk(cl, topic, write); ok { + return true + } + + h.Log.Debug("client failed allowed ACL check", + "client", cl.ID, + "username", string(cl.Properties.Username), + "topic", topic) + + return false +} diff --git a/hooks/auth/auth_test.go b/hooks/auth/auth_test.go new file mode 100644 index 0000000..1c8b41e --- /dev/null +++ b/hooks/auth/auth_test.go @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package auth + +import ( + "log/slog" + "os" + "testing" + + "github.com/stretchr/testify/require" + "testmqtt/mqtt" + "testmqtt/packets" +) + +var logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) + +// func teardown(t *testing.T, path string, h *Hook) { +// h.Stop() +// } + +func TestBasicID(t *testing.T) { + h := new(Hook) + require.Equal(t, "auth-ledger", h.ID()) +} + +func TestBasicProvides(t *testing.T) { + h := new(Hook) + require.True(t, h.Provides(mqtt.OnACLCheck)) + require.True(t, h.Provides(mqtt.OnConnectAuthenticate)) + require.False(t, h.Provides(mqtt.OnPublish)) +} + +func TestBasicInitBadConfig(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + err := h.Init(map[string]any{}) + require.Error(t, err) +} + +func TestBasicInitDefaultConfig(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + err := h.Init(nil) + require.NoError(t, err) +} + +func TestBasicInitWithLedgerPointer(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + ln := &Ledger{ + Auth: []AuthRule{ + { + Remote: "127.0.0.1", + Allow: true, + }, + }, + ACL: []ACLRule{ + { + Remote: "127.0.0.1", + Filters: Filters{ + "#": ReadWrite, + }, + }, + }, + } + + err := h.Init(&Options{ + Ledger: ln, + }) + + require.NoError(t, err) + require.Same(t, ln, h.ledger) +} + +func TestBasicInitWithLedgerJSON(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + require.Nil(t, h.ledger) + err := h.Init(&Options{ + Data: ledgerJSON, + }) + + require.NoError(t, err) + require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username) + require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client) +} + +func TestBasicInitWithLedgerYAML(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + require.Nil(t, h.ledger) + err := h.Init(&Options{ + Data: ledgerYAML, + }) + + require.NoError(t, err) + require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username) + require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client) +} + +func TestBasicInitWithLedgerBadDAta(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + require.Nil(t, h.ledger) + err := h.Init(&Options{ + Data: []byte("fdsfdsafasd"), + }) + + require.Error(t, err) +} + +func TestOnConnectAuthenticate(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + ln := new(Ledger) + ln.Auth = checkLedger.Auth + ln.ACL = checkLedger.ACL + err := h.Init( + &Options{ + Ledger: ln, + }, + ) + + require.NoError(t, err) + + require.True(t, h.OnConnectAuthenticate( + &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}}, + )) + + require.False(t, h.OnConnectAuthenticate( + &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}}, + )) + + require.False(t, h.OnConnectAuthenticate( + &mqtt.Client{}, + packets.Packet{}, + )) +} + +func TestOnACL(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + ln := new(Ledger) + ln.Auth = checkLedger.Auth + ln.ACL = checkLedger.ACL + err := h.Init( + &Options{ + Ledger: ln, + }, + ) + + require.NoError(t, err) + + require.True(t, h.OnACLCheck( + &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + "mochi/info", + true, + )) + + require.False(t, h.OnACLCheck( + &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + "d/j/f", + true, + )) + + require.True(t, h.OnACLCheck( + &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + "readonly", + false, + )) + + require.False(t, h.OnACLCheck( + &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + "readonly", + true, + )) +} diff --git a/hooks/auth/ledger.go b/hooks/auth/ledger.go new file mode 100644 index 0000000..90c5fb8 --- /dev/null +++ b/hooks/auth/ledger.go @@ -0,0 +1,246 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package auth + +import ( + "encoding/json" + "strings" + "sync" + + "gopkg.in/yaml.v3" + + "testmqtt/mqtt" + "testmqtt/packets" +) + +const ( + Deny Access = iota // user cannot access the topic + ReadOnly // user can only subscribe to the topic + WriteOnly // user can only publish to the topic + ReadWrite // user can both publish and subscribe to the topic +) + +// Access determines the read/write privileges for an ACL rule. +type Access byte + +// Users contains a map of access rules for specific users, keyed on username. +type Users map[string]UserRule + +// UserRule defines a set of access rules for a specific user. +type UserRule struct { + Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user + Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user + ACL Filters `json:"acl,omitempty" yaml:"acl,omitempty"` // filters to match, if desired + Disallow bool `json:"disallow,omitempty" yaml:"disallow,omitempty"` // allow or disallow the user +} + +// AuthRules defines generic access rules applicable to all users. +type AuthRules []AuthRule + +type AuthRule struct { + Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client + Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user + Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or + Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user + Allow bool `json:"allow,omitempty" yaml:"allow,omitempty"` // allow or disallow the users +} + +// ACLRules defines generic topic or filter access rules applicable to all users. +type ACLRules []ACLRule + +// ACLRule defines access rules for a specific topic or filter. +type ACLRule struct { + Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client + Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user + Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or + Filters Filters `json:"filters,omitempty" yaml:"filters,omitempty"` // filters to match +} + +// Filters is a map of Access rules keyed on filter. +type Filters map[RString]Access + +// RString is a rule value string. +type RString string + +// Matches returns true if the rule matches a given string. +func (r RString) Matches(a string) bool { + rr := string(r) + if r == "" || r == "*" || a == rr { + return true + } + + i := strings.Index(rr, "*") + if i > 0 && len(a) > i && strings.Compare(rr[:i], a[:i]) == 0 { + return true + } + + return false +} + +// FilterMatches returns true if a filter matches a topic rule. +func (r RString) FilterMatches(a string) bool { + _, ok := MatchTopic(string(r), a) + return ok +} + +// MatchTopic checks if a given topic matches a filter, accounting for filter +// wildcards. Eg. filter /a/b/+/c == topic a/b/d/c. +func MatchTopic(filter string, topic string) (elements []string, matched bool) { + filterParts := strings.Split(filter, "/") + topicParts := strings.Split(topic, "/") + + elements = make([]string, 0) + for i := 0; i < len(filterParts); i++ { + if i >= len(topicParts) { + matched = false + return + } + + if filterParts[i] == "+" { + elements = append(elements, topicParts[i]) + continue + } + + if filterParts[i] == "#" { + matched = true + elements = append(elements, strings.Join(topicParts[i:], "/")) + return + } + + if filterParts[i] != topicParts[i] { + matched = false + return + } + } + + return elements, true +} + +// Ledger is an auth ledger containing access rules for users and topics. +type Ledger struct { + sync.Mutex `json:"-" yaml:"-"` + Users Users `json:"users" yaml:"users"` + Auth AuthRules `json:"auth" yaml:"auth"` + ACL ACLRules `json:"acl" yaml:"acl"` +} + +// Update updates the internal values of the ledger. +func (l *Ledger) Update(ln *Ledger) { + l.Lock() + defer l.Unlock() + l.Auth = ln.Auth + l.ACL = ln.ACL +} + +// AuthOk returns true if the rules indicate the user is allowed to authenticate. +func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) { + // If the users map is set, always check for a predefined user first instead + // of iterating through global rules. + if l.Users != nil { + if u, ok := l.Users[string(cl.Properties.Username)]; ok && + u.Password != "" && + u.Password == RString(pk.Connect.Password) { + return 0, !u.Disallow + } + } + + // If there's no users map, or no user was found, attempt to find a matching + // rule (which may also contain a user). + for n, rule := range l.Auth { + if rule.Client.Matches(cl.ID) && + rule.Username.Matches(string(cl.Properties.Username)) && + rule.Password.Matches(string(pk.Connect.Password)) && + rule.Remote.Matches(cl.Net.Remote) { + return n, rule.Allow + } + } + + return 0, false +} + +// ACLOk returns true if the rules indicate the user is allowed to read or write to +// a specific filter or topic respectively, based on the `write` bool. +func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) { + // If the users map is set, always check for a predefined user first instead + // of iterating through global rules. + if l.Users != nil { + if u, ok := l.Users[string(cl.Properties.Username)]; ok && len(u.ACL) > 0 { + for filter, access := range u.ACL { + if filter.FilterMatches(topic) { + if !write && (access == ReadOnly || access == ReadWrite) { + return n, true + } else if write && (access == WriteOnly || access == ReadWrite) { + return n, true + } else { + return n, false + } + } + } + } + } + + for n, rule := range l.ACL { + if rule.Client.Matches(cl.ID) && + rule.Username.Matches(string(cl.Properties.Username)) && + rule.Remote.Matches(cl.Net.Remote) { + if len(rule.Filters) == 0 { + return n, true + } + + if write { + for filter, access := range rule.Filters { + if access == WriteOnly || access == ReadWrite { + if filter.FilterMatches(topic) { + return n, true + } + } + } + } + + if !write { + for filter, access := range rule.Filters { + if access == ReadOnly || access == ReadWrite { + if filter.FilterMatches(topic) { + return n, true + } + } + } + } + + for filter := range rule.Filters { + if filter.FilterMatches(topic) { + return n, false + } + } + } + } + + return 0, true +} + +// ToJSON encodes the values into a JSON string. +func (l *Ledger) ToJSON() (data []byte, err error) { + return json.Marshal(l) +} + +// ToYAML encodes the values into a YAML string. +func (l *Ledger) ToYAML() (data []byte, err error) { + return yaml.Marshal(l) +} + +// Unmarshal decodes a JSON or YAML string (such as a rule config from a file) into a struct. +func (l *Ledger) Unmarshal(data []byte) error { + l.Lock() + defer l.Unlock() + if len(data) == 0 { + return nil + } + + if data[0] == '{' { + return json.Unmarshal(data, l) + } + + return yaml.Unmarshal(data, &l) +} diff --git a/hooks/auth/ledger_test.go b/hooks/auth/ledger_test.go new file mode 100644 index 0000000..ab4fb3d --- /dev/null +++ b/hooks/auth/ledger_test.go @@ -0,0 +1,610 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package auth + +import ( + "testing" + + "github.com/stretchr/testify/require" + "testmqtt/mqtt" + "testmqtt/packets" +) + +var ( + checkLedger = Ledger{ + Users: Users{ // users are allowed by default + "mochi-co": { + Password: "melon", + ACL: Filters{ + "d/+/f": Deny, + "mochi-co/#": ReadWrite, + "readonly": ReadOnly, + }, + }, + "suspended-username": { + Password: "any", + Disallow: true, + }, + "mochi": { // ACL only, will defer to AuthRules for authentication + ACL: Filters{ + "special/mochi": ReadOnly, + "secret/mochi": Deny, + "ignored": ReadWrite, + }, + }, + }, + Auth: AuthRules{ + {Username: "banned-user"}, // never allow specific username + {Remote: "127.0.0.1", Allow: true}, // always allow localhost + {Remote: "123.123.123.123"}, // disallow any from specific address + {Username: "not-mochi", Remote: "111.144.155.166"}, // disallow specific username and address + {Remote: "111.*", Allow: true}, // allow any in wildcard (that isn't the above username) + {Username: "mochi", Password: "melon", Allow: true}, // allow matching user/pass + {Username: "mochi-co", Password: "melon", Allow: false}, // allow matching user/pass (should never trigger due to Users map) + }, + ACL: ACLRules{ + { + Username: "mochi", // allow matching user/pass + Filters: Filters{ + "a/b/c": Deny, + "d/+/f": Deny, + "mochi/#": ReadWrite, + "updates/#": WriteOnly, + "readonly": ReadOnly, + "ignored": Deny, + }, + }, + {Remote: "localhost", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to localhost + {Username: "admin", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to admin + {Remote: "001.002.003.004"}, // Allow all with no filter + {Filters: Filters{"$SYS/#": Deny}}, // Deny $SYS access to all others + }, + } +) + +func TestRStringMatches(t *testing.T) { + require.True(t, RString("*").Matches("any")) + require.True(t, RString("*").Matches("")) + require.True(t, RString("").Matches("any")) + require.True(t, RString("").Matches("")) + require.False(t, RString("no").Matches("any")) + require.False(t, RString("no").Matches("")) +} + +func TestCanAuthenticate(t *testing.T) { + tt := []struct { + desc string + client *mqtt.Client + pk packets.Packet + n int + ok bool + }{ + { + desc: "allow all local 127.0.0.1", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + Net: mqtt.ClientConnection{ + Remote: "127.0.0.1", + }, + }, + pk: packets.Packet{Connect: packets.ConnectParams{}}, + ok: true, + n: 1, + }, + { + desc: "allow username/password", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}}, + ok: true, + n: 5, + }, + { + desc: "deny username/password", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}}, + ok: false, + n: 0, + }, + { + desc: "allow all local 127.0.0.1", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + Net: mqtt.ClientConnection{ + Remote: "127.0.0.1", + }, + }, + pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}}, + ok: true, + n: 1, + }, + { + desc: "allow username/password", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}}, + ok: true, + n: 5, + }, + { + desc: "deny username/password", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}}, + ok: false, + n: 0, + }, + { + desc: "deny client from address", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("not-mochi"), + }, + Net: mqtt.ClientConnection{ + Remote: "111.144.155.166", + }, + }, + pk: packets.Packet{}, + ok: false, + n: 3, + }, + { + desc: "allow remote wildcard", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + Net: mqtt.ClientConnection{ + Remote: "111.0.0.1", + }, + }, + pk: packets.Packet{}, + ok: true, + n: 4, + }, + { + desc: "never allow username", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("banned-user"), + }, + Net: mqtt.ClientConnection{ + Remote: "127.0.0.1", + }, + }, + pk: packets.Packet{}, + ok: false, + n: 0, + }, + { + desc: "matching user in users", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi-co"), + }, + }, + pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}}, + ok: true, + n: 0, + }, + { + desc: "never user in users", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("suspended-user"), + }, + }, + pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("any")}}, + ok: false, + n: 0, + }, + } + + for _, d := range tt { + t.Run(d.desc, func(t *testing.T) { + n, ok := checkLedger.AuthOk(d.client, d.pk) + require.Equal(t, d.n, n) + require.Equal(t, d.ok, ok) + }) + } +} + +func TestCanACL(t *testing.T) { + tt := []struct { + client *mqtt.Client + desc string + topic string + n int + write bool + ok bool + }{ + { + desc: "allow normal write on any other filter", + client: &mqtt.Client{}, + topic: "default/acl/write/access", + write: true, + ok: true, + }, + { + desc: "allow normal read on any other filter", + client: &mqtt.Client{}, + topic: "default/acl/read/access", + write: false, + ok: true, + }, + { + desc: "deny user on literal filter", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "a/b/c", + }, + { + desc: "deny user on partial filter", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "d/j/f", + }, + { + desc: "allow read/write to user path", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "mochi/read/write", + write: true, + ok: true, + }, + { + desc: "deny read on write-only path", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "updates/no/reading", + write: false, + ok: false, + }, + { + desc: "deny read on write-only path ext", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "updates/mochi", + write: false, + ok: false, + }, + { + desc: "allow read on not-acl path (no #)", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "updates", + write: false, + ok: true, + }, + { + desc: "allow write on write-only path", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "updates/mochi", + write: true, + ok: true, + }, + { + desc: "deny write on read-only path", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "readonly", + write: true, + ok: false, + }, + { + desc: "allow read on read-only path", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "readonly", + write: false, + ok: true, + }, + { + desc: "allow $sys access to localhost", + client: &mqtt.Client{ + Net: mqtt.ClientConnection{ + Remote: "localhost", + }, + }, + topic: "$SYS/test", + write: false, + ok: true, + n: 1, + }, + { + desc: "allow $sys access to admin", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("admin"), + }, + }, + topic: "$SYS/test", + write: false, + ok: true, + n: 2, + }, + { + desc: "deny $sys access to all others", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "$SYS/test", + write: false, + ok: false, + n: 4, + }, + { + desc: "allow all with no filter", + client: &mqtt.Client{ + Net: mqtt.ClientConnection{ + Remote: "001.002.003.004", + }, + }, + topic: "any/path", + write: true, + ok: true, + n: 3, + }, + { + desc: "use users embedded acl deny", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "secret/mochi", + write: true, + ok: false, + }, + { + desc: "use users embedded acl any", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "any/mochi", + write: true, + ok: true, + }, + { + desc: "use users embedded acl write on read-only", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "special/mochi", + write: true, + ok: false, + }, + { + desc: "use users embedded acl read on read-only", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "special/mochi", + write: false, + ok: true, + }, + { + desc: "preference users embedded acl", + client: &mqtt.Client{ + Properties: mqtt.ClientProperties{ + Username: []byte("mochi"), + }, + }, + topic: "ignored", + write: true, + ok: true, + }, + } + + for _, d := range tt { + t.Run(d.desc, func(t *testing.T) { + n, ok := checkLedger.ACLOk(d.client, d.topic, d.write) + require.Equal(t, d.n, n) + require.Equal(t, d.ok, ok) + }) + } +} + +func TestMatchTopic(t *testing.T) { + el, matched := MatchTopic("a/+/c/+", "a/b/c/d") + require.True(t, matched) + require.Equal(t, []string{"b", "d"}, el) + + el, matched = MatchTopic("a/+/+/+", "a/b/c/d") + require.True(t, matched) + require.Equal(t, []string{"b", "c", "d"}, el) + + el, matched = MatchTopic("stuff/#", "stuff/things/yeah") + require.True(t, matched) + require.Equal(t, []string{"things/yeah"}, el) + + el, matched = MatchTopic("a/+/#/+", "a/b/c/d/as/dds") + require.True(t, matched) + require.Equal(t, []string{"b", "c/d/as/dds"}, el) + + el, matched = MatchTopic("test", "test") + require.True(t, matched) + require.Equal(t, make([]string, 0), el) + + el, matched = MatchTopic("things/stuff//", "things/stuff/") + require.False(t, matched) + require.Equal(t, make([]string, 0), el) + + el, matched = MatchTopic("t", "t2") + require.False(t, matched) + require.Equal(t, make([]string, 0), el) + + el, matched = MatchTopic(" ", " ") + require.False(t, matched) + require.Equal(t, make([]string, 0), el) +} + +var ( + ledgerStruct = Ledger{ + Users: Users{ + "mochi": { + Password: "peach", + ACL: Filters{ + "readonly": ReadOnly, + "deny": Deny, + }, + }, + }, + Auth: AuthRules{ + { + Client: "*", + Username: "mochi-co", + Password: "melon", + Remote: "192.168.1.*", + Allow: true, + }, + }, + ACL: ACLRules{ + { + Client: "*", + Username: "mochi-co", + Remote: "127.*", + Filters: Filters{ + "readonly": ReadOnly, + "writeonly": WriteOnly, + "readwrite": ReadWrite, + "deny": Deny, + }, + }, + }, + } + + ledgerJSON = []byte(`{"users":{"mochi":{"password":"peach","acl":{"deny":0,"readonly":1}}},"auth":[{"client":"*","username":"mochi-co","remote":"192.168.1.*","password":"melon","allow":true}],"acl":[{"client":"*","username":"mochi-co","remote":"127.*","filters":{"deny":0,"readonly":1,"readwrite":3,"writeonly":2}}]}`) + ledgerYAML = []byte(`users: + mochi: + password: peach + acl: + deny: 0 + readonly: 1 +auth: + - client: '*' + username: mochi-co + remote: 192.168.1.* + password: melon + allow: true +acl: + - client: '*' + username: mochi-co + remote: 127.* + filters: + deny: 0 + readonly: 1 + readwrite: 3 + writeonly: 2 +`) +) + +func TestLedgerUpdate(t *testing.T) { + old := &Ledger{ + Auth: AuthRules{ + {Remote: "127.0.0.1", Allow: true}, + }, + } + + n := &Ledger{ + Auth: AuthRules{ + {Remote: "127.0.0.1", Allow: true}, + {Remote: "192.168.*", Allow: true}, + }, + } + + old.Update(n) + require.Len(t, old.Auth, 2) + require.Equal(t, RString("192.168.*"), old.Auth[1].Remote) + require.NotSame(t, n, old) +} + +func TestLedgerToJSON(t *testing.T) { + data, err := ledgerStruct.ToJSON() + require.NoError(t, err) + require.Equal(t, ledgerJSON, data) +} + +func TestLedgerToYAML(t *testing.T) { + data, err := ledgerStruct.ToYAML() + require.NoError(t, err) + require.Equal(t, ledgerYAML, data) +} + +func TestLedgerUnmarshalFromYAML(t *testing.T) { + l := new(Ledger) + err := l.Unmarshal(ledgerYAML) + require.NoError(t, err) + require.Equal(t, &ledgerStruct, l) + require.NotSame(t, l, &ledgerStruct) +} + +func TestLedgerUnmarshalFromJSON(t *testing.T) { + l := new(Ledger) + err := l.Unmarshal(ledgerJSON) + require.NoError(t, err) + require.Equal(t, &ledgerStruct, l) + require.NotSame(t, l, &ledgerStruct) +} + +func TestLedgerUnmarshalNil(t *testing.T) { + l := new(Ledger) + err := l.Unmarshal([]byte{}) + require.NoError(t, err) + require.Equal(t, new(Ledger), l) +} diff --git a/hooks/debug/debug.go b/hooks/debug/debug.go new file mode 100644 index 0000000..5fa852b --- /dev/null +++ b/hooks/debug/debug.go @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package debug + +import ( + "fmt" + "log/slog" + "strings" + + "testmqtt/hooks/storage" + "testmqtt/mqtt" + "testmqtt/packets" +) + +// Options contains configuration settings for the debug output. +type Options struct { + Enable bool `yaml:"enable" json:"enable"` // non-zero field for enabling hook using file-based config + ShowPacketData bool `yaml:"show_packet_data" json:"show_packet_data"` // include decoded packet data (default false) + ShowPings bool `yaml:"show_pings" json:"show_pings"` // show ping requests and responses (default false) + ShowPasswords bool `yaml:"show_passwords" json:"show_passwords"` // show connecting user passwords (default false) +} + +// Hook is a debugging hook which logs additional low-level information from the server. +type Hook struct { + mqtt.HookBase + config *Options + Log *slog.Logger +} + +// ID returns the ID of the hook. +func (h *Hook) ID() string { + return "debug" +} + +// Provides indicates that this hook provides all methods. +func (h *Hook) Provides(b byte) bool { + return true +} + +// Init is called when the hook is initialized. +func (h *Hook) Init(config any) error { + if _, ok := config.(*Options); !ok && config != nil { + return mqtt.ErrInvalidConfigType + } + + if config == nil { + config = new(Options) + } + + h.config = config.(*Options) + + return nil +} + +// SetOpts is called when the hook receives inheritable server parameters. +func (h *Hook) SetOpts(l *slog.Logger, opts *mqtt.HookOptions) { + h.Log = l + h.Log.Debug("", "method", "SetOpts") +} + +// Stop is called when the hook is stopped. +func (h *Hook) Stop() error { + h.Log.Debug("", "method", "Stop") + return nil +} + +// OnStarted is called when the server starts. +func (h *Hook) OnStarted() { + h.Log.Debug("", "method", "OnStarted") +} + +// OnStopped is called when the server stops. +func (h *Hook) OnStopped() { + h.Log.Debug("", "method", "OnStopped") +} + +// OnPacketRead is called when a new packet is received from a client. +func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) { + if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings { + return pk, nil + } + + h.Log.Debug(fmt.Sprintf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk)) + return pk, nil +} + +// OnPacketSent is called when a packet is sent to a client. +func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) { + if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings { + return + } + + h.Log.Debug(fmt.Sprintf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk)) +} + +// OnRetainMessage is called when a published message is retained (or retain deleted/modified). +func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { + h.Log.Debug("retained message on topic", "m", h.packetMeta(pk)) +} + +// OnQosPublish is called when a publish packet with Qos is issued to a subscriber. +func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { + h.Log.Debug("inflight out", "m", h.packetMeta(pk)) +} + +// OnQosComplete is called when the Qos flow for a message has been completed. +func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { + h.Log.Debug("inflight complete", "m", h.packetMeta(pk)) +} + +// OnQosDropped is called the Qos flow for a message expires. +func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { + h.Log.Debug("inflight dropped", "m", h.packetMeta(pk)) +} + +// OnLWTSent is called when a Will Message has been issued from a disconnecting client. +func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) { + h.Log.Debug("sent lwt for client", "method", "OnLWTSent", "client", cl.ID) +} + +// OnRetainedExpired is called when the server clears expired retained messages. +func (h *Hook) OnRetainedExpired(filter string) { + h.Log.Debug("retained message expired", "method", "OnRetainedExpired", "topic", filter) +} + +// OnClientExpired is called when the server clears an expired client. +func (h *Hook) OnClientExpired(cl *mqtt.Client) { + h.Log.Debug("client session expired", "method", "OnClientExpired", "client", cl.ID) +} + +// StoredClients is called when the server restores clients from a store. +func (h *Hook) StoredClients() (v []storage.Client, err error) { + h.Log.Debug("", "method", "StoredClients") + + return v, nil +} + +// StoredSubscriptions is called when the server restores subscriptions from a store. +func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { + h.Log.Debug("", "method", "StoredSubscriptions") + return v, nil +} + +// StoredRetainedMessages is called when the server restores retained messages from a store. +func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { + h.Log.Debug("", "method", "StoredRetainedMessages") + return v, nil +} + +// StoredInflightMessages is called when the server restores inflight messages from a store. +func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { + h.Log.Debug("", "method", "StoredInflightMessages") + return v, nil +} + +// StoredSysInfo is called when the server restores system info from a store. +func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { + h.Log.Debug("", "method", "StoredSysInfo") + + return v, nil +} + +// packetMeta adds additional type-specific metadata to the debug logs. +func (h *Hook) packetMeta(pk packets.Packet) map[string]any { + m := map[string]any{} + switch pk.FixedHeader.Type { + case packets.Connect: + m["id"] = pk.Connect.ClientIdentifier + m["clean"] = pk.Connect.Clean + m["keepalive"] = pk.Connect.Keepalive + m["version"] = pk.ProtocolVersion + m["username"] = string(pk.Connect.Username) + if h.config.ShowPasswords { + m["password"] = string(pk.Connect.Password) + } + if pk.Connect.WillFlag { + m["will_topic"] = pk.Connect.WillTopic + m["will_payload"] = string(pk.Connect.WillPayload) + } + case packets.Publish: + m["topic"] = pk.TopicName + m["payload"] = string(pk.Payload) + m["raw"] = pk.Payload + m["qos"] = pk.FixedHeader.Qos + m["id"] = pk.PacketID + case packets.Connack: + fallthrough + case packets.Disconnect: + fallthrough + case packets.Puback: + fallthrough + case packets.Pubrec: + fallthrough + case packets.Pubrel: + fallthrough + case packets.Pubcomp: + m["id"] = pk.PacketID + m["reason"] = int(pk.ReasonCode) + if pk.ReasonCode > packets.CodeSuccess.Code && pk.ProtocolVersion == 5 { + m["reason_string"] = pk.Properties.ReasonString + } + case packets.Subscribe: + f := map[string]int{} + ids := map[string]int{} + for _, v := range pk.Filters { + f[v.Filter] = int(v.Qos) + ids[v.Filter] = v.Identifier + } + m["filters"] = f + m["subids"] = f + + case packets.Unsubscribe: + f := []string{} + for _, v := range pk.Filters { + f = append(f, v.Filter) + } + m["filters"] = f + case packets.Suback: + fallthrough + case packets.Unsuback: + r := []int{} + for _, v := range pk.ReasonCodes { + r = append(r, int(v)) + } + m["reasons"] = r + case packets.Auth: + // tbd + } + + if h.config.ShowPacketData { + m["packet"] = pk + } + + return m +} diff --git a/hooks/storage/badger/badger.go b/hooks/storage/badger/badger.go new file mode 100644 index 0000000..5cafa38 --- /dev/null +++ b/hooks/storage/badger/badger.go @@ -0,0 +1,576 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co, gsagula, werbenhu + +package badger + +import ( + "bytes" + "errors" + "fmt" + "strings" + "time" + + badgerdb "github.com/dgraph-io/badger/v4" + "testmqtt/hooks/storage" + "testmqtt/mqtt" + "testmqtt/packets" + "testmqtt/system" +) + +const ( + // defaultDbFile is the default file path for the badger db file. + defaultDbFile = ".badger" + defaultGcInterval = 5 * 60 // gc interval in seconds + defaultGcDiscardRatio = 0.5 +) + +// clientKey returns a primary key for a client. +func clientKey(cl *mqtt.Client) string { + return storage.ClientKey + "_" + cl.ID +} + +// subscriptionKey returns a primary key for a subscription. +func subscriptionKey(cl *mqtt.Client, filter string) string { + return storage.SubscriptionKey + "_" + cl.ID + ":" + filter +} + +// retainedKey returns a primary key for a retained message. +func retainedKey(topic string) string { + return storage.RetainedKey + "_" + topic +} + +// inflightKey returns a primary key for an inflight message. +func inflightKey(cl *mqtt.Client, pk packets.Packet) string { + return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID() +} + +// sysInfoKey returns a primary key for system info. +func sysInfoKey() string { + return storage.SysInfoKey +} + +// Serializable is an interface for objects that can be serialized and deserialized. +type Serializable interface { + UnmarshalBinary([]byte) error + MarshalBinary() (data []byte, err error) +} + +// Options contains configuration settings for the BadgerDB instance. +type Options struct { + Options *badgerdb.Options + Path string `yaml:"path" json:"path"` + // GcDiscardRatio specifies the ratio of log discard compared to the maximum possible log discard. + // Setting it to a higher value would result in fewer space reclaims, while setting it to a lower value + // would result in more space reclaims at the cost of increased activity on the LSM tree. + // discardRatio must be in the range (0.0, 1.0), both endpoints excluded, otherwise, it will be set to the default value of 0.5. + GcDiscardRatio float64 `yaml:"gc_discard_ratio" json:"gc_discard_ratio"` + GcInterval int64 `yaml:"gc_interval" json:"gc_interval"` +} + +// Hook is a persistent storage hook based using BadgerDB file store as a backend. +type Hook struct { + mqtt.HookBase + config *Options // options for configuring the BadgerDB instance. + gcTicker *time.Ticker // Ticker for BadgerDB garbage collection. + db *badgerdb.DB // the BadgerDB instance. +} + +// ID returns the id of the hook. +func (h *Hook) ID() string { + return "badger-db" +} + +// Provides indicates which hook methods this hook provides. +func (h *Hook) Provides(b byte) bool { + return bytes.Contains([]byte{ + mqtt.OnSessionEstablished, + mqtt.OnDisconnect, + mqtt.OnSubscribed, + mqtt.OnUnsubscribed, + mqtt.OnRetainMessage, + mqtt.OnWillSent, + mqtt.OnQosPublish, + mqtt.OnQosComplete, + mqtt.OnQosDropped, + mqtt.OnSysInfoTick, + mqtt.OnClientExpired, + mqtt.OnRetainedExpired, + mqtt.StoredClients, + mqtt.StoredInflightMessages, + mqtt.StoredRetainedMessages, + mqtt.StoredSubscriptions, + mqtt.StoredSysInfo, + }, []byte{b}) +} + +// GcLoop periodically runs the garbage collection process to reclaim space in the value log files. +// It uses a ticker to trigger the garbage collection at regular intervals specified by the configuration. +// Refer to: https://dgraph.io/docs/badger/get-started/#garbage-collection +func (h *Hook) gcLoop() { + for range h.gcTicker.C { + again: + // Run the garbage collection process with a threshold. + // If the process returns nil (success), repeat the process. + err := h.db.RunValueLogGC(h.config.GcDiscardRatio) + if err == nil { + goto again // Retry garbage collection if successful. + } + } +} + +// Init initializes and connects to the badger instance. +func (h *Hook) Init(config any) error { + if _, ok := config.(*Options); !ok && config != nil { + return mqtt.ErrInvalidConfigType + } + + if config == nil { + h.config = new(Options) + } else { + h.config = config.(*Options) + } + + if len(h.config.Path) == 0 { + h.config.Path = defaultDbFile + } + + if h.config.GcInterval == 0 { + h.config.GcInterval = defaultGcInterval + } + + if h.config.GcDiscardRatio <= 0.0 || h.config.GcDiscardRatio >= 1.0 { + h.config.GcDiscardRatio = defaultGcDiscardRatio + } + + if h.config.Options == nil { + defaultOpts := badgerdb.DefaultOptions(h.config.Path) + h.config.Options = &defaultOpts + } + h.config.Options.Logger = h + + var err error + h.db, err = badgerdb.Open(*h.config.Options) + if err != nil { + return err + } + + h.gcTicker = time.NewTicker(time.Duration(h.config.GcInterval) * time.Second) + go h.gcLoop() + + return nil +} + +// Stop closes the badger instance. +func (h *Hook) Stop() error { + if h.gcTicker != nil { + h.gcTicker.Stop() + } + return h.db.Close() +} + +// OnSessionEstablished adds a client to the store when their session is established. +func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { + h.updateClient(cl) +} + +// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record. +func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { + h.updateClient(cl) +} + +// updateClient writes the client data to the store. +func (h *Hook) updateClient(cl *mqtt.Client) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + props := cl.Properties.Props.Copy(false) + in := &storage.Client{ + ID: cl.ID, + T: storage.ClientKey, + Remote: cl.Net.Remote, + Listener: cl.Net.Listener, + Username: cl.Properties.Username, + Clean: cl.Properties.Clean, + ProtocolVersion: cl.Properties.ProtocolVersion, + Properties: storage.ClientProperties{ + SessionExpiryInterval: props.SessionExpiryInterval, + AuthenticationMethod: props.AuthenticationMethod, + AuthenticationData: props.AuthenticationData, + RequestProblemInfo: props.RequestProblemInfo, + RequestResponseInfo: props.RequestResponseInfo, + ReceiveMaximum: props.ReceiveMaximum, + TopicAliasMaximum: props.TopicAliasMaximum, + User: props.User, + MaximumPacketSize: props.MaximumPacketSize, + }, + Will: storage.ClientWill(cl.Properties.Will), + } + + err := h.setKv(clientKey(cl), in) + if err != nil { + h.Log.Error("failed to upsert client data", "error", err, "data", in) + } +} + +// OnDisconnect removes a client from the store if their session has expired. +func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + h.updateClient(cl) + + if !expire { + return + } + + if errors.Is(cl.StopCause(), packets.ErrSessionTakenOver) { + return + } + + _ = h.delKv(clientKey(cl)) +} + +// OnSubscribed adds one or more client subscriptions to the store. +func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + var in *storage.Subscription + for i := 0; i < len(pk.Filters); i++ { + in = &storage.Subscription{ + ID: subscriptionKey(cl, pk.Filters[i].Filter), + T: storage.SubscriptionKey, + Client: cl.ID, + Qos: reasonCodes[i], + Filter: pk.Filters[i].Filter, + Identifier: pk.Filters[i].Identifier, + NoLocal: pk.Filters[i].NoLocal, + RetainHandling: pk.Filters[i].RetainHandling, + RetainAsPublished: pk.Filters[i].RetainAsPublished, + } + + _ = h.setKv(in.ID, in) + } +} + +// OnUnsubscribed removes one or more client subscriptions from the store. +func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + for i := 0; i < len(pk.Filters); i++ { + _ = h.delKv(subscriptionKey(cl, pk.Filters[i].Filter)) + } +} + +// OnRetainMessage adds a retained message for a topic to the store. +func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + if r == -1 { + _ = h.delKv(retainedKey(pk.TopicName)) + return + } + + props := pk.Properties.Copy(false) + in := &storage.Message{ + ID: retainedKey(pk.TopicName), + T: storage.RetainedKey, + FixedHeader: pk.FixedHeader, + TopicName: pk.TopicName, + Payload: pk.Payload, + Created: pk.Created, + Origin: pk.Origin, + Properties: storage.MessageProperties{ + PayloadFormat: props.PayloadFormat, + MessageExpiryInterval: props.MessageExpiryInterval, + ContentType: props.ContentType, + ResponseTopic: props.ResponseTopic, + CorrelationData: props.CorrelationData, + SubscriptionIdentifier: props.SubscriptionIdentifier, + TopicAlias: props.TopicAlias, + User: props.User, + }, + } + + _ = h.setKv(in.ID, in) +} + +// OnQosPublish adds or updates an inflight message in the store. +func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + props := pk.Properties.Copy(false) + in := &storage.Message{ + ID: inflightKey(cl, pk), + T: storage.InflightKey, + Origin: pk.Origin, + PacketID: pk.PacketID, + FixedHeader: pk.FixedHeader, + TopicName: pk.TopicName, + Payload: pk.Payload, + Sent: sent, + Created: pk.Created, + Properties: storage.MessageProperties{ + PayloadFormat: props.PayloadFormat, + MessageExpiryInterval: props.MessageExpiryInterval, + ContentType: props.ContentType, + ResponseTopic: props.ResponseTopic, + CorrelationData: props.CorrelationData, + SubscriptionIdentifier: props.SubscriptionIdentifier, + TopicAlias: props.TopicAlias, + User: props.User, + }, + } + + _ = h.setKv(in.ID, in) +} + +// OnQosComplete removes a resolved inflight message from the store. +func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + _ = h.delKv(inflightKey(cl, pk)) +} + +// OnQosDropped removes a dropped inflight message from the store. +func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + } + + h.OnQosComplete(cl, pk) +} + +// OnSysInfoTick stores the latest system info in the store. +func (h *Hook) OnSysInfoTick(sys *system.Info) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + in := &storage.SystemInfo{ + ID: sysInfoKey(), + T: storage.SysInfoKey, + Info: *sys.Clone(), + } + + _ = h.setKv(in.ID, in) +} + +// OnRetainedExpired deletes expired retained messages from the store. +func (h *Hook) OnRetainedExpired(filter string) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + _ = h.delKv(retainedKey(filter)) +} + +// OnClientExpired deleted expired clients from the store. +func (h *Hook) OnClientExpired(cl *mqtt.Client) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + _ = h.delKv(clientKey(cl)) +} + +// StoredClients returns all stored clients from the store. +func (h *Hook) StoredClients() (v []storage.Client, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + err = h.iterKv(storage.ClientKey, func(value []byte) error { + obj := storage.Client{} + err = obj.UnmarshalBinary(value) + if err == nil { + v = append(v, obj) + } + return err + }) + return +} + +// StoredSubscriptions returns all stored subscriptions from the store. +func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + v = make([]storage.Subscription, 0) + err = h.iterKv(storage.SubscriptionKey, func(value []byte) error { + obj := storage.Subscription{} + err = obj.UnmarshalBinary(value) + if err == nil { + v = append(v, obj) + } + return err + }) + return +} + +// StoredRetainedMessages returns all stored retained messages from the store. +func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + v = make([]storage.Message, 0) + err = h.iterKv(storage.RetainedKey, func(value []byte) error { + obj := storage.Message{} + err = obj.UnmarshalBinary(value) + if err == nil { + v = append(v, obj) + } + return err + }) + + if err != nil && !errors.Is(err, badgerdb.ErrKeyNotFound) { + return + } + return +} + +// StoredInflightMessages returns all stored inflight messages from the store. +func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + v = make([]storage.Message, 0) + err = h.iterKv(storage.InflightKey, func(value []byte) error { + obj := storage.Message{} + err = obj.UnmarshalBinary(value) + if err == nil { + v = append(v, obj) + } + return err + }) + return +} + +// StoredSysInfo returns the system info from the store. +func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + err = h.getKv(storage.SysInfoKey, &v) + if err != nil && !errors.Is(err, badgerdb.ErrKeyNotFound) { + return + } + return v, nil +} + +// Errorf satisfies the badger interface for an error logger. +func (h *Hook) Errorf(m string, v ...any) { + h.Log.Error(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) + +} + +// Warningf satisfies the badger interface for a warning logger. +func (h *Hook) Warningf(m string, v ...any) { + h.Log.Warn(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) +} + +// Infof satisfies the badger interface for an info logger. +func (h *Hook) Infof(m string, v ...any) { + h.Log.Info(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) +} + +// Debugf satisfies the badger interface for a debug logger. +func (h *Hook) Debugf(m string, v ...any) { + h.Log.Debug(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) +} + +// setKv stores a key-value pair in the database. +func (h *Hook) setKv(k string, v storage.Serializable) error { + err := h.db.Update(func(txn *badgerdb.Txn) error { + data, _ := v.MarshalBinary() + return txn.Set([]byte(k), data) + }) + if err != nil { + h.Log.Error("failed to upsert data", "error", err, "key", k) + } + return err +} + +// delKv deletes a key-value pair from the database. +func (h *Hook) delKv(k string) error { + err := h.db.Update(func(txn *badgerdb.Txn) error { + return txn.Delete([]byte(k)) + }) + + if err != nil { + h.Log.Error("failed to delete data", "error", err, "key", k) + } + return err +} + +// getKv retrieves the value associated with a key from the database. +func (h *Hook) getKv(k string, v storage.Serializable) error { + return h.db.View(func(txn *badgerdb.Txn) error { + item, err := txn.Get([]byte(k)) + if err != nil { + return err + } + value, err := item.ValueCopy(nil) + if err != nil { + return err + } + return v.UnmarshalBinary(value) + }) +} + +// iterKv iterates over key-value pairs with keys having the specified prefix in the database. +func (h *Hook) iterKv(prefix string, visit func([]byte) error) error { + err := h.db.View(func(txn *badgerdb.Txn) error { + iterator := txn.NewIterator(badgerdb.DefaultIteratorOptions) + defer iterator.Close() + + for iterator.Seek([]byte(prefix)); iterator.ValidForPrefix([]byte(prefix)); iterator.Next() { + item := iterator.Item() + value, err := item.ValueCopy(nil) + + if err != nil { + return err + } + + if err := visit(value); err != nil { + return err + } + } + return nil + }) + if err != nil { + h.Log.Error("failed to find data", "error", err, "prefix", prefix) + } + return err +} diff --git a/hooks/storage/badger/badger_test.go b/hooks/storage/badger/badger_test.go new file mode 100644 index 0000000..5f65fd4 --- /dev/null +++ b/hooks/storage/badger/badger_test.go @@ -0,0 +1,809 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co, werbenhu + +package badger + +import ( + "errors" + "log/slog" + "os" + "strings" + "testing" + "time" + + badgerdb "github.com/dgraph-io/badger/v4" + "testmqtt/hooks/storage" + "testmqtt/mqtt" + "testmqtt/packets" + "testmqtt/system" + + "github.com/stretchr/testify/require" +) + +var ( + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) + + client = &mqtt.Client{ + ID: "test", + Net: mqtt.ClientConnection{ + Remote: "test.addr", + Listener: "listener", + }, + Properties: mqtt.ClientProperties{ + Username: []byte("username"), + Clean: false, + }, + } + + pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}} +) + +func teardown(t *testing.T, path string, h *Hook) { + _ = h.Stop() + _ = h.db.Close() + err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1)) + require.NoError(t, err) +} + +func TestSetGetDelKv(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.Init(nil) + defer teardown(t, h.config.Path, h) + + key := "testKey" + value := &storage.Client{ID: "cl1"} + err := h.setKv(key, value) + require.NoError(t, err) + + var client storage.Client + err = h.getKv(key, &client) + require.NoError(t, err) + require.Equal(t, "cl1", client.ID) + + err = h.delKv(key) + require.NoError(t, err) + + err = h.getKv(key, &client) + require.ErrorIs(t, badgerdb.ErrKeyNotFound, err) +} + +func TestClientKey(t *testing.T) { + k := clientKey(&mqtt.Client{ID: "cl1"}) + require.Equal(t, storage.ClientKey+"_cl1", k) +} + +func TestSubscriptionKey(t *testing.T) { + k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c") + require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k) +} + +func TestRetainedKey(t *testing.T) { + k := retainedKey("a/b/c") + require.Equal(t, storage.RetainedKey+"_a/b/c", k) +} + +func TestInflightKey(t *testing.T) { + k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1}) + require.Equal(t, storage.InflightKey+"_cl1:1", k) +} + +func TestSysInfoKey(t *testing.T) { + require.Equal(t, storage.SysInfoKey, sysInfoKey()) +} + +func TestID(t *testing.T) { + h := new(Hook) + require.Equal(t, "badger-db", h.ID()) +} + +func TestProvides(t *testing.T) { + h := new(Hook) + require.True(t, h.Provides(mqtt.OnSessionEstablished)) + require.True(t, h.Provides(mqtt.OnDisconnect)) + require.True(t, h.Provides(mqtt.OnSubscribed)) + require.True(t, h.Provides(mqtt.OnUnsubscribed)) + require.True(t, h.Provides(mqtt.OnRetainMessage)) + require.True(t, h.Provides(mqtt.OnQosPublish)) + require.True(t, h.Provides(mqtt.OnQosComplete)) + require.True(t, h.Provides(mqtt.OnQosDropped)) + require.True(t, h.Provides(mqtt.OnSysInfoTick)) + require.True(t, h.Provides(mqtt.StoredClients)) + require.True(t, h.Provides(mqtt.StoredInflightMessages)) + require.True(t, h.Provides(mqtt.StoredRetainedMessages)) + require.True(t, h.Provides(mqtt.StoredSubscriptions)) + require.True(t, h.Provides(mqtt.StoredSysInfo)) + require.False(t, h.Provides(mqtt.OnACLCheck)) + require.False(t, h.Provides(mqtt.OnConnectAuthenticate)) +} + +func TestInitBadConfig(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + err := h.Init(map[string]any{}) + require.Error(t, err) +} + +func TestInitBadOption(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + err := h.Init(&Options{ + Options: &badgerdb.Options{ + NumCompactors: 1, + }, + }) + // Cannot have 1 compactor. Need at least 2 + require.Error(t, err) +} + +func TestInitUseDefaults(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + require.Equal(t, defaultDbFile, h.config.Path) +} + +func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + h.OnSessionEstablished(client, packets.Packet{}) + + r := new(storage.Client) + err = h.getKv(clientKey(client), r) + require.NoError(t, err) + require.Equal(t, client.ID, r.ID) + require.Equal(t, client.Properties.Username, r.Username) + require.Equal(t, client.Properties.Clean, r.Clean) + require.Equal(t, client.Net.Remote, r.Remote) + require.Equal(t, client.Net.Listener, r.Listener) + require.NotSame(t, client, r) + + h.OnDisconnect(client, nil, false) + r2 := new(storage.Client) + err = h.getKv(clientKey(client), r2) + require.NoError(t, err) + require.Equal(t, client.ID, r.ID) + + h.OnDisconnect(client, nil, true) + r3 := new(storage.Client) + err = h.getKv(clientKey(client), r3) + require.Error(t, err) + require.ErrorIs(t, badgerdb.ErrKeyNotFound, err) + require.Empty(t, r3.ID) +} + +func TestOnClientExpired(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + cl := &mqtt.Client{ID: "cl1"} + clientKey := clientKey(cl) + + err = h.setKv(clientKey, &storage.Client{ID: cl.ID}) + require.NoError(t, err) + + r := new(storage.Client) + err = h.getKv(clientKey, r) + + require.NoError(t, err) + require.Equal(t, cl.ID, r.ID) + + h.OnClientExpired(cl) + err = h.getKv(clientKey, r) + require.Error(t, err) + require.ErrorIs(t, badgerdb.ErrKeyNotFound, err) +} + +func TestOnClientExpiredNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnClientExpired(client) +} + +func TestOnClientExpiredClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnClientExpired(client) +} + +func TestOnSessionEstablishedNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnSessionEstablished(client, packets.Packet{}) +} + +func TestOnSessionEstablishedClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnSessionEstablished(client, packets.Packet{}) +} + +func TestOnWillSent(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + c1 := client + c1.Properties.Will.Flag = 1 + h.OnWillSent(c1, packets.Packet{}) + + r := new(storage.Client) + err = h.getKv(clientKey(client), r) + require.NoError(t, err) + + require.Equal(t, uint32(1), r.Will.Flag) + require.NotSame(t, client, r) +} + +func TestOnDisconnectNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnDisconnect(client, nil, false) +} + +func TestOnDisconnectClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnDisconnect(client, nil, false) +} + +func TestOnDisconnectSessionTakenOver(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + + testClient := &mqtt.Client{ + ID: "test", + Net: mqtt.ClientConnection{ + Remote: "test.addr", + Listener: "listener", + }, + Properties: mqtt.ClientProperties{ + Username: []byte("username"), + Clean: false, + }, + } + + testClient.Stop(packets.ErrSessionTakenOver) + teardown(t, h.config.Path, h) + h.OnDisconnect(testClient, nil, true) +} + +func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + h.OnSubscribed(client, pkf, []byte{0}) + r := new(storage.Subscription) + + err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r) + require.NoError(t, err) + require.Equal(t, client.ID, r.Client) + require.Equal(t, pkf.Filters[0].Filter, r.Filter) + require.Equal(t, byte(0), r.Qos) + + h.OnUnsubscribed(client, pkf) + err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r) + require.Error(t, err) + require.Equal(t, badgerdb.ErrKeyNotFound, err) +} + +func TestOnSubscribedNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnSubscribed(client, pkf, []byte{0}) +} + +func TestOnSubscribedClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnSubscribed(client, pkf, []byte{0}) +} + +func TestOnUnsubscribedNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnUnsubscribed(client, pkf) +} + +func TestOnUnsubscribedClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnUnsubscribed(client, pkf) +} + +func TestOnRetainMessageThenUnset(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + pk := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Retain: true, + }, + Payload: []byte("hello"), + TopicName: "a/b/c", + } + + h.OnRetainMessage(client, pk, 1) + + r := new(storage.Message) + err = h.getKv(retainedKey(pk.TopicName), r) + require.NoError(t, err) + require.Equal(t, pk.TopicName, r.TopicName) + require.Equal(t, pk.Payload, r.Payload) + + h.OnRetainMessage(client, pk, -1) + err = h.getKv(retainedKey(pk.TopicName), r) + require.Error(t, err) + require.ErrorIs(t, err, badgerdb.ErrKeyNotFound) + + // coverage: delete deleted + h.OnRetainMessage(client, pk, -1) + err = h.getKv(retainedKey(pk.TopicName), r) + require.Error(t, err) + require.ErrorIs(t, err, badgerdb.ErrKeyNotFound) +} + +func TestOnRetainedExpired(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + m := &storage.Message{ + ID: retainedKey("a/b/c"), + T: storage.RetainedKey, + TopicName: "a/b/c", + } + + err = h.setKv(m.ID, m) + require.NoError(t, err) + + r := new(storage.Message) + err = h.getKv(m.ID, r) + require.NoError(t, err) + require.Equal(t, m.TopicName, r.TopicName) + + h.OnRetainedExpired(m.TopicName) + err = h.getKv(m.ID, r) + require.Error(t, err) + require.ErrorIs(t, err, badgerdb.ErrKeyNotFound) +} + +func TestOnRetainExpiredNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnRetainedExpired("a/b/c") +} + +func TestOnRetainExpiredClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnRetainedExpired("a/b/c") +} + +func TestOnRetainMessageNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnRetainMessage(client, packets.Packet{}, 0) +} + +func TestOnRetainMessageClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnRetainMessage(client, packets.Packet{}, 0) +} + +func TestOnQosPublishThenQOSComplete(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + pk := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Retain: true, + Qos: 2, + }, + Payload: []byte("hello"), + TopicName: "a/b/c", + } + + h.OnQosPublish(client, pk, time.Now().Unix(), 0) + + r := new(storage.Message) + err = h.getKv(inflightKey(client, pk), r) + require.NoError(t, err) + require.Equal(t, pk.TopicName, r.TopicName) + require.Equal(t, pk.Payload, r.Payload) + + // ensure dates are properly saved + require.True(t, r.Sent > 0) + require.True(t, time.Now().Unix()-1 < r.Sent) + + // OnQosDropped is a passthrough to OnQosComplete here + h.OnQosDropped(client, pk) + err = h.getKv(inflightKey(client, pk), r) + require.Error(t, err) + require.ErrorIs(t, err, badgerdb.ErrKeyNotFound) +} + +func TestOnQosPublishNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) +} + +func TestOnQosPublishClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) +} + +func TestOnQosCompleteNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnQosComplete(client, packets.Packet{}) +} + +func TestOnQosCompleteClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnQosComplete(client, packets.Packet{}) +} + +func TestOnQosDroppedNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnQosDropped(client, packets.Packet{}) +} + +func TestOnSysInfoTick(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + info := &system.Info{ + Version: "2.0.0", + BytesReceived: 100, + } + + h.OnSysInfoTick(info) + + r := new(storage.SystemInfo) + err = h.getKv(storage.SysInfoKey, r) + require.NoError(t, err) + require.Equal(t, info.Version, r.Version) + require.Equal(t, info.BytesReceived, r.BytesReceived) + require.NotSame(t, info, r) +} + +func TestOnSysInfoTickNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnSysInfoTick(new(system.Info)) +} + +func TestOnSysInfoTickClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnSysInfoTick(new(system.Info)) +} + +func TestStoredClients(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + // populate with clients + err = h.setKv(storage.ClientKey+"_cl1", &storage.Client{ID: "cl1", T: storage.ClientKey}) + require.NoError(t, err) + + err = h.setKv(storage.ClientKey+"_cl2", &storage.Client{ID: "cl2", T: storage.ClientKey}) + require.NoError(t, err) + + err = h.setKv(storage.ClientKey+"_cl3", &storage.Client{ID: "cl3", T: storage.ClientKey}) + require.NoError(t, err) + + r, err := h.StoredClients() + require.NoError(t, err) + require.Len(t, r, 3) + require.Equal(t, "cl1", r[0].ID) + require.Equal(t, "cl2", r[1].ID) + require.Equal(t, "cl3", r[2].ID) +} + +func TestStoredClientsNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredClients() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestStoredSubscriptions(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + // populate with subscriptions + err = h.setKv(storage.SubscriptionKey+"_sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey}) + require.NoError(t, err) + + err = h.setKv(storage.SubscriptionKey+"_sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey}) + require.NoError(t, err) + + err = h.setKv(storage.SubscriptionKey+"_sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey}) + require.NoError(t, err) + + r, err := h.StoredSubscriptions() + require.NoError(t, err) + require.Len(t, r, 3) + require.Equal(t, "sub1", r[0].ID) + require.Equal(t, "sub2", r[1].ID) + require.Equal(t, "sub3", r[2].ID) +} + +func TestStoredSubscriptionsNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredSubscriptions() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestStoredRetainedMessages(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + // populate with messages + err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1", T: storage.RetainedKey}) + require.NoError(t, err) + + err = h.setKv(storage.RetainedKey+"_m2", &storage.Message{ID: "m2", T: storage.RetainedKey}) + require.NoError(t, err) + + err = h.setKv(storage.RetainedKey+"_m3", &storage.Message{ID: "m3", T: storage.RetainedKey}) + require.NoError(t, err) + + err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3", T: storage.InflightKey}) + require.NoError(t, err) + + r, err := h.StoredRetainedMessages() + require.NoError(t, err) + require.Len(t, r, 3) + require.Equal(t, "m1", r[0].ID) + require.Equal(t, "m2", r[1].ID) + require.Equal(t, "m3", r[2].ID) +} + +func TestStoredRetainedMessagesNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredRetainedMessages() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestStoredInflightMessages(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + // populate with messages + err = h.setKv(storage.InflightKey+"_i1", &storage.Message{ID: "i1", T: storage.InflightKey}) + require.NoError(t, err) + + err = h.setKv(storage.InflightKey+"_i2", &storage.Message{ID: "i2", T: storage.InflightKey}) + require.NoError(t, err) + + err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3", T: storage.InflightKey}) + require.NoError(t, err) + + err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1", T: storage.RetainedKey}) + require.NoError(t, err) + + r, err := h.StoredInflightMessages() + require.NoError(t, err) + require.Len(t, r, 3) + require.Equal(t, "i1", r[0].ID) + require.Equal(t, "i2", r[1].ID) + require.Equal(t, "i3", r[2].ID) +} + +func TestStoredInflightMessagesNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredInflightMessages() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestStoredSysInfo(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + // populate with messages + err = h.setKv(storage.SysInfoKey, &storage.SystemInfo{ + ID: storage.SysInfoKey, + Info: system.Info{ + Version: "2.0.0", + }, + T: storage.SysInfoKey, + }) + require.NoError(t, err) + + r, err := h.StoredSysInfo() + require.NoError(t, err) + require.Equal(t, "2.0.0", r.Info.Version) +} + +func TestStoredSysInfoNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredSysInfo() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestErrorf(t *testing.T) { + // coverage: one day check log hook + h := new(Hook) + h.SetOpts(logger, nil) + h.Errorf("test", 1, 2, 3) +} + +func TestWarningf(t *testing.T) { + // coverage: one day check log hook + h := new(Hook) + h.SetOpts(logger, nil) + h.Warningf("test", 1, 2, 3) +} + +func TestInfof(t *testing.T) { + // coverage: one day check log hook + h := new(Hook) + h.SetOpts(logger, nil) + h.Infof("test", 1, 2, 3) +} + +func TestDebugf(t *testing.T) { + // coverage: one day check log hook + h := new(Hook) + h.SetOpts(logger, nil) + h.Debugf("test", 1, 2, 3) +} + +func TestGcLoop(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + opts := badgerdb.DefaultOptions(defaultDbFile) + opts.ValueLogFileSize = 1 << 20 + h.Init(&Options{ + GcInterval: 2, // Set the interval for garbage collection. + Options: &opts, + }) + defer teardown(t, defaultDbFile, h) + h.OnSessionEstablished(client, packets.Packet{}) + h.OnDisconnect(client, nil, true) + time.Sleep(3 * time.Second) +} + +func TestGetSetDelKv(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + defer teardown(t, h.config.Path, h) + require.NoError(t, err) + + err = h.setKv("testKey", &storage.Client{ID: "testId"}) + require.NoError(t, err) + + var obj storage.Client + err = h.getKv("testKey", &obj) + require.NoError(t, err) + + err = h.delKv("testKey") + require.NoError(t, err) + + err = h.getKv("testKey", &obj) + require.Error(t, err) + require.ErrorIs(t, badgerdb.ErrKeyNotFound, err) +} + +func TestIterKv(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + err := h.Init(nil) + defer teardown(t, h.config.Path, h) + require.NoError(t, err) + + h.setKv("prefix_a_1", &storage.Client{ID: "1"}) + h.setKv("prefix_a_2", &storage.Client{ID: "2"}) + h.setKv("prefix_b_2", &storage.Client{ID: "3"}) + + var clients []storage.Client + err = h.iterKv("prefix_a", func(data []byte) error { + var item storage.Client + item.UnmarshalBinary(data) + clients = append(clients, item) + return nil + }) + require.Equal(t, 2, len(clients)) + require.NoError(t, err) + + visitErr := errors.New("iter visit error") + err = h.iterKv("prefix_b", func(data []byte) error { + return visitErr + }) + require.ErrorIs(t, visitErr, err) +} diff --git a/hooks/storage/bolt/bolt.go b/hooks/storage/bolt/bolt.go new file mode 100644 index 0000000..6d6235a --- /dev/null +++ b/hooks/storage/bolt/bolt.go @@ -0,0 +1,525 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co, werbenhu + +// Package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead. +package bolt + +import ( + "bytes" + "errors" + "time" + + "testmqtt/hooks/storage" + "testmqtt/mqtt" + "testmqtt/packets" + "testmqtt/system" + + "go.etcd.io/bbolt" +) + +var ( + ErrBucketNotFound = errors.New("bucket not found") + ErrKeyNotFound = errors.New("key not found") +) + +const ( + // defaultDbFile is the default file path for the boltdb file. + defaultDbFile = ".bolt" + + // defaultTimeout is the default time to hold a connection to the file. + defaultTimeout = 250 * time.Millisecond + + defaultBucket = "mochi" +) + +// clientKey returns a primary key for a client. +func clientKey(cl *mqtt.Client) string { + return storage.ClientKey + "_" + cl.ID +} + +// subscriptionKey returns a primary key for a subscription. +func subscriptionKey(cl *mqtt.Client, filter string) string { + return storage.SubscriptionKey + "_" + cl.ID + ":" + filter +} + +// retainedKey returns a primary key for a retained message. +func retainedKey(topic string) string { + return storage.RetainedKey + "_" + topic +} + +// inflightKey returns a primary key for an inflight message. +func inflightKey(cl *mqtt.Client, pk packets.Packet) string { + return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID() +} + +// sysInfoKey returns a primary key for system info. +func sysInfoKey() string { + return storage.SysInfoKey +} + +// Options contains configuration settings for the bolt instance. +type Options struct { + Options *bbolt.Options + Bucket string `yaml:"bucket" json:"bucket"` + Path string `yaml:"path" json:"path"` +} + +// Hook is a persistent storage hook based using boltdb file store as a backend. +type Hook struct { + mqtt.HookBase + config *Options // options for configuring the boltdb instance. + db *bbolt.DB // the boltdb instance. +} + +// ID returns the id of the hook. +func (h *Hook) ID() string { + return "bolt-db" +} + +// Provides indicates which hook methods this hook provides. +func (h *Hook) Provides(b byte) bool { + return bytes.Contains([]byte{ + mqtt.OnSessionEstablished, + mqtt.OnDisconnect, + mqtt.OnSubscribed, + mqtt.OnUnsubscribed, + mqtt.OnRetainMessage, + mqtt.OnWillSent, + mqtt.OnQosPublish, + mqtt.OnQosComplete, + mqtt.OnQosDropped, + mqtt.OnSysInfoTick, + mqtt.OnClientExpired, + mqtt.OnRetainedExpired, + mqtt.StoredClients, + mqtt.StoredInflightMessages, + mqtt.StoredRetainedMessages, + mqtt.StoredSubscriptions, + mqtt.StoredSysInfo, + }, []byte{b}) +} + +// Init initializes and connects to the boltdb instance. +func (h *Hook) Init(config any) error { + if _, ok := config.(*Options); !ok && config != nil { + return mqtt.ErrInvalidConfigType + } + + if config == nil { + config = new(Options) + } + + h.config = config.(*Options) + if h.config.Options == nil { + h.config.Options = &bbolt.Options{ + Timeout: defaultTimeout, + } + } + if len(h.config.Path) == 0 { + h.config.Path = defaultDbFile + } + + if len(h.config.Bucket) == 0 { + h.config.Bucket = defaultBucket + } + + var err error + h.db, err = bbolt.Open(h.config.Path, 0600, h.config.Options) + if err != nil { + return err + } + + err = h.db.Update(func(tx *bbolt.Tx) error { + _, err := tx.CreateBucketIfNotExists([]byte(h.config.Bucket)) + return err + }) + return err +} + +// Stop closes the boltdb instance. +func (h *Hook) Stop() error { + err := h.db.Close() + h.db = nil + return err +} + +// OnSessionEstablished adds a client to the store when their session is established. +func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { + h.updateClient(cl) +} + +// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record. +func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { + h.updateClient(cl) +} + +// updateClient writes the client data to the store. +func (h *Hook) updateClient(cl *mqtt.Client) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + props := cl.Properties.Props.Copy(false) + in := &storage.Client{ + ID: cl.ID, + T: storage.ClientKey, + Remote: cl.Net.Remote, + Listener: cl.Net.Listener, + Username: cl.Properties.Username, + Clean: cl.Properties.Clean, + ProtocolVersion: cl.Properties.ProtocolVersion, + Properties: storage.ClientProperties{ + SessionExpiryInterval: props.SessionExpiryInterval, + AuthenticationMethod: props.AuthenticationMethod, + AuthenticationData: props.AuthenticationData, + RequestProblemInfo: props.RequestProblemInfo, + RequestResponseInfo: props.RequestResponseInfo, + ReceiveMaximum: props.ReceiveMaximum, + TopicAliasMaximum: props.TopicAliasMaximum, + User: props.User, + MaximumPacketSize: props.MaximumPacketSize, + }, + Will: storage.ClientWill(cl.Properties.Will), + } + + _ = h.setKv(clientKey(cl), in) +} + +// OnDisconnect removes a client from the store if they were using a clean session. +func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + if !expire { + return + } + + if cl.StopCause() == packets.ErrSessionTakenOver { + return + } + + _ = h.delKv(clientKey(cl)) +} + +// OnSubscribed adds one or more client subscriptions to the store. +func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + var in *storage.Subscription + for i := 0; i < len(pk.Filters); i++ { + in = &storage.Subscription{ + ID: subscriptionKey(cl, pk.Filters[i].Filter), + T: storage.SubscriptionKey, + Client: cl.ID, + Qos: reasonCodes[i], + Filter: pk.Filters[i].Filter, + Identifier: pk.Filters[i].Identifier, + NoLocal: pk.Filters[i].NoLocal, + RetainHandling: pk.Filters[i].RetainHandling, + RetainAsPublished: pk.Filters[i].RetainAsPublished, + } + _ = h.setKv(in.ID, in) + } +} + +// OnUnsubscribed removes one or more client subscriptions from the store. +func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + for i := 0; i < len(pk.Filters); i++ { + _ = h.delKv(subscriptionKey(cl, pk.Filters[i].Filter)) + } +} + +// OnRetainMessage adds a retained message for a topic to the store. +func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + if r == -1 { + _ = h.delKv(retainedKey(pk.TopicName)) + return + } + + props := pk.Properties.Copy(false) + in := &storage.Message{ + ID: retainedKey(pk.TopicName), + T: storage.RetainedKey, + FixedHeader: pk.FixedHeader, + TopicName: pk.TopicName, + Payload: pk.Payload, + Created: pk.Created, + Origin: pk.Origin, + Properties: storage.MessageProperties{ + PayloadFormat: props.PayloadFormat, + MessageExpiryInterval: props.MessageExpiryInterval, + ContentType: props.ContentType, + ResponseTopic: props.ResponseTopic, + CorrelationData: props.CorrelationData, + SubscriptionIdentifier: props.SubscriptionIdentifier, + TopicAlias: props.TopicAlias, + User: props.User, + }, + } + + _ = h.setKv(in.ID, in) +} + +// OnQosPublish adds or updates an inflight message in the store. +func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + props := pk.Properties.Copy(false) + in := &storage.Message{ + ID: inflightKey(cl, pk), + T: storage.InflightKey, + Origin: pk.Origin, + FixedHeader: pk.FixedHeader, + TopicName: pk.TopicName, + Payload: pk.Payload, + Sent: sent, + Created: pk.Created, + Properties: storage.MessageProperties{ + PayloadFormat: props.PayloadFormat, + MessageExpiryInterval: props.MessageExpiryInterval, + ContentType: props.ContentType, + ResponseTopic: props.ResponseTopic, + CorrelationData: props.CorrelationData, + SubscriptionIdentifier: props.SubscriptionIdentifier, + TopicAlias: props.TopicAlias, + User: props.User, + }, + } + + _ = h.setKv(in.ID, in) +} + +// OnQosComplete removes a resolved inflight message from the store. +func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + _ = h.delKv(inflightKey(cl, pk)) +} + +// OnQosDropped removes a dropped inflight message from the store. +func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + } + + h.OnQosComplete(cl, pk) +} + +// OnSysInfoTick stores the latest system info in the store. +func (h *Hook) OnSysInfoTick(sys *system.Info) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + in := &storage.SystemInfo{ + ID: sysInfoKey(), + T: storage.SysInfoKey, + Info: *sys, + } + + _ = h.setKv(in.ID, in) +} + +// OnRetainedExpired deletes expired retained messages from the store. +func (h *Hook) OnRetainedExpired(filter string) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + _ = h.delKv(retainedKey(filter)) +} + +// OnClientExpired deleted expired clients from the store. +func (h *Hook) OnClientExpired(cl *mqtt.Client) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + _ = h.delKv(clientKey(cl)) +} + +// StoredClients returns all stored clients from the store. +func (h *Hook) StoredClients() (v []storage.Client, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return v, storage.ErrDBFileNotOpen + } + + err = h.iterKv(storage.ClientKey, func(value []byte) error { + obj := storage.Client{} + err = obj.UnmarshalBinary(value) + if err == nil { + v = append(v, obj) + } + return err + }) + return v, nil +} + +// StoredSubscriptions returns all stored subscriptions from the store. +func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return v, storage.ErrDBFileNotOpen + } + + v = make([]storage.Subscription, 0) + err = h.iterKv(storage.SubscriptionKey, func(value []byte) error { + obj := storage.Subscription{} + err = obj.UnmarshalBinary(value) + if err == nil { + v = append(v, obj) + } + return err + }) + return +} + +// StoredRetainedMessages returns all stored retained messages from the store. +func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return v, storage.ErrDBFileNotOpen + } + + v = make([]storage.Message, 0) + err = h.iterKv(storage.RetainedKey, func(value []byte) error { + obj := storage.Message{} + err = obj.UnmarshalBinary(value) + if err == nil { + v = append(v, obj) + } + return err + }) + return +} + +// StoredInflightMessages returns all stored inflight messages from the store. +func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return v, storage.ErrDBFileNotOpen + } + + v = make([]storage.Message, 0) + err = h.iterKv(storage.InflightKey, func(value []byte) error { + obj := storage.Message{} + err = obj.UnmarshalBinary(value) + if err == nil { + v = append(v, obj) + } + return err + }) + return +} + +// StoredSysInfo returns the system info from the store. +func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return v, storage.ErrDBFileNotOpen + } + + err = h.getKv(storage.SysInfoKey, &v) + if err != nil && !errors.Is(err, ErrKeyNotFound) { + return + } + + return v, nil +} + +// setKv stores a key-value pair in the database. +func (h *Hook) setKv(k string, v storage.Serializable) error { + err := h.db.Update(func(tx *bbolt.Tx) error { + + bucket := tx.Bucket([]byte(h.config.Bucket)) + data, _ := v.MarshalBinary() + err := bucket.Put([]byte(k), data) + if err != nil { + return err + } + return nil + }) + if err != nil { + h.Log.Error("failed to upsert data", "error", err, "key", k) + } + return err +} + +// delKv deletes a key-value pair from the database. +func (h *Hook) delKv(k string) error { + err := h.db.Update(func(tx *bbolt.Tx) error { + bucket := tx.Bucket([]byte(h.config.Bucket)) + err := bucket.Delete([]byte(k)) + if err != nil { + return err + } + return nil + }) + if err != nil { + h.Log.Error("failed to delete data", "error", err, "key", k) + } + return err +} + +// getKv retrieves the value associated with a key from the database. +func (h *Hook) getKv(k string, v storage.Serializable) error { + err := h.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket([]byte(h.config.Bucket)) + + value := bucket.Get([]byte(k)) + if value == nil { + return ErrKeyNotFound + } + + return v.UnmarshalBinary(value) + }) + if err != nil { + h.Log.Error("failed to get data", "error", err, "key", k) + } + return err +} + +// iterKv iterates over key-value pairs with keys having the specified prefix in the database. +func (h *Hook) iterKv(prefix string, visit func([]byte) error) error { + err := h.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket([]byte(h.config.Bucket)) + + c := bucket.Cursor() + for k, v := c.Seek([]byte(prefix)); k != nil && string(k[:len(prefix)]) == prefix; k, v = c.Next() { + if err := visit(v); err != nil { + return err + } + } + return nil + }) + + if err != nil { + h.Log.Error("failed to iter data", "error", err, "prefix", prefix) + } + return err +} diff --git a/hooks/storage/bolt/bolt_test.go b/hooks/storage/bolt/bolt_test.go new file mode 100644 index 0000000..780162f --- /dev/null +++ b/hooks/storage/bolt/bolt_test.go @@ -0,0 +1,791 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co, werbenhu + +package bolt + +import ( + "errors" + "log/slog" + "os" + "testing" + "time" + + "testmqtt/hooks/storage" + "testmqtt/mqtt" + "testmqtt/packets" + "testmqtt/system" + + "github.com/stretchr/testify/require" +) + +var ( + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) + + client = &mqtt.Client{ + ID: "test", + Net: mqtt.ClientConnection{ + Remote: "test.addr", + Listener: "listener", + }, + Properties: mqtt.ClientProperties{ + Username: []byte("username"), + Clean: false, + }, + } + + pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}} +) + +func teardown(t *testing.T, path string, h *Hook) { + _ = h.Stop() + err := os.Remove(path) + require.NoError(t, err) +} + +func TestClientKey(t *testing.T) { + k := clientKey(&mqtt.Client{ID: "cl1"}) + require.Equal(t, "CL_cl1", k) +} + +func TestSubscriptionKey(t *testing.T) { + k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c") + require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k) +} + +func TestRetainedKey(t *testing.T) { + k := retainedKey("a/b/c") + require.Equal(t, storage.RetainedKey+"_a/b/c", k) +} + +func TestInflightKey(t *testing.T) { + k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1}) + require.Equal(t, storage.InflightKey+"_cl1:1", k) +} + +func TestSysInfoKey(t *testing.T) { + require.Equal(t, storage.SysInfoKey, sysInfoKey()) +} + +func TestID(t *testing.T) { + h := new(Hook) + require.Equal(t, "bolt-db", h.ID()) +} + +func TestProvides(t *testing.T) { + h := new(Hook) + require.True(t, h.Provides(mqtt.OnSessionEstablished)) + require.True(t, h.Provides(mqtt.OnDisconnect)) + require.True(t, h.Provides(mqtt.OnSubscribed)) + require.True(t, h.Provides(mqtt.OnUnsubscribed)) + require.True(t, h.Provides(mqtt.OnRetainMessage)) + require.True(t, h.Provides(mqtt.OnQosPublish)) + require.True(t, h.Provides(mqtt.OnQosComplete)) + require.True(t, h.Provides(mqtt.OnQosDropped)) + require.True(t, h.Provides(mqtt.OnSysInfoTick)) + require.True(t, h.Provides(mqtt.StoredClients)) + require.True(t, h.Provides(mqtt.StoredInflightMessages)) + require.True(t, h.Provides(mqtt.StoredRetainedMessages)) + require.True(t, h.Provides(mqtt.StoredSubscriptions)) + require.True(t, h.Provides(mqtt.StoredSysInfo)) + require.False(t, h.Provides(mqtt.OnACLCheck)) + require.False(t, h.Provides(mqtt.OnConnectAuthenticate)) +} + +func TestInitBadConfig(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + err := h.Init(map[string]any{}) + require.Error(t, err) +} + +func TestInitUseDefaults(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + require.Equal(t, defaultTimeout, h.config.Options.Timeout) + require.Equal(t, defaultDbFile, h.config.Path) +} + +func TestInitBadPath(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(&Options{ + Path: "..", + }) + require.Error(t, err) +} + +func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + h.OnSessionEstablished(client, packets.Packet{}) + + r := new(storage.Client) + err = h.getKv(clientKey(client), r) + require.NoError(t, err) + require.Equal(t, client.ID, r.ID) + require.Equal(t, client.Net.Remote, r.Remote) + require.Equal(t, client.Net.Listener, r.Listener) + require.Equal(t, client.Properties.Username, r.Username) + require.Equal(t, client.Properties.Clean, r.Clean) + require.NotSame(t, client, r) + + h.OnDisconnect(client, nil, false) + r2 := new(storage.Client) + err = h.getKv(clientKey(client), r2) + require.NoError(t, err) + require.Equal(t, client.ID, r.ID) + + h.OnDisconnect(client, nil, true) + r3 := new(storage.Client) + err = h.getKv(clientKey(client), r3) + require.Error(t, err) + require.ErrorIs(t, ErrKeyNotFound, err) + require.Empty(t, r3.ID) +} + +func TestOnSessionEstablishedNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnSessionEstablished(client, packets.Packet{}) +} + +func TestOnSessionEstablishedClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnSessionEstablished(client, packets.Packet{}) +} + +func TestOnWillSent(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + c1 := client + c1.Properties.Will.Flag = 1 + h.OnWillSent(c1, packets.Packet{}) + + r := new(storage.Client) + err = h.getKv(clientKey(client), r) + require.NoError(t, err) + + require.Equal(t, uint32(1), r.Will.Flag) + require.NotSame(t, client, r) +} + +func TestOnClientExpired(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + cl := &mqtt.Client{ID: "cl1"} + clientKey := clientKey(cl) + + err = h.setKv(clientKey, &storage.Client{ID: cl.ID}) + require.NoError(t, err) + + r := new(storage.Client) + err = h.getKv(clientKey, r) + require.NoError(t, err) + require.Equal(t, cl.ID, r.ID) + + h.OnClientExpired(cl) + err = h.getKv(clientKey, r) + require.Error(t, err) + require.ErrorIs(t, ErrKeyNotFound, err) +} + +func TestOnClientExpiredClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnClientExpired(client) +} + +func TestOnClientExpiredNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnClientExpired(client) +} + +func TestOnDisconnectNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnDisconnect(client, nil, false) +} + +func TestOnDisconnectClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnDisconnect(client, nil, false) +} + +func TestOnDisconnectSessionTakenOver(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + + testClient := &mqtt.Client{ + ID: "test", + Net: mqtt.ClientConnection{ + Remote: "test.addr", + Listener: "listener", + }, + Properties: mqtt.ClientProperties{ + Username: []byte("username"), + Clean: false, + }, + } + + testClient.Stop(packets.ErrSessionTakenOver) + teardown(t, h.config.Path, h) + h.OnDisconnect(testClient, nil, true) +} + +func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + h.OnSubscribed(client, pkf, []byte{0}) + r := new(storage.Subscription) + + err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r) + require.NoError(t, err) + require.Equal(t, client.ID, r.Client) + require.Equal(t, pkf.Filters[0].Filter, r.Filter) + require.Equal(t, byte(0), r.Qos) + + h.OnUnsubscribed(client, pkf) + err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r) + require.Error(t, err) + require.Equal(t, ErrKeyNotFound, err) +} + +func TestOnSubscribedNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnSubscribed(client, pkf, []byte{0}) +} + +func TestOnSubscribedClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnSubscribed(client, pkf, []byte{0}) +} + +func TestOnUnsubscribedNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnUnsubscribed(client, pkf) +} + +func TestOnUnsubscribedClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnUnsubscribed(client, pkf) +} + +func TestOnRetainMessageThenUnset(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + pk := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Retain: true, + }, + Payload: []byte("hello"), + TopicName: "a/b/c", + } + + h.OnRetainMessage(client, pk, 1) + + r := new(storage.Message) + err = h.getKv(retainedKey(pk.TopicName), r) + require.NoError(t, err) + require.Equal(t, pk.TopicName, r.TopicName) + require.Equal(t, pk.Payload, r.Payload) + + h.OnRetainMessage(client, pk, -1) + err = h.getKv(retainedKey(pk.TopicName), r) + require.Error(t, err) + require.Equal(t, ErrKeyNotFound, err) + + // coverage: delete deleted + h.OnRetainMessage(client, pk, -1) + err = h.getKv(retainedKey(pk.TopicName), r) + require.Error(t, err) + require.Equal(t, ErrKeyNotFound, err) +} + +func TestOnRetainedExpired(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + m := &storage.Message{ + ID: retainedKey("a/b/c"), + T: storage.RetainedKey, + TopicName: "a/b/c", + } + + err = h.setKv(m.ID, m) + require.NoError(t, err) + + r := new(storage.Message) + err = h.getKv(m.ID, r) + require.NoError(t, err) + require.Equal(t, m.TopicName, r.TopicName) + + h.OnRetainedExpired(m.TopicName) + err = h.getKv(m.ID, r) + require.Error(t, err) + require.Equal(t, ErrKeyNotFound, err) +} + +func TestOnRetainedExpiredClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnRetainedExpired("a/b/c") +} + +func TestOnRetainedExpiredNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnRetainedExpired("a/b/c") +} + +func TestOnRetainMessageNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnRetainMessage(client, packets.Packet{}, 0) +} + +func TestOnRetainMessageClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnRetainMessage(client, packets.Packet{}, 0) +} + +func TestOnQosPublishThenQOSComplete(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + pk := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Retain: true, + Qos: 2, + }, + Payload: []byte("hello"), + TopicName: "a/b/c", + } + + h.OnQosPublish(client, pk, time.Now().Unix(), 0) + + r := new(storage.Message) + err = h.getKv(inflightKey(client, pk), r) + require.NoError(t, err) + require.Equal(t, pk.TopicName, r.TopicName) + require.Equal(t, pk.Payload, r.Payload) + + // ensure dates are properly saved to bolt + require.True(t, r.Sent > 0) + require.True(t, time.Now().Unix()-1 < r.Sent) + + // OnQosDropped is a passthrough to OnQosComplete here + h.OnQosDropped(client, pk) + err = h.getKv(inflightKey(client, pk), r) + require.Error(t, err) + require.Equal(t, ErrKeyNotFound, err) +} + +func TestOnQosPublishNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) +} + +func TestOnQosPublishClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) +} + +func TestOnQosCompleteNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnQosComplete(client, packets.Packet{}) +} + +func TestOnQosCompleteClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnQosComplete(client, packets.Packet{}) +} + +func TestOnQosDroppedNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnQosDropped(client, packets.Packet{}) +} + +func TestOnSysInfoTick(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + info := &system.Info{ + Version: "2.0.0", + BytesReceived: 100, + } + + h.OnSysInfoTick(info) + + r := new(storage.SystemInfo) + err = h.getKv(storage.SysInfoKey, r) + require.NoError(t, err) + require.Equal(t, info.Version, r.Version) + require.Equal(t, info.BytesReceived, r.BytesReceived) + require.NotSame(t, info, r) +} + +func TestOnSysInfoTickNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnSysInfoTick(new(system.Info)) +} + +func TestOnSysInfoTickClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnSysInfoTick(new(system.Info)) +} + +func TestStoredClients(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + // populate with clients + err = h.setKv(storage.ClientKey+"_"+"cl1", &storage.Client{ID: "cl1"}) + require.NoError(t, err) + + err = h.setKv(storage.ClientKey+"_"+"cl2", &storage.Client{ID: "cl2"}) + require.NoError(t, err) + + err = h.setKv(storage.ClientKey+"_"+"cl3", &storage.Client{ID: "cl3"}) + require.NoError(t, err) + + r, err := h.StoredClients() + require.NoError(t, err) + require.Len(t, r, 3) + require.Equal(t, "cl1", r[0].ID) + require.Equal(t, "cl2", r[1].ID) + require.Equal(t, "cl3", r[2].ID) +} + +func TestStoredClientsNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredClients() + require.Empty(t, v) + require.ErrorIs(t, storage.ErrDBFileNotOpen, err) +} + +func TestStoredClientsClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + v, err := h.StoredClients() + require.Empty(t, v) + require.ErrorIs(t, storage.ErrDBFileNotOpen, err) +} + +func TestStoredSubscriptions(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + // populate with subscriptions + err = h.setKv(storage.SubscriptionKey+"_"+"sub1", &storage.Subscription{ID: "sub1"}) + require.NoError(t, err) + + err = h.setKv(storage.SubscriptionKey+"_"+"sub2", &storage.Subscription{ID: "sub2"}) + require.NoError(t, err) + + err = h.setKv(storage.SubscriptionKey+"_"+"sub3", &storage.Subscription{ID: "sub3"}) + require.NoError(t, err) + + r, err := h.StoredSubscriptions() + require.NoError(t, err) + require.Len(t, r, 3) + require.Equal(t, "sub1", r[0].ID) + require.Equal(t, "sub2", r[1].ID) + require.Equal(t, "sub3", r[2].ID) +} + +func TestStoredSubscriptionsNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredSubscriptions() + require.Empty(t, v) + require.ErrorIs(t, storage.ErrDBFileNotOpen, err) +} + +func TestStoredSubscriptionsClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + v, err := h.StoredSubscriptions() + require.Empty(t, v) + require.ErrorIs(t, storage.ErrDBFileNotOpen, err) +} + +func TestStoredRetainedMessages(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + // populate with messages + err = h.setKv(storage.RetainedKey+"_"+"m1", &storage.Message{ID: "m1"}) + require.NoError(t, err) + + err = h.setKv(storage.RetainedKey+"_"+"m2", &storage.Message{ID: "m2"}) + require.NoError(t, err) + + err = h.setKv(storage.RetainedKey+"_"+"m3", &storage.Message{ID: "m3"}) + require.NoError(t, err) + + err = h.setKv(storage.InflightKey+"_"+"i3", &storage.Message{ID: "i3"}) + require.NoError(t, err) + + r, err := h.StoredRetainedMessages() + require.NoError(t, err) + require.Len(t, r, 3) + require.Equal(t, "m1", r[0].ID) + require.Equal(t, "m2", r[1].ID) + require.Equal(t, "m3", r[2].ID) +} + +func TestStoredRetainedMessagesNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredRetainedMessages() + require.Empty(t, v) + require.ErrorIs(t, storage.ErrDBFileNotOpen, err) +} + +func TestStoredRetainedMessagesClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + v, err := h.StoredRetainedMessages() + require.Empty(t, v) + require.Error(t, err) +} + +func TestStoredInflightMessages(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + // populate with messages + err = h.setKv(storage.InflightKey+"_"+"i1", &storage.Message{ID: "i1"}) + require.NoError(t, err) + + err = h.setKv(storage.InflightKey+"_"+"i2", &storage.Message{ID: "i2"}) + require.NoError(t, err) + + err = h.setKv(storage.InflightKey+"_"+"i3", &storage.Message{ID: "i3"}) + require.NoError(t, err) + + err = h.setKv(storage.RetainedKey+"_"+"m1", &storage.Message{ID: "m1"}) + require.NoError(t, err) + + r, err := h.StoredInflightMessages() + require.NoError(t, err) + require.Len(t, r, 3) + require.Equal(t, "i1", r[0].ID) + require.Equal(t, "i2", r[1].ID) + require.Equal(t, "i3", r[2].ID) +} + +func TestStoredInflightMessagesNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredInflightMessages() + require.Empty(t, v) + require.ErrorIs(t, storage.ErrDBFileNotOpen, err) +} + +func TestStoredInflightMessagesClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + v, err := h.StoredInflightMessages() + require.Empty(t, v) + require.ErrorIs(t, storage.ErrDBFileNotOpen, err) +} + +func TestStoredSysInfo(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + // populate with sys info + err = h.setKv(storage.SysInfoKey, &storage.SystemInfo{ + ID: storage.SysInfoKey, + Info: system.Info{ + Version: "2.0.0", + }, + T: storage.SysInfoKey, + }) + require.NoError(t, err) + + r, err := h.StoredSysInfo() + require.NoError(t, err) + require.Equal(t, "2.0.0", r.Info.Version) +} + +func TestStoredSysInfoNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredSysInfo() + require.Empty(t, v) + require.ErrorIs(t, storage.ErrDBFileNotOpen, err) +} + +func TestStoredSysInfoClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + v, err := h.StoredSysInfo() + require.Empty(t, v) + require.Error(t, err) +} + +func TestGetSetDelKv(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + defer teardown(t, h.config.Path, h) + require.NoError(t, err) + + err = h.setKv("testId", &storage.Client{ID: "testId"}) + require.NoError(t, err) + + var obj storage.Client + err = h.getKv("testId", &obj) + require.NoError(t, err) + + err = h.delKv("testId") + require.NoError(t, err) + + err = h.getKv("testId", &obj) + require.Error(t, err) + require.ErrorIs(t, ErrKeyNotFound, err) +} + +func TestIterKv(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + err := h.Init(nil) + defer teardown(t, h.config.Path, h) + require.NoError(t, err) + + h.setKv("prefix_a_1", &storage.Client{ID: "1"}) + h.setKv("prefix_a_2", &storage.Client{ID: "2"}) + h.setKv("prefix_b_2", &storage.Client{ID: "3"}) + + var clients []storage.Client + err = h.iterKv("prefix_a", func(data []byte) error { + var item storage.Client + item.UnmarshalBinary(data) + clients = append(clients, item) + return nil + }) + require.Equal(t, 2, len(clients)) + require.NoError(t, err) + + visitErr := errors.New("iter visit error") + err = h.iterKv("prefix_b", func(data []byte) error { + return visitErr + }) + require.ErrorIs(t, visitErr, err) +} diff --git a/hooks/storage/pebble/pebble.go b/hooks/storage/pebble/pebble.go new file mode 100644 index 0000000..2066338 --- /dev/null +++ b/hooks/storage/pebble/pebble.go @@ -0,0 +1,524 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: werbenhu + +package pebble + +import ( + "bytes" + "errors" + "fmt" + "strings" + + pebbledb "github.com/cockroachdb/pebble" + "testmqtt/hooks/storage" + "testmqtt/mqtt" + "testmqtt/packets" + "testmqtt/system" +) + +const ( + // defaultDbFile is the default file path for the pebble db file. + defaultDbFile = ".pebble" +) + +// clientKey returns a primary key for a client. +func clientKey(cl *mqtt.Client) string { + return storage.ClientKey + "_" + cl.ID +} + +// subscriptionKey returns a primary key for a subscription. +func subscriptionKey(cl *mqtt.Client, filter string) string { + return storage.SubscriptionKey + "_" + cl.ID + ":" + filter +} + +// retainedKey returns a primary key for a retained message. +func retainedKey(topic string) string { + return storage.RetainedKey + "_" + topic +} + +// inflightKey returns a primary key for an inflight message. +func inflightKey(cl *mqtt.Client, pk packets.Packet) string { + return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID() +} + +// sysInfoKey returns a primary key for system info. +func sysInfoKey() string { + return storage.SysInfoKey +} + +// keyUpperBound returns the upper bound for a given byte slice by incrementing the last byte. +// It returns nil if all bytes are incremented and equal to 0. +func keyUpperBound(b []byte) []byte { + end := make([]byte, len(b)) + copy(end, b) + for i := len(end) - 1; i >= 0; i-- { + end[i] = end[i] + 1 + if end[i] != 0 { + return end[:i+1] + } + } + return nil +} + +const ( + NoSync = "NoSync" // NoSync specifies the default write options for writes which do not synchronize to disk. + Sync = "Sync" // Sync specifies the default write options for writes which synchronize to disk. +) + +// Options contains configuration settings for the pebble DB instance. +type Options struct { + Options *pebbledb.Options + Mode string `yaml:"mode" json:"mode"` + Path string `yaml:"path" json:"path"` +} + +// Hook is a persistent storage hook based using pebble DB file store as a backend. +type Hook struct { + mqtt.HookBase + config *Options // options for configuring the pebble DB instance. + db *pebbledb.DB // the pebble DB instance + mode *pebbledb.WriteOptions // mode holds the optional per-query parameters for Set and Delete operations +} + +// ID returns the id of the hook. +func (h *Hook) ID() string { + return "pebble-db" +} + +// Provides indicates which hook methods this hook provides. +func (h *Hook) Provides(b byte) bool { + return bytes.Contains([]byte{ + mqtt.OnSessionEstablished, + mqtt.OnDisconnect, + mqtt.OnSubscribed, + mqtt.OnUnsubscribed, + mqtt.OnRetainMessage, + mqtt.OnWillSent, + mqtt.OnQosPublish, + mqtt.OnQosComplete, + mqtt.OnQosDropped, + mqtt.OnSysInfoTick, + mqtt.OnClientExpired, + mqtt.OnRetainedExpired, + mqtt.StoredClients, + mqtt.StoredInflightMessages, + mqtt.StoredRetainedMessages, + mqtt.StoredSubscriptions, + mqtt.StoredSysInfo, + }, []byte{b}) +} + +// Init initializes and connects to the pebble instance. +func (h *Hook) Init(config any) error { + if _, ok := config.(*Options); !ok && config != nil { + return mqtt.ErrInvalidConfigType + } + + if config == nil { + h.config = new(Options) + } else { + h.config = config.(*Options) + } + + if len(h.config.Path) == 0 { + h.config.Path = defaultDbFile + } + + if h.config.Options == nil { + h.config.Options = &pebbledb.Options{} + } + + h.mode = pebbledb.NoSync + if strings.EqualFold(h.config.Mode, "Sync") { + h.mode = pebbledb.Sync + } + + var err error + h.db, err = pebbledb.Open(h.config.Path, h.config.Options) + if err != nil { + return err + } + + return nil +} + +// Stop closes the pebble instance. +func (h *Hook) Stop() error { + err := h.db.Close() + h.db = nil + return err +} + +// OnSessionEstablished adds a client to the store when their session is established. +func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { + h.updateClient(cl) +} + +// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record. +func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { + h.updateClient(cl) +} + +// updateClient writes the client data to the store. +func (h *Hook) updateClient(cl *mqtt.Client) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + props := cl.Properties.Props.Copy(false) + in := &storage.Client{ + ID: cl.ID, + T: storage.ClientKey, + Remote: cl.Net.Remote, + Listener: cl.Net.Listener, + Username: cl.Properties.Username, + Clean: cl.Properties.Clean, + ProtocolVersion: cl.Properties.ProtocolVersion, + Properties: storage.ClientProperties{ + SessionExpiryInterval: props.SessionExpiryInterval, + AuthenticationMethod: props.AuthenticationMethod, + AuthenticationData: props.AuthenticationData, + RequestProblemInfo: props.RequestProblemInfo, + RequestResponseInfo: props.RequestResponseInfo, + ReceiveMaximum: props.ReceiveMaximum, + TopicAliasMaximum: props.TopicAliasMaximum, + User: props.User, + MaximumPacketSize: props.MaximumPacketSize, + }, + Will: storage.ClientWill(cl.Properties.Will), + } + h.setKv(clientKey(cl), in) +} + +// OnDisconnect removes a client from the store if their session has expired. +func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + h.updateClient(cl) + + if !expire { + return + } + + if errors.Is(cl.StopCause(), packets.ErrSessionTakenOver) { + return + } + + h.delKv(clientKey(cl)) +} + +// OnSubscribed adds one or more client subscriptions to the store. +func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + var in *storage.Subscription + for i := 0; i < len(pk.Filters); i++ { + in = &storage.Subscription{ + ID: subscriptionKey(cl, pk.Filters[i].Filter), + T: storage.SubscriptionKey, + Client: cl.ID, + Qos: reasonCodes[i], + Filter: pk.Filters[i].Filter, + Identifier: pk.Filters[i].Identifier, + NoLocal: pk.Filters[i].NoLocal, + RetainHandling: pk.Filters[i].RetainHandling, + RetainAsPublished: pk.Filters[i].RetainAsPublished, + } + h.setKv(in.ID, in) + } +} + +// OnUnsubscribed removes one or more client subscriptions from the store. +func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + for i := 0; i < len(pk.Filters); i++ { + h.delKv(subscriptionKey(cl, pk.Filters[i].Filter)) + } +} + +// OnRetainMessage adds a retained message for a topic to the store. +func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + if r == -1 { + h.delKv(retainedKey(pk.TopicName)) + return + } + + props := pk.Properties.Copy(false) + in := &storage.Message{ + ID: retainedKey(pk.TopicName), + T: storage.RetainedKey, + FixedHeader: pk.FixedHeader, + TopicName: pk.TopicName, + Payload: pk.Payload, + Created: pk.Created, + Origin: pk.Origin, + Properties: storage.MessageProperties{ + PayloadFormat: props.PayloadFormat, + MessageExpiryInterval: props.MessageExpiryInterval, + ContentType: props.ContentType, + ResponseTopic: props.ResponseTopic, + CorrelationData: props.CorrelationData, + SubscriptionIdentifier: props.SubscriptionIdentifier, + TopicAlias: props.TopicAlias, + User: props.User, + }, + } + + h.setKv(in.ID, in) +} + +// OnQosPublish adds or updates an inflight message in the store. +func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + props := pk.Properties.Copy(false) + in := &storage.Message{ + ID: inflightKey(cl, pk), + T: storage.InflightKey, + Origin: pk.Origin, + PacketID: pk.PacketID, + FixedHeader: pk.FixedHeader, + TopicName: pk.TopicName, + Payload: pk.Payload, + Sent: sent, + Created: pk.Created, + Properties: storage.MessageProperties{ + PayloadFormat: props.PayloadFormat, + MessageExpiryInterval: props.MessageExpiryInterval, + ContentType: props.ContentType, + ResponseTopic: props.ResponseTopic, + CorrelationData: props.CorrelationData, + SubscriptionIdentifier: props.SubscriptionIdentifier, + TopicAlias: props.TopicAlias, + User: props.User, + }, + } + h.setKv(in.ID, in) +} + +// OnQosComplete removes a resolved inflight message from the store. +func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + h.delKv(inflightKey(cl, pk)) +} + +// OnQosDropped removes a dropped inflight message from the store. +func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + } + + h.OnQosComplete(cl, pk) +} + +// OnSysInfoTick stores the latest system info in the store. +func (h *Hook) OnSysInfoTick(sys *system.Info) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + in := &storage.SystemInfo{ + ID: sysInfoKey(), + T: storage.SysInfoKey, + Info: *sys.Clone(), + } + h.setKv(in.ID, in) +} + +// OnRetainedExpired deletes expired retained messages from the store. +func (h *Hook) OnRetainedExpired(filter string) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + h.delKv(retainedKey(filter)) +} + +// OnClientExpired deleted expired clients from the store. +func (h *Hook) OnClientExpired(cl *mqtt.Client) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + h.delKv(clientKey(cl)) +} + +// StoredClients returns all stored clients from the store. +func (h *Hook) StoredClients() (v []storage.Client, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + iter, _ := h.db.NewIter(&pebbledb.IterOptions{ + LowerBound: []byte(storage.ClientKey), + UpperBound: keyUpperBound([]byte(storage.ClientKey)), + }) + + for iter.First(); iter.Valid(); iter.Next() { + item := storage.Client{} + if err := item.UnmarshalBinary(iter.Value()); err == nil { + v = append(v, item) + } + } + return v, nil +} + +// StoredSubscriptions returns all stored subscriptions from the store. +func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + iter, _ := h.db.NewIter(&pebbledb.IterOptions{ + LowerBound: []byte(storage.SubscriptionKey), + UpperBound: keyUpperBound([]byte(storage.SubscriptionKey)), + }) + + for iter.First(); iter.Valid(); iter.Next() { + item := storage.Subscription{} + if err := item.UnmarshalBinary(iter.Value()); err == nil { + v = append(v, item) + } + } + return v, nil +} + +// StoredRetainedMessages returns all stored retained messages from the store. +func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + iter, _ := h.db.NewIter(&pebbledb.IterOptions{ + LowerBound: []byte(storage.RetainedKey), + UpperBound: keyUpperBound([]byte(storage.RetainedKey)), + }) + + for iter.First(); iter.Valid(); iter.Next() { + item := storage.Message{} + if err := item.UnmarshalBinary(iter.Value()); err == nil { + v = append(v, item) + } + } + return v, nil +} + +// StoredInflightMessages returns all stored inflight messages from the store. +func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + iter, _ := h.db.NewIter(&pebbledb.IterOptions{ + LowerBound: []byte(storage.InflightKey), + UpperBound: keyUpperBound([]byte(storage.InflightKey)), + }) + + for iter.First(); iter.Valid(); iter.Next() { + item := storage.Message{} + if err := item.UnmarshalBinary(iter.Value()); err == nil { + v = append(v, item) + } + } + return v, nil +} + +// StoredSysInfo returns the system info from the store. +func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + err = h.getKv(sysInfoKey(), &v) + if errors.Is(err, pebbledb.ErrNotFound) { + return v, nil + } + + return +} + +// Errorf satisfies the pebble interface for an error logger. +func (h *Hook) Errorf(m string, v ...any) { + h.Log.Error(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) + +} + +// Warningf satisfies the pebble interface for a warning logger. +func (h *Hook) Warningf(m string, v ...any) { + h.Log.Warn(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) +} + +// Infof satisfies the pebble interface for an info logger. +func (h *Hook) Infof(m string, v ...any) { + h.Log.Info(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) +} + +// Debugf satisfies the pebble interface for a debug logger. +func (h *Hook) Debugf(m string, v ...any) { + h.Log.Debug(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) +} + +// delKv deletes a key-value pair from the database. +func (h *Hook) delKv(k string) error { + err := h.db.Delete([]byte(k), h.mode) + if err != nil { + h.Log.Error("failed to delete data", "error", err, "key", k) + return err + } + return nil +} + +// setKv stores a key-value pair in the database. +func (h *Hook) setKv(k string, v storage.Serializable) error { + bs, _ := v.MarshalBinary() + err := h.db.Set([]byte(k), bs, h.mode) + if err != nil { + h.Log.Error("failed to update data", "error", err, "key", k) + return err + } + return nil +} + +// getKv retrieves the value associated with a key from the database. +func (h *Hook) getKv(k string, v storage.Serializable) error { + value, closer, err := h.db.Get([]byte(k)) + if err != nil { + return err + } + + defer func() { + if closer != nil { + closer.Close() + } + }() + return v.UnmarshalBinary(value) +} diff --git a/hooks/storage/pebble/pebble_test.go b/hooks/storage/pebble/pebble_test.go new file mode 100644 index 0000000..eafc3f4 --- /dev/null +++ b/hooks/storage/pebble/pebble_test.go @@ -0,0 +1,812 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: werbenhu + +package pebble + +import ( + "log/slog" + "os" + "strings" + "testing" + "time" + + pebbledb "github.com/cockroachdb/pebble" + "github.com/stretchr/testify/require" + "testmqtt/hooks/storage" + "testmqtt/mqtt" + "testmqtt/packets" + "testmqtt/system" +) + +var ( + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) + + client = &mqtt.Client{ + ID: "test", + Net: mqtt.ClientConnection{ + Remote: "test.addr", + Listener: "listener", + }, + Properties: mqtt.ClientProperties{ + Username: []byte("username"), + Clean: false, + }, + } + + pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}} +) + +func teardown(t *testing.T, path string, h *Hook) { + _ = h.Stop() + err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1)) + require.NoError(t, err) +} + +func TestKeyUpperBound(t *testing.T) { + // Test case 1: Non-nil case + input1 := []byte{97, 98, 99} // "abc" + require.NotNil(t, keyUpperBound(input1)) + + // Test case 2: All bytes are 255 + input2 := []byte{255, 255, 255} + require.Nil(t, keyUpperBound(input2)) + + // Test case 3: Empty slice + input3 := []byte{} + require.Nil(t, keyUpperBound(input3)) + + // Test case 4: Nil case + input4 := []byte{255, 255, 255} + require.Nil(t, keyUpperBound(input4)) +} + +func TestClientKey(t *testing.T) { + k := clientKey(&mqtt.Client{ID: "cl1"}) + require.Equal(t, storage.ClientKey+"_cl1", k) +} + +func TestSubscriptionKey(t *testing.T) { + k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c") + require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k) +} + +func TestRetainedKey(t *testing.T) { + k := retainedKey("a/b/c") + require.Equal(t, storage.RetainedKey+"_a/b/c", k) +} + +func TestInflightKey(t *testing.T) { + k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1}) + require.Equal(t, storage.InflightKey+"_cl1:1", k) +} + +func TestSysInfoKey(t *testing.T) { + require.Equal(t, storage.SysInfoKey, sysInfoKey()) +} + +func TestID(t *testing.T) { + h := new(Hook) + require.Equal(t, "pebble-db", h.ID()) +} + +func TestProvides(t *testing.T) { + h := new(Hook) + require.True(t, h.Provides(mqtt.OnSessionEstablished)) + require.True(t, h.Provides(mqtt.OnDisconnect)) + require.True(t, h.Provides(mqtt.OnSubscribed)) + require.True(t, h.Provides(mqtt.OnUnsubscribed)) + require.True(t, h.Provides(mqtt.OnRetainMessage)) + require.True(t, h.Provides(mqtt.OnQosPublish)) + require.True(t, h.Provides(mqtt.OnQosComplete)) + require.True(t, h.Provides(mqtt.OnQosDropped)) + require.True(t, h.Provides(mqtt.OnSysInfoTick)) + require.True(t, h.Provides(mqtt.StoredClients)) + require.True(t, h.Provides(mqtt.StoredInflightMessages)) + require.True(t, h.Provides(mqtt.StoredRetainedMessages)) + require.True(t, h.Provides(mqtt.StoredSubscriptions)) + require.True(t, h.Provides(mqtt.StoredSysInfo)) + require.False(t, h.Provides(mqtt.OnACLCheck)) + require.False(t, h.Provides(mqtt.OnConnectAuthenticate)) +} + +func TestInitBadConfig(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + err := h.Init(map[string]any{}) + require.Error(t, err) +} + +func TestInitErr(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + err := h.Init(&Options{ + Options: &pebbledb.Options{ + ReadOnly: true, + }, + }) + require.Error(t, err) +} + +func TestInitUseDefaults(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + require.Equal(t, defaultDbFile, h.config.Path) +} + +func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + h.OnSessionEstablished(client, packets.Packet{}) + + r := new(storage.Client) + err = h.getKv(clientKey(client), r) + require.NoError(t, err) + require.Equal(t, client.ID, r.ID) + require.Equal(t, client.Properties.Username, r.Username) + require.Equal(t, client.Properties.Clean, r.Clean) + require.Equal(t, client.Net.Remote, r.Remote) + require.Equal(t, client.Net.Listener, r.Listener) + require.NotSame(t, client, r) + + h.OnDisconnect(client, nil, false) + r2 := new(storage.Client) + err = h.getKv(clientKey(client), r2) + require.NoError(t, err) + require.Equal(t, client.ID, r.ID) + + h.OnDisconnect(client, nil, true) + r3 := new(storage.Client) + err = h.getKv(clientKey(client), r3) + require.Error(t, err) + require.ErrorIs(t, pebbledb.ErrNotFound, err) + require.Empty(t, r3.ID) + +} + +func TestOnClientExpired(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + cl := &mqtt.Client{ID: "cl1"} + clientKey := clientKey(cl) + + err = h.setKv(clientKey, &storage.Client{ID: cl.ID}) + require.NoError(t, err) + + r := new(storage.Client) + err = h.getKv(clientKey, r) + require.NoError(t, err) + require.Equal(t, cl.ID, r.ID) + + h.OnClientExpired(cl) + err = h.getKv(clientKey, r) + require.Error(t, err) + require.ErrorIs(t, err, pebbledb.ErrNotFound) +} + +func TestOnClientExpiredNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnClientExpired(client) +} + +func TestOnClientExpiredClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnClientExpired(client) +} + +func TestOnSessionEstablishedNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnSessionEstablished(client, packets.Packet{}) +} + +func TestOnSessionEstablishedClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnSessionEstablished(client, packets.Packet{}) +} + +func TestOnWillSent(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + c1 := client + c1.Properties.Will.Flag = 1 + h.OnWillSent(c1, packets.Packet{}) + + r := new(storage.Client) + err = h.getKv(clientKey(client), r) + require.NoError(t, err) + + require.Equal(t, uint32(1), r.Will.Flag) + require.NotSame(t, client, r) +} + +func TestOnDisconnectNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnDisconnect(client, nil, false) +} + +func TestOnDisconnectClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnDisconnect(client, nil, false) +} + +func TestOnDisconnectSessionTakenOver(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + + testClient := &mqtt.Client{ + ID: "test", + Net: mqtt.ClientConnection{ + Remote: "test.addr", + Listener: "listener", + }, + Properties: mqtt.ClientProperties{ + Username: []byte("username"), + Clean: false, + }, + } + + testClient.Stop(packets.ErrSessionTakenOver) + h.OnDisconnect(testClient, nil, true) + teardown(t, h.config.Path, h) +} + +func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + h.OnSubscribed(client, pkf, []byte{0}) + r := new(storage.Subscription) + + err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r) + require.NoError(t, err) + require.Equal(t, client.ID, r.Client) + require.Equal(t, pkf.Filters[0].Filter, r.Filter) + require.Equal(t, byte(0), r.Qos) + + h.OnUnsubscribed(client, pkf) + err = h.getKv(subscriptionKey(client, pkf.Filters[0].Filter), r) + require.Error(t, err) + require.Equal(t, pebbledb.ErrNotFound, err) +} + +func TestOnSubscribedNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnSubscribed(client, pkf, []byte{0}) +} + +func TestOnSubscribedClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnSubscribed(client, pkf, []byte{0}) +} + +func TestOnUnsubscribedNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnUnsubscribed(client, pkf) +} + +func TestOnUnsubscribedClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnUnsubscribed(client, pkf) +} + +func TestOnRetainMessageThenUnset(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + pk := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Retain: true, + }, + Payload: []byte("hello"), + TopicName: "a/b/c", + } + + h.OnRetainMessage(client, pk, 1) + + r := new(storage.Message) + err = h.getKv(retainedKey(pk.TopicName), r) + require.NoError(t, err) + require.Equal(t, pk.TopicName, r.TopicName) + require.Equal(t, pk.Payload, r.Payload) + + h.OnRetainMessage(client, pk, -1) + err = h.getKv(retainedKey(pk.TopicName), r) + require.Error(t, err) + require.ErrorIs(t, err, pebbledb.ErrNotFound) + + // coverage: delete deleted + h.OnRetainMessage(client, pk, -1) + err = h.getKv(retainedKey(pk.TopicName), r) + require.Error(t, err) + require.ErrorIs(t, err, pebbledb.ErrNotFound) +} + +func TestOnRetainedExpired(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + m := &storage.Message{ + ID: retainedKey("a/b/c"), + T: storage.RetainedKey, + TopicName: "a/b/c", + } + + err = h.setKv(m.ID, m) + require.NoError(t, err) + + r := new(storage.Message) + err = h.getKv(m.ID, r) + require.NoError(t, err) + require.Equal(t, m.TopicName, r.TopicName) + + h.OnRetainedExpired(m.TopicName) + err = h.getKv(m.ID, r) + require.Error(t, err) + require.ErrorIs(t, err, pebbledb.ErrNotFound) +} + +func TestOnRetainExpiredNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnRetainedExpired("a/b/c") +} + +func TestOnRetainExpiredClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnRetainedExpired("a/b/c") +} + +func TestOnRetainMessageNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnRetainMessage(client, packets.Packet{}, 0) +} + +func TestOnRetainMessageClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnRetainMessage(client, packets.Packet{}, 0) +} + +func TestOnQosPublishThenQOSComplete(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + pk := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Retain: true, + Qos: 2, + }, + Payload: []byte("hello"), + TopicName: "a/b/c", + } + + h.OnQosPublish(client, pk, time.Now().Unix(), 0) + + r := new(storage.Message) + err = h.getKv(inflightKey(client, pk), r) + require.NoError(t, err) + require.Equal(t, pk.TopicName, r.TopicName) + require.Equal(t, pk.Payload, r.Payload) + + // ensure dates are properly saved + require.True(t, r.Sent > 0) + require.True(t, time.Now().Unix()-1 < r.Sent) + + // OnQosDropped is a passthrough to OnQosComplete here + h.OnQosDropped(client, pk) + err = h.getKv(inflightKey(client, pk), r) + require.Error(t, err) + require.ErrorIs(t, err, pebbledb.ErrNotFound) +} + +func TestOnQosPublishNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) +} + +func TestOnQosPublishClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) +} + +func TestOnQosCompleteNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnQosComplete(client, packets.Packet{}) +} + +func TestOnQosCompleteClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnQosComplete(client, packets.Packet{}) +} + +func TestOnQosDroppedNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnQosDropped(client, packets.Packet{}) +} + +func TestOnSysInfoTick(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + info := &system.Info{ + Version: "2.0.0", + BytesReceived: 100, + } + + h.OnSysInfoTick(info) + + r := new(storage.SystemInfo) + err = h.getKv(storage.SysInfoKey, r) + require.NoError(t, err) + require.Equal(t, info.Version, r.Version) + require.Equal(t, info.BytesReceived, r.BytesReceived) + require.NotSame(t, info, r) +} + +func TestOnSysInfoTickNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + h.OnSysInfoTick(new(system.Info)) +} + +func TestOnSysInfoTickClosedDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + teardown(t, h.config.Path, h) + h.OnSysInfoTick(new(system.Info)) +} + +func TestStoredClients(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + // populate with clients + err = h.setKv(storage.ClientKey+"_cl1", &storage.Client{ID: "cl1"}) + require.NoError(t, err) + + err = h.setKv(storage.ClientKey+"_cl2", &storage.Client{ID: "cl2"}) + require.NoError(t, err) + + err = h.setKv(storage.ClientKey+"_cl3", &storage.Client{ID: "cl3"}) + require.NoError(t, err) + + r, err := h.StoredClients() + require.NoError(t, err) + require.Len(t, r, 3) + require.Equal(t, "cl1", r[0].ID) + require.Equal(t, "cl2", r[1].ID) + require.Equal(t, "cl3", r[2].ID) +} + +func TestStoredClientsNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredClients() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestStoredSubscriptions(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + // populate with subscriptions + err = h.setKv(storage.SubscriptionKey+"_sub1", &storage.Subscription{ID: "sub1"}) + require.NoError(t, err) + + err = h.setKv(storage.SubscriptionKey+"_sub2", &storage.Subscription{ID: "sub2"}) + require.NoError(t, err) + + err = h.setKv(storage.SubscriptionKey+"_sub3", &storage.Subscription{ID: "sub3"}) + require.NoError(t, err) + + r, err := h.StoredSubscriptions() + require.NoError(t, err) + require.Len(t, r, 3) + require.Equal(t, "sub1", r[0].ID) + require.Equal(t, "sub2", r[1].ID) + require.Equal(t, "sub3", r[2].ID) +} + +func TestStoredSubscriptionsNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredSubscriptions() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestStoredRetainedMessages(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + // populate with messages + err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1"}) + require.NoError(t, err) + + err = h.setKv(storage.RetainedKey+"_m2", &storage.Message{ID: "m2"}) + require.NoError(t, err) + + err = h.setKv(storage.RetainedKey+"_m3", &storage.Message{ID: "m3"}) + require.NoError(t, err) + + err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3"}) + require.NoError(t, err) + + r, err := h.StoredRetainedMessages() + require.NoError(t, err) + require.Len(t, r, 3) + require.Equal(t, "m1", r[0].ID) + require.Equal(t, "m2", r[1].ID) + require.Equal(t, "m3", r[2].ID) +} + +func TestStoredRetainedMessagesNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredRetainedMessages() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestStoredInflightMessages(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + // populate with messages + err = h.setKv(storage.InflightKey+"_i1", &storage.Message{ID: "i1"}) + require.NoError(t, err) + + err = h.setKv(storage.InflightKey+"_i2", &storage.Message{ID: "i2"}) + require.NoError(t, err) + + err = h.setKv(storage.InflightKey+"_i3", &storage.Message{ID: "i3"}) + require.NoError(t, err) + + err = h.setKv(storage.RetainedKey+"_m1", &storage.Message{ID: "m1"}) + require.NoError(t, err) + + r, err := h.StoredInflightMessages() + require.NoError(t, err) + require.Len(t, r, 3) + require.Equal(t, "i1", r[0].ID) + require.Equal(t, "i2", r[1].ID) + require.Equal(t, "i3", r[2].ID) +} + +func TestStoredInflightMessagesNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredInflightMessages() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestStoredSysInfo(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + r, err := h.StoredSysInfo() + require.NoError(t, err) + + // populate with messages + err = h.setKv(storage.SysInfoKey, &storage.SystemInfo{ + ID: storage.SysInfoKey, + Info: system.Info{ + Version: "2.0.0", + }, + T: storage.SysInfoKey, + }) + require.NoError(t, err) + + r, err = h.StoredSysInfo() + require.NoError(t, err) + require.Equal(t, "2.0.0", r.Info.Version) +} + +func TestStoredSysInfoNoDB(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + v, err := h.StoredSysInfo() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestErrorf(t *testing.T) { + // coverage: one day check log hook + h := new(Hook) + h.SetOpts(logger, nil) + h.Errorf("test", 1, 2, 3) +} + +func TestWarningf(t *testing.T) { + // coverage: one day check log hook + h := new(Hook) + h.SetOpts(logger, nil) + h.Warningf("test", 1, 2, 3) +} + +func TestInfof(t *testing.T) { + // coverage: one day check log hook + h := new(Hook) + h.SetOpts(logger, nil) + h.Infof("test", 1, 2, 3) +} + +func TestDebugf(t *testing.T) { + // coverage: one day check log hook + h := new(Hook) + h.SetOpts(logger, nil) + h.Debugf("test", 1, 2, 3) +} + +func TestGetSetDelKv(t *testing.T) { + opts := []struct { + name string + opt *Options + }{ + { + name: "NoSync", + opt: &Options{ + Mode: NoSync, + }, + }, + { + name: "Sync", + opt: &Options{ + Mode: Sync, + }, + }, + } + for _, tt := range opts { + t.Run(tt.name, func(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(tt.opt) + defer teardown(t, h.config.Path, h) + require.NoError(t, err) + + err = h.setKv("testKey", &storage.Client{ID: "testId"}) + require.NoError(t, err) + + var obj storage.Client + err = h.getKv("testKey", &obj) + require.NoError(t, err) + + err = h.delKv("testKey") + require.NoError(t, err) + + err = h.getKv("testKey", &obj) + require.Error(t, err) + require.ErrorIs(t, pebbledb.ErrNotFound, err) + }) + } +} + +func TestGetSetDelKvErr(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + err := h.Init(&Options{ + Mode: Sync, + }) + require.NoError(t, err) + + err = h.setKv("testKey", &storage.Client{ID: "testId"}) + require.NoError(t, err) + h.Stop() + + h = new(Hook) + h.SetOpts(logger, nil) + + err = h.Init(&Options{ + Mode: Sync, + Options: &pebbledb.Options{ + ReadOnly: true, + }, + }) + require.NoError(t, err) + defer teardown(t, h.config.Path, h) + + err = h.setKv("testKey", &storage.Client{ID: "testId"}) + require.Error(t, err) + + err = h.delKv("testKey") + require.Error(t, err) +} diff --git a/hooks/storage/redis/redis.go b/hooks/storage/redis/redis.go new file mode 100644 index 0000000..24ed872 --- /dev/null +++ b/hooks/storage/redis/redis.go @@ -0,0 +1,532 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package redis + +import ( + "bytes" + "context" + "errors" + "fmt" + + "testmqtt/hooks/storage" + "testmqtt/mqtt" + "testmqtt/packets" + "testmqtt/system" + + "github.com/go-redis/redis/v8" +) + +// defaultAddr 默认Redis地址 +// defaultAddr is the default address to the redis service. +const defaultAddr = "localhost:6379" + +// defaultHPrefix 默认前缀 +// defaultHPrefix is a prefix to better identify hsets created by mochi mqtt. +const defaultHPrefix = "mochi-" + +// clientKey returns a primary key for a client. +func clientKey(cl *mqtt.Client) string { + return cl.ID +} + +// subscriptionKey returns a primary key for a subscription. +func subscriptionKey(cl *mqtt.Client, filter string) string { + return cl.ID + ":" + filter +} + +// retainedKey returns a primary key for a retained message. +func retainedKey(topic string) string { + return topic +} + +// inflightKey returns a primary key for an inflight message. +func inflightKey(cl *mqtt.Client, pk packets.Packet) string { + return cl.ID + ":" + pk.FormatID() +} + +// sysInfoKey returns a primary key for system info. +func sysInfoKey() string { + return storage.SysInfoKey +} + +// Options contains configuration settings for the bolt instance. +type Options struct { + Address string `yaml:"address" json:"address"` + Username string `yaml:"username" json:"username"` + Password string `yaml:"password" json:"password"` + Database int `yaml:"database" json:"database"` + HPrefix string `yaml:"h_prefix" json:"h_prefix"` + Options *redis.Options +} + +// Hook is a persistent storage hook based using Redis as a backend. +type Hook struct { + mqtt.HookBase + config *Options // options for connecting to the Redis instance. + db *redis.Client // the Redis instance + ctx context.Context // a context for the connection +} + +// ID returns the id of the hook. +func (h *Hook) ID() string { + return "redis-db" +} + +// Provides indicates which hook methods this hook provides. +func (h *Hook) Provides(b byte) bool { + return bytes.Contains([]byte{ + mqtt.OnSessionEstablished, + mqtt.OnDisconnect, + mqtt.OnSubscribed, + mqtt.OnUnsubscribed, + mqtt.OnRetainMessage, + mqtt.OnQosPublish, + mqtt.OnQosComplete, + mqtt.OnQosDropped, + mqtt.OnWillSent, + mqtt.OnSysInfoTick, + mqtt.OnClientExpired, + mqtt.OnRetainedExpired, + mqtt.StoredClients, + mqtt.StoredInflightMessages, + mqtt.StoredRetainedMessages, + mqtt.StoredSubscriptions, + mqtt.StoredSysInfo, + }, []byte{b}) +} + +// hKey returns a hash set key with a unique prefix. +func (h *Hook) hKey(s string) string { + return h.config.HPrefix + s +} + +// Init 初始化并连接到 Redis 服务 +// Init initializes and connects to the redis service. +func (h *Hook) Init(config any) error { + + if _, ok := config.(*Options); !ok && config != nil { + return mqtt.ErrInvalidConfigType + } + + h.ctx = context.Background() + + if config == nil { + config = new(Options) + } + h.config = config.(*Options) + if h.config.Options == nil { + h.config.Options = &redis.Options{ + Addr: defaultAddr, + } + h.config.Options.Addr = h.config.Address + h.config.Options.DB = h.config.Database + //h.config.Options.Username = h.config.Username + h.config.Options.Password = h.config.Password + } + + if h.config.HPrefix == "" { + h.config.HPrefix = defaultHPrefix + } + + h.Log.Info( + "connecting to redis service", + "prefix", h.config.HPrefix, + "address", h.config.Options.Addr, + "username", h.config.Options.Username, + "password-len", len(h.config.Options.Password), + "db", h.config.Options.DB, + ) + + h.db = redis.NewClient(h.config.Options) + _, err := h.db.Ping(context.Background()).Result() + if err != nil { + return fmt.Errorf("failed to ping service: %w", err) + } + + h.Log.Info("connected to redis service") + + return nil +} + +// Stop closes the redis connection. +func (h *Hook) Stop() error { + h.Log.Info("disconnecting from redis service") + + return h.db.Close() +} + +// OnSessionEstablished adds a client to the store when their session is established. +func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { + h.updateClient(cl) +} + +// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record. +func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { + h.updateClient(cl) +} + +// updateClient writes the client data to the store. +func (h *Hook) updateClient(cl *mqtt.Client) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + props := cl.Properties.Props.Copy(false) + in := &storage.Client{ + ID: clientKey(cl), + T: storage.ClientKey, + Remote: cl.Net.Remote, + Listener: cl.Net.Listener, + Username: cl.Properties.Username, + Clean: cl.Properties.Clean, + ProtocolVersion: cl.Properties.ProtocolVersion, + Properties: storage.ClientProperties{ + SessionExpiryInterval: props.SessionExpiryInterval, + AuthenticationMethod: props.AuthenticationMethod, + AuthenticationData: props.AuthenticationData, + RequestProblemInfo: props.RequestProblemInfo, + RequestResponseInfo: props.RequestResponseInfo, + ReceiveMaximum: props.ReceiveMaximum, + TopicAliasMaximum: props.TopicAliasMaximum, + User: props.User, + MaximumPacketSize: props.MaximumPacketSize, + }, + Will: storage.ClientWill(cl.Properties.Will), + } + + err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err() + if err != nil { + h.Log.Error("failed to hset client data", "error", err, "data", in) + } +} + +// OnDisconnect removes a client from the store if they were using a clean session. +func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + if !expire { + return + } + + if cl.StopCause() == packets.ErrSessionTakenOver { + return + } + + err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err() + if err != nil { + h.Log.Error("failed to delete client", "error", err, "id", clientKey(cl)) + } +} + +// OnSubscribed adds one or more client subscriptions to the store. +func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + var in *storage.Subscription + for i := 0; i < len(pk.Filters); i++ { + in = &storage.Subscription{ + ID: subscriptionKey(cl, pk.Filters[i].Filter), + T: storage.SubscriptionKey, + Client: cl.ID, + Qos: reasonCodes[i], + Filter: pk.Filters[i].Filter, + Identifier: pk.Filters[i].Identifier, + NoLocal: pk.Filters[i].NoLocal, + RetainHandling: pk.Filters[i].RetainHandling, + RetainAsPublished: pk.Filters[i].RetainAsPublished, + } + + err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err() + if err != nil { + h.Log.Error("failed to hset subscription data", "error", err, "data", in) + } + } +} + +// OnUnsubscribed removes one or more client subscriptions from the store. +func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + for i := 0; i < len(pk.Filters); i++ { + err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err() + if err != nil { + h.Log.Error("failed to delete subscription data", "error", err, "id", clientKey(cl)) + } + } +} + +// OnRetainMessage adds a retained message for a topic to the store. +func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + if r == -1 { + err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err() + if err != nil { + h.Log.Error("failed to delete retained message data", "error", err, "id", retainedKey(pk.TopicName)) + } + + return + } + + props := pk.Properties.Copy(false) + in := &storage.Message{ + ID: retainedKey(pk.TopicName), + T: storage.RetainedKey, + FixedHeader: pk.FixedHeader, + TopicName: pk.TopicName, + Payload: pk.Payload, + Created: pk.Created, + Origin: pk.Origin, + Properties: storage.MessageProperties{ + PayloadFormat: props.PayloadFormat, + MessageExpiryInterval: props.MessageExpiryInterval, + ContentType: props.ContentType, + ResponseTopic: props.ResponseTopic, + CorrelationData: props.CorrelationData, + SubscriptionIdentifier: props.SubscriptionIdentifier, + TopicAlias: props.TopicAlias, + User: props.User, + }, + } + + err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err() + if err != nil { + h.Log.Error("failed to hset retained message data", "error", err, "data", in) + } +} + +// OnQosPublish adds or updates an inflight message in the store. +func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + props := pk.Properties.Copy(false) + in := &storage.Message{ + ID: inflightKey(cl, pk), + T: storage.InflightKey, + Origin: pk.Origin, + FixedHeader: pk.FixedHeader, + TopicName: pk.TopicName, + Payload: pk.Payload, + Sent: sent, + Created: pk.Created, + Properties: storage.MessageProperties{ + PayloadFormat: props.PayloadFormat, + MessageExpiryInterval: props.MessageExpiryInterval, + ContentType: props.ContentType, + ResponseTopic: props.ResponseTopic, + CorrelationData: props.CorrelationData, + SubscriptionIdentifier: props.SubscriptionIdentifier, + TopicAlias: props.TopicAlias, + User: props.User, + }, + } + + err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err() + if err != nil { + h.Log.Error("failed to hset qos inflight message data", "error", err, "data", in) + } +} + +// OnQosComplete removes a resolved inflight message from the store. +func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err() + if err != nil { + h.Log.Error("failed to delete qos inflight message data", "error", err, "id", inflightKey(cl, pk)) + } +} + +// OnQosDropped removes a dropped inflight message from the store. +func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + } + + h.OnQosComplete(cl, pk) +} + +// OnSysInfoTick stores the latest system info in the store. +func (h *Hook) OnSysInfoTick(sys *system.Info) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + in := &storage.SystemInfo{ + ID: sysInfoKey(), + T: storage.SysInfoKey, + Info: *sys, + } + + err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err() + if err != nil { + h.Log.Error("failed to hset server info data", "error", err, "data", in) + } +} + +// OnRetainedExpired deletes expired retained messages from the store. +func (h *Hook) OnRetainedExpired(filter string) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err() + if err != nil { + h.Log.Error("failed to delete expired retained message", "error", err, "id", retainedKey(filter)) + } +} + +// OnClientExpired deleted expired clients from the store. +func (h *Hook) OnClientExpired(cl *mqtt.Client) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err() + if err != nil { + h.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl)) + } +} + +// StoredClients returns all stored clients from the store. +func (h *Hook) StoredClients() (v []storage.Client, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result() + if err != nil && !errors.Is(err, redis.Nil) { + h.Log.Error("failed to HGetAll client data", "error", err) + return + } + + for _, row := range rows { + var d storage.Client + if err = d.UnmarshalBinary([]byte(row)); err != nil { + h.Log.Error("failed to unmarshal client data", "error", err, "data", row) + } + + v = append(v, d) + } + + return v, nil +} + +// StoredSubscriptions returns all stored subscriptions from the store. +func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result() + if err != nil && !errors.Is(err, redis.Nil) { + h.Log.Error("failed to HGetAll subscription data", "error", err) + return + } + + for _, row := range rows { + var d storage.Subscription + if err = d.UnmarshalBinary([]byte(row)); err != nil { + h.Log.Error("failed to unmarshal subscription data", "error", err, "data", row) + } + + v = append(v, d) + } + + return v, nil +} + +// StoredRetainedMessages returns all stored retained messages from the store. +func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result() + if err != nil && !errors.Is(err, redis.Nil) { + h.Log.Error("failed to HGetAll retained message data", "error", err) + return + } + + for _, row := range rows { + var d storage.Message + if err = d.UnmarshalBinary([]byte(row)); err != nil { + h.Log.Error("failed to unmarshal retained message data", "error", err, "data", row) + } + + v = append(v, d) + } + + return v, nil +} + +// StoredInflightMessages returns all stored inflight messages from the store. +func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result() + if err != nil && !errors.Is(err, redis.Nil) { + h.Log.Error("failed to HGetAll inflight message data", "error", err) + return + } + + for _, row := range rows { + var d storage.Message + if err = d.UnmarshalBinary([]byte(row)); err != nil { + h.Log.Error("failed to unmarshal inflight message data", "error", err, "data", row) + } + + v = append(v, d) + } + + return v, nil +} + +// StoredSysInfo returns the system info from the store. +func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { + if h.db == nil { + h.Log.Error("", "error", storage.ErrDBFileNotOpen) + return + } + + row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return + } + + if err = v.UnmarshalBinary([]byte(row)); err != nil { + h.Log.Error("failed to unmarshal sys info data", "error", err, "data", row) + } + + return v, nil +} diff --git a/hooks/storage/redis/redis_test.go b/hooks/storage/redis/redis_test.go new file mode 100644 index 0000000..e3a5c90 --- /dev/null +++ b/hooks/storage/redis/redis_test.go @@ -0,0 +1,834 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package redis + +import ( + "log/slog" + "os" + "sort" + "testing" + "time" + + "testmqtt/hooks/storage" + "testmqtt/mqtt" + "testmqtt/packets" + "testmqtt/system" + + miniredis "github.com/alicebob/miniredis/v2" + redis "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/require" +) + +var ( + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) + + client = &mqtt.Client{ + ID: "test", + Net: mqtt.ClientConnection{ + Remote: "test.addr", + Listener: "listener", + }, + Properties: mqtt.ClientProperties{ + Username: []byte("username"), + Clean: false, + }, + } + + pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}} +) + +func newHook(t *testing.T, addr string) *Hook { + h := new(Hook) + h.SetOpts(logger, nil) + + err := h.Init(&Options{ + Options: &redis.Options{ + Addr: addr, + }, + }) + require.NoError(t, err) + + return h +} + +func teardown(t *testing.T, h *Hook) { + if h.db != nil { + err := h.db.FlushAll(h.ctx).Err() + require.NoError(t, err) + h.Stop() + } +} + +func TestClientKey(t *testing.T) { + k := clientKey(&mqtt.Client{ID: "cl1"}) + require.Equal(t, "cl1", k) +} + +func TestSubscriptionKey(t *testing.T) { + k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c") + require.Equal(t, "cl1:a/b/c", k) +} + +func TestRetainedKey(t *testing.T) { + k := retainedKey("a/b/c") + require.Equal(t, "a/b/c", k) +} + +func TestInflightKey(t *testing.T) { + k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1}) + require.Equal(t, "cl1:1", k) +} + +func TestSysInfoKey(t *testing.T) { + require.Equal(t, storage.SysInfoKey, sysInfoKey()) +} + +func TestID(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + require.Equal(t, "redis-db", h.ID()) +} + +func TestProvides(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + require.True(t, h.Provides(mqtt.OnSessionEstablished)) + require.True(t, h.Provides(mqtt.OnDisconnect)) + require.True(t, h.Provides(mqtt.OnSubscribed)) + require.True(t, h.Provides(mqtt.OnUnsubscribed)) + require.True(t, h.Provides(mqtt.OnRetainMessage)) + require.True(t, h.Provides(mqtt.OnQosPublish)) + require.True(t, h.Provides(mqtt.OnQosComplete)) + require.True(t, h.Provides(mqtt.OnQosDropped)) + require.True(t, h.Provides(mqtt.OnSysInfoTick)) + require.True(t, h.Provides(mqtt.StoredClients)) + require.True(t, h.Provides(mqtt.StoredInflightMessages)) + require.True(t, h.Provides(mqtt.StoredRetainedMessages)) + require.True(t, h.Provides(mqtt.StoredSubscriptions)) + require.True(t, h.Provides(mqtt.StoredSysInfo)) + require.False(t, h.Provides(mqtt.OnACLCheck)) + require.False(t, h.Provides(mqtt.OnConnectAuthenticate)) +} + +func TestHKey(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.SetOpts(logger, nil) + require.Equal(t, defaultHPrefix+"test", h.hKey("test")) +} + +func TestInitUseDefaults(t *testing.T) { + s := miniredis.RunT(t) + s.StartAddr(defaultAddr) + defer s.Close() + + h := newHook(t, defaultAddr) + h.SetOpts(logger, nil) + err := h.Init(nil) + require.NoError(t, err) + defer teardown(t, h) + + require.Equal(t, defaultHPrefix, h.config.HPrefix) + require.Equal(t, defaultAddr, h.config.Options.Addr) +} + +func TestInitUsePassConfig(t *testing.T) { + s := miniredis.RunT(t) + s.StartAddr(defaultAddr) + defer s.Close() + + h := newHook(t, "") + h.SetOpts(logger, nil) + + err := h.Init(&Options{ + Address: defaultAddr, + Username: "username", + Password: "password", + Database: 2, + }) + require.Error(t, err) + h.db.FlushAll(h.ctx) + + require.Equal(t, defaultAddr, h.config.Options.Addr) + require.Equal(t, "username", h.config.Options.Username) + require.Equal(t, "password", h.config.Options.Password) + require.Equal(t, 2, h.config.Options.DB) +} + +func TestInitBadConfig(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + + err := h.Init(map[string]any{}) + require.Error(t, err) +} + +func TestInitBadAddr(t *testing.T) { + h := new(Hook) + h.SetOpts(logger, nil) + err := h.Init(&Options{ + Options: &redis.Options{ + Addr: "abc:123", + }, + }) + require.Error(t, err) +} + +func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + defer teardown(t, h) + + h.OnSessionEstablished(client, packets.Packet{}) + + r := new(storage.Client) + row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result() + require.NoError(t, err) + err = r.UnmarshalBinary([]byte(row)) + require.NoError(t, err) + + require.Equal(t, client.ID, r.ID) + require.Equal(t, client.Net.Remote, r.Remote) + require.Equal(t, client.Net.Listener, r.Listener) + require.Equal(t, client.Properties.Username, r.Username) + require.Equal(t, client.Properties.Clean, r.Clean) + require.NotSame(t, client, r) + + h.OnDisconnect(client, nil, false) + r2 := new(storage.Client) + row, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result() + require.NoError(t, err) + err = r2.UnmarshalBinary([]byte(row)) + require.NoError(t, err) + require.Equal(t, client.ID, r.ID) + + h.OnDisconnect(client, nil, true) + r3 := new(storage.Client) + _, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result() + require.Error(t, err) + require.ErrorIs(t, err, redis.Nil) + require.Empty(t, r3.ID) +} + +func TestOnSessionEstablishedNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + + h.db = nil + h.OnSessionEstablished(client, packets.Packet{}) +} + +func TestOnSessionEstablishedClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + h.OnSessionEstablished(client, packets.Packet{}) +} + +func TestOnWillSent(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + defer teardown(t, h) + + c1 := client + c1.Properties.Will.Flag = 1 + h.OnWillSent(c1, packets.Packet{}) + + r := new(storage.Client) + row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result() + require.NoError(t, err) + err = r.UnmarshalBinary([]byte(row)) + require.NoError(t, err) + + require.Equal(t, uint32(1), r.Will.Flag) + require.NotSame(t, client, r) +} + +func TestOnClientExpired(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + defer teardown(t, h) + + cl := &mqtt.Client{ID: "cl1"} + clientKey := clientKey(cl) + + err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey, &storage.Client{ID: cl.ID}).Err() + require.NoError(t, err) + + r := new(storage.Client) + row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result() + require.NoError(t, err) + err = r.UnmarshalBinary([]byte(row)) + require.NoError(t, err) + require.Equal(t, clientKey, r.ID) + + h.OnClientExpired(cl) + _, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result() + require.Error(t, err) + require.ErrorIs(t, redis.Nil, err) +} + +func TestOnClientExpiredClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + h.OnClientExpired(client) +} + +func TestOnClientExpiredNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + h.OnClientExpired(client) +} + +func TestOnDisconnectNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + h.OnDisconnect(client, nil, false) +} + +func TestOnDisconnectClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + h.OnDisconnect(client, nil, false) +} + +func TestOnDisconnectSessionTakenOver(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + + testClient := &mqtt.Client{ + ID: "test", + Net: mqtt.ClientConnection{ + Remote: "test.addr", + Listener: "listener", + }, + Properties: mqtt.ClientProperties{ + Username: []byte("username"), + Clean: false, + }, + } + + testClient.Stop(packets.ErrSessionTakenOver) + teardown(t, h) + h.OnDisconnect(testClient, nil, true) +} + +func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + defer teardown(t, h) + + h.OnSubscribed(client, pkf, []byte{0}) + + r := new(storage.Subscription) + row, err := h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result() + require.NoError(t, err) + err = r.UnmarshalBinary([]byte(row)) + require.NoError(t, err) + require.Equal(t, client.ID, r.Client) + require.Equal(t, pkf.Filters[0].Filter, r.Filter) + require.Equal(t, byte(0), r.Qos) + + h.OnUnsubscribed(client, pkf) + _, err = h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result() + require.Error(t, err) + require.ErrorIs(t, err, redis.Nil) +} + +func TestOnSubscribedNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + h.OnSubscribed(client, pkf, []byte{0}) +} + +func TestOnSubscribedClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + h.OnSubscribed(client, pkf, []byte{0}) +} + +func TestOnUnsubscribedNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + h.OnUnsubscribed(client, pkf) +} + +func TestOnUnsubscribedClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + h.OnUnsubscribed(client, pkf) +} + +func TestOnRetainMessageThenUnset(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + defer teardown(t, h) + + pk := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Retain: true, + }, + Payload: []byte("hello"), + TopicName: "a/b/c", + } + + h.OnRetainMessage(client, pk, 1) + + r := new(storage.Message) + row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result() + require.NoError(t, err) + err = r.UnmarshalBinary([]byte(row)) + require.NoError(t, err) + + require.NoError(t, err) + require.Equal(t, pk.TopicName, r.TopicName) + require.Equal(t, pk.Payload, r.Payload) + + h.OnRetainMessage(client, pk, -1) + _, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result() + require.Error(t, err) + require.ErrorIs(t, err, redis.Nil) + + // coverage: delete deleted + h.OnRetainMessage(client, pk, -1) + _, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result() + require.Error(t, err) + require.ErrorIs(t, err, redis.Nil) +} + +func TestOnRetainedExpired(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + defer teardown(t, h) + + m := &storage.Message{ + ID: retainedKey("a/b/c"), + T: storage.RetainedKey, + TopicName: "a/b/c", + } + + err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), m.ID, m).Err() + require.NoError(t, err) + + r := new(storage.Message) + row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result() + require.NoError(t, err) + err = r.UnmarshalBinary([]byte(row)) + require.NoError(t, err) + require.NoError(t, err) + require.Equal(t, m.TopicName, r.TopicName) + + h.OnRetainedExpired(m.TopicName) + + _, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result() + require.Error(t, err) + require.ErrorIs(t, err, redis.Nil) +} + +func TestOnRetainedExpiredClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + h.OnRetainedExpired("a/b/c") +} + +func TestOnRetainedExpiredNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + h.OnRetainedExpired("a/b/c") +} + +func TestOnRetainMessageNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + h.OnRetainMessage(client, packets.Packet{}, 0) +} + +func TestOnRetainMessageClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + h.OnRetainMessage(client, packets.Packet{}, 0) +} + +func TestOnQosPublishThenQOSComplete(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + defer teardown(t, h) + + pk := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Retain: true, + Qos: 2, + }, + Payload: []byte("hello"), + TopicName: "a/b/c", + } + + h.OnQosPublish(client, pk, time.Now().Unix(), 0) + + r := new(storage.Message) + row, err := h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result() + require.NoError(t, err) + err = r.UnmarshalBinary([]byte(row)) + require.NoError(t, err) + require.Equal(t, pk.TopicName, r.TopicName) + require.Equal(t, pk.Payload, r.Payload) + + // ensure dates are properly saved to bolt + require.True(t, r.Sent > 0) + require.True(t, time.Now().Unix()-1 < r.Sent) + + // OnQosDropped is a passthrough to OnQosComplete here + h.OnQosDropped(client, pk) + _, err = h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result() + require.Error(t, err) + require.ErrorIs(t, err, redis.Nil) +} + +func TestOnQosPublishNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) +} + +func TestOnQosPublishClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) +} + +func TestOnQosCompleteNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + h.OnQosComplete(client, packets.Packet{}) +} + +func TestOnQosCompleteClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + h.OnQosComplete(client, packets.Packet{}) +} + +func TestOnQosDroppedNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + h.OnQosDropped(client, packets.Packet{}) +} + +func TestOnSysInfoTick(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + defer teardown(t, h) + + info := &system.Info{ + Version: "2.0.0", + BytesReceived: 100, + } + + h.OnSysInfoTick(info) + + r := new(storage.SystemInfo) + row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result() + require.NoError(t, err) + err = r.UnmarshalBinary([]byte(row)) + require.NoError(t, err) + require.Equal(t, info.Version, r.Version) + require.Equal(t, info.BytesReceived, r.BytesReceived) + require.NotSame(t, info, r) +} + +func TestOnSysInfoTickClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + h.OnSysInfoTick(new(system.Info)) +} +func TestOnSysInfoTickNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + h.OnSysInfoTick(new(system.Info)) +} + +func TestStoredClients(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + defer teardown(t, h) + + // populate with clients + err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl1", &storage.Client{ID: "cl1", T: storage.ClientKey}).Err() + require.NoError(t, err) + + err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl2", &storage.Client{ID: "cl2", T: storage.ClientKey}).Err() + require.NoError(t, err) + + err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl3", &storage.Client{ID: "cl3", T: storage.ClientKey}).Err() + require.NoError(t, err) + + r, err := h.StoredClients() + require.NoError(t, err) + require.Len(t, r, 3) + + sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID }) + require.Equal(t, "cl1", r[0].ID) + require.Equal(t, "cl2", r[1].ID) + require.Equal(t, "cl3", r[2].ID) +} + +func TestStoredClientsNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + v, err := h.StoredClients() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestStoredClientsClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + + v, err := h.StoredClients() + require.Empty(t, v) + require.Error(t, err) +} + +func TestStoredSubscriptions(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + defer teardown(t, h) + + // populate with subscriptions + err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey}).Err() + require.NoError(t, err) + + err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey}).Err() + require.NoError(t, err) + + err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey}).Err() + require.NoError(t, err) + + r, err := h.StoredSubscriptions() + require.NoError(t, err) + require.Len(t, r, 3) + sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID }) + require.Equal(t, "sub1", r[0].ID) + require.Equal(t, "sub2", r[1].ID) + require.Equal(t, "sub3", r[2].ID) +} + +func TestStoredSubscriptionsNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + v, err := h.StoredSubscriptions() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestStoredSubscriptionsClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + + v, err := h.StoredSubscriptions() + require.Empty(t, v) + require.Error(t, err) +} + +func TestStoredRetainedMessages(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + defer teardown(t, h) + + // populate with messages + err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m1", &storage.Message{ID: "m1", T: storage.RetainedKey}).Err() + require.NoError(t, err) + + err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m2", &storage.Message{ID: "m2", T: storage.RetainedKey}).Err() + require.NoError(t, err) + + err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err() + require.NoError(t, err) + + err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err() + require.NoError(t, err) + + r, err := h.StoredRetainedMessages() + require.NoError(t, err) + require.Len(t, r, 3) + sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID }) + require.Equal(t, "m1", r[0].ID) + require.Equal(t, "m2", r[1].ID) + require.Equal(t, "m3", r[2].ID) +} + +func TestStoredRetainedMessagesNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + v, err := h.StoredRetainedMessages() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestStoredRetainedMessagesClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + + v, err := h.StoredRetainedMessages() + require.Empty(t, v) + require.Error(t, err) +} + +func TestStoredInflightMessages(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + defer teardown(t, h) + + // populate with messages + err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1", &storage.Message{ID: "i1", T: storage.InflightKey}).Err() + require.NoError(t, err) + + err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2", &storage.Message{ID: "i2", T: storage.InflightKey}).Err() + require.NoError(t, err) + + err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err() + require.NoError(t, err) + + err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err() + require.NoError(t, err) + + r, err := h.StoredInflightMessages() + require.NoError(t, err) + require.Len(t, r, 3) + sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID }) + require.Equal(t, "i1", r[0].ID) + require.Equal(t, "i2", r[1].ID) + require.Equal(t, "i3", r[2].ID) +} + +func TestStoredInflightMessagesNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + v, err := h.StoredInflightMessages() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestStoredInflightMessagesClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + + v, err := h.StoredInflightMessages() + require.Empty(t, v) + require.Error(t, err) +} + +func TestStoredSysInfo(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + defer teardown(t, h) + + // populate with sys info + err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey, + &storage.SystemInfo{ + ID: storage.SysInfoKey, + Info: system.Info{ + Version: "2.0.0", + }, + T: storage.SysInfoKey, + }).Err() + require.NoError(t, err) + + r, err := h.StoredSysInfo() + require.NoError(t, err) + require.Equal(t, "2.0.0", r.Info.Version) +} + +func TestStoredSysInfoNoDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + h.db = nil + v, err := h.StoredSysInfo() + require.Empty(t, v) + require.NoError(t, err) +} + +func TestStoredSysInfoClosedDB(t *testing.T) { + s := miniredis.RunT(t) + defer s.Close() + h := newHook(t, s.Addr()) + teardown(t, h) + + v, err := h.StoredSysInfo() + require.Empty(t, v) + require.Error(t, err) +} diff --git a/hooks/storage/storage.go b/hooks/storage/storage.go new file mode 100644 index 0000000..7fd3d13 --- /dev/null +++ b/hooks/storage/storage.go @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package storage + +import ( + "encoding/json" + "errors" + + "testmqtt/packets" + "testmqtt/system" +) + +const ( + // SubscriptionKey 唯一标识订阅信息。在存储系统中,这个键用于检索和管理客户端的订阅信息 + // SubscriptionKey unique key to denote Subscriptions in a store + SubscriptionKey = "SUB" + // SysInfoKey 唯一标识服务器系统信息。在存储系统中,这个键用于存储和检索与服务器状态或系统配置相关的信息。 + // SysInfoKey unique key to denote server system information in a store + SysInfoKey = "SYS" + // RetainedKey 唯一标识保留的消息。保留消息是在订阅时立即发送的消息,即使订阅者在消息发布时不在线。这个键用于存储这些保留消息。 + // RetainedKey unique key to denote retained messages in a store + RetainedKey = "RET" + // InflightKey 唯一标识飞行中的消息。飞行中的消息指的是已经发布但尚未确认的消息,这个键用于跟踪这些消息的状态。 + // InflightKey unique key to denote inflight messages in a store + InflightKey = "IFM" + // ClientKey 唯一标识客户端信息。在存储系统中,这个键用于检索和管理有关客户端的信息,如客户端状态、连接信息等。 + // ClientKey unique key to denote clients in a store + ClientKey = "CL" +) + +var ( + // ErrDBFileNotOpen 数据库文件没有打开 + // ErrDBFileNotOpen indicates that the file database (e.g. bolt/badger) wasn't open for reading. + ErrDBFileNotOpen = errors.New("数据库文件没有打开") +) + +// Serializable is an interface for objects that can be serialized and deserialized. +type Serializable interface { + UnmarshalBinary([]byte) error + MarshalBinary() (data []byte, err error) +} + +// Client is a storable representation of an MQTT client. +type Client struct { + Will ClientWill `json:"will"` // will topic and payload data if applicable + Properties ClientProperties `json:"properties"` // the connect properties for the client + Username []byte `json:"username"` // the username of the client + ID string `json:"id" storm:"id"` // the client id / storage key + T string `json:"t"` // the data type (client) + Remote string `json:"remote"` // the remote address of the client + Listener string `json:"listener"` // the listener the client connected on + ProtocolVersion byte `json:"protocolVersion"` // mqtt protocol version of the client + Clean bool `json:"clean"` // if the client requested a clean start/session +} + +// ClientProperties contains a limited set of the mqtt v5 properties specific to a client connection. +type ClientProperties struct { + AuthenticationData []byte `json:"authenticationData,omitempty"` + User []packets.UserProperty `json:"user,omitempty"` + AuthenticationMethod string `json:"authenticationMethod,omitempty"` + SessionExpiryInterval uint32 `json:"sessionExpiryInterval,omitempty"` + MaximumPacketSize uint32 `json:"maximumPacketSize,omitempty"` + ReceiveMaximum uint16 `json:"receiveMaximum,omitempty"` + TopicAliasMaximum uint16 `json:"topicAliasMaximum,omitempty"` + SessionExpiryIntervalFlag bool `json:"sessionExpiryIntervalFlag,omitempty"` + RequestProblemInfo byte `json:"requestProblemInfo,omitempty"` + RequestProblemInfoFlag bool `json:"requestProblemInfoFlag,omitempty"` + RequestResponseInfo byte `json:"requestResponseInfo,omitempty"` +} + +// ClientWill contains a will message for a client, and limited mqtt v5 properties. +type ClientWill struct { + Payload []byte `json:"payload,omitempty"` + User []packets.UserProperty `json:"user,omitempty"` + TopicName string `json:"topicName,omitempty"` + Flag uint32 `json:"flag,omitempty"` + WillDelayInterval uint32 `json:"willDelayInterval,omitempty"` + Qos byte `json:"qos,omitempty"` + Retain bool `json:"retain,omitempty"` +} + +// MarshalBinary encodes the values into a json string. +func (d Client) MarshalBinary() (data []byte, err error) { + return json.Marshal(d) +} + +// UnmarshalBinary decodes a json string into a struct. +func (d *Client) UnmarshalBinary(data []byte) error { + if len(data) == 0 { + return nil + } + return json.Unmarshal(data, d) +} + +// Message is a storable representation of an MQTT message (specifically publish). +type Message struct { + Properties MessageProperties `json:"properties"` // - + Payload []byte `json:"payload"` // the message payload (if retained) + T string `json:"t,omitempty"` // the data type + ID string `json:"id,omitempty" storm:"id"` // the storage key + Origin string `json:"origin,omitempty"` // the id of the client who sent the message + TopicName string `json:"topic_name,omitempty"` // the topic the message was sent to (if retained) + FixedHeader packets.FixedHeader `json:"fixedheader"` // the header properties of the message + Created int64 `json:"created,omitempty"` // the time the message was created in unixtime + Sent int64 `json:"sent,omitempty"` // the last time the message was sent (for retries) in unixtime (if inflight) + PacketID uint16 `json:"packet_id,omitempty"` // the unique id of the packet (if inflight) +} + +// MessageProperties contains a limited subset of mqtt v5 properties specific to publish messages. +type MessageProperties struct { + CorrelationData []byte `json:"correlationData,omitempty"` + SubscriptionIdentifier []int `json:"subscriptionIdentifier,omitempty"` + User []packets.UserProperty `json:"user,omitempty"` + ContentType string `json:"contentType,omitempty"` + ResponseTopic string `json:"responseTopic,omitempty"` + MessageExpiryInterval uint32 `json:"messageExpiry,omitempty"` + TopicAlias uint16 `json:"topicAlias,omitempty"` + PayloadFormat byte `json:"payloadFormat,omitempty"` + PayloadFormatFlag bool `json:"payloadFormatFlag,omitempty"` +} + +// MarshalBinary encodes the values into a json string. +func (d Message) MarshalBinary() (data []byte, err error) { + return json.Marshal(d) +} + +// UnmarshalBinary decodes a json string into a struct. +func (d *Message) UnmarshalBinary(data []byte) error { + if len(data) == 0 { + return nil + } + return json.Unmarshal(data, d) +} + +// ToPacket converts a storage.Message to a standard packet. +func (d *Message) ToPacket() packets.Packet { + pk := packets.Packet{ + FixedHeader: d.FixedHeader, + PacketID: d.PacketID, + TopicName: d.TopicName, + Payload: d.Payload, + Origin: d.Origin, + Created: d.Created, + Properties: packets.Properties{ + PayloadFormat: d.Properties.PayloadFormat, + PayloadFormatFlag: d.Properties.PayloadFormatFlag, + MessageExpiryInterval: d.Properties.MessageExpiryInterval, + ContentType: d.Properties.ContentType, + ResponseTopic: d.Properties.ResponseTopic, + CorrelationData: d.Properties.CorrelationData, + SubscriptionIdentifier: d.Properties.SubscriptionIdentifier, + TopicAlias: d.Properties.TopicAlias, + User: d.Properties.User, + }, + } + + // Return a deep copy of the packet data otherwise the slices will + // continue pointing at the values from the storage packet. + pk = pk.Copy(true) + pk.FixedHeader.Dup = d.FixedHeader.Dup + + return pk +} + +// Subscription is a storable representation of an MQTT subscription. +type Subscription struct { + T string `json:"t,omitempty"` + ID string `json:"id,omitempty" storm:"id"` + Client string `json:"client,omitempty"` + Filter string `json:"filter"` + Identifier int `json:"identifier,omitempty"` + RetainHandling byte `json:"retain_handling,omitempty"` + Qos byte `json:"qos"` + RetainAsPublished bool `json:"retain_as_pub,omitempty"` + NoLocal bool `json:"no_local,omitempty"` +} + +// MarshalBinary encodes the values into a json string. +func (d Subscription) MarshalBinary() (data []byte, err error) { + return json.Marshal(d) +} + +// UnmarshalBinary decodes a json string into a struct. +func (d *Subscription) UnmarshalBinary(data []byte) error { + if len(data) == 0 { + return nil + } + return json.Unmarshal(data, d) +} + +// SystemInfo is a storable representation of the system information values. +type SystemInfo struct { + system.Info // embed the system info struct + T string `json:"t"` // the data type + ID string `json:"id" storm:"id"` // the storage key +} + +// MarshalBinary 将值编码为json字符串 +// MarshalBinary encodes the values into a json string. +func (d SystemInfo) MarshalBinary() (data []byte, err error) { + return json.Marshal(d) +} + +// UnmarshalBinary 将json字符串解码为结构体 +// UnmarshalBinary decodes a json string into a struct. +func (d *SystemInfo) UnmarshalBinary(data []byte) error { + if len(data) == 0 { + return nil + } + return json.Unmarshal(data, d) +} diff --git a/hooks/storage/storage_test.go b/hooks/storage/storage_test.go new file mode 100644 index 0000000..835cbae --- /dev/null +++ b/hooks/storage/storage_test.go @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package storage + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "testmqtt/packets" + "testmqtt/system" +) + +var ( + clientStruct = Client{ + ID: "test", + T: "client", + Remote: "remote", + Listener: "listener", + Username: []byte("mochi"), + Clean: true, + Properties: ClientProperties{ + SessionExpiryInterval: 2, + SessionExpiryIntervalFlag: true, + AuthenticationMethod: "a", + AuthenticationData: []byte("test"), + RequestProblemInfo: 1, + RequestProblemInfoFlag: true, + RequestResponseInfo: 1, + ReceiveMaximum: 128, + TopicAliasMaximum: 256, + User: []packets.UserProperty{ + {Key: "k", Val: "v"}, + }, + MaximumPacketSize: 120, + }, + Will: ClientWill{ + Qos: 1, + Payload: []byte("abc"), + TopicName: "a/b/c", + Flag: 1, + Retain: true, + WillDelayInterval: 2, + User: []packets.UserProperty{ + {Key: "k2", Val: "v2"}, + }, + }, + } + clientJSON = []byte(`{"will":{"payload":"YWJj","user":[{"k":"k2","v":"v2"}],"topicName":"a/b/c","flag":1,"willDelayInterval":2,"qos":1,"retain":true},"properties":{"authenticationData":"dGVzdA==","user":[{"k":"k","v":"v"}],"authenticationMethod":"a","sessionExpiryInterval":2,"maximumPacketSize":120,"receiveMaximum":128,"topicAliasMaximum":256,"sessionExpiryIntervalFlag":true,"requestProblemInfo":1,"requestProblemInfoFlag":true,"requestResponseInfo":1},"username":"bW9jaGk=","id":"test","t":"client","remote":"remote","listener":"listener","protocolVersion":0,"clean":true}`) + + messageStruct = Message{ + T: "message", + Payload: []byte("payload"), + FixedHeader: packets.FixedHeader{ + Remaining: 2, + Type: 3, + Qos: 1, + Dup: true, + Retain: true, + }, + ID: "id", + Origin: "mochi", + TopicName: "topic", + Properties: MessageProperties{ + PayloadFormat: 1, + PayloadFormatFlag: true, + MessageExpiryInterval: 20, + ContentType: "type", + ResponseTopic: "a/b/r", + CorrelationData: []byte("r"), + SubscriptionIdentifier: []int{1}, + TopicAlias: 2, + User: []packets.UserProperty{ + {Key: "k2", Val: "v2"}, + }, + }, + Created: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(), + Sent: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(), + PacketID: 100, + } + messageJSON = []byte(`{"properties":{"correlationData":"cg==","subscriptionIdentifier":[1],"user":[{"k":"k2","v":"v2"}],"contentType":"type","responseTopic":"a/b/r","messageExpiry":20,"topicAlias":2,"payloadFormat":1,"payloadFormatFlag":true},"payload":"cGF5bG9hZA==","t":"message","id":"id","origin":"mochi","topic_name":"topic","fixedheader":{"remaining":2,"type":3,"qos":1,"dup":true,"retain":true},"created":1569027723,"sent":1569027723,"packet_id":100}`) + + subscriptionStruct = Subscription{ + T: "subscription", + ID: "id", + Client: "mochi", + Filter: "a/b/c", + Qos: 1, + } + subscriptionJSON = []byte(`{"t":"subscription","id":"id","client":"mochi","filter":"a/b/c","qos":1}`) + + sysInfoStruct = SystemInfo{ + T: "info", + ID: "id", + Info: system.Info{ + Version: "2.0.0", + Started: 1, + Uptime: 2, + BytesReceived: 3, + BytesSent: 4, + ClientsConnected: 5, + ClientsMaximum: 7, + MessagesReceived: 10, + MessagesSent: 11, + MessagesDropped: 20, + PacketsReceived: 12, + PacketsSent: 13, + Retained: 15, + Inflight: 16, + InflightDropped: 17, + }, + } + sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`) +) + +func TestClientMarshalBinary(t *testing.T) { + data, err := clientStruct.MarshalBinary() + require.NoError(t, err) + require.JSONEq(t, string(clientJSON), string(data)) +} + +func TestClientUnmarshalBinary(t *testing.T) { + d := clientStruct + err := d.UnmarshalBinary(clientJSON) + require.NoError(t, err) + require.Equal(t, clientStruct, d) +} + +func TestClientUnmarshalBinaryEmpty(t *testing.T) { + d := Client{} + err := d.UnmarshalBinary([]byte{}) + require.NoError(t, err) + require.Equal(t, Client{}, d) +} + +func TestMessageMarshalBinary(t *testing.T) { + data, err := messageStruct.MarshalBinary() + require.NoError(t, err) + require.JSONEq(t, string(messageJSON), string(data)) +} + +func TestMessageUnmarshalBinary(t *testing.T) { + d := messageStruct + err := d.UnmarshalBinary(messageJSON) + require.NoError(t, err) + require.Equal(t, messageStruct, d) +} + +func TestMessageUnmarshalBinaryEmpty(t *testing.T) { + d := Message{} + err := d.UnmarshalBinary([]byte{}) + require.NoError(t, err) + require.Equal(t, Message{}, d) +} + +func TestSubscriptionMarshalBinary(t *testing.T) { + data, err := subscriptionStruct.MarshalBinary() + require.NoError(t, err) + require.JSONEq(t, string(subscriptionJSON), string(data)) +} + +func TestSubscriptionUnmarshalBinary(t *testing.T) { + d := subscriptionStruct + err := d.UnmarshalBinary(subscriptionJSON) + require.NoError(t, err) + require.Equal(t, subscriptionStruct, d) +} + +func TestSubscriptionUnmarshalBinaryEmpty(t *testing.T) { + d := Subscription{} + err := d.UnmarshalBinary([]byte{}) + require.NoError(t, err) + require.Equal(t, Subscription{}, d) +} + +func TestSysInfoMarshalBinary(t *testing.T) { + data, err := sysInfoStruct.MarshalBinary() + require.NoError(t, err) + require.JSONEq(t, string(sysInfoJSON), string(data)) +} + +func TestSysInfoUnmarshalBinary(t *testing.T) { + d := sysInfoStruct + err := d.UnmarshalBinary(sysInfoJSON) + require.NoError(t, err) + require.Equal(t, sysInfoStruct, d) +} + +func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) { + d := SystemInfo{} + err := d.UnmarshalBinary([]byte{}) + require.NoError(t, err) + require.Equal(t, SystemInfo{}, d) +} + +func TestMessageToPacket(t *testing.T) { + d := messageStruct + pk := d.ToPacket() + + require.Equal(t, packets.Packet{ + Payload: []byte("payload"), + FixedHeader: packets.FixedHeader{ + Remaining: d.FixedHeader.Remaining, + Type: d.FixedHeader.Type, + Qos: d.FixedHeader.Qos, + Dup: d.FixedHeader.Dup, + Retain: d.FixedHeader.Retain, + }, + Origin: d.Origin, + TopicName: d.TopicName, + Properties: packets.Properties{ + PayloadFormat: d.Properties.PayloadFormat, + PayloadFormatFlag: d.Properties.PayloadFormatFlag, + MessageExpiryInterval: d.Properties.MessageExpiryInterval, + ContentType: d.Properties.ContentType, + ResponseTopic: d.Properties.ResponseTopic, + CorrelationData: d.Properties.CorrelationData, + SubscriptionIdentifier: d.Properties.SubscriptionIdentifier, + TopicAlias: d.Properties.TopicAlias, + User: d.Properties.User, + }, + PacketID: 100, + Created: d.Created, + }, pk) + +} diff --git a/listeners/http_healthcheck.go b/listeners/http_healthcheck.go new file mode 100644 index 0000000..fc0f13d --- /dev/null +++ b/listeners/http_healthcheck.go @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: Derek Duncan + +package listeners + +import ( + "context" + "log/slog" + "net/http" + "sync" + "sync/atomic" + "time" +) + +const TypeHealthCheck = "healthcheck" + +// HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint. +type HTTPHealthCheck struct { + sync.RWMutex + id string // the internal id of the listener + address string // the network address to bind to + config Config // configuration values for the listener + listen *http.Server // the http server + end uint32 // ensure the close methods are only called once +} + +// NewHTTPHealthCheck initializes and returns a new HTTP listener, listening on an address. +func NewHTTPHealthCheck(config Config) *HTTPHealthCheck { + return &HTTPHealthCheck{ + id: config.ID, + address: config.Address, + config: config, + } +} + +// ID returns the id of the listener. +func (l *HTTPHealthCheck) ID() string { + return l.id +} + +// Address returns the address of the listener. +func (l *HTTPHealthCheck) Address() string { + return l.address +} + +// Protocol returns the address of the listener. +func (l *HTTPHealthCheck) Protocol() string { + if l.listen != nil && l.listen.TLSConfig != nil { + return "https" + } + + return "http" +} + +// Init initializes the listener. +func (l *HTTPHealthCheck) Init(_ *slog.Logger) error { + mux := http.NewServeMux() + mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + } + }) + l.listen = &http.Server{ + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + Addr: l.address, + Handler: mux, + } + + if l.config.TLSConfig != nil { + l.listen.TLSConfig = l.config.TLSConfig + } + + return nil +} + +// Serve starts listening for new connections and serving responses. +func (l *HTTPHealthCheck) Serve(establish EstablishFn) { + if l.listen.TLSConfig != nil { + _ = l.listen.ListenAndServeTLS("", "") + } else { + _ = l.listen.ListenAndServe() + } +} + +// Close closes the listener and any client connections. +func (l *HTTPHealthCheck) Close(closeClients CloseFn) { + l.Lock() + defer l.Unlock() + + if atomic.CompareAndSwapUint32(&l.end, 0, 1) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = l.listen.Shutdown(ctx) + } + + closeClients(l.id) +} diff --git a/listeners/http_healthcheck_test.go b/listeners/http_healthcheck_test.go new file mode 100644 index 0000000..66a693c --- /dev/null +++ b/listeners/http_healthcheck_test.go @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: Derek Duncan + +package listeners + +import ( + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewHTTPHealthCheck(t *testing.T) { + l := NewHTTPHealthCheck(basicConfig) + require.Equal(t, basicConfig.ID, l.id) + require.Equal(t, basicConfig.Address, l.address) +} + +func TestHTTPHealthCheckID(t *testing.T) { + l := NewHTTPHealthCheck(basicConfig) + require.Equal(t, basicConfig.ID, l.ID()) +} + +func TestHTTPHealthCheckAddress(t *testing.T) { + l := NewHTTPHealthCheck(basicConfig) + require.Equal(t, basicConfig.Address, l.Address()) +} + +func TestHTTPHealthCheckProtocol(t *testing.T) { + l := NewHTTPHealthCheck(basicConfig) + require.Equal(t, "http", l.Protocol()) +} + +func TestHTTPHealthCheckTLSProtocol(t *testing.T) { + l := NewHTTPHealthCheck(tlsConfig) + _ = l.Init(logger) + require.Equal(t, "https", l.Protocol()) +} + +func TestHTTPHealthCheckInit(t *testing.T) { + l := NewHTTPHealthCheck(basicConfig) + err := l.Init(logger) + require.NoError(t, err) + + require.NotNil(t, l.listen) + require.Equal(t, basicConfig.Address, l.listen.Addr) +} + +func TestHTTPHealthCheckServeAndClose(t *testing.T) { + // setup http stats listener + l := NewHTTPHealthCheck(basicConfig) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + // call healthcheck + resp, err := http.Get("http://localhost" + testAddr + "/healthcheck") + require.NoError(t, err) + require.NotNil(t, resp) + + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + // ensure listening is closed + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.Equal(t, true, closed) + + _, err = http.Get("http://localhost/healthcheck" + testAddr + "/healthcheck") + require.Error(t, err) + <-o +} + +func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) { + // setup http stats listener + l := NewHTTPHealthCheck(basicConfig) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + // make disallowed method type http request + resp, err := http.Post("http://localhost"+testAddr+"/healthcheck", "application/json", http.NoBody) + require.NoError(t, err) + require.NotNil(t, resp) + + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + // ensure listening is closed + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.Equal(t, true, closed) + + _, err = http.Post("http://localhost/healthcheck"+testAddr+"/healthcheck", "application/json", http.NoBody) + require.Error(t, err) + <-o +} + +func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) { + l := NewHTTPHealthCheck(tlsConfig) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + l.Close(MockCloser) +} diff --git a/listeners/http_sysinfo.go b/listeners/http_sysinfo.go new file mode 100644 index 0000000..03b2da8 --- /dev/null +++ b/listeners/http_sysinfo.go @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package listeners + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "sync" + "sync/atomic" + "time" + + "testmqtt/system" +) + +const TypeSysInfo = "sysinfo" + +// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint. +type HTTPStats struct { + sync.RWMutex + id string // the internal id of the listener + address string // the network address to bind to + config Config // configuration values for the listener + listen *http.Server // the http server + sysInfo *system.Info // pointers to the server data + log *slog.Logger // server logger + end uint32 // ensure the close methods are only called once +} + +// NewHTTPStats initializes and returns a new HTTP listener, listening on an address. +func NewHTTPStats(config Config, sysInfo *system.Info) *HTTPStats { + return &HTTPStats{ + sysInfo: sysInfo, + id: config.ID, + address: config.Address, + config: config, + } +} + +// ID returns the id of the listener. +func (l *HTTPStats) ID() string { + return l.id +} + +// Address returns the address of the listener. +func (l *HTTPStats) Address() string { + return l.address +} + +// Protocol returns the address of the listener. +func (l *HTTPStats) Protocol() string { + if l.listen != nil && l.listen.TLSConfig != nil { + return "https" + } + + return "http" +} + +// Init initializes the listener. +func (l *HTTPStats) Init(log *slog.Logger) error { + l.log = log + mux := http.NewServeMux() + mux.HandleFunc("/", l.jsonHandler) + l.listen = &http.Server{ + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + Addr: l.address, + Handler: mux, + } + + if l.config.TLSConfig != nil { + l.listen.TLSConfig = l.config.TLSConfig + } + + return nil +} + +// Serve starts listening for new connections and serving responses. +func (l *HTTPStats) Serve(establish EstablishFn) { + + var err error + if l.listen.TLSConfig != nil { + err = l.listen.ListenAndServeTLS("", "") + } else { + err = l.listen.ListenAndServe() + } + + // After the listener has been shutdown, no need to print the http.ErrServerClosed error. + if err != nil && atomic.LoadUint32(&l.end) == 0 { + l.log.Error("failed to serve.", "error", err, "listener", l.id) + } +} + +// Close closes the listener and any client connections. +func (l *HTTPStats) Close(closeClients CloseFn) { + l.Lock() + defer l.Unlock() + + if atomic.CompareAndSwapUint32(&l.end, 0, 1) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = l.listen.Shutdown(ctx) + } + + closeClients(l.id) +} + +// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON. +func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) { + info := *l.sysInfo.Clone() + + out, err := json.MarshalIndent(info, "", "\t") + if err != nil { + _, _ = io.WriteString(w, err.Error()) + } + + _, _ = w.Write(out) +} diff --git a/listeners/http_sysinfo_test.go b/listeners/http_sysinfo_test.go new file mode 100644 index 0000000..c35c077 --- /dev/null +++ b/listeners/http_sysinfo_test.go @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package listeners + +import ( + "encoding/json" + "io" + "net/http" + "testing" + "time" + + "testmqtt/system" + + "github.com/stretchr/testify/require" +) + +func TestNewHTTPStats(t *testing.T) { + l := NewHTTPStats(basicConfig, nil) + require.Equal(t, "t1", l.id) + require.Equal(t, testAddr, l.address) +} + +func TestHTTPStatsID(t *testing.T) { + l := NewHTTPStats(basicConfig, nil) + require.Equal(t, "t1", l.ID()) +} + +func TestHTTPStatsAddress(t *testing.T) { + l := NewHTTPStats(basicConfig, nil) + require.Equal(t, testAddr, l.Address()) +} + +func TestHTTPStatsProtocol(t *testing.T) { + l := NewHTTPStats(basicConfig, nil) + require.Equal(t, "http", l.Protocol()) +} + +func TestHTTPStatsTLSProtocol(t *testing.T) { + l := NewHTTPStats(tlsConfig, nil) + _ = l.Init(logger) + require.Equal(t, "https", l.Protocol()) +} + +func TestHTTPStatsInit(t *testing.T) { + sysInfo := new(system.Info) + l := NewHTTPStats(basicConfig, sysInfo) + err := l.Init(logger) + require.NoError(t, err) + + require.NotNil(t, l.sysInfo) + require.Equal(t, sysInfo, l.sysInfo) + require.NotNil(t, l.listen) + require.Equal(t, testAddr, l.listen.Addr) +} + +func TestHTTPStatsServeAndClose(t *testing.T) { + sysInfo := &system.Info{ + Version: "test", + } + + // setup http stats listener + l := NewHTTPStats(basicConfig, sysInfo) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + // get body from stats address + resp, err := http.Get("http://localhost" + testAddr) + require.NoError(t, err) + require.NotNil(t, resp) + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + // decode body from json and check data + v := new(system.Info) + err = json.Unmarshal(body, v) + require.NoError(t, err) + require.Equal(t, "test", v.Version) + + // ensure listening is closed + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.Equal(t, true, closed) + + _, err = http.Get("http://localhost" + testAddr) + require.Error(t, err) + <-o +} + +func TestHTTPStatsServeTLSAndClose(t *testing.T) { + sysInfo := &system.Info{ + Version: "test", + } + + l := NewHTTPStats(tlsConfig, sysInfo) + + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + l.Close(MockCloser) +} + +func TestHTTPStatsFailedToServe(t *testing.T) { + sysInfo := &system.Info{ + Version: "test", + } + + // setup http stats listener + config := basicConfig + config.Address = "wrong_addr" + l := NewHTTPStats(config, sysInfo) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + <-o + // ensure listening is closed + var closed bool + l.Close(func(id string) { + closed = true + }) + require.Equal(t, true, closed) +} diff --git a/listeners/listeners.go b/listeners/listeners.go new file mode 100644 index 0000000..ded7c37 --- /dev/null +++ b/listeners/listeners.go @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package listeners + +import ( + "crypto/tls" + "net" + "sync" + + "log/slog" +) + +// Config contains configuration values for a listener. +type Config struct { + Type string + ID string + Address string + // TLSConfig is a tls.Config configuration to be used with the listener. See examples folder for basic and mutual-tls use. + TLSConfig *tls.Config +} + +// EstablishFn is a callback function for establishing new clients. +type EstablishFn func(id string, c net.Conn) error + +// CloseFn is a callback function for closing all listener clients. +type CloseFn func(id string) + +// Listener is an interface for network listeners. A network listener listens +// for incoming client connections and adds them to the server. +type Listener interface { + Init(*slog.Logger) error // open the network address + Serve(EstablishFn) // starting actively listening for new connections + ID() string // return the id of the listener + Address() string // the address of the listener + Protocol() string // the protocol in use by the listener + Close(CloseFn) // stop and close the listener +} + +// Listeners contains the network listeners for the broker. +type Listeners struct { + ClientsWg sync.WaitGroup // a waitgroup that waits for all clients in all listeners to finish. + internal map[string]Listener // a map of active listeners. + sync.RWMutex +} + +// New returns a new instance of Listeners. +func New() *Listeners { + return &Listeners{ + internal: map[string]Listener{}, + } +} + +// Add adds a new listener to the listeners map, keyed on id. +func (l *Listeners) Add(val Listener) { + l.Lock() + defer l.Unlock() + l.internal[val.ID()] = val +} + +// Get returns the value of a listener if it exists. +func (l *Listeners) Get(id string) (Listener, bool) { + l.RLock() + defer l.RUnlock() + val, ok := l.internal[id] + return val, ok +} + +// Len returns the length of the listeners map. +func (l *Listeners) Len() int { + l.RLock() + defer l.RUnlock() + return len(l.internal) +} + +// Delete removes a listener from the internal map. +func (l *Listeners) Delete(id string) { + l.Lock() + defer l.Unlock() + delete(l.internal, id) +} + +// Serve starts a listener serving from the internal map. +func (l *Listeners) Serve(id string, establisher EstablishFn) { + l.RLock() + defer l.RUnlock() + listener := l.internal[id] + + go func(e EstablishFn) { + listener.Serve(e) + }(establisher) +} + +// ServeAll starts all listeners serving from the internal map. +func (l *Listeners) ServeAll(establisher EstablishFn) { + l.RLock() + i := 0 + ids := make([]string, len(l.internal)) + for id := range l.internal { + ids[i] = id + i++ + } + l.RUnlock() + + for _, id := range ids { + l.Serve(id, establisher) + } +} + +// Close stops a listener from the internal map. +func (l *Listeners) Close(id string, closer CloseFn) { + l.RLock() + defer l.RUnlock() + if listener, ok := l.internal[id]; ok { + listener.Close(closer) + } +} + +// CloseAll iterates and closes all registered listeners. +func (l *Listeners) CloseAll(closer CloseFn) { + l.RLock() + i := 0 + ids := make([]string, len(l.internal)) + for id := range l.internal { + ids[i] = id + i++ + } + l.RUnlock() + + for _, id := range ids { + l.Close(id, closer) + } + l.ClientsWg.Wait() +} diff --git a/listeners/listeners_test.go b/listeners/listeners_test.go new file mode 100644 index 0000000..6b4b025 --- /dev/null +++ b/listeners/listeners_test.go @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package listeners + +import ( + "crypto/tls" + "log" + "os" + "testing" + "time" + + "log/slog" + + "github.com/stretchr/testify/require" +) + +const testAddr = ":22222" + +var ( + basicConfig = Config{ID: "t1", Address: testAddr} + tlsConfig = Config{ID: "t1", Address: testAddr, TLSConfig: tlsConfigBasic} + + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) + + testCertificate = []byte(`-----BEGIN CERTIFICATE----- +MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB +VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV +BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD +VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x +DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3 +AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi +OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI +MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD +gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ +qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy +zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw= +-----END CERTIFICATE-----`) + + testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o +FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA +rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB +AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K +UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m +n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ +mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6 +INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z +AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt +/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32 +WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy +w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3 +OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc= +-----END RSA PRIVATE KEY-----`) + + tlsConfigBasic *tls.Config +) + +func init() { + cert, err := tls.X509KeyPair(testCertificate, testPrivateKey) + if err != nil { + log.Fatal(err) + } + + // Basic TLS Config + tlsConfigBasic = &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + } + tlsConfig.TLSConfig = tlsConfigBasic +} + +func TestNew(t *testing.T) { + l := New() + require.NotNil(t, l.internal) +} + +func TestAddListener(t *testing.T) { + l := New() + l.Add(NewMockListener("t1", testAddr)) + require.Contains(t, l.internal, "t1") +} + +func TestGetListener(t *testing.T) { + l := New() + l.Add(NewMockListener("t1", testAddr)) + l.Add(NewMockListener("t2", testAddr)) + require.Contains(t, l.internal, "t1") + require.Contains(t, l.internal, "t2") + + g, ok := l.Get("t1") + require.True(t, ok) + require.Equal(t, g.ID(), "t1") +} + +func TestLenListener(t *testing.T) { + l := New() + l.Add(NewMockListener("t1", testAddr)) + l.Add(NewMockListener("t2", testAddr)) + require.Contains(t, l.internal, "t1") + require.Contains(t, l.internal, "t2") + require.Equal(t, 2, l.Len()) +} + +func TestDeleteListener(t *testing.T) { + l := New() + l.Add(NewMockListener("t1", testAddr)) + require.Contains(t, l.internal, "t1") + l.Delete("t1") + _, ok := l.Get("t1") + require.False(t, ok) + require.Nil(t, l.internal["t1"]) +} + +func TestServeListener(t *testing.T) { + l := New() + l.Add(NewMockListener("t1", testAddr)) + l.Serve("t1", MockEstablisher) + time.Sleep(time.Millisecond) + require.True(t, l.internal["t1"].(*MockListener).IsServing()) + + l.Close("t1", MockCloser) + require.False(t, l.internal["t1"].(*MockListener).IsServing()) +} + +func TestServeAllListeners(t *testing.T) { + l := New() + l.Add(NewMockListener("t1", testAddr)) + l.Add(NewMockListener("t2", testAddr)) + l.Add(NewMockListener("t3", testAddr)) + l.ServeAll(MockEstablisher) + time.Sleep(time.Millisecond) + + require.True(t, l.internal["t1"].(*MockListener).IsServing()) + require.True(t, l.internal["t2"].(*MockListener).IsServing()) + require.True(t, l.internal["t3"].(*MockListener).IsServing()) + + l.Close("t1", MockCloser) + l.Close("t2", MockCloser) + l.Close("t3", MockCloser) + + require.False(t, l.internal["t1"].(*MockListener).IsServing()) + require.False(t, l.internal["t2"].(*MockListener).IsServing()) + require.False(t, l.internal["t3"].(*MockListener).IsServing()) +} + +func TestCloseListener(t *testing.T) { + l := New() + mocked := NewMockListener("t1", testAddr) + l.Add(mocked) + l.Serve("t1", MockEstablisher) + time.Sleep(time.Millisecond) + var closed bool + l.Close("t1", func(id string) { + closed = true + }) + require.True(t, closed) +} + +func TestCloseAllListeners(t *testing.T) { + l := New() + l.Add(NewMockListener("t1", testAddr)) + l.Add(NewMockListener("t2", testAddr)) + l.Add(NewMockListener("t3", testAddr)) + l.ServeAll(MockEstablisher) + time.Sleep(time.Millisecond) + require.True(t, l.internal["t1"].(*MockListener).IsServing()) + require.True(t, l.internal["t2"].(*MockListener).IsServing()) + require.True(t, l.internal["t3"].(*MockListener).IsServing()) + + closed := make(map[string]bool) + l.CloseAll(func(id string) { + closed[id] = true + }) + require.Contains(t, closed, "t1") + require.Contains(t, closed, "t2") + require.Contains(t, closed, "t3") + require.True(t, closed["t1"]) + require.True(t, closed["t2"]) + require.True(t, closed["t3"]) +} diff --git a/listeners/mock.go b/listeners/mock.go new file mode 100644 index 0000000..1a67d89 --- /dev/null +++ b/listeners/mock.go @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package listeners + +import ( + "fmt" + "net" + "sync" + + "log/slog" +) + +const TypeMock = "mock" + +// MockEstablisher is a function signature which can be used in testing. +func MockEstablisher(id string, c net.Conn) error { + return nil +} + +// MockCloser is a function signature which can be used in testing. +func MockCloser(id string) {} + +// MockListener is a mock listener for establishing client connections. +type MockListener struct { + sync.RWMutex + id string // the id of the listener + address string // the network address the listener binds to + Config *Config // configuration for the listener + done chan bool // indicate the listener is done + Serving bool // indicate the listener is serving + Listening bool // indiciate the listener is listening + ErrListen bool // throw an error on listen +} + +// NewMockListener returns a new instance of MockListener. +func NewMockListener(id, address string) *MockListener { + return &MockListener{ + id: id, + address: address, + done: make(chan bool), + } +} + +// Serve serves the mock listener. +func (l *MockListener) Serve(establisher EstablishFn) { + l.Lock() + l.Serving = true + l.Unlock() + + for range l.done { + return + } +} + +// Init initializes the listener. +func (l *MockListener) Init(log *slog.Logger) error { + if l.ErrListen { + return fmt.Errorf("listen failure") + } + + l.Lock() + defer l.Unlock() + l.Listening = true + return nil +} + +// ID returns the id of the mock listener. +func (l *MockListener) ID() string { + return l.id +} + +// Address returns the address of the listener. +func (l *MockListener) Address() string { + return l.address +} + +// Protocol returns the address of the listener. +func (l *MockListener) Protocol() string { + return "mock" +} + +// Close closes the mock listener. +func (l *MockListener) Close(closer CloseFn) { + l.Lock() + defer l.Unlock() + l.Serving = false + closer(l.id) + close(l.done) +} + +// IsServing indicates whether the mock listener is serving. +func (l *MockListener) IsServing() bool { + l.Lock() + defer l.Unlock() + return l.Serving +} + +// IsListening indicates whether the mock listener is listening. +func (l *MockListener) IsListening() bool { + l.Lock() + defer l.Unlock() + return l.Listening +} diff --git a/listeners/mock_test.go b/listeners/mock_test.go new file mode 100644 index 0000000..46aa922 --- /dev/null +++ b/listeners/mock_test.go @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package listeners + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestMockEstablisher(t *testing.T) { + _, w := net.Pipe() + err := MockEstablisher("t1", w) + require.NoError(t, err) + _ = w.Close() +} + +func TestNewMockListener(t *testing.T) { + mocked := NewMockListener("t1", testAddr) + require.Equal(t, "t1", mocked.id) + require.Equal(t, testAddr, mocked.address) +} +func TestMockListenerID(t *testing.T) { + mocked := NewMockListener("t1", testAddr) + require.Equal(t, "t1", mocked.ID()) +} + +func TestMockListenerAddress(t *testing.T) { + mocked := NewMockListener("t1", testAddr) + require.Equal(t, testAddr, mocked.Address()) +} +func TestMockListenerProtocol(t *testing.T) { + mocked := NewMockListener("t1", testAddr) + require.Equal(t, "mock", mocked.Protocol()) +} + +func TestNewMockListenerIsListening(t *testing.T) { + mocked := NewMockListener("t1", testAddr) + require.Equal(t, false, mocked.IsListening()) +} + +func TestNewMockListenerIsServing(t *testing.T) { + mocked := NewMockListener("t1", testAddr) + require.Equal(t, false, mocked.IsServing()) +} + +func TestNewMockListenerInit(t *testing.T) { + mocked := NewMockListener("t1", testAddr) + require.Equal(t, "t1", mocked.id) + require.Equal(t, testAddr, mocked.address) + + require.Equal(t, false, mocked.IsListening()) + err := mocked.Init(nil) + require.NoError(t, err) + require.Equal(t, true, mocked.IsListening()) +} + +func TestNewMockListenerInitFailure(t *testing.T) { + mocked := NewMockListener("t1", testAddr) + mocked.ErrListen = true + err := mocked.Init(nil) + require.Error(t, err) +} + +func TestMockListenerServe(t *testing.T) { + mocked := NewMockListener("t1", testAddr) + require.Equal(t, false, mocked.IsServing()) + + o := make(chan bool) + go func(o chan bool) { + mocked.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) // easy non-channel wait for start of serving + require.Equal(t, true, mocked.IsServing()) + + var closed bool + mocked.Close(func(id string) { + closed = true + }) + require.Equal(t, true, closed) + <-o + + _ = mocked.Init(nil) +} + +func TestMockListenerClose(t *testing.T) { + mocked := NewMockListener("t1", testAddr) + var closed bool + mocked.Close(func(id string) { + closed = true + }) + require.Equal(t, true, closed) +} diff --git a/listeners/net.go b/listeners/net.go new file mode 100644 index 0000000..fa4ef3d --- /dev/null +++ b/listeners/net.go @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: Jeroen Rinzema + +package listeners + +import ( + "net" + "sync" + "sync/atomic" + + "log/slog" +) + +// Net is a listener for establishing client connections on basic TCP protocol. +type Net struct { // [MQTT-4.2.0-1] + mu sync.Mutex + listener net.Listener // a net.Listener which will listen for new clients + id string // the internal id of the listener + log *slog.Logger // server logger + end uint32 // ensure the close methods are only called once +} + +// NewNet initialises and returns a listener serving incoming connections on the given net.Listener +func NewNet(id string, listener net.Listener) *Net { + return &Net{ + id: id, + listener: listener, + } +} + +// ID returns the id of the listener. +func (l *Net) ID() string { + return l.id +} + +// Address returns the address of the listener. +func (l *Net) Address() string { + return l.listener.Addr().String() +} + +// Protocol returns the network of the listener. +func (l *Net) Protocol() string { + return l.listener.Addr().Network() +} + +// Init initializes the listener. +func (l *Net) Init(log *slog.Logger) error { + l.log = log + return nil +} + +// Serve starts waiting for new TCP connections, and calls the establish +// connection callback for any received. +func (l *Net) Serve(establish EstablishFn) { + for { + if atomic.LoadUint32(&l.end) == 1 { + return + } + + conn, err := l.listener.Accept() + if err != nil { + return + } + + if atomic.LoadUint32(&l.end) == 0 { + go func() { + err = establish(l.id, conn) + if err != nil { + l.log.Warn("", "error", err) + } + }() + } + } +} + +// Close closes the listener and any client connections. +func (l *Net) Close(closeClients CloseFn) { + l.mu.Lock() + defer l.mu.Unlock() + + if atomic.CompareAndSwapUint32(&l.end, 0, 1) { + closeClients(l.id) + } + + if l.listener != nil { + err := l.listener.Close() + if err != nil { + return + } + } +} diff --git a/listeners/net_test.go b/listeners/net_test.go new file mode 100644 index 0000000..14a1ad6 --- /dev/null +++ b/listeners/net_test.go @@ -0,0 +1,105 @@ +package listeners + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewNet(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + require.Equal(t, "t1", l.id) +} + +func TestNetID(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + require.Equal(t, "t1", l.ID()) +} + +func TestNetAddress(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + require.Equal(t, n.Addr().String(), l.Address()) +} + +func TestNetProtocol(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + require.Equal(t, "tcp", l.Protocol()) +} + +func TestNetInit(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + err = l.Init(logger) + l.Close(MockCloser) + require.NoError(t, err) +} + +func TestNetServeAndClose(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + err = l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.True(t, closed) + <-o + + l.Close(MockCloser) // coverage: close closed + l.Serve(MockEstablisher) // coverage: serve closed +} + +func TestNetEstablishThenEnd(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + err = l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + established := make(chan bool) + go func() { + l.Serve(func(id string, c net.Conn) error { + established <- true + return errors.New("ending") // return an error to exit immediately + }) + o <- true + }() + + time.Sleep(time.Millisecond) + _, _ = net.Dial("tcp", n.Addr().String()) + require.Equal(t, true, <-established) + l.Close(MockCloser) + <-o +} diff --git a/listeners/tcp.go b/listeners/tcp.go new file mode 100644 index 0000000..014a182 --- /dev/null +++ b/listeners/tcp.go @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package listeners + +import ( + "crypto/tls" + "net" + "sync" + "sync/atomic" + + "log/slog" +) + +const TypeTCP = "tcp" + +// TCP is a listener for establishing client connections on basic TCP protocol. +type TCP struct { // [MQTT-4.2.0-1] + sync.RWMutex + id string // the internal id of the listener + address string // the network address to bind to + listen net.Listener // a net.Listener which will listen for new clients + config Config // configuration values for the listener + log *slog.Logger // server logger + end uint32 // ensure the close methods are only called once +} + +// NewTCP initializes and returns a new TCP listener, listening on an address. +func NewTCP(config Config) *TCP { + return &TCP{ + id: config.ID, + address: config.Address, + config: config, + } +} + +// ID returns the id of the listener. +func (l *TCP) ID() string { + return l.id +} + +// Address returns the address of the listener. +func (l *TCP) Address() string { + if l.listen != nil { + return l.listen.Addr().String() + } + return l.address +} + +// Protocol returns the address of the listener. +func (l *TCP) Protocol() string { + return "tcp" +} + +// Init initializes the listener. +func (l *TCP) Init(log *slog.Logger) error { + l.log = log + + var err error + if l.config.TLSConfig != nil { + l.listen, err = tls.Listen("tcp", l.address, l.config.TLSConfig) + } else { + l.listen, err = net.Listen("tcp", l.address) + } + + return err +} + +// Serve starts waiting for new TCP connections, and calls the establish +// connection callback for any received. +func (l *TCP) Serve(establish EstablishFn) { + for { + if atomic.LoadUint32(&l.end) == 1 { + return + } + + conn, err := l.listen.Accept() + if err != nil { + return + } + + if atomic.LoadUint32(&l.end) == 0 { + go func() { + err = establish(l.id, conn) + if err != nil { + l.log.Warn("", "error", err) + } + }() + } + } +} + +// Close closes the listener and any client connections. +func (l *TCP) Close(closeClients CloseFn) { + l.Lock() + defer l.Unlock() + + if atomic.CompareAndSwapUint32(&l.end, 0, 1) { + closeClients(l.id) + } + + if l.listen != nil { + err := l.listen.Close() + if err != nil { + return + } + } +} diff --git a/listeners/tcp_test.go b/listeners/tcp_test.go new file mode 100644 index 0000000..4b7496f --- /dev/null +++ b/listeners/tcp_test.go @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package listeners + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewTCP(t *testing.T) { + l := NewTCP(basicConfig) + require.Equal(t, "t1", l.id) + require.Equal(t, testAddr, l.address) +} + +func TestTCPID(t *testing.T) { + l := NewTCP(basicConfig) + require.Equal(t, "t1", l.ID()) +} + +func TestTCPAddress(t *testing.T) { + l := NewTCP(basicConfig) + require.Equal(t, testAddr, l.Address()) +} + +func TestTCPProtocol(t *testing.T) { + l := NewTCP(basicConfig) + require.Equal(t, "tcp", l.Protocol()) +} + +func TestTCPProtocolTLS(t *testing.T) { + l := NewTCP(tlsConfig) + _ = l.Init(logger) + defer l.listen.Close() + require.Equal(t, "tcp", l.Protocol()) +} + +func TestTCPInit(t *testing.T) { + l := NewTCP(basicConfig) + err := l.Init(logger) + l.Close(MockCloser) + require.NoError(t, err) + + l2 := NewTCP(tlsConfig) + err = l2.Init(logger) + l2.Close(MockCloser) + require.NoError(t, err) + require.NotNil(t, l2.config.TLSConfig) +} + +func TestTCPServeAndClose(t *testing.T) { + l := NewTCP(basicConfig) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.True(t, closed) + <-o + + l.Close(MockCloser) // coverage: close closed + l.Serve(MockEstablisher) // coverage: serve closed +} + +func TestTCPServeTLSAndClose(t *testing.T) { + l := NewTCP(tlsConfig) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.Equal(t, true, closed) + <-o +} + +func TestTCPEstablishThenEnd(t *testing.T) { + l := NewTCP(basicConfig) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + established := make(chan bool) + go func() { + l.Serve(func(id string, c net.Conn) error { + established <- true + return errors.New("ending") // return an error to exit immediately + }) + o <- true + }() + + time.Sleep(time.Millisecond) + _, _ = net.Dial("tcp", l.listen.Addr().String()) + require.Equal(t, true, <-established) + l.Close(MockCloser) + <-o +} diff --git a/listeners/unixsock.go b/listeners/unixsock.go new file mode 100644 index 0000000..23df130 --- /dev/null +++ b/listeners/unixsock.go @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: jason@zgwit.com + +package listeners + +import ( + "net" + "os" + "sync" + "sync/atomic" + + "log/slog" +) + +const TypeUnix = "unix" + +// UnixSock is a listener for establishing client connections on basic UnixSock protocol. +type UnixSock struct { + sync.RWMutex + id string // the internal id of the listener. + address string // the network address to bind to. + config Config // configuration values for the listener + listen net.Listener // a net.Listener which will listen for new clients. + log *slog.Logger // server logger + end uint32 // ensure the close methods are only called once. +} + +// NewUnixSock initializes and returns a new UnixSock listener, listening on an address. +func NewUnixSock(config Config) *UnixSock { + return &UnixSock{ + id: config.ID, + address: config.Address, + config: config, + } +} + +// ID returns the id of the listener. +func (l *UnixSock) ID() string { + return l.id +} + +// Address returns the address of the listener. +func (l *UnixSock) Address() string { + return l.address +} + +// Protocol returns the address of the listener. +func (l *UnixSock) Protocol() string { + return "unix" +} + +// Init initializes the listener. +func (l *UnixSock) Init(log *slog.Logger) error { + l.log = log + + var err error + _ = os.Remove(l.address) + l.listen, err = net.Listen("unix", l.address) + return err +} + +// Serve starts waiting for new UnixSock connections, and calls the establish +// connection callback for any received. +func (l *UnixSock) Serve(establish EstablishFn) { + for { + if atomic.LoadUint32(&l.end) == 1 { + return + } + + conn, err := l.listen.Accept() + if err != nil { + return + } + + if atomic.LoadUint32(&l.end) == 0 { + go func() { + err = establish(l.id, conn) + if err != nil { + l.log.Warn("", "error", err) + } + }() + } + } +} + +// Close closes the listener and any client connections. +func (l *UnixSock) Close(closeClients CloseFn) { + l.Lock() + defer l.Unlock() + + if atomic.CompareAndSwapUint32(&l.end, 0, 1) { + closeClients(l.id) + } + + if l.listen != nil { + err := l.listen.Close() + if err != nil { + return + } + } +} diff --git a/listeners/unixsock_test.go b/listeners/unixsock_test.go new file mode 100644 index 0000000..a3940e6 --- /dev/null +++ b/listeners/unixsock_test.go @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: jason@zgwit.com + +package listeners + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +const testUnixAddr = "mochi.sock" + +var ( + unixConfig = Config{ID: "t1", Address: testUnixAddr} +) + +func TestNewUnixSock(t *testing.T) { + l := NewUnixSock(unixConfig) + require.Equal(t, "t1", l.id) + require.Equal(t, testUnixAddr, l.address) +} + +func TestUnixSockID(t *testing.T) { + l := NewUnixSock(unixConfig) + require.Equal(t, "t1", l.ID()) +} + +func TestUnixSockAddress(t *testing.T) { + l := NewUnixSock(unixConfig) + require.Equal(t, testUnixAddr, l.Address()) +} + +func TestUnixSockProtocol(t *testing.T) { + l := NewUnixSock(unixConfig) + require.Equal(t, "unix", l.Protocol()) +} + +func TestUnixSockInit(t *testing.T) { + l := NewUnixSock(unixConfig) + err := l.Init(logger) + l.Close(MockCloser) + require.NoError(t, err) + + t2Config := unixConfig + t2Config.ID = "t2" + l2 := NewUnixSock(t2Config) + err = l2.Init(logger) + l2.Close(MockCloser) + require.NoError(t, err) +} + +func TestUnixSockServeAndClose(t *testing.T) { + l := NewUnixSock(unixConfig) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.True(t, closed) + <-o + + l.Close(MockCloser) // coverage: close closed + l.Serve(MockEstablisher) // coverage: serve closed +} + +func TestUnixSockEstablishThenEnd(t *testing.T) { + l := NewUnixSock(unixConfig) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + established := make(chan bool) + go func() { + l.Serve(func(id string, c net.Conn) error { + established <- true + return errors.New("ending") // return an error to exit immediately + }) + o <- true + }() + + time.Sleep(time.Millisecond) + _, _ = net.Dial("unix", l.listen.Addr().String()) + require.Equal(t, true, <-established) + l.Close(MockCloser) + <-o +} diff --git a/listeners/websocket.go b/listeners/websocket.go new file mode 100644 index 0000000..267daf6 --- /dev/null +++ b/listeners/websocket.go @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package listeners + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "sync" + "sync/atomic" + "time" + + "log/slog" + + "github.com/gorilla/websocket" +) + +const TypeWS = "ws" + +var ( + // ErrInvalidMessage indicates that a message payload was not valid. + ErrInvalidMessage = errors.New("message type not binary") +) + +// Websocket is a listener for establishing websocket connections. +type Websocket struct { // [MQTT-4.2.0-1] + sync.RWMutex + id string // the internal id of the listener + address string // the network address to bind to + config Config // configuration values for the listener + listen *http.Server // a http server for serving websocket connections + log *slog.Logger // server logger + establish EstablishFn // the server's establish connection handler + upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection. + end uint32 // ensure the close methods are only called once +} + +// NewWebsocket initializes and returns a new Websocket listener, listening on an address. +func NewWebsocket(config Config) *Websocket { + return &Websocket{ + id: config.ID, + address: config.Address, + config: config, + upgrader: &websocket.Upgrader{ + Subprotocols: []string{"mqtt"}, + CheckOrigin: func(r *http.Request) bool { + return true + }, + }, + } +} + +// ID returns the id of the listener. +func (l *Websocket) ID() string { + return l.id +} + +// Address returns the address of the listener. +func (l *Websocket) Address() string { + return l.address +} + +// Protocol returns the address of the listener. +func (l *Websocket) Protocol() string { + if l.config.TLSConfig != nil { + return "wss" + } + + return "ws" +} + +// Init initializes the listener. +func (l *Websocket) Init(log *slog.Logger) error { + l.log = log + + mux := http.NewServeMux() + mux.HandleFunc("/", l.handler) + l.listen = &http.Server{ + Addr: l.address, + Handler: mux, + TLSConfig: l.config.TLSConfig, + ReadTimeout: 60 * time.Second, + WriteTimeout: 60 * time.Second, + } + + return nil +} + +// handler upgrades and handles an incoming websocket connection. +func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) { + c, err := l.upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer c.Close() + + err = l.establish(l.id, &wsConn{Conn: c.UnderlyingConn(), c: c}) + if err != nil { + l.log.Warn("", "error", err) + } +} + +// Serve starts waiting for new Websocket connections, and calls the connection +// establishment callback for any received. +func (l *Websocket) Serve(establish EstablishFn) { + var err error + l.establish = establish + + if l.listen.TLSConfig != nil { + err = l.listen.ListenAndServeTLS("", "") + } else { + err = l.listen.ListenAndServe() + } + + // After the listener has been shutdown, no need to print the http.ErrServerClosed error. + if err != nil && atomic.LoadUint32(&l.end) == 0 { + l.log.Error("failed to serve.", "error", err, "listener", l.id) + } +} + +// Close closes the listener and any client connections. +func (l *Websocket) Close(closeClients CloseFn) { + l.Lock() + defer l.Unlock() + + if atomic.CompareAndSwapUint32(&l.end, 0, 1) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = l.listen.Shutdown(ctx) + } + + closeClients(l.id) +} + +// wsConn is a websocket connection which satisfies the net.Conn interface. +type wsConn struct { + net.Conn + c *websocket.Conn + + // reader for the current message (can be nil) + r io.Reader +} + +// Read reads the next span of bytes from the websocket connection and returns the number of bytes read. +func (ws *wsConn) Read(p []byte) (int, error) { + if ws.r == nil { + op, r, err := ws.c.NextReader() + if err != nil { + return 0, err + } + + if op != websocket.BinaryMessage { + err = ErrInvalidMessage + return 0, err + } + + ws.r = r + } + + var n int + for { + // buffer is full, return what we've read so far + if n == len(p) { + return n, nil + } + + br, err := ws.r.Read(p[n:]) + n += br + if err != nil { + // when ANY error occurs, we consider this the end of the current message (either because it really is, via + // io.EOF, or because something bad happened, in which case we want to drop the remainder) + ws.r = nil + + if errors.Is(err, io.EOF) { + err = nil + } + return n, err + } + } +} + +// Write writes bytes to the websocket connection. +func (ws *wsConn) Write(p []byte) (int, error) { + err := ws.c.WriteMessage(websocket.BinaryMessage, p) + if err != nil { + return 0, err + } + + return len(p), nil +} + +// Close signals the underlying websocket conn to close. +func (ws *wsConn) Close() error { + return ws.Conn.Close() +} diff --git a/listeners/websocket_test.go b/listeners/websocket_test.go new file mode 100644 index 0000000..56d6bc3 --- /dev/null +++ b/listeners/websocket_test.go @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package listeners + +import ( + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" +) + +func TestNewWebsocket(t *testing.T) { + l := NewWebsocket(basicConfig) + require.Equal(t, "t1", l.id) + require.Equal(t, testAddr, l.address) +} + +func TestWebsocketID(t *testing.T) { + l := NewWebsocket(basicConfig) + require.Equal(t, "t1", l.ID()) +} + +func TestWebsocketAddress(t *testing.T) { + l := NewWebsocket(basicConfig) + require.Equal(t, testAddr, l.Address()) +} + +func TestWebsocketProtocol(t *testing.T) { + l := NewWebsocket(basicConfig) + require.Equal(t, "ws", l.Protocol()) +} + +func TestWebsocketProtocolTLS(t *testing.T) { + l := NewWebsocket(tlsConfig) + require.Equal(t, "wss", l.Protocol()) +} + +func TestWebsocketInit(t *testing.T) { + l := NewWebsocket(basicConfig) + require.Nil(t, l.listen) + err := l.Init(logger) + require.NoError(t, err) + require.NotNil(t, l.listen) +} + +func TestWebsocketServeAndClose(t *testing.T) { + l := NewWebsocket(basicConfig) + _ = l.Init(logger) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.True(t, closed) + <-o +} + +func TestWebsocketServeTLSAndClose(t *testing.T) { + l := NewWebsocket(tlsConfig) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + var closed bool + l.Close(func(id string) { + closed = true + }) + require.Equal(t, true, closed) + <-o +} + +func TestWebsocketFailedToServe(t *testing.T) { + config := tlsConfig + config.Address = "wrong_addr" + l := NewWebsocket(config) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + <-o + var closed bool + l.Close(func(id string) { + closed = true + }) + require.Equal(t, true, closed) +} + +func TestWebsocketUpgrade(t *testing.T) { + l := NewWebsocket(basicConfig) + _ = l.Init(logger) + + e := make(chan bool) + l.establish = func(id string, c net.Conn) error { + e <- true + return nil + } + + s := httptest.NewServer(http.HandlerFunc(l.handler)) + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil) + require.NoError(t, err) + require.Equal(t, true, <-e) + + s.Close() + _ = ws.Close() +} + +func TestWebsocketConnectionReads(t *testing.T) { + l := NewWebsocket(basicConfig) + _ = l.Init(nil) + + recv := make(chan []byte) + l.establish = func(id string, c net.Conn) error { + var out []byte + for { + buf := make([]byte, 2048) + n, err := c.Read(buf) + require.NoError(t, err) + out = append(out, buf[:n]...) + if n < 2048 { + break + } + } + + recv <- out + return nil + } + + s := httptest.NewServer(http.HandlerFunc(l.handler)) + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil) + require.NoError(t, err) + + pkt := make([]byte, 3000) // make sure this is >2048 + for i := 0; i < len(pkt); i++ { + pkt[i] = byte(i % 100) + } + + err = ws.WriteMessage(websocket.BinaryMessage, pkt) + require.NoError(t, err) + + got := <-recv + require.Equal(t, 3000, len(got)) + require.Equal(t, pkt, got) + + s.Close() + _ = ws.Close() +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..c7c372b --- /dev/null +++ b/main.go @@ -0,0 +1,103 @@ +package main + +import ( + "log" + "os" + "os/signal" + "syscall" + "testmqtt/config" + "testmqtt/mqtt" +) + +func main() { + + sigs := make(chan os.Signal, 1) + done := make(chan bool, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigs + done <- true + }() + + configBytes, err := os.ReadFile("config.yaml") + if err != nil { + log.Fatal(err) + } + + options, err := config.FromBytes(configBytes) + if err != nil { + log.Fatal(err) + } + + server := mqtt.New(options) + + go func() { + err := server.Serve() + if err != nil { + log.Fatal(err) + } + }() + + <-done + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") + + //tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener") + //wsAddr := flag.String("ws", ":1882", "network address for Websocket listener") + //infoAddr := flag.String("info", ":8080", "network address for web info dashboard listener") + //flag.Parse() + // + //sigs := make(chan os.Signal, 1) + //done := make(chan bool, 1) + //signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + //go func() { + // <-sigs + // done <- true + //}() + // + //server := mqtt.New(nil) + //_ = server.AddHook(new(auth.AllowHook), nil) + // + //tcp := listeners.NewTCP(listeners.Config{ + // ID: "t1", + // Address: *tcpAddr, + //}) + //err := server.AddListener(tcp) + //if err != nil { + // log.Fatal(err) + //} + // + //ws := listeners.NewWebsocket(listeners.Config{ + // ID: "ws1", + // Address: *wsAddr, + //}) + //err = server.AddListener(ws) + //if err != nil { + // log.Fatal(err) + //} + // + //stats := listeners.NewHTTPStats( + // listeners.Config{ + // ID: "info", + // Address: *infoAddr, + // }, + // server.Info, + //) + //err = server.AddListener(stats) + //if err != nil { + // log.Fatal(err) + //} + // + //go func() { + // err := server.Serve() + // if err != nil { + // log.Fatal(err) + // } + //}() + // + //<-done + //server.Log.Warn("caught signal, stopping...") + //_ = server.Close() + //server.Log.Info("mochi mqtt shutdown complete") +} diff --git a/mempool/bufpool.go b/mempool/bufpool.go new file mode 100644 index 0000000..2b35b0c --- /dev/null +++ b/mempool/bufpool.go @@ -0,0 +1,83 @@ +package mempool + +import ( + "bytes" + "sync" +) + +var bufPool = NewBuffer(0) + +// GetBuffer 从默认缓冲池中获取一个缓冲区 +// GetBuffer takes a Buffer from the default buffer pool +func GetBuffer() *bytes.Buffer { return bufPool.Get() } + +// PutBuffer 将 Buffer 返回到默认缓冲池 +// PutBuffer returns Buffer to the default buffer pool +func PutBuffer(x *bytes.Buffer) { bufPool.Put(x) } + +// BufferPool 缓冲池接口 +type BufferPool interface { + Get() *bytes.Buffer + Put(x *bytes.Buffer) +} + +// NewBuffer returns a buffer pool. The max specify the max capacity of the Buffer the pool will +// return. If the Buffer becoomes large than max, it will no longer be returned to the pool. If +// max <= 0, no limit will be enforced. +func NewBuffer(max int) BufferPool { + if max > 0 { + return newBufferWithCap(max) + } + return newBuffer() +} + +// Buffer is a Buffer pool. +type Buffer struct { + pool *sync.Pool +} + +func newBuffer() *Buffer { + return &Buffer{ + pool: &sync.Pool{ + New: func() any { return new(bytes.Buffer) }, + }, + } +} + +// Get a Buffer from the pool. +func (b *Buffer) Get() *bytes.Buffer { + return b.pool.Get().(*bytes.Buffer) +} + +// Put the Buffer back into pool. It resets the Buffer for reuse. +func (b *Buffer) Put(x *bytes.Buffer) { + x.Reset() + b.pool.Put(x) +} + +// BufferWithCap is a Buffer pool that +type BufferWithCap struct { + bp *Buffer + max int +} + +func newBufferWithCap(max int) *BufferWithCap { + return &BufferWithCap{ + bp: newBuffer(), + max: max, + } +} + +// Get a Buffer from the pool. +func (b *BufferWithCap) Get() *bytes.Buffer { + return b.bp.Get() +} + +// Put the Buffer back into the pool if the capacity doesn't exceed the limit. It resets the Buffer +// for reuse. +func (b *BufferWithCap) Put(x *bytes.Buffer) { + if x.Cap() > b.max { + return + } + b.bp.Put(x) +} diff --git a/mempool/bufpool_test.go b/mempool/bufpool_test.go new file mode 100644 index 0000000..560f2e9 --- /dev/null +++ b/mempool/bufpool_test.go @@ -0,0 +1,96 @@ +package mempool + +import ( + "bytes" + "reflect" + "runtime/debug" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewBuffer(t *testing.T) { + defer debug.SetGCPercent(debug.SetGCPercent(-1)) + bp := NewBuffer(1000) + require.Equal(t, "*mempool.BufferWithCap", reflect.TypeOf(bp).String()) + + bp = NewBuffer(0) + require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String()) + + bp = NewBuffer(-1) + require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String()) +} + +func TestBuffer(t *testing.T) { + defer debug.SetGCPercent(debug.SetGCPercent(-1)) + Size := 101 + + bp := NewBuffer(0) + buf := bp.Get() + + for i := 0; i < Size; i++ { + buf.WriteByte('a') + } + + bp.Put(buf) + buf = bp.Get() + require.Equal(t, 0, buf.Len()) +} + +func TestBufferWithCap(t *testing.T) { + defer debug.SetGCPercent(debug.SetGCPercent(-1)) + Size := 101 + bp := NewBuffer(100) + buf := bp.Get() + + for i := 0; i < Size; i++ { + buf.WriteByte('a') + } + + bp.Put(buf) + buf = bp.Get() + require.Equal(t, 0, buf.Len()) + require.Equal(t, 0, buf.Cap()) +} + +func BenchmarkBufferPool(b *testing.B) { + bp := NewBuffer(0) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b := bp.Get() + b.WriteString("this is a test") + bp.Put(b) + } +} + +func BenchmarkBufferPoolWithCapLarger(b *testing.B) { + bp := NewBuffer(64 * 1024) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b := bp.Get() + b.WriteString("this is a test") + bp.Put(b) + } +} + +func BenchmarkBufferPoolWithCapLesser(b *testing.B) { + bp := NewBuffer(10) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b := bp.Get() + b.WriteString("this is a test") + bp.Put(b) + } +} + +func BenchmarkBufferWithoutPool(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + b := new(bytes.Buffer) + b.WriteString("this is a test") + _ = b + } +} diff --git a/mqtt/clients.go b/mqtt/clients.go new file mode 100644 index 0000000..bbfc826 --- /dev/null +++ b/mqtt/clients.go @@ -0,0 +1,649 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package mqtt + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/rs/xid" + + "testmqtt/packets" +) + +const ( + defaultKeepalive uint16 = 10 // the default connection keepalive value in seconds. + defaultClientProtocolVersion byte = 4 // the default mqtt protocol version of connecting clients (if somehow unspecified). + minimumKeepalive uint16 = 5 // the minimum recommended keepalive - values under with display a warning. +) + +var ( + ErrMinimumKeepalive = errors.New("client keepalive is below minimum recommended value and may exhibit connection instability") +) + +// ReadFn is the function signature for the function used for reading and processing new packets. +type ReadFn func(*Client, packets.Packet) error + +// Clients contains a map of the clients known by the broker. +type Clients struct { + internal map[string]*Client // clients known by the broker, keyed on client id. + sync.RWMutex +} + +// NewClients returns an instance of Clients. +func NewClients() *Clients { + return &Clients{ + internal: make(map[string]*Client), + } +} + +// Add adds a new client to the clients map, keyed on client id. +func (cl *Clients) Add(val *Client) { + cl.Lock() + defer cl.Unlock() + cl.internal[val.ID] = val +} + +// GetAll returns all the clients. +func (cl *Clients) GetAll() map[string]*Client { + cl.RLock() + defer cl.RUnlock() + m := map[string]*Client{} + for k, v := range cl.internal { + m[k] = v + } + return m +} + +// Get returns the value of a client if it exists. +func (cl *Clients) Get(id string) (*Client, bool) { + cl.RLock() + defer cl.RUnlock() + val, ok := cl.internal[id] + return val, ok +} + +// Len returns the length of the clients map. +func (cl *Clients) Len() int { + cl.RLock() + defer cl.RUnlock() + val := len(cl.internal) + return val +} + +// Delete removes a client from the internal map. +func (cl *Clients) Delete(id string) { + cl.Lock() + defer cl.Unlock() + delete(cl.internal, id) +} + +// GetByListener returns clients matching a listener id. +func (cl *Clients) GetByListener(id string) []*Client { + cl.RLock() + defer cl.RUnlock() + clients := make([]*Client, 0, cl.Len()) + for _, client := range cl.internal { + if client.Net.Listener == id && !client.Closed() { + clients = append(clients, client) + } + } + return clients +} + +// Client contains information about a client known by the broker. +type Client struct { + Properties ClientProperties // client properties + State ClientState // the operational state of the client. + Net ClientConnection // network connection state of the client + ID string // the client id. + ops *ops // ops provides a reference to server ops. + sync.RWMutex // mutex +} + +// ClientConnection contains the connection transport and metadata for the client. +type ClientConnection struct { + Conn net.Conn // the net.Conn used to establish the connection + bconn *bufio.Reader // a buffered net.Conn for reading packets + outbuf *bytes.Buffer // a buffer for writing packets + Remote string // the remote address of the client + Listener string // listener id of the client + Inline bool // if true, the client is the built-in 'inline' embedded client +} + +// ClientProperties contains the properties which define the client behaviour. +type ClientProperties struct { + Props packets.Properties + Will Will + Username []byte + ProtocolVersion byte + Clean bool +} + +// Will contains the last will and testament details for a client connection. +type Will struct { + Payload []byte // - + User []packets.UserProperty // - + TopicName string // - + Flag uint32 // 0,1 + WillDelayInterval uint32 // - + Qos byte // - + Retain bool // - +} + +// ClientState tracks the state of the client. +type ClientState struct { + TopicAliases TopicAliases // a map of topic aliases + stopCause atomic.Value // reason for stopping + Inflight *Inflight // a map of in-flight qos messages + Subscriptions *Subscriptions // a map of the subscription filters a client maintains + disconnected int64 // the time the client disconnected in unix time, for calculating expiry + outbound chan *packets.Packet // queue for pending outbound packets + endOnce sync.Once // only end once + isTakenOver uint32 // used to identify orphaned clients + packetID uint32 // the current highest packetID + open context.Context // indicate that the client is open for packet exchange + cancelOpen context.CancelFunc // cancel function for open context + outboundQty int32 // number of messages currently in the outbound queue + Keepalive uint16 // the number of seconds the connection can wait + ServerKeepalive bool // keepalive was set by the server +} + +// newClient returns a new instance of Client. This is almost exclusively used by Server +// for creating new clients, but it lives here because it's not dependent. +func newClient(c net.Conn, o *ops) *Client { + ctx, cancel := context.WithCancel(context.Background()) + cl := &Client{ + State: ClientState{ + Inflight: NewInflights(), + Subscriptions: NewSubscriptions(), + TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum), + open: ctx, + cancelOpen: cancel, + Keepalive: defaultKeepalive, + outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending), + }, + Properties: ClientProperties{ + ProtocolVersion: defaultClientProtocolVersion, // default protocol version + }, + ops: o, + } + + if c != nil { + cl.Net = ClientConnection{ + Conn: c, + bconn: bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize), + Remote: c.RemoteAddr().String(), + } + } + + return cl +} + +// WriteLoop ranges over pending outbound messages and writes them to the client connection. +func (cl *Client) WriteLoop() { + for { + select { + case pk := <-cl.State.outbound: + if err := cl.WritePacket(*pk); err != nil { + // TODO : Figure out what to do with error + cl.ops.log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk) + } + atomic.AddInt32(&cl.State.outboundQty, -1) + case <-cl.State.open.Done(): + return + } + } +} + +// ParseConnect parses the connect parameters and properties for a client. +func (cl *Client) ParseConnect(lid string, pk packets.Packet) { + cl.Net.Listener = lid + + cl.Properties.ProtocolVersion = pk.ProtocolVersion + cl.Properties.Username = pk.Connect.Username + cl.Properties.Clean = pk.Connect.Clean + cl.Properties.Props = pk.Properties.Copy(false) + + if cl.Properties.Props.ReceiveMaximum > cl.ops.options.Capabilities.MaximumInflight { // 3.3.4 Non-normative + cl.Properties.Props.ReceiveMaximum = cl.ops.options.Capabilities.MaximumInflight + } + + if pk.Connect.Keepalive <= minimumKeepalive { + cl.ops.log.Warn( + ErrMinimumKeepalive.Error(), + "client", cl.ID, + "keepalive", pk.Connect.Keepalive, + "recommended", minimumKeepalive, + ) + } + + cl.State.Keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22] + cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client + cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max + cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum) + + cl.ID = pk.Connect.ClientIdentifier + if cl.ID == "" { + cl.ID = xid.New().String() // [MQTT-3.1.3-6] [MQTT-3.1.3-7] + cl.Properties.Props.AssignedClientID = cl.ID + } + + if pk.Connect.WillFlag { + cl.Properties.Will = Will{ + Qos: pk.Connect.WillQos, + Retain: pk.Connect.WillRetain, + Payload: pk.Connect.WillPayload, + TopicName: pk.Connect.WillTopic, + WillDelayInterval: pk.Connect.WillProperties.WillDelayInterval, + User: pk.Connect.WillProperties.User, + } + if pk.Properties.SessionExpiryIntervalFlag && + pk.Properties.SessionExpiryInterval < pk.Connect.WillProperties.WillDelayInterval { + cl.Properties.Will.WillDelayInterval = pk.Properties.SessionExpiryInterval + } + if pk.Connect.WillFlag { + cl.Properties.Will.Flag = 1 // atomic for checking + } + } +} + +// refreshDeadline refreshes the read/write deadline for the net.Conn connection. +func (cl *Client) refreshDeadline(keepalive uint16) { + var expiry time.Time // nil time can be used to disable deadline if keepalive = 0 + if keepalive > 0 { + expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22] + } + + if cl.Net.Conn != nil { + _ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22] + } +} + +// NextPacketID returns the next available (unused) packet id for the client. +// If no unused packet ids are available, an error is returned and the client +// should be disconnected. +func (cl *Client) NextPacketID() (i uint32, err error) { + cl.Lock() + defer cl.Unlock() + + i = atomic.LoadUint32(&cl.State.packetID) + started := i + overflowed := false + for { + if overflowed && i == started { + return 0, packets.ErrQuotaExceeded + } + + if i >= cl.ops.options.Capabilities.maximumPacketID { + overflowed = true + i = 0 + continue + } + + i++ + + if _, ok := cl.State.Inflight.Get(uint16(i)); !ok { + atomic.StoreUint32(&cl.State.packetID, i) + return i, nil + } + } +} + +// ResendInflightMessages attempts to resend any pending inflight messages to connected clients. +func (cl *Client) ResendInflightMessages(force bool) error { + if cl.State.Inflight.Len() == 0 { + return nil + } + + for _, tk := range cl.State.Inflight.GetAll(false) { + if tk.FixedHeader.Type == packets.Publish { + tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3] + } + + cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0) + err := cl.WritePacket(tk) + if err != nil { + return err + } + + if tk.FixedHeader.Type == packets.Puback || tk.FixedHeader.Type == packets.Pubcomp { + if ok := cl.State.Inflight.Delete(tk.PacketID); ok { + cl.ops.hooks.OnQosComplete(cl, tk) + atomic.AddInt64(&cl.ops.info.Inflight, -1) + } + } + } + + return nil +} + +// ClearInflights deletes all inflight messages for the client, e.g. for a disconnected user with a clean session. +func (cl *Client) ClearInflights() { + for _, tk := range cl.State.Inflight.GetAll(false) { + if ok := cl.State.Inflight.Delete(tk.PacketID); ok { + cl.ops.hooks.OnQosDropped(cl, tk) + atomic.AddInt64(&cl.ops.info.Inflight, -1) + } + } +} + +// ClearExpiredInflights deletes any inflight messages which have expired. +func (cl *Client) ClearExpiredInflights(now, maximumExpiry int64) []uint16 { + deleted := []uint16{} + for _, tk := range cl.State.Inflight.GetAll(false) { + expired := tk.ProtocolVersion == 5 && tk.Expiry > 0 && tk.Expiry < now // [MQTT-3.3.2-5] + + // If the maximum message expiry interval is set (greater than 0), and the message + // retention period exceeds the maximum expiry, the message will be forcibly removed. + enforced := maximumExpiry > 0 && now-tk.Created > maximumExpiry + + if expired || enforced { + if ok := cl.State.Inflight.Delete(tk.PacketID); ok { + cl.ops.hooks.OnQosDropped(cl, tk) + atomic.AddInt64(&cl.ops.info.Inflight, -1) + deleted = append(deleted, tk.PacketID) + } + } + } + + return deleted +} + +// Read reads incoming packets from the connected client and transforms them into +// packets to be handled by the packetHandler. +func (cl *Client) Read(packetHandler ReadFn) error { + var err error + + for { + if cl.Closed() { + return nil + } + + cl.refreshDeadline(cl.State.Keepalive) + fh := new(packets.FixedHeader) + err = cl.ReadFixedHeader(fh) + if err != nil { + return err + } + + pk, err := cl.ReadPacket(fh) + if err != nil { + return err + } + + err = packetHandler(cl, pk) // Process inbound packet. + if err != nil { + return err + } + } +} + +// Stop instructs the client to shut down all processing goroutines and disconnect. +func (cl *Client) Stop(err error) { + cl.State.endOnce.Do(func() { + + if cl.Net.Conn != nil { + _ = cl.Net.Conn.Close() // omit close error + } + + if err != nil { + cl.State.stopCause.Store(err) + } + + if cl.State.cancelOpen != nil { + cl.State.cancelOpen() + } + + atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix()) + }) +} + +// StopCause returns the reason the client connection was stopped, if any. +func (cl *Client) StopCause() error { + if cl.State.stopCause.Load() == nil { + return nil + } + return cl.State.stopCause.Load().(error) +} + +// StopTime returns the the time the client disconnected in unix time, else zero. +func (cl *Client) StopTime() int64 { + return atomic.LoadInt64(&cl.State.disconnected) +} + +// Closed returns true if client connection is closed. +func (cl *Client) Closed() bool { + return cl.State.open == nil || cl.State.open.Err() != nil +} + +// ReadFixedHeader reads in the values of the next packet's fixed header. +func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error { + if cl.Net.bconn == nil { + return ErrConnectionClosed + } + + b, err := cl.Net.bconn.ReadByte() + if err != nil { + return err + } + + err = fh.Decode(b) + if err != nil { + return err + } + + var bu int + fh.Remaining, bu, err = packets.DecodeLength(cl.Net.bconn) + if err != nil { + return err + } + + if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize { + return packets.ErrPacketTooLarge // [MQTT-3.2.2-15] + } + + atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1)) + return nil +} + +// ReadPacket reads the remaining buffer into an MQTT packet. +func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) { + atomic.AddInt64(&cl.ops.info.PacketsReceived, 1) + + pk.ProtocolVersion = cl.Properties.ProtocolVersion // inherit client protocol version for decoding + pk.FixedHeader = *fh + p := make([]byte, pk.FixedHeader.Remaining) + n, err := io.ReadFull(cl.Net.bconn, p) + if err != nil { + return pk, err + } + + atomic.AddInt64(&cl.ops.info.BytesReceived, int64(n)) + + // Decode the remaining packet values using a fresh copy of the bytes, + // otherwise the next packet will change the data of this one. + px := append([]byte{}, p[:]...) + switch pk.FixedHeader.Type { + case packets.Connect: + err = pk.ConnectDecode(px) + case packets.Disconnect: + err = pk.DisconnectDecode(px) + case packets.Connack: + err = pk.ConnackDecode(px) + case packets.Publish: + err = pk.PublishDecode(px) + if err == nil { + atomic.AddInt64(&cl.ops.info.MessagesReceived, 1) + } + case packets.Puback: + err = pk.PubackDecode(px) + case packets.Pubrec: + err = pk.PubrecDecode(px) + case packets.Pubrel: + err = pk.PubrelDecode(px) + case packets.Pubcomp: + err = pk.PubcompDecode(px) + case packets.Subscribe: + err = pk.SubscribeDecode(px) + case packets.Suback: + err = pk.SubackDecode(px) + case packets.Unsubscribe: + err = pk.UnsubscribeDecode(px) + case packets.Unsuback: + err = pk.UnsubackDecode(px) + case packets.Pingreq: + case packets.Pingresp: + case packets.Auth: + err = pk.AuthDecode(px) + default: + err = fmt.Errorf("invalid packet type; %v", pk.FixedHeader.Type) + } + + if err != nil { + return pk, err + } + + pk, err = cl.ops.hooks.OnPacketRead(cl, pk) + return +} + +// WritePacket encodes and writes a packet to the client. +func (cl *Client) WritePacket(pk packets.Packet) error { + if cl.Closed() { + return ErrConnectionClosed + } + + if cl.Net.Conn == nil { + return nil + } + + if pk.Expiry > 0 { + pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6] + } + + pk.ProtocolVersion = cl.Properties.ProtocolVersion + if pk.Mods.MaxSize == 0 { // NB we use this statement to embed client packet sizes in tests + pk.Mods.MaxSize = cl.Properties.Props.MaximumPacketSize + } + + if cl.Properties.Props.RequestProblemInfoFlag && cl.Properties.Props.RequestProblemInfo == 0x0 { + pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set + } + + if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo { + pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode + } + + pk = cl.ops.hooks.OnPacketEncode(cl, pk) + + var err error + buf := new(bytes.Buffer) + switch pk.FixedHeader.Type { + case packets.Connect: + err = pk.ConnectEncode(buf) + case packets.Connack: + err = pk.ConnackEncode(buf) + case packets.Publish: + err = pk.PublishEncode(buf) + case packets.Puback: + err = pk.PubackEncode(buf) + case packets.Pubrec: + err = pk.PubrecEncode(buf) + case packets.Pubrel: + err = pk.PubrelEncode(buf) + case packets.Pubcomp: + err = pk.PubcompEncode(buf) + case packets.Subscribe: + err = pk.SubscribeEncode(buf) + case packets.Suback: + err = pk.SubackEncode(buf) + case packets.Unsubscribe: + err = pk.UnsubscribeEncode(buf) + case packets.Unsuback: + err = pk.UnsubackEncode(buf) + case packets.Pingreq: + err = pk.PingreqEncode(buf) + case packets.Pingresp: + err = pk.PingrespEncode(buf) + case packets.Disconnect: + err = pk.DisconnectEncode(buf) + case packets.Auth: + err = pk.AuthEncode(buf) + default: + err = fmt.Errorf("%w: %v", packets.ErrNoValidPacketAvailable, pk.FixedHeader.Type) + } + if err != nil { + return err + } + + if pk.Mods.MaxSize > 0 && uint32(buf.Len()) > pk.Mods.MaxSize { + return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25] + } + + n, err := func() (int64, error) { + cl.Lock() + defer cl.Unlock() + if len(cl.State.outbound) == 0 { + if cl.Net.outbuf == nil { + return buf.WriteTo(cl.Net.Conn) + } + + // first write to buffer, then flush buffer + n, _ := cl.Net.outbuf.Write(buf.Bytes()) // will always be successful + err = cl.flushOutbuf() + return int64(n), err + } + + // there are more writes in the queue + if cl.Net.outbuf == nil { + if buf.Len() >= cl.ops.options.ClientNetWriteBufferSize { + return buf.WriteTo(cl.Net.Conn) + } + cl.Net.outbuf = new(bytes.Buffer) + } + + n, _ := cl.Net.outbuf.Write(buf.Bytes()) // will always be successful + if cl.Net.outbuf.Len() < cl.ops.options.ClientNetWriteBufferSize { + return int64(n), nil + } + + err = cl.flushOutbuf() + return int64(n), err + }() + if err != nil { + return err + } + + atomic.AddInt64(&cl.ops.info.BytesSent, n) + atomic.AddInt64(&cl.ops.info.PacketsSent, 1) + if pk.FixedHeader.Type == packets.Publish { + atomic.AddInt64(&cl.ops.info.MessagesSent, 1) + } + + cl.ops.hooks.OnPacketSent(cl, pk, buf.Bytes()) + + return err +} + +func (cl *Client) flushOutbuf() (err error) { + if cl.Net.outbuf == nil { + return + } + + _, err = cl.Net.outbuf.WriteTo(cl.Net.Conn) + if err == nil { + cl.Net.outbuf = nil + } + return +} diff --git a/mqtt/clients_test.go b/mqtt/clients_test.go new file mode 100644 index 0000000..15d8215 --- /dev/null +++ b/mqtt/clients_test.go @@ -0,0 +1,930 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package mqtt + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "log/slog" + "net" + "strings" + "sync/atomic" + "testing" + "time" + + "testmqtt/packets" + "testmqtt/system" + + "github.com/stretchr/testify/require" +) + +const pkInfo = "packet type %v, %s" + +var errClientStop = errors.New("test stop") + +func newTestClient() (cl *Client, r net.Conn, w net.Conn) { + r, w = net.Pipe() + + cl = newClient(w, &ops{ + info: new(system.Info), + hooks: new(Hooks), + log: logger, + options: &Options{ + Capabilities: &Capabilities{ + ReceiveMaximum: 10, + MaximumInflight: 5, + TopicAliasMaximum: 10000, + MaximumClientWritesPending: 3, + maximumPacketID: 10, + }, + }, + }) + + cl.ID = "mochi" + cl.State.Inflight.maximumSendQuota = 5 + cl.State.Inflight.sendQuota = 5 + cl.State.Inflight.maximumReceiveQuota = 10 + cl.State.Inflight.receiveQuota = 10 + cl.Properties.Props.TopicAliasMaximum = 0 + cl.Properties.Props.RequestResponseInfo = 0x1 + + go cl.WriteLoop() + + return +} + +func TestNewInflights(t *testing.T) { + require.NotNil(t, NewInflights().internal) +} + +func TestNewClients(t *testing.T) { + cl := NewClients() + require.NotNil(t, cl.internal) +} + +func TestClientsAdd(t *testing.T) { + cl := NewClients() + cl.Add(&Client{ID: "t1"}) + require.Contains(t, cl.internal, "t1") +} + +func TestClientsGet(t *testing.T) { + cl := NewClients() + cl.Add(&Client{ID: "t1"}) + cl.Add(&Client{ID: "t2"}) + require.Contains(t, cl.internal, "t1") + require.Contains(t, cl.internal, "t2") + + client, ok := cl.Get("t1") + require.Equal(t, true, ok) + require.Equal(t, "t1", client.ID) +} + +func TestClientsGetAll(t *testing.T) { + cl := NewClients() + cl.Add(&Client{ID: "t1"}) + cl.Add(&Client{ID: "t2"}) + cl.Add(&Client{ID: "t3"}) + require.Contains(t, cl.internal, "t1") + require.Contains(t, cl.internal, "t2") + require.Contains(t, cl.internal, "t3") + + clients := cl.GetAll() + require.Len(t, clients, 3) +} + +func TestClientsLen(t *testing.T) { + cl := NewClients() + cl.Add(&Client{ID: "t1"}) + cl.Add(&Client{ID: "t2"}) + require.Contains(t, cl.internal, "t1") + require.Contains(t, cl.internal, "t2") + require.Equal(t, 2, cl.Len()) +} + +func TestClientsDelete(t *testing.T) { + cl := NewClients() + cl.Add(&Client{ID: "t1"}) + require.Contains(t, cl.internal, "t1") + + cl.Delete("t1") + _, ok := cl.Get("t1") + require.Equal(t, false, ok) + require.Nil(t, cl.internal["t1"]) +} + +func TestClientsGetByListener(t *testing.T) { + cl := NewClients() + cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}}) + cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}}) + require.Contains(t, cl.internal, "t1") + require.Contains(t, cl.internal, "t2") + + clients := cl.GetByListener("tcp1") + require.NotEmpty(t, clients) + require.Equal(t, 1, len(clients)) + require.Equal(t, "tcp1", clients[0].Net.Listener) +} + +func TestNewClient(t *testing.T) { + cl, _, _ := newTestClient() + + require.NotNil(t, cl) + require.NotNil(t, cl.State.Inflight.internal) + require.NotNil(t, cl.State.Subscriptions) + require.NotNil(t, cl.State.TopicAliases) + require.Equal(t, defaultKeepalive, cl.State.Keepalive) + require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion) + require.NotNil(t, cl.Net.Conn) + require.NotNil(t, cl.Net.bconn) + require.NotNil(t, cl.ops) + require.NotNil(t, cl.ops.options.Capabilities) + require.False(t, cl.Net.Inline) +} + +func TestClientParseConnect(t *testing.T) { + cl, _, _ := newTestClient() + + pk := packets.Packet{ + ProtocolVersion: 4, + Connect: packets.ConnectParams{ + ProtocolName: []byte{'M', 'Q', 'T', 'T'}, + Clean: true, + Keepalive: 60, + ClientIdentifier: "mochi", + WillFlag: true, + WillTopic: "lwt", + WillPayload: []byte("lol gg"), + WillQos: 1, + WillRetain: false, + }, + Properties: packets.Properties{ + ReceiveMaximum: uint16(5), + }, + } + + cl.ParseConnect("tcp1", pk) + require.Equal(t, pk.Connect.ClientIdentifier, cl.ID) + require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive) + require.Equal(t, pk.Connect.Clean, cl.Properties.Clean) + require.Equal(t, pk.Connect.ClientIdentifier, cl.ID) + require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName) + require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload) + require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos) + require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain) + require.Equal(t, uint32(1), cl.Properties.Will.Flag) + require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota) + require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota) + require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota) + require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota) +} + +func TestClientParseConnectReceiveMaxExceedMaxInflight(t *testing.T) { + const MaxInflight uint16 = 1 + cl, _, _ := newTestClient() + cl.ops.options.Capabilities.MaximumInflight = MaxInflight + + pk := packets.Packet{ + ProtocolVersion: 4, + Connect: packets.ConnectParams{ + ProtocolName: []byte{'M', 'Q', 'T', 'T'}, + Clean: true, + Keepalive: 60, + ClientIdentifier: "mochi", + WillFlag: true, + WillTopic: "lwt", + WillPayload: []byte("lol gg"), + WillQos: 1, + WillRetain: false, + }, + Properties: packets.Properties{ + ReceiveMaximum: uint16(5), + }, + } + + cl.ParseConnect("tcp1", pk) + require.Equal(t, pk.Connect.ClientIdentifier, cl.ID) + require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive) + require.Equal(t, pk.Connect.Clean, cl.Properties.Clean) + require.Equal(t, pk.Connect.ClientIdentifier, cl.ID) + require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName) + require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload) + require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos) + require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain) + require.Equal(t, uint32(1), cl.Properties.Will.Flag) + require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota) + require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota) + require.Equal(t, int32(MaxInflight), cl.State.Inflight.sendQuota) + require.Equal(t, int32(MaxInflight), cl.State.Inflight.maximumSendQuota) +} + +func TestClientParseConnectOverrideWillDelay(t *testing.T) { + cl, _, _ := newTestClient() + + pk := packets.Packet{ + ProtocolVersion: 4, + Connect: packets.ConnectParams{ + ProtocolName: []byte{'M', 'Q', 'T', 'T'}, + Clean: true, + Keepalive: 60, + ClientIdentifier: "mochi", + WillFlag: true, + WillProperties: packets.Properties{ + WillDelayInterval: 200, + }, + }, + Properties: packets.Properties{ + SessionExpiryInterval: 100, + SessionExpiryIntervalFlag: true, + }, + } + + cl.ParseConnect("tcp1", pk) + require.Equal(t, pk.Properties.SessionExpiryInterval, cl.Properties.Will.WillDelayInterval) +} + +func TestClientParseConnectNoID(t *testing.T) { + cl, _, _ := newTestClient() + cl.ParseConnect("tcp1", packets.Packet{}) + require.NotEmpty(t, cl.ID) +} + +func TestClientParseConnectBelowMinimumKeepalive(t *testing.T) { + cl, _, _ := newTestClient() + var b bytes.Buffer + x := bufio.NewWriter(&b) + cl.ops.log = slog.New(slog.NewTextHandler(x, nil)) + + pk := packets.Packet{ + ProtocolVersion: 4, + Connect: packets.ConnectParams{ + ProtocolName: []byte{'M', 'Q', 'T', 'T'}, + Keepalive: minimumKeepalive - 1, + ClientIdentifier: "mochi", + }, + } + cl.ParseConnect("tcp1", pk) + err := x.Flush() + require.NoError(t, err) + require.True(t, strings.Contains(b.String(), ErrMinimumKeepalive.Error())) + require.NotEmpty(t, cl.ID) +} + +func TestClientNextPacketID(t *testing.T) { + cl, _, _ := newTestClient() + + i, err := cl.NextPacketID() + require.NoError(t, err) + require.Equal(t, uint32(1), i) + + i, err = cl.NextPacketID() + require.NoError(t, err) + require.Equal(t, uint32(2), i) +} + +func TestClientNextPacketIDInUse(t *testing.T) { + cl, _, _ := newTestClient() + + // skip over 2 + cl.State.Inflight.Set(packets.Packet{PacketID: 2}) + + i, err := cl.NextPacketID() + require.NoError(t, err) + require.Equal(t, uint32(1), i) + + i, err = cl.NextPacketID() + require.NoError(t, err) + require.Equal(t, uint32(3), i) + + // Skip over overflow + cl.State.Inflight.Set(packets.Packet{PacketID: 65535}) + atomic.StoreUint32(&cl.State.packetID, 65534) + + i, err = cl.NextPacketID() + require.NoError(t, err) + require.Equal(t, uint32(1), i) +} + +func TestClientNextPacketIDExhausted(t *testing.T) { + cl, _, _ := newTestClient() + for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ { + cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)} + } + + i, err := cl.NextPacketID() + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrQuotaExceeded) + require.Equal(t, uint32(0), i) +} + +func TestClientNextPacketIDOverflow(t *testing.T) { + cl, _, _ := newTestClient() + for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ { + cl.State.Inflight.internal[uint16(i)] = packets.Packet{} + } + + cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID - 1 + i, err := cl.NextPacketID() + require.NoError(t, err) + require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i) + cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{} + + cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID + _, err = cl.NextPacketID() + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrQuotaExceeded) +} + +func TestClientClearInflights(t *testing.T) { + cl, _, _ := newTestClient() + n := time.Now().Unix() + + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 1, Expiry: n - 1}) + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 2, Expiry: n - 2}) + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 3, Created: n - 3}) // within bounds + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 5, Created: n - 5}) // over max server expiry limit + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 7, Created: n}) + + require.Equal(t, 5, cl.State.Inflight.Len()) + cl.ClearInflights() + require.Equal(t, 0, cl.State.Inflight.Len()) +} + +func TestClientClearExpiredInflights(t *testing.T) { + cl, _, _ := newTestClient() + + n := time.Now().Unix() + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 1, Expiry: n - 1}) + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 2, Expiry: n - 2}) + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 3, Created: n - 3}) // within bounds + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 5, Created: n - 5}) // over max server expiry limit + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 7, Created: n}) + require.Equal(t, 5, cl.State.Inflight.Len()) + + deleted := cl.ClearExpiredInflights(n, 4) + require.Len(t, deleted, 3) + require.ElementsMatch(t, []uint16{1, 2, 5}, deleted) + require.Equal(t, 2, cl.State.Inflight.Len()) + + cl.State.Inflight.Set(packets.Packet{PacketID: 11, Expiry: n - 1}) + cl.State.Inflight.Set(packets.Packet{PacketID: 12, Expiry: n - 2}) // expiry is ineffective for v3. + cl.State.Inflight.Set(packets.Packet{PacketID: 13, Created: n - 3}) // within bounds for v3 + cl.State.Inflight.Set(packets.Packet{PacketID: 15, Created: n - 5}) // over max server expiry limit + require.Equal(t, 6, cl.State.Inflight.Len()) + + deleted = cl.ClearExpiredInflights(n, 4) + require.Len(t, deleted, 3) + require.ElementsMatch(t, []uint16{11, 12, 15}, deleted) + require.Equal(t, 3, cl.State.Inflight.Len()) + + cl.State.Inflight.Set(packets.Packet{PacketID: 17, Created: n - 1}) + deleted = cl.ClearExpiredInflights(n, 0) // maximumExpiry = 0 do not process abandon messages + require.Len(t, deleted, 0) + require.Equal(t, 4, cl.State.Inflight.Len()) + + cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 18, Expiry: n - 1}) + deleted = cl.ClearExpiredInflights(n, 0) // maximumExpiry = 0 do not abandon messages + require.ElementsMatch(t, []uint16{18}, deleted) // expiry is still effective for v5. + require.Len(t, deleted, 1) + require.Equal(t, 4, cl.State.Inflight.Len()) +} + +func TestClientResendInflightMessages(t *testing.T) { + pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback) + cl, r, w := newTestClient() + + cl.State.Inflight.Set(*pk1.Packet) + require.Equal(t, 1, cl.State.Inflight.Len()) + + go func() { + err := cl.ResendInflightMessages(true) + require.NoError(t, err) + time.Sleep(time.Millisecond) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, 0, cl.State.Inflight.Len()) + require.Equal(t, pk1.RawBytes, buf) +} + +func TestClientResendInflightMessagesWriteFailure(t *testing.T) { + pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup) + cl, r, _ := newTestClient() + _ = r.Close() + + cl.State.Inflight.Set(*pk1.Packet) + require.Equal(t, 1, cl.State.Inflight.Len()) + err := cl.ResendInflightMessages(true) + require.Error(t, err) + require.ErrorIs(t, err, io.ErrClosedPipe) + require.Equal(t, 1, cl.State.Inflight.Len()) +} + +func TestClientResendInflightMessagesNoMessages(t *testing.T) { + cl, _, _ := newTestClient() + err := cl.ResendInflightMessages(true) + require.NoError(t, err) +} + +func TestClientRefreshDeadline(t *testing.T) { + cl, _, _ := newTestClient() + cl.refreshDeadline(10) + require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline? +} + +func TestClientReadFixedHeader(t *testing.T) { + cl, r, _ := newTestClient() + + defer cl.Stop(errClientStop) + go func() { + _, _ = r.Write([]byte{packets.Connect << 4, 0x00}) + _ = r.Close() + }() + + fh := new(packets.FixedHeader) + err := cl.ReadFixedHeader(fh) + require.NoError(t, err) + require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.BytesReceived)) +} + +func TestClientReadFixedHeaderDecodeError(t *testing.T) { + cl, r, _ := newTestClient() + defer cl.Stop(errClientStop) + + go func() { + _, _ = r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}) + _ = r.Close() + }() + + fh := new(packets.FixedHeader) + err := cl.ReadFixedHeader(fh) + require.Error(t, err) +} + +func TestClientReadFixedHeaderPacketOversized(t *testing.T) { + cl, r, _ := newTestClient() + cl.ops.options.Capabilities.MaximumPacketSize = 2 + defer cl.Stop(errClientStop) + + go func() { + _, _ = r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes) + _ = r.Close() + }() + + fh := new(packets.FixedHeader) + err := cl.ReadFixedHeader(fh) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrPacketTooLarge) +} + +func TestClientReadFixedHeaderReadEOF(t *testing.T) { + cl, r, _ := newTestClient() + defer cl.Stop(errClientStop) + + go func() { + _ = r.Close() + }() + + fh := new(packets.FixedHeader) + err := cl.ReadFixedHeader(fh) + require.Error(t, err) + require.Equal(t, io.EOF, err) +} + +func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) { + cl, r, _ := newTestClient() + defer cl.Stop(errClientStop) + + go func() { + _, _ = r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}) + _ = r.Close() + }() + + fh := new(packets.FixedHeader) + err := cl.ReadFixedHeader(fh) + require.Error(t, err) +} + +func TestClientReadOK(t *testing.T) { + cl, r, _ := newTestClient() + defer cl.Stop(errClientStop) + go func() { + _, _ = r.Write([]byte{ + packets.Publish << 4, 18, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload, + packets.Publish << 4, 11, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'd', '/', 'e', '/', 'f', // Topic Name + 'y', 'e', 'a', 'h', // Payload + }) + _ = r.Close() + }() + + var pks []packets.Packet + o := make(chan error) + go func() { + o <- cl.Read(func(cl *Client, pk packets.Packet) error { + pks = append(pks, pk) + return nil + }) + }() + + err := <-o + require.Error(t, err) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 2, len(pks)) + require.Equal(t, []packets.Packet{ + { + ProtocolVersion: cl.Properties.ProtocolVersion, + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + Remaining: 18, + }, + TopicName: "a/b/c", + Payload: []byte("hello mochi"), + }, + { + ProtocolVersion: cl.Properties.ProtocolVersion, + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + Remaining: 11, + }, + TopicName: "d/e/f", + Payload: []byte("yeah"), + }, + }, pks) + + require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.MessagesReceived)) +} + +func TestClientReadDone(t *testing.T) { + cl, _, _ := newTestClient() + defer cl.Stop(errClientStop) + cl.State.cancelOpen() + + o := make(chan error) + go func() { + o <- cl.Read(func(cl *Client, pk packets.Packet) error { + return nil + }) + }() + + require.NoError(t, <-o) +} + +func TestClientStop(t *testing.T) { + cl, _, _ := newTestClient() + require.Equal(t, int64(0), cl.StopTime()) + cl.Stop(nil) + require.Equal(t, nil, cl.State.stopCause.Load()) + require.InDelta(t, time.Now().Unix(), cl.State.disconnected, 1.0) + require.Equal(t, cl.State.disconnected, cl.StopTime()) + require.True(t, cl.Closed()) + require.Equal(t, nil, cl.StopCause()) +} + +func TestClientClosed(t *testing.T) { + cl, _, _ := newTestClient() + require.False(t, cl.Closed()) + cl.Stop(nil) + require.True(t, cl.Closed()) +} + +func TestClientReadFixedHeaderError(t *testing.T) { + cl, r, _ := newTestClient() + defer cl.Stop(errClientStop) + go func() { + _, _ = r.Write([]byte{ + packets.Publish << 4, 11, // Fixed header + }) + _ = r.Close() + }() + + cl.Net.bconn = nil + fh := new(packets.FixedHeader) + err := cl.ReadFixedHeader(fh) + require.Error(t, err) + require.ErrorIs(t, ErrConnectionClosed, err) +} + +func TestClientReadReadHandlerErr(t *testing.T) { + cl, r, _ := newTestClient() + defer cl.Stop(errClientStop) + go func() { + _, _ = r.Write([]byte{ + packets.Publish << 4, 11, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'd', '/', 'e', '/', 'f', // Topic Name + 'y', 'e', 'a', 'h', // Payload + }) + _ = r.Close() + }() + + err := cl.Read(func(cl *Client, pk packets.Packet) error { + return errors.New("test") + }) + + require.Error(t, err) +} + +func TestClientReadReadPacketOK(t *testing.T) { + cl, r, _ := newTestClient() + defer cl.Stop(errClientStop) + go func() { + _, _ = r.Write([]byte{ + packets.Publish << 4, 11, // Fixed header + 0, 5, + 'd', '/', 'e', '/', 'f', + 'y', 'e', 'a', 'h', + }) + _ = r.Close() + }() + + fh := new(packets.FixedHeader) + err := cl.ReadFixedHeader(fh) + require.NoError(t, err) + + pk, err := cl.ReadPacket(fh) + require.NoError(t, err) + require.NotNil(t, pk) + + require.Equal(t, packets.Packet{ + ProtocolVersion: cl.Properties.ProtocolVersion, + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + Remaining: 11, + }, + TopicName: "d/e/f", + Payload: []byte("yeah"), + }, pk) +} + +func TestClientReadPacket(t *testing.T) { + cl, r, _ := newTestClient() + defer cl.Stop(errClientStop) + + for _, tx := range pkTable { + tt := tx // avoid data race + t.Run(tt.Desc, func(t *testing.T) { + atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0) + go func() { + _, _ = r.Write(tt.RawBytes) + }() + + fh := new(packets.FixedHeader) + err := cl.ReadFixedHeader(fh) + require.NoError(t, err) + + if tt.Packet.ProtocolVersion == 5 { + cl.Properties.ProtocolVersion = 5 + } else { + cl.Properties.ProtocolVersion = 0 + } + + pk, err := cl.ReadPacket(fh) + require.NoError(t, err, pkInfo, tt.Case, tt.Desc) + require.NotNil(t, pk, pkInfo, tt.Case, tt.Desc) + require.Equal(t, *tt.Packet, pk, pkInfo, tt.Case, tt.Desc) + + if tt.Packet.FixedHeader.Type == packets.Publish { + require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsReceived), pkInfo, tt.Case, tt.Desc) + } + }) + } +} + +func TestClientReadPacketInvalidTypeError(t *testing.T) { + cl, _, _ := newTestClient() + _ = cl.Net.Conn.Close() + _, err := cl.ReadPacket(&packets.FixedHeader{}) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid packet type") +} + +func TestClientWritePacket(t *testing.T) { + for _, tt := range pkTable { + cl, r, _ := newTestClient() + defer cl.Stop(errClientStop) + + cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion + + o := make(chan []byte) + go func() { + buf, err := io.ReadAll(r) + require.NoError(t, err) + o <- buf + }() + + err := cl.WritePacket(*tt.Packet) + require.NoError(t, err, pkInfo, tt.Case, tt.Desc) + + time.Sleep(2 * time.Millisecond) + _ = cl.Net.Conn.Close() + + require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc) + + cl.Stop(errClientStop) + time.Sleep(time.Millisecond * 1) + + // The stop cause is either the test error, EOF, or a + // closed pipe, depending on which goroutine runs first. + err = cl.StopCause() + require.True(t, + errors.Is(err, errClientStop) || + errors.Is(err, io.EOF) || + errors.Is(err, io.ErrClosedPipe)) + + require.Equal(t, int64(len(tt.RawBytes)), atomic.LoadInt64(&cl.ops.info.BytesSent)) + require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsSent)) + if tt.Packet.FixedHeader.Type == packets.Publish { + require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.MessagesSent)) + } + } +} + +func TestClientWritePacketBuffer(t *testing.T) { + r, w := net.Pipe() + + cl := newClient(w, &ops{ + info: new(system.Info), + hooks: new(Hooks), + log: logger, + options: &Options{ + Capabilities: &Capabilities{ + ReceiveMaximum: 10, + TopicAliasMaximum: 10000, + MaximumClientWritesPending: 3, + maximumPacketID: 10, + }, + }, + }) + + cl.ID = "mochi" + cl.State.Inflight.maximumSendQuota = 5 + cl.State.Inflight.sendQuota = 5 + cl.State.Inflight.maximumReceiveQuota = 10 + cl.State.Inflight.receiveQuota = 10 + cl.Properties.Props.TopicAliasMaximum = 0 + cl.Properties.Props.RequestResponseInfo = 0x1 + + cl.ops.options.ClientNetWriteBufferSize = 10 + defer cl.Stop(errClientStop) + + small := packets.TPacketData[packets.Publish].Get(packets.TPublishNoPayload).Packet + large := packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + + cl.State.outbound <- small + + tt := []struct { + pks []*packets.Packet + size int + }{ + { + pks: []*packets.Packet{small, small}, + size: 18, + }, + { + pks: []*packets.Packet{large}, + size: 20, + }, + { + pks: []*packets.Packet{small}, + size: 0, + }, + } + + go func() { + for i, tx := range tt { + for _, pk := range tx.pks { + cl.Properties.ProtocolVersion = pk.ProtocolVersion + err := cl.WritePacket(*pk) + require.NoError(t, err, "index: %d", i) + if i == len(tt)-1 { + cl.Net.Conn.Close() + } + time.Sleep(100 * time.Millisecond) + } + } + }() + + var n int + var err error + for i, tx := range tt { + buf := make([]byte, 100) + if i == len(tt)-1 { + buf, err = io.ReadAll(r) + n = len(buf) + } else { + n, err = io.ReadAtLeast(r, buf, 1) + } + require.NoError(t, err, "index: %d", i) + require.Equal(t, tx.size, n, "index: %d", i) + } +} + +func TestWriteClientOversizePacket(t *testing.T) { + cl, _, _ := newTestClient() + cl.Properties.Props.MaximumPacketSize = 2 + pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet + err := cl.WritePacket(pk) + require.Error(t, err) + require.ErrorIs(t, packets.ErrPacketTooLarge, err) +} + +func TestClientReadPacketReadingError(t *testing.T) { + cl, r, _ := newTestClient() + defer cl.Stop(errClientStop) + go func() { + _, _ = r.Write([]byte{ + 0, 11, // Fixed header + 0, 5, + 'd', '/', 'e', '/', 'f', + 'y', 'e', 'a', 'h', + }) + _ = r.Close() + }() + + _, err := cl.ReadPacket(&packets.FixedHeader{ + Type: 0, + Remaining: 11, + }) + require.Error(t, err) +} + +func TestClientReadPacketReadUnknown(t *testing.T) { + cl, r, _ := newTestClient() + defer cl.Stop(errClientStop) + go func() { + _, _ = r.Write([]byte{ + 0, 11, // Fixed header + 0, 5, + 'd', '/', 'e', '/', 'f', + 'y', 'e', 'a', 'h', + }) + _ = r.Close() + }() + + _, err := cl.ReadPacket(&packets.FixedHeader{ + Remaining: 1, + }) + require.Error(t, err) +} + +func TestClientWritePacketWriteNoConn(t *testing.T) { + cl, _, _ := newTestClient() + cl.Stop(errClientStop) + + err := cl.WritePacket(*pkTable[1].Packet) + require.Error(t, err) + require.Equal(t, ErrConnectionClosed, err) +} + +func TestClientWritePacketWriteError(t *testing.T) { + cl, _, _ := newTestClient() + _ = cl.Net.Conn.Close() + + err := cl.WritePacket(*pkTable[1].Packet) + require.Error(t, err) +} + +func TestClientWritePacketInvalidPacket(t *testing.T) { + cl, _, _ := newTestClient() + err := cl.WritePacket(packets.Packet{}) + require.Error(t, err) +} + +var ( + pkTable = []packets.TPacketCase{ + packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311), + packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedMqtt5), + packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession), + packets.TPacketData[packets.Publish].Get(packets.TPublishBasic), + packets.TPacketData[packets.Publish].Get(packets.TPublishMqtt5), + packets.TPacketData[packets.Puback].Get(packets.TPuback), + packets.TPacketData[packets.Pubrec].Get(packets.TPubrec), + packets.TPacketData[packets.Pubrel].Get(packets.TPubrel), + packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp), + packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe), + packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5), + packets.TPacketData[packets.Suback].Get(packets.TSuback), + packets.TPacketData[packets.Suback].Get(packets.TSubackMqtt5), + packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribe), + packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5), + packets.TPacketData[packets.Unsuback].Get(packets.TUnsuback), + packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackMqtt5), + packets.TPacketData[packets.Pingreq].Get(packets.TPingreq), + packets.TPacketData[packets.Pingresp].Get(packets.TPingresp), + packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect), + packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5), + packets.TPacketData[packets.Auth].Get(packets.TAuth), + } +) diff --git a/mqtt/hooks.go b/mqtt/hooks.go new file mode 100644 index 0000000..98c8433 --- /dev/null +++ b/mqtt/hooks.go @@ -0,0 +1,864 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co, thedevop, dgduncan + +package mqtt + +import ( + "errors" + "fmt" + "log/slog" + "sync" + "sync/atomic" + + "testmqtt/hooks/storage" + "testmqtt/packets" + "testmqtt/system" +) + +const ( + SetOptions byte = iota + OnSysInfoTick + OnStarted + OnStopped + OnConnectAuthenticate + OnACLCheck + OnConnect + OnSessionEstablish + OnSessionEstablished + OnDisconnect + OnAuthPacket + OnPacketRead + OnPacketEncode + OnPacketSent + OnPacketProcessed + OnSubscribe + OnSubscribed + OnSelectSubscribers + OnUnsubscribe + OnUnsubscribed + OnPublish + OnPublished + OnPublishDropped + OnRetainMessage + OnRetainPublished + OnQosPublish + OnQosComplete + OnQosDropped + OnPacketIDExhausted + OnWill + OnWillSent + OnClientExpired + OnRetainedExpired + StoredClients + StoredSubscriptions + StoredInflightMessages + StoredRetainedMessages + StoredSysInfo +) + +var ( + // ErrInvalidConfigType indicates a different Type of config value was expected to what was received. + // ErrInvalidConfigType = errors.New("invalid config type provided") + ErrInvalidConfigType = errors.New("提供的配置类型无效") +) + +// HookLoadConfig contains the hook and configuration as loaded from a configuration (usually file). +type HookLoadConfig struct { + Hook Hook + Config any +} + +// Hook provides an interface of handlers for different events which occur +// during the lifecycle of the broker. +type Hook interface { + ID() string + Provides(b byte) bool + Init(config any) error + Stop() error + SetOpts(l *slog.Logger, o *HookOptions) + + OnStarted() + OnStopped() + OnConnectAuthenticate(cl *Client, pk packets.Packet) bool + OnACLCheck(cl *Client, topic string, write bool) bool + OnSysInfoTick(*system.Info) + OnConnect(cl *Client, pk packets.Packet) error + OnSessionEstablish(cl *Client, pk packets.Packet) + OnSessionEstablished(cl *Client, pk packets.Packet) + OnDisconnect(cl *Client, err error, expire bool) + OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) + OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) // triggers when a new packet is received by a client, but before packet validation + OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet // modify a packet before it is byte-encoded and written to the client + OnPacketSent(cl *Client, pk packets.Packet, b []byte) // triggers when packet bytes have been written to the client + OnPacketProcessed(cl *Client, pk packets.Packet, err error) // triggers after a packet from the client been processed (handled) + OnSubscribe(cl *Client, pk packets.Packet) packets.Packet + OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) + OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers + OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet + OnUnsubscribed(cl *Client, pk packets.Packet) + OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) + OnPublished(cl *Client, pk packets.Packet) + OnPublishDropped(cl *Client, pk packets.Packet) + OnRetainMessage(cl *Client, pk packets.Packet, r int64) + OnRetainPublished(cl *Client, pk packets.Packet) + OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) + OnQosComplete(cl *Client, pk packets.Packet) + OnQosDropped(cl *Client, pk packets.Packet) + OnPacketIDExhausted(cl *Client, pk packets.Packet) + OnWill(cl *Client, will Will) (Will, error) + OnWillSent(cl *Client, pk packets.Packet) + OnClientExpired(cl *Client) + OnRetainedExpired(filter string) + StoredClients() ([]storage.Client, error) + StoredSubscriptions() ([]storage.Subscription, error) + StoredInflightMessages() ([]storage.Message, error) + StoredRetainedMessages() ([]storage.Message, error) + StoredSysInfo() (storage.SystemInfo, error) +} + +// HookOptions contains values which are inherited from the server on initialisation. +type HookOptions struct { + Capabilities *Capabilities +} + +// Hooks is a slice of Hook interfaces to be called in sequence. +type Hooks struct { + Log *slog.Logger // a logger for the hook (from the server) + internal atomic.Value // a slice of []Hook + wg sync.WaitGroup // a waitgroup for syncing hook shutdown + qty int64 // the number of hooks in use + sync.Mutex // a mutex for locking when adding hooks +} + +// Len returns the number of hooks added. +func (h *Hooks) Len() int64 { + return atomic.LoadInt64(&h.qty) +} + +// Provides returns true if any one hook provides any of the requested hook methods. +func (h *Hooks) Provides(b ...byte) bool { + for _, hook := range h.GetAll() { + for _, hb := range b { + if hook.Provides(hb) { + return true + } + } + } + + return false +} + +// Add adds and initializes a new hook. +func (h *Hooks) Add(hook Hook, config any) error { + h.Lock() + defer h.Unlock() + + err := hook.Init(config) + if err != nil { + return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err) + } + + i, ok := h.internal.Load().([]Hook) + if !ok { + i = []Hook{} + } + + i = append(i, hook) + h.internal.Store(i) + atomic.AddInt64(&h.qty, 1) + h.wg.Add(1) + + return nil +} + +// GetAll returns a slice of all the hooks. +func (h *Hooks) GetAll() []Hook { + i, ok := h.internal.Load().([]Hook) + if !ok { + return []Hook{} + } + + return i +} + +// Stop indicates all attached hooks to gracefully end. +func (h *Hooks) Stop() { + go func() { + for _, hook := range h.GetAll() { + h.Log.Info("stopping hook", "hook", hook.ID()) + if err := hook.Stop(); err != nil { + h.Log.Debug("problem stopping hook", "error", err, "hook", hook.ID()) + } + + h.wg.Done() + } + }() + + h.wg.Wait() +} + +// OnSysInfoTick is called when the $SYS topic values are published out. +func (h *Hooks) OnSysInfoTick(sys *system.Info) { + for _, hook := range h.GetAll() { + if hook.Provides(OnSysInfoTick) { + hook.OnSysInfoTick(sys) + } + } +} + +// OnStarted is called when the server has successfully started. +func (h *Hooks) OnStarted() { + for _, hook := range h.GetAll() { + if hook.Provides(OnStarted) { + hook.OnStarted() + } + } +} + +// OnStopped is called when the server has successfully stopped. +func (h *Hooks) OnStopped() { + for _, hook := range h.GetAll() { + if hook.Provides(OnStopped) { + hook.OnStopped() + } + } +} + +// OnConnect is called when a new client connects, and may return a packets.Code as an error to halt the connection. +func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) error { + for _, hook := range h.GetAll() { + if hook.Provides(OnConnect) { + err := hook.OnConnect(cl, pk) + if err != nil { + return err + } + } + } + return nil +} + +// OnSessionEstablish is called right after a new client connects and authenticates and right before +// the session is established and CONNACK is sent. +func (h *Hooks) OnSessionEstablish(cl *Client, pk packets.Packet) { + for _, hook := range h.GetAll() { + if hook.Provides(OnSessionEstablish) { + hook.OnSessionEstablish(cl, pk) + } + } +} + +// OnSessionEstablished is called when a new client establishes a session (after OnConnect). +func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) { + for _, hook := range h.GetAll() { + if hook.Provides(OnSessionEstablished) { + hook.OnSessionEstablished(cl, pk) + } + } +} + +// OnDisconnect is called when a client is disconnected for any reason. +func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) { + for _, hook := range h.GetAll() { + if hook.Provides(OnDisconnect) { + hook.OnDisconnect(cl, err, expire) + } + } +} + +// OnPacketRead is called when a packet is received from a client. +func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { + pkx = pk + for _, hook := range h.GetAll() { + if hook.Provides(OnPacketRead) { + npk, err := hook.OnPacketRead(cl, pkx) + if err != nil && errors.Is(err, packets.ErrRejectPacket) { + h.Log.Debug("packet rejected", "hook", hook.ID(), "packet", pkx) + return pk, err + } else if err != nil { + continue + } + + pkx = npk + } + } + + return +} + +// OnAuthPacket is called when an auth packet is received. It is intended to allow developers +// to create their own auth packet handling mechanisms. +func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { + pkx = pk + for _, hook := range h.GetAll() { + if hook.Provides(OnAuthPacket) { + npk, err := hook.OnAuthPacket(cl, pkx) + if err != nil { + return pk, err + } + + pkx = npk + } + } + + return +} + +// OnPacketEncode is called immediately before a packet is encoded to be sent to a client. +func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet { + for _, hook := range h.GetAll() { + if hook.Provides(OnPacketEncode) { + pk = hook.OnPacketEncode(cl, pk) + } + } + + return pk +} + +// OnPacketProcessed is called when a packet has been received and successfully handled by the broker. +func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) { + for _, hook := range h.GetAll() { + if hook.Provides(OnPacketProcessed) { + hook.OnPacketProcessed(cl, pk, err) + } + } +} + +// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter +// containing the bytes sent. +func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) { + for _, hook := range h.GetAll() { + if hook.Provides(OnPacketSent) { + hook.OnPacketSent(cl, pk, b) + } + } +} + +// OnSubscribe is called when a client subscribes to one or more filters. This method +// differs from OnSubscribed in that it allows you to modify the subscription values +// before the packet is processed. The return values of the hook methods are passed-through +// in the order the hooks were attached. +func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet { + for _, hook := range h.GetAll() { + if hook.Provides(OnSubscribe) { + pk = hook.OnSubscribe(cl, pk) + } + } + return pk +} + +// OnSubscribed is called when a client subscribes to one or more filters. +func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) { + for _, hook := range h.GetAll() { + if hook.Provides(OnSubscribed) { + hook.OnSubscribed(cl, pk, reasonCodes) + } + } +} + +// OnSelectSubscribers is called when subscribers have been collected for a topic, but before +// shared subscription subscribers have been selected. This hook can be used to programmatically +// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared +// group in a custom manner (such as based on client id, ip, etc). +func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers { + for _, hook := range h.GetAll() { + if hook.Provides(OnSelectSubscribers) { + subs = hook.OnSelectSubscribers(subs, pk) + } + } + return subs +} + +// OnUnsubscribe is called when a client unsubscribes from one or more filters. This method +// differs from OnUnsubscribed in that it allows you to modify the unsubscription values +// before the packet is processed. The return values of the hook methods are passed-through +// in the order the hooks were attached. +func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet { + for _, hook := range h.GetAll() { + if hook.Provides(OnUnsubscribe) { + pk = hook.OnUnsubscribe(cl, pk) + } + } + return pk +} + +// OnUnsubscribed is called when a client unsubscribes from one or more filters. +func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) { + for _, hook := range h.GetAll() { + if hook.Provides(OnUnsubscribed) { + hook.OnUnsubscribed(cl, pk) + } + } +} + +// OnPublish is called when a client publishes a message. This method differs from OnPublished +// in that it allows you to modify you to modify the incoming packet before it is processed. +// The return values of the hook methods are passed-through in the order the hooks were attached. +func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { + pkx = pk + for _, hook := range h.GetAll() { + if hook.Provides(OnPublish) { + npk, err := hook.OnPublish(cl, pkx) + if err != nil { + if errors.Is(err, packets.ErrRejectPacket) { + h.Log.Debug("publish packet rejected", + "error", err, + "hook", hook.ID(), + "packet", pkx) + return pk, err + } + h.Log.Error("publish packet error", + "error", err, + "hook", hook.ID(), + "packet", pkx) + return pk, err + } + pkx = npk + } + } + + return +} + +// OnPublished is called when a client has published a message to subscribers. +func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) { + for _, hook := range h.GetAll() { + if hook.Provides(OnPublished) { + hook.OnPublished(cl, pk) + } + } +} + +// OnPublishDropped is called when a message to a client was dropped instead of delivered +// such as when a client is too slow to respond. +func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) { + for _, hook := range h.GetAll() { + if hook.Provides(OnPublishDropped) { + hook.OnPublishDropped(cl, pk) + } + } +} + +// OnRetainMessage is called then a published message is retained. +func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) { + for _, hook := range h.GetAll() { + if hook.Provides(OnRetainMessage) { + hook.OnRetainMessage(cl, pk, r) + } + } +} + +// OnRetainPublished is called when a retained message is published. +func (h *Hooks) OnRetainPublished(cl *Client, pk packets.Packet) { + for _, hook := range h.GetAll() { + if hook.Provides(OnRetainPublished) { + hook.OnRetainPublished(cl, pk) + } + } +} + +// OnQosPublish is called when a publish packet with Qos >= 1 is issued to a subscriber. +// In other words, this method is called when a new inflight message is created or resent. +// It is typically used to store a new inflight message. +func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) { + for _, hook := range h.GetAll() { + if hook.Provides(OnQosPublish) { + hook.OnQosPublish(cl, pk, sent, resends) + } + } +} + +// OnQosComplete is called when the Qos flow for a message has been completed. +// In other words, when an inflight message is resolved. +// It is typically used to delete an inflight message from a store. +func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) { + for _, hook := range h.GetAll() { + if hook.Provides(OnQosComplete) { + hook.OnQosComplete(cl, pk) + } + } +} + +// OnQosDropped is called the Qos flow for a message expires. In other words, when +// an inflight message expires or is abandoned. It is typically used to delete an +// inflight message from a store. +func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) { + for _, hook := range h.GetAll() { + if hook.Provides(OnQosDropped) { + hook.OnQosDropped(cl, pk) + } + } +} + +// OnPacketIDExhausted is called when the client runs out of unused packet ids to +// assign to a packet. +func (h *Hooks) OnPacketIDExhausted(cl *Client, pk packets.Packet) { + for _, hook := range h.GetAll() { + if hook.Provides(OnPacketIDExhausted) { + hook.OnPacketIDExhausted(cl, pk) + } + } +} + +// OnWill is called when a client disconnects and publishes an LWT message. This method +// differs from OnWillSent in that it allows you to modify the LWT message before it is +// published. The return values of the hook methods are passed-through in the order +// the hooks were attached. +func (h *Hooks) OnWill(cl *Client, will Will) Will { + for _, hook := range h.GetAll() { + if hook.Provides(OnWill) { + mlwt, err := hook.OnWill(cl, will) + if err != nil { + h.Log.Error("parse will error", + "error", err, + "hook", hook.ID(), + "will", will) + continue + } + will = mlwt + } + } + + return will +} + +// OnWillSent is called when an LWT message has been issued from a disconnecting client. +func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) { + for _, hook := range h.GetAll() { + if hook.Provides(OnWillSent) { + hook.OnWillSent(cl, pk) + } + } +} + +// OnClientExpired is called when a client session has expired and should be deleted. +func (h *Hooks) OnClientExpired(cl *Client) { + for _, hook := range h.GetAll() { + if hook.Provides(OnClientExpired) { + hook.OnClientExpired(cl) + } + } +} + +// OnRetainedExpired is called when a retained message has expired and should be deleted. +func (h *Hooks) OnRetainedExpired(filter string) { + for _, hook := range h.GetAll() { + if hook.Provides(OnRetainedExpired) { + hook.OnRetainedExpired(filter) + } + } +} + +// StoredClients returns all clients, e.g. from a persistent store, is used to +// populate the server clients list before start. +func (h *Hooks) StoredClients() (v []storage.Client, err error) { + for _, hook := range h.GetAll() { + if hook.Provides(StoredClients) { + v, err := hook.StoredClients() + if err != nil { + h.Log.Error("failed to load clients", "error", err, "hook", hook.ID()) + return v, err + } + + if len(v) > 0 { + return v, nil + } + } + } + + return +} + +// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is +// used to populate the server subscriptions list before start. +func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) { + for _, hook := range h.GetAll() { + if hook.Provides(StoredSubscriptions) { + v, err := hook.StoredSubscriptions() + if err != nil { + h.Log.Error("failed to load subscriptions", "error", err, "hook", hook.ID()) + return v, err + } + + if len(v) > 0 { + return v, nil + } + } + } + + return +} + +// StoredInflightMessages returns all inflight messages, e.g. from a persistent store, +// and is used to populate the restored clients with inflight messages before start. +func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) { + for _, hook := range h.GetAll() { + if hook.Provides(StoredInflightMessages) { + v, err := hook.StoredInflightMessages() + if err != nil { + h.Log.Error("failed to load inflight messages", "error", err, "hook", hook.ID()) + return v, err + } + + if len(v) > 0 { + return v, nil + } + } + } + + return +} + +// StoredRetainedMessages returns all retained messages, e.g. from a persistent store, +// and is used to populate the server topics with retained messages before start. +func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) { + for _, hook := range h.GetAll() { + if hook.Provides(StoredRetainedMessages) { + v, err := hook.StoredRetainedMessages() + if err != nil { + h.Log.Error("failed to load retained messages", "error", err, "hook", hook.ID()) + return v, err + } + + if len(v) > 0 { + return v, nil + } + } + } + + return +} + +// StoredSysInfo returns a set of system info values. +func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) { + for _, hook := range h.GetAll() { + if hook.Provides(StoredSysInfo) { + v, err := hook.StoredSysInfo() + if err != nil { + h.Log.Error("failed to load $SYS info", "error", err, "hook", hook.ID()) + return v, err + } + + if v.Version != "" { + return v, nil + } + } + } + + return +} + +// OnConnectAuthenticate is called when a user attempts to authenticate with the server. +// An implementation of this method MUST be used to allow or deny access to the +// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to +// check connecting users against an existing user database. +func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { + for _, hook := range h.GetAll() { + if hook.Provides(OnConnectAuthenticate) { + if ok := hook.OnConnectAuthenticate(cl, pk); ok { + return true + } + } + } + + return false +} + +// OnACLCheck is called when a user attempts to publish or subscribe to a topic filter. +// An implementation of this method MUST be used to allow or deny access to the +// (see hooks/auth/allow_all or basic). It can be used in custom hooks to +// check publishing and subscribing users against an existing permissions or roles database. +func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool { + for _, hook := range h.GetAll() { + if hook.Provides(OnACLCheck) { + if ok := hook.OnACLCheck(cl, topic, write); ok { + return true + } + } + } + + return false +} + +// HookBase provides a set of default methods for each hook. It should be embedded in +// all hooks. +type HookBase struct { + Hook + Log *slog.Logger + Opts *HookOptions +} + +// ID returns the ID of the hook. +func (h *HookBase) ID() string { + return "base" +} + +// Provides indicates which methods a hook provides. The default is none - this method +// should be overridden by the embedding hook. +func (h *HookBase) Provides(b byte) bool { + return false +} + +// Init performs any pre-start initializations for the hook, such as connecting to databases +// or opening files. +func (h *HookBase) Init(config any) error { + return nil +} + +// SetOpts is called by the server to propagate internal values and generally should +// not be called manually. +func (h *HookBase) SetOpts(l *slog.Logger, opts *HookOptions) { + h.Log = l + h.Opts = opts +} + +// Stop is called to gracefully shut down the hook. +func (h *HookBase) Stop() error { + return nil +} + +// OnStarted is called when the server starts. +func (h *HookBase) OnStarted() {} + +// OnStopped is called when the server stops. +func (h *HookBase) OnStopped() {} + +// OnSysInfoTick is called when the server publishes system info. +func (h *HookBase) OnSysInfoTick(*system.Info) {} + +// OnConnectAuthenticate is called when a user attempts to authenticate with the server. +func (h *HookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { + return false +} + +// OnACLCheck is called when a user attempts to subscribe or publish to a topic. +func (h *HookBase) OnACLCheck(cl *Client, topic string, write bool) bool { + return false +} + +// OnConnect is called when a new client connects. +func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) error { + return nil +} + +// OnSessionEstablish is called right after a new client connects and authenticates and right before +// the session is established and CONNACK is sent. +func (h *HookBase) OnSessionEstablish(cl *Client, pk packets.Packet) {} + +// OnSessionEstablished is called when a new client establishes a session (after OnConnect). +func (h *HookBase) OnSessionEstablished(cl *Client, pk packets.Packet) {} + +// OnDisconnect is called when a client is disconnected for any reason. +func (h *HookBase) OnDisconnect(cl *Client, err error, expire bool) {} + +// OnAuthPacket is called when an auth packet is received from the client. +func (h *HookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) { + return pk, nil +} + +// OnPacketRead is called when a packet is received. +func (h *HookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) { + return pk, nil +} + +// OnPacketEncode is called before a packet is byte-encoded and written to the client. +func (h *HookBase) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet { + return pk +} + +// OnPacketSent is called immediately after a packet is written to a client. +func (h *HookBase) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {} + +// OnPacketProcessed is called immediately after a packet from a client is processed. +func (h *HookBase) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {} + +// OnSubscribe is called when a client subscribes to one or more filters. +func (h *HookBase) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet { + return pk +} + +// OnSubscribed is called when a client subscribes to one or more filters. +func (h *HookBase) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {} + +// OnSelectSubscribers is called when selecting subscribers to receive a message. +func (h *HookBase) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers { + return subs +} + +// OnUnsubscribe is called when a client unsubscribes from one or more filters. +func (h *HookBase) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet { + return pk +} + +// OnUnsubscribed is called when a client unsubscribes from one or more filters. +func (h *HookBase) OnUnsubscribed(cl *Client, pk packets.Packet) {} + +// OnPublish is called when a client publishes a message. +func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) { + return pk, nil +} + +// OnPublished is called when a client has published a message to subscribers. +func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {} + +// OnPublishDropped is called when a message to a client is dropped instead of being delivered. +func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {} + +// OnRetainMessage is called then a published message is retained. +func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {} + +// OnRetainPublished is called when a retained message is published. +func (h *HookBase) OnRetainPublished(cl *Client, pk packets.Packet) {} + +// OnQosPublish is called when a publish packet with Qos > 1 is issued to a subscriber. +func (h *HookBase) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {} + +// OnQosComplete is called when the Qos flow for a message has been completed. +func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {} + +// OnQosDropped is called the Qos flow for a message expires. +func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {} + +// OnPacketIDExhausted is called when the client runs out of unused packet ids to assign to a packet. +func (h *HookBase) OnPacketIDExhausted(cl *Client, pk packets.Packet) {} + +// OnWill is called when a client disconnects and publishes an LWT message. +func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) { + return will, nil +} + +// OnWillSent is called when an LWT message has been issued from a disconnecting client. +func (h *HookBase) OnWillSent(cl *Client, pk packets.Packet) {} + +// OnClientExpired is called when a client session has expired. +func (h *HookBase) OnClientExpired(cl *Client) {} + +// OnRetainedExpired is called when a retained message for a topic has expired. +func (h *HookBase) OnRetainedExpired(topic string) {} + +// StoredClients returns all clients from a store. +func (h *HookBase) StoredClients() (v []storage.Client, err error) { + return +} + +// StoredSubscriptions returns all subcriptions from a store. +func (h *HookBase) StoredSubscriptions() (v []storage.Subscription, err error) { + return +} + +// StoredInflightMessages returns all inflight messages from a store. +func (h *HookBase) StoredInflightMessages() (v []storage.Message, err error) { + return +} + +// StoredRetainedMessages 返回存储区中所有保留的消息 +// StoredRetainedMessages returns all retained messages from a store. +func (h *HookBase) StoredRetainedMessages() (v []storage.Message, err error) { + return +} + +// StoredSysInfo 返回一组系统信息值 +// StoredSysInfo returns a set of system info values. +func (h *HookBase) StoredSysInfo() (v storage.SystemInfo, err error) { + return +} diff --git a/mqtt/hooks_test.go b/mqtt/hooks_test.go new file mode 100644 index 0000000..b792727 --- /dev/null +++ b/mqtt/hooks_test.go @@ -0,0 +1,667 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package mqtt + +import ( + "errors" + "strconv" + "sync/atomic" + "testing" + "time" + + "testmqtt/hooks/storage" + "testmqtt/packets" + "testmqtt/system" + + "github.com/stretchr/testify/require" +) + +type modifiedHookBase struct { + HookBase + err error + fail bool + failAt int +} + +var errTestHook = errors.New("error") + +func (h *modifiedHookBase) ID() string { + return "modified" +} + +func (h *modifiedHookBase) Init(config any) error { + if config != nil { + return errTestHook + } + return nil +} + +func (h *modifiedHookBase) Provides(b byte) bool { + return true +} + +func (h *modifiedHookBase) Stop() error { + if h.fail { + return errTestHook + } + + return nil +} + +func (h *modifiedHookBase) OnConnect(cl *Client, pk packets.Packet) error { + if h.fail { + return errTestHook + } + + return nil +} + +func (h *modifiedHookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { + return true +} + +func (h *modifiedHookBase) OnACLCheck(cl *Client, topic string, write bool) bool { + return true +} + +func (h *modifiedHookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) { + if h.fail { + if h.err != nil { + return pk, h.err + } + + return pk, errTestHook + } + + return pk, nil +} + +func (h *modifiedHookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) { + if h.fail { + if h.err != nil { + return pk, h.err + } + + return pk, errTestHook + } + + return pk, nil +} + +func (h *modifiedHookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) { + if h.fail { + if h.err != nil { + return pk, h.err + } + + return pk, errTestHook + } + + return pk, nil +} + +func (h *modifiedHookBase) OnWill(cl *Client, will Will) (Will, error) { + if h.fail { + return will, errTestHook + } + + return will, nil +} + +func (h *modifiedHookBase) StoredClients() (v []storage.Client, err error) { + if h.fail || h.failAt == 1 { + return v, errTestHook + } + + return []storage.Client{ + {ID: "cl1"}, + {ID: "cl2"}, + {ID: "cl3"}, + }, nil +} + +func (h *modifiedHookBase) StoredSubscriptions() (v []storage.Subscription, err error) { + if h.fail || h.failAt == 2 { + return v, errTestHook + } + + return []storage.Subscription{ + {ID: "sub1"}, + {ID: "sub2"}, + {ID: "sub3"}, + }, nil +} + +func (h *modifiedHookBase) StoredRetainedMessages() (v []storage.Message, err error) { + if h.fail || h.failAt == 3 { + return v, errTestHook + } + + return []storage.Message{ + {ID: "r1"}, + {ID: "r2"}, + {ID: "r3"}, + }, nil +} + +func (h *modifiedHookBase) StoredInflightMessages() (v []storage.Message, err error) { + if h.fail || h.failAt == 4 { + return v, errTestHook + } + + return []storage.Message{ + {ID: "i1"}, + {ID: "i2"}, + {ID: "i3"}, + }, nil +} + +func (h *modifiedHookBase) StoredSysInfo() (v storage.SystemInfo, err error) { + if h.fail || h.failAt == 5 { + return v, errTestHook + } + + return storage.SystemInfo{ + Info: system.Info{ + Version: "2.0.0", + }, + }, nil +} + +type providesCheckHook struct { + HookBase +} + +func (h *providesCheckHook) Provides(b byte) bool { + return b == OnConnect +} + +func TestHooksProvides(t *testing.T) { + h := new(Hooks) + err := h.Add(new(providesCheckHook), nil) + require.NoError(t, err) + + err = h.Add(new(HookBase), nil) + require.NoError(t, err) + + require.True(t, h.Provides(OnConnect, OnDisconnect)) + require.False(t, h.Provides(OnDisconnect)) +} + +func TestHooksAddLenGetAll(t *testing.T) { + h := new(Hooks) + err := h.Add(new(HookBase), nil) + require.NoError(t, err) + + err = h.Add(new(modifiedHookBase), nil) + require.NoError(t, err) + + require.Equal(t, int64(2), atomic.LoadInt64(&h.qty)) + require.Equal(t, int64(2), h.Len()) + + all := h.GetAll() + require.Equal(t, "base", all[0].ID()) + require.Equal(t, "modified", all[1].ID()) +} + +func TestHooksAddInitFailure(t *testing.T) { + h := new(Hooks) + err := h.Add(new(modifiedHookBase), map[string]any{}) + require.Error(t, err) + require.Equal(t, int64(0), atomic.LoadInt64(&h.qty)) +} + +func TestHooksStop(t *testing.T) { + h := new(Hooks) + h.Log = logger + + err := h.Add(new(HookBase), nil) + require.NoError(t, err) + require.Equal(t, int64(1), atomic.LoadInt64(&h.qty)) + require.Equal(t, int64(1), h.Len()) + + h.Stop() +} + +// coverage: also cover some empty functions +func TestHooksNonReturns(t *testing.T) { + h := new(Hooks) + cl := new(Client) + + for i := 0; i < 2; i++ { + t.Run("step-"+strconv.Itoa(i), func(t *testing.T) { + // on first iteration, check without hook methods + h.OnStarted() + h.OnStopped() + h.OnSysInfoTick(new(system.Info)) + h.OnSessionEstablish(cl, packets.Packet{}) + h.OnSessionEstablished(cl, packets.Packet{}) + h.OnDisconnect(cl, nil, false) + h.OnPacketSent(cl, packets.Packet{}, []byte{}) + h.OnPacketProcessed(cl, packets.Packet{}, nil) + h.OnSubscribed(cl, packets.Packet{}, []byte{1}) + h.OnUnsubscribed(cl, packets.Packet{}) + h.OnPublished(cl, packets.Packet{}) + h.OnPublishDropped(cl, packets.Packet{}) + h.OnRetainMessage(cl, packets.Packet{}, 0) + h.OnRetainPublished(cl, packets.Packet{}) + h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0) + h.OnQosComplete(cl, packets.Packet{}) + h.OnQosDropped(cl, packets.Packet{}) + h.OnPacketIDExhausted(cl, packets.Packet{}) + h.OnWillSent(cl, packets.Packet{}) + h.OnClientExpired(cl) + h.OnRetainedExpired("a/b/c") + + // on second iteration, check added hook methods + err := h.Add(new(modifiedHookBase), nil) + require.NoError(t, err) + }) + } +} + +func TestHooksOnConnectAuthenticate(t *testing.T) { + h := new(Hooks) + + ok := h.OnConnectAuthenticate(new(Client), packets.Packet{}) + require.False(t, ok) + + err := h.Add(new(modifiedHookBase), nil) + require.NoError(t, err) + + ok = h.OnConnectAuthenticate(new(Client), packets.Packet{}) + require.True(t, ok) +} + +func TestHooksOnACLCheck(t *testing.T) { + h := new(Hooks) + + ok := h.OnACLCheck(new(Client), "a/b/c", true) + require.False(t, ok) + + err := h.Add(new(modifiedHookBase), nil) + require.NoError(t, err) + + ok = h.OnACLCheck(new(Client), "a/b/c", true) + require.True(t, ok) +} + +func TestHooksOnSubscribe(t *testing.T) { + h := new(Hooks) + err := h.Add(new(modifiedHookBase), nil) + require.NoError(t, err) + + pki := packets.Packet{ + Filters: packets.Subscriptions{ + {Filter: "a/b/c", Qos: 1}, + }, + } + pk := h.OnSubscribe(new(Client), pki) + require.EqualValues(t, pk, pki) +} + +func TestHooksOnSelectSubscribers(t *testing.T) { + h := new(Hooks) + err := h.Add(new(modifiedHookBase), nil) + require.NoError(t, err) + + subs := &Subscribers{ + Subscriptions: map[string]packets.Subscription{ + "cl1": {Filter: "a/b/c"}, + }, + } + + subs2 := h.OnSelectSubscribers(subs, packets.Packet{}) + require.EqualValues(t, subs, subs2) +} + +func TestHooksOnUnsubscribe(t *testing.T) { + h := new(Hooks) + err := h.Add(new(modifiedHookBase), nil) + require.NoError(t, err) + + pki := packets.Packet{ + Filters: packets.Subscriptions{ + {Filter: "a/b/c", Qos: 1}, + }, + } + + pk := h.OnUnsubscribe(new(Client), pki) + require.EqualValues(t, pk, pki) +} + +func TestHooksOnPublish(t *testing.T) { + h := new(Hooks) + h.Log = logger + + hook := new(modifiedHookBase) + err := h.Add(hook, nil) + require.NoError(t, err) + + pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10}) + require.NoError(t, err) + require.Equal(t, uint16(10), pk.PacketID) + + // coverage: failure + hook.fail = true + pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10}) + require.Error(t, err) + require.Equal(t, uint16(10), pk.PacketID) + + // coverage: reject packet + hook.err = packets.ErrRejectPacket + pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10}) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrRejectPacket) + require.Equal(t, uint16(10), pk.PacketID) +} + +func TestHooksOnPacketRead(t *testing.T) { + h := new(Hooks) + h.Log = logger + + hook := new(modifiedHookBase) + err := h.Add(hook, nil) + require.NoError(t, err) + + pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10}) + require.NoError(t, err) + require.Equal(t, uint16(10), pk.PacketID) + + // coverage: failure + hook.fail = true + pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10}) + require.NoError(t, err) + require.Equal(t, uint16(10), pk.PacketID) + + // coverage: reject packet + hook.err = packets.ErrRejectPacket + pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10}) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrRejectPacket) + require.Equal(t, uint16(10), pk.PacketID) +} + +func TestHooksOnAuthPacket(t *testing.T) { + h := new(Hooks) + h.Log = logger + + hook := new(modifiedHookBase) + err := h.Add(hook, nil) + require.NoError(t, err) + + pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10}) + require.NoError(t, err) + require.Equal(t, uint16(10), pk.PacketID) + + hook.fail = true + pk, err = h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10}) + require.Error(t, err) + require.Equal(t, uint16(10), pk.PacketID) +} + +func TestHooksOnConnect(t *testing.T) { + h := new(Hooks) + h.Log = logger + + hook := new(modifiedHookBase) + err := h.Add(hook, nil) + require.NoError(t, err) + + err = h.OnConnect(new(Client), packets.Packet{PacketID: 10}) + require.NoError(t, err) + + hook.fail = true + err = h.OnConnect(new(Client), packets.Packet{PacketID: 10}) + require.Error(t, err) +} + +func TestHooksOnPacketEncode(t *testing.T) { + h := new(Hooks) + h.Log = logger + + hook := new(modifiedHookBase) + err := h.Add(hook, nil) + require.NoError(t, err) + + pk := h.OnPacketEncode(new(Client), packets.Packet{PacketID: 10}) + require.Equal(t, uint16(10), pk.PacketID) +} + +func TestHooksOnLWT(t *testing.T) { + h := new(Hooks) + h.Log = logger + + hook := new(modifiedHookBase) + err := h.Add(hook, nil) + require.NoError(t, err) + + lwt := h.OnWill(new(Client), Will{TopicName: "a/b/c"}) + require.Equal(t, "a/b/c", lwt.TopicName) + + // coverage: fail lwt + hook.fail = true + lwt = h.OnWill(new(Client), Will{TopicName: "a/b/c"}) + require.Equal(t, "a/b/c", lwt.TopicName) +} + +func TestHooksStoredClients(t *testing.T) { + h := new(Hooks) + h.Log = logger + + v, err := h.StoredClients() + require.NoError(t, err) + require.Len(t, v, 0) + + hook := new(modifiedHookBase) + err = h.Add(hook, nil) + require.NoError(t, err) + + v, err = h.StoredClients() + require.NoError(t, err) + require.Len(t, v, 3) + + hook.fail = true + v, err = h.StoredClients() + require.Error(t, err) + require.Len(t, v, 0) +} + +func TestHooksStoredSubscriptions(t *testing.T) { + h := new(Hooks) + h.Log = logger + + v, err := h.StoredSubscriptions() + require.NoError(t, err) + require.Len(t, v, 0) + + hook := new(modifiedHookBase) + err = h.Add(hook, nil) + require.NoError(t, err) + + v, err = h.StoredSubscriptions() + require.NoError(t, err) + require.Len(t, v, 3) + + hook.fail = true + v, err = h.StoredSubscriptions() + require.Error(t, err) + require.Len(t, v, 0) +} + +func TestHooksStoredRetainedMessages(t *testing.T) { + h := new(Hooks) + h.Log = logger + + v, err := h.StoredRetainedMessages() + require.NoError(t, err) + require.Len(t, v, 0) + + hook := new(modifiedHookBase) + err = h.Add(hook, nil) + require.NoError(t, err) + + v, err = h.StoredRetainedMessages() + require.NoError(t, err) + require.Len(t, v, 3) + + hook.fail = true + v, err = h.StoredRetainedMessages() + require.Error(t, err) + require.Len(t, v, 0) +} + +func TestHooksStoredInflightMessages(t *testing.T) { + h := new(Hooks) + h.Log = logger + + v, err := h.StoredInflightMessages() + require.NoError(t, err) + require.Len(t, v, 0) + + hook := new(modifiedHookBase) + err = h.Add(hook, nil) + require.NoError(t, err) + + v, err = h.StoredInflightMessages() + require.NoError(t, err) + require.Len(t, v, 3) + + hook.fail = true + v, err = h.StoredInflightMessages() + require.Error(t, err) + require.Len(t, v, 0) +} + +func TestHooksStoredSysInfo(t *testing.T) { + h := new(Hooks) + h.Log = logger + + v, err := h.StoredSysInfo() + require.NoError(t, err) + require.Equal(t, "", v.Info.Version) + + hook := new(modifiedHookBase) + err = h.Add(hook, nil) + require.NoError(t, err) + + v, err = h.StoredSysInfo() + require.NoError(t, err) + require.Equal(t, "2.0.0", v.Info.Version) + + hook.fail = true + v, err = h.StoredSysInfo() + require.Error(t, err) + require.Equal(t, "", v.Info.Version) +} + +func TestHookBaseID(t *testing.T) { + h := new(HookBase) + require.Equal(t, "base", h.ID()) +} + +func TestHookBaseProvidesNone(t *testing.T) { + h := new(HookBase) + require.False(t, h.Provides(OnConnect)) + require.False(t, h.Provides(OnDisconnect)) +} + +func TestHookBaseInit(t *testing.T) { + h := new(HookBase) + require.Nil(t, h.Init(nil)) +} + +func TestHookBaseSetOpts(t *testing.T) { + h := new(HookBase) + h.SetOpts(logger, new(HookOptions)) + require.NotNil(t, h.Log) + require.NotNil(t, h.Opts) +} + +func TestHookBaseClose(t *testing.T) { + h := new(HookBase) + require.Nil(t, h.Stop()) +} + +func TestHookBaseOnConnectAuthenticate(t *testing.T) { + h := new(HookBase) + v := h.OnConnectAuthenticate(new(Client), packets.Packet{}) + require.False(t, v) +} + +func TestHookBaseOnACLCheck(t *testing.T) { + h := new(HookBase) + v := h.OnACLCheck(new(Client), "topic", true) + require.False(t, v) +} + +func TestHookBaseOnConnect(t *testing.T) { + h := new(HookBase) + err := h.OnConnect(new(Client), packets.Packet{}) + require.NoError(t, err) +} + +func TestHookBaseOnPublish(t *testing.T) { + h := new(HookBase) + pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10}) + require.NoError(t, err) + require.Equal(t, uint16(10), pk.PacketID) +} + +func TestHookBaseOnPacketRead(t *testing.T) { + h := new(HookBase) + pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10}) + require.NoError(t, err) + require.Equal(t, uint16(10), pk.PacketID) +} + +func TestHookBaseOnAuthPacket(t *testing.T) { + h := new(HookBase) + pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10}) + require.NoError(t, err) + require.Equal(t, uint16(10), pk.PacketID) +} + +func TestHookBaseOnLWT(t *testing.T) { + h := new(HookBase) + lwt, err := h.OnWill(new(Client), Will{TopicName: "a/b/c"}) + require.NoError(t, err) + require.Equal(t, "a/b/c", lwt.TopicName) +} + +func TestHookBaseStoredClients(t *testing.T) { + h := new(HookBase) + v, err := h.StoredClients() + require.NoError(t, err) + require.Empty(t, v) +} + +func TestHookBaseStoredSubscriptions(t *testing.T) { + h := new(HookBase) + v, err := h.StoredSubscriptions() + require.NoError(t, err) + require.Empty(t, v) +} + +func TestHookBaseStoredInflightMessages(t *testing.T) { + h := new(HookBase) + v, err := h.StoredInflightMessages() + require.NoError(t, err) + require.Empty(t, v) +} + +func TestHookBaseStoredRetainedMessages(t *testing.T) { + h := new(HookBase) + v, err := h.StoredRetainedMessages() + require.NoError(t, err) + require.Empty(t, v) +} + +func TestHookBaseStoreSysInfo(t *testing.T) { + h := new(HookBase) + v, err := h.StoredSysInfo() + require.NoError(t, err) + require.Equal(t, "", v.Version) +} diff --git a/mqtt/inflight.go b/mqtt/inflight.go new file mode 100644 index 0000000..c631c7f --- /dev/null +++ b/mqtt/inflight.go @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package mqtt + +import ( + "sort" + "sync" + "sync/atomic" + + "testmqtt/packets" +) + +// Inflight is a map of InflightMessage keyed on packet id. +type Inflight struct { + sync.RWMutex + internal map[uint16]packets.Packet // internal contains the inflight packets + receiveQuota int32 // remaining inbound qos quota for flow control + sendQuota int32 // remaining outbound qos quota for flow control + maximumReceiveQuota int32 // maximum allowed receive quota + maximumSendQuota int32 // maximum allowed send quota +} + +// NewInflights returns a new instance of an Inflight packets map. +func NewInflights() *Inflight { + return &Inflight{ + internal: map[uint16]packets.Packet{}, + } +} + +// Set adds or updates an inflight packet by packet id. +func (i *Inflight) Set(m packets.Packet) bool { + i.Lock() + defer i.Unlock() + + _, ok := i.internal[m.PacketID] + i.internal[m.PacketID] = m + return !ok +} + +// Get returns an inflight packet by packet id. +func (i *Inflight) Get(id uint16) (packets.Packet, bool) { + i.RLock() + defer i.RUnlock() + + if m, ok := i.internal[id]; ok { + return m, true + } + + return packets.Packet{}, false +} + +// Len returns the size of the inflight messages map. +func (i *Inflight) Len() int { + i.RLock() + defer i.RUnlock() + return len(i.internal) +} + +// Clone returns a new instance of Inflight with the same message data. +// This is used when transferring inflights from a taken-over session. +func (i *Inflight) Clone() *Inflight { + c := NewInflights() + i.RLock() + defer i.RUnlock() + for k, v := range i.internal { + c.internal[k] = v + } + return c +} + +// GetAll returns all the inflight messages. +func (i *Inflight) GetAll(immediate bool) []packets.Packet { + i.RLock() + defer i.RUnlock() + + m := []packets.Packet{} + for _, v := range i.internal { + if !immediate || (immediate && v.Expiry < 0) { + m = append(m, v) + } + } + + sort.Slice(m, func(i, j int) bool { + return uint16(m[i].Created) < uint16(m[j].Created) + }) + + return m +} + +// NextImmediate returns the next inflight packet which is indicated to be sent immediately. +// This typically occurs when the quota has been exhausted, and we need to wait until new quota +// is free to continue sending. +func (i *Inflight) NextImmediate() (packets.Packet, bool) { + i.RLock() + defer i.RUnlock() + + m := i.GetAll(true) + if len(m) > 0 { + return m[0], true + } + + return packets.Packet{}, false +} + +// Delete removes an in-flight message from the map. Returns true if the message existed. +func (i *Inflight) Delete(id uint16) bool { + i.Lock() + defer i.Unlock() + + _, ok := i.internal[id] + delete(i.internal, id) + + return ok +} + +// TakeRecieveQuota reduces the receive quota by 1. +func (i *Inflight) DecreaseReceiveQuota() { + if atomic.LoadInt32(&i.receiveQuota) > 0 { + atomic.AddInt32(&i.receiveQuota, -1) + } +} + +// TakeRecieveQuota increases the receive quota by 1. +func (i *Inflight) IncreaseReceiveQuota() { + if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) { + atomic.AddInt32(&i.receiveQuota, 1) + } +} + +// ResetReceiveQuota resets the receive quota to the maximum allowed value. +func (i *Inflight) ResetReceiveQuota(n int32) { + atomic.StoreInt32(&i.receiveQuota, n) + atomic.StoreInt32(&i.maximumReceiveQuota, n) +} + +// DecreaseSendQuota reduces the send quota by 1. +func (i *Inflight) DecreaseSendQuota() { + if atomic.LoadInt32(&i.sendQuota) > 0 { + atomic.AddInt32(&i.sendQuota, -1) + } +} + +// IncreaseSendQuota increases the send quota by 1. +func (i *Inflight) IncreaseSendQuota() { + if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) { + atomic.AddInt32(&i.sendQuota, 1) + } +} + +// ResetSendQuota resets the send quota to the maximum allowed value. +func (i *Inflight) ResetSendQuota(n int32) { + atomic.StoreInt32(&i.sendQuota, n) + atomic.StoreInt32(&i.maximumSendQuota, n) +} diff --git a/mqtt/inflight_test.go b/mqtt/inflight_test.go new file mode 100644 index 0000000..d755204 --- /dev/null +++ b/mqtt/inflight_test.go @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package mqtt + +import ( + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + "testmqtt/packets" +) + +func TestInflightSet(t *testing.T) { + cl, _, _ := newTestClient() + + r := cl.State.Inflight.Set(packets.Packet{PacketID: 1}) + require.True(t, r) + require.NotNil(t, cl.State.Inflight.internal[1]) + require.NotEqual(t, 0, cl.State.Inflight.internal[1].PacketID) + + r = cl.State.Inflight.Set(packets.Packet{PacketID: 1}) + require.False(t, r) +} + +func TestInflightGet(t *testing.T) { + cl, _, _ := newTestClient() + cl.State.Inflight.Set(packets.Packet{PacketID: 2}) + + msg, ok := cl.State.Inflight.Get(2) + require.True(t, ok) + require.NotEqual(t, 0, msg.PacketID) +} + +func TestInflightGetAllAndImmediate(t *testing.T) { + cl, _, _ := newTestClient() + cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1}) + cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2}) + cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1}) + cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1}) + cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5}) + + require.Equal(t, []packets.Packet{ + {PacketID: 1, Created: 1}, + {PacketID: 2, Created: 2}, + {PacketID: 3, Created: 3, Expiry: -1}, + {PacketID: 4, Created: 4, Expiry: -1}, + {PacketID: 5, Created: 5}, + }, cl.State.Inflight.GetAll(false)) + + require.Equal(t, []packets.Packet{ + {PacketID: 3, Created: 3, Expiry: -1}, + {PacketID: 4, Created: 4, Expiry: -1}, + }, cl.State.Inflight.GetAll(true)) +} + +func TestInflightLen(t *testing.T) { + cl, _, _ := newTestClient() + cl.State.Inflight.Set(packets.Packet{PacketID: 2}) + require.Equal(t, 1, cl.State.Inflight.Len()) +} + +func TestInflightClone(t *testing.T) { + cl, _, _ := newTestClient() + cl.State.Inflight.Set(packets.Packet{PacketID: 2}) + require.Equal(t, 1, cl.State.Inflight.Len()) + + cloned := cl.State.Inflight.Clone() + require.NotNil(t, cloned) + require.NotSame(t, cloned, cl.State.Inflight) +} + +func TestInflightDelete(t *testing.T) { + cl, _, _ := newTestClient() + + cl.State.Inflight.Set(packets.Packet{PacketID: 3}) + require.NotNil(t, cl.State.Inflight.internal[3]) + + r := cl.State.Inflight.Delete(3) + require.True(t, r) + require.Equal(t, uint16(0), cl.State.Inflight.internal[3].PacketID) + + _, ok := cl.State.Inflight.Get(3) + require.False(t, ok) + + r = cl.State.Inflight.Delete(3) + require.False(t, r) +} + +func TestResetReceiveQuota(t *testing.T) { + i := NewInflights() + require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumReceiveQuota)) + require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) + i.ResetReceiveQuota(6) + require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumReceiveQuota)) + require.Equal(t, int32(6), atomic.LoadInt32(&i.receiveQuota)) +} + +func TestReceiveQuota(t *testing.T) { + i := NewInflights() + i.receiveQuota = 4 + i.maximumReceiveQuota = 5 + require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) + require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota)) + + // Return 1 + i.IncreaseReceiveQuota() + require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) + require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota)) + + // Try to go over max limit + i.IncreaseReceiveQuota() + require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) + require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota)) + + // Reset to max 1 + i.ResetReceiveQuota(1) + require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) + require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota)) + + // Take 1 + i.DecreaseReceiveQuota() + require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) + require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) + + // Try to go below zero + i.DecreaseReceiveQuota() + require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) + require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) +} + +func TestResetSendQuota(t *testing.T) { + i := NewInflights() + require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumSendQuota)) + require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) + i.ResetSendQuota(6) + require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumSendQuota)) + require.Equal(t, int32(6), atomic.LoadInt32(&i.sendQuota)) +} + +func TestSendQuota(t *testing.T) { + i := NewInflights() + i.sendQuota = 4 + i.maximumSendQuota = 5 + require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) + require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota)) + + // Return 1 + i.IncreaseSendQuota() + require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) + require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota)) + + // Try to go over max limit + i.IncreaseSendQuota() + require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) + require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota)) + + // Reset to max 1 + i.ResetSendQuota(1) + require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) + require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota)) + + // Take 1 + i.DecreaseSendQuota() + require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) + require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) + + // Try to go below zero + i.DecreaseSendQuota() + require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) + require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) +} + +func TestNextImmediate(t *testing.T) { + cl, _, _ := newTestClient() + cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1}) + cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2}) + cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1}) + cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1}) + cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5}) + + pk, ok := cl.State.Inflight.NextImmediate() + require.True(t, ok) + require.Equal(t, packets.Packet{PacketID: 3, Created: 3, Expiry: -1}, pk) + + r := cl.State.Inflight.Delete(3) + require.True(t, r) + + pk, ok = cl.State.Inflight.NextImmediate() + require.True(t, ok) + require.Equal(t, packets.Packet{PacketID: 4, Created: 4, Expiry: -1}, pk) + + r = cl.State.Inflight.Delete(4) + require.True(t, r) + + _, ok = cl.State.Inflight.NextImmediate() + require.False(t, ok) +} diff --git a/mqtt/server.go b/mqtt/server.go new file mode 100644 index 0000000..af7992a --- /dev/null +++ b/mqtt/server.go @@ -0,0 +1,1759 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +// Package mqtt provides a high performance, fully compliant MQTT v5 broker server with v3.1.1 backward compatibility. +package mqtt + +import ( + "errors" + "fmt" + "math" + "net" + "os" + "runtime" + "sort" + "strconv" + "strings" + "sync/atomic" + "time" + + "testmqtt/hooks/storage" + "testmqtt/listeners" + "testmqtt/packets" + "testmqtt/system" + + "log/slog" +) + +const ( + // Version 服务器版本 + Version = "2.6.5" // the current server version. + // 默认系统主题发布时间间隔 + defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes + LocalListener = "local" + InlineClientId = "inline" +) + +var ( + // Deprecated: Use NewDefaultServerCapabilities to avoid data race issue. + DefaultServerCapabilities = NewDefaultServerCapabilities() + + ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists + ErrConnectionClosed = errors.New("connection not open") // connection is closed + ErrInlineClientNotEnabled = errors.New("please set Options.InlineClient=true to use this feature") // inline client is not enabled by default + ErrOptionsUnreadable = errors.New("unable to read options from bytes") +) + +// Capabilities indicates the capabilities and features provided by the server. +type Capabilities struct { + MaximumClients int64 `yaml:"maximum_clients" json:"maximum_clients"` // maximum number of connected clients + MaximumMessageExpiryInterval int64 `yaml:"maximum_message_expiry_interval" json:"maximum_message_expiry_interval"` // maximum message expiry if message expiry is 0 or over + MaximumClientWritesPending int32 `yaml:"maximum_client_writes_pending" json:"maximum_client_writes_pending"` // maximum number of pending message writes for a client + MaximumSessionExpiryInterval uint32 `yaml:"maximum_session_expiry_interval" json:"maximum_session_expiry_interval"` // maximum number of seconds to keep disconnected sessions + MaximumPacketSize uint32 `yaml:"maximum_packet_size" json:"maximum_packet_size"` // maximum packet size, no limit if 0 + maximumPacketID uint32 // unexported, used for testing only + ReceiveMaximum uint16 `yaml:"receive_maximum" json:"receive_maximum"` // maximum number of concurrent qos messages per client + MaximumInflight uint16 `yaml:"maximum_inflight" json:"maximum_inflight"` // maximum number of qos > 0 messages can be stored, 0(=8192)-65535 + TopicAliasMaximum uint16 `yaml:"topic_alias_maximum" json:"topic_alias_maximum"` // maximum topic alias value + SharedSubAvailable byte `yaml:"shared_sub_available" json:"shared_sub_available"` // support of shared subscriptions + MinimumProtocolVersion byte `yaml:"minimum_protocol_version" json:"minimum_protocol_version"` // minimum supported mqtt version + Compatibilities Compatibilities `yaml:"compatibilities" json:"compatibilities"` // version compatibilities the server provides + MaximumQos byte `yaml:"maximum_qos" json:"maximum_qos"` // maximum qos value available to clients + RetainAvailable byte `yaml:"retain_available" json:"retain_available"` // support of retain messages + WildcardSubAvailable byte `yaml:"wildcard_sub_available" json:"wildcard_sub_available"` // support of wildcard subscriptions + SubIDAvailable byte `yaml:"sub_id_available" json:"sub_id_available"` // support of subscription identifiers +} + +// NewDefaultServerCapabilities defines the default features and capabilities provided by the server. +func NewDefaultServerCapabilities() *Capabilities { + return &Capabilities{ + MaximumClients: math.MaxInt64, // maximum number of connected clients + MaximumMessageExpiryInterval: 60 * 60 * 24, // maximum message expiry if message expiry is 0 or over + MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client + MaximumSessionExpiryInterval: math.MaxUint32, // maximum number of seconds to keep disconnected sessions + MaximumPacketSize: 0, // no maximum packet size + maximumPacketID: math.MaxUint16, + ReceiveMaximum: 1024, // maximum number of concurrent qos messages per client + MaximumInflight: 1024 * 8, // maximum number of qos > 0 messages can be stored + TopicAliasMaximum: math.MaxUint16, // maximum topic alias value + SharedSubAvailable: 1, // shared subscriptions are available + MinimumProtocolVersion: 3, // minimum supported mqtt version (3.0.0) + MaximumQos: 2, // maximum qos value available to clients + RetainAvailable: 1, // retain messages is available + WildcardSubAvailable: 1, // wildcard subscriptions are available + SubIDAvailable: 1, // subscription identifiers are available + } +} + +// Compatibilities provides flags for using compatibility modes. +type Compatibilities struct { + ObscureNotAuthorized bool `yaml:"obscure_not_authorized" json:"obscure_not_authorized"` // return unspecified errors instead of not authorized + PassiveClientDisconnect bool `yaml:"passive_client_disconnect" json:"passive_client_disconnect"` // don't disconnect the client forcefully after sending disconnect packet (paho - spec violation) + AlwaysReturnResponseInfo bool `yaml:"always_return_response_info" json:"always_return_response_info"` // always return response info (useful for testing) + RestoreSysInfoOnRestart bool `yaml:"restore_sys_info_on_restart" json:"restore_sys_info_on_restart"` // restore system info from store as if server never stopped + NoInheritedPropertiesOnAck bool `yaml:"no_inherited_properties_on_ack" json:"no_inherited_properties_on_ack"` // don't allow inherited user properties on ack (paho - spec violation) +} + +// Options contains configurable options for the server. +type Options struct { + // Listeners specifies any listeners which should be dynamically added on serve. Used when setting listeners by config. + Listeners []listeners.Config `yaml:"listeners" json:"listeners"` + + // Hooks specifies any hooks which should be dynamically added on serve. Used when setting hooks by config. + Hooks []HookLoadConfig `yaml:"hooks" json:"hooks"` + + // Capabilities defines the server features and behaviour. If you only wish to modify + // several of these values, set them explicitly - e.g. + // server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024 + Capabilities *Capabilities `yaml:"capabilities" json:"capabilities"` + + // ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer. + ClientNetWriteBufferSize int `yaml:"client_net_write_buffer_size" json:"client_net_write_buffer_size"` + + // ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer. + ClientNetReadBufferSize int `yaml:"client_net_read_buffer_size" json:"client_net_read_buffer_size"` + + // Logger specifies a custom configured implementation of log/slog to override + // the servers default logger configuration. If you wish to change the log level, + // of the default logger, you can do so by setting: + // server := mqtt.New(nil) + // level := new(slog.LevelVar) + // server.Slog = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + // Level: level, + // })) + // level.Set(slog.LevelDebug) + Logger *slog.Logger `yaml:"-" json:"-"` + + // SysTopicResendInterval specifies the interval between $SYS topic updates in seconds. + SysTopicResendInterval int64 `yaml:"sys_topic_resend_interval" json:"sys_topic_resend_interval"` + + // Enable Inline client to allow direct subscribing and publishing from the parent codebase, + // with negligible performance difference (disabled by default to prevent confusion in statistics). + InlineClient bool `yaml:"inline_client" json:"inline_client"` +} + +// Server is an MQTT broker server. It should be created with server.New() +// in order to ensure all the internal fields are correctly populated. +type Server struct { + Options *Options // configurable server options + Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections + Clients *Clients // clients known to the broker + Topics *TopicsIndex // an index of topic filter subscriptions and retained messages + Info *system.Info // values about the server commonly known as $SYS topics + loop *loop // loop contains tickers for the system event loop + done chan bool // indicate that the server is ending + Log *slog.Logger // minimal no-alloc logger + hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage + inlineClient *Client // inlineClient is a special client used for inline subscriptions and inline Publish +} + +// loop contains interval tickers for the system events loop. +type loop struct { + sysTopics *time.Ticker // interval ticker for sending updating $SYS topics + clientExpiry *time.Ticker // interval ticker for cleaning expired clients + inflightExpiry *time.Ticker // interval ticker for cleaning up expired inflight messages + retainedExpiry *time.Ticker // interval ticker for cleaning retained messages + willDelaySend *time.Ticker // interval ticker for sending Will Messages with a delay + willDelayed *packets.Packets // activate LWT packets which will be sent after a delay +} + +// ops contains server values which can be propagated to other structs. +type ops struct { + options *Options // a pointer to the server options and capabilities, for referencing in clients + info *system.Info // pointers to server system info + hooks *Hooks // pointer to the server hooks + log *slog.Logger // a structured logger for the client +} + +// New returns a new instance of mochi mqtt broker. Optional parameters +// can be specified to override some default settings (see Options). +func New(opts *Options) *Server { + if opts == nil { + opts = new(Options) + } + + opts.ensureDefaults() + + s := &Server{ + done: make(chan bool), + Clients: NewClients(), + Topics: NewTopicsIndex(), + Listeners: listeners.New(), + loop: &loop{ + sysTopics: time.NewTicker(time.Second * time.Duration(opts.SysTopicResendInterval)), + clientExpiry: time.NewTicker(time.Second), + inflightExpiry: time.NewTicker(time.Second), + retainedExpiry: time.NewTicker(time.Second), + willDelaySend: time.NewTicker(time.Second), + willDelayed: packets.NewPackets(), + }, + Options: opts, + Info: &system.Info{ + Version: Version, + Started: time.Now().Unix(), + }, + Log: opts.Logger, + hooks: &Hooks{ + Log: opts.Logger, + }, + } + + if s.Options.InlineClient { + s.inlineClient = s.NewClient(nil, LocalListener, InlineClientId, true) + s.Clients.Add(s.inlineClient) + } + + return s +} + +// ensureDefaults ensures that the server starts with sane default values, if none are provided. +func (o *Options) ensureDefaults() { + if o.Capabilities == nil { + o.Capabilities = NewDefaultServerCapabilities() + } + + o.Capabilities.maximumPacketID = math.MaxUint16 // spec maximum is 65535 + + if o.Capabilities.MaximumInflight == 0 { + o.Capabilities.MaximumInflight = 1024 * 8 + } + + if o.SysTopicResendInterval == 0 { + o.SysTopicResendInterval = defaultSysTopicInterval + } + + if o.ClientNetWriteBufferSize == 0 { + o.ClientNetWriteBufferSize = 1024 * 2 + } + + if o.ClientNetReadBufferSize == 0 { + o.ClientNetReadBufferSize = 1024 * 2 + } + + if o.Logger == nil { + log := slog.New(slog.NewTextHandler(os.Stdout, nil)) + o.Logger = log + } +} + +// NewClient returns a new Client instance, populated with all the required values and +// references to be used with the server. If you are using this client to directly publish +// messages from the embedding application, set the inline flag to true to bypass ACL and +// topic validation checks. +func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) *Client { + cl := newClient(c, &ops{ // [MQTT-3.1.2-6] implicit + options: s.Options, + info: s.Info, + hooks: s.hooks, + log: s.Log, + }) + + cl.ID = id + cl.Net.Listener = listener + + if inline { // inline clients bypass acl and some validity checks. + cl.Net.Inline = true + // By default, we don't want to restrict developer publishes, + // but if you do, reset this after creating inline client. + cl.State.Inflight.ResetReceiveQuota(math.MaxInt32) + } + + return cl +} + +// AddHook attaches a new Hook to the server. Ideally, this should be called +// before the server is started with s.Serve(). +func (s *Server) AddHook(hook Hook, config any) error { + nl := s.Log.With("hook", hook.ID()) + hook.SetOpts(nl, &HookOptions{ + Capabilities: s.Options.Capabilities, + }) + + s.Log.Info("added hook", "hook", hook.ID()) + return s.hooks.Add(hook, config) +} + +// AddHooksFromConfig adds hooks to the server which were specified in the hooks config (usually from a config file). +// New built-in hooks should be added to this list. +func (s *Server) AddHooksFromConfig(hooks []HookLoadConfig) error { + for _, h := range hooks { + if err := s.AddHook(h.Hook, h.Config); err != nil { + return err + } + } + return nil +} + +// AddListener adds a new network listener to the server, for receiving incoming client connections. +func (s *Server) AddListener(l listeners.Listener) error { + if _, ok := s.Listeners.Get(l.ID()); ok { + return ErrListenerIDExists + } + + nl := s.Log.With(slog.String("listener", l.ID())) + err := l.Init(nl) + if err != nil { + return err + } + + s.Listeners.Add(l) + + s.Log.Info("attached listener", "id", l.ID(), "protocol", l.Protocol(), "address", l.Address()) + return nil +} + +// AddListenersFromConfig adds listeners to the server which were specified in the listeners config (usually from a config file). +// New built-in listeners should be added to this list. +func (s *Server) AddListenersFromConfig(configs []listeners.Config) error { + for _, conf := range configs { + var l listeners.Listener + switch strings.ToLower(conf.Type) { + case listeners.TypeTCP: + l = listeners.NewTCP(conf) + case listeners.TypeWS: + l = listeners.NewWebsocket(conf) + case listeners.TypeUnix: + l = listeners.NewUnixSock(conf) + case listeners.TypeHealthCheck: + l = listeners.NewHTTPHealthCheck(conf) + case listeners.TypeSysInfo: + l = listeners.NewHTTPStats(conf, s.Info) + case listeners.TypeMock: + l = listeners.NewMockListener(conf.ID, conf.Address) + default: + s.Log.Error("listener type unavailable by config", "listener", conf.Type) + continue + } + if err := s.AddListener(l); err != nil { + return err + } + } + return nil +} + +// Serve starts the event loops responsible for establishing client connections +// on all attached listeners, publishing the system topics, and starting all hooks. +func (s *Server) Serve() error { + s.Log.Info("mochi mqtt starting", "version", Version) + defer s.Log.Info("mochi mqtt server started") + + if len(s.Options.Listeners) > 0 { + err := s.AddListenersFromConfig(s.Options.Listeners) + if err != nil { + return err + } + } + + if len(s.Options.Hooks) > 0 { + err := s.AddHooksFromConfig(s.Options.Hooks) + if err != nil { + return err + } + } + + if s.hooks.Provides( + StoredClients, + StoredInflightMessages, + StoredRetainedMessages, + StoredSubscriptions, + StoredSysInfo, + ) { + err := s.readStore() + if err != nil { + return err + } + } + + go s.eventLoop() // spin up event loop for issuing $SYS values and closing server. + s.Listeners.ServeAll(s.EstablishConnection) // start listening on all listeners. + s.publishSysTopics() // begin publishing $SYS system values. + s.hooks.OnStarted() + + return nil +} + +// eventLoop loops forever, running various server housekeeping methods at different intervals. +func (s *Server) eventLoop() { + s.Log.Debug("system event loop started") + defer s.Log.Debug("system event loop halted") + + for { + select { + case <-s.done: + s.loop.sysTopics.Stop() + return + case <-s.loop.sysTopics.C: + s.publishSysTopics() + case <-s.loop.clientExpiry.C: + s.clearExpiredClients(time.Now().Unix()) + case <-s.loop.retainedExpiry.C: + s.clearExpiredRetainedMessages(time.Now().Unix()) + case <-s.loop.willDelaySend.C: + s.sendDelayedLWT(time.Now().Unix()) + case <-s.loop.inflightExpiry.C: + s.clearExpiredInflights(time.Now().Unix()) + } + } +} + +// EstablishConnection establishes a new client when a listener accepts a new connection. +func (s *Server) EstablishConnection(listener string, c net.Conn) error { + cl := s.NewClient(c, listener, "", false) + return s.attachClient(cl, listener) +} + +// attachClient validates an incoming client connection and if viable, attaches the client +// to the server, performs session housekeeping, and reads incoming packets. +func (s *Server) attachClient(cl *Client, listener string) error { + defer s.Listeners.ClientsWg.Done() + s.Listeners.ClientsWg.Add(1) + + go cl.WriteLoop() + defer cl.Stop(nil) + + pk, err := s.readConnectionPacket(cl) + if err != nil { + return fmt.Errorf("read connection: %w", err) + } + + cl.ParseConnect(listener, pk) + if atomic.LoadInt64(&s.Info.ClientsConnected) >= s.Options.Capabilities.MaximumClients { + if cl.Properties.ProtocolVersion < 5 { + s.SendConnack(cl, packets.ErrServerUnavailable, false, nil) + } else { + s.SendConnack(cl, packets.ErrServerBusy, false, nil) + } + + return packets.ErrServerBusy + } + + code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2] + if code != packets.CodeSuccess { + if err := s.SendConnack(cl, code, false, nil); err != nil { + return fmt.Errorf("invalid connection send ack: %w", err) + } + return code // [MQTT-3.2.2-7] [MQTT-3.1.4-6] + } + + err = s.hooks.OnConnect(cl, pk) + if err != nil { + return err + } + + cl.refreshDeadline(cl.State.Keepalive) + if !s.hooks.OnConnectAuthenticate(cl, pk) { // [MQTT-3.1.4-2] + err := s.SendConnack(cl, packets.ErrBadUsernameOrPassword, false, nil) + if err != nil { + return fmt.Errorf("invalid connection send ack: %w", err) + } + + return packets.ErrBadUsernameOrPassword + } + + atomic.AddInt64(&s.Info.ClientsConnected, 1) + defer atomic.AddInt64(&s.Info.ClientsConnected, -1) + + s.hooks.OnSessionEstablish(cl, pk) + + sessionPresent := s.inheritClientSession(pk, cl) + s.Clients.Add(cl) // [MQTT-4.1.0-1] + + err = s.SendConnack(cl, code, sessionPresent, nil) // [MQTT-3.1.4-5] [MQTT-3.2.0-1] [MQTT-3.2.0-2] &[MQTT-3.14.0-1] + if err != nil { + return fmt.Errorf("ack connection packet: %w", err) + } + + s.loop.willDelayed.Delete(cl.ID) // [MQTT-3.1.3-9] + + if sessionPresent { + err = cl.ResendInflightMessages(true) + if err != nil { + return fmt.Errorf("resend inflight: %w", err) + } + } + + s.hooks.OnSessionEstablished(cl, pk) + + err = cl.Read(s.receivePacket) + if err != nil { + s.sendLWT(cl) + cl.Stop(err) + } else { + cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10] + } + s.Log.Debug("client disconnected", "error", err, "client", cl.ID, "remote", cl.Net.Remote, "listener", listener) + + expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean) + s.hooks.OnDisconnect(cl, err, expire) + + if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 { + cl.ClearInflights() + s.UnsubscribeClient(cl) + s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23] + } + + return err +} + +// readConnectionPacket reads the first incoming header for a connection, and if +// acceptable, returns the valid connection packet. +func (s *Server) readConnectionPacket(cl *Client) (pk packets.Packet, err error) { + fh := new(packets.FixedHeader) + err = cl.ReadFixedHeader(fh) + if err != nil { + return + } + + if fh.Type != packets.Connect { + return pk, packets.ErrProtocolViolationRequireFirstConnect // [MQTT-3.1.0-1] + } + + pk, err = cl.ReadPacket(fh) + if err != nil { + return + } + + return +} + +// receivePacket processes an incoming packet for a client, and issues a disconnect to the client +// if an error has occurred (if mqtt v5). +func (s *Server) receivePacket(cl *Client, pk packets.Packet) error { + err := s.processPacket(cl, pk) + if err != nil { + if code, ok := err.(packets.Code); ok && + cl.Properties.ProtocolVersion == 5 && + code.Code >= packets.ErrUnspecifiedError.Code { + _ = s.DisconnectClient(cl, code) + } + + s.Log.Warn("error processing packet", "error", err, "client", cl.ID, "listener", cl.Net.Listener, "pk", pk) + + return err + } + + return nil +} + +// validateConnect validates that a connect packet is compliant. +func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code { + code := pk.ConnectValidate() // [MQTT-3.1.4-1] [MQTT-3.1.4-2] + if code != packets.CodeSuccess { + return code + } + + if cl.Properties.ProtocolVersion < 5 && !pk.Connect.Clean && pk.Connect.ClientIdentifier == "" { + return packets.ErrUnspecifiedError + } + + if cl.Properties.ProtocolVersion < s.Options.Capabilities.MinimumProtocolVersion { + return packets.ErrUnsupportedProtocolVersion // [MQTT-3.1.2-2] + } else if cl.Properties.Will.Qos > s.Options.Capabilities.MaximumQos { + return packets.ErrQosNotSupported // [MQTT-3.2.2-12] + } else if cl.Properties.Will.Retain && s.Options.Capabilities.RetainAvailable == 0x00 { + return packets.ErrRetainNotSupported // [MQTT-3.2.2-13] + } + + return code +} + +// inheritClientSession inherits the state of an existing client sharing the same +// connection ID. If clean is true, the state of any previously existing client +// session is abandoned. +func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool { + if existing, ok := s.Clients.Get(cl.ID); ok { + _ = s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3] + if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4] + s.UnsubscribeClient(existing) + existing.ClearInflights() + atomic.StoreUint32(&existing.State.isTakenOver, 1) // only set isTakenOver after unsubscribe has occurred + return false // [MQTT-3.2.2-3] + } + + atomic.StoreUint32(&existing.State.isTakenOver, 1) + if existing.State.Inflight.Len() > 0 { + cl.State.Inflight = existing.State.Inflight.Clone() // [MQTT-3.1.2-5] + if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.options.Capabilities.ReceiveMaximum != 0 { + cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client + cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max + } + } + + for _, sub := range existing.State.Subscriptions.GetAll() { + existed := !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3] + if !existed { + atomic.AddInt64(&s.Info.Subscriptions, 1) + } + cl.State.Subscriptions.Add(sub.Filter, sub) + } + + // Clean the state of the existing client to prevent sequential take-overs + // from increasing memory usage by inflights + subs * client-id. + s.UnsubscribeClient(existing) + existing.ClearInflights() + + s.Log.Debug("session taken over", "client", cl.ID, "old_remote", existing.Net.Remote, "new_remote", cl.Net.Remote) + + return true // [MQTT-3.2.2-3] + } + + if atomic.LoadInt64(&s.Info.ClientsConnected) > atomic.LoadInt64(&s.Info.ClientsMaximum) { + atomic.AddInt64(&s.Info.ClientsMaximum, 1) + } + + return false // [MQTT-3.2.2-2] +} + +// SendConnack returns a Connack packet to a client. +func (s *Server) SendConnack(cl *Client, reason packets.Code, present bool, properties *packets.Properties) error { + if properties == nil { + properties = &packets.Properties{ + ReceiveMaximum: s.Options.Capabilities.ReceiveMaximum, + } + } + + properties.ReceiveMaximum = s.Options.Capabilities.ReceiveMaximum // 3.2.2.3.3 Receive Maximum + if cl.State.ServerKeepalive { // You can set this dynamically using the OnConnect hook. + properties.ServerKeepAlive = cl.State.Keepalive // [MQTT-3.1.2-21] + properties.ServerKeepAliveFlag = true + } + + if reason.Code >= packets.ErrUnspecifiedError.Code { + if cl.Properties.ProtocolVersion < 5 { + if v3reason, ok := packets.V5CodesToV3[reason]; ok { // NB v3 3.2.2.3 Connack return codes + reason = v3reason + } + } + + properties.ReasonString = reason.Reason + ack := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Connack, + }, + SessionPresent: false, // [MQTT-3.2.2-6] + ReasonCode: reason.Code, // [MQTT-3.2.2-8] + Properties: *properties, + } + return cl.WritePacket(ack) + } + + if s.Options.Capabilities.MaximumQos < 2 { + properties.MaximumQos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] + properties.MaximumQosFlag = true + } + + if cl.Properties.Props.AssignedClientID != "" { + properties.AssignedClientID = cl.Properties.Props.AssignedClientID // [MQTT-3.1.3-7] [MQTT-3.2.2-16] + } + + if cl.Properties.Props.SessionExpiryInterval > s.Options.Capabilities.MaximumSessionExpiryInterval { + properties.SessionExpiryInterval = s.Options.Capabilities.MaximumSessionExpiryInterval + properties.SessionExpiryIntervalFlag = true + cl.Properties.Props.SessionExpiryInterval = properties.SessionExpiryInterval + cl.Properties.Props.SessionExpiryIntervalFlag = true + } + + ack := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Connack, + }, + SessionPresent: present, + ReasonCode: reason.Code, // [MQTT-3.2.2-8] + Properties: *properties, + } + return cl.WritePacket(ack) +} + +// processPacket processes an inbound packet for a client. Since the method is +// typically called as a goroutine, errors are primarily for test checking purposes. +func (s *Server) processPacket(cl *Client, pk packets.Packet) error { + var err error + + switch pk.FixedHeader.Type { + case packets.Connect: + err = s.processConnect(cl, pk) + case packets.Disconnect: + err = s.processDisconnect(cl, pk) + case packets.Pingreq: + err = s.processPingreq(cl, pk) + case packets.Publish: + code := pk.PublishValidate(s.Options.Capabilities.TopicAliasMaximum) + if code != packets.CodeSuccess { + return code + } + err = s.processPublish(cl, pk) + case packets.Puback: + err = s.processPuback(cl, pk) + case packets.Pubrec: + err = s.processPubrec(cl, pk) + case packets.Pubrel: + err = s.processPubrel(cl, pk) + case packets.Pubcomp: + err = s.processPubcomp(cl, pk) + case packets.Subscribe: + code := pk.SubscribeValidate() + if code != packets.CodeSuccess { + return code + } + err = s.processSubscribe(cl, pk) + case packets.Unsubscribe: + code := pk.UnsubscribeValidate() + if code != packets.CodeSuccess { + return code + } + err = s.processUnsubscribe(cl, pk) + case packets.Auth: + code := pk.AuthValidate() + if code != packets.CodeSuccess { + return code + } + err = s.processAuth(cl, pk) + default: + return fmt.Errorf("no valid packet available; %v", pk.FixedHeader.Type) + } + + s.hooks.OnPacketProcessed(cl, pk, err) + if err != nil { + return err + } + + if cl.State.Inflight.Len() > 0 && atomic.LoadInt32(&cl.State.Inflight.sendQuota) > 0 { + next, ok := cl.State.Inflight.NextImmediate() + if ok { + _ = cl.WritePacket(next) + if ok := cl.State.Inflight.Delete(next.PacketID); ok { + atomic.AddInt64(&s.Info.Inflight, -1) + } + cl.State.Inflight.DecreaseSendQuota() + } + } + + return nil +} + +// processConnect processes a Connect packet. The packet cannot be used to establish +// a new connection on an existing connection. See EstablishConnection instead. +func (s *Server) processConnect(cl *Client, _ packets.Packet) error { + s.sendLWT(cl) + return packets.ErrProtocolViolationSecondConnect // [MQTT-3.1.0-2] +} + +// processPingreq processes a Pingreq packet. +func (s *Server) processPingreq(cl *Client, _ packets.Packet) error { + return cl.WritePacket(packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Pingresp, // [MQTT-3.12.4-1] + }, + }) +} + +// Publish publishes a publish packet into the broker as if it were sent from the specified client. +// This is a convenience function which wraps InjectPacket. As such, this method can publish packets +// to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the +// outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete). +func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error { + if !s.Options.InlineClient { + return ErrInlineClientNotEnabled + } + + return s.InjectPacket(s.inlineClient, packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + Qos: qos, + Retain: retain, + }, + TopicName: topic, + Payload: payload, + PacketID: uint16(qos), // we never process the inbound qos, but we need a packet id for validity checks. + }) +} + +// Subscribe adds an inline subscription for the specified topic filter and subscription identifier +// with the provided handler function. +func (s *Server) Subscribe(filter string, subscriptionId int, handler InlineSubFn) error { + if !s.Options.InlineClient { + return ErrInlineClientNotEnabled + } + + if handler == nil { + return packets.ErrInlineSubscriptionHandlerInvalid + } + + if !IsValidFilter(filter, false) { + return packets.ErrTopicFilterInvalid + } + + subscription := packets.Subscription{ + Identifier: subscriptionId, + Filter: filter, + } + + pk := s.hooks.OnSubscribe(s.inlineClient, packets.Packet{ // subscribe like a normal client. + Origin: s.inlineClient.ID, + FixedHeader: packets.FixedHeader{Type: packets.Subscribe}, + Filters: packets.Subscriptions{subscription}, + }) + + inlineSubscription := InlineSubscription{ + Subscription: subscription, + Handler: handler, + } + + s.Topics.InlineSubscribe(inlineSubscription) + s.hooks.OnSubscribed(s.inlineClient, pk, []byte{packets.CodeSuccess.Code}) + + // Handling retained messages. + for _, pkv := range s.Topics.Messages(filter) { // [MQTT-3.8.4-4] + handler(s.inlineClient, inlineSubscription.Subscription, pkv) + } + return nil +} + +// Unsubscribe removes an inline subscription for the specified subscription and topic filter. +// It allows you to unsubscribe a specific subscription from the internal subscription +// associated with the given topic filter. +func (s *Server) Unsubscribe(filter string, subscriptionId int) error { + if !s.Options.InlineClient { + return ErrInlineClientNotEnabled + } + + if !IsValidFilter(filter, false) { + return packets.ErrTopicFilterInvalid + } + + pk := s.hooks.OnUnsubscribe(s.inlineClient, packets.Packet{ + Origin: s.inlineClient.ID, + FixedHeader: packets.FixedHeader{Type: packets.Unsubscribe}, + Filters: packets.Subscriptions{ + { + Identifier: subscriptionId, + Filter: filter, + }, + }, + }) + + s.Topics.InlineUnsubscribe(subscriptionId, filter) + s.hooks.OnUnsubscribed(s.inlineClient, pk) + return nil +} + +// InjectPacket injects a packet into the broker as if it were sent from the specified client. +// InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks. +func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error { + pk.ProtocolVersion = cl.Properties.ProtocolVersion + + err := s.processPacket(cl, pk) + if err != nil { + return err + } + + atomic.AddInt64(&cl.ops.info.PacketsReceived, 1) + if pk.FixedHeader.Type == packets.Publish { + atomic.AddInt64(&cl.ops.info.MessagesReceived, 1) + } + + return nil +} + +// processPublish processes a Publish packet. +func (s *Server) processPublish(cl *Client, pk packets.Packet) error { + if !cl.Net.Inline && !IsValidFilter(pk.TopicName, true) { + return nil + } + + if atomic.LoadInt32(&cl.State.Inflight.receiveQuota) == 0 { + return s.DisconnectClient(cl, packets.ErrReceiveMaximum) // ~[MQTT-3.3.4-7] ~[MQTT-3.3.4-8] + } + + if !cl.Net.Inline && !s.hooks.OnACLCheck(cl, pk.TopicName, true) { + if pk.FixedHeader.Qos == 0 { + return nil + } + + if cl.Properties.ProtocolVersion != 5 { + return s.DisconnectClient(cl, packets.ErrNotAuthorized) + } + + ackType := packets.Puback + if pk.FixedHeader.Qos == 2 { + ackType = packets.Pubrec + } + + ack := s.buildAck(pk.PacketID, ackType, 0, pk.Properties, packets.ErrNotAuthorized) + return cl.WritePacket(ack) + } + + pk.Origin = cl.ID + pk.Created = time.Now().Unix() + + if !cl.Net.Inline { + if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok { + if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10] + ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse) + return cl.WritePacket(ack) + } + if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5] + atomic.AddInt64(&s.Info.Inflight, -1) + } + } + } + + if pk.Properties.TopicAliasFlag && pk.Properties.TopicAlias > 0 { // [MQTT-3.3.2-11] + pk.TopicName = cl.State.TopicAliases.Inbound.Set(pk.Properties.TopicAlias, pk.TopicName) + } + + if pk.FixedHeader.Qos > s.Options.Capabilities.MaximumQos { + pk.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] Reduce qos based on server max qos capability + } + + pkx, err := s.hooks.OnPublish(cl, pk) + if err == nil { + pk = pkx + } else if errors.Is(err, packets.ErrRejectPacket) { + return nil + } else if errors.Is(err, packets.CodeSuccessIgnore) { + pk.Ignore = true + } else if cl.Properties.ProtocolVersion == 5 && pk.FixedHeader.Qos > 0 && errors.As(err, new(packets.Code)) { + err = cl.WritePacket(s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, err.(packets.Code))) + if err != nil { + return err + } + return nil + } + + if pk.FixedHeader.Retain { // [MQTT-3.3.1-5] ![MQTT-3.3.1-8] + s.retainMessage(cl, pk) + } + + // If it's inlineClient, it can't handle PUBREC and PUBREL. + // When it publishes a package with a qos > 0, the server treats + // the package as qos=0, and the client receives it as qos=1 or 2. + if pk.FixedHeader.Qos == 0 || cl.Net.Inline { + s.publishToSubscribers(pk) + s.hooks.OnPublished(cl, pk) + return nil + } + + cl.State.Inflight.DecreaseReceiveQuota() + ack := s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, packets.QosCodes[pk.FixedHeader.Qos]) // [MQTT-4.3.2-4] + if pk.FixedHeader.Qos == 2 { + ack = s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.CodeSuccess) // [MQTT-3.3.4-1] [MQTT-4.3.3-8] + } + + if ok := cl.State.Inflight.Set(ack); ok { + atomic.AddInt64(&s.Info.Inflight, 1) + s.hooks.OnQosPublish(cl, ack, ack.Created, 0) + } + + err = cl.WritePacket(ack) + if err != nil { + return err + } + + if pk.FixedHeader.Qos == 1 { + if ok := cl.State.Inflight.Delete(ack.PacketID); ok { + atomic.AddInt64(&s.Info.Inflight, -1) + } + cl.State.Inflight.IncreaseReceiveQuota() + s.hooks.OnQosComplete(cl, ack) + } + + s.publishToSubscribers(pk) + s.hooks.OnPublished(cl, pk) + + return nil +} + +// retainMessage adds a message to a topic, and if a persistent store is provided, +// adds the message to the store to be reloaded if necessary. +func (s *Server) retainMessage(cl *Client, pk packets.Packet) { + if s.Options.Capabilities.RetainAvailable == 0 || pk.Ignore { + return + } + + out := pk.Copy(false) + r := s.Topics.RetainMessage(out) + s.hooks.OnRetainMessage(cl, pk, r) + atomic.StoreInt64(&s.Info.Retained, int64(s.Topics.Retained.Len())) +} + +// publishToSubscribers publishes a publish packet to all subscribers with matching topic filters. +func (s *Server) publishToSubscribers(pk packets.Packet) { + if pk.Ignore { + return + } + + if pk.Created == 0 { + pk.Created = time.Now().Unix() + } + + pk.Expiry = pk.Created + s.Options.Capabilities.MaximumMessageExpiryInterval + if pk.Properties.MessageExpiryInterval > 0 { + pk.Expiry = pk.Created + int64(pk.Properties.MessageExpiryInterval) + } + + subscribers := s.Topics.Subscribers(pk.TopicName) + if len(subscribers.Shared) > 0 { + subscribers = s.hooks.OnSelectSubscribers(subscribers, pk) + if len(subscribers.SharedSelected) == 0 { + subscribers.SelectShared() + } + subscribers.MergeSharedSelected() + } + + for _, inlineSubscription := range subscribers.InlineSubscriptions { + inlineSubscription.Handler(s.inlineClient, inlineSubscription.Subscription, pk) + } + + for id, subs := range subscribers.Subscriptions { + if cl, ok := s.Clients.Get(id); ok { + _, err := s.publishToClient(cl, subs, pk) + if err != nil { + s.Log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk) + } + } + } +} + +func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (packets.Packet, error) { + if sub.NoLocal && pk.Origin == cl.ID { + return pk, nil // [MQTT-3.8.3-3] + } + + out := pk.Copy(false) + if !s.hooks.OnACLCheck(cl, pk.TopicName, false) { + return out, packets.ErrNotAuthorized + } + if !sub.FwdRetainedFlag && ((cl.Properties.ProtocolVersion == 5 && !sub.RetainAsPublished) || cl.Properties.ProtocolVersion < 5) { // ![MQTT-3.3.1-13] [v3 MQTT-3.3.1-9] + out.FixedHeader.Retain = false // [MQTT-3.3.1-12] + } + + if len(sub.Identifiers) > 0 { // [MQTT-3.3.4-3] + out.Properties.SubscriptionIdentifier = []int{} + for _, id := range sub.Identifiers { + out.Properties.SubscriptionIdentifier = append(out.Properties.SubscriptionIdentifier, id) // [MQTT-3.3.4-4] ![MQTT-3.3.4-5] + } + sort.Ints(out.Properties.SubscriptionIdentifier) + } + + if out.FixedHeader.Qos > sub.Qos { + out.FixedHeader.Qos = sub.Qos + } + + if out.FixedHeader.Qos > s.Options.Capabilities.MaximumQos { + out.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] + } + + if cl.Properties.Props.TopicAliasMaximum > 0 { + var aliasExists bool + out.Properties.TopicAlias, aliasExists = cl.State.TopicAliases.Outbound.Set(pk.TopicName) + if out.Properties.TopicAlias > 0 { + out.Properties.TopicAliasFlag = true + if aliasExists { + out.TopicName = "" + } + } + } + + if out.FixedHeader.Qos > 0 { + if cl.State.Inflight.Len() >= int(s.Options.Capabilities.MaximumInflight) { + // add hook? + atomic.AddInt64(&s.Info.InflightDropped, 1) + s.Log.Warn("client store quota reached", "client", cl.ID, "listener", cl.Net.Listener) + return out, packets.ErrQuotaExceeded + } + + i, err := cl.NextPacketID() // [MQTT-4.3.2-1] [MQTT-4.3.3-1] + if err != nil { + s.hooks.OnPacketIDExhausted(cl, pk) + atomic.AddInt64(&s.Info.InflightDropped, 1) + s.Log.Warn("packet ids exhausted", "error", err, "client", cl.ID, "listener", cl.Net.Listener) + return out, packets.ErrQuotaExceeded + } + + out.PacketID = uint16(i) // [MQTT-2.2.1-4] + sentQuota := atomic.LoadInt32(&cl.State.Inflight.sendQuota) + + if ok := cl.State.Inflight.Set(out); ok { // [MQTT-4.3.2-3] [MQTT-4.3.3-3] + atomic.AddInt64(&s.Info.Inflight, 1) + s.hooks.OnQosPublish(cl, out, out.Created, 0) + cl.State.Inflight.DecreaseSendQuota() + } + + if sentQuota == 0 && atomic.LoadInt32(&cl.State.Inflight.maximumSendQuota) > 0 { + out.Expiry = -1 + cl.State.Inflight.Set(out) + return out, nil + } + } + + if cl.Net.Conn == nil || cl.Closed() { + return out, packets.CodeDisconnect + } + + select { + case cl.State.outbound <- &out: + atomic.AddInt32(&cl.State.outboundQty, 1) + default: + atomic.AddInt64(&s.Info.MessagesDropped, 1) + cl.ops.hooks.OnPublishDropped(cl, pk) + if out.FixedHeader.Qos > 0 { + cl.State.Inflight.Delete(out.PacketID) // packet was dropped due to irregular circumstances, so rollback inflight. + cl.State.Inflight.IncreaseSendQuota() + } + return out, packets.ErrPendingClientWritesExceeded + } + + return out, nil +} + +func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, existed bool) { + if IsSharedFilter(sub.Filter) { + return // 4.8.2 Non-normative - Shared Subscriptions - No Retained Messages are sent to the Session when it first subscribes. + } + + if sub.RetainHandling == 1 && existed || sub.RetainHandling == 2 { // [MQTT-3.3.1-10] [MQTT-3.3.1-11] + return + } + + sub.FwdRetainedFlag = true + for _, pkv := range s.Topics.Messages(sub.Filter) { // [MQTT-3.8.4-4] + _, err := s.publishToClient(cl, sub, pkv) + if err != nil { + s.Log.Debug("failed to publish retained message", "error", err, "client", cl.ID, "listener", cl.Net.Listener, "packet", pkv) + continue + } + s.hooks.OnRetainPublished(cl, pkv) + } +} + +// buildAck builds a standardised ack message for Puback, Pubrec, Pubrel, Pubcomp packets. +func (s *Server) buildAck(packetID uint16, pkt, qos byte, properties packets.Properties, reason packets.Code) packets.Packet { + if s.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck { + properties = packets.Properties{} + } + if reason.Code >= packets.ErrUnspecifiedError.Code { + properties.ReasonString = reason.Reason + } + + pk := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: pkt, + Qos: qos, + }, + PacketID: packetID, // [MQTT-2.2.1-5] + ReasonCode: reason.Code, // [MQTT-3.4.2-1] + Properties: properties, + Created: time.Now().Unix(), + Expiry: time.Now().Unix() + s.Options.Capabilities.MaximumMessageExpiryInterval, + } + + return pk +} + +// processPuback processes a Puback packet, denoting completion of a QOS 1 packet sent from the server. +func (s *Server) processPuback(cl *Client, pk packets.Packet) error { + if _, ok := cl.State.Inflight.Get(pk.PacketID); !ok { + return nil // omit, but would be packets.ErrPacketIdentifierNotFound + } + + if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5] + cl.State.Inflight.IncreaseSendQuota() + atomic.AddInt64(&s.Info.Inflight, -1) + s.hooks.OnQosComplete(cl, pk) + } + + return nil +} + +// processPubrec processes a Pubrec packet, denoting receipt of a QOS 2 packet sent from the server. +func (s *Server) processPubrec(cl *Client, pk packets.Packet) error { + if _, ok := cl.State.Inflight.Get(pk.PacketID); !ok { // [MQTT-4.3.3-7] [MQTT-4.3.3-13] + return cl.WritePacket(s.buildAck(pk.PacketID, packets.Pubrel, 1, pk.Properties, packets.ErrPacketIdentifierNotFound)) + } + + if pk.ReasonCode >= packets.ErrUnspecifiedError.Code || !pk.ReasonCodeValid() { // [MQTT-4.3.3-4] + if ok := cl.State.Inflight.Delete(pk.PacketID); ok { + atomic.AddInt64(&s.Info.Inflight, -1) + } + cl.ops.hooks.OnQosDropped(cl, pk) + return nil // as per MQTT5 Section 4.13.2 paragraph 2 + } + + ack := s.buildAck(pk.PacketID, packets.Pubrel, 1, pk.Properties, packets.CodeSuccess) // [MQTT-4.3.3-4] ![MQTT-4.3.3-6] + cl.State.Inflight.DecreaseReceiveQuota() // -1 RECV QUOTA + cl.State.Inflight.Set(ack) // [MQTT-4.3.3-5] + return cl.WritePacket(ack) +} + +// processPubrel processes a Pubrel packet, denoting completion of a QOS 2 packet sent from the client. +func (s *Server) processPubrel(cl *Client, pk packets.Packet) error { + if _, ok := cl.State.Inflight.Get(pk.PacketID); !ok { // [MQTT-4.3.3-7] [MQTT-4.3.3-13] + return cl.WritePacket(s.buildAck(pk.PacketID, packets.Pubcomp, 0, pk.Properties, packets.ErrPacketIdentifierNotFound)) + } + + if pk.ReasonCode >= packets.ErrUnspecifiedError.Code || !pk.ReasonCodeValid() { // [MQTT-4.3.3-9] + if ok := cl.State.Inflight.Delete(pk.PacketID); ok { + atomic.AddInt64(&s.Info.Inflight, -1) + } + cl.ops.hooks.OnQosDropped(cl, pk) + return nil + } + + ack := s.buildAck(pk.PacketID, packets.Pubcomp, 0, pk.Properties, packets.CodeSuccess) // [MQTT-4.3.3-11] + cl.State.Inflight.Set(ack) + + err := cl.WritePacket(ack) + if err != nil { + return err + } + + cl.State.Inflight.IncreaseReceiveQuota() // +1 RECV QUOTA + cl.State.Inflight.IncreaseSendQuota() // +1 SENT QUOTA + if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.3-12] + atomic.AddInt64(&s.Info.Inflight, -1) + s.hooks.OnQosComplete(cl, pk) + } + + return nil +} + +// processPubcomp processes a Pubcomp packet, denoting completion of a QOS 2 packet sent from the server. +func (s *Server) processPubcomp(cl *Client, pk packets.Packet) error { + // regardless of whether the pubcomp is a success or failure, we end the qos flow, delete inflight, and restore the quotas. + cl.State.Inflight.IncreaseReceiveQuota() // +1 RECV QUOTA + cl.State.Inflight.IncreaseSendQuota() // +1 SENT QUOTA + if ok := cl.State.Inflight.Delete(pk.PacketID); ok { + atomic.AddInt64(&s.Info.Inflight, -1) + s.hooks.OnQosComplete(cl, pk) + } + + return nil +} + +// processSubscribe processes a Subscribe packet. +func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error { + pk = s.hooks.OnSubscribe(cl, pk) + code := packets.CodeSuccess + if _, ok := cl.State.Inflight.Get(pk.PacketID); ok { + code = packets.ErrPacketIdentifierInUse + } + + filterExisted := make([]bool, len(pk.Filters)) + reasonCodes := make([]byte, len(pk.Filters)) + for i, sub := range pk.Filters { + if code != packets.CodeSuccess { + reasonCodes[i] = code.Code // NB 3.9.3 Non-normative 0x91 + continue + } else if !IsValidFilter(sub.Filter, false) { + reasonCodes[i] = packets.ErrTopicFilterInvalid.Code + } else if sub.NoLocal && IsSharedFilter(sub.Filter) { + reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4] + } else if !s.hooks.OnACLCheck(cl, sub.Filter, false) { + reasonCodes[i] = packets.ErrNotAuthorized.Code + if s.Options.Capabilities.Compatibilities.ObscureNotAuthorized { + reasonCodes[i] = packets.ErrUnspecifiedError.Code + } + } else { + isNew := s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3] + if isNew { + atomic.AddInt64(&s.Info.Subscriptions, 1) + } + cl.State.Subscriptions.Add(sub.Filter, sub) // [MQTT-3.2.2-10] + + if sub.Qos > s.Options.Capabilities.MaximumQos { + sub.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] + } + + filterExisted[i] = !isNew + reasonCodes[i] = sub.Qos // [MQTT-3.9.3-1] [MQTT-3.8.4-7] + } + + if reasonCodes[i] > packets.CodeGrantedQos2.Code && cl.Properties.ProtocolVersion < 5 { // MQTT3 + reasonCodes[i] = packets.ErrUnspecifiedError.Code + } + } + + ack := packets.Packet{ // [MQTT-3.8.4-1] [MQTT-3.8.4-5] + FixedHeader: packets.FixedHeader{ + Type: packets.Suback, + }, + PacketID: pk.PacketID, // [MQTT-2.2.1-6] [MQTT-3.8.4-2] + ReasonCodes: reasonCodes, // [MQTT-3.8.4-6] + Properties: packets.Properties{ + User: pk.Properties.User, + }, + } + + if code.Code >= packets.ErrUnspecifiedError.Code { + ack.Properties.ReasonString = code.Reason + } + + s.hooks.OnSubscribed(cl, pk, reasonCodes) + err := cl.WritePacket(ack) + if err != nil { + return err + } + + for i, sub := range pk.Filters { // [MQTT-3.3.1-9] + if reasonCodes[i] >= packets.ErrUnspecifiedError.Code { + continue + } + + s.publishRetainedToClient(cl, sub, filterExisted[i]) + } + + return nil +} + +// processUnsubscribe processes an unsubscribe packet. +func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error { + code := packets.CodeSuccess + if _, ok := cl.State.Inflight.Get(pk.PacketID); ok { + code = packets.ErrPacketIdentifierInUse + } + + pk = s.hooks.OnUnsubscribe(cl, pk) + reasonCodes := make([]byte, len(pk.Filters)) + for i, sub := range pk.Filters { // [MQTT-3.10.4-6] [MQTT-3.11.3-1] + if code != packets.CodeSuccess { + reasonCodes[i] = code.Code // NB 3.11.3 Non-normative 0x91 + continue + } + + if q := s.Topics.Unsubscribe(sub.Filter, cl.ID); q { + atomic.AddInt64(&s.Info.Subscriptions, -1) + reasonCodes[i] = packets.CodeSuccess.Code + } else { + reasonCodes[i] = packets.CodeNoSubscriptionExisted.Code + } + + cl.State.Subscriptions.Delete(sub.Filter) // [MQTT-3.10.4-2] [MQTT-3.10.4-2] ~[MQTT-3.10.4-3] + } + + ack := packets.Packet{ // [MQTT-3.10.4-4] + FixedHeader: packets.FixedHeader{ + Type: packets.Unsuback, + }, + PacketID: pk.PacketID, // [MQTT-2.2.1-6] [MQTT-3.10.4-5] + ReasonCodes: reasonCodes, // [MQTT-3.11.3-2] + Properties: packets.Properties{ + User: pk.Properties.User, + }, + } + + if code.Code >= packets.ErrUnspecifiedError.Code { + ack.Properties.ReasonString = code.Reason + } + + s.hooks.OnUnsubscribed(cl, pk) + return cl.WritePacket(ack) +} + +// UnsubscribeClient unsubscribes a client from all of their subscriptions. +func (s *Server) UnsubscribeClient(cl *Client) { + i := 0 + filterMap := cl.State.Subscriptions.GetAll() + filters := make([]packets.Subscription, len(filterMap)) + for k := range filterMap { + cl.State.Subscriptions.Delete(k) + } + + if atomic.LoadUint32(&cl.State.isTakenOver) == 1 { + return + } + + for k, v := range filterMap { + if s.Topics.Unsubscribe(k, cl.ID) { + atomic.AddInt64(&s.Info.Subscriptions, -1) + } + filters[i] = v + i++ + } + s.hooks.OnUnsubscribed(cl, packets.Packet{FixedHeader: packets.FixedHeader{Type: packets.Unsubscribe}, Filters: filters}) +} + +// processAuth processes an Auth packet. +func (s *Server) processAuth(cl *Client, pk packets.Packet) error { + _, err := s.hooks.OnAuthPacket(cl, pk) + if err != nil { + return err + } + + return nil +} + +// processDisconnect processes a Disconnect packet. +func (s *Server) processDisconnect(cl *Client, pk packets.Packet) error { + if pk.Properties.SessionExpiryIntervalFlag { + if pk.Properties.SessionExpiryInterval > 0 && cl.Properties.Props.SessionExpiryInterval == 0 { + return packets.ErrProtocolViolationZeroNonZeroExpiry + } + + cl.Properties.Props.SessionExpiryInterval = pk.Properties.SessionExpiryInterval + cl.Properties.Props.SessionExpiryIntervalFlag = true + } + + if pk.ReasonCode == packets.CodeDisconnectWillMessage.Code { // [MQTT-3.1.2.5] Non-normative comment + return packets.CodeDisconnectWillMessage + } + + s.loop.willDelayed.Delete(cl.ID) // [MQTT-3.1.3-9] [MQTT-3.1.2-8] + cl.Stop(packets.CodeDisconnect) // [MQTT-3.14.4-2] + + return nil +} + +// DisconnectClient sends a Disconnect packet to a client and then closes the client connection. +func (s *Server) DisconnectClient(cl *Client, code packets.Code) error { + out := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Disconnect, + }, + ReasonCode: code.Code, + Properties: packets.Properties{}, + } + + if code.Code >= packets.ErrUnspecifiedError.Code { + out.Properties.ReasonString = code.Reason // // [MQTT-3.14.2-1] + } + + // We already have a code we are using to disconnect the client, so we are not + // interested if the write packet fails due to a closed connection (as we are closing it). + err := cl.WritePacket(out) + if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect { + cl.Stop(code) + if code.Code >= packets.ErrUnspecifiedError.Code { + return code + } + } + + return err +} + +// publishSysTopics publishes the current values to the server $SYS topics. +// Due to the int to string conversions this method is not as cheap as +// some of the others so the publishing interval should be set appropriately. +func (s *Server) publishSysTopics() { + pk := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + Retain: true, + }, + Created: time.Now().Unix(), + } + + var m runtime.MemStats + runtime.ReadMemStats(&m) + atomic.StoreInt64(&s.Info.MemoryAlloc, int64(m.HeapInuse)) + atomic.StoreInt64(&s.Info.Threads, int64(runtime.NumGoroutine())) + atomic.StoreInt64(&s.Info.Time, time.Now().Unix()) + atomic.StoreInt64(&s.Info.Uptime, time.Now().Unix()-atomic.LoadInt64(&s.Info.Started)) + atomic.StoreInt64(&s.Info.ClientsTotal, int64(s.Clients.Len())) + atomic.StoreInt64(&s.Info.ClientsDisconnected, atomic.LoadInt64(&s.Info.ClientsTotal)-atomic.LoadInt64(&s.Info.ClientsConnected)) + + info := s.Info.Clone() + topics := map[string]string{ + SysPrefix + "/broker/version": s.Info.Version, + SysPrefix + "/broker/time": Int64toa(info.Time), + SysPrefix + "/broker/uptime": Int64toa(info.Uptime), + SysPrefix + "/broker/started": Int64toa(info.Started), + SysPrefix + "/broker/load/bytes/received": Int64toa(info.BytesReceived), + SysPrefix + "/broker/load/bytes/sent": Int64toa(info.BytesSent), + SysPrefix + "/broker/clients/connected": Int64toa(info.ClientsConnected), + SysPrefix + "/broker/clients/disconnected": Int64toa(info.ClientsDisconnected), + SysPrefix + "/broker/clients/maximum": Int64toa(info.ClientsMaximum), + SysPrefix + "/broker/clients/total": Int64toa(info.ClientsTotal), + SysPrefix + "/broker/packets/received": Int64toa(info.PacketsReceived), + SysPrefix + "/broker/packets/sent": Int64toa(info.PacketsSent), + SysPrefix + "/broker/messages/received": Int64toa(info.MessagesReceived), + SysPrefix + "/broker/messages/sent": Int64toa(info.MessagesSent), + SysPrefix + "/broker/messages/dropped": Int64toa(info.MessagesDropped), + SysPrefix + "/broker/messages/inflight": Int64toa(info.Inflight), + SysPrefix + "/broker/retained": Int64toa(info.Retained), + SysPrefix + "/broker/subscriptions": Int64toa(info.Subscriptions), + SysPrefix + "/broker/system/memory": Int64toa(info.MemoryAlloc), + SysPrefix + "/broker/system/threads": Int64toa(info.Threads), + } + + for topic, payload := range topics { + pk.TopicName = topic + pk.Payload = []byte(payload) + s.Topics.RetainMessage(pk.Copy(false)) + s.publishToSubscribers(pk) + } + + s.hooks.OnSysInfoTick(info) +} + +// Close attempts to gracefully shut down the server, all listeners, clients, and stores. +func (s *Server) Close() error { + close(s.done) + s.Log.Info("gracefully stopping server") + s.Listeners.CloseAll(s.closeListenerClients) + s.hooks.OnStopped() + s.hooks.Stop() + + s.Log.Info("mochi mqtt server stopped") + return nil +} + +// closeListenerClients closes all clients on the specified listener. +func (s *Server) closeListenerClients(listener string) { + clients := s.Clients.GetByListener(listener) + for _, cl := range clients { + _ = s.DisconnectClient(cl, packets.ErrServerShuttingDown) + } +} + +// sendLWT issues an LWT message to a topic when a client disconnects. +func (s *Server) sendLWT(cl *Client) { + if atomic.LoadUint32(&cl.Properties.Will.Flag) == 0 { + return + } + + modifiedLWT := s.hooks.OnWill(cl, cl.Properties.Will) + + pk := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + Retain: modifiedLWT.Retain, // [MQTT-3.1.2-14] [MQTT-3.1.2-15] + Qos: modifiedLWT.Qos, + }, + TopicName: modifiedLWT.TopicName, + Payload: modifiedLWT.Payload, + Properties: packets.Properties{ + User: modifiedLWT.User, + }, + Origin: cl.ID, + Created: time.Now().Unix(), + } + + if cl.Properties.Will.WillDelayInterval > 0 { + pk.Connect.WillProperties.WillDelayInterval = cl.Properties.Will.WillDelayInterval + pk.Expiry = time.Now().Unix() + int64(pk.Connect.WillProperties.WillDelayInterval) + s.loop.willDelayed.Add(cl.ID, pk) + return + } + + if pk.FixedHeader.Retain { + s.retainMessage(cl, pk) + } + + s.publishToSubscribers(pk) // [MQTT-3.1.2-8] + atomic.StoreUint32(&cl.Properties.Will.Flag, 0) // [MQTT-3.1.2-10] + s.hooks.OnWillSent(cl, pk) +} + +// readStore reads in any data from the persistent datastore (if applicable). +func (s *Server) readStore() error { + if s.hooks.Provides(StoredClients) { + clients, err := s.hooks.StoredClients() + if err != nil { + return fmt.Errorf("failed to load clients; %w", err) + } + s.loadClients(clients) + s.Log.Debug("loaded clients from store", "len", len(clients)) + } + + if s.hooks.Provides(StoredSubscriptions) { + subs, err := s.hooks.StoredSubscriptions() + if err != nil { + return fmt.Errorf("load subscriptions; %w", err) + } + s.loadSubscriptions(subs) + s.Log.Debug("loaded subscriptions from store", "len", len(subs)) + } + + if s.hooks.Provides(StoredInflightMessages) { + inflight, err := s.hooks.StoredInflightMessages() + if err != nil { + return fmt.Errorf("load inflight; %w", err) + } + s.loadInflight(inflight) + s.Log.Debug("loaded inflights from store", "len", len(inflight)) + } + + if s.hooks.Provides(StoredRetainedMessages) { + retained, err := s.hooks.StoredRetainedMessages() + if err != nil { + return fmt.Errorf("load retained; %w", err) + } + s.loadRetained(retained) + s.Log.Debug("loaded retained messages from store", "len", len(retained)) + } + + if s.hooks.Provides(StoredSysInfo) { + sysInfo, err := s.hooks.StoredSysInfo() + if err != nil { + return fmt.Errorf("load server info; %w", err) + } + s.loadServerInfo(sysInfo.Info) + s.Log.Debug("loaded $SYS info from store") + } + + return nil +} + +// loadServerInfo restores server info from the datastore. +func (s *Server) loadServerInfo(v system.Info) { + if s.Options.Capabilities.Compatibilities.RestoreSysInfoOnRestart { + atomic.StoreInt64(&s.Info.BytesReceived, v.BytesReceived) + atomic.StoreInt64(&s.Info.BytesSent, v.BytesSent) + atomic.StoreInt64(&s.Info.ClientsMaximum, v.ClientsMaximum) + atomic.StoreInt64(&s.Info.ClientsTotal, v.ClientsTotal) + atomic.StoreInt64(&s.Info.ClientsDisconnected, v.ClientsDisconnected) + atomic.StoreInt64(&s.Info.MessagesReceived, v.MessagesReceived) + atomic.StoreInt64(&s.Info.MessagesSent, v.MessagesSent) + atomic.StoreInt64(&s.Info.MessagesDropped, v.MessagesDropped) + atomic.StoreInt64(&s.Info.PacketsReceived, v.PacketsReceived) + atomic.StoreInt64(&s.Info.PacketsSent, v.PacketsSent) + atomic.StoreInt64(&s.Info.InflightDropped, v.InflightDropped) + } + atomic.StoreInt64(&s.Info.Retained, v.Retained) + atomic.StoreInt64(&s.Info.Inflight, v.Inflight) + atomic.StoreInt64(&s.Info.Subscriptions, v.Subscriptions) +} + +// loadSubscriptions restores subscriptions from the datastore. +func (s *Server) loadSubscriptions(v []storage.Subscription) { + for _, sub := range v { + sb := packets.Subscription{ + Filter: sub.Filter, + RetainHandling: sub.RetainHandling, + Qos: sub.Qos, + RetainAsPublished: sub.RetainAsPublished, + NoLocal: sub.NoLocal, + Identifier: sub.Identifier, + } + if s.Topics.Subscribe(sub.Client, sb) { + if cl, ok := s.Clients.Get(sub.Client); ok { + cl.State.Subscriptions.Add(sub.Filter, sb) + } + } + } +} + +// loadClients restores clients from the datastore. +func (s *Server) loadClients(v []storage.Client) { + for _, c := range v { + cl := s.NewClient(nil, c.Listener, c.ID, false) + cl.Properties.Username = c.Username + cl.Properties.Clean = c.Clean + cl.Properties.ProtocolVersion = c.ProtocolVersion + cl.Properties.Props = packets.Properties{ + SessionExpiryInterval: c.Properties.SessionExpiryInterval, + SessionExpiryIntervalFlag: c.Properties.SessionExpiryIntervalFlag, + AuthenticationMethod: c.Properties.AuthenticationMethod, + AuthenticationData: c.Properties.AuthenticationData, + RequestProblemInfoFlag: c.Properties.RequestProblemInfoFlag, + RequestProblemInfo: c.Properties.RequestProblemInfo, + RequestResponseInfo: c.Properties.RequestResponseInfo, + ReceiveMaximum: c.Properties.ReceiveMaximum, + TopicAliasMaximum: c.Properties.TopicAliasMaximum, + User: c.Properties.User, + MaximumPacketSize: c.Properties.MaximumPacketSize, + } + cl.Properties.Will = Will(c.Will) + + // cancel the context, update cl.State such as disconnected time and stopCause. + cl.Stop(packets.ErrServerShuttingDown) + + expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean) + s.hooks.OnDisconnect(cl, packets.ErrServerShuttingDown, expire) + if expire { + cl.ClearInflights() + s.UnsubscribeClient(cl) + } else { + s.Clients.Add(cl) + } + } +} + +// loadInflight restores inflight messages from the datastore. +func (s *Server) loadInflight(v []storage.Message) { + for _, msg := range v { + if client, ok := s.Clients.Get(msg.Origin); ok { + client.State.Inflight.Set(msg.ToPacket()) + } + } +} + +// loadRetained restores retained messages from the datastore. +func (s *Server) loadRetained(v []storage.Message) { + for _, msg := range v { + s.Topics.RetainMessage(msg.ToPacket()) + } +} + +// clearExpiredClients deletes all clients which have been disconnected for longer +// than their given expiry intervals. +func (s *Server) clearExpiredClients(dt int64) { + for id, client := range s.Clients.GetAll() { + disconnected := client.StopTime() + if disconnected == 0 { + continue + } + + expire := s.Options.Capabilities.MaximumSessionExpiryInterval + if client.Properties.ProtocolVersion == 5 && client.Properties.Props.SessionExpiryIntervalFlag { + expire = client.Properties.Props.SessionExpiryInterval + } + + if disconnected+int64(expire) < dt { + s.hooks.OnClientExpired(client) + s.Clients.Delete(id) // [MQTT-4.1.0-2] + } + } +} + +// clearExpiredRetainedMessage deletes retained messages from topics if they have expired. +func (s *Server) clearExpiredRetainedMessages(now int64) { + for filter, pk := range s.Topics.Retained.GetAll() { + expired := pk.ProtocolVersion == 5 && pk.Expiry > 0 && pk.Expiry < now // [MQTT-3.3.2-5] + + // If the maximum message expiry interval is set (greater than 0), and the message + // retention period exceeds the maximum expiry, the message will be forcibly removed. + enforced := s.Options.Capabilities.MaximumMessageExpiryInterval > 0 && + now-pk.Created > s.Options.Capabilities.MaximumMessageExpiryInterval + + if expired || enforced { + s.Topics.Retained.Delete(filter) + s.hooks.OnRetainedExpired(filter) + } + } +} + +// clearExpiredInflights deletes any inflight messages which have expired. +func (s *Server) clearExpiredInflights(now int64) { + for _, client := range s.Clients.GetAll() { + if deleted := client.ClearExpiredInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 { + for _, id := range deleted { + s.hooks.OnQosDropped(client, packets.Packet{PacketID: id}) + } + } + } +} + +// sendDelayedLWT sends any LWT messages which have reached their issue time. +func (s *Server) sendDelayedLWT(dt int64) { + for id, pk := range s.loop.willDelayed.GetAll() { + if dt > pk.Expiry { + s.publishToSubscribers(pk) // [MQTT-3.1.2-8] + if cl, ok := s.Clients.Get(id); ok { + if pk.FixedHeader.Retain { + s.retainMessage(cl, pk) + } + cl.Properties.Will = Will{} // [MQTT-3.1.2-10] + s.hooks.OnWillSent(cl, pk) + } + s.loop.willDelayed.Delete(id) + } + } +} + +// Int64toa converts an int64 to a string. +func Int64toa(v int64) string { + return strconv.FormatInt(v, 10) +} diff --git a/mqtt/server_test.go b/mqtt/server_test.go new file mode 100644 index 0000000..d9b8705 --- /dev/null +++ b/mqtt/server_test.go @@ -0,0 +1,3915 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package mqtt + +import ( + "bytes" + "encoding/binary" + "io" + "log/slog" + "net" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + + "testmqtt/hooks/storage" + "testmqtt/listeners" + "testmqtt/packets" + "testmqtt/system" + + "github.com/stretchr/testify/require" +) + +var logger = slog.New(slog.NewTextHandler(io.Discard, nil)) + +type ProtocolTest []struct { + protocolVersion byte + in packets.TPacketCase + out packets.TPacketCase + data map[string]any +} + +type AllowHook struct { + HookBase +} + +func (h *AllowHook) SetOpts(l *slog.Logger, opts *HookOptions) { + h.Log = l + h.Opts = opts +} + +func (h *AllowHook) ID() string { + return "allow-all-auth" +} + +func (h *AllowHook) Provides(b byte) bool { + return bytes.Contains([]byte{OnConnectAuthenticate, OnACLCheck}, []byte{b}) +} + +func (h *AllowHook) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { return true } +func (h *AllowHook) OnACLCheck(cl *Client, topic string, write bool) bool { return true } + +type DenyHook struct { + HookBase +} + +func (h *DenyHook) SetOpts(l *slog.Logger, opts *HookOptions) { + h.Log = l + h.Opts = opts +} + +func (h *DenyHook) ID() string { + return "deny-all-auth" +} + +func (h *DenyHook) Provides(b byte) bool { + return bytes.Contains([]byte{OnConnectAuthenticate, OnACLCheck}, []byte{b}) +} + +func (h *DenyHook) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { return false } +func (h *DenyHook) OnACLCheck(cl *Client, topic string, write bool) bool { return false } + +type DelayHook struct { + HookBase + DisconnectDelay time.Duration +} + +func (h *DelayHook) SetOpts(l *slog.Logger, opts *HookOptions) { + h.Log = l + h.Opts = opts +} + +func (h *DelayHook) ID() string { + return "delay-hook" +} + +func (h *DelayHook) Provides(b byte) bool { + return bytes.Contains([]byte{OnDisconnect}, []byte{b}) +} + +func (h *DelayHook) OnDisconnect(cl *Client, err error, expire bool) { + time.Sleep(h.DisconnectDelay) +} + +func newServer() *Server { + cc := NewDefaultServerCapabilities() + cc.MaximumMessageExpiryInterval = 0 + cc.ReceiveMaximum = 0 + s := New(&Options{ + Logger: logger, + Capabilities: cc, + }) + _ = s.AddHook(new(AllowHook), nil) + return s +} + +func newServerWithInlineClient() *Server { + cc := NewDefaultServerCapabilities() + cc.MaximumMessageExpiryInterval = 0 + cc.ReceiveMaximum = 0 + s := New(&Options{ + Logger: logger, + Capabilities: cc, + InlineClient: true, + }) + _ = s.AddHook(new(AllowHook), nil) + return s +} + +func TestOptionsSetDefaults(t *testing.T) { + opts := &Options{} + opts.ensureDefaults() + + require.Equal(t, defaultSysTopicInterval, opts.SysTopicResendInterval) + require.Equal(t, NewDefaultServerCapabilities(), opts.Capabilities) + + opts = new(Options) + opts.ensureDefaults() + require.Equal(t, defaultSysTopicInterval, opts.SysTopicResendInterval) +} + +func TestNew(t *testing.T) { + s := New(nil) + require.NotNil(t, s) + require.NotNil(t, s.Clients) + require.NotNil(t, s.Listeners) + require.NotNil(t, s.Topics) + require.NotNil(t, s.Info) + require.NotNil(t, s.Log) + require.NotNil(t, s.Options) + require.NotNil(t, s.loop) + require.NotNil(t, s.loop.sysTopics) + require.NotNil(t, s.loop.inflightExpiry) + require.NotNil(t, s.loop.clientExpiry) + require.NotNil(t, s.hooks) + require.NotNil(t, s.hooks.Log) + require.NotNil(t, s.done) + require.Nil(t, s.inlineClient) + require.Equal(t, 0, s.Clients.Len()) +} + +func TestNewWithInlineClient(t *testing.T) { + s := New(&Options{ + InlineClient: true, + }) + require.NotNil(t, s.inlineClient) + require.Equal(t, 1, s.Clients.Len()) +} + +func TestNewNilOpts(t *testing.T) { + s := New(nil) + require.NotNil(t, s) + require.NotNil(t, s.Options) +} + +func TestServerNewClient(t *testing.T) { + s := New(nil) + s.Log = logger + r, _ := net.Pipe() + + cl := s.NewClient(r, "testing", "test", false) + require.NotNil(t, cl) + require.Equal(t, "test", cl.ID) + require.Equal(t, "testing", cl.Net.Listener) + require.False(t, cl.Net.Inline) + require.NotNil(t, cl.State.Inflight.internal) + require.NotNil(t, cl.State.Subscriptions) + require.NotNil(t, cl.State.TopicAliases) + require.Equal(t, defaultKeepalive, cl.State.Keepalive) + require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion) + require.NotNil(t, cl.Net.Conn) + require.NotNil(t, cl.Net.bconn) + require.NotNil(t, cl.ops) + require.Equal(t, s.Log, cl.ops.log) +} + +func TestServerNewClientInline(t *testing.T) { + s := New(nil) + cl := s.NewClient(nil, "testing", "test", true) + require.True(t, cl.Net.Inline) +} + +func TestServerAddHook(t *testing.T) { + s := New(nil) + + s.Log = logger + require.NotNil(t, s) + + require.Equal(t, int64(0), s.hooks.Len()) + err := s.AddHook(new(HookBase), nil) + require.NoError(t, err) + require.Equal(t, int64(1), s.hooks.Len()) +} + +func TestServerAddListener(t *testing.T) { + s := newServer() + defer s.Close() + + require.NotNil(t, s) + + err := s.AddListener(listeners.NewMockListener("t1", ":1882")) + require.NoError(t, err) + + // add existing listener + err = s.AddListener(listeners.NewMockListener("t1", ":1882")) + require.Error(t, err) + require.Equal(t, ErrListenerIDExists, err) +} + +func TestServerAddHooksFromConfig(t *testing.T) { + s := newServer() + defer s.Close() + require.NotNil(t, s) + s.Log = logger + + hooks := []HookLoadConfig{ + {Hook: new(modifiedHookBase)}, + } + + err := s.AddHooksFromConfig(hooks) + require.NoError(t, err) +} + +func TestServerAddHooksFromConfigError(t *testing.T) { + s := newServer() + defer s.Close() + require.NotNil(t, s) + s.Log = logger + + hooks := []HookLoadConfig{ + {Hook: new(modifiedHookBase), Config: map[string]interface{}{}}, + } + + err := s.AddHooksFromConfig(hooks) + require.Error(t, err) +} + +func TestServerAddListenerInitFailure(t *testing.T) { + s := newServer() + defer s.Close() + + require.NotNil(t, s) + + m := listeners.NewMockListener("t1", ":1882") + m.ErrListen = true + err := s.AddListener(m) + require.Error(t, err) +} + +func TestServerAddListenersFromConfig(t *testing.T) { + s := newServer() + defer s.Close() + require.NotNil(t, s) + s.Log = logger + + lc := []listeners.Config{ + {Type: listeners.TypeTCP, ID: "tcp", Address: ":1883"}, + {Type: listeners.TypeWS, ID: "ws", Address: ":1882"}, + {Type: listeners.TypeHealthCheck, ID: "health", Address: ":1881"}, + {Type: listeners.TypeSysInfo, ID: "info", Address: ":1880"}, + {Type: listeners.TypeUnix, ID: "unix", Address: "mochi.sock"}, + {Type: listeners.TypeMock, ID: "mock", Address: "0"}, + {Type: "unknown", ID: "unknown"}, + } + + err := s.AddListenersFromConfig(lc) + require.NoError(t, err) + require.Equal(t, 6, s.Listeners.Len()) + + tcp, _ := s.Listeners.Get("tcp") + require.Equal(t, "[::]:1883", tcp.Address()) + + ws, _ := s.Listeners.Get("ws") + require.Equal(t, ":1882", ws.Address()) + + health, _ := s.Listeners.Get("health") + require.Equal(t, ":1881", health.Address()) + + info, _ := s.Listeners.Get("info") + require.Equal(t, ":1880", info.Address()) + + unix, _ := s.Listeners.Get("unix") + require.Equal(t, "mochi.sock", unix.Address()) + + mock, _ := s.Listeners.Get("mock") + require.Equal(t, "0", mock.Address()) +} + +func TestServerAddListenersFromConfigError(t *testing.T) { + s := newServer() + defer s.Close() + require.NotNil(t, s) + s.Log = logger + + lc := []listeners.Config{ + {Type: listeners.TypeTCP, ID: "tcp", Address: "x"}, + } + + err := s.AddListenersFromConfig(lc) + require.Error(t, err) + require.Equal(t, 0, s.Listeners.Len()) +} + +func TestServerServe(t *testing.T) { + s := newServer() + defer s.Close() + + require.NotNil(t, s) + + err := s.AddListener(listeners.NewMockListener("t1", ":1882")) + require.NoError(t, err) + + err = s.Serve() + require.NoError(t, err) + + time.Sleep(time.Millisecond) + + require.Equal(t, 1, s.Listeners.Len()) + listener, ok := s.Listeners.Get("t1") + + require.Equal(t, true, ok) + require.Equal(t, true, listener.(*listeners.MockListener).IsServing()) +} + +func TestServerServeFromConfig(t *testing.T) { + s := newServer() + defer s.Close() + require.NotNil(t, s) + + s.Options.Listeners = []listeners.Config{ + {Type: listeners.TypeMock, ID: "mock", Address: "0"}, + } + + s.Options.Hooks = []HookLoadConfig{ + {Hook: new(modifiedHookBase)}, + } + + err := s.Serve() + require.NoError(t, err) + + time.Sleep(time.Millisecond) + + require.Equal(t, 1, s.Listeners.Len()) + listener, ok := s.Listeners.Get("mock") + + require.Equal(t, true, ok) + require.Equal(t, true, listener.(*listeners.MockListener).IsServing()) +} + +func TestServerServeFromConfigListenerError(t *testing.T) { + s := newServer() + defer s.Close() + require.NotNil(t, s) + + s.Options.Listeners = []listeners.Config{ + {Type: listeners.TypeTCP, ID: "tcp", Address: "x"}, + } + + err := s.Serve() + require.Error(t, err) +} + +func TestServerServeFromConfigHookError(t *testing.T) { + s := newServer() + defer s.Close() + require.NotNil(t, s) + + s.Options.Hooks = []HookLoadConfig{ + {Hook: new(modifiedHookBase), Config: map[string]interface{}{}}, + } + + err := s.Serve() + require.Error(t, err) +} + +func TestServerServeReadStoreFailure(t *testing.T) { + s := newServer() + defer s.Close() + + require.NotNil(t, s) + + err := s.AddListener(listeners.NewMockListener("t1", ":1882")) + require.NoError(t, err) + + hook := new(modifiedHookBase) + hook.failAt = 1 + err = s.AddHook(hook, nil) + require.NoError(t, err) + + err = s.Serve() + require.Error(t, err) +} + +func TestServerEventLoop(t *testing.T) { + s := newServer() + defer s.Close() + + s.loop.sysTopics = time.NewTicker(time.Millisecond) + s.loop.inflightExpiry = time.NewTicker(time.Millisecond) + s.loop.clientExpiry = time.NewTicker(time.Millisecond) + s.loop.retainedExpiry = time.NewTicker(time.Millisecond) + s.loop.willDelaySend = time.NewTicker(time.Millisecond) + go s.eventLoop() + + time.Sleep(time.Millisecond * 3) +} + +func TestServerReadConnectionPacket(t *testing.T) { + s := newServer() + defer s.Close() + + cl, r, _ := newTestClient() + s.Clients.Add(cl) + + o := make(chan packets.Packet) + go func() { + pk, err := s.readConnectionPacket(cl) + require.NoError(t, err) + o <- pk + }() + + go func() { + _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + _ = r.Close() + }() + + require.Equal(t, *packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet, <-o) +} + +func TestServerReadConnectionPacketBadFixedHeader(t *testing.T) { + s := newServer() + defer s.Close() + + cl, r, _ := newTestClient() + s.Clients.Add(cl) + + o := make(chan error) + go func() { + _, err := s.readConnectionPacket(cl) + o <- err + }() + + go func() { + _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalFixedHeader).RawBytes) + _ = r.Close() + }() + + err := <-o + require.Error(t, err) + require.Equal(t, packets.ErrMalformedVariableByteInteger, err) +} + +func TestServerReadConnectionPacketBadPacketType(t *testing.T) { + s := newServer() + defer s.Close() + + cl, r, _ := newTestClient() + s.Clients.Add(cl) + + go func() { + _, _ = r.Write(packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes) + _ = r.Close() + }() + + _, err := s.readConnectionPacket(cl) + require.Error(t, err) + require.Equal(t, packets.ErrProtocolViolationRequireFirstConnect, err) +} + +func TestServerReadConnectionPacketBadPacket(t *testing.T) { + s := newServer() + defer s.Close() + + cl, r, _ := newTestClient() + s.Clients.Add(cl) + + go func() { + _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalProtocolName).RawBytes) + _ = r.Close() + }() + + _, err := s.readConnectionPacket(cl) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrMalformedProtocolName) +} + +func TestEstablishConnection(t *testing.T) { + s := newServer() + defer s.Close() + + r, w := net.Pipe() + o := make(chan error) + go func() { + o <- s.EstablishConnection("tcp", r) + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + }() + + // receive the connack + recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(w) + require.NoError(t, err) + recv <- buf + }() + + err := <-o + require.NoError(t, err) + + // Todo: + // s.Clients is already empty here. Is it necessary to check v.StopCause()? + + // for _, v := range s.Clients.GetAll() { + // require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect + // } + + require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv) + + _ = w.Close() + _ = r.Close() + + // client must be deleted on session close if Clean = true + _, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).Packet.Connect.ClientIdentifier) + require.False(t, ok) +} + +func TestEstablishConnectionAckFailure(t *testing.T) { + s := newServer() + defer s.Close() + + r, w := net.Pipe() + o := make(chan error) + go func() { + o <- s.EstablishConnection("tcp", r) + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _ = w.Close() + }() + + err := <-o + require.Error(t, err) + require.ErrorIs(t, err, io.ErrClosedPipe) + + _ = r.Close() +} + +func TestEstablishConnectionReadError(t *testing.T) { + s := newServer() + defer s.Close() + + r, w := net.Pipe() + o := make(chan error) + go func() { + o <- s.EstablishConnection("tcp", r) + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) // second connect error + }() + + // receive the connack + recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(w) + require.NoError(t, err) + recv <- buf + }() + + err := <-o + require.Error(t, err) + + // Retrieve the client corresponding to the Client Identifier. + retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet.Connect.ClientIdentifier) + require.True(t, ok) + require.ErrorIs(t, retrievedCl.StopCause(), packets.ErrProtocolViolationSecondConnect) // true error is disconnect + + ret := <-recv + require.Equal(t, append( + packets.TPacketData[packets.Connack].Get(packets.TConnackMinCleanMqtt5).RawBytes, + packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectSecondConnect).RawBytes...), + ret, + ) + + _ = w.Close() + _ = r.Close() +} + +func TestEstablishConnectionInheritExisting(t *testing.T) { + s := newServer() + defer s.Close() + + cl, r0, _ := newTestClient() + cl.Properties.ProtocolVersion = 5 + cl.Properties.Username = []byte("mochi") + cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier + cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1}) + cl.State.Inflight.Set(*packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) + s.Clients.Add(cl) + + r, w := net.Pipe() + o := make(chan error) + go func() { + err := s.EstablishConnection("tcp", r) + o <- err + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + time.Sleep(time.Millisecond) // we want to receive the queued inflight, so we need to wait a moment before sending the disconnect. + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + }() + + // receive the disconnect session takeover + takeover := make(chan []byte) + go func() { + buf, err := io.ReadAll(r0) + require.NoError(t, err) + takeover <- buf + }() + + // receive the connack + recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(w) + require.NoError(t, err) + recv <- buf + }() + + err := <-o + require.NoError(t, err) + + // Retrieve the client corresponding to the Client Identifier. + retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) + require.True(t, ok) + require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect + + connackPlusPacket := append( + packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes, + packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes..., + ) + require.Equal(t, connackPlusPacket, <-recv) + require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectTakeover).RawBytes, <-takeover) + + time.Sleep(time.Microsecond * 100) + _ = w.Close() + _ = r.Close() + + clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) + require.True(t, ok) + require.NotEmpty(t, clw.State.Subscriptions) + + // Prevent sequential takeover memory-bloom. + require.Empty(t, cl.State.Subscriptions.GetAll()) +} + +// See https://github.com/mochi-mqtt/server/issues/173 +func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { + s := newServer() + d := new(DelayHook) + d.DisconnectDelay = time.Millisecond * 200 + _ = s.AddHook(d, nil) + defer s.Close() + + // Clean session, 0 session expiry interval + cl1RawBytes := []byte{ + packets.Connect << 4, 21, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 5, // Protocol Version + 1 << 1, // Packet Flags + 0, 30, // Keepalive + 5, // Properties length + 17, 0, 0, 0, 0, // Session Expiry Interval (17) + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + } + + // Make first connection + r1, w1 := net.Pipe() + o1 := make(chan error) + go func() { + err := s.EstablishConnection("tcp", r1) + o1 <- err + }() + go func() { + _, _ = w1.Write(cl1RawBytes) + }() + + // receive the first connack + recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(w1) + require.NoError(t, err) + recv <- buf + }() + + // Get the first client pointer + time.Sleep(time.Millisecond * 50) + cl1, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).Packet.Connect.ClientIdentifier) + require.True(t, ok) + cl1.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1}) + cl1.State.Subscriptions.Add("d/e/f", packets.Subscription{Filter: "d/e/f", Qos: 0}) + time.Sleep(time.Millisecond * 50) + + // Make the second connection + r2, w2 := net.Pipe() + o2 := make(chan error) + go func() { + err := s.EstablishConnection("tcp", r2) + o2 <- err + }() + go func() { + x := packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes[:] + x[19] = '.' // differentiate username bytes in debugging + _, _ = w2.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes) + }() + + // receive the second connack + recv2 := make(chan []byte) + go func() { + buf, err := io.ReadAll(w2) + require.NoError(t, err) + recv2 <- buf + }() + + // Capture first Client pointer + clp1, ok := s.Clients.Get("zen") + require.True(t, ok) + require.Empty(t, clp1.Properties.Username) + require.NotEmpty(t, clp1.State.Subscriptions.GetAll()) + + err1 := <-o1 + require.Error(t, err1) + require.ErrorIs(t, err1, io.ErrClosedPipe) + + // Capture second Client pointer + clp2, ok := s.Clients.Get("zen") + require.True(t, ok) + require.Equal(t, []byte(".ochi"), clp2.Properties.Username) + require.NotEmpty(t, clp2.State.Subscriptions.GetAll()) + require.Empty(t, clp1.State.Subscriptions.GetAll()) + + _, _ = w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + require.NoError(t, <-o2) +} + +func TestEstablishConnectionResentPendingInflightsError(t *testing.T) { + s := newServer() + defer s.Close() + + n := time.Now().Unix() + cl, r0, _ := newTestClient() + cl.Properties.ProtocolVersion = 5 + cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier + cl.State.Inflight = NewInflights() + cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: n - 2}) // no packet type + s.Clients.Add(cl) + + r, w := net.Pipe() + o := make(chan error) + go func() { + o <- s.EstablishConnection("tcp", r) + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + }() + + go func() { + _, err := io.ReadAll(r0) + require.NoError(t, err) + }() + + go func() { + _, err := io.ReadAll(w) + require.NoError(t, err) + }() + + err := <-o + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrNoValidPacketAvailable) +} + +func TestEstablishConnectionInheritExistingClean(t *testing.T) { + s := newServer() + defer s.Close() + + cl, r0, _ := newTestClient() + cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier + cl.Properties.Clean = true + cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1}) + s.Clients.Add(cl) + + r, w := net.Pipe() + o := make(chan error) + go func() { + o <- s.EstablishConnection("tcp", r) + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + }() + + // receive the disconnect + takeover := make(chan []byte) + go func() { + buf, err := io.ReadAll(r0) + require.NoError(t, err) + takeover <- buf + }() + + // receive the connack + recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(w) + require.NoError(t, err) + recv <- buf + }() + + err := <-o + require.NoError(t, err) + + // Retrieve the client corresponding to the Client Identifier. + retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) + require.True(t, ok) + require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect + + require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv) + require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, <-takeover) + + _ = w.Close() + _ = r.Close() + + clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) + require.True(t, ok) + require.Equal(t, 0, clw.State.Subscriptions.Len()) +} + +func TestEstablishConnectionBadAuthentication(t *testing.T) { + s := New(&Options{ + Logger: logger, + }) + defer s.Close() + + r, w := net.Pipe() + o := make(chan error) + go func() { + o <- s.EstablishConnection("tcp", r) + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + }() + + // receive the connack + recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(w) + require.NoError(t, err) + recv <- buf + }() + + err := <-o + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrBadUsernameOrPassword) + require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackBadUsernamePasswordNoSession).RawBytes, <-recv) + + _ = w.Close() + _ = r.Close() +} + +func TestEstablishConnectionBadAuthenticationAckFailure(t *testing.T) { + s := New(&Options{ + Logger: logger, + }) + defer s.Close() + + r, w := net.Pipe() + o := make(chan error) + go func() { + o <- s.EstablishConnection("tcp", r) + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _ = w.Close() + }() + + err := <-o + require.Error(t, err) + require.ErrorIs(t, err, io.ErrClosedPipe) + + _ = r.Close() +} + +func TestServerEstablishConnectionInvalidConnect(t *testing.T) { + s := newServer() + + r, w := net.Pipe() + o := make(chan error) + go func() { + o <- s.EstablishConnection("tcp", r) + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + }() + + // receive the connack + recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(w) + require.NoError(t, err) + recv <- buf + }() + + err := <-o + require.Error(t, err) + require.ErrorIs(t, packets.ErrProtocolViolationReservedBit, err) + require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackProtocolViolationNoSession).RawBytes, <-recv) + + _ = r.Close() +} + +func TestEstablishConnectionMaximumClientsReached(t *testing.T) { + cc := NewDefaultServerCapabilities() + cc.MaximumClients = 0 + s := New(&Options{ + Logger: logger, + Capabilities: cc, + }) + _ = s.AddHook(new(AllowHook), nil) + defer s.Close() + + r, w := net.Pipe() + o := make(chan error) + go func() { + o <- s.EstablishConnection("tcp", r) + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + }() + + // receive the connack + recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(w) + require.NoError(t, err) + recv <- buf + }() + + err := <-o + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrServerBusy) + + _ = r.Close() +} + +// See https://github.com/mochi-mqtt/server/issues/178 +func TestServerEstablishConnectionZeroByteUsernameIsValid(t *testing.T) { + s := newServer() + + r, w := net.Pipe() + o := make(chan error) + go func() { + o <- s.EstablishConnection("tcp", r) + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectZeroByteUsername).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + }() + + // receive the connack error + go func() { + _, err := io.ReadAll(w) + require.NoError(t, err) + }() + + err := <-o + require.NoError(t, err) + + _ = r.Close() +} + +func TestServerEstablishConnectionInvalidConnectAckFailure(t *testing.T) { + s := newServer() + + r, w := net.Pipe() + o := make(chan error) + go func() { + o <- s.EstablishConnection("tcp", r) + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes) + _ = w.Close() + }() + + err := <-o + require.Error(t, err) + require.ErrorIs(t, err, io.ErrClosedPipe) + + _ = r.Close() +} + +func TestServerEstablishConnectionBadPacket(t *testing.T) { + s := newServer() + + r, w := net.Pipe() + o := make(chan error) + go func() { + o <- s.EstablishConnection("tcp", r) + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnackBadProtocolVersion).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + }() + + err := <-o + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrProtocolViolationRequireFirstConnect) + + _ = r.Close() +} + +func TestServerEstablishConnectionOnConnectError(t *testing.T) { + s := newServer() + hook := new(modifiedHookBase) + hook.fail = true + err := s.AddHook(hook, nil) + require.NoError(t, err) + + r, w := net.Pipe() + o := make(chan error) + go func() { + o <- s.EstablishConnection("tcp", r) + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + }() + + err = <-o + require.Error(t, err) + require.ErrorIs(t, err, errTestHook) + + _ = r.Close() +} + +func TestServerSendConnack(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + s.Options.Capabilities.MaximumQos = 1 + cl.Properties.Props = packets.Properties{ + AssignedClientID: "mochi", + } + go func() { + err := s.SendConnack(cl, packets.CodeSuccess, true, nil) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackMinMqtt5).RawBytes, buf) +} + +func TestServerSendConnackFailureReason(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + go func() { + err := s.SendConnack(cl, packets.ErrUnspecifiedError, true, nil) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackInvalidMinMqtt5).RawBytes, buf) +} + +func TestServerSendConnackWithServerKeepalive(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + cl.State.Keepalive = 10 + cl.State.ServerKeepalive = true + go func() { + err := s.SendConnack(cl, packets.CodeSuccess, true, nil) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackServerKeepalive).RawBytes, buf) +} + +func TestServerValidateConnect(t *testing.T) { + packet := *packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet + invalidBitPacket := packet + invalidBitPacket.ReservedBit = 1 + packetCleanIdPacket := packet + packetCleanIdPacket.Connect.Clean = false + packetCleanIdPacket.Connect.ClientIdentifier = "" + tt := []struct { + desc string + client *Client + capabilities Capabilities + packet packets.Packet + expect packets.Code + }{ + { + desc: "unsupported protocol version", + client: &Client{Properties: ClientProperties{ProtocolVersion: 3}}, + capabilities: Capabilities{MinimumProtocolVersion: 4}, + packet: packet, + expect: packets.ErrUnsupportedProtocolVersion, + }, + { + desc: "will qos not supported", + client: &Client{Properties: ClientProperties{Will: Will{Qos: 2}}}, + capabilities: Capabilities{MaximumQos: 1}, + packet: packet, + expect: packets.ErrQosNotSupported, + }, + { + desc: "retain not supported", + client: &Client{Properties: ClientProperties{Will: Will{Retain: true}}}, + capabilities: Capabilities{RetainAvailable: 0}, + packet: packet, + expect: packets.ErrRetainNotSupported, + }, + { + desc: "invalid packet validate", + client: &Client{Properties: ClientProperties{Will: Will{Retain: true}}}, + capabilities: Capabilities{RetainAvailable: 0}, + packet: invalidBitPacket, + expect: packets.ErrProtocolViolationReservedBit, + }, + { + desc: "mqtt3 clean no client id ", + client: &Client{Properties: ClientProperties{ProtocolVersion: 3}}, + capabilities: Capabilities{}, + packet: packetCleanIdPacket, + expect: packets.ErrUnspecifiedError, + }, + } + + s := newServer() + for _, tx := range tt { + t.Run(tx.desc, func(t *testing.T) { + s.Options.Capabilities = &tx.capabilities + err := s.validateConnect(tx.client, tx.packet) + require.Error(t, err) + require.ErrorIs(t, err, tx.expect) + }) + } +} + +func TestServerSendConnackAdjustedExpiryInterval(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + cl.Properties.Props.SessionExpiryInterval = uint32(300) + s.Options.Capabilities.MaximumSessionExpiryInterval = 120 + go func() { + err := s.SendConnack(cl, packets.CodeSuccess, false, nil) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedAdjustedExpiryInterval).RawBytes, buf) +} + +func TestInheritClientSession(t *testing.T) { + s := newServer() + + n := time.Now().Unix() + + existing, _, _ := newTestClient() + existing.Net.Conn = nil + existing.ID = "mochi" + existing.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1}) + existing.State.Inflight = NewInflights() + existing.State.Inflight.Set(packets.Packet{PacketID: 1, Created: n - 1}) + existing.State.Inflight.Set(packets.Packet{PacketID: 2, Created: n - 2}) + + s.Clients.Add(existing) + + cl, _, _ := newTestClient() + cl.Properties.ProtocolVersion = 5 + + require.Equal(t, 0, cl.State.Inflight.Len()) + require.Equal(t, 0, cl.State.Subscriptions.Len()) + + // Inherit existing client properties + b := s.inheritClientSession(packets.Packet{Connect: packets.ConnectParams{ClientIdentifier: "mochi"}}, cl) + require.True(t, b) + require.Equal(t, 2, cl.State.Inflight.Len()) + require.Equal(t, 1, cl.State.Subscriptions.Len()) + + // On clean, clear existing properties + cl, _, _ = newTestClient() + cl.Properties.ProtocolVersion = 5 + b = s.inheritClientSession(packets.Packet{Connect: packets.ConnectParams{ClientIdentifier: "mochi", Clean: true}}, cl) + require.False(t, b) + require.Equal(t, 0, cl.State.Inflight.Len()) + require.Equal(t, 0, cl.State.Subscriptions.Len()) +} + +func TestServerUnsubscribeClient(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + pk := packets.Subscription{Filter: "a/b/c", Qos: 1} + cl.State.Subscriptions.Add("a/b/c", pk) + s.Topics.Subscribe(cl.ID, pk) + subs := s.Topics.Subscribers("a/b/c") + require.Equal(t, 1, len(subs.Subscriptions)) + s.UnsubscribeClient(cl) + subs = s.Topics.Subscribers("a/b/c") + require.Equal(t, 0, len(subs.Subscriptions)) +} + +func TestServerProcessPacketFailure(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + err := s.processPacket(cl, packets.Packet{}) + require.Error(t, err) +} + +func TestServerProcessPacketConnect(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + + err := s.processPacket(cl, *packets.TPacketData[packets.Connect].Get(packets.TConnectClean).Packet) + require.Error(t, err) +} + +func TestServerProcessPacketPingreq(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Pingresp].Get(packets.TPingresp).RawBytes, buf) +} + +func TestServerProcessPacketPingreqError(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + cl.Stop(packets.CodeDisconnect) + + err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet) + require.Error(t, err) + require.ErrorIs(t, cl.StopCause(), packets.CodeDisconnect) +} + +func TestServerProcessPacketPublishInvalid(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + + err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishInvalidQosMustPacketID).Packet) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID) +} + +func TestInjectPacketPublishAndReceive(t *testing.T) { + s := newServer() + _ = s.Serve() + defer s.Close() + + sender, _, w1 := newTestClient() + sender.Net.Inline = true + sender.ID = "sender" + s.Clients.Add(sender) + + receiver, r2, w2 := newTestClient() + receiver.ID = "receiver" + s.Clients.Add(receiver) + s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"}) + + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived)) + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r2) + require.NoError(t, err) + receiverBuf <- buf + }() + + go func() { + err := s.InjectPacket(sender, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) + require.NoError(t, err) + _ = w1.Close() + time.Sleep(time.Millisecond * 10) + _ = w2.Close() + }() + + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) +} + +func TestServerPublishAndReceive(t *testing.T) { + s := newServerWithInlineClient() + + _ = s.Serve() + defer s.Close() + + sender, _, w1 := newTestClient() + sender.Net.Inline = true + sender.ID = "sender" + s.Clients.Add(sender) + + receiver, r2, w2 := newTestClient() + receiver.ID = "receiver" + s.Clients.Add(receiver) + s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"}) + + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived)) + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r2) + require.NoError(t, err) + receiverBuf <- buf + }() + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos) + require.NoError(t, err) + _ = w1.Close() + time.Sleep(time.Millisecond * 10) + _ = w2.Close() + }() + + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) +} + +func TestServerPublishNoInlineClient(t *testing.T) { + s := newServer() + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos) + require.Error(t, err) + require.ErrorIs(t, err, ErrInlineClientNotEnabled) +} + +func TestInjectPacketError(t *testing.T) { + s := newServer() + defer s.Close() + cl, _, _ := newTestClient() + cl.Net.Inline = true + pkx := *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet + pkx.Filters = packets.Subscriptions{} + err := s.InjectPacket(cl, pkx) + require.Error(t, err) +} + +func TestInjectPacketPublishInvalidTopic(t *testing.T) { + s := newServer() + defer s.Close() + cl, _, _ := newTestClient() + cl.Net.Inline = true + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + pkx.TopicName = "$SYS/test" + err := s.InjectPacket(cl, pkx) + require.NoError(t, err) // bypass topic validity and acl checks +} + +func TestServerProcessPacketPublishAndReceive(t *testing.T) { + s := newServer() + _ = s.Serve() + defer s.Close() + + sender, _, w1 := newTestClient() + sender.ID = "sender" + s.Clients.Add(sender) + + receiver, r2, w2 := newTestClient() + receiver.ID = "receiver" + s.Clients.Add(receiver) + s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"}) + + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived)) + require.Equal(t, 0, len(s.Topics.Messages("a/b/c"))) + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r2) + require.NoError(t, err) + receiverBuf <- buf + }() + + go func() { + err := s.processPacket(sender, *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.NoError(t, err) + time.Sleep(time.Millisecond * 10) + _ = w1.Close() + _ = w2.Close() + }() + + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) + require.Equal(t, 1, len(s.Topics.Messages("a/b/c"))) +} + +func TestServerBuildAck(t *testing.T) { + s := newServer() + properties := packets.Properties{ + User: []packets.UserProperty{ + {Key: "hello", Val: "世界"}, + }, + } + ack := s.buildAck(7, packets.Puback, 1, properties, packets.CodeGrantedQos1) + require.Equal(t, packets.Puback, ack.FixedHeader.Type) + require.Equal(t, uint8(1), ack.FixedHeader.Qos) + require.Equal(t, packets.CodeGrantedQos1.Code, ack.ReasonCode) + require.Equal(t, properties, ack.Properties) +} + +func TestServerBuildAckError(t *testing.T) { + s := newServer() + properties := packets.Properties{ + User: []packets.UserProperty{ + {Key: "hello", Val: "世界"}, + }, + } + ack := s.buildAck(7, packets.Puback, 1, properties, packets.ErrMalformedPacket) + require.Equal(t, packets.Puback, ack.FixedHeader.Type) + require.Equal(t, uint8(1), ack.FixedHeader.Qos) + require.Equal(t, packets.ErrMalformedPacket.Code, ack.ReasonCode) + properties.ReasonString = packets.ErrMalformedPacket.Reason + require.Equal(t, properties, ack.Properties) +} + +func TestServerBuildAckPahoCompatibility(t *testing.T) { + s := newServer() + s.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true + properties := packets.Properties{ + User: []packets.UserProperty{ + {Key: "hello", Val: "世界"}, + }, + } + ack := s.buildAck(7, packets.Puback, 1, properties, packets.CodeGrantedQos1) + require.Equal(t, packets.Puback, ack.FixedHeader.Type) + require.Equal(t, uint8(1), ack.FixedHeader.Qos) + require.Equal(t, packets.CodeGrantedQos1.Code, ack.ReasonCode) + require.Equal(t, packets.Properties{}, ack.Properties) +} + +func TestServerProcessPacketAndNextImmediate(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + + next := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet + next.Expiry = -1 + cl.State.Inflight.Set(next) + atomic.StoreInt64(&s.Info.Inflight, 1) + require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Inflight)) + require.Equal(t, int32(5), cl.State.Inflight.sendQuota) + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, buf) + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight)) + require.Equal(t, int32(4), cl.State.Inflight.sendQuota) +} + +func TestServerProcessPublishAckFailure(t *testing.T) { + s := newServer() + _ = s.Serve() + defer s.Close() + + cl, _, w := newTestClient() + s.Clients.Add(cl) + + _ = w.Close() + err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) + require.Error(t, err) + require.ErrorIs(t, err, io.ErrClosedPipe) +} + +func TestServerProcessPublishOnPublishAckErrorRWError(t *testing.T) { + s := newServer() + hook := new(modifiedHookBase) + hook.fail = true + hook.err = packets.ErrUnspecifiedError + err := s.AddHook(hook, nil) + require.NoError(t, err) + + cl, _, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + s.Clients.Add(cl) + _ = w.Close() + + err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) + require.Error(t, err) + require.ErrorIs(t, err, io.ErrClosedPipe) +} + +func TestServerProcessPublishOnPublishAckErrorContinue(t *testing.T) { + s := newServer() + hook := new(modifiedHookBase) + hook.fail = true + hook.err = packets.ErrPayloadFormatInvalid + err := s.AddHook(hook, nil) + require.NoError(t, err) + _ = s.Serve() + defer s.Close() + + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + s.Clients.Add(cl) + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPubackUnexpectedError).RawBytes, buf) +} + +func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) { + s := newServer() + hook := new(modifiedHookBase) + hook.fail = true + hook.err = packets.CodeSuccessIgnore + err := s.AddHook(hook, nil) + require.NoError(t, err) + _ = s.Serve() + defer s.Close() + + cl, r, w := newTestClient() + s.Clients.Add(cl) + + receiver, r2, w2 := newTestClient() + receiver.ID = "receiver" + s.Clients.Add(receiver) + s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"}) + + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived)) + require.Equal(t, 0, len(s.Topics.Messages("a/b/c"))) + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r2) + require.NoError(t, err) + receiverBuf <- buf + }() + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) + require.NoError(t, err) + _ = w.Close() + _ = w2.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf) + require.Equal(t, []byte{}, <-receiverBuf) + require.Equal(t, 0, len(s.Topics.Messages("a/b/c"))) +} + +func TestServerProcessPacketPublishMaximumReceive(t *testing.T) { + s := newServer() + _ = s.Serve() + defer s.Close() + + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + cl.State.Inflight.ResetReceiveQuota(0) + s.Clients.Add(cl) + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrReceiveMaximum) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectReceiveMaximum).RawBytes, buf) +} + +func TestServerProcessPublishInvalidTopic(t *testing.T) { + s := newServer() + _ = s.Serve() + defer s.Close() + cl, _, _ := newTestClient() + err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishSpecDenySysTopic).Packet) + require.NoError(t, err) // $SYS Topics should be ignored? +} + +func TestServerProcessPublishACLCheckDeny(t *testing.T) { + tt := []struct { + name string + protocolVersion byte + pk packets.Packet + expectErr error + expectReponse []byte + expectDisconnect bool + }{ + { + name: "v4_QOS0", + protocolVersion: 4, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet, + expectErr: nil, + expectReponse: nil, + expectDisconnect: false, + }, + { + name: "v4_QOS1", + protocolVersion: 4, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet, + expectErr: packets.ErrNotAuthorized, + expectReponse: nil, + expectDisconnect: true, + }, + { + name: "v4_QOS2", + protocolVersion: 4, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet, + expectErr: packets.ErrNotAuthorized, + expectReponse: nil, + expectDisconnect: true, + }, + { + name: "v5_QOS0", + protocolVersion: 5, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet, + expectErr: nil, + expectReponse: nil, + expectDisconnect: false, + }, + { + name: "v5_QOS1", + protocolVersion: 5, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Mqtt5).Packet, + expectErr: nil, + expectReponse: packets.TPacketData[packets.Puback].Get(packets.TPubrecMqtt5NotAuthorized).RawBytes, + expectDisconnect: false, + }, + { + name: "v5_QOS2", + protocolVersion: 5, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2Mqtt5).Packet, + expectErr: nil, + expectReponse: packets.TPacketData[packets.Pubrec].Get(packets.TPubrecMqtt5NotAuthorized).RawBytes, + expectDisconnect: false, + }, + } + + for _, tx := range tt { + t.Run(tx.name, func(t *testing.T) { + cc := NewDefaultServerCapabilities() + s := New(&Options{ + Logger: logger, + Capabilities: cc, + }) + _ = s.AddHook(new(DenyHook), nil) + _ = s.Serve() + defer s.Close() + + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = tx.protocolVersion + s.Clients.Add(cl) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + err := s.processPublish(cl, tx.pk) + require.ErrorIs(t, err, tx.expectErr) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + + if tx.expectReponse != nil { + require.Equal(t, tx.expectReponse, buf) + } + + require.Equal(t, tx.expectDisconnect, cl.Closed()) + wg.Wait() + }) + } +} + +func TestServerProcessPublishOnMessageRecvRejected(t *testing.T) { + s := newServer() + require.NotNil(t, s) + hook := new(modifiedHookBase) + hook.fail = true + hook.err = packets.ErrRejectPacket + + err := s.AddHook(hook, nil) + require.NoError(t, err) + + _ = s.Serve() + defer s.Close() + cl, _, _ := newTestClient() + err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) + require.NoError(t, err) // packets rejected silently +} + +func TestServerProcessPacketPublishQos0(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, []byte{}, buf) +} + +func TestServerProcessPacketPublishQos1PacketIDInUse(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.State.Inflight.Set(packets.Packet{PacketID: 7, FixedHeader: packets.FixedHeader{Type: packets.Publish}}) + atomic.StoreInt64(&s.Info.Inflight, 1) + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf) + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight)) +} + +func TestServerProcessPacketPublishQos2PacketIDInUse(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + cl.State.Inflight.Set(packets.Packet{PacketID: 7, FixedHeader: packets.FixedHeader{Type: packets.Pubrec}}) + atomic.StoreInt64(&s.Info.Inflight, 1) + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2Mqtt5).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Pubrec].Get(packets.TPubrecMqtt5IDInUse).RawBytes, buf) + require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Inflight)) +} + +func TestServerProcessPacketPublishQos1(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf) +} + +func TestServerProcessPacketPublishQos2(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).RawBytes, buf) +} + +func TestServerProcessPacketPublishDowngradeQos(t *testing.T) { + s := newServer() + s.Options.Capabilities.MaximumQos = 1 + cl, r, w := newTestClient() + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf) +} + +func TestPublishToSubscribersSelfNoLocal(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + s.Clients.Add(cl) + subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", NoLocal: true}) + require.True(t, subbed) + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + pkx.Origin = cl.ID + s.publishToSubscribers(pkx) + time.Sleep(time.Millisecond) + _ = w.Close() + }() + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r) + require.NoError(t, err) + receiverBuf <- buf + }() + + require.Equal(t, []byte{}, <-receiverBuf) +} + +func TestPublishToSubscribers(t *testing.T) { + s := newServer() + cl, r1, w1 := newTestClient() + cl.ID = "cl1" + cl2, r2, w2 := newTestClient() + cl2.ID = "cl2" + cl3, r3, w3 := newTestClient() + cl3.ID = "cl3" + s.Clients.Add(cl) + s.Clients.Add(cl2) + s.Clients.Add(cl3) + require.True(t, s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c"})) + require.True(t, s.Topics.Subscribe(cl2.ID, packets.Subscription{Filter: SharePrefix + "/tmp/a/b/c"})) + require.True(t, s.Topics.Subscribe(cl3.ID, packets.Subscription{Filter: SharePrefix + "/tmp/a/b/c"})) + + cl1Recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(r1) + require.NoError(t, err) + cl1Recv <- buf + }() + + cl2Recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(r2) + require.NoError(t, err) + cl2Recv <- buf + }() + + cl3Recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(r3) + require.NoError(t, err) + cl3Recv <- buf + }() + + go func() { + s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) + time.Sleep(time.Millisecond) + _ = w1.Close() + _ = w2.Close() + _ = w3.Close() + }() + + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-cl1Recv) + rcv2 := <-cl2Recv + rcv3 := <-cl3Recv + + ok := false + if len(rcv2) > 0 { + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, rcv2) + require.Equal(t, []byte{}, rcv3) + ok = true + } else if len(rcv3) > 0 { + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, rcv3) + require.Equal(t, []byte{}, rcv2) + ok = true + } + require.True(t, ok) +} + +func TestPublishToSubscribersMessageExpiryDelta(t *testing.T) { + s := newServer() + s.Options.Capabilities.MaximumMessageExpiryInterval = 86400 + cl, r1, w1 := newTestClient() + cl.ID = "cl1" + cl.Properties.ProtocolVersion = 5 + s.Clients.Add(cl) + require.True(t, s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c"})) + + cl1Recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(r1) + require.NoError(t, err) + cl1Recv <- buf + }() + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + pkx.Created = time.Now().Unix() - 30 + s.publishToSubscribers(pkx) + time.Sleep(time.Millisecond) + _ = w1.Close() + }() + + b := <-cl1Recv + pk := new(packets.Packet) + pk.ProtocolVersion = 5 + require.Equal(t, uint32(s.Options.Capabilities.MaximumMessageExpiryInterval-30), binary.BigEndian.Uint32(b[11:15])) +} + +func TestPublishToSubscribersIdentifiers(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + s.Clients.Add(cl) + subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/+", Identifier: 2}) + require.True(t, subbed) + subbed = s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/#", Identifier: 3}) + require.True(t, subbed) + subbed = s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "d/e/f", Identifier: 4}) + require.True(t, subbed) + + go func() { + s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) + time.Sleep(time.Millisecond) + _ = w.Close() + }() + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r) + require.NoError(t, err) + receiverBuf <- buf + }() + + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishSubscriberIdentifier).RawBytes, <-receiverBuf) +} + +func TestPublishToSubscribersPkIgnore(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + s.Clients.Add(cl) + subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "#", Identifier: 1}) + require.True(t, subbed) + + go func() { + pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + pk.Ignore = true + s.publishToSubscribers(pk) + time.Sleep(time.Millisecond) + _ = w.Close() + }() + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r) + require.NoError(t, err) + receiverBuf <- buf + }() + + require.Equal(t, []byte{}, <-receiverBuf) +} + +func TestPublishToClientServerDowngradeQos(t *testing.T) { + s := newServer() + s.Options.Capabilities.MaximumQos = 1 + + cl, r, w := newTestClient() + s.Clients.Add(cl) + + _, ok := cl.State.Inflight.Get(1) + require.False(t, ok) + cl.State.packetID = 6 // just to match the same packet id (7) in the fixtures + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet + pkx.FixedHeader.Qos = 2 + _, _ = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, pkx) + time.Sleep(time.Microsecond * 100) + _ = w.Close() + }() + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r) + require.NoError(t, err) + receiverBuf <- buf + }() + + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, <-receiverBuf) +} + +func TestPublishToClientSubscriptionDowngradeQos(t *testing.T) { + s := newServer() + s.Options.Capabilities.MaximumQos = 2 + + cl, r, w := newTestClient() + s.Clients.Add(cl) + + _, ok := cl.State.Inflight.Get(1) + require.False(t, ok) + cl.State.packetID = 6 // just to match the same packet id (7) in the fixtures + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet + pkx.FixedHeader.Qos = 2 + _, _ = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, pkx) + time.Sleep(time.Microsecond * 100) + _ = w.Close() + }() + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r) + require.NoError(t, err) + receiverBuf <- buf + }() + + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, <-receiverBuf) +} + +func TestPublishToClientExceedClientWritesPending(t *testing.T) { + var sendQuota uint16 = 5 + s := newServer() + + _, w := net.Pipe() + cl := newClient(w, &ops{ + info: new(system.Info), + hooks: new(Hooks), + log: logger, + options: &Options{ + Capabilities: &Capabilities{ + MaximumClientWritesPending: 3, + maximumPacketID: 10, + }, + }, + }) + cl.Properties.Props.ReceiveMaximum = sendQuota + cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) + + s.Clients.Add(cl) + + for i := int32(0); i < cl.ops.options.Capabilities.MaximumClientWritesPending; i++ { + cl.State.outbound <- new(packets.Packet) + atomic.AddInt32(&cl.State.outboundQty, 1) + } + + id, _ := cl.NextPacketID() + cl.State.Inflight.Set(packets.Packet{PacketID: uint16(id)}) + cl.State.Inflight.DecreaseSendQuota() + sendQuota-- + + _, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, packets.Packet{}) + require.Error(t, err) + require.ErrorIs(t, packets.ErrPendingClientWritesExceeded, err) + require.Equal(t, int32(sendQuota), atomic.LoadInt32(&cl.State.Inflight.sendQuota)) + + _, err = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, packets.Packet{FixedHeader: packets.FixedHeader{Qos: 1}}) + require.Error(t, err) + require.ErrorIs(t, packets.ErrPendingClientWritesExceeded, err) + require.Equal(t, int32(sendQuota), atomic.LoadInt32(&cl.State.Inflight.sendQuota)) +} + +func TestPublishToClientServerTopicAlias(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + cl.Properties.Props.TopicAliasMaximum = 5 + s.Clients.Add(cl) + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet + _, _ = s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) + _, _ = s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) + time.Sleep(time.Millisecond) + _ = w.Close() + }() + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r) + require.NoError(t, err) + receiverBuf <- buf + }() + + ret := <-receiverBuf + pk1 := make([]byte, len(packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).RawBytes)) + pk2 := make([]byte, len(packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).RawBytes)-5) + copy(pk1, ret[:len(packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).RawBytes)]) + copy(pk2, ret[len(packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).RawBytes):]) + require.Equal(t, append(pk1, pk2...), ret) +} + +func TestPublishToClientMqtt3RetainFalseLeverageNoConn(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + cl.Net.Conn = nil + + out, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", RetainAsPublished: true}, *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.False(t, out.FixedHeader.Retain) + require.Error(t, err) + require.ErrorIs(t, err, packets.CodeDisconnect) +} + +func TestPublishToClientMqtt5RetainAsPublishedTrueLeverageNoConn(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + cl.Properties.ProtocolVersion = 5 + cl.Net.Conn = nil + + out, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", RetainAsPublished: true}, *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.True(t, out.FixedHeader.Retain) + require.Error(t, err) + require.ErrorIs(t, err, packets.CodeDisconnect) +} + +func TestPublishToClientExceedMaximumInflight(t *testing.T) { + const MaxInflight uint16 = 5 + s := newServer() + cl, _, _ := newTestClient() + s.Options.Capabilities.MaximumInflight = MaxInflight + cl.ops.options.Capabilities.MaximumInflight = MaxInflight + for i := uint16(0); i < MaxInflight; i++ { + cl.State.Inflight.Set(packets.Packet{PacketID: i}) + } + + _, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrQuotaExceeded) + require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.InflightDropped)) +} + +func TestPublishToClientExhaustedPacketID(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + for i := uint32(0); i <= cl.ops.options.Capabilities.maximumPacketID; i++ { + cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)}) + } + + _, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrQuotaExceeded) + require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.InflightDropped)) +} + +func TestPublishToClientACLNotAuthorized(t *testing.T) { + s := New(&Options{ + Logger: logger, + }) + err := s.AddHook(new(DenyHook), nil) + require.NoError(t, err) + cl, _, _ := newTestClient() + + _, err = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrNotAuthorized) +} + +func TestPublishToClientNoConn(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + cl.Net.Conn = nil + + _, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) + require.Error(t, err) + require.ErrorIs(t, err, packets.CodeDisconnect) +} + +func TestProcessPublishWithTopicAlias(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + s.Clients.Add(cl) + subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 0}) + require.True(t, subbed) + + cl2, _, w2 := newTestClient() + cl2.Properties.ProtocolVersion = 5 + cl2.State.TopicAliases.Inbound.Set(1, "a/b/c") + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishMqtt5).Packet + pkx.Properties.SubscriptionIdentifier = []int{} // must not contain from client to server + pkx.TopicName = "" + pkx.Properties.TopicAlias = 1 + _ = s.processPacket(cl2, pkx) + time.Sleep(time.Millisecond) + _ = w2.Close() + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, buf) +} + +func TestPublishToSubscribersExhaustedSendQuota(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + s.Clients.Add(cl) + cl.State.Inflight.sendQuota = 0 + + subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2}) + require.True(t, subbed) + + // coverage: subscriber publish errors are non-returnable + // can we hook into log/slog ? + _ = r.Close() + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet + pkx.PacketID = 0 + s.publishToSubscribers(pkx) + time.Sleep(time.Millisecond) + _ = w.Close() +} + +func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + s.Clients.Add(cl) + for i := uint32(0); i <= cl.ops.options.Capabilities.maximumPacketID; i++ { + cl.State.Inflight.Set(packets.Packet{PacketID: 1}) + } + + subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2}) + require.True(t, subbed) + + // coverage: subscriber publish errors are non-returnable + // can we hook into log/slog ? + _ = r.Close() + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet + pkx.PacketID = 0 + s.publishToSubscribers(pkx) + time.Sleep(time.Millisecond) + _ = w.Close() +} + +func TestPublishToSubscribersNoConnection(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + s.Clients.Add(cl) + subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2}) + require.True(t, subbed) + + // coverage: subscriber publish errors are non-returnable + // can we hook into log/slog ? + _ = r.Close() + s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) + time.Sleep(time.Millisecond) + _ = w.Close() +} + +func TestPublishRetainedToClient(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + s.Clients.Add(cl) + + subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2}) + require.True(t, subbed) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetainMqtt5).Packet) + require.Equal(t, int64(1), retained) + + go func() { + s.publishRetainedToClient(cl, packets.Subscription{Filter: "a/b/c"}, false) + time.Sleep(time.Millisecond) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes, buf) +} + +func TestPublishRetainedToClientIsShared(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + s.Clients.Add(cl) + + sub := packets.Subscription{Filter: SharePrefix + "/test/a/b/c"} + subbed := s.Topics.Subscribe(cl.ID, sub) + require.True(t, subbed) + + go func() { + s.publishRetainedToClient(cl, sub, false) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, []byte{}, buf) +} + +func TestPublishRetainedToClientError(t *testing.T) { + s := newServer() + cl, _, w := newTestClient() + s.Clients.Add(cl) + + sub := packets.Subscription{Filter: "a/b/c"} + subbed := s.Topics.Subscribe(cl.ID, sub) + require.True(t, subbed) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + + _ = w.Close() + s.publishRetainedToClient(cl, sub, false) +} + +func TestNoRetainMessageIfUnavailable(t *testing.T) { + s := newServer() + s.Options.Capabilities.RetainAvailable = 0 + cl, _, _ := newTestClient() + s.Clients.Add(cl) + + s.retainMessage(new(Client), *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Retained)) +} + +func TestNoRetainMessageIfPkIgnore(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + s.Clients.Add(cl) + + pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet + pk.Ignore = true + s.retainMessage(new(Client), pk) + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Retained)) +} + +func TestNoRetainMessage(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + s.Clients.Add(cl) + + s.retainMessage(new(Client), *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Retained)) +} + +func TestServerProcessPacketPuback(t *testing.T) { + tt := ProtocolTest{ + { + protocolVersion: 4, + in: packets.TPacketData[packets.Puback].Get(packets.TPuback), + }, + { + protocolVersion: 5, + in: packets.TPacketData[packets.Puback].Get(packets.TPubackMqtt5), + }, + } + + for _, tx := range tt { + t.Run(strconv.Itoa(int(tx.protocolVersion)), func(t *testing.T) { + pID := uint16(7) + s := newServer() + cl, _, _ := newTestClient() + cl.State.Inflight.sendQuota = 3 + cl.State.Inflight.receiveQuota = 3 + + cl.State.Inflight.Set(packets.Packet{PacketID: pID}) + atomic.AddInt64(&s.Info.Inflight, 1) + + err := s.processPacket(cl, *tx.in.Packet) + require.NoError(t, err) + + require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.sendQuota)) + require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.receiveQuota)) + + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight)) + _, ok := cl.State.Inflight.Get(pID) + require.False(t, ok) + }) + } +} + +func TestServerProcessPacketPubackNoPacketID(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + cl.State.Inflight.sendQuota = 3 + cl.State.Inflight.receiveQuota = 3 + + pk := *packets.TPacketData[packets.Puback].Get(packets.TPuback).Packet + err := s.processPacket(cl, pk) + require.NoError(t, err) + + require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.sendQuota)) + require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.receiveQuota)) +} + +func TestServerProcessPacketPubrec(t *testing.T) { + pID := uint16(7) + s := newServer() + cl, r, w := newTestClient() + cl.State.Inflight.sendQuota = 3 + cl.State.Inflight.receiveQuota = 3 + + cl.State.Inflight.Set(packets.Packet{PacketID: pID}) + atomic.AddInt64(&s.Info.Inflight, 1) + + recv := make(chan []byte) + go func() { // receive the ack + buf, err := io.ReadAll(r) + require.NoError(t, err) + recv <- buf + }() + + err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet) + require.NoError(t, err) + _ = w.Close() + + require.Equal(t, packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).RawBytes, <-recv) + + require.Equal(t, int32(2), atomic.LoadInt32(&cl.State.Inflight.receiveQuota)) + require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.sendQuota)) + require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Inflight)) + _, ok := cl.State.Inflight.Get(pID) + require.True(t, ok) +} + +func TestServerProcessPacketPubrecNoPacketID(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + cl.State.Inflight.sendQuota = 3 + cl.State.Inflight.receiveQuota = 3 + + recv := make(chan []byte) + go func() { // receive the ack + buf, err := io.ReadAll(r) + require.NoError(t, err) + recv <- buf + }() + + pk := *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet // not sending properties + err := s.processPacket(cl, pk) + require.NoError(t, err) + _ = w.Close() + + require.Equal(t, packets.TPacketData[packets.Pubrel].Get(packets.TPubrelMqtt5AckNoPacket).RawBytes, <-recv) + + require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.sendQuota)) + require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.receiveQuota)) +} + +func TestServerProcessPacketPubrecInvalidReason(t *testing.T) { + pID := uint16(7) + s := newServer() + cl, _, _ := newTestClient() + cl.State.Inflight.Set(packets.Packet{PacketID: pID}) + err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrecInvalidReason).Packet) + require.NoError(t, err) + require.Equal(t, int64(-1), atomic.LoadInt64(&s.Info.Inflight)) + _, ok := cl.State.Inflight.Get(pID) + require.False(t, ok) +} + +func TestServerProcessPacketPubrecFailure(t *testing.T) { + pID := uint16(7) + s := newServer() + cl, _, _ := newTestClient() + cl.State.Inflight.Set(packets.Packet{PacketID: pID}) + cl.Stop(packets.CodeDisconnect) + err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet) + require.Error(t, err) + require.ErrorIs(t, cl.StopCause(), packets.CodeDisconnect) +} + +func TestServerProcessPacketPubrel(t *testing.T) { + pID := uint16(7) + s := newServer() + cl, r, w := newTestClient() + cl.State.Inflight.sendQuota = 3 + cl.State.Inflight.receiveQuota = 3 + + cl.State.Inflight.Set(packets.Packet{PacketID: pID}) + atomic.AddInt64(&s.Info.Inflight, 1) + + recv := make(chan []byte) + go func() { // receive the ack + buf, err := io.ReadAll(r) + require.NoError(t, err) + recv <- buf + }() + + err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet) + require.NoError(t, err) + _ = w.Close() + + require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.receiveQuota)) + require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.sendQuota)) + + require.Equal(t, packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp).RawBytes, <-recv) + + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight)) + _, ok := cl.State.Inflight.Get(pID) + require.False(t, ok) +} + +func TestServerProcessPacketPubrelNoPacketID(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + cl.State.Inflight.sendQuota = 3 + cl.State.Inflight.receiveQuota = 3 + + recv := make(chan []byte) + go func() { // receive the ack + buf, err := io.ReadAll(r) + require.NoError(t, err) + recv <- buf + }() + + pk := *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet // not sending properties + err := s.processPacket(cl, pk) + require.NoError(t, err) + _ = w.Close() + + require.Equal(t, packets.TPacketData[packets.Pubcomp].Get(packets.TPubcompMqtt5AckNoPacket).RawBytes, <-recv) + + require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.sendQuota)) + require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.receiveQuota)) +} + +func TestServerProcessPacketPubrelFailure(t *testing.T) { + pID := uint16(7) + s := newServer() + cl, _, _ := newTestClient() + cl.State.Inflight.Set(packets.Packet{PacketID: pID}) + cl.Stop(packets.CodeDisconnect) + err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet) + require.Error(t, err) + require.ErrorIs(t, cl.StopCause(), packets.CodeDisconnect) +} + +func TestServerProcessPacketPubrelBadReason(t *testing.T) { + pID := uint16(7) + s := newServer() + cl, _, _ := newTestClient() + cl.State.Inflight.Set(packets.Packet{PacketID: pID}) + err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrelInvalidReason).Packet) + require.NoError(t, err) + require.Equal(t, int64(-1), atomic.LoadInt64(&s.Info.Inflight)) + _, ok := cl.State.Inflight.Get(pID) + require.False(t, ok) +} + +func TestServerProcessPacketPubcomp(t *testing.T) { + tt := ProtocolTest{ + { + protocolVersion: 4, + in: packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp), + }, + { + protocolVersion: 5, + in: packets.TPacketData[packets.Pubcomp].Get(packets.TPubcompMqtt5), + }, + } + + for _, tx := range tt { + t.Run(strconv.Itoa(int(tx.protocolVersion)), func(t *testing.T) { + pID := uint16(7) + s := newServer() + cl, _, _ := newTestClient() + cl.Properties.ProtocolVersion = tx.protocolVersion + cl.State.Inflight.sendQuota = 3 + cl.State.Inflight.receiveQuota = 3 + + cl.State.Inflight.Set(packets.Packet{PacketID: pID}) + atomic.AddInt64(&s.Info.Inflight, 1) + + err := s.processPacket(cl, *tx.in.Packet) + require.NoError(t, err) + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight)) + + require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.receiveQuota)) + require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.sendQuota)) + + _, ok := cl.State.Inflight.Get(pID) + require.False(t, ok) + }) + } +} + +func TestServerProcessInboundQos2Flow(t *testing.T) { + tt := ProtocolTest{ + { + protocolVersion: 5, + in: packets.TPacketData[packets.Publish].Get(packets.TPublishQos2), + out: packets.TPacketData[packets.Pubrec].Get(packets.TPubrec), + data: map[string]any{ + "sendquota": int32(3), + "recvquota": int32(2), + "inflight": int64(1), + }, + }, + { + protocolVersion: 5, + in: packets.TPacketData[packets.Pubrel].Get(packets.TPubrel), + out: packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp), + data: map[string]any{ + "sendquota": int32(4), + "recvquota": int32(3), + "inflight": int64(0), + }, + }, + } + + pID := uint16(7) + s := newServer() + cl, r, w := newTestClient() + cl.State.Inflight.sendQuota = 3 + cl.State.Inflight.receiveQuota = 3 + + for i, tx := range tt { + t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) { + r, w = net.Pipe() + cl.Net.Conn = w + + recv := make(chan []byte) + go func() { // receive the ack + buf, err := io.ReadAll(r) + require.NoError(t, err) + recv <- buf + }() + + err := s.processPacket(cl, *tx.in.Packet) + require.NoError(t, err) + _ = w.Close() + + require.Equal(t, tx.out.RawBytes, <-recv) + if i == 0 { + _, ok := cl.State.Inflight.Get(pID) + require.True(t, ok) + } + + require.Equal(t, tx.data["inflight"].(int64), atomic.LoadInt64(&s.Info.Inflight)) + require.Equal(t, tx.data["recvquota"].(int32), atomic.LoadInt32(&cl.State.Inflight.receiveQuota)) + require.Equal(t, tx.data["sendquota"].(int32), atomic.LoadInt32(&cl.State.Inflight.sendQuota)) + }) + } + + _, ok := cl.State.Inflight.Get(pID) + require.False(t, ok) +} + +func TestServerProcessOutboundQos2Flow(t *testing.T) { + tt := ProtocolTest{ + { + protocolVersion: 5, + in: packets.TPacketData[packets.Publish].Get(packets.TPublishQos2), + out: packets.TPacketData[packets.Publish].Get(packets.TPublishQos2), + data: map[string]any{ + "sendquota": int32(2), + "recvquota": int32(3), + "inflight": int64(1), + }, + }, + { + protocolVersion: 5, + in: packets.TPacketData[packets.Pubrec].Get(packets.TPubrec), + out: packets.TPacketData[packets.Pubrel].Get(packets.TPubrel), + data: map[string]any{ + "sendquota": int32(2), + "recvquota": int32(2), + "inflight": int64(1), + }, + }, + { + protocolVersion: 5, + in: packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp), + data: map[string]any{ + "sendquota": int32(3), + "recvquota": int32(3), + "inflight": int64(0), + }, + }, + } + + pID := uint16(6) + s := newServer() + cl, _, _ := newTestClient() + cl.State.packetID = uint32(6) + cl.State.Inflight.sendQuota = 3 + cl.State.Inflight.receiveQuota = 3 + s.Clients.Add(cl) + s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2}) + + for i, tx := range tt { + t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) { + r, w := net.Pipe() + time.Sleep(time.Millisecond) + cl.Net.Conn = w + + recv := make(chan []byte) + go func() { // receive the ack + buf, err := io.ReadAll(r) + require.NoError(t, err) + recv <- buf + }() + + if i == 0 { + s.publishToSubscribers(*tx.in.Packet) + } else { + err := s.processPacket(cl, *tx.in.Packet) + require.NoError(t, err) + } + + time.Sleep(time.Millisecond) + _ = w.Close() + + if i != 2 { + require.Equal(t, tx.out.RawBytes, <-recv) + } + + require.Equal(t, tx.data["inflight"].(int64), atomic.LoadInt64(&s.Info.Inflight)) + require.Equal(t, tx.data["recvquota"].(int32), atomic.LoadInt32(&cl.State.Inflight.receiveQuota)) + require.Equal(t, tx.data["sendquota"].(int32), atomic.LoadInt32(&cl.State.Inflight.sendQuota)) + }) + } + + _, ok := cl.State.Inflight.Get(pID) + require.False(t, ok) +} + +func TestServerProcessPacketSubscribe(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackMqtt5).RawBytes, buf) +} + +func TestServerProcessPacketSubscribePacketIDInUse(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + cl.State.Inflight.Set(packets.Packet{PacketID: 15, FixedHeader: packets.FixedHeader{Type: packets.Publish}}) + + pkx := *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5).Packet + pkx.PacketID = 15 + go func() { + err := s.processPacket(cl, pkx) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackPacketIDInUse).RawBytes, buf) +} + +func TestServerProcessPacketSubscribeInvalid(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + cl.Properties.ProtocolVersion = 5 + + err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeSpecQosMustPacketID).Packet) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID) +} + +func TestServerProcessPacketSubscribeInvalidFilter(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidFilter).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackInvalidFilter).RawBytes, buf) +} + +func TestServerProcessPacketSubscribeInvalidSharedNoLocal(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidSharedNoLocal).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackInvalidSharedNoLocal).RawBytes, buf) +} + +func TestServerProcessSubscribeWithRetain(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) + require.NoError(t, err) + + time.Sleep(time.Millisecond) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, append( + packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes, + packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes..., + ), buf) +} + +func TestServerProcessSubscribeDowngradeQos(t *testing.T) { + s := newServer() + s.Options.Capabilities.MaximumQos = 1 + cl, r, w := newTestClient() + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMany).Packet) + require.NoError(t, err) + + time.Sleep(time.Millisecond) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, []byte{0, 1, 1}, buf[4:]) +} + +func TestServerProcessSubscribeWithRetainHandling1(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c"}) + s.Clients.Add(cl) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeRetainHandling1).Packet) + require.NoError(t, err) + + time.Sleep(time.Millisecond) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes, buf) +} + +func TestServerProcessSubscribeWithRetainHandling2(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + s.Clients.Add(cl) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeRetainHandling2).Packet) + require.NoError(t, err) + + time.Sleep(time.Millisecond) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes, buf) +} + +func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + s.Clients.Add(cl) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeRetainAsPublished).Packet) + require.NoError(t, err) + + time.Sleep(time.Millisecond) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, append( + packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes, + packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes..., + ), buf) +} + +func TestServerProcessSubscribeNoConnection(t *testing.T) { + s := newServer() + cl, r, _ := newTestClient() + _ = r.Close() + err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) + require.Error(t, err) + require.ErrorIs(t, err, io.ErrClosedPipe) +} + +func TestServerProcessSubscribeACLCheckDeny(t *testing.T) { + s := New(&Options{ + Logger: logger, + }) + _ = s.Serve() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + + go func() { + err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackDeny).RawBytes, buf) +} + +func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) { + s := New(&Options{ + Logger: logger, + }) + _ = s.Serve() + s.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + + go func() { + err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackUnspecifiedErrorMqtt5).RawBytes, buf) +} + +func TestServerProcessSubscribeErrorDowngrade(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 3 + cl.State.packetID = 1 // just to match the same packet id (7) in the fixtures + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidSharedNoLocal).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackUnspecifiedError).RawBytes, buf) +} + +func TestServerProcessPacketUnsubscribe(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b", Qos: 0}) + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackMqtt5).RawBytes, buf) + require.Equal(t, int64(-1), atomic.LoadInt64(&s.Info.Subscriptions)) +} + +func TestServerProcessPacketUnsubscribePackedIDInUse(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + cl.State.Inflight.Set(packets.Packet{PacketID: 15, FixedHeader: packets.FixedHeader{Type: packets.Publish}}) + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackPacketIDInUse).RawBytes, buf) + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Subscriptions)) +} + +func TestServerProcessPacketUnsubscribeInvalid(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeSpecQosMustPacketID).Packet) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID) +} + +func TestServerReceivePacketError(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + err := s.receivePacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeSpecQosMustPacketID).Packet) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID) +} + +func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + cl.Properties.Props.SessionExpiryInterval = 0 + cl.Properties.ProtocolVersion = 5 + cl.Properties.Props.RequestProblemInfo = 0 + cl.Properties.Props.RequestProblemInfoFlag = true + go func() { + err := s.receivePacket(cl, *packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5).Packet) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrProtocolViolationZeroNonZeroExpiry) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectZeroNonZeroExpiry).RawBytes, buf) +} + +func TestServerRecievePacketDisconnectClient(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + + go func() { + err := s.DisconnectClient(cl, packets.CodeDisconnect) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, buf) +} + +func TestServerProcessPacketDisconnect(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + cl.Properties.Props.SessionExpiryInterval = 30 + cl.Properties.ProtocolVersion = 5 + + s.loop.willDelayed.Add(cl.ID, packets.Packet{TopicName: "a/b/c", Payload: []byte("hello")}) + require.Equal(t, 1, s.loop.willDelayed.Len()) + + err := s.processPacket(cl, *packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5).Packet) + require.NoError(t, err) + + require.Equal(t, 0, s.loop.willDelayed.Len()) + require.True(t, cl.Closed()) + require.Equal(t, time.Now().Unix(), atomic.LoadInt64(&cl.State.disconnected)) +} + +func TestServerProcessPacketDisconnectNonZeroExpiryViolation(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + cl.Properties.Props.SessionExpiryInterval = 0 + cl.Properties.ProtocolVersion = 5 + cl.Properties.Props.RequestProblemInfo = 0 + cl.Properties.Props.RequestProblemInfoFlag = true + + err := s.processPacket(cl, *packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5).Packet) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrProtocolViolationZeroNonZeroExpiry) +} + +func TestServerProcessPacketDisconnectDisconnectWithWillMessage(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + cl.Properties.Props.SessionExpiryInterval = 30 + cl.Properties.ProtocolVersion = 5 + + s.loop.willDelayed.Add(cl.ID, packets.Packet{TopicName: "a/b/c", Payload: []byte("hello")}) + require.Equal(t, 1, s.loop.willDelayed.Len()) + + err := s.processPacket(cl, *packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5DisconnectWithWillMessage).Packet) + require.Error(t, err) + + require.Equal(t, 1, s.loop.willDelayed.Len()) + require.False(t, cl.Closed()) +} + +func TestServerProcessPacketAuth(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet) + require.NoError(t, err) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, []byte{}, buf) +} + +func TestServerProcessPacketAuthInvalidReason(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + pkx := *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet + pkx.ReasonCode = 99 + err := s.processPacket(cl, pkx) + require.Error(t, err) + require.ErrorIs(t, packets.ErrProtocolViolationInvalidReason, err) +} + +func TestServerProcessPacketAuthFailure(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + + hook := new(modifiedHookBase) + hook.fail = true + err := s.AddHook(hook, nil) + require.NoError(t, err) + + err = s.processAuth(cl, *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet) + require.Error(t, err) + require.ErrorIs(t, errTestHook, err) +} + +func TestServerSendLWT(t *testing.T) { + s := newServer() + _ = s.Serve() + defer s.Close() + + sender, _, w1 := newTestClient() + sender.ID = "sender" + sender.Properties.Will = Will{ + Flag: 1, + TopicName: "a/b/c", + Payload: []byte("hello mochi"), + } + s.Clients.Add(sender) + + receiver, r2, w2 := newTestClient() + receiver.ID = "receiver" + s.Clients.Add(receiver) + s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c", Qos: 0}) + + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived)) + require.Equal(t, 0, len(s.Topics.Messages("a/b/c"))) + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r2) + require.NoError(t, err) + receiverBuf <- buf + }() + + go func() { + s.sendLWT(sender) + time.Sleep(time.Millisecond * 10) + _ = w1.Close() + _ = w2.Close() + }() + + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) +} + +func TestServerSendLWTRetain(t *testing.T) { + s := newServer() + _ = s.Serve() + defer s.Close() + + sender, _, w1 := newTestClient() + sender.ID = "sender" + sender.Properties.Will = Will{ + Flag: 1, + TopicName: "a/b/c", + Payload: []byte("hello mochi"), + Retain: true, + } + s.Clients.Add(sender) + + receiver, r2, w2 := newTestClient() + receiver.ID = "receiver" + s.Clients.Add(receiver) + s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c", Qos: 0}) + + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived)) + require.Equal(t, 0, len(s.Topics.Messages("a/b/c"))) + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r2) + require.NoError(t, err) + receiverBuf <- buf + }() + + go func() { + s.sendLWT(sender) + time.Sleep(time.Millisecond * 10) + _ = w1.Close() + _ = w2.Close() + }() + + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) +} + +func TestServerSendLWTDelayed(t *testing.T) { + s := newServer() + cl1, _, _ := newTestClient() + cl1.ID = "cl1" + cl1.Properties.Will = Will{ + Flag: 1, + TopicName: "a/b/c", + Payload: []byte("hello mochi"), + Retain: true, + WillDelayInterval: 2, + } + s.Clients.Add(cl1) + + cl2, r, w := newTestClient() + cl2.ID = "cl2" + s.Clients.Add(cl2) + require.True(t, s.Topics.Subscribe(cl2.ID, packets.Subscription{Filter: "a/b/c"})) + + go func() { + s.sendLWT(cl1) + pk, ok := s.loop.willDelayed.Get(cl1.ID) + require.True(t, ok) + pk.Expiry = time.Now().Unix() - 1 // set back expiry time + s.loop.willDelayed.Add(cl1.ID, pk) + require.Equal(t, 1, s.loop.willDelayed.Len()) + s.sendDelayedLWT(time.Now().Unix()) + require.Equal(t, 0, s.loop.willDelayed.Len()) + time.Sleep(time.Millisecond) + _ = w.Close() + }() + + recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(r) + require.NoError(t, err) + recv <- buf + }() + + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-recv) +} + +func TestServerReadStore(t *testing.T) { + s := newServer() + hook := new(modifiedHookBase) + _ = s.AddHook(hook, nil) + + hook.failAt = 1 // clients + err := s.readStore() + require.Error(t, err) + + hook.failAt = 2 // subscriptions + err = s.readStore() + require.Error(t, err) + + hook.failAt = 3 // inflight + err = s.readStore() + require.Error(t, err) + + hook.failAt = 4 // retained + err = s.readStore() + require.Error(t, err) + + hook.failAt = 5 // sys info + err = s.readStore() + require.Error(t, err) +} + +func TestServerLoadClients(t *testing.T) { + v := []storage.Client{ + {ID: "mochi"}, + {ID: "zen"}, + {ID: "mochi-co"}, + {ID: "v3-clean", ProtocolVersion: 4, Clean: true}, + {ID: "v3-not-clean", ProtocolVersion: 4, Clean: false}, + { + ID: "v5-clean", + ProtocolVersion: 5, + Clean: true, + Properties: storage.ClientProperties{ + SessionExpiryInterval: 10, + }, + }, + { + ID: "v5-expire-interval-0", + ProtocolVersion: 5, + Properties: storage.ClientProperties{ + SessionExpiryInterval: 0, + }, + }, + { + ID: "v5-expire-interval-not-0", + ProtocolVersion: 5, + Properties: storage.ClientProperties{ + SessionExpiryInterval: 10, + }, + }, + } + + s := newServer() + require.Equal(t, 0, s.Clients.Len()) + s.loadClients(v) + require.Equal(t, 6, s.Clients.Len()) + cl, ok := s.Clients.Get("mochi") + require.True(t, ok) + require.Equal(t, "mochi", cl.ID) + + _, ok = s.Clients.Get("v3-clean") + require.False(t, ok) + _, ok = s.Clients.Get("v3-not-clean") + require.True(t, ok) + _, ok = s.Clients.Get("v5-clean") + require.True(t, ok) + _, ok = s.Clients.Get("v5-expire-interval-0") + require.False(t, ok) + _, ok = s.Clients.Get("v5-expire-interval-not-0") + require.True(t, ok) +} + +func TestServerLoadSubscriptions(t *testing.T) { + v := []storage.Subscription{ + {ID: "sub1", Client: "mochi", Filter: "a/b/c"}, + {ID: "sub2", Client: "mochi", Filter: "d/e/f", Qos: 1}, + {ID: "sub3", Client: "mochi", Filter: "h/i/j", Qos: 2}, + } + + s := newServer() + cl, _, _ := newTestClient() + s.Clients.Add(cl) + require.Equal(t, 0, cl.State.Subscriptions.Len()) + s.loadSubscriptions(v) + require.Equal(t, 3, cl.State.Subscriptions.Len()) +} + +func TestServerLoadInflightMessages(t *testing.T) { + s := newServer() + s.loadClients([]storage.Client{ + {ID: "mochi"}, + {ID: "zen"}, + {ID: "mochi-co"}, + }) + + require.Equal(t, 3, s.Clients.Len()) + + v := []storage.Message{ + {Origin: "mochi", PacketID: 1, Payload: []byte("hello world"), TopicName: "a/b/c"}, + {Origin: "mochi", PacketID: 2, Payload: []byte("yes"), TopicName: "a/b/c"}, + {Origin: "zen", PacketID: 3, Payload: []byte("hello world"), TopicName: "a/b/c"}, + {Origin: "mochi-co", PacketID: 4, Payload: []byte("hello world"), TopicName: "a/b/c"}, + } + s.loadInflight(v) + + cl, ok := s.Clients.Get("mochi") + require.True(t, ok) + require.Equal(t, "mochi", cl.ID) + + msg, ok := cl.State.Inflight.Get(2) + require.True(t, ok) + require.Equal(t, []byte{'y', 'e', 's'}, msg.Payload) + require.Equal(t, "a/b/c", msg.TopicName) + + cl, ok = s.Clients.Get("mochi-co") + require.True(t, ok) + msg, ok = cl.State.Inflight.Get(4) + require.True(t, ok) +} + +func TestServerLoadRetainedMessages(t *testing.T) { + s := newServer() + + v := []storage.Message{ + {Origin: "mochi", FixedHeader: packets.FixedHeader{Retain: true}, Payload: []byte("hello world"), TopicName: "a/b/c"}, + {Origin: "mochi-co", FixedHeader: packets.FixedHeader{Retain: true}, Payload: []byte("yes"), TopicName: "d/e/f"}, + {Origin: "zen", FixedHeader: packets.FixedHeader{Retain: true}, Payload: []byte("hello world"), TopicName: "h/i/j"}, + } + s.loadRetained(v) + require.Equal(t, 1, len(s.Topics.Messages("a/b/c"))) + require.Equal(t, 1, len(s.Topics.Messages("d/e/f"))) + require.Equal(t, 1, len(s.Topics.Messages("h/i/j"))) + require.Equal(t, 0, len(s.Topics.Messages("w/x/y"))) +} + +func TestServerClose(t *testing.T) { + s := newServer() + + hook := new(modifiedHookBase) + _ = s.AddHook(hook, nil) + + cl, r, _ := newTestClient() + cl.Net.Listener = "t1" + cl.Properties.ProtocolVersion = 5 + s.Clients.Add(cl) + + err := s.AddListener(listeners.NewMockListener("t1", ":1882")) + require.NoError(t, err) + _ = s.Serve() + + // receive the disconnect + recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(r) + require.NoError(t, err) + recv <- buf + }() + + time.Sleep(time.Millisecond) + require.Equal(t, 1, s.Listeners.Len()) + + listener, ok := s.Listeners.Get("t1") + require.Equal(t, true, ok) + require.Equal(t, true, listener.(*listeners.MockListener).IsServing()) + + _ = s.Close() + time.Sleep(time.Millisecond) + require.Equal(t, false, listener.(*listeners.MockListener).IsServing()) + require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectShuttingDown).RawBytes, <-recv) +} + +func TestServerClearExpiredInflights(t *testing.T) { + s := New(nil) + require.NotNil(t, s) + s.Options.Capabilities.MaximumMessageExpiryInterval = 4 + + n := time.Now().Unix() + cl, _, _ := newTestClient() + cl.ops.info = s.Info + + cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1}) + cl.State.Inflight.Set(packets.Packet{PacketID: 2, Expiry: n - 2}) + cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: n - 3}) // within bounds + cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: n - 5}) // over max server expiry limit + cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n}) + + s.Clients.Add(cl) + + require.Len(t, cl.State.Inflight.GetAll(false), 5) + s.clearExpiredInflights(n) + require.Len(t, cl.State.Inflight.GetAll(false), 2) + require.Equal(t, int64(-3), s.Info.Inflight) + + s.Options.Capabilities.MaximumMessageExpiryInterval = 0 + cl.State.Inflight.Set(packets.Packet{PacketID: 8, Expiry: n - 8}) + s.clearExpiredInflights(n) + require.Len(t, cl.State.Inflight.GetAll(false), 3) +} + +func TestServerClearExpiredRetained(t *testing.T) { + s := New(nil) + require.NotNil(t, s) + s.Options.Capabilities.MaximumMessageExpiryInterval = 4 + + n := time.Now().Unix() + s.Topics.Retained.Add("a/b/c", packets.Packet{ProtocolVersion: 5, Created: n, Expiry: n - 1}) + s.Topics.Retained.Add("d/e/f", packets.Packet{ProtocolVersion: 5, Created: n, Expiry: n - 2}) + s.Topics.Retained.Add("g/h/i", packets.Packet{ProtocolVersion: 5, Created: n - 3}) // within bounds + s.Topics.Retained.Add("j/k/l", packets.Packet{ProtocolVersion: 5, Created: n - 5}) // over max server expiry limit + s.Topics.Retained.Add("m/n/o", packets.Packet{ProtocolVersion: 5, Created: n}) + + require.Len(t, s.Topics.Retained.GetAll(), 5) + s.clearExpiredRetainedMessages(n) + require.Len(t, s.Topics.Retained.GetAll(), 2) + + s.Topics.Retained.Add("p/q/r", packets.Packet{Created: n, Expiry: n - 1}) + s.Topics.Retained.Add("s/t/u", packets.Packet{Created: n, Expiry: n - 2}) // expiry is ineffective for v3. + s.Topics.Retained.Add("v/w/x", packets.Packet{Created: n - 3}) // within bounds for v3 + s.Topics.Retained.Add("y/z/1", packets.Packet{Created: n - 5}) // over max server expiry limit + require.Len(t, s.Topics.Retained.GetAll(), 6) + s.clearExpiredRetainedMessages(n) + require.Len(t, s.Topics.Retained.GetAll(), 5) + + s.Options.Capabilities.MaximumMessageExpiryInterval = 0 + s.Topics.Retained.Add("2/3/4", packets.Packet{Created: n - 8}) + s.clearExpiredRetainedMessages(n) + require.Len(t, s.Topics.Retained.GetAll(), 6) +} + +func TestServerClearExpiredClients(t *testing.T) { + s := New(nil) + require.NotNil(t, s) + + n := time.Now().Unix() + + cl, _, _ := newTestClient() + cl.ID = "cl" + s.Clients.Add(cl) + + // No Expiry + cl0, _, _ := newTestClient() + cl0.ID = "c0" + cl0.State.disconnected = n - 10 + cl0.State.cancelOpen() + cl0.Properties.ProtocolVersion = 5 + cl0.Properties.Props.SessionExpiryInterval = 12 + cl0.Properties.Props.SessionExpiryIntervalFlag = true + s.Clients.Add(cl0) + + // Normal Expiry + cl1, _, _ := newTestClient() + cl1.ID = "c1" + cl1.State.disconnected = n - 10 + cl1.State.cancelOpen() + cl1.Properties.ProtocolVersion = 5 + cl1.Properties.Props.SessionExpiryInterval = 8 + cl1.Properties.Props.SessionExpiryIntervalFlag = true + s.Clients.Add(cl1) + + // No Expiry, indefinite session + cl2, _, _ := newTestClient() + cl2.ID = "c2" + cl2.State.disconnected = n - 10 + cl2.State.cancelOpen() + cl2.Properties.ProtocolVersion = 5 + cl2.Properties.Props.SessionExpiryInterval = 0 + cl2.Properties.Props.SessionExpiryIntervalFlag = true + s.Clients.Add(cl2) + + require.Equal(t, 4, s.Clients.Len()) + + s.clearExpiredClients(n) + require.Equal(t, 2, s.Clients.Len()) +} + +func TestLoadServerInfoRestoreOnRestart(t *testing.T) { + s := New(nil) + s.Options.Capabilities.Compatibilities.RestoreSysInfoOnRestart = true + info := system.Info{ + BytesReceived: 60, + } + + s.loadServerInfo(info) + require.Equal(t, int64(60), s.Info.BytesReceived) +} + +func TestItoa(t *testing.T) { + i := int64(22) + require.Equal(t, "22", Int64toa(i)) +} + +func TestServerSubscribe(t *testing.T) { + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {} + + s := newServerWithInlineClient() + require.NotNil(t, s) + + tt := []struct { + desc string + filter string + identifier int + handler InlineSubFn + expect error + }{ + { + desc: "subscribe", + filter: "a/b/c", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "re-subscribe", + filter: "a/b/c", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "subscribe d/e/f", + filter: "d/e/f", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "re-subscribe d/e/f by different identifier", + filter: "d/e/f", + identifier: 2, + handler: handler, + expect: nil, + }, + { + desc: "subscribe different handler", + filter: "a/b/c", + identifier: 1, + handler: func(cl *Client, sub packets.Subscription, pk packets.Packet) {}, + expect: nil, + }, + { + desc: "subscribe $SYS/info", + filter: "$SYS/info", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "subscribe invalid ###", + filter: "###", + identifier: 1, + handler: handler, + expect: packets.ErrTopicFilterInvalid, + }, + { + desc: "subscribe invalid handler", + filter: "a/b/c", + identifier: 1, + handler: nil, + expect: packets.ErrInlineSubscriptionHandlerInvalid, + }, + } + + for _, tx := range tt { + t.Run(tx.desc, func(t *testing.T) { + require.Equal(t, tx.expect, s.Subscribe(tx.filter, tx.identifier, tx.handler)) + }) + } +} + +func TestServerSubscribeNoInlineClient(t *testing.T) { + s := newServer() + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {}) + require.Error(t, err) + require.ErrorIs(t, err, ErrInlineClientNotEnabled) +} + +func TestServerUnsubscribe(t *testing.T) { + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + s := newServerWithInlineClient() + err := s.Subscribe("a/b/c", 1, handler) + require.Nil(t, err) + + err = s.Subscribe("d/e/f", 1, handler) + require.Nil(t, err) + + err = s.Subscribe("d/e/f", 2, handler) + require.Nil(t, err) + + err = s.Unsubscribe("a/b/c", 1) + require.Nil(t, err) + + err = s.Unsubscribe("d/e/f", 1) + require.Nil(t, err) + + err = s.Unsubscribe("d/e/f", 2) + require.Nil(t, err) + + err = s.Unsubscribe("not/exist", 1) + require.Nil(t, err) + + err = s.Unsubscribe("#/#/invalid", 1) + require.Equal(t, packets.ErrTopicFilterInvalid, err) +} + +func TestServerUnsubscribeNoInlineClient(t *testing.T) { + s := newServer() + err := s.Unsubscribe("a/b/c", 1) + require.Error(t, err) + require.ErrorIs(t, err, ErrInlineClientNotEnabled) +} + +func TestPublishToInlineSubscriber(t *testing.T) { + s := newServerWithInlineClient() + finishCh := make(chan bool) + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + s.publishToSubscribers(pkx) + }() + + require.Equal(t, true, <-finishCh) +} + +func TestPublishToInlineSubscribersDifferentFilter(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("mochi mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "z/e/n", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + s.publishToSubscribers(pkx) + + pkx = *packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet + s.publishToSubscribers(pkx) + }() + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} + +func TestPublishToInlineSubscribersDifferentIdentifier(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 2, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + s.publishToSubscribers(pkx) + }() + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} + +func TestServerSubscribeWithRetain(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 1 + finishCh := make(chan bool, subNumber) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + require.Equal(t, true, <-finishCh) +} + +func TestServerSubscribeWithRetainDifferentFilter(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + retained = s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet) + require.Equal(t, int64(1), retained) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("mochi mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "z/e/n", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} + +func TestServerSubscribeWithRetainDifferentIdentifier(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 2, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} diff --git a/mqtt/topics.go b/mqtt/topics.go new file mode 100644 index 0000000..f42e621 --- /dev/null +++ b/mqtt/topics.go @@ -0,0 +1,824 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package mqtt + +import ( + "strings" + "sync" + "sync/atomic" + + "testmqtt/packets" +) + +var ( + // SharePrefix 共享主题的前缀 + SharePrefix = "$SHARE" // the prefix indicating a share topic + // SysPrefix 系统信息主题的前缀 + SysPrefix = "$SYS" // the prefix indicating a system info topic +) + +// TopicAliases contains inbound and outbound topic alias registrations. +type TopicAliases struct { + Inbound *InboundTopicAliases + Outbound *OutboundTopicAliases +} + +// NewTopicAliases returns an instance of TopicAliases. +func NewTopicAliases(topicAliasMaximum uint16) TopicAliases { + return TopicAliases{ + Inbound: NewInboundTopicAliases(topicAliasMaximum), + Outbound: NewOutboundTopicAliases(topicAliasMaximum), + } +} + +// NewInboundTopicAliases returns a pointer to InboundTopicAliases. +func NewInboundTopicAliases(topicAliasMaximum uint16) *InboundTopicAliases { + return &InboundTopicAliases{ + maximum: topicAliasMaximum, + internal: map[uint16]string{}, + } +} + +// InboundTopicAliases contains a map of topic aliases received from the client. +type InboundTopicAliases struct { + internal map[uint16]string + sync.RWMutex + maximum uint16 +} + +// Set sets a new alias for a specific topic. +func (a *InboundTopicAliases) Set(id uint16, topic string) string { + a.Lock() + defer a.Unlock() + + if a.maximum == 0 { + return topic // ? + } + + if existing, ok := a.internal[id]; ok && topic == "" { + return existing + } + + a.internal[id] = topic + return topic +} + +// OutboundTopicAliases contains a map of topic aliases sent from the broker to the client. +type OutboundTopicAliases struct { + internal map[string]uint16 + sync.RWMutex + cursor uint32 + maximum uint16 +} + +// NewOutboundTopicAliases returns a pointer to OutboundTopicAliases. +func NewOutboundTopicAliases(topicAliasMaximum uint16) *OutboundTopicAliases { + return &OutboundTopicAliases{ + maximum: topicAliasMaximum, + internal: map[string]uint16{}, + } +} + +// Set sets a new topic alias for a topic and returns the alias value, and a boolean +// indicating if the alias already existed. +func (a *OutboundTopicAliases) Set(topic string) (uint16, bool) { + a.Lock() + defer a.Unlock() + + if a.maximum == 0 { + return 0, false + } + + if i, ok := a.internal[topic]; ok { + return i, true + } + + i := atomic.LoadUint32(&a.cursor) + if i+1 > uint32(a.maximum) { + // if i+1 > math.MaxUint16 { + return 0, false + } + + a.internal[topic] = uint16(i) + 1 + atomic.StoreUint32(&a.cursor, i+1) + return uint16(i) + 1, false +} + +// SharedSubscriptions contains a map of subscriptions to a shared filter, +// keyed on share group then client id. +type SharedSubscriptions struct { + internal map[string]map[string]packets.Subscription + sync.RWMutex +} + +// NewSharedSubscriptions returns a new instance of Subscriptions. +func NewSharedSubscriptions() *SharedSubscriptions { + return &SharedSubscriptions{ + internal: map[string]map[string]packets.Subscription{}, + } +} + +// Add creates a new shared subscription for a group and client id pair. +func (s *SharedSubscriptions) Add(group, id string, val packets.Subscription) { + s.Lock() + defer s.Unlock() + if _, ok := s.internal[group]; !ok { + s.internal[group] = map[string]packets.Subscription{} + } + s.internal[group][id] = val +} + +// Delete deletes a client id from a shared subscription group. +func (s *SharedSubscriptions) Delete(group, id string) { + s.Lock() + defer s.Unlock() + delete(s.internal[group], id) + if len(s.internal[group]) == 0 { + delete(s.internal, group) + } +} + +// Get returns the subscription properties for a client id in a share group, if one exists. +func (s *SharedSubscriptions) Get(group, id string) (val packets.Subscription, ok bool) { + s.RLock() + defer s.RUnlock() + if _, ok := s.internal[group]; !ok { + return val, ok + } + + val, ok = s.internal[group][id] + return val, ok +} + +// GroupLen returns the number of groups subscribed to the filter. +func (s *SharedSubscriptions) GroupLen() int { + s.RLock() + defer s.RUnlock() + val := len(s.internal) + return val +} + +// Len returns the total number of shared subscriptions to a filter across all groups. +func (s *SharedSubscriptions) Len() int { + s.RLock() + defer s.RUnlock() + n := 0 + for _, group := range s.internal { + n += len(group) + } + return n +} + +// GetAll returns all shared subscription groups and their subscriptions. +func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscription { + s.RLock() + defer s.RUnlock() + m := map[string]map[string]packets.Subscription{} + for group, subs := range s.internal { + if _, ok := m[group]; !ok { + m[group] = map[string]packets.Subscription{} + } + + for id, sub := range subs { + m[group][id] = sub + } + } + return m +} + +// InlineSubFn is the signature for a callback function which will be called +// when an inline client receives a message on a topic it is subscribed to. +// The sub argument contains information about the subscription that was matched for any filters. +type InlineSubFn func(cl *Client, sub packets.Subscription, pk packets.Packet) + +// InlineSubscriptions represents a map of internal subscriptions keyed on client. +type InlineSubscriptions struct { + internal map[int]InlineSubscription + sync.RWMutex +} + +// NewInlineSubscriptions returns a new instance of InlineSubscriptions. +func NewInlineSubscriptions() *InlineSubscriptions { + return &InlineSubscriptions{ + internal: map[int]InlineSubscription{}, + } +} + +// Add adds a new internal subscription for a client id. +func (s *InlineSubscriptions) Add(val InlineSubscription) { + s.Lock() + defer s.Unlock() + s.internal[val.Identifier] = val +} + +// GetAll returns all internal subscriptions. +func (s *InlineSubscriptions) GetAll() map[int]InlineSubscription { + s.RLock() + defer s.RUnlock() + m := map[int]InlineSubscription{} + for k, v := range s.internal { + m[k] = v + } + return m +} + +// Get returns an internal subscription for a client id. +func (s *InlineSubscriptions) Get(id int) (val InlineSubscription, ok bool) { + s.RLock() + defer s.RUnlock() + val, ok = s.internal[id] + return val, ok +} + +// Len returns the number of internal subscriptions. +func (s *InlineSubscriptions) Len() int { + s.RLock() + defer s.RUnlock() + val := len(s.internal) + return val +} + +// Delete removes an internal subscription by the client id. +func (s *InlineSubscriptions) Delete(id int) { + s.Lock() + defer s.Unlock() + delete(s.internal, id) +} + +// Subscriptions is a map of subscriptions keyed on client. +type Subscriptions struct { + internal map[string]packets.Subscription + sync.RWMutex +} + +// NewSubscriptions returns a new instance of Subscriptions. +func NewSubscriptions() *Subscriptions { + return &Subscriptions{ + internal: map[string]packets.Subscription{}, + } +} + +// Add adds a new subscription for a client. ID can be a filter in the +// case this map is client state, or a client id if particle state. +func (s *Subscriptions) Add(id string, val packets.Subscription) { + s.Lock() + defer s.Unlock() + s.internal[id] = val +} + +// GetAll returns all subscriptions. +func (s *Subscriptions) GetAll() map[string]packets.Subscription { + s.RLock() + defer s.RUnlock() + m := map[string]packets.Subscription{} + for k, v := range s.internal { + m[k] = v + } + return m +} + +// Get returns a subscriptions for a specific client or filter id. +func (s *Subscriptions) Get(id string) (val packets.Subscription, ok bool) { + s.RLock() + defer s.RUnlock() + val, ok = s.internal[id] + return val, ok +} + +// Len returns the number of subscriptions. +func (s *Subscriptions) Len() int { + s.RLock() + defer s.RUnlock() + val := len(s.internal) + return val +} + +// Delete removes a subscription by client or filter id. +func (s *Subscriptions) Delete(id string) { + s.Lock() + defer s.Unlock() + delete(s.internal, id) +} + +// ClientSubscriptions is a map of aggregated subscriptions for a client. +type ClientSubscriptions map[string]packets.Subscription + +type InlineSubscription struct { + packets.Subscription + Handler InlineSubFn +} + +// Subscribers contains the shared and non-shared subscribers matching a topic. +type Subscribers struct { + Shared map[string]map[string]packets.Subscription + SharedSelected map[string]packets.Subscription + Subscriptions map[string]packets.Subscription + InlineSubscriptions map[int]InlineSubscription +} + +// SelectShared returns one subscriber for each shared subscription group. +func (s *Subscribers) SelectShared() { + s.SharedSelected = map[string]packets.Subscription{} + for _, subs := range s.Shared { + for client, sub := range subs { + cls, ok := s.SharedSelected[client] + if !ok { + cls = sub + } + + s.SharedSelected[client] = cls.Merge(sub) + break + } + } +} + +// MergeSharedSelected merges the selected subscribers for a shared subscription group +// and the non-shared subscribers, to ensure that no subscriber gets multiple messages +// due to have both types of subscription matching the same filter. +func (s *Subscribers) MergeSharedSelected() { + for client, sub := range s.SharedSelected { + cls, ok := s.Subscriptions[client] + if !ok { + cls = sub + } + + s.Subscriptions[client] = cls.Merge(sub) + } +} + +// TopicsIndex is a prefix/trie tree containing topic subscribers and retained messages. +type TopicsIndex struct { + Retained *packets.Packets + root *particle // a leaf containing a message and more leaves. +} + +// NewTopicsIndex returns a pointer to a new instance of Index. +func NewTopicsIndex() *TopicsIndex { + return &TopicsIndex{ + Retained: packets.NewPackets(), + root: &particle{ + particles: newParticles(), + subscriptions: NewSubscriptions(), + }, + } +} + +// InlineSubscribe adds a new internal subscription for a topic filter, returning +// true if the subscription was new. +func (x *TopicsIndex) InlineSubscribe(subscription InlineSubscription) bool { + x.root.Lock() + defer x.root.Unlock() + + var existed bool + n := x.set(subscription.Filter, 0) + _, existed = n.inlineSubscriptions.Get(subscription.Identifier) + n.inlineSubscriptions.Add(subscription) + + return !existed +} + +// InlineUnsubscribe removes an internal subscription for a topic filter associated with a specific client, +// returning true if the subscription existed. +func (x *TopicsIndex) InlineUnsubscribe(id int, filter string) bool { + x.root.Lock() + defer x.root.Unlock() + + particle := x.seek(filter, 0) + if particle == nil { + return false + } + + particle.inlineSubscriptions.Delete(id) + + if particle.inlineSubscriptions.Len() == 0 { + x.trim(particle) + } + return true +} + +// Subscribe adds a new subscription for a client to a topic filter, returning +// true if the subscription was new. +func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool { + x.root.Lock() + defer x.root.Unlock() + + var existed bool + prefix, _ := isolateParticle(subscription.Filter, 0) + if strings.EqualFold(prefix, SharePrefix) { + group, _ := isolateParticle(subscription.Filter, 1) + n := x.set(subscription.Filter, 2) + _, existed = n.shared.Get(group, client) + n.shared.Add(group, client, subscription) + } else { + n := x.set(subscription.Filter, 0) + _, existed = n.subscriptions.Get(client) + n.subscriptions.Add(client, subscription) + } + + return !existed +} + +// Unsubscribe removes a subscription filter for a client, returning true if the +// subscription existed. +func (x *TopicsIndex) Unsubscribe(filter, client string) bool { + x.root.Lock() + defer x.root.Unlock() + + var d int + prefix, _ := isolateParticle(filter, 0) + shareSub := strings.EqualFold(prefix, SharePrefix) + if shareSub { + d = 2 + } + + particle := x.seek(filter, d) + if particle == nil { + return false + } + + if shareSub { + group, _ := isolateParticle(filter, 1) + particle.shared.Delete(group, client) + } else { + particle.subscriptions.Delete(client) + } + + x.trim(particle) + return true +} + +// RetainMessage saves a message payload to the end of a topic address. Returns +// 1 if a retained message was added, and -1 if the retained message was removed. +// 0 is returned if sequential empty payloads are received. +func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 { + x.root.Lock() + defer x.root.Unlock() + + n := x.set(pk.TopicName, 0) + n.Lock() + defer n.Unlock() + if len(pk.Payload) > 0 { + n.retainPath = pk.TopicName + x.Retained.Add(pk.TopicName, pk) + return 1 + } + + var out int64 + if pke, ok := x.Retained.Get(pk.TopicName); ok && len(pke.Payload) > 0 && pke.FixedHeader.Retain { + out = -1 // if a retained packet existed, return -1 + } + + n.retainPath = "" + x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7] + x.trim(n) + + return out +} + +// set creates a topic address in the index and returns the final particle. +func (x *TopicsIndex) set(topic string, d int) *particle { + var key string + var hasNext = true + n := x.root + for hasNext { + key, hasNext = isolateParticle(topic, d) + d++ + + p := n.particles.get(key) + if p == nil { + p = newParticle(key, n) + n.particles.add(p) + } + n = p + } + + return n +} + +// seek finds the particle at a specific index in a topic filter. +func (x *TopicsIndex) seek(filter string, d int) *particle { + var key string + var hasNext = true + n := x.root + for hasNext { + key, hasNext = isolateParticle(filter, d) + n = n.particles.get(key) + d++ + if n == nil { + return nil + } + } + + return n +} + +// trim removes empty filter particles from the index. +func (x *TopicsIndex) trim(n *particle) { + for n.parent != nil && n.retainPath == "" && n.particles.len()+n.subscriptions.Len()+n.shared.Len()+n.inlineSubscriptions.Len() == 0 { + key := n.key + n = n.parent + n.particles.delete(key) + } +} + +// Messages returns a slice of any retained messages which match a filter. +func (x *TopicsIndex) Messages(filter string) []packets.Packet { + return x.scanMessages(filter, 0, nil, []packets.Packet{}) +} + +// scanMessages returns all retained messages on topics matching a given filter. +func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []packets.Packet) []packets.Packet { + if n == nil { + n = x.root + } + + if len(filter) == 0 || x.Retained.Len() == 0 { + return pks + } + + if !strings.ContainsRune(filter, '#') && !strings.ContainsRune(filter, '+') { + if pk, ok := x.Retained.Get(filter); ok { + pks = append(pks, pk) + } + return pks + } + + key, hasNext := isolateParticle(filter, d) + if key == "+" || key == "#" || d == -1 { + for _, adjacent := range n.particles.getAll() { + if d == 0 && adjacent.key == SysPrefix { + continue + } + + if !hasNext { + if adjacent.retainPath != "" { + if pk, ok := x.Retained.Get(adjacent.retainPath); ok { + pks = append(pks, pk) + } + } + } + + if hasNext || (d >= 0 && key == "#") { + pks = x.scanMessages(filter, d+1, adjacent, pks) + } + } + return pks + } + + if particle := n.particles.get(key); particle != nil { + if hasNext { + return x.scanMessages(filter, d+1, particle, pks) + } + + if pk, ok := x.Retained.Get(particle.retainPath); ok { + pks = append(pks, pk) + } + } + + return pks +} + +// Subscribers returns a map of clients who are subscribed to matching filters, +// their subscription ids and highest qos. +func (x *TopicsIndex) Subscribers(topic string) *Subscribers { + return x.scanSubscribers(topic, 0, nil, &Subscribers{ + Shared: map[string]map[string]packets.Subscription{}, + SharedSelected: map[string]packets.Subscription{}, + Subscriptions: map[string]packets.Subscription{}, + InlineSubscriptions: map[int]InlineSubscription{}, + }) +} + +// scanSubscribers returns a list of client subscriptions matching an indexed topic address. +func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Subscribers) *Subscribers { + if n == nil { + n = x.root + } + + if len(topic) == 0 { + return subs + } + + key, hasNext := isolateParticle(topic, d) + for _, partKey := range []string{key, "+"} { + if particle := n.particles.get(partKey); particle != nil { // [MQTT-3.3.2-3] + if hasNext { + x.scanSubscribers(topic, d+1, particle, subs) + } else { + x.gatherSubscriptions(topic, particle, subs) + x.gatherSharedSubscriptions(particle, subs) + x.gatherInlineSubscriptions(particle, subs) + + if wild := particle.particles.get("#"); wild != nil && partKey != "+" { + x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2 + x.gatherSharedSubscriptions(wild, subs) + x.gatherInlineSubscriptions(particle, subs) + } + } + } + } + + if particle := n.particles.get("#"); particle != nil { + x.gatherSubscriptions(topic, particle, subs) + x.gatherSharedSubscriptions(particle, subs) + x.gatherInlineSubscriptions(particle, subs) + } + + return subs +} + +// gatherSubscriptions collects any matching subscriptions, and gathers any identifiers or highest qos values. +func (x *TopicsIndex) gatherSubscriptions(topic string, particle *particle, subs *Subscribers) { + if subs.Subscriptions == nil { + subs.Subscriptions = map[string]packets.Subscription{} + } + + for client, sub := range particle.subscriptions.GetAll() { + if len(sub.Filter) > 0 && topic[0] == '$' && (sub.Filter[0] == '+' || sub.Filter[0] == '#') { // don't match $ topics with top level wildcards [MQTT-4.7.1-1] [MQTT-4.7.1-2] + continue + } + + cls, ok := subs.Subscriptions[client] + if !ok { + cls = sub + } + + subs.Subscriptions[client] = cls.Merge(sub) + } +} + +// gatherSharedSubscriptions gathers all shared subscriptions for a particle. +func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscribers) { + if subs.Shared == nil { + subs.Shared = map[string]map[string]packets.Subscription{} + } + + for _, shares := range particle.shared.GetAll() { + for client, sub := range shares { + if _, ok := subs.Shared[sub.Filter]; !ok { + subs.Shared[sub.Filter] = map[string]packets.Subscription{} + } + + subs.Shared[sub.Filter][client] = sub + } + } +} + +// gatherSharedSubscriptions gathers all inline subscriptions for a particle. +func (x *TopicsIndex) gatherInlineSubscriptions(particle *particle, subs *Subscribers) { + if subs.InlineSubscriptions == nil { + subs.InlineSubscriptions = map[int]InlineSubscription{} + } + + for id, inline := range particle.inlineSubscriptions.GetAll() { + subs.InlineSubscriptions[id] = inline + } +} + +// isolateParticle extracts a particle between d / and d+1 / without allocations. +func isolateParticle(filter string, d int) (particle string, hasNext bool) { + var next, end int + for i := 0; end > -1 && i <= d; i++ { + end = strings.IndexRune(filter, '/') + + switch { + case d > -1 && i == d && end > -1: + hasNext = true + particle = filter[next:end] + case end > -1: + hasNext = false + filter = filter[end+1:] + default: + hasNext = false + particle = filter[next:] + } + } + + return +} + +// IsSharedFilter returns true if the filter uses the share prefix. +func IsSharedFilter(filter string) bool { + prefix, _ := isolateParticle(filter, 0) + return strings.EqualFold(prefix, SharePrefix) +} + +// IsValidFilter returns true if the filter is valid. +func IsValidFilter(filter string, forPublish bool) bool { + if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publish. + return false // [MQTT-4.7.3-1] + } + + if forPublish { + if len(filter) >= len(SysPrefix) && strings.EqualFold(filter[0:len(SysPrefix)], SysPrefix) { + // 4.7.2 Non-normative - The Server SHOULD prevent Clients from using such Topic Names [$SYS] to exchange messages with other Clients. + return false + } + + if strings.ContainsRune(filter, '+') || strings.ContainsRune(filter, '#') { + return false //[MQTT-3.3.2-2] + } + } + + wildhash := strings.IndexRune(filter, '#') + if wildhash >= 0 && wildhash != len(filter)-1 { // [MQTT-4.7.1-2] + return false + } + + prefix, hasNext := isolateParticle(filter, 0) + if !hasNext && strings.EqualFold(prefix, SharePrefix) { + return false // [MQTT-4.8.2-1] + } + + if hasNext && strings.EqualFold(prefix, SharePrefix) { + group, hasNext := isolateParticle(filter, 1) + if !hasNext { + return false // [MQTT-4.8.2-1] + } + + if strings.ContainsRune(group, '+') || strings.ContainsRune(group, '#') { + return false // [MQTT-4.8.2-2] + } + } + + return true +} + +// particle is a child node on the tree. +type particle struct { + key string // the key of the particle + parent *particle // a pointer to the parent of the particle + particles particles // a map of child particles + subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address + shared *SharedSubscriptions // a map of shared subscriptions keyed on group name + inlineSubscriptions *InlineSubscriptions // a map of inline subscriptions for this particle + retainPath string // path of a retained message + sync.Mutex // mutex for when making changes to the particle +} + +// newParticle returns a pointer to a new instance of particle. +func newParticle(key string, parent *particle) *particle { + return &particle{ + key: key, + parent: parent, + particles: newParticles(), + subscriptions: NewSubscriptions(), + shared: NewSharedSubscriptions(), + inlineSubscriptions: NewInlineSubscriptions(), + } +} + +// particles is a concurrency safe map of particles. +type particles struct { + internal map[string]*particle + sync.RWMutex +} + +// newParticles returns a map of particles. +func newParticles() particles { + return particles{ + internal: map[string]*particle{}, + } +} + +// add adds a new particle. +func (p *particles) add(val *particle) { + p.Lock() + p.internal[val.key] = val + p.Unlock() +} + +// getAll returns all particles. +func (p *particles) getAll() map[string]*particle { + p.RLock() + defer p.RUnlock() + m := map[string]*particle{} + for k, v := range p.internal { + m[k] = v + } + return m +} + +// get returns a particle by id (key). +func (p *particles) get(id string) *particle { + p.RLock() + defer p.RUnlock() + return p.internal[id] +} + +// len returns the number of particles. +func (p *particles) len() int { + p.RLock() + defer p.RUnlock() + val := len(p.internal) + return val +} + +// delete removes a particle. +func (p *particles) delete(id string) { + p.Lock() + defer p.Unlock() + delete(p.internal, id) +} diff --git a/mqtt/topics_test.go b/mqtt/topics_test.go new file mode 100644 index 0000000..ee63b71 --- /dev/null +++ b/mqtt/topics_test.go @@ -0,0 +1,1068 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package mqtt + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "testmqtt/packets" +) + +const ( + testGroup = "testgroup" + otherGroup = "other" +) + +func TestNewSharedSubscriptions(t *testing.T) { + s := NewSharedSubscriptions() + require.NotNil(t, s.internal) +} + +func TestSharedSubscriptionsAdd(t *testing.T) { + s := NewSharedSubscriptions() + s.Add(testGroup, "cl1", packets.Subscription{Filter: "a/b/c"}) + require.Contains(t, s.internal, testGroup) + require.Contains(t, s.internal[testGroup], "cl1") +} + +func TestSharedSubscriptionsGet(t *testing.T) { + s := NewSharedSubscriptions() + s.Add(testGroup, "cl1", packets.Subscription{Qos: 2}) + s.Add(testGroup, "cl2", packets.Subscription{Qos: 2}) + require.Contains(t, s.internal, testGroup) + require.Contains(t, s.internal[testGroup], "cl1") + require.Contains(t, s.internal, testGroup) + require.Contains(t, s.internal[testGroup], "cl2") + + sub, ok := s.Get(testGroup, "cl2") + require.Equal(t, true, ok) + require.Equal(t, byte(2), sub.Qos) +} + +func TestSharedSubscriptionsGetAll(t *testing.T) { + s := NewSharedSubscriptions() + s.Add(testGroup, "cl1", packets.Subscription{Qos: 0}) + s.Add(testGroup, "cl2", packets.Subscription{Qos: 1}) + s.Add(otherGroup, "cl3", packets.Subscription{Qos: 2}) + require.Contains(t, s.internal, testGroup) + require.Contains(t, s.internal[testGroup], "cl1") + require.Contains(t, s.internal, testGroup) + require.Contains(t, s.internal[testGroup], "cl2") + require.Contains(t, s.internal, otherGroup) + require.Contains(t, s.internal[otherGroup], "cl3") + + subs := s.GetAll() + require.Len(t, subs, 2) + require.Len(t, subs[testGroup], 2) + require.Len(t, subs[otherGroup], 1) +} + +func TestSharedSubscriptionsLen(t *testing.T) { + s := NewSharedSubscriptions() + s.Add(testGroup, "cl1", packets.Subscription{Qos: 0}) + s.Add(testGroup, "cl2", packets.Subscription{Qos: 1}) + s.Add(otherGroup, "cl2", packets.Subscription{Qos: 1}) + require.Contains(t, s.internal, testGroup) + require.Contains(t, s.internal[testGroup], "cl1") + require.Contains(t, s.internal, testGroup) + require.Contains(t, s.internal[testGroup], "cl2") + require.Contains(t, s.internal, otherGroup) + require.Contains(t, s.internal[otherGroup], "cl2") + require.Equal(t, 3, s.Len()) + require.Equal(t, 2, s.GroupLen()) +} + +func TestSharedSubscriptionsDelete(t *testing.T) { + s := NewSharedSubscriptions() + s.Add(testGroup, "cl1", packets.Subscription{Qos: 1}) + s.Add(testGroup, "cl2", packets.Subscription{Qos: 2}) + require.Contains(t, s.internal, testGroup) + require.Contains(t, s.internal[testGroup], "cl1") + require.Contains(t, s.internal, testGroup) + require.Contains(t, s.internal[testGroup], "cl2") + + require.Equal(t, 2, s.Len()) + + s.Delete(testGroup, "cl1") + _, ok := s.Get(testGroup, "cl1") + require.False(t, ok) + require.Equal(t, 1, s.GroupLen()) + require.Equal(t, 1, s.Len()) + + s.Delete(testGroup, "cl2") + _, ok = s.Get(testGroup, "cl2") + require.False(t, ok) + require.Equal(t, 0, s.GroupLen()) + require.Equal(t, 0, s.Len()) +} + +func TestNewSubscriptions(t *testing.T) { + s := NewSubscriptions() + require.NotNil(t, s.internal) +} + +func TestSubscriptionsAdd(t *testing.T) { + s := NewSubscriptions() + s.Add("cl1", packets.Subscription{}) + require.Contains(t, s.internal, "cl1") +} + +func TestSubscriptionsGet(t *testing.T) { + s := NewSubscriptions() + s.Add("cl1", packets.Subscription{Qos: 2}) + s.Add("cl2", packets.Subscription{Qos: 2}) + require.Contains(t, s.internal, "cl1") + require.Contains(t, s.internal, "cl2") + + sub, ok := s.Get("cl1") + require.True(t, ok) + require.Equal(t, byte(2), sub.Qos) +} + +func TestSubscriptionsGetAll(t *testing.T) { + s := NewSubscriptions() + s.Add("cl1", packets.Subscription{Qos: 0}) + s.Add("cl2", packets.Subscription{Qos: 1}) + s.Add("cl3", packets.Subscription{Qos: 2}) + require.Contains(t, s.internal, "cl1") + require.Contains(t, s.internal, "cl2") + require.Contains(t, s.internal, "cl3") + + subs := s.GetAll() + require.Len(t, subs, 3) +} + +func TestSubscriptionsLen(t *testing.T) { + s := NewSubscriptions() + s.Add("cl1", packets.Subscription{Qos: 0}) + s.Add("cl2", packets.Subscription{Qos: 1}) + require.Contains(t, s.internal, "cl1") + require.Contains(t, s.internal, "cl2") + require.Equal(t, 2, s.Len()) +} + +func TestSubscriptionsDelete(t *testing.T) { + s := NewSubscriptions() + s.Add("cl1", packets.Subscription{Qos: 1}) + require.Contains(t, s.internal, "cl1") + + s.Delete("cl1") + _, ok := s.Get("cl1") + require.False(t, ok) +} + +func TestNewTopicsIndex(t *testing.T) { + index := NewTopicsIndex() + require.NotNil(t, index) + require.NotNil(t, index.root) +} + +func BenchmarkNewTopicsIndex(b *testing.B) { + for n := 0; n < b.N; n++ { + NewTopicsIndex() + } +} + +func TestSubscribe(t *testing.T) { + tt := []struct { + desc string + client string + filter string + subscription packets.Subscription + wasNew bool + }{ + { + desc: "subscribe", + client: "cl1", + + subscription: packets.Subscription{Filter: "a/b/c", Qos: 2}, + wasNew: true, + }, + { + desc: "subscribe existed", + client: "cl1", + + subscription: packets.Subscription{Filter: "a/b/c", Qos: 1}, + wasNew: false, + }, + { + desc: "subscribe case sensitive didnt exist", + client: "cl1", + + subscription: packets.Subscription{Filter: "A/B/c", Qos: 1}, + wasNew: true, + }, + { + desc: "wildcard+ sub", + client: "cl1", + + subscription: packets.Subscription{Filter: "d/+"}, + wasNew: true, + }, + { + desc: "wildcard# sub", + client: "cl1", + subscription: packets.Subscription{Filter: "d/e/#"}, + wasNew: true, + }, + } + + index := NewTopicsIndex() + for _, tx := range tt { + t.Run(tx.desc, func(t *testing.T) { + require.Equal(t, tx.wasNew, index.Subscribe(tx.client, tx.subscription)) + }) + } + + final := index.root.particles.get("a").particles.get("b").particles.get("c") + require.NotNil(t, final) + client, exists := final.subscriptions.Get("cl1") + require.True(t, exists) + require.Equal(t, byte(1), client.Qos) +} + +func TestSubscribeShared(t *testing.T) { + index := NewTopicsIndex() + index.Subscribe("cl1", packets.Subscription{Filter: SharePrefix + "/tmp/a/b/c", Qos: 2}) + final := index.root.particles.get("a").particles.get("b").particles.get("c") + require.NotNil(t, final) + client, exists := final.shared.Get("tmp", "cl1") + require.True(t, exists) + require.Equal(t, byte(2), client.Qos) + require.Equal(t, 0, final.subscriptions.Len()) + require.Equal(t, 1, final.shared.Len()) +} + +func BenchmarkSubscribe(b *testing.B) { + index := NewTopicsIndex() + for n := 0; n < b.N; n++ { + index.Subscribe("client-1", packets.Subscription{Filter: "a/b/c"}) + } +} + +func BenchmarkSubscribeShared(b *testing.B) { + index := NewTopicsIndex() + for n := 0; n < b.N; n++ { + index.Subscribe("client-1", packets.Subscription{Filter: "$SHARE/tmp/a/b/c"}) + } +} + +func TestUnsubscribe(t *testing.T) { + index := NewTopicsIndex() + index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/d", Qos: 1}) + client, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").subscriptions.Get("cl1") + require.NotNil(t, client) + require.True(t, exists) + + index.Subscribe("cl1", packets.Subscription{Filter: "a/b/+/d", Qos: 1}) + client, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").subscriptions.Get("cl1") + require.NotNil(t, client) + require.True(t, exists) + + index.Subscribe("cl1", packets.Subscription{Filter: "d/e/f", Qos: 1}) + client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl1") + require.NotNil(t, client) + require.True(t, exists) + + index.Subscribe("cl2", packets.Subscription{Filter: "d/e/f", Qos: 1}) + client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl2") + require.NotNil(t, client) + require.True(t, exists) + + index.Subscribe("cl3", packets.Subscription{Filter: "#", Qos: 2}) + client, exists = index.root.particles.get("#").subscriptions.Get("cl3") + require.NotNil(t, client) + require.True(t, exists) + + ok := index.Unsubscribe("a/b/c/d", "cl1") + require.True(t, ok) + require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c")) + client, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").subscriptions.Get("cl1") + require.NotNil(t, client) + require.True(t, exists) + + ok = index.Unsubscribe("d/e/f", "cl1") + require.True(t, ok) + + require.Equal(t, 1, index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Len()) + client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl2") + require.NotNil(t, client) + require.True(t, exists) + + ok = index.Unsubscribe("fdasfdas/dfsfads/sa", "nobody") + require.False(t, ok) +} + +func TestUnsubscribeNoCascade(t *testing.T) { + index := NewTopicsIndex() + index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"}) + index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/e/e"}) + + ok := index.Unsubscribe("a/b/c/e/e", "cl1") + require.True(t, ok) + require.Equal(t, 1, index.root.particles.len()) + + client, exists := index.root.particles.get("a").particles.get("b").particles.get("c").subscriptions.Get("cl1") + require.NotNil(t, client) + require.True(t, exists) +} + +func TestUnsubscribeShared(t *testing.T) { + index := NewTopicsIndex() + index.Subscribe("cl1", packets.Subscription{Filter: "$SHARE/tmp/a/b/c", Qos: 2}) + final := index.root.particles.get("a").particles.get("b").particles.get("c") + require.NotNil(t, final) + client, exists := final.shared.Get("tmp", "cl1") + require.True(t, exists) + require.Equal(t, byte(2), client.Qos) + + require.True(t, index.Unsubscribe("$share/tmp/a/b/c", "cl1")) + _, exists = final.shared.Get("tmp", "cl1") + require.False(t, exists) +} + +func BenchmarkUnsubscribe(b *testing.B) { + index := NewTopicsIndex() + + for n := 0; n < b.N; n++ { + b.StopTimer() + index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"}) + b.StartTimer() + index.Unsubscribe("a/b/c", "cl1") + } +} + +func TestIndexSeek(t *testing.T) { + filter := "a/b/c/d/e/f" + index := NewTopicsIndex() + k1 := index.set(filter, 0) + require.Equal(t, "f", k1.key) + k1.subscriptions.Add("cl1", packets.Subscription{}) + + require.Equal(t, k1, index.seek(filter, 0)) + require.Nil(t, index.seek("d/e/f", 0)) +} + +func TestIndexTrim(t *testing.T) { + index := NewTopicsIndex() + k1 := index.set("a/b/c", 0) + require.Equal(t, "c", k1.key) + k1.subscriptions.Add("cl1", packets.Subscription{}) + + k2 := index.set("a/b/c/d/e/f", 0) + require.Equal(t, "f", k2.key) + k2.subscriptions.Add("cl1", packets.Subscription{}) + + k3 := index.set("a/b", 0) + require.Equal(t, "b", k3.key) + k3.subscriptions.Add("cl1", packets.Subscription{}) + + index.trim(k2) + require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c")) + require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").particles.get("e").particles.get("f")) + require.NotNil(t, index.root.particles.get("a").particles.get("b")) + + k2.subscriptions.Delete("cl1") + index.trim(k2) + + require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d")) + require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c")) + + k1.subscriptions.Delete("cl1") + k3.subscriptions.Delete("cl1") + index.trim(k2) + require.Nil(t, index.root.particles.get("a")) +} + +func TestIndexSet(t *testing.T) { + index := NewTopicsIndex() + child := index.set("a/b/c", 0) + require.Equal(t, "c", child.key) + require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c")) + + child = index.set("a/b/c/d/e", 0) + require.Equal(t, "e", child.key) + + child = index.set("a/b/c/c/a", 0) + require.Equal(t, "a", child.key) +} + +func TestIndexSetPrefixed(t *testing.T) { + index := NewTopicsIndex() + child := index.set("/c", 0) + require.Equal(t, "c", child.key) + require.NotNil(t, index.root.particles.get("").particles.get("c")) +} + +func BenchmarkIndexSet(b *testing.B) { + index := NewTopicsIndex() + for n := 0; n < b.N; n++ { + index.set("a/b/c", 0) + } +} + +func TestRetainMessage(t *testing.T) { + pk := packets.Packet{ + FixedHeader: packets.FixedHeader{Retain: true}, + TopicName: "a/b/c", + Payload: []byte("hello"), + } + + index := NewTopicsIndex() + r := index.RetainMessage(pk) + require.Equal(t, int64(1), r) + pke, ok := index.Retained.Get(pk.TopicName) + require.True(t, ok) + require.Equal(t, pk, pke) + + pk2 := packets.Packet{ + FixedHeader: packets.FixedHeader{Retain: true}, + TopicName: "a/b/d/f", + Payload: []byte("hello"), + } + r = index.RetainMessage(pk2) + require.Equal(t, int64(1), r) + // The same message already exists, but we're not doing a deep-copy check, so it's considered to be a new message. + r = index.RetainMessage(pk2) + require.Equal(t, int64(1), r) + + // Clear existing retained + pk3 := packets.Packet{TopicName: "a/b/c", Payload: []byte{}} + r = index.RetainMessage(pk3) + require.Equal(t, int64(-1), r) + _, ok = index.Retained.Get(pk.TopicName) + require.False(t, ok) + + // Clear no retained + r = index.RetainMessage(pk3) + require.Equal(t, int64(0), r) +} + +func BenchmarkRetainMessage(b *testing.B) { + index := NewTopicsIndex() + for n := 0; n < b.N; n++ { + index.RetainMessage(packets.Packet{TopicName: "a/b/c/d"}) + } +} + +func TestIsolateParticle(t *testing.T) { + particle, hasNext := isolateParticle("path/to/my/mqtt", 0) + require.Equal(t, "path", particle) + require.Equal(t, true, hasNext) + particle, hasNext = isolateParticle("path/to/my/mqtt", 1) + require.Equal(t, "to", particle) + require.Equal(t, true, hasNext) + particle, hasNext = isolateParticle("path/to/my/mqtt", 2) + require.Equal(t, "my", particle) + require.Equal(t, true, hasNext) + particle, hasNext = isolateParticle("path/to/my/mqtt", 3) + require.Equal(t, "mqtt", particle) + require.Equal(t, false, hasNext) + + particle, hasNext = isolateParticle("/path/", 0) + require.Equal(t, "", particle) + require.Equal(t, true, hasNext) + particle, hasNext = isolateParticle("/path/", 1) + require.Equal(t, "path", particle) + require.Equal(t, true, hasNext) + particle, hasNext = isolateParticle("/path/", 2) + require.Equal(t, "", particle) + require.Equal(t, false, hasNext) + + particle, hasNext = isolateParticle("a/b/c/+/+", 3) + require.Equal(t, "+", particle) + require.Equal(t, true, hasNext) + particle, hasNext = isolateParticle("a/b/c/+/+", 4) + require.Equal(t, "+", particle) + require.Equal(t, false, hasNext) +} + +func BenchmarkIsolateParticle(b *testing.B) { + for n := 0; n < b.N; n++ { + isolateParticle("path/to/my/mqtt", 3) + } +} + +func TestScanSubscribers(t *testing.T) { + index := NewTopicsIndex() + index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: "a/b/c", Identifier: 22}) + index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: "a/b/c/d/e/f"}) + index.Subscribe("cl1", packets.Subscription{Qos: 2, Filter: "a/b/c/d/+/f"}) + index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "a/#"}) + index.Subscribe("cl2", packets.Subscription{Qos: 1, Filter: "a/b/c"}) + index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "a/b/+", Identifier: 77}) + index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "d/e/f", Identifier: 7237}) + index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "$SYS/uptime", Identifier: 3}) + index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: "+/b", Identifier: 234}) + index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: "#", Identifier: 5}) + index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "$SYS/test", Identifier: 2}) + + subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers)) + require.Equal(t, 3, len(subs.Subscriptions)) + require.Contains(t, subs.Subscriptions, "cl1") + require.Contains(t, subs.Subscriptions, "cl2") + require.Contains(t, subs.Subscriptions, "cl4") + + require.Equal(t, byte(1), subs.Subscriptions["cl1"].Qos) + require.Equal(t, byte(2), subs.Subscriptions["cl2"].Qos) + require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos) + + require.Equal(t, 22, subs.Subscriptions["cl1"].Identifiers["a/b/c"]) + require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/#"]) + require.Equal(t, 77, subs.Subscriptions["cl2"].Identifiers["a/b/+"]) + require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/b/c"]) + require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"]) + + subs = index.scanSubscribers("d/e/f/g", 0, nil, new(Subscribers)) + require.Equal(t, 1, len(subs.Subscriptions)) + require.Contains(t, subs.Subscriptions, "cl4") + require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos) + require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"]) + + subs = index.scanSubscribers("", 0, nil, new(Subscribers)) + require.Equal(t, 0, len(subs.Subscriptions)) +} + +func TestScanSubscribersTopicInheritanceBug(t *testing.T) { + index := NewTopicsIndex() + index.Subscribe("cl1", packets.Subscription{Qos: 0, Filter: "a/b/c"}) + index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "a/b"}) + + subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers)) + require.Equal(t, 1, len(subs.Subscriptions)) +} + +func TestScanSubscribersShared(t *testing.T) { + index := NewTopicsIndex() + index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111}) + index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 112}) + index.Subscribe("cl3", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 113}) + index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 10}) + index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 200}) + index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 201}) + index.Subscribe("cl5", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c/#"}) + subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers)) + require.Equal(t, 4, len(subs.Shared)) +} + +func TestSelectSharedSubscriber(t *testing.T) { + index := NewTopicsIndex() + index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 110}) + index.Subscribe("cl1b", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111}) + index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 112}) + index.Subscribe("cl3", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 113}) + subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers)) + require.Equal(t, 2, len(subs.Shared)) + require.Contains(t, subs.Shared, SharePrefix+"/tmp/a/b/c") + require.Contains(t, subs.Shared, SharePrefix+"/tmp2/a/b/c") + require.Len(t, subs.Shared[SharePrefix+"/tmp/a/b/c"], 3) + require.Len(t, subs.Shared[SharePrefix+"/tmp2/a/b/c"], 1) + subs.SelectShared() + require.Len(t, subs.SharedSelected, 2) +} + +func TestMergeSharedSelected(t *testing.T) { + s := &Subscribers{ + SharedSelected: map[string]packets.Subscription{ + "cl1": {Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 110}, + "cl2": {Qos: 1, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 111}, + }, + Subscriptions: map[string]packets.Subscription{ + "cl2": {Qos: 1, Filter: "a/b/c", Identifier: 112}, + }, + } + + s.MergeSharedSelected() + + require.Equal(t, 2, len(s.Subscriptions)) + require.Contains(t, s.Subscriptions, "cl1") + require.Contains(t, s.Subscriptions, "cl2") + require.EqualValues(t, map[string]int{ + SharePrefix + "/tmp2/a/b/c": 111, + "a/b/c": 112, + }, s.Subscriptions["cl2"].Identifiers) +} + +func TestSubscribersFind(t *testing.T) { + tt := []struct { + filter string + topic string + matched bool + }{ + {filter: "a", topic: "a", matched: true}, + {filter: "a/", topic: "a", matched: false}, + {filter: "a/", topic: "a/", matched: true}, + {filter: "/a", topic: "/a", matched: true}, + {filter: "path/to/my/mqtt", topic: "path/to/my/mqtt", matched: true}, + {filter: "path/to/+/mqtt", topic: "path/to/my/mqtt", matched: true}, + {filter: "+/to/+/mqtt", topic: "path/to/my/mqtt", matched: true}, + {filter: "#", topic: "path/to/my/mqtt", matched: true}, + {filter: "+/+/+/+", topic: "path/to/my/mqtt", matched: true}, + {filter: "+/+/+/#", topic: "path/to/my/mqtt", matched: true}, + {filter: "zen/#", topic: "zen", matched: true}, // as per 4.7.1.2 + {filter: "trailing-end/#", topic: "trailing-end/", matched: true}, + {filter: "+/prefixed", topic: "/prefixed", matched: true}, + {filter: "+/+/#", topic: "path/to/my/mqtt", matched: true}, + {filter: "path/to/", topic: "path/to/my/mqtt", matched: false}, + {filter: "#/stuff", topic: "path/to/my/mqtt", matched: false}, + {filter: "#", topic: "$SYS/info", matched: false}, + {filter: "$SYS/#", topic: "$SYS/info", matched: true}, + {filter: "+/info", topic: "$SYS/info", matched: false}, + } + + for _, tx := range tt { + t.Run("filter:'"+tx.filter+"' vs topic:'"+tx.topic+"'", func(t *testing.T) { + index := NewTopicsIndex() + index.Subscribe("cl1", packets.Subscription{Filter: tx.filter}) + subs := index.Subscribers(tx.topic) + require.Equal(t, tx.matched, len(subs.Subscriptions) == 1) + }) + } +} + +func BenchmarkSubscribers(b *testing.B) { + index := NewTopicsIndex() + index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"}) + index.Subscribe("cl1", packets.Subscription{Filter: "a/+/c"}) + index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/+"}) + index.Subscribe("cl2", packets.Subscription{Filter: "a/b/c/d"}) + index.Subscribe("cl3", packets.Subscription{Filter: "#"}) + + for n := 0; n < b.N; n++ { + index.Subscribers("a/b/c") + } +} + +func TestMessagesPattern(t *testing.T) { + payload := []byte("hello") + fh := packets.FixedHeader{Type: packets.Publish, Retain: true} + + pks := []packets.Packet{ + {TopicName: "$SYS/uptime", Payload: payload, FixedHeader: fh}, + {TopicName: "$SYS/info", Payload: payload, FixedHeader: fh}, + {TopicName: "a/b/c/d", Payload: payload, FixedHeader: fh}, + {TopicName: "a/b/c/e", Payload: payload, FixedHeader: fh}, + {TopicName: "a/b/d/f", Payload: payload, FixedHeader: fh}, + {TopicName: "q/w/e/r/t/y", Payload: payload, FixedHeader: fh}, + {TopicName: "q/x/e/r/t/o", Payload: payload, FixedHeader: fh}, + {TopicName: "asdf", Payload: payload, FixedHeader: fh}, + } + + tt := []struct { + filter string + len int + }{ + {"a/b/c/d", 1}, + {"$SYS/+", 2}, + {"$SYS/#", 2}, + {"#", len(pks) - 2}, + {"a/b/c/+", 2}, + {"a/+/c/+", 2}, + {"+/+/+/d", 1}, + {"q/w/e/#", 1}, + {"+/+/+/+", 3}, + {"q/#", 2}, + {"asdf", 1}, + {"", 0}, + {"#", 6}, + } + + index := NewTopicsIndex() + for _, pk := range pks { + index.RetainMessage(pk) + } + + for _, tx := range tt { + t.Run("filter:'"+tx.filter, func(t *testing.T) { + messages := index.Messages(tx.filter) + require.Equal(t, tx.len, len(messages)) + }) + } +} + +func BenchmarkMessages(b *testing.B) { + index := NewTopicsIndex() + index.RetainMessage(packets.Packet{TopicName: "a/b/c/d"}) + index.RetainMessage(packets.Packet{TopicName: "a/b/d/e/f"}) + index.RetainMessage(packets.Packet{TopicName: "d/e/f/g"}) + index.RetainMessage(packets.Packet{TopicName: "$SYS/info"}) + index.RetainMessage(packets.Packet{TopicName: "q/w/e/r/t/y"}) + + for n := 0; n < b.N; n++ { + index.Messages("+/b/c/+") + } +} + +func TestNewParticles(t *testing.T) { + cl := newParticles() + require.NotNil(t, cl.internal) +} + +func TestParticlesAdd(t *testing.T) { + p := newParticles() + p.add(&particle{key: "a"}) + require.Contains(t, p.internal, "a") +} + +func TestParticlesGet(t *testing.T) { + p := newParticles() + p.add(&particle{key: "a"}) + p.add(&particle{key: "b"}) + require.Contains(t, p.internal, "a") + require.Contains(t, p.internal, "b") + + particle := p.get("a") + require.NotNil(t, particle) + require.Equal(t, "a", particle.key) +} + +func TestParticlesGetAll(t *testing.T) { + p := newParticles() + p.add(&particle{key: "a"}) + p.add(&particle{key: "b"}) + p.add(&particle{key: "c"}) + require.Contains(t, p.internal, "a") + require.Contains(t, p.internal, "b") + require.Contains(t, p.internal, "c") + + particles := p.getAll() + require.Len(t, particles, 3) +} + +func TestParticlesLen(t *testing.T) { + p := newParticles() + p.add(&particle{key: "a"}) + p.add(&particle{key: "b"}) + require.Contains(t, p.internal, "a") + require.Contains(t, p.internal, "b") + require.Equal(t, 2, p.len()) +} + +func TestParticlesDelete(t *testing.T) { + p := newParticles() + p.add(&particle{key: "a"}) + require.Contains(t, p.internal, "a") + + p.delete("a") + particle := p.get("a") + require.Nil(t, particle) +} + +func TestIsValid(t *testing.T) { + require.True(t, IsValidFilter("a/b/c", false)) + require.True(t, IsValidFilter("a/b//c", false)) + require.True(t, IsValidFilter("$SYS", false)) + require.True(t, IsValidFilter("$SYS/info", false)) + require.True(t, IsValidFilter("$sys/info", false)) + require.True(t, IsValidFilter("abc/#", false)) + require.False(t, IsValidFilter("", false)) + require.False(t, IsValidFilter(SharePrefix, false)) + require.False(t, IsValidFilter(SharePrefix+"/", false)) + require.False(t, IsValidFilter(SharePrefix+"/b+/", false)) + require.False(t, IsValidFilter(SharePrefix+"/+", false)) + require.False(t, IsValidFilter(SharePrefix+"/#", false)) + require.False(t, IsValidFilter(SharePrefix+"/#/", false)) + require.False(t, IsValidFilter("a/#/c", false)) +} + +func TestIsValidForPublish(t *testing.T) { + require.True(t, IsValidFilter("", true)) + require.True(t, IsValidFilter("a/b/c", true)) + require.False(t, IsValidFilter("a/b/+/d", true)) + require.False(t, IsValidFilter("a/b/#", true)) + require.False(t, IsValidFilter("$SYS/info", true)) +} + +func TestIsSharedFilter(t *testing.T) { + require.True(t, IsSharedFilter(SharePrefix+"/tmp/a/b/c")) + require.False(t, IsSharedFilter("a/b/c")) +} + +func TestNewInboundAliases(t *testing.T) { + a := NewInboundTopicAliases(5) + require.NotNil(t, a) + require.NotNil(t, a.internal) + require.Equal(t, uint16(5), a.maximum) +} + +func TestInboundAliasesSet(t *testing.T) { + topic := "test" + id := uint16(1) + a := NewInboundTopicAliases(5) + require.Equal(t, topic, a.Set(id, topic)) + require.Contains(t, a.internal, id) + require.Equal(t, a.internal[id], topic) + + require.Equal(t, topic, a.Set(id, "")) +} + +func TestInboundAliasesSetMaxZero(t *testing.T) { + topic := "test" + id := uint16(1) + a := NewInboundTopicAliases(0) + require.Equal(t, topic, a.Set(id, topic)) + require.NotContains(t, a.internal, id) +} + +func TestNewOutboundAliases(t *testing.T) { + a := NewOutboundTopicAliases(5) + require.NotNil(t, a) + require.NotNil(t, a.internal) + require.Equal(t, uint16(5), a.maximum) + require.Equal(t, uint32(0), a.cursor) +} + +func TestOutboundAliasesSet(t *testing.T) { + a := NewOutboundTopicAliases(3) + n, ok := a.Set("t1") + require.False(t, ok) + require.Equal(t, uint16(1), n) + + n, ok = a.Set("t2") + require.False(t, ok) + require.Equal(t, uint16(2), n) + + n, ok = a.Set("t3") + require.False(t, ok) + require.Equal(t, uint16(3), n) + + n, ok = a.Set("t4") + require.False(t, ok) + require.Equal(t, uint16(0), n) + + n, ok = a.Set("t2") + require.True(t, ok) + require.Equal(t, uint16(2), n) +} + +func TestOutboundAliasesSetMaxZero(t *testing.T) { + topic := "test" + a := NewOutboundTopicAliases(0) + n, ok := a.Set(topic) + require.False(t, ok) + require.Equal(t, uint16(0), n) +} + +func TestNewTopicAliases(t *testing.T) { + a := NewTopicAliases(5) + require.NotNil(t, a.Inbound) + require.Equal(t, uint16(5), a.Inbound.maximum) + require.NotNil(t, a.Outbound) + require.Equal(t, uint16(5), a.Outbound.maximum) +} + +func TestNewInlineSubscriptions(t *testing.T) { + subscriptions := NewInlineSubscriptions() + require.NotNil(t, subscriptions) + require.NotNil(t, subscriptions.internal) + require.Equal(t, 0, subscriptions.Len()) +} + +func TestInlineSubscriptionAdd(t *testing.T) { + subscriptions := NewInlineSubscriptions() + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + subscription := InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + Handler: handler, + } + subscriptions.Add(subscription) + + sub, ok := subscriptions.Get(1) + require.True(t, ok) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler)) +} + +func TestInlineSubscriptionGet(t *testing.T) { + subscriptions := NewInlineSubscriptions() + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + subscription := InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + Handler: handler, + } + subscriptions.Add(subscription) + + sub, ok := subscriptions.Get(1) + require.True(t, ok) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler)) + + _, ok = subscriptions.Get(999) + require.False(t, ok) +} + +func TestInlineSubscriptionsGetAll(t *testing.T) { + subscriptions := NewInlineSubscriptions() + + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + }) + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + }) + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2}, + }) + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 3}, + }) + + allSubs := subscriptions.GetAll() + require.Len(t, allSubs, 3) + require.Contains(t, allSubs, 1) + require.Contains(t, allSubs, 2) + require.Contains(t, allSubs, 3) +} + +func TestInlineSubscriptionDelete(t *testing.T) { + subscriptions := NewInlineSubscriptions() + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + subscription := InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + Handler: handler, + } + subscriptions.Add(subscription) + + subscriptions.Delete(1) + _, ok := subscriptions.Get(1) + require.False(t, ok) + require.Empty(t, subscriptions.GetAll()) + require.Zero(t, subscriptions.Len()) +} + +func TestInlineSubscribe(t *testing.T) { + + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + tt := []struct { + desc string + filter string + subscription InlineSubscription + wasNew bool + }{ + { + desc: "subscribe", + filter: "a/b/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}}, + wasNew: true, + }, + { + desc: "subscribe existed", + filter: "a/b/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}}, + wasNew: false, + }, + { + desc: "subscribe different identifier", + filter: "a/b/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2}}, + wasNew: true, + }, + { + desc: "subscribe case sensitive didnt exist", + filter: "A/B/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "A/B/c", Identifier: 1}}, + wasNew: true, + }, + { + desc: "wildcard+ sub", + filter: "d/+", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/+", Identifier: 1}}, + wasNew: true, + }, + { + desc: "wildcard# sub", + filter: "d/e/#", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/#", Identifier: 1}}, + wasNew: true, + }, + } + + index := NewTopicsIndex() + for _, tx := range tt { + t.Run(tx.desc, func(t *testing.T) { + require.Equal(t, tx.wasNew, index.InlineSubscribe(tx.subscription)) + }) + } + + final := index.root.particles.get("a").particles.get("b").particles.get("c") + require.NotNil(t, final) +} + +func TestInlineUnsubscribe(t *testing.T) { + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + index := NewTopicsIndex() + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}}) + sub, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index = NewTopicsIndex() + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}}) + sub, exists = index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 2}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(2) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/+/d", Identifier: 1}}) + sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "#", Identifier: 1}}) + sub, exists = index.root.particles.get("#").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + ok := index.InlineUnsubscribe(1, "a/b/c/d") + require.True(t, ok) + require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c")) + + sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + ok = index.InlineUnsubscribe(1, "d/e/f") + require.True(t, ok) + require.NotNil(t, index.root.particles.get("d").particles.get("e").particles.get("f")) + + ok = index.InlineUnsubscribe(1, "not/exist") + require.False(t, ok) +} diff --git a/packets/codec.go b/packets/codec.go new file mode 100644 index 0000000..152d777 --- /dev/null +++ b/packets/codec.go @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package packets + +import ( + "bytes" + "encoding/binary" + "io" + "unicode/utf8" + "unsafe" +) + +// bytesToString provides a zero-alloc no-copy byte to string conversion. +// via https://github.com/golang/go/issues/25484#issuecomment-391415660 +func bytesToString(bs []byte) string { + return *(*string)(unsafe.Pointer(&bs)) +} + +// decodeUint16 extracts the value of two bytes from a byte array. +func decodeUint16(buf []byte, offset int) (uint16, int, error) { + if len(buf) < offset+2 { + return 0, 0, ErrMalformedOffsetUintOutOfRange + } + + return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil +} + +// decodeUint32 extracts the value of four bytes from a byte array. +func decodeUint32(buf []byte, offset int) (uint32, int, error) { + if len(buf) < offset+4 { + return 0, 0, ErrMalformedOffsetUintOutOfRange + } + + return binary.BigEndian.Uint32(buf[offset : offset+4]), offset + 4, nil +} + +// decodeString extracts a string from a byte array, beginning at an offset. +func decodeString(buf []byte, offset int) (string, int, error) { + b, n, err := decodeBytes(buf, offset) + if err != nil { + return "", 0, err + } + + if !validUTF8(b) { // [MQTT-1.5.4-1] [MQTT-3.1.3-5] + return "", 0, ErrMalformedInvalidUTF8 + } + + return bytesToString(b), n, nil +} + +// validUTF8 checks if the byte array contains valid UTF-8 characters. +func validUTF8(b []byte) bool { + return utf8.Valid(b) && bytes.IndexByte(b, 0x00) == -1 // [MQTT-1.5.4-1] [MQTT-1.5.4-2] +} + +// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads. +func decodeBytes(buf []byte, offset int) ([]byte, int, error) { + length, next, err := decodeUint16(buf, offset) + if err != nil { + return make([]byte, 0), 0, err + } + + if next+int(length) > len(buf) { + return make([]byte, 0), 0, ErrMalformedOffsetBytesOutOfRange + } + + return buf[next : next+int(length)], next + int(length), nil +} + +// decodeByte extracts the value of a byte from a byte array. +func decodeByte(buf []byte, offset int) (byte, int, error) { + if len(buf) <= offset { + return 0, 0, ErrMalformedOffsetByteOutOfRange + } + return buf[offset], offset + 1, nil +} + +// decodeByteBool extracts the value of a byte from a byte array and returns a bool. +func decodeByteBool(buf []byte, offset int) (bool, int, error) { + if len(buf) <= offset { + return false, 0, ErrMalformedOffsetBoolOutOfRange + } + return 1&buf[offset] > 0, offset + 1, nil +} + +// encodeBool returns a byte instead of a bool. +func encodeBool(b bool) byte { + if b { + return 1 + } + return 0 +} + +// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads. +func encodeBytes(val []byte) []byte { + // In most circumstances the number of bytes being encoded is small. + // Setting the cap to a low amount allows us to account for those without + // triggering allocation growth on append unless we need to. + buf := make([]byte, 2, 32) + binary.BigEndian.PutUint16(buf, uint16(len(val))) + return append(buf, val...) +} + +// encodeUint16 encodes a uint16 value to a byte array. +func encodeUint16(val uint16) []byte { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, val) + return buf +} + +// encodeUint32 encodes a uint16 value to a byte array. +func encodeUint32(val uint32) []byte { + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, val) + return buf +} + +// encodeString encodes a string to a byte array. +func encodeString(val string) []byte { + // Like encodeBytes, we set the cap to a small number to avoid + // triggering allocation growth on append unless we absolutely need to. + buf := make([]byte, 2, 32) + binary.BigEndian.PutUint16(buf, uint16(len(val))) + return append(buf, []byte(val)...) +} + +// encodeLength writes length bits for the header. +func encodeLength(b *bytes.Buffer, length int64) { + // 1.5.5 Variable Byte Integer encode non-normative + // https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027 + for { + eb := byte(length % 128) + length /= 128 + if length > 0 { + eb |= 0x80 + } + b.WriteByte(eb) + if length == 0 { + break // [MQTT-1.5.5-1] + } + } +} + +func DecodeLength(b io.ByteReader) (n, bu int, err error) { + // see 1.5.5 Variable Byte Integer decode non-normative + // https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027 + var multiplier uint32 + var value uint32 + bu = 1 + for { + eb, err := b.ReadByte() + if err != nil { + return 0, bu, err + } + + value |= uint32(eb&127) << multiplier + if value > 268435455 { + return 0, bu, ErrMalformedVariableByteInteger + } + + if (eb & 128) == 0 { + break + } + + multiplier += 7 + bu++ + } + + return int(value), bu, nil +} diff --git a/packets/codec_test.go b/packets/codec_test.go new file mode 100644 index 0000000..9129721 --- /dev/null +++ b/packets/codec_test.go @@ -0,0 +1,422 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package packets + +import ( + "bytes" + "errors" + "fmt" + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBytesToString(t *testing.T) { + b := []byte{'a', 'b', 'c'} + require.Equal(t, "abc", bytesToString(b)) +} + +func TestDecodeString(t *testing.T) { + expect := []struct { + name string + rawBytes []byte + result string + offset int + shouldFail error + }{ + { + offset: 0, + rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, + result: "a/b/c/d", + }, + { + offset: 14, + rawBytes: []byte{ + Connect << 4, 17, // Fixed header + 0, 6, // Protocol Name - MSB+LSB + 'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name + 3, // Protocol Version + 0, // Packet Flags + 0, 30, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'h', 'e', 'y', // Client ID "zen"}, + }, + result: "hey", + }, + { + offset: 2, + rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97}, + result: "1/2/3/4/a/b/c/d/e/^/@/!", + }, + { + offset: 0, + rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38}, + result: "x/y/z", + }, + { + offset: 0, + rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'}, + shouldFail: ErrMalformedOffsetBytesOutOfRange, + }, + { + offset: 5, + rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'}, + shouldFail: ErrMalformedOffsetBytesOutOfRange, + }, + { + offset: 9, + rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'}, + shouldFail: ErrMalformedOffsetUintOutOfRange, + }, + { + offset: 17, + rawBytes: []byte{ + Connect << 4, 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 6, // Will Topic - MSB+LSB + 'l', + }, + shouldFail: ErrMalformedOffsetBytesOutOfRange, + }, + { + offset: 0, + rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100}, + shouldFail: ErrMalformedInvalidUTF8, + }, + } + + for i, wanted := range expect { + t.Run(fmt.Sprint(i), func(t *testing.T) { + result, _, err := decodeString(wanted.rawBytes, wanted.offset) + if wanted.shouldFail != nil { + require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) + return + } + + require.NoError(t, err) + require.Equal(t, wanted.result, result) + }) + } +} + +func TestDecodeStringZeroWidthNoBreak(t *testing.T) { // [MQTT-1.5.4-3] + result, _, err := decodeString([]byte{0, 3, 0xEF, 0xBB, 0xBF}, 0) + require.NoError(t, err) + require.Equal(t, "\ufeff", result) +} + +func TestDecodeBytes(t *testing.T) { + expect := []struct { + rawBytes []byte + result []uint8 + next int + offset int + shouldFail error + }{ + { + rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // truncated connect packet (clean session) + result: []byte{0x4d, 0x51, 0x54, 0x54}, + next: 6, + offset: 0, + }, + { + rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // truncated connect packet, only checking start + result: []byte{0x4d, 0x51, 0x54, 0x54}, + next: 6, + offset: 0, + }, + { + rawBytes: []byte{0, 4, 77, 81}, + offset: 0, + shouldFail: ErrMalformedOffsetBytesOutOfRange, + }, + { + rawBytes: []byte{0, 4, 77, 81}, + offset: 8, + shouldFail: ErrMalformedOffsetUintOutOfRange, + }, + } + + for i, wanted := range expect { + t.Run(fmt.Sprint(i), func(t *testing.T) { + result, _, err := decodeBytes(wanted.rawBytes, wanted.offset) + if wanted.shouldFail != nil { + require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) + return + } + + require.NoError(t, err) + require.Equal(t, wanted.result, result) + }) + } +} + +func TestDecodeByte(t *testing.T) { + expect := []struct { + rawBytes []byte + result uint8 + offset int + shouldFail error + }{ + { + rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes + result: uint8(0x00), + offset: 0, + }, + { + rawBytes: []byte{0, 4, 77, 81, 84, 84}, + result: uint8(0x04), + offset: 1, + }, + { + rawBytes: []byte{0, 4, 77, 81, 84, 84}, + result: uint8(0x4d), + offset: 2, + }, + { + rawBytes: []byte{0, 4, 77, 81, 84, 84}, + result: uint8(0x51), + offset: 3, + }, + { + rawBytes: []byte{0, 4, 77, 80, 82, 84}, + offset: 8, + shouldFail: ErrMalformedOffsetByteOutOfRange, + }, + } + + for i, wanted := range expect { + t.Run(fmt.Sprint(i), func(t *testing.T) { + result, offset, err := decodeByte(wanted.rawBytes, wanted.offset) + if wanted.shouldFail != nil { + require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) + return + } + + require.NoError(t, err) + require.Equal(t, wanted.result, result) + require.Equal(t, i+1, offset) + }) + } +} + +func TestDecodeUint16(t *testing.T) { + expect := []struct { + rawBytes []byte + result uint16 + offset int + shouldFail error + }{ + { + rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, + result: uint16(0x07), + offset: 0, + }, + { + rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, + result: uint16(0x761), + offset: 1, + }, + { + rawBytes: []byte{0, 7, 255, 47}, + offset: 8, + shouldFail: ErrMalformedOffsetUintOutOfRange, + }, + } + + for i, wanted := range expect { + t.Run(fmt.Sprint(i), func(t *testing.T) { + result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset) + if wanted.shouldFail != nil { + require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) + return + } + + require.NoError(t, err) + require.Equal(t, wanted.result, result) + require.Equal(t, i+2, offset) + }) + } +} + +func TestDecodeUint32(t *testing.T) { + expect := []struct { + rawBytes []byte + result uint32 + offset int + shouldFail error + }{ + { + rawBytes: []byte{0, 0, 0, 7, 8}, + result: uint32(7), + offset: 0, + }, + { + rawBytes: []byte{0, 0, 1, 226, 64, 8}, + result: uint32(123456), + offset: 1, + }, + { + rawBytes: []byte{0, 7, 255, 47}, + offset: 8, + shouldFail: ErrMalformedOffsetUintOutOfRange, + }, + } + + for i, wanted := range expect { + t.Run(fmt.Sprint(i), func(t *testing.T) { + result, offset, err := decodeUint32(wanted.rawBytes, wanted.offset) + if wanted.shouldFail != nil { + require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) + return + } + + require.NoError(t, err) + require.Equal(t, wanted.result, result) + require.Equal(t, i+4, offset) + }) + } +} + +func TestDecodeByteBool(t *testing.T) { + expect := []struct { + rawBytes []byte + result bool + offset int + shouldFail error + }{ + { + rawBytes: []byte{0x00, 0x00}, + result: false, + }, + { + rawBytes: []byte{0x01, 0x00}, + result: true, + }, + { + rawBytes: []byte{0x01, 0x00}, + offset: 5, + shouldFail: ErrMalformedOffsetBoolOutOfRange, + }, + } + + for i, wanted := range expect { + t.Run(fmt.Sprint(i), func(t *testing.T) { + result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset) + if wanted.shouldFail != nil { + require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) + return + } + + require.NoError(t, err) + require.Equal(t, wanted.result, result) + require.Equal(t, 1, offset) + }) + } +} + +func TestDecodeLength(t *testing.T) { + b := bytes.NewBuffer([]byte{0x78}) + n, bu, err := DecodeLength(b) + require.NoError(t, err) + require.Equal(t, 120, n) + require.Equal(t, 1, bu) + + b = bytes.NewBuffer([]byte{255, 255, 255, 127}) + n, bu, err = DecodeLength(b) + require.NoError(t, err) + require.Equal(t, 268435455, n) + require.Equal(t, 4, bu) +} + +func TestDecodeLengthErrors(t *testing.T) { + b := bytes.NewBuffer([]byte{}) + _, _, err := DecodeLength(b) + require.Error(t, err) + + b = bytes.NewBuffer([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}) + _, _, err = DecodeLength(b) + require.Error(t, err) + require.ErrorIs(t, err, ErrMalformedVariableByteInteger) +} + +func TestEncodeBool(t *testing.T) { + result := encodeBool(true) + require.Equal(t, byte(1), result) + + result = encodeBool(false) + require.Equal(t, byte(0), result) + + // Check failure. + result = encodeBool(false) + require.NotEqual(t, byte(1), result) +} + +func TestEncodeBytes(t *testing.T) { + result := encodeBytes([]byte("testing")) + require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result) + + result = encodeBytes([]byte("testing")) + require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result) +} + +func TestEncodeUint16(t *testing.T) { + result := encodeUint16(0) + require.Equal(t, []byte{0x00, 0x00}, result) + + result = encodeUint16(32767) + require.Equal(t, []byte{0x7f, 0xff}, result) + + result = encodeUint16(math.MaxUint16) + require.Equal(t, []byte{0xff, 0xff}, result) +} + +func TestEncodeUint32(t *testing.T) { + result := encodeUint32(7) + require.Equal(t, []byte{0x00, 0x00, 0x00, 0x07}, result) + + result = encodeUint32(32767) + require.Equal(t, []byte{0, 0, 127, 255}, result) + + result = encodeUint32(math.MaxUint32) + require.Equal(t, []byte{255, 255, 255, 255}, result) +} + +func TestEncodeString(t *testing.T) { + result := encodeString("testing") + require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result) + + result = encodeString("") + require.Equal(t, []uint8{0x00, 0x00}, result) + + result = encodeString("a") + require.Equal(t, []uint8{0x00, 0x01, 0x61}, result) + + result = encodeString("b") + require.NotEqual(t, []uint8{0x00, 0x00}, result) +} + +func TestEncodeLength(t *testing.T) { + b := new(bytes.Buffer) + encodeLength(b, 120) + require.Equal(t, []byte{0x78}, b.Bytes()) + + b = new(bytes.Buffer) + encodeLength(b, math.MaxInt64) + require.Equal(t, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, b.Bytes()) +} + +func TestValidUTF8(t *testing.T) { + require.True(t, validUTF8([]byte{0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67})) + require.False(t, validUTF8([]byte{0xff, 0xff})) + require.False(t, validUTF8([]byte{0x74, 0x00, 0x73, 0x74})) +} diff --git a/packets/codes.go b/packets/codes.go new file mode 100644 index 0000000..5af1b74 --- /dev/null +++ b/packets/codes.go @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package packets + +// Code contains a reason code and reason string for a response. +type Code struct { + Reason string + Code byte +} + +// String returns the readable reason for a code. +func (c Code) String() string { + return c.Reason +} + +// Error returns the readable reason for a code. +func (c Code) Error() string { + return c.Reason +} + +var ( + // QosCodes indicates the reason codes for each Qos byte. + QosCodes = map[byte]Code{ + 0: CodeGrantedQos0, + 1: CodeGrantedQos1, + 2: CodeGrantedQos2, + } + + CodeSuccessIgnore = Code{Code: 0x00, Reason: "ignore packet"} + CodeSuccess = Code{Code: 0x00, Reason: "success"} + CodeDisconnect = Code{Code: 0x00, Reason: "disconnected"} + CodeGrantedQos0 = Code{Code: 0x00, Reason: "granted qos 0"} + CodeGrantedQos1 = Code{Code: 0x01, Reason: "granted qos 1"} + CodeGrantedQos2 = Code{Code: 0x02, Reason: "granted qos 2"} + CodeDisconnectWillMessage = Code{Code: 0x04, Reason: "disconnect with will message"} + CodeNoMatchingSubscribers = Code{Code: 0x10, Reason: "no matching subscribers"} + CodeNoSubscriptionExisted = Code{Code: 0x11, Reason: "no subscription existed"} + CodeContinueAuthentication = Code{Code: 0x18, Reason: "continue authentication"} + CodeReAuthenticate = Code{Code: 0x19, Reason: "re-authenticate"} + ErrUnspecifiedError = Code{Code: 0x80, Reason: "unspecified error"} + ErrMalformedPacket = Code{Code: 0x81, Reason: "malformed packet"} + ErrMalformedProtocolName = Code{Code: 0x81, Reason: "malformed packet: protocol name"} + ErrMalformedProtocolVersion = Code{Code: 0x81, Reason: "malformed packet: protocol version"} + ErrMalformedFlags = Code{Code: 0x81, Reason: "malformed packet: flags"} + ErrMalformedKeepalive = Code{Code: 0x81, Reason: "malformed packet: keepalive"} + ErrMalformedPacketID = Code{Code: 0x81, Reason: "malformed packet: packet identifier"} + ErrMalformedTopic = Code{Code: 0x81, Reason: "malformed packet: topic"} + ErrMalformedWillTopic = Code{Code: 0x81, Reason: "malformed packet: will topic"} + ErrMalformedWillPayload = Code{Code: 0x81, Reason: "malformed packet: will message"} + ErrMalformedUsername = Code{Code: 0x81, Reason: "malformed packet: username"} + ErrMalformedPassword = Code{Code: 0x81, Reason: "malformed packet: password"} + ErrMalformedQos = Code{Code: 0x81, Reason: "malformed packet: qos"} + ErrMalformedOffsetUintOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset uint out of range"} + ErrMalformedOffsetBytesOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset bytes out of range"} + ErrMalformedOffsetByteOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset byte out of range"} + ErrMalformedOffsetBoolOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset boolean out of range"} + ErrMalformedInvalidUTF8 = Code{Code: 0x81, Reason: "malformed packet: invalid utf-8 string"} + ErrMalformedVariableByteInteger = Code{Code: 0x81, Reason: "malformed packet: variable byte integer out of range"} + ErrMalformedBadProperty = Code{Code: 0x81, Reason: "malformed packet: unknown property"} + ErrMalformedProperties = Code{Code: 0x81, Reason: "malformed packet: properties"} + ErrMalformedWillProperties = Code{Code: 0x81, Reason: "malformed packet: will properties"} + ErrMalformedSessionPresent = Code{Code: 0x81, Reason: "malformed packet: session present"} + ErrMalformedReasonCode = Code{Code: 0x81, Reason: "malformed packet: reason code"} + ErrProtocolViolation = Code{Code: 0x82, Reason: "protocol violation"} + ErrProtocolViolationProtocolName = Code{Code: 0x82, Reason: "protocol violation: protocol name"} + ErrProtocolViolationProtocolVersion = Code{Code: 0x82, Reason: "protocol violation: protocol version"} + ErrProtocolViolationReservedBit = Code{Code: 0x82, Reason: "protocol violation: reserved bit not 0"} + ErrProtocolViolationFlagNoUsername = Code{Code: 0x82, Reason: "protocol violation: username flag set but no value"} + ErrProtocolViolationFlagNoPassword = Code{Code: 0x82, Reason: "protocol violation: password flag set but no value"} + ErrProtocolViolationUsernameNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"} + ErrProtocolViolationPasswordNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"} + ErrProtocolViolationPasswordTooLong = Code{Code: 0x82, Reason: "protocol violation: password too long"} + ErrProtocolViolationUsernameTooLong = Code{Code: 0x82, Reason: "protocol violation: username too long"} + ErrProtocolViolationNoPacketID = Code{Code: 0x82, Reason: "protocol violation: missing packet id"} + ErrProtocolViolationSurplusPacketID = Code{Code: 0x82, Reason: "protocol violation: surplus packet id"} + ErrProtocolViolationQosOutOfRange = Code{Code: 0x82, Reason: "protocol violation: qos out of range"} + ErrProtocolViolationSecondConnect = Code{Code: 0x82, Reason: "protocol violation: second connect packet"} + ErrProtocolViolationZeroNonZeroExpiry = Code{Code: 0x82, Reason: "protocol violation: non-zero expiry"} + ErrProtocolViolationRequireFirstConnect = Code{Code: 0x82, Reason: "protocol violation: first packet must be connect"} + ErrProtocolViolationWillFlagNoPayload = Code{Code: 0x82, Reason: "protocol violation: will flag no payload"} + ErrProtocolViolationWillFlagSurplusRetain = Code{Code: 0x82, Reason: "protocol violation: will flag surplus retain"} + ErrProtocolViolationSurplusWildcard = Code{Code: 0x82, Reason: "protocol violation: topic contains wildcards"} + ErrProtocolViolationSurplusSubID = Code{Code: 0x82, Reason: "protocol violation: contained subscription identifier"} + ErrProtocolViolationInvalidTopic = Code{Code: 0x82, Reason: "protocol violation: invalid topic"} + ErrProtocolViolationInvalidSharedNoLocal = Code{Code: 0x82, Reason: "protocol violation: invalid shared no local"} + ErrProtocolViolationNoFilters = Code{Code: 0x82, Reason: "protocol violation: must contain at least one filter"} + ErrProtocolViolationInvalidReason = Code{Code: 0x82, Reason: "protocol violation: invalid reason"} + ErrProtocolViolationOversizeSubID = Code{Code: 0x82, Reason: "protocol violation: oversize subscription id"} + ErrProtocolViolationDupNoQos = Code{Code: 0x82, Reason: "protocol violation: dup true with no qos"} + ErrProtocolViolationUnsupportedProperty = Code{Code: 0x82, Reason: "protocol violation: unsupported property"} + ErrProtocolViolationNoTopic = Code{Code: 0x82, Reason: "protocol violation: no topic or alias"} + ErrImplementationSpecificError = Code{Code: 0x83, Reason: "implementation specific error"} + ErrRejectPacket = Code{Code: 0x83, Reason: "packet rejected"} + ErrUnsupportedProtocolVersion = Code{Code: 0x84, Reason: "unsupported protocol version"} + ErrClientIdentifierNotValid = Code{Code: 0x85, Reason: "client identifier not valid"} + ErrClientIdentifierTooLong = Code{Code: 0x85, Reason: "client identifier too long"} + ErrBadUsernameOrPassword = Code{Code: 0x86, Reason: "bad username or password"} + ErrNotAuthorized = Code{Code: 0x87, Reason: "not authorized"} + ErrServerUnavailable = Code{Code: 0x88, Reason: "server unavailable"} + ErrServerBusy = Code{Code: 0x89, Reason: "server busy"} + ErrBanned = Code{Code: 0x8A, Reason: "banned"} + ErrServerShuttingDown = Code{Code: 0x8B, Reason: "server shutting down"} + ErrBadAuthenticationMethod = Code{Code: 0x8C, Reason: "bad authentication method"} + ErrKeepAliveTimeout = Code{Code: 0x8D, Reason: "keep alive timeout"} + ErrSessionTakenOver = Code{Code: 0x8E, Reason: "session takeover"} + ErrTopicFilterInvalid = Code{Code: 0x8F, Reason: "topic filter invalid"} + ErrTopicNameInvalid = Code{Code: 0x90, Reason: "topic name invalid"} + ErrPacketIdentifierInUse = Code{Code: 0x91, Reason: "packet identifier in use"} + ErrPacketIdentifierNotFound = Code{Code: 0x92, Reason: "packet identifier not found"} + ErrReceiveMaximum = Code{Code: 0x93, Reason: "receive maximum exceeded"} + ErrTopicAliasInvalid = Code{Code: 0x94, Reason: "topic alias invalid"} + ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"} + ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"} + ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"} + ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"} + ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"} + ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"} + ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"} + ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"} + ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"} + ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"} + ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptions not supported"} + ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"} + ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"} + ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"} + ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"} + ErrInlineSubscriptionHandlerInvalid = Code{Code: 0xA3, Reason: "inline subscription handler not valid."} + + // MQTTv3 specific bytes. + Err3UnsupportedProtocolVersion = Code{Code: 0x01} + Err3ClientIdentifierNotValid = Code{Code: 0x02} + Err3ServerUnavailable = Code{Code: 0x03} + ErrMalformedUsernameOrPassword = Code{Code: 0x04} + Err3NotAuthorized = Code{Code: 0x05} + + // V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes. + // This is required because MQTTv3 has different return byte specification. + // See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257 + V5CodesToV3 = map[Code]Code{ + ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion, + ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid, + ErrServerUnavailable: Err3ServerUnavailable, + ErrMalformedUsername: ErrMalformedUsernameOrPassword, + ErrMalformedPassword: ErrMalformedUsernameOrPassword, + ErrBadUsernameOrPassword: Err3NotAuthorized, + } +) diff --git a/packets/codes_test.go b/packets/codes_test.go new file mode 100644 index 0000000..aed8e57 --- /dev/null +++ b/packets/codes_test.go @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package packets + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCodesString(t *testing.T) { + c := Code{ + Reason: "test", + Code: 0x1, + } + + require.Equal(t, "test", c.String()) +} + +func TestCodesError(t *testing.T) { + c := Code{ + Reason: "error", + Code: 0x1, + } + + require.Equal(t, "error", error(c).Error()) +} diff --git a/packets/fixedheader.go b/packets/fixedheader.go new file mode 100644 index 0000000..eb20451 --- /dev/null +++ b/packets/fixedheader.go @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package packets + +import ( + "bytes" +) + +// FixedHeader contains the values of the fixed header portion of the MQTT packet. +type FixedHeader struct { + Remaining int `json:"remaining"` // the number of remaining bytes in the payload. + Type byte `json:"type"` // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1). + Qos byte `json:"qos"` // indicates the quality of service expected. + Dup bool `json:"dup"` // indicates if the packet was already sent at an earlier time. + Retain bool `json:"retain"` // whether the message should be retained. +} + +// Encode encodes the FixedHeader and returns a bytes buffer. +func (fh *FixedHeader) Encode(buf *bytes.Buffer) { + buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain)) + encodeLength(buf, int64(fh.Remaining)) +} + +// Decode extracts the specification bits from the header byte. +func (fh *FixedHeader) Decode(hb byte) error { + fh.Type = hb >> 4 // Get the message type from the first 4 bytes. + + switch fh.Type { + case Publish: + if (hb>>1)&0x01 > 0 && (hb>>1)&0x02 > 0 { + return ErrProtocolViolationQosOutOfRange // [MQTT-3.3.1-4] + } + + fh.Dup = (hb>>3)&0x01 > 0 // is duplicate + fh.Qos = (hb >> 1) & 0x03 // qos flag + fh.Retain = hb&0x01 > 0 // is retain flag + case Pubrel: + fallthrough + case Subscribe: + fallthrough + case Unsubscribe: + if (hb>>0)&0x01 != 0 || (hb>>1)&0x01 != 1 || (hb>>2)&0x01 != 0 || (hb>>3)&0x01 != 0 { // [MQTT-3.8.1-1] [MQTT-3.10.1-1] + return ErrMalformedFlags + } + + fh.Qos = (hb >> 1) & 0x03 + default: + if (hb>>0)&0x01 != 0 || + (hb>>1)&0x01 != 0 || + (hb>>2)&0x01 != 0 || + (hb>>3)&0x01 != 0 { // [MQTT-3.8.3-5] [MQTT-3.14.1-1] [MQTT-3.15.1-1] + return ErrMalformedFlags + } + } + + if fh.Qos == 0 && fh.Dup { + return ErrProtocolViolationDupNoQos // [MQTT-3.3.1-2] + } + + return nil +} diff --git a/packets/fixedheader_test.go b/packets/fixedheader_test.go new file mode 100644 index 0000000..fe8c497 --- /dev/null +++ b/packets/fixedheader_test.go @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package packets + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +type fixedHeaderTable struct { + desc string + rawBytes []byte + header FixedHeader + packetError bool + expect error +} + +var fixedHeaderExpected = []fixedHeaderTable{ + { + desc: "connect", + rawBytes: []byte{Connect << 4, 0x00}, + header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0}, + }, + { + desc: "connack", + rawBytes: []byte{Connack << 4, 0x00}, + header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0}, + }, + { + desc: "publish", + rawBytes: []byte{Publish << 4, 0x00}, + header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0}, + }, + { + desc: "publish qos 1", + rawBytes: []byte{Publish<<4 | 1<<1, 0x00}, + header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0}, + }, + { + desc: "publish qos 1 retain", + rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00}, + header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0}, + }, + { + desc: "publish qos 2", + rawBytes: []byte{Publish<<4 | 2<<1, 0x00}, + header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0}, + }, + { + desc: "publish qos 2 retain", + rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00}, + header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0}, + }, + { + desc: "publish dup qos 0", + rawBytes: []byte{Publish<<4 | 1<<3, 0x00}, + header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0}, + expect: ErrProtocolViolationDupNoQos, + }, + { + desc: "publish dup qos 0 retain", + rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00}, + header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0}, + expect: ErrProtocolViolationDupNoQos, + }, + { + desc: "publish dup qos 1 retain", + rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00}, + header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0}, + }, + { + desc: "publish dup qos 2 retain", + rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00}, + header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0}, + }, + { + desc: "puback", + rawBytes: []byte{Puback << 4, 0x00}, + header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0}, + }, + { + desc: "pubrec", + rawBytes: []byte{Pubrec << 4, 0x00}, + header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0}, + }, + { + desc: "pubrel", + rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00}, + header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0}, + }, + { + desc: "pubcomp", + rawBytes: []byte{Pubcomp << 4, 0x00}, + header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0}, + }, + { + desc: "subscribe", + rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00}, + header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0}, + }, + { + desc: "suback", + rawBytes: []byte{Suback << 4, 0x00}, + header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0}, + }, + { + desc: "unsubscribe", + rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00}, + header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0}, + }, + { + desc: "unsuback", + rawBytes: []byte{Unsuback << 4, 0x00}, + header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0}, + }, + { + desc: "pingreq", + rawBytes: []byte{Pingreq << 4, 0x00}, + header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0}, + }, + { + desc: "pingresp", + rawBytes: []byte{Pingresp << 4, 0x00}, + header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0}, + }, + { + desc: "disconnect", + rawBytes: []byte{Disconnect << 4, 0x00}, + header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0}, + }, + { + desc: "auth", + rawBytes: []byte{Auth << 4, 0x00}, + header: FixedHeader{Type: Auth, Dup: false, Qos: 0, Retain: false, Remaining: 0}, + }, + + // remaining length + { + desc: "remaining length 10", + rawBytes: []byte{Publish << 4, 0x0a}, + header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10}, + }, + { + desc: "remaining length 512", + rawBytes: []byte{Publish << 4, 0x80, 0x04}, + header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512}, + }, + { + desc: "remaining length 978", + rawBytes: []byte{Publish << 4, 0xd2, 0x07}, + header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978}, + }, + { + desc: "remaining length 20202", + rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01}, + header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102}, + }, + { + desc: "remaining length oversize", + rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}, + header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333}, + packetError: true, + }, + + // Invalid flags for packet + { + desc: "invalid type dup is true", + rawBytes: []byte{Connect<<4 | 1<<3, 0x00}, + header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0}, + expect: ErrMalformedFlags, + }, + { + desc: "invalid type qos is 1", + rawBytes: []byte{Connect<<4 | 1<<1, 0x00}, + header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0}, + expect: ErrMalformedFlags, + }, + { + desc: "invalid type retain is true", + rawBytes: []byte{Connect<<4 | 1, 0x00}, + header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0}, + expect: ErrMalformedFlags, + }, + { + desc: "invalid publish qos bits 1 + 2 set", + rawBytes: []byte{Publish<<4 | 1<<1 | 1<<2, 0x00}, + header: FixedHeader{Type: Publish}, + expect: ErrProtocolViolationQosOutOfRange, + }, + { + desc: "invalid pubrel bits 3,2,1,0 should be 0,0,1,0", + rawBytes: []byte{Pubrel<<4 | 1<<2 | 1<<0, 0x00}, + header: FixedHeader{Type: Pubrel, Qos: 1}, + expect: ErrMalformedFlags, + }, + { + desc: "invalid subscribe bits 3,2,1,0 should be 0,0,1,0", + rawBytes: []byte{Subscribe<<4 | 1<<2, 0x00}, + header: FixedHeader{Type: Subscribe, Qos: 1}, + expect: ErrMalformedFlags, + }, +} + +func TestFixedHeaderEncode(t *testing.T) { + for _, wanted := range fixedHeaderExpected { + t.Run(wanted.desc, func(t *testing.T) { + buf := new(bytes.Buffer) + wanted.header.Encode(buf) + if wanted.expect == nil { + require.Equal(t, len(wanted.rawBytes), len(buf.Bytes())) + require.EqualValues(t, wanted.rawBytes, buf.Bytes()) + } + }) + } +} + +func TestFixedHeaderDecode(t *testing.T) { + for _, wanted := range fixedHeaderExpected { + t.Run(wanted.desc, func(t *testing.T) { + fh := new(FixedHeader) + err := fh.Decode(wanted.rawBytes[0]) + if wanted.expect != nil { + require.Equal(t, wanted.expect, err) + } else { + require.NoError(t, err) + require.Equal(t, wanted.header.Type, fh.Type) + require.Equal(t, wanted.header.Dup, fh.Dup) + require.Equal(t, wanted.header.Qos, fh.Qos) + require.Equal(t, wanted.header.Retain, fh.Retain) + } + }) + } +} diff --git a/packets/packets.go b/packets/packets.go new file mode 100644 index 0000000..d52d5af --- /dev/null +++ b/packets/packets.go @@ -0,0 +1,1173 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package packets + +import ( + "bytes" + "errors" + "fmt" + "math" + "strconv" + "strings" + "sync" + + "testmqtt/mempool" +) + +// 所有有效的数据包类型及其数据包标识符。 +// All valid packet types and their packet identifiers. +const ( + Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets. + Connect // 1 + Connack // 2 + Publish // 3 + Puback // 4 + Pubrec // 5 + Pubrel // 6 + Pubcomp // 7 + Subscribe // 8 + Suback // 9 + Unsubscribe // 10 + Unsuback // 11 + Pingreq // 12 + Pingresp // 13 + Disconnect // 14 + Auth // 15 + WillProperties byte = 99 // Special byte for validating Will Properties. +) + +var ( + // ErrNoValidPacketAvailable indicates the packet type byte provided does not exist in the mqtt specification. + ErrNoValidPacketAvailable = errors.New("no valid packet available") + + // PacketNames is a map of packet bytes to human-readable names, for easier debugging. + PacketNames = map[byte]string{ + 0: "Reserved", + 1: "Connect", + 2: "Connack", + 3: "Publish", + 4: "Puback", + 5: "Pubrec", + 6: "Pubrel", + 7: "Pubcomp", + 8: "Subscribe", + 9: "Suback", + 10: "Unsubscribe", + 11: "Unsuback", + 12: "Pingreq", + 13: "Pingresp", + 14: "Disconnect", + 15: "Auth", + } +) + +// Packets is a concurrency safe map of packets. +type Packets struct { + internal map[string]Packet + sync.RWMutex +} + +// NewPackets returns a new instance of Packets. +func NewPackets() *Packets { + return &Packets{ + internal: map[string]Packet{}, + } +} + +// Add adds a new packet to the map. +func (p *Packets) Add(id string, val Packet) { + p.Lock() + defer p.Unlock() + p.internal[id] = val +} + +// GetAll returns all packets in the map. +func (p *Packets) GetAll() map[string]Packet { + p.RLock() + defer p.RUnlock() + m := map[string]Packet{} + for k, v := range p.internal { + m[k] = v + } + return m +} + +// Get returns a specific packet in the map by packet id. +func (p *Packets) Get(id string) (val Packet, ok bool) { + p.RLock() + defer p.RUnlock() + val, ok = p.internal[id] + return val, ok +} + +// Len returns the number of packets in the map. +func (p *Packets) Len() int { + p.RLock() + defer p.RUnlock() + val := len(p.internal) + return val +} + +// Delete removes a packet from the map by packet id. +func (p *Packets) Delete(id string) { + p.Lock() + defer p.Unlock() + delete(p.internal, id) +} + +// Packet represents an MQTT packet. Instead of providing a packet interface +// variant packet structs, this is a single concrete packet type to cover all packet +// types, which allows us to take advantage of various compiler optimizations. It +// contains a combination of mqtt spec values and internal broker control codes. +type Packet struct { + Connect ConnectParams // parameters for connect packets (just for organisation) + Properties Properties // all mqtt v5 packet properties + Payload []byte // a message/payload for publish packets + ReasonCodes []byte // one or more reason codes for multi-reason responses (suback, etc) + Filters Subscriptions // a list of subscription filters and their properties (subscribe, unsubscribe) + TopicName string // the topic a payload is being published to + Origin string // client id of the client who is issuing the packet (mostly internal use) + FixedHeader FixedHeader // - + Created int64 // unix timestamp indicating time packet was created/received on the server + Expiry int64 // unix timestamp indicating when the packet will expire and should be deleted + Mods Mods // internal broker control values for controlling certain mqtt v5 compliance + PacketID uint16 // packet id for the packet (publish, qos, etc) + ProtocolVersion byte // protocol version of the client the packet belongs to + SessionPresent bool // session existed for connack + ReasonCode byte // reason code for a packet response (acks, etc) + ReservedBit byte // reserved, do not use (except in testing) + Ignore bool // if true, do not perform any message forwarding operations +} + +// Mods specifies certain values required for certain mqtt v5 compliance within packet encoding/decoding. +type Mods struct { + MaxSize uint32 // the maximum packet size specified by the client / server + DisallowProblemInfo bool // if problem info is disallowed + AllowResponseInfo bool // if response info is disallowed +} + +// ConnectParams contains packet values which are specifically related to connect packets. +type ConnectParams struct { + WillProperties Properties `json:"willProperties"` // - + Password []byte `json:"password"` // - + Username []byte `json:"username"` // - + ProtocolName []byte `json:"protocolName"` // - + WillPayload []byte `json:"willPayload"` // - + ClientIdentifier string `json:"clientId"` // - + WillTopic string `json:"willTopic"` // - + Keepalive uint16 `json:"keepalive"` // - + PasswordFlag bool `json:"passwordFlag"` // - + UsernameFlag bool `json:"usernameFlag"` // - + WillQos byte `json:"willQos"` // - + WillFlag bool `json:"willFlag"` // - + WillRetain bool `json:"willRetain"` // - + Clean bool `json:"clean"` // CleanSession in v3.1.1, CleanStart in v5 +} + +// Subscriptions is a slice of Subscription. +type Subscriptions []Subscription // must be a slice to retain order. + +// Subscription contains details about a client subscription to a topic filter. +type Subscription struct { + ShareName []string + Filter string + Identifier int + Identifiers map[string]int + RetainHandling byte + Qos byte + RetainAsPublished bool + NoLocal bool + FwdRetainedFlag bool // true if the subscription forms part of a publish response to a client subscription and packet is retained. +} + +// Copy creates a new instance of a packet, but with an empty header for inheriting new QoS flags, etc. +func (pk *Packet) Copy(allowTransfer bool) Packet { + p := Packet{ + FixedHeader: FixedHeader{ + Remaining: pk.FixedHeader.Remaining, + Type: pk.FixedHeader.Type, + Retain: pk.FixedHeader.Retain, + Dup: false, // [MQTT-4.3.1-1] [MQTT-4.3.2-2] + Qos: pk.FixedHeader.Qos, + }, + Mods: Mods{ + MaxSize: pk.Mods.MaxSize, + }, + ReservedBit: pk.ReservedBit, + ProtocolVersion: pk.ProtocolVersion, + Connect: ConnectParams{ + ClientIdentifier: pk.Connect.ClientIdentifier, + Keepalive: pk.Connect.Keepalive, + WillQos: pk.Connect.WillQos, + WillTopic: pk.Connect.WillTopic, + WillFlag: pk.Connect.WillFlag, + WillRetain: pk.Connect.WillRetain, + WillProperties: pk.Connect.WillProperties.Copy(allowTransfer), + Clean: pk.Connect.Clean, + }, + TopicName: pk.TopicName, + Properties: pk.Properties.Copy(allowTransfer), + SessionPresent: pk.SessionPresent, + ReasonCode: pk.ReasonCode, + Filters: pk.Filters, + Created: pk.Created, + Expiry: pk.Expiry, + Origin: pk.Origin, + } + + if allowTransfer { + p.PacketID = pk.PacketID + } + + if len(pk.Connect.ProtocolName) > 0 { + p.Connect.ProtocolName = append([]byte{}, pk.Connect.ProtocolName...) + } + + if len(pk.Connect.Password) > 0 { + p.Connect.PasswordFlag = true + p.Connect.Password = append([]byte{}, pk.Connect.Password...) + } + + if len(pk.Connect.Username) > 0 { + p.Connect.UsernameFlag = true + p.Connect.Username = append([]byte{}, pk.Connect.Username...) + } + + if len(pk.Connect.WillPayload) > 0 { + p.Connect.WillPayload = append([]byte{}, pk.Connect.WillPayload...) + } + + if len(pk.Payload) > 0 { + p.Payload = append([]byte{}, pk.Payload...) + } + + if len(pk.ReasonCodes) > 0 { + p.ReasonCodes = append([]byte{}, pk.ReasonCodes...) + } + + return p +} + +// Merge merges a new subscription with a base subscription, preserving the highest +// qos value, matched identifiers and any special properties. +func (s Subscription) Merge(n Subscription) Subscription { + if s.Identifiers == nil { + s.Identifiers = map[string]int{ + s.Filter: s.Identifier, + } + } + + if n.Identifier > 0 { + s.Identifiers[n.Filter] = n.Identifier + } + + if n.Qos > s.Qos { + s.Qos = n.Qos // [MQTT-3.3.4-2] + } + + if n.NoLocal { + s.NoLocal = true // [MQTT-3.8.3-3] + } + + return s +} + +// encode encodes a subscription and properties into bytes. +func (s Subscription) encode() byte { + var flag byte + flag |= s.Qos + + if s.NoLocal { + flag |= 1 << 2 + } + + if s.RetainAsPublished { + flag |= 1 << 3 + } + + flag |= s.RetainHandling << 4 + return flag +} + +// decode decodes subscription bytes into a subscription struct. +func (s *Subscription) decode(b byte) { + s.Qos = b & 3 // byte + s.NoLocal = 1&(b>>2) > 0 // bool + s.RetainAsPublished = 1&(b>>3) > 0 // bool + s.RetainHandling = 3 & (b >> 4) // byte +} + +// ConnectEncode encodes a connect packet. +func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error { + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) + nb.Write(encodeBytes(pk.Connect.ProtocolName)) + nb.WriteByte(pk.ProtocolVersion) + + nb.WriteByte( + encodeBool(pk.Connect.Clean)<<1 | + encodeBool(pk.Connect.WillFlag)<<2 | + pk.Connect.WillQos<<3 | + encodeBool(pk.Connect.WillRetain)<<5 | + encodeBool(pk.Connect.PasswordFlag)<<6 | + encodeBool(pk.Connect.UsernameFlag)<<7 | + 0, // [MQTT-2.1.3-1] + ) + + nb.Write(encodeUint16(pk.Connect.Keepalive)) + + if pk.ProtocolVersion == 5 { + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) + (&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0) + nb.Write(pb.Bytes()) + } + + nb.Write(encodeString(pk.Connect.ClientIdentifier)) + + if pk.Connect.WillFlag { + if pk.ProtocolVersion == 5 { + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) + (&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0) + nb.Write(pb.Bytes()) + } + + nb.Write(encodeString(pk.Connect.WillTopic)) + nb.Write(encodeBytes(pk.Connect.WillPayload)) + } + + if pk.Connect.UsernameFlag { + nb.Write(encodeBytes(pk.Connect.Username)) + } + + if pk.Connect.PasswordFlag { + nb.Write(encodeBytes(pk.Connect.Password)) + } + + pk.FixedHeader.Remaining = nb.Len() + pk.FixedHeader.Encode(buf) + buf.Write(nb.Bytes()) + + return nil +} + +// ConnectDecode decodes a connect packet. +func (pk *Packet) ConnectDecode(buf []byte) error { + var offset int + var err error + + pk.Connect.ProtocolName, offset, err = decodeBytes(buf, 0) + if err != nil { + return ErrMalformedProtocolName + } + + pk.ProtocolVersion, offset, err = decodeByte(buf, offset) + if err != nil { + return ErrMalformedProtocolVersion + } + + flags, offset, err := decodeByte(buf, offset) + if err != nil { + return ErrMalformedFlags + } + + pk.ReservedBit = 1 & flags + pk.Connect.Clean = 1&(flags>>1) > 0 + pk.Connect.WillFlag = 1&(flags>>2) > 0 + pk.Connect.WillQos = 3 & (flags >> 3) // this one is not a bool + pk.Connect.WillRetain = 1&(flags>>5) > 0 + pk.Connect.PasswordFlag = 1&(flags>>6) > 0 + pk.Connect.UsernameFlag = 1&(flags>>7) > 0 + + pk.Connect.Keepalive, offset, err = decodeUint16(buf, offset) + if err != nil { + return ErrMalformedKeepalive + } + + if pk.ProtocolVersion == 5 { + n, err := pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:])) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedProperties) + } + offset += n + } + + pk.Connect.ClientIdentifier, offset, err = decodeString(buf, offset) // [MQTT-3.1.3-1] [MQTT-3.1.3-2] [MQTT-3.1.3-3] [MQTT-3.1.3-4] + if err != nil { + return ErrClientIdentifierNotValid // [MQTT-3.1.3-8] + } + + if pk.Connect.WillFlag { // [MQTT-3.1.2-7] + if pk.ProtocolVersion == 5 { + n, err := pk.Connect.WillProperties.Decode(WillProperties, bytes.NewBuffer(buf[offset:])) + if err != nil { + return ErrMalformedWillProperties + } + offset += n + } + + pk.Connect.WillTopic, offset, err = decodeString(buf, offset) + if err != nil { + return ErrMalformedWillTopic + } + + pk.Connect.WillPayload, offset, err = decodeBytes(buf, offset) + if err != nil { + return ErrMalformedWillPayload + } + } + + if pk.Connect.UsernameFlag { // [MQTT-3.1.3-12] + if offset >= len(buf) { // we are at the end of the packet + return ErrProtocolViolationFlagNoUsername // [MQTT-3.1.2-17] + } + + pk.Connect.Username, offset, err = decodeBytes(buf, offset) + if err != nil { + return ErrMalformedUsername + } + } + + if pk.Connect.PasswordFlag { + pk.Connect.Password, _, err = decodeBytes(buf, offset) + if err != nil { + return ErrMalformedPassword + } + } + + return nil +} + +// ConnectValidate ensures the connect packet is compliant. +func (pk *Packet) ConnectValidate() Code { + if !bytes.Equal(pk.Connect.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) && !bytes.Equal(pk.Connect.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) { + return ErrProtocolViolationProtocolName // [MQTT-3.1.2-1] + } + + if (bytes.Equal(pk.Connect.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) && pk.ProtocolVersion != 3) || + (bytes.Equal(pk.Connect.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) && pk.ProtocolVersion != 4 && pk.ProtocolVersion != 5) { + return ErrProtocolViolationProtocolVersion // [MQTT-3.1.2-2] + } + + if pk.ReservedBit != 0 { + return ErrProtocolViolationReservedBit // [MQTT-3.1.2-3] + } + + if len(pk.Connect.Password) > math.MaxUint16 { + return ErrProtocolViolationPasswordTooLong + } + + if len(pk.Connect.Username) > math.MaxUint16 { + return ErrProtocolViolationUsernameTooLong + } + + if !pk.Connect.UsernameFlag && len(pk.Connect.Username) > 0 { + return ErrProtocolViolationUsernameNoFlag // [MQTT-3.1.2-16] + } + + if pk.Connect.PasswordFlag && len(pk.Connect.Password) == 0 { + return ErrProtocolViolationFlagNoPassword // [MQTT-3.1.2-19] + } + + if !pk.Connect.PasswordFlag && len(pk.Connect.Password) > 0 { + return ErrProtocolViolationPasswordNoFlag // [MQTT-3.1.2-18] + } + + if len(pk.Connect.ClientIdentifier) > math.MaxUint16 { + return ErrClientIdentifierNotValid + } + + if pk.Connect.WillFlag { + if len(pk.Connect.WillPayload) == 0 || pk.Connect.WillTopic == "" { + return ErrProtocolViolationWillFlagNoPayload // [MQTT-3.1.2-9] + } + + if pk.Connect.WillQos > 2 { + return ErrProtocolViolationQosOutOfRange // [MQTT-3.1.2-12] + } + } + + if !pk.Connect.WillFlag && pk.Connect.WillRetain { + return ErrProtocolViolationWillFlagSurplusRetain // [MQTT-3.1.2-13] + } + + return CodeSuccess +} + +// ConnackEncode encodes a Connack packet. +func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error { + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) + nb.WriteByte(encodeBool(pk.SessionPresent)) + nb.WriteByte(pk.ReasonCode) + + if pk.ProtocolVersion == 5 { + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) + pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode + nb.Write(pb.Bytes()) + } + + pk.FixedHeader.Remaining = nb.Len() + pk.FixedHeader.Encode(buf) + buf.Write(nb.Bytes()) + + return nil +} + +// ConnackDecode decodes a Connack packet. +func (pk *Packet) ConnackDecode(buf []byte) error { + var offset int + var err error + + pk.SessionPresent, offset, err = decodeByteBool(buf, 0) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent) + } + + pk.ReasonCode, offset, err = decodeByte(buf, offset) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedReasonCode) + } + + if pk.ProtocolVersion == 5 { + _, err := pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:])) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedProperties) + } + } + + return nil +} + +// DisconnectEncode encodes a Disconnect packet. +func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error { + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) + + if pk.ProtocolVersion == 5 { + nb.WriteByte(pk.ReasonCode) + + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) + pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()) + nb.Write(pb.Bytes()) + } + + pk.FixedHeader.Remaining = nb.Len() + pk.FixedHeader.Encode(buf) + buf.Write(nb.Bytes()) + + return nil +} + +// DisconnectDecode decodes a Disconnect packet. +func (pk *Packet) DisconnectDecode(buf []byte) error { + if pk.ProtocolVersion == 5 && pk.FixedHeader.Remaining > 1 { + var err error + var offset int + pk.ReasonCode, offset, err = decodeByte(buf, offset) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedReasonCode) + } + + if pk.FixedHeader.Remaining > 2 { + _, err = pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:])) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedProperties) + } + } + } + + return nil +} + +// PingreqEncode encodes a Pingreq packet. +func (pk *Packet) PingreqEncode(buf *bytes.Buffer) error { + pk.FixedHeader.Encode(buf) + return nil +} + +// PingreqDecode decodes a Pingreq packet. +func (pk *Packet) PingreqDecode(buf []byte) error { + return nil +} + +// PingrespEncode encodes a Pingresp packet. +func (pk *Packet) PingrespEncode(buf *bytes.Buffer) error { + pk.FixedHeader.Encode(buf) + return nil +} + +// PingrespDecode decodes a Pingres packet. +func (pk *Packet) PingrespDecode(buf []byte) error { + return nil +} + +// PublishEncode encodes a Publish packet. +func (pk *Packet) PublishEncode(buf *bytes.Buffer) error { + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) + + nb.Write(encodeString(pk.TopicName)) // [MQTT-3.3.2-1] + + if pk.FixedHeader.Qos > 0 { + if pk.PacketID == 0 { + return ErrProtocolViolationNoPacketID // [MQTT-2.2.1-2] + } + nb.Write(encodeUint16(pk.PacketID)) + } + + if pk.ProtocolVersion == 5 { + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) + pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload)) + nb.Write(pb.Bytes()) + } + + pk.FixedHeader.Remaining = nb.Len() + len(pk.Payload) + pk.FixedHeader.Encode(buf) + buf.Write(nb.Bytes()) + buf.Write(pk.Payload) + + return nil +} + +// PublishDecode extracts the data values from the packet. +func (pk *Packet) PublishDecode(buf []byte) error { + var offset int + var err error + + pk.TopicName, offset, err = decodeString(buf, 0) // [MQTT-3.3.2-1] + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedTopic) + } + + if pk.FixedHeader.Qos > 0 { + pk.PacketID, offset, err = decodeUint16(buf, offset) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedPacketID) + } + } + + if pk.ProtocolVersion == 5 { + n, err := pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:])) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedProperties) + } + + offset += n + } + + pk.Payload = buf[offset:] + + return nil +} + +// PublishValidate validates a publish packet. +func (pk *Packet) PublishValidate(topicAliasMaximum uint16) Code { + if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 { + return ErrProtocolViolationNoPacketID // [MQTT-2.2.1-3] [MQTT-2.2.1-4] + } + + if pk.FixedHeader.Qos == 0 && pk.PacketID > 0 { + return ErrProtocolViolationSurplusPacketID // [MQTT-2.2.1-2] + } + + if strings.ContainsAny(pk.TopicName, "+#") { + return ErrProtocolViolationSurplusWildcard // [MQTT-3.3.2-2] + } + + if pk.Properties.TopicAlias > topicAliasMaximum { + return ErrTopicAliasInvalid // [MQTT-3.2.2-17] [MQTT-3.3.2-9] ~[MQTT-3.3.2-10] [MQTT-3.3.2-12] + } + + if pk.TopicName == "" && pk.Properties.TopicAlias == 0 { + return ErrProtocolViolationNoTopic // ~[MQTT-3.3.2-8] + } + + if pk.Properties.TopicAliasFlag && pk.Properties.TopicAlias == 0 { + return ErrTopicAliasInvalid // [MQTT-3.3.2-8] + } + + if len(pk.Properties.SubscriptionIdentifier) > 0 { + return ErrProtocolViolationSurplusSubID // [MQTT-3.3.4-6] + } + + return CodeSuccess +} + +// encodePubAckRelRecComp encodes a Puback, Pubrel, Pubrec, or Pubcomp packet. +func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error { + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) + nb.Write(encodeUint16(pk.PacketID)) + + if pk.ProtocolVersion == 5 { + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) + pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()) + if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 { + nb.WriteByte(pk.ReasonCode) + } + + if pb.Len() > 1 { + nb.Write(pb.Bytes()) + } + } + + pk.FixedHeader.Remaining = nb.Len() + pk.FixedHeader.Encode(buf) + buf.Write(nb.Bytes()) + return nil +} + +// decode extracts the data values from a Puback, Pubrel, Pubrec, or Pubcomp packet. +func (pk *Packet) decodePubAckRelRecComp(buf []byte) error { + var offset int + var err error + pk.PacketID, offset, err = decodeUint16(buf, offset) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedPacketID) + } + + if pk.ProtocolVersion == 5 && pk.FixedHeader.Remaining > 2 { + pk.ReasonCode, offset, err = decodeByte(buf, offset) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedReasonCode) + } + + if pk.FixedHeader.Remaining > 3 { + _, err = pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:])) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedProperties) + } + } + } + + return nil +} + +// PubackEncode encodes a Puback packet. +func (pk *Packet) PubackEncode(buf *bytes.Buffer) error { + return pk.encodePubAckRelRecComp(buf) +} + +// PubackDecode decodes a Puback packet. +func (pk *Packet) PubackDecode(buf []byte) error { + return pk.decodePubAckRelRecComp(buf) +} + +// PubcompEncode encodes a Pubcomp packet. +func (pk *Packet) PubcompEncode(buf *bytes.Buffer) error { + return pk.encodePubAckRelRecComp(buf) +} + +// PubcompDecode decodes a Pubcomp packet. +func (pk *Packet) PubcompDecode(buf []byte) error { + return pk.decodePubAckRelRecComp(buf) +} + +// PubrecEncode encodes a Pubrec packet. +func (pk *Packet) PubrecEncode(buf *bytes.Buffer) error { + return pk.encodePubAckRelRecComp(buf) +} + +// PubrecDecode decodes a Pubrec packet. +func (pk *Packet) PubrecDecode(buf []byte) error { + return pk.decodePubAckRelRecComp(buf) +} + +// PubrelEncode encodes a Pubrel packet. +func (pk *Packet) PubrelEncode(buf *bytes.Buffer) error { + return pk.encodePubAckRelRecComp(buf) +} + +// PubrelDecode decodes a Pubrel packet. +func (pk *Packet) PubrelDecode(buf []byte) error { + return pk.decodePubAckRelRecComp(buf) +} + +// ReasonCodeValid returns true if the provided reason code is valid for the packet type. +func (pk *Packet) ReasonCodeValid() bool { + switch pk.FixedHeader.Type { + case Pubrec: + return bytes.Contains([]byte{ + CodeSuccess.Code, + CodeNoMatchingSubscribers.Code, + ErrUnspecifiedError.Code, + ErrImplementationSpecificError.Code, + ErrNotAuthorized.Code, + ErrTopicNameInvalid.Code, + ErrPacketIdentifierInUse.Code, + ErrQuotaExceeded.Code, + ErrPayloadFormatInvalid.Code, + }, []byte{pk.ReasonCode}) + case Pubrel: + fallthrough + case Pubcomp: + return bytes.Contains([]byte{ + CodeSuccess.Code, + ErrPacketIdentifierNotFound.Code, + }, []byte{pk.ReasonCode}) + case Suback: + return bytes.Contains([]byte{ + CodeGrantedQos0.Code, + CodeGrantedQos1.Code, + CodeGrantedQos2.Code, + ErrUnspecifiedError.Code, + ErrImplementationSpecificError.Code, + ErrNotAuthorized.Code, + ErrTopicFilterInvalid.Code, + ErrPacketIdentifierInUse.Code, + ErrQuotaExceeded.Code, + ErrSharedSubscriptionsNotSupported.Code, + ErrSubscriptionIdentifiersNotSupported.Code, + ErrWildcardSubscriptionsNotSupported.Code, + }, []byte{pk.ReasonCode}) + case Unsuback: + return bytes.Contains([]byte{ + CodeSuccess.Code, + CodeNoSubscriptionExisted.Code, + ErrUnspecifiedError.Code, + ErrImplementationSpecificError.Code, + ErrNotAuthorized.Code, + ErrTopicFilterInvalid.Code, + ErrPacketIdentifierInUse.Code, + }, []byte{pk.ReasonCode}) + } + + return true +} + +// SubackEncode encodes a Suback packet. +func (pk *Packet) SubackEncode(buf *bytes.Buffer) error { + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) + nb.Write(encodeUint16(pk.PacketID)) + + if pk.ProtocolVersion == 5 { + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) + pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes)) + nb.Write(pb.Bytes()) + } + + nb.Write(pk.ReasonCodes) + + pk.FixedHeader.Remaining = nb.Len() + pk.FixedHeader.Encode(buf) + buf.Write(nb.Bytes()) + + return nil +} + +// SubackDecode decodes a Suback packet. +func (pk *Packet) SubackDecode(buf []byte) error { + var offset int + var err error + + pk.PacketID, offset, err = decodeUint16(buf, offset) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedPacketID) + } + + if pk.ProtocolVersion == 5 { + n, err := pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:])) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedProperties) + } + offset += n + } + + pk.ReasonCodes = buf[offset:] + + return nil +} + +// SubscribeEncode encodes a Subscribe packet. +func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error { + if pk.PacketID == 0 { + return ErrProtocolViolationNoPacketID + } + + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) + nb.Write(encodeUint16(pk.PacketID)) + + xb := mempool.GetBuffer() // capture and write filters after length checks + defer mempool.PutBuffer(xb) + for _, opts := range pk.Filters { + xb.Write(encodeString(opts.Filter)) // [MQTT-3.8.3-1] + if pk.ProtocolVersion == 5 { + xb.WriteByte(opts.encode()) + } else { + xb.WriteByte(opts.Qos) + } + } + + if pk.ProtocolVersion == 5 { + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) + pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len()) + nb.Write(pb.Bytes()) + } + + nb.Write(xb.Bytes()) + + pk.FixedHeader.Remaining = nb.Len() + pk.FixedHeader.Encode(buf) + buf.Write(nb.Bytes()) + + return nil +} + +// SubscribeDecode decodes a Subscribe packet. +func (pk *Packet) SubscribeDecode(buf []byte) error { + var offset int + var err error + + pk.PacketID, offset, err = decodeUint16(buf, offset) + if err != nil { + return ErrMalformedPacketID + } + + if pk.ProtocolVersion == 5 { + n, err := pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:])) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedProperties) + } + offset += n + } + + var filter string + pk.Filters = Subscriptions{} + for offset < len(buf) { + filter, offset, err = decodeString(buf, offset) // [MQTT-3.8.3-1] + if err != nil { + return ErrMalformedTopic + } + + var option byte + sub := &Subscription{ + Filter: filter, + } + + if pk.ProtocolVersion == 5 { + sub.decode(buf[offset]) + offset += 1 + } else { + option, offset, err = decodeByte(buf, offset) + if err != nil { + return ErrMalformedQos + } + sub.Qos = option + } + + if len(pk.Properties.SubscriptionIdentifier) > 0 { + sub.Identifier = pk.Properties.SubscriptionIdentifier[0] + } + + if sub.Qos > 2 { + return ErrProtocolViolationQosOutOfRange + } + + pk.Filters = append(pk.Filters, *sub) + } + + return nil +} + +// SubscribeValidate ensures the packet is compliant. +func (pk *Packet) SubscribeValidate() Code { + if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 { + return ErrProtocolViolationNoPacketID // [MQTT-2.2.1-3] [MQTT-2.2.1-4] + } + + if len(pk.Filters) == 0 { + return ErrProtocolViolationNoFilters // [MQTT-3.10.3-2] + } + + for _, v := range pk.Filters { + if v.Identifier > 268435455 { // 3.3.2.3.8 The Subscription Identifier can have the value of 1 to 268,435,455. + return ErrProtocolViolationOversizeSubID // + } + } + + return CodeSuccess +} + +// UnsubackEncode encodes an Unsuback packet. +func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error { + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) + nb.Write(encodeUint16(pk.PacketID)) + + if pk.ProtocolVersion == 5 { + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) + pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()) + nb.Write(pb.Bytes()) + nb.Write(pk.ReasonCodes) + } + + pk.FixedHeader.Remaining = nb.Len() + pk.FixedHeader.Encode(buf) + buf.Write(nb.Bytes()) + + return nil +} + +// UnsubackDecode decodes an Unsuback packet. +func (pk *Packet) UnsubackDecode(buf []byte) error { + var offset int + var err error + + pk.PacketID, offset, err = decodeUint16(buf, offset) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedPacketID) + } + + if pk.ProtocolVersion == 5 { + n, err := pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:])) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedProperties) + } + + offset += n + + pk.ReasonCodes = buf[offset:] + } + + return nil +} + +// UnsubscribeEncode encodes an Unsubscribe packet. +func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error { + if pk.PacketID == 0 { + return ErrProtocolViolationNoPacketID + } + + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) + nb.Write(encodeUint16(pk.PacketID)) + + xb := mempool.GetBuffer() // capture filters and write after length checks + defer mempool.PutBuffer(xb) + for _, sub := range pk.Filters { + xb.Write(encodeString(sub.Filter)) // [MQTT-3.10.3-1] + } + + if pk.ProtocolVersion == 5 { + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) + pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len()) + nb.Write(pb.Bytes()) + } + + nb.Write(xb.Bytes()) + + pk.FixedHeader.Remaining = nb.Len() + pk.FixedHeader.Encode(buf) + buf.Write(nb.Bytes()) + + return nil +} + +// UnsubscribeDecode decodes an Unsubscribe packet. +func (pk *Packet) UnsubscribeDecode(buf []byte) error { + var offset int + var err error + + pk.PacketID, offset, err = decodeUint16(buf, offset) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedPacketID) + } + + if pk.ProtocolVersion == 5 { + n, err := pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:])) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedProperties) + } + offset += n + } + + var filter string + pk.Filters = Subscriptions{} + for offset < len(buf) { + filter, offset, err = decodeString(buf, offset) // [MQTT-3.10.3-1] + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedTopic) + } + pk.Filters = append(pk.Filters, Subscription{Filter: filter}) + } + + return nil +} + +// UnsubscribeValidate validates an Unsubscribe packet. +func (pk *Packet) UnsubscribeValidate() Code { + if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 { + return ErrProtocolViolationNoPacketID // [MQTT-2.2.1-3] [MQTT-2.2.1-4] + } + + if len(pk.Filters) == 0 { + return ErrProtocolViolationNoFilters // [MQTT-3.10.3-2] + } + + return CodeSuccess +} + +// AuthEncode encodes an Auth packet. +func (pk *Packet) AuthEncode(buf *bytes.Buffer) error { + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) + nb.WriteByte(pk.ReasonCode) + + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) + pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()) + nb.Write(pb.Bytes()) + + pk.FixedHeader.Remaining = nb.Len() + pk.FixedHeader.Encode(buf) + buf.Write(nb.Bytes()) + return nil +} + +// AuthDecode decodes an Auth packet. +func (pk *Packet) AuthDecode(buf []byte) error { + var offset int + var err error + + pk.ReasonCode, offset, err = decodeByte(buf, offset) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedReasonCode) + } + + _, err = pk.Properties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:])) + if err != nil { + return fmt.Errorf("%s: %w", err, ErrMalformedProperties) + } + + return nil +} + +// AuthValidate returns success if the auth packet is valid. +func (pk *Packet) AuthValidate() Code { + if pk.ReasonCode != CodeSuccess.Code && + pk.ReasonCode != CodeContinueAuthentication.Code && + pk.ReasonCode != CodeReAuthenticate.Code { + return ErrProtocolViolationInvalidReason // [MQTT-3.15.2-1] + } + + return CodeSuccess +} + +// FormatID returns the PacketID field as a decimal integer. +func (pk *Packet) FormatID() string { + return strconv.FormatUint(uint64(pk.PacketID), 10) +} diff --git a/packets/packets_test.go b/packets/packets_test.go new file mode 100644 index 0000000..1e18f1f --- /dev/null +++ b/packets/packets_test.go @@ -0,0 +1,505 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package packets + +import ( + "bytes" + "fmt" + "testing" + + "github.com/jinzhu/copier" + "github.com/stretchr/testify/require" +) + +const pkInfo = "packet type %v, %s" + +var packetList = []byte{ + Connect, + Connack, + Publish, + Puback, + Pubrec, + Pubrel, + Pubcomp, + Subscribe, + Suback, + Unsubscribe, + Unsuback, + Pingreq, + Pingresp, + Disconnect, + Auth, +} + +var pkTable = []TPacketCase{ + TPacketData[Connect].Get(TConnectMqtt311), + TPacketData[Connect].Get(TConnectMqtt5), + TPacketData[Connect].Get(TConnectUserPassLWT), + TPacketData[Connack].Get(TConnackAcceptedMqtt5), + TPacketData[Connack].Get(TConnackAcceptedNoSession), + TPacketData[Publish].Get(TPublishBasic), + TPacketData[Publish].Get(TPublishMqtt5), + TPacketData[Puback].Get(TPuback), + TPacketData[Pubrec].Get(TPubrec), + TPacketData[Pubrel].Get(TPubrel), + TPacketData[Pubcomp].Get(TPubcomp), + TPacketData[Subscribe].Get(TSubscribe), + TPacketData[Subscribe].Get(TSubscribeMqtt5), + TPacketData[Suback].Get(TSuback), + TPacketData[Unsubscribe].Get(TUnsubscribe), + TPacketData[Unsubscribe].Get(TUnsubscribeMqtt5), + TPacketData[Pingreq].Get(TPingreq), + TPacketData[Pingresp].Get(TPingresp), + TPacketData[Disconnect].Get(TDisconnect), + TPacketData[Disconnect].Get(TDisconnectMqtt5), +} + +func TestNewPackets(t *testing.T) { + s := NewPackets() + require.NotNil(t, s.internal) +} + +func TestPacketsAdd(t *testing.T) { + s := NewPackets() + s.Add("cl1", Packet{}) + require.Contains(t, s.internal, "cl1") +} + +func TestPacketsGet(t *testing.T) { + s := NewPackets() + s.Add("cl1", Packet{TopicName: "a1"}) + s.Add("cl2", Packet{TopicName: "a2"}) + require.Contains(t, s.internal, "cl1") + require.Contains(t, s.internal, "cl2") + + pk, ok := s.Get("cl1") + require.True(t, ok) + require.Equal(t, "a1", pk.TopicName) +} + +func TestPacketsGetAll(t *testing.T) { + s := NewPackets() + s.Add("cl1", Packet{TopicName: "a1"}) + s.Add("cl2", Packet{TopicName: "a2"}) + s.Add("cl3", Packet{TopicName: "a3"}) + require.Contains(t, s.internal, "cl1") + require.Contains(t, s.internal, "cl2") + require.Contains(t, s.internal, "cl3") + + subs := s.GetAll() + require.Len(t, subs, 3) +} + +func TestPacketsLen(t *testing.T) { + s := NewPackets() + s.Add("cl1", Packet{TopicName: "a1"}) + s.Add("cl2", Packet{TopicName: "a2"}) + require.Contains(t, s.internal, "cl1") + require.Contains(t, s.internal, "cl2") + require.Equal(t, 2, s.Len()) +} + +func TestSPacketsDelete(t *testing.T) { + s := NewPackets() + s.Add("cl1", Packet{TopicName: "a1"}) + require.Contains(t, s.internal, "cl1") + + s.Delete("cl1") + _, ok := s.Get("cl1") + require.False(t, ok) +} + +func TestFormatPacketID(t *testing.T) { + for _, id := range []uint16{0, 7, 0x100, 0xffff} { + packet := &Packet{PacketID: id} + require.Equal(t, fmt.Sprint(id), packet.FormatID()) + } +} + +func TestSubscriptionOptionsEncodeDecode(t *testing.T) { + p := &Subscription{ + Qos: 2, + NoLocal: true, + RetainAsPublished: true, + RetainHandling: 2, + } + x := new(Subscription) + x.decode(p.encode()) + require.Equal(t, *p, *x) + + p = &Subscription{ + Qos: 1, + NoLocal: false, + RetainAsPublished: false, + RetainHandling: 1, + } + x = new(Subscription) + x.decode(p.encode()) + require.Equal(t, *p, *x) +} + +func TestPacketEncode(t *testing.T) { + for _, pkt := range packetList { + require.Contains(t, TPacketData, pkt) + for _, wanted := range TPacketData[pkt] { + t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) { + if !encodeTestOK(wanted) { + return + } + + pk := new(Packet) + _ = copier.Copy(pk, wanted.Packet) + require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc) + + pk.Mods.AllowResponseInfo = true + + buf := new(bytes.Buffer) + var err error + switch pkt { + case Connect: + err = pk.ConnectEncode(buf) + case Connack: + err = pk.ConnackEncode(buf) + case Publish: + err = pk.PublishEncode(buf) + case Puback: + err = pk.PubackEncode(buf) + case Pubrec: + err = pk.PubrecEncode(buf) + case Pubrel: + err = pk.PubrelEncode(buf) + case Pubcomp: + err = pk.PubcompEncode(buf) + case Subscribe: + err = pk.SubscribeEncode(buf) + case Suback: + err = pk.SubackEncode(buf) + case Unsubscribe: + err = pk.UnsubscribeEncode(buf) + case Unsuback: + err = pk.UnsubackEncode(buf) + case Pingreq: + err = pk.PingreqEncode(buf) + case Pingresp: + err = pk.PingrespEncode(buf) + case Disconnect: + err = pk.DisconnectEncode(buf) + case Auth: + err = pk.AuthEncode(buf) + } + if wanted.Expect != nil { + require.Error(t, err, pkInfo, pkt, wanted.Desc) + return + } + + require.NoError(t, err, pkInfo, pkt, wanted.Desc) + encoded := buf.Bytes() + + // If ActualBytes is set, compare mutated version of byte string instead (to avoid length mismatches, etc). + if len(wanted.ActualBytes) > 0 { + wanted.RawBytes = wanted.ActualBytes + } + require.EqualValues(t, wanted.RawBytes, encoded, pkInfo, pkt, wanted.Desc) + }) + } + } +} + +func TestPacketDecode(t *testing.T) { + for _, pkt := range packetList { + require.Contains(t, TPacketData, pkt) + for _, wanted := range TPacketData[pkt] { + t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) { + if !decodeTestOK(wanted) { + return + } + + pk := &Packet{FixedHeader: FixedHeader{Type: pkt}} + pk.Mods.AllowResponseInfo = true + _ = pk.FixedHeader.Decode(wanted.RawBytes[0]) + if len(wanted.RawBytes) > 0 { + pk.FixedHeader.Remaining = int(wanted.RawBytes[1]) + } + + if wanted.Packet != nil && wanted.Packet.ProtocolVersion != 0 { + pk.ProtocolVersion = wanted.Packet.ProtocolVersion + } + + buf := wanted.RawBytes[2:] + var err error + switch pkt { + case Connect: + err = pk.ConnectDecode(buf) + case Connack: + err = pk.ConnackDecode(buf) + case Publish: + err = pk.PublishDecode(buf) + case Puback: + err = pk.PubackDecode(buf) + case Pubrec: + err = pk.PubrecDecode(buf) + case Pubrel: + err = pk.PubrelDecode(buf) + case Pubcomp: + err = pk.PubcompDecode(buf) + case Subscribe: + err = pk.SubscribeDecode(buf) + case Suback: + err = pk.SubackDecode(buf) + case Unsubscribe: + err = pk.UnsubscribeDecode(buf) + case Unsuback: + err = pk.UnsubackDecode(buf) + case Pingreq: + err = pk.PingreqDecode(buf) + case Pingresp: + err = pk.PingrespDecode(buf) + case Disconnect: + err = pk.DisconnectDecode(buf) + case Auth: + err = pk.AuthDecode(buf) + } + + if wanted.FailFirst != nil { + require.Error(t, err, pkInfo, pkt, wanted.Desc) + require.ErrorIs(t, err, wanted.FailFirst, pkInfo, pkt, wanted.Desc) + return + } + + require.NoError(t, err, pkInfo, pkt, wanted.Desc) + + require.EqualValues(t, wanted.Packet.Filters, pk.Filters, pkInfo, pkt, wanted.Desc) + + require.Equal(t, wanted.Packet.FixedHeader.Type, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.FixedHeader.Dup, pk.FixedHeader.Dup, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.FixedHeader.Qos, pk.FixedHeader.Qos, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.FixedHeader.Retain, pk.FixedHeader.Retain, pkInfo, pkt, wanted.Desc) + + if pkt == Connect { + // we use ProtocolVersion for controlling packet encoding, but we don't need to test + // against it unless it's a connect packet. + require.Equal(t, wanted.Packet.ProtocolVersion, pk.ProtocolVersion, pkInfo, pkt, wanted.Desc) + } + require.Equal(t, wanted.Packet.Connect.ProtocolName, pk.Connect.ProtocolName, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.Connect.Clean, pk.Connect.Clean, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.Connect.ClientIdentifier, pk.Connect.ClientIdentifier, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.Connect.Keepalive, pk.Connect.Keepalive, pkInfo, pkt, wanted.Desc) + + require.Equal(t, wanted.Packet.Connect.UsernameFlag, pk.Connect.UsernameFlag, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.Connect.Username, pk.Connect.Username, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.Connect.PasswordFlag, pk.Connect.PasswordFlag, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.Connect.Password, pk.Connect.Password, pkInfo, pkt, wanted.Desc) + + require.Equal(t, wanted.Packet.Connect.WillFlag, pk.Connect.WillFlag, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.Connect.WillTopic, pk.Connect.WillTopic, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.Connect.WillPayload, pk.Connect.WillPayload, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.Connect.WillQos, pk.Connect.WillQos, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.Connect.WillRetain, pk.Connect.WillRetain, pkInfo, pkt, wanted.Desc) + + require.Equal(t, wanted.Packet.ReasonCodes, pk.ReasonCodes, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.ReasonCode, pk.ReasonCode, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.SessionPresent, pk.SessionPresent, pkInfo, pkt, wanted.Desc) + require.Equal(t, wanted.Packet.PacketID, pk.PacketID, pkInfo, pkt, wanted.Desc) + + require.EqualValues(t, wanted.Packet.Properties, pk.Properties) + require.EqualValues(t, wanted.Packet.Connect.WillProperties, pk.Connect.WillProperties) + }) + } + } +} + +func TestValidate(t *testing.T) { + for _, pkt := range packetList { + require.Contains(t, TPacketData, pkt) + for _, wanted := range TPacketData[pkt] { + t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) { + if wanted.Group == "validate" || wanted.Primary { + pk := wanted.Packet + var err error + switch pkt { + case Connect: + err = pk.ConnectValidate() + case Publish: + err = pk.PublishValidate(1024) + case Subscribe: + err = pk.SubscribeValidate() + case Unsubscribe: + err = pk.UnsubscribeValidate() + case Auth: + err = pk.AuthValidate() + } + + if wanted.Expect != nil { + require.Error(t, err, pkInfo, pkt, wanted.Desc) + require.ErrorIs(t, wanted.Expect, err, pkInfo, pkt, wanted.Desc) + } + } + }) + } + } +} + +func TestAckValidatePubrec(t *testing.T) { + for _, b := range []byte{ + CodeSuccess.Code, + CodeNoMatchingSubscribers.Code, + ErrUnspecifiedError.Code, + ErrImplementationSpecificError.Code, + ErrNotAuthorized.Code, + ErrTopicNameInvalid.Code, + ErrPacketIdentifierInUse.Code, + ErrQuotaExceeded.Code, + ErrPayloadFormatInvalid.Code, + } { + pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: b} + require.True(t, pk.ReasonCodeValid()) + } + pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: ErrClientIdentifierTooLong.Code} + require.False(t, pk.ReasonCodeValid()) +} + +func TestAckValidatePubrel(t *testing.T) { + for _, b := range []byte{ + CodeSuccess.Code, + ErrPacketIdentifierNotFound.Code, + } { + pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: b} + require.True(t, pk.ReasonCodeValid()) + } + pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code} + require.False(t, pk.ReasonCodeValid()) +} + +func TestAckValidatePubcomp(t *testing.T) { + for _, b := range []byte{ + CodeSuccess.Code, + ErrPacketIdentifierNotFound.Code, + } { + pk := Packet{FixedHeader: FixedHeader{Type: Pubcomp}, ReasonCode: b} + require.True(t, pk.ReasonCodeValid()) + } + pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code} + require.False(t, pk.ReasonCodeValid()) +} + +func TestAckValidateSuback(t *testing.T) { + for _, b := range []byte{ + CodeGrantedQos0.Code, + CodeGrantedQos1.Code, + CodeGrantedQos2.Code, + ErrUnspecifiedError.Code, + ErrImplementationSpecificError.Code, + ErrNotAuthorized.Code, + ErrTopicFilterInvalid.Code, + ErrPacketIdentifierInUse.Code, + ErrQuotaExceeded.Code, + ErrSharedSubscriptionsNotSupported.Code, + ErrSubscriptionIdentifiersNotSupported.Code, + ErrWildcardSubscriptionsNotSupported.Code, + } { + pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: b} + require.True(t, pk.ReasonCodeValid()) + } + + pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: ErrClientIdentifierTooLong.Code} + require.False(t, pk.ReasonCodeValid()) +} + +func TestAckValidateUnsuback(t *testing.T) { + for _, b := range []byte{ + CodeSuccess.Code, + CodeNoSubscriptionExisted.Code, + ErrUnspecifiedError.Code, + ErrImplementationSpecificError.Code, + ErrNotAuthorized.Code, + ErrTopicFilterInvalid.Code, + ErrPacketIdentifierInUse.Code, + } { + pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: b} + require.True(t, pk.ReasonCodeValid()) + } + + pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: ErrClientIdentifierTooLong.Code} + require.False(t, pk.ReasonCodeValid()) +} + +func TestReasonCodeValidMisc(t *testing.T) { + pk := Packet{FixedHeader: FixedHeader{Type: Connack}, ReasonCode: CodeSuccess.Code} + require.True(t, pk.ReasonCodeValid()) +} + +func TestCopy(t *testing.T) { + for _, tt := range pkTable { + pkc := tt.Packet.Copy(true) + + require.Equal(t, tt.Packet.FixedHeader.Qos, pkc.FixedHeader.Qos, pkInfo, tt.Case, tt.Desc) + require.Equal(t, false, pkc.FixedHeader.Dup, pkInfo, tt.Case, tt.Desc) + require.Equal(t, false, pkc.FixedHeader.Retain, pkInfo, tt.Case, tt.Desc) + + require.Equal(t, tt.Packet.TopicName, pkc.TopicName, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Connect.ClientIdentifier, pkc.Connect.ClientIdentifier, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Connect.Keepalive, pkc.Connect.Keepalive, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.ProtocolVersion, pkc.ProtocolVersion, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Connect.PasswordFlag, pkc.Connect.PasswordFlag, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Connect.UsernameFlag, pkc.Connect.UsernameFlag, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Connect.WillQos, pkc.Connect.WillQos, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Connect.WillTopic, pkc.Connect.WillTopic, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Connect.WillFlag, pkc.Connect.WillFlag, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Connect.WillRetain, pkc.Connect.WillRetain, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Connect.WillProperties, pkc.Connect.WillProperties, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Properties, pkc.Properties, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Connect.Clean, pkc.Connect.Clean, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.SessionPresent, pkc.SessionPresent, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.ReasonCode, pkc.ReasonCode, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.PacketID, pkc.PacketID, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Filters, pkc.Filters, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Payload, pkc.Payload, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Connect.Password, pkc.Connect.Password, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Connect.Username, pkc.Connect.Username, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Connect.ProtocolName, pkc.Connect.ProtocolName, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Connect.WillPayload, pkc.Connect.WillPayload, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.ReasonCodes, pkc.ReasonCodes, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc) + require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc) + require.EqualValues(t, pkc.Properties, tt.Packet.Properties) + + pkcc := tt.Packet.Copy(false) + require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc) + } +} + +func TestMergeSubscription(t *testing.T) { + sub := Subscription{ + Filter: "a/b/c", + RetainHandling: 0, + Qos: 0, + RetainAsPublished: false, + NoLocal: false, + Identifier: 1, + } + + sub2 := Subscription{ + Filter: "a/b/d", + RetainHandling: 0, + Qos: 2, + RetainAsPublished: false, + NoLocal: true, + Identifier: 2, + } + + expect := Subscription{ + Filter: "a/b/c", + RetainHandling: 0, + Qos: 2, + RetainAsPublished: false, + NoLocal: true, + Identifier: 1, + Identifiers: map[string]int{ + "a/b/c": 1, + "a/b/d": 2, + }, + } + require.Equal(t, expect, sub.Merge(sub2)) +} diff --git a/packets/properties.go b/packets/properties.go new file mode 100644 index 0000000..1ad6294 --- /dev/null +++ b/packets/properties.go @@ -0,0 +1,481 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package packets + +import ( + "bytes" + "fmt" + "strings" + + "testmqtt/mempool" +) + +const ( + PropPayloadFormat byte = 1 + PropMessageExpiryInterval byte = 2 + PropContentType byte = 3 + PropResponseTopic byte = 8 + PropCorrelationData byte = 9 + PropSubscriptionIdentifier byte = 11 + PropSessionExpiryInterval byte = 17 + PropAssignedClientID byte = 18 + PropServerKeepAlive byte = 19 + PropAuthenticationMethod byte = 21 + PropAuthenticationData byte = 22 + PropRequestProblemInfo byte = 23 + PropWillDelayInterval byte = 24 + PropRequestResponseInfo byte = 25 + PropResponseInfo byte = 26 + PropServerReference byte = 28 + PropReasonString byte = 31 + PropReceiveMaximum byte = 33 + PropTopicAliasMaximum byte = 34 + PropTopicAlias byte = 35 + PropMaximumQos byte = 36 + PropRetainAvailable byte = 37 + PropUser byte = 38 + PropMaximumPacketSize byte = 39 + PropWildcardSubAvailable byte = 40 + PropSubIDAvailable byte = 41 + PropSharedSubAvailable byte = 42 +) + +// validPacketProperties indicates which properties are valid for which packet types. +var validPacketProperties = map[byte]map[byte]byte{ + PropPayloadFormat: {Publish: 1, WillProperties: 1}, + PropMessageExpiryInterval: {Publish: 1, WillProperties: 1}, + PropContentType: {Publish: 1, WillProperties: 1}, + PropResponseTopic: {Publish: 1, WillProperties: 1}, + PropCorrelationData: {Publish: 1, WillProperties: 1}, + PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1}, + PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1}, + PropAssignedClientID: {Connack: 1}, + PropServerKeepAlive: {Connack: 1}, + PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1}, + PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1}, + PropRequestProblemInfo: {Connect: 1}, + PropWillDelayInterval: {WillProperties: 1}, + PropRequestResponseInfo: {Connect: 1}, + PropResponseInfo: {Connack: 1}, + PropServerReference: {Connack: 1, Disconnect: 1}, + PropReasonString: {Connack: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Suback: 1, Unsuback: 1, Disconnect: 1, Auth: 1}, + PropReceiveMaximum: {Connect: 1, Connack: 1}, + PropTopicAliasMaximum: {Connect: 1, Connack: 1}, + PropTopicAlias: {Publish: 1}, + PropMaximumQos: {Connack: 1}, + PropRetainAvailable: {Connack: 1}, + PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1, WillProperties: 1}, + PropMaximumPacketSize: {Connect: 1, Connack: 1}, + PropWildcardSubAvailable: {Connack: 1}, + PropSubIDAvailable: {Connack: 1}, + PropSharedSubAvailable: {Connack: 1}, +} + +// UserProperty is an arbitrary key-value pair for a packet user properties array. +type UserProperty struct { // [MQTT-1.5.7-1] + Key string `json:"k"` + Val string `json:"v"` +} + +// Properties contains all mqtt v5 properties available for a packet. +// Some properties have valid values of 0 or not-present. In this case, we opt for +// property flags to indicate the usage of property. +// Refer to mqtt v5 2.2.2.2 Property spec for more information. +type Properties struct { + CorrelationData []byte `json:"cd"` + SubscriptionIdentifier []int `json:"si"` + AuthenticationData []byte `json:"ad"` + User []UserProperty `json:"user"` + ContentType string `json:"ct"` + ResponseTopic string `json:"rt"` + AssignedClientID string `json:"aci"` + AuthenticationMethod string `json:"am"` + ResponseInfo string `json:"ri"` + ServerReference string `json:"sr"` + ReasonString string `json:"rs"` + MessageExpiryInterval uint32 `json:"me"` + SessionExpiryInterval uint32 `json:"sei"` + WillDelayInterval uint32 `json:"wdi"` + MaximumPacketSize uint32 `json:"mps"` + ServerKeepAlive uint16 `json:"ska"` + ReceiveMaximum uint16 `json:"rm"` + TopicAliasMaximum uint16 `json:"tam"` + TopicAlias uint16 `json:"ta"` + PayloadFormat byte `json:"pf"` + PayloadFormatFlag bool `json:"fpf"` + SessionExpiryIntervalFlag bool `json:"fsei"` + ServerKeepAliveFlag bool `json:"fska"` + RequestProblemInfo byte `json:"rpi"` + RequestProblemInfoFlag bool `json:"frpi"` + RequestResponseInfo byte `json:"rri"` + TopicAliasFlag bool `json:"fta"` + MaximumQos byte `json:"mqos"` + MaximumQosFlag bool `json:"fmqos"` + RetainAvailable byte `json:"ra"` + RetainAvailableFlag bool `json:"fra"` + WildcardSubAvailable byte `json:"wsa"` + WildcardSubAvailableFlag bool `json:"fwsa"` + SubIDAvailable byte `json:"sida"` + SubIDAvailableFlag bool `json:"fsida"` + SharedSubAvailable byte `json:"ssa"` + SharedSubAvailableFlag bool `json:"fssa"` +} + +// Copy creates a new Properties struct with copies of the values. +func (p *Properties) Copy(allowTransfer bool) Properties { + pr := Properties{ + PayloadFormat: p.PayloadFormat, // [MQTT-3.3.2-4] + PayloadFormatFlag: p.PayloadFormatFlag, + MessageExpiryInterval: p.MessageExpiryInterval, + ContentType: p.ContentType, // [MQTT-3.3.2-20] + ResponseTopic: p.ResponseTopic, // [MQTT-3.3.2-15] + SessionExpiryInterval: p.SessionExpiryInterval, + SessionExpiryIntervalFlag: p.SessionExpiryIntervalFlag, + AssignedClientID: p.AssignedClientID, + ServerKeepAlive: p.ServerKeepAlive, + ServerKeepAliveFlag: p.ServerKeepAliveFlag, + AuthenticationMethod: p.AuthenticationMethod, + RequestProblemInfo: p.RequestProblemInfo, + RequestProblemInfoFlag: p.RequestProblemInfoFlag, + WillDelayInterval: p.WillDelayInterval, + RequestResponseInfo: p.RequestResponseInfo, + ResponseInfo: p.ResponseInfo, + ServerReference: p.ServerReference, + ReasonString: p.ReasonString, + ReceiveMaximum: p.ReceiveMaximum, + TopicAliasMaximum: p.TopicAliasMaximum, + TopicAlias: 0, // NB; do not copy topic alias [MQTT-3.3.2-7] + we do not send to clients (currently) [MQTT-3.1.2-26] [MQTT-3.1.2-27] + MaximumQos: p.MaximumQos, + MaximumQosFlag: p.MaximumQosFlag, + RetainAvailable: p.RetainAvailable, + RetainAvailableFlag: p.RetainAvailableFlag, + MaximumPacketSize: p.MaximumPacketSize, + WildcardSubAvailable: p.WildcardSubAvailable, + WildcardSubAvailableFlag: p.WildcardSubAvailableFlag, + SubIDAvailable: p.SubIDAvailable, + SubIDAvailableFlag: p.SubIDAvailableFlag, + SharedSubAvailable: p.SharedSubAvailable, + SharedSubAvailableFlag: p.SharedSubAvailableFlag, + } + + if allowTransfer { + pr.TopicAlias = p.TopicAlias + pr.TopicAliasFlag = p.TopicAliasFlag + } + + if len(p.CorrelationData) > 0 { + pr.CorrelationData = append([]byte{}, p.CorrelationData...) // [MQTT-3.3.2-16] + } + + if len(p.SubscriptionIdentifier) > 0 { + pr.SubscriptionIdentifier = append([]int{}, p.SubscriptionIdentifier...) + } + + if len(p.AuthenticationData) > 0 { + pr.AuthenticationData = append([]byte{}, p.AuthenticationData...) + } + + if len(p.User) > 0 { + pr.User = []UserProperty{} + for _, v := range p.User { + pr.User = append(pr.User, UserProperty{ // [MQTT-3.3.2-17] + Key: v.Key, + Val: v.Val, + }) + } + } + + return pr +} + +// canEncode returns true if the property type is valid for the packet type. +func (p *Properties) canEncode(pkt byte, k byte) bool { + return validPacketProperties[k][pkt] == 1 +} + +// Encode encodes properties into a bytes buffer. +func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) { + if p == nil { + return + } + + buf := mempool.GetBuffer() + defer mempool.PutBuffer(buf) + if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag { + buf.WriteByte(PropPayloadFormat) + buf.WriteByte(p.PayloadFormat) + } + + if p.canEncode(pkt, PropMessageExpiryInterval) && p.MessageExpiryInterval > 0 { + buf.WriteByte(PropMessageExpiryInterval) + buf.Write(encodeUint32(p.MessageExpiryInterval)) + } + + if p.canEncode(pkt, PropContentType) && p.ContentType != "" { + buf.WriteByte(PropContentType) + buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19] + } + + if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14] + p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28] + buf.WriteByte(PropResponseTopic) + buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13] + } + + if mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28] + buf.WriteByte(PropCorrelationData) + buf.Write(encodeBytes(p.CorrelationData)) + } + + if p.canEncode(pkt, PropSubscriptionIdentifier) && len(p.SubscriptionIdentifier) > 0 { + for _, v := range p.SubscriptionIdentifier { + if v > 0 { + buf.WriteByte(PropSubscriptionIdentifier) + encodeLength(buf, int64(v)) + } + } + } + + if p.canEncode(pkt, PropSessionExpiryInterval) && p.SessionExpiryIntervalFlag { // [MQTT-3.14.2-2] + buf.WriteByte(PropSessionExpiryInterval) + buf.Write(encodeUint32(p.SessionExpiryInterval)) + } + + if p.canEncode(pkt, PropAssignedClientID) && p.AssignedClientID != "" { + buf.WriteByte(PropAssignedClientID) + buf.Write(encodeString(p.AssignedClientID)) + } + + if p.canEncode(pkt, PropServerKeepAlive) && p.ServerKeepAliveFlag { + buf.WriteByte(PropServerKeepAlive) + buf.Write(encodeUint16(p.ServerKeepAlive)) + } + + if p.canEncode(pkt, PropAuthenticationMethod) && p.AuthenticationMethod != "" { + buf.WriteByte(PropAuthenticationMethod) + buf.Write(encodeString(p.AuthenticationMethod)) + } + + if p.canEncode(pkt, PropAuthenticationData) && len(p.AuthenticationData) > 0 { + buf.WriteByte(PropAuthenticationData) + buf.Write(encodeBytes(p.AuthenticationData)) + } + + if p.canEncode(pkt, PropRequestProblemInfo) && p.RequestProblemInfoFlag { + buf.WriteByte(PropRequestProblemInfo) + buf.WriteByte(p.RequestProblemInfo) + } + + if p.canEncode(pkt, PropWillDelayInterval) && p.WillDelayInterval > 0 { + buf.WriteByte(PropWillDelayInterval) + buf.Write(encodeUint32(p.WillDelayInterval)) + } + + if p.canEncode(pkt, PropRequestResponseInfo) && p.RequestResponseInfo > 0 { + buf.WriteByte(PropRequestResponseInfo) + buf.WriteByte(p.RequestResponseInfo) + } + + if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28] + buf.WriteByte(PropResponseInfo) + buf.Write(encodeString(p.ResponseInfo)) + } + + if p.canEncode(pkt, PropServerReference) && len(p.ServerReference) > 0 { + buf.WriteByte(PropServerReference) + buf.Write(encodeString(p.ServerReference)) + } + + // [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2] + // [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2] + if !mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" { + b := encodeString(p.ReasonString) + if mods.MaxSize == 0 || uint32(n+len(b)+1) < mods.MaxSize { + buf.WriteByte(PropReasonString) + buf.Write(b) + } + } + + if p.canEncode(pkt, PropReceiveMaximum) && p.ReceiveMaximum > 0 { + buf.WriteByte(PropReceiveMaximum) + buf.Write(encodeUint16(p.ReceiveMaximum)) + } + + if p.canEncode(pkt, PropTopicAliasMaximum) && p.TopicAliasMaximum > 0 { + buf.WriteByte(PropTopicAliasMaximum) + buf.Write(encodeUint16(p.TopicAliasMaximum)) + } + + if p.canEncode(pkt, PropTopicAlias) && p.TopicAliasFlag && p.TopicAlias > 0 { // [MQTT-3.3.2-8] + buf.WriteByte(PropTopicAlias) + buf.Write(encodeUint16(p.TopicAlias)) + } + + if p.canEncode(pkt, PropMaximumQos) && p.MaximumQosFlag && p.MaximumQos < 2 { + buf.WriteByte(PropMaximumQos) + buf.WriteByte(p.MaximumQos) + } + + if p.canEncode(pkt, PropRetainAvailable) && p.RetainAvailableFlag { + buf.WriteByte(PropRetainAvailable) + buf.WriteByte(p.RetainAvailable) + } + + if !mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) { + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) + for _, v := range p.User { + pb.WriteByte(PropUser) + pb.Write(encodeString(v.Key)) + pb.Write(encodeString(v.Val)) + } + // [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3] + // [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3] + if mods.MaxSize == 0 || uint32(n+pb.Len()+1) < mods.MaxSize { + buf.Write(pb.Bytes()) + } + } + + if p.canEncode(pkt, PropMaximumPacketSize) && p.MaximumPacketSize > 0 { + buf.WriteByte(PropMaximumPacketSize) + buf.Write(encodeUint32(p.MaximumPacketSize)) + } + + if p.canEncode(pkt, PropWildcardSubAvailable) && p.WildcardSubAvailableFlag { + buf.WriteByte(PropWildcardSubAvailable) + buf.WriteByte(p.WildcardSubAvailable) + } + + if p.canEncode(pkt, PropSubIDAvailable) && p.SubIDAvailableFlag { + buf.WriteByte(PropSubIDAvailable) + buf.WriteByte(p.SubIDAvailable) + } + + if p.canEncode(pkt, PropSharedSubAvailable) && p.SharedSubAvailableFlag { + buf.WriteByte(PropSharedSubAvailable) + buf.WriteByte(p.SharedSubAvailable) + } + + encodeLength(b, int64(buf.Len())) + b.Write(buf.Bytes()) // [MQTT-3.1.3-10] +} + +// Decode decodes property bytes into a properties struct. +func (p *Properties) Decode(pkt byte, b *bytes.Buffer) (n int, err error) { + if p == nil { + return 0, nil + } + + var bu int + n, bu, err = DecodeLength(b) + if err != nil { + return n + bu, err + } + + if n == 0 { + return n + bu, nil + } + + bt := b.Bytes() + var k byte + for offset := 0; offset < n; { + k, offset, err = decodeByte(bt, offset) + if err != nil { + return n + bu, err + } + + if _, ok := validPacketProperties[k][pkt]; !ok { + return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pkt, ErrProtocolViolationUnsupportedProperty) + } + + switch k { + case PropPayloadFormat: + p.PayloadFormat, offset, err = decodeByte(bt, offset) + p.PayloadFormatFlag = true + case PropMessageExpiryInterval: + p.MessageExpiryInterval, offset, err = decodeUint32(bt, offset) + case PropContentType: + p.ContentType, offset, err = decodeString(bt, offset) + case PropResponseTopic: + p.ResponseTopic, offset, err = decodeString(bt, offset) + case PropCorrelationData: + p.CorrelationData, offset, err = decodeBytes(bt, offset) + case PropSubscriptionIdentifier: + if p.SubscriptionIdentifier == nil { + p.SubscriptionIdentifier = []int{} + } + + n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:])) + if err != nil { + return n + bu, err + } + p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n) + offset += bu + case PropSessionExpiryInterval: + p.SessionExpiryInterval, offset, err = decodeUint32(bt, offset) + p.SessionExpiryIntervalFlag = true + case PropAssignedClientID: + p.AssignedClientID, offset, err = decodeString(bt, offset) + case PropServerKeepAlive: + p.ServerKeepAlive, offset, err = decodeUint16(bt, offset) + p.ServerKeepAliveFlag = true + case PropAuthenticationMethod: + p.AuthenticationMethod, offset, err = decodeString(bt, offset) + case PropAuthenticationData: + p.AuthenticationData, offset, err = decodeBytes(bt, offset) + case PropRequestProblemInfo: + p.RequestProblemInfo, offset, err = decodeByte(bt, offset) + p.RequestProblemInfoFlag = true + case PropWillDelayInterval: + p.WillDelayInterval, offset, err = decodeUint32(bt, offset) + case PropRequestResponseInfo: + p.RequestResponseInfo, offset, err = decodeByte(bt, offset) + case PropResponseInfo: + p.ResponseInfo, offset, err = decodeString(bt, offset) + case PropServerReference: + p.ServerReference, offset, err = decodeString(bt, offset) + case PropReasonString: + p.ReasonString, offset, err = decodeString(bt, offset) + case PropReceiveMaximum: + p.ReceiveMaximum, offset, err = decodeUint16(bt, offset) + case PropTopicAliasMaximum: + p.TopicAliasMaximum, offset, err = decodeUint16(bt, offset) + case PropTopicAlias: + p.TopicAlias, offset, err = decodeUint16(bt, offset) + p.TopicAliasFlag = true + case PropMaximumQos: + p.MaximumQos, offset, err = decodeByte(bt, offset) + p.MaximumQosFlag = true + case PropRetainAvailable: + p.RetainAvailable, offset, err = decodeByte(bt, offset) + p.RetainAvailableFlag = true + case PropUser: + var k, v string + k, offset, err = decodeString(bt, offset) + if err != nil { + return n + bu, err + } + v, offset, err = decodeString(bt, offset) + p.User = append(p.User, UserProperty{Key: k, Val: v}) + case PropMaximumPacketSize: + p.MaximumPacketSize, offset, err = decodeUint32(bt, offset) + case PropWildcardSubAvailable: + p.WildcardSubAvailable, offset, err = decodeByte(bt, offset) + p.WildcardSubAvailableFlag = true + case PropSubIDAvailable: + p.SubIDAvailable, offset, err = decodeByte(bt, offset) + p.SubIDAvailableFlag = true + case PropSharedSubAvailable: + p.SharedSubAvailable, offset, err = decodeByte(bt, offset) + p.SharedSubAvailableFlag = true + } + + if err != nil { + return n + bu, err + } + } + + return n + bu, nil +} diff --git a/packets/properties_test.go b/packets/properties_test.go new file mode 100644 index 0000000..b0a2f10 --- /dev/null +++ b/packets/properties_test.go @@ -0,0 +1,333 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package packets + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +var ( + propertiesStruct = Properties{ + PayloadFormat: byte(1), // UTF-8 Format + PayloadFormatFlag: true, + MessageExpiryInterval: uint32(2), + ContentType: "text/plain", + ResponseTopic: "a/b/c", + CorrelationData: []byte("data"), + SubscriptionIdentifier: []int{322122}, + SessionExpiryInterval: uint32(120), + SessionExpiryIntervalFlag: true, + AssignedClientID: "mochi-v5", + ServerKeepAlive: uint16(20), + ServerKeepAliveFlag: true, + AuthenticationMethod: "SHA-1", + AuthenticationData: []byte("auth-data"), + RequestProblemInfo: byte(1), + RequestProblemInfoFlag: true, + WillDelayInterval: uint32(600), + RequestResponseInfo: byte(1), + ResponseInfo: "response", + ServerReference: "mochi-2", + ReasonString: "reason", + ReceiveMaximum: uint16(500), + TopicAliasMaximum: uint16(999), + TopicAlias: uint16(3), + TopicAliasFlag: true, + MaximumQos: byte(1), + MaximumQosFlag: true, + RetainAvailable: byte(1), + RetainAvailableFlag: true, + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + { + Key: "key2", + Val: "value2", + }, + }, + MaximumPacketSize: uint32(32000), + WildcardSubAvailable: byte(1), + WildcardSubAvailableFlag: true, + SubIDAvailable: byte(1), + SubIDAvailableFlag: true, + SharedSubAvailable: byte(1), + SharedSubAvailableFlag: true, + } + + propertiesBytes = []byte{ + 172, 1, // VBI + + // Payload Format (1) (vbi:2) + 1, 1, + + // Message Expiry (2) (vbi:7) + 2, 0, 0, 0, 2, + + // Content Type (3) (vbi:20) + 3, + 0, 10, 't', 'e', 'x', 't', '/', 'p', 'l', 'a', 'i', 'n', + + // Response Topic (8) (vbi:28) + 8, + 0, 5, 'a', '/', 'b', '/', 'c', + + // Correlations Data (9) (vbi:35) + 9, + 0, 4, 'd', 'a', 't', 'a', + + // Subscription Identifier (11) (vbi:39) + 11, + 202, 212, 19, + + // Session Expiry Interval (17) (vbi:43) + 17, + 0, 0, 0, 120, + + // Assigned Client ID (18) (vbi:55) + 18, + 0, 8, 'm', 'o', 'c', 'h', 'i', '-', 'v', '5', + + // Server Keep Alive (19) (vbi:58) + 19, + 0, 20, + + // Authentication Method (21) (vbi:66) + 21, + 0, 5, 'S', 'H', 'A', '-', '1', + + // Authentication Data (22) (vbi:78) + 22, + 0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a', + + // Request Problem Info (23) (vbi:80) + 23, 1, + + // Will Delay Interval (24) (vbi:85) + 24, + 0, 0, 2, 88, + + // Request Response Info (25) (vbi:87) + 25, 1, + + // Response Info (26) (vbi:98) + 26, + 0, 8, 'r', 'e', 's', 'p', 'o', 'n', 's', 'e', + + // Server Reference (28) (vbi:108) + 28, + 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', + + // Reason String (31) (vbi:117) + 31, + 0, 6, 'r', 'e', 'a', 's', 'o', 'n', + + // Receive Maximum (33) (vbi:120) + 33, + 1, 244, + + // Topic Alias Maximum (34) (vbi:123) + 34, + 3, 231, + + // Topic Alias (35) (vbi:126) + 35, + 0, 3, + + // Maximum Qos (36) (vbi:128) + 36, 1, + + // Retain Available (37) (vbi: 130) + 37, 1, + + // User Properties (38) (vbi:161) + 38, + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + 38, + 0, 4, 'k', 'e', 'y', '2', + 0, 6, 'v', 'a', 'l', 'u', 'e', '2', + + // Maximum Packet Size (39) (vbi:166) + 39, + 0, 0, 125, 0, + + // Wildcard Subscriptions Available (40) (vbi:168) + 40, 1, + + // Subscription ID Available (41) (vbi:170) + 41, 1, + + // Shared Subscriptions Available (42) (vbi:172) + 42, 1, + } +) + +func init() { + validPacketProperties[PropPayloadFormat][Reserved] = 1 + validPacketProperties[PropMessageExpiryInterval][Reserved] = 1 + validPacketProperties[PropContentType][Reserved] = 1 + validPacketProperties[PropResponseTopic][Reserved] = 1 + validPacketProperties[PropCorrelationData][Reserved] = 1 + validPacketProperties[PropSubscriptionIdentifier][Reserved] = 1 + validPacketProperties[PropSessionExpiryInterval][Reserved] = 1 + validPacketProperties[PropAssignedClientID][Reserved] = 1 + validPacketProperties[PropServerKeepAlive][Reserved] = 1 + validPacketProperties[PropAuthenticationMethod][Reserved] = 1 + validPacketProperties[PropAuthenticationData][Reserved] = 1 + validPacketProperties[PropRequestProblemInfo][Reserved] = 1 + validPacketProperties[PropWillDelayInterval][Reserved] = 1 + validPacketProperties[PropRequestResponseInfo][Reserved] = 1 + validPacketProperties[PropResponseInfo][Reserved] = 1 + validPacketProperties[PropServerReference][Reserved] = 1 + validPacketProperties[PropReasonString][Reserved] = 1 + validPacketProperties[PropReceiveMaximum][Reserved] = 1 + validPacketProperties[PropTopicAliasMaximum][Reserved] = 1 + validPacketProperties[PropTopicAlias][Reserved] = 1 + validPacketProperties[PropMaximumQos][Reserved] = 1 + validPacketProperties[PropRetainAvailable][Reserved] = 1 + validPacketProperties[PropUser][Reserved] = 1 + validPacketProperties[PropMaximumPacketSize][Reserved] = 1 + validPacketProperties[PropWildcardSubAvailable][Reserved] = 1 + validPacketProperties[PropSubIDAvailable][Reserved] = 1 + validPacketProperties[PropSharedSubAvailable][Reserved] = 1 +} + +func TestEncodeProperties(t *testing.T) { + props := propertiesStruct + b := bytes.NewBuffer([]byte{}) + props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0) + require.Equal(t, propertiesBytes, b.Bytes()) +} + +func TestEncodePropertiesDisallowProblemInfo(t *testing.T) { + props := propertiesStruct + b := bytes.NewBuffer([]byte{}) + props.Encode(Reserved, Mods{DisallowProblemInfo: true}, b, 0) + require.NotEqual(t, propertiesBytes, b.Bytes()) + require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6})) + require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5})) + require.False(t, bytes.Contains(b.Bytes(), []byte{26, 0, 8})) +} + +func TestEncodePropertiesDisallowResponseInfo(t *testing.T) { + props := propertiesStruct + b := bytes.NewBuffer([]byte{}) + props.Encode(Reserved, Mods{AllowResponseInfo: false}, b, 0) + require.NotEqual(t, propertiesBytes, b.Bytes()) + require.NotContains(t, b.Bytes(), []byte{8, 0, 5}) + require.NotContains(t, b.Bytes(), []byte{9, 0, 4}) +} + +func TestEncodePropertiesNil(t *testing.T) { + type tmp struct { + p *Properties + } + + pr := tmp{} + b := bytes.NewBuffer([]byte{}) + pr.p.Encode(Reserved, Mods{}, b, 0) + require.Equal(t, []byte{}, b.Bytes()) +} + +func TestEncodeZeroProperties(t *testing.T) { + // [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero. + props := new(Properties) + b := bytes.NewBuffer([]byte{}) + props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0) + require.Equal(t, []byte{0x00}, b.Bytes()) +} + +func TestDecodeProperties(t *testing.T) { + b := bytes.NewBuffer(propertiesBytes) + + props := new(Properties) + n, err := props.Decode(Reserved, b) + require.NoError(t, err) + require.Equal(t, 172+2, n) + require.EqualValues(t, propertiesStruct, *props) +} + +func TestDecodePropertiesNil(t *testing.T) { + b := bytes.NewBuffer(propertiesBytes) + + type tmp struct { + p *Properties + } + + pr := tmp{} + n, err := pr.p.Decode(Reserved, b) + require.NoError(t, err) + require.Equal(t, 0, n) +} + +func TestDecodePropertiesBadInitialVBI(t *testing.T) { + b := bytes.NewBuffer([]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}) + props := new(Properties) + _, err := props.Decode(Reserved, b) + require.Error(t, err) + require.ErrorIs(t, ErrMalformedVariableByteInteger, err) +} + +func TestDecodePropertiesZeroLengthVBI(t *testing.T) { + b := bytes.NewBuffer([]byte{0}) + props := new(Properties) + _, err := props.Decode(Reserved, b) + require.NoError(t, err) + require.Equal(t, props, new(Properties)) +} + +func TestDecodePropertiesBadKeyByte(t *testing.T) { + b := bytes.NewBuffer([]byte{64, 1}) + props := new(Properties) + _, err := props.Decode(Reserved, b) + require.Error(t, err) + require.ErrorIs(t, err, ErrMalformedOffsetByteOutOfRange) +} + +func TestDecodePropertiesInvalidForPacket(t *testing.T) { + b := bytes.NewBuffer([]byte{1, 99}) + props := new(Properties) + _, err := props.Decode(Reserved, b) + require.Error(t, err) + require.ErrorIs(t, err, ErrProtocolViolationUnsupportedProperty) +} + +func TestDecodePropertiesGeneralFailure(t *testing.T) { + b := bytes.NewBuffer([]byte{10, 11, 202, 212, 19}) + props := new(Properties) + _, err := props.Decode(Reserved, b) + require.Error(t, err) +} + +func TestDecodePropertiesBadSubscriptionID(t *testing.T) { + b := bytes.NewBuffer([]byte{10, 11, 255, 255, 255, 255, 255, 255, 255, 255}) + props := new(Properties) + _, err := props.Decode(Reserved, b) + require.Error(t, err) +} + +func TestDecodePropertiesBadUserProps(t *testing.T) { + b := bytes.NewBuffer([]byte{10, 38, 255, 255, 255, 255, 255, 255, 255, 255}) + props := new(Properties) + _, err := props.Decode(Reserved, b) + require.Error(t, err) +} + +func TestCopyProperties(t *testing.T) { + require.EqualValues(t, propertiesStruct, propertiesStruct.Copy(true)) +} + +func TestCopyPropertiesNoTransfer(t *testing.T) { + pkA := propertiesStruct + pkB := pkA.Copy(false) + + // Properties which should never be transferred from one connection to another + require.Equal(t, uint16(0), pkB.TopicAlias) +} diff --git a/packets/tpackets.go b/packets/tpackets.go new file mode 100644 index 0000000..79a6f58 --- /dev/null +++ b/packets/tpackets.go @@ -0,0 +1,4031 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package packets + +// TPacketCase contains data for cross-checking the encoding and decoding +// of packets and expected scenarios. +type TPacketCase struct { + RawBytes []byte // the bytes that make the packet + ActualBytes []byte // the actual byte array that is created in the event of a byte mutation + Group string // a group that should run the test, blank for all + Desc string // a description of the test + FailFirst error // expected fail result to be run immediately after the method is called + Packet *Packet // the packet that is Expected + ActualPacket *Packet // the actual packet after mutations + Expect error // generic Expected fail result to be checked + Isolate bool // isolate can be used to isolate a test + Primary bool // primary is a test that should be run using readPackets + Case byte // the identifying byte of the case +} + +// TPacketCases is a slice of TPacketCase. +type TPacketCases []TPacketCase + +// Get returns a case matching a given T byte. +func (f TPacketCases) Get(b byte) TPacketCase { + for _, v := range f { + if v.Case == b { + return v + } + } + + return TPacketCase{} +} + +const ( + TConnectMqtt31 byte = iota + TConnectMqtt311 + TConnectMqtt5 + TConnectMqtt5LWT + TConnectClean + TConnectUserPass + TConnectUserPassLWT + TConnectMalProtocolName + TConnectMalProtocolVersion + TConnectMalFlags + TConnectMalKeepalive + TConnectMalClientID + TConnectMalWillTopic + TConnectMalWillFlag + TConnectMalUsername + TConnectMalPassword + TConnectMalFixedHeader + TConnectMalReservedBit + TConnectMalProperties + TConnectMalWillProperties + TConnectInvalidProtocolName + TConnectInvalidProtocolVersion + TConnectInvalidProtocolVersion2 + TConnectInvalidReservedBit + TConnectInvalidClientIDTooLong + TConnectInvalidFlagNoUsername + TConnectInvalidFlagNoPassword + TConnectInvalidUsernameNoFlag + TConnectInvalidPasswordNoFlag + TConnectInvalidUsernameTooLong + TConnectInvalidPasswordTooLong + TConnectInvalidWillFlagNoPayload + TConnectInvalidWillFlagQosOutOfRange + TConnectInvalidWillSurplusRetain + TConnectZeroByteUsername + TConnectSpecInvalidUTF8D800 + TConnectSpecInvalidUTF8DFFF + TConnectSpecInvalidUTF80000 + TConnectSpecInvalidUTF8NoSkip + TConnackAcceptedNoSession + TConnackAcceptedSessionExists + TConnackAcceptedMqtt5 + TConnackAcceptedAdjustedExpiryInterval + TConnackMinMqtt5 + TConnackMinCleanMqtt5 + TConnackServerKeepalive + TConnackInvalidMinMqtt5 + TConnackBadProtocolVersion + TConnackProtocolViolationNoSession + TConnackBadClientID + TConnackServerUnavailable + TConnackBadUsernamePassword + TConnackBadUsernamePasswordNoSession + TConnackMqtt5BadUsernamePasswordNoSession + TConnackNotAuthorised + TConnackMalSessionPresent + TConnackMalReturnCode + TConnackMalProperties + TConnackDropProperties + TConnackDropPropertiesPartial + TPublishNoPayload + TPublishBasic + TPublishBasicTopicAliasOnly + TPublishBasicMqtt5 + TPublishMqtt5 + TPublishQos1 + TPublishQos1Mqtt5 + TPublishQos1NoPayload + TPublishQos1Dup + TPublishQos2 + TPublishQos2Mqtt5 + TPublishQos2Upgraded + TPublishSubscriberIdentifier + TPublishRetain + TPublishRetainMqtt5 + TPublishDup + TPublishMalTopicName + TPublishMalPacketID + TPublishMalProperties + TPublishCopyBasic + TPublishSpecQos0NoPacketID + TPublishSpecQosMustPacketID + TPublishDropOversize + TPublishInvalidQos0NoPacketID + TPublishInvalidQosMustPacketID + TPublishInvalidSurplusSubID + TPublishInvalidSurplusWildcard + TPublishInvalidSurplusWildcard2 + TPublishInvalidNoTopic + TPublishInvalidTopicAlias + TPublishInvalidExcessTopicAlias + TPublishSpecDenySysTopic + TPuback + TPubackMqtt5 + TPubackMqtt5NotAuthorized + TPubackMalPacketID + TPubackMalProperties + TPubackUnexpectedError + TPubrec + TPubrecMqtt5 + TPubrecMqtt5IDInUse + TPubrecMqtt5NotAuthorized + TPubrecMalPacketID + TPubrecMalProperties + TPubrecMalReasonCode + TPubrecInvalidReason + TPubrel + TPubrelMqtt5 + TPubrelMqtt5AckNoPacket + TPubrelMalPacketID + TPubrelMalProperties + TPubrelInvalidReason + TPubcomp + TPubcompMqtt5 + TPubcompMqtt5AckNoPacket + TPubcompMalPacketID + TPubcompMalProperties + TPubcompInvalidReason + TSubscribe + TSubscribeMany + TSubscribeMqtt5 + TSubscribeRetainHandling1 + TSubscribeRetainHandling2 + TSubscribeRetainAsPublished + TSubscribeMalPacketID + TSubscribeMalTopic + TSubscribeMalQos + TSubscribeMalQosRange + TSubscribeMalProperties + TSubscribeInvalidQosMustPacketID + TSubscribeSpecQosMustPacketID + TSubscribeInvalidNoFilters + TSubscribeInvalidSharedNoLocal + TSubscribeInvalidFilter + TSubscribeInvalidIdentifierOversize + TSuback + TSubackMany + TSubackDeny + TSubackUnspecifiedError + TSubackUnspecifiedErrorMqtt5 + TSubackMqtt5 + TSubackPacketIDInUse + TSubackInvalidFilter + TSubackInvalidSharedNoLocal + TSubackMalPacketID + TSubackMalProperties + TUnsubscribe + TUnsubscribeMany + TUnsubscribeMqtt5 + TUnsubscribeMalPacketID + TUnsubscribeMalTopicName + TUnsubscribeMalProperties + TUnsubscribeInvalidQosMustPacketID + TUnsubscribeSpecQosMustPacketID + TUnsubscribeInvalidNoFilters + TUnsuback + TUnsubackMany + TUnsubackMqtt5 + TUnsubackPacketIDInUse + TUnsubackMalPacketID + TUnsubackMalProperties + TPingreq + TPingresp + TDisconnect + TDisconnectTakeover + TDisconnectMqtt5 + TDisconnectMqtt5DisconnectWithWillMessage + TDisconnectSecondConnect + TDisconnectReceiveMaximum + TDisconnectDropProperties + TDisconnectShuttingDown + TDisconnectMalProperties + TDisconnectMalReasonCode + TDisconnectZeroNonZeroExpiry + TAuth + TAuthMalReasonCode + TAuthMalProperties + TAuthInvalidReason + TAuthInvalidReason2 +) + +// TPacketData contains individual encoding and decoding scenarios for each packet type. +var TPacketData = map[byte]TPacketCases{ + Connect: { + { + Case: TConnectMqtt31, + Desc: "mqtt v3.1", + Primary: true, + RawBytes: []byte{ + Connect << 4, 17, // Fixed header + 0, 6, // Protocol Name - MSB+LSB + 'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name + 3, // Protocol Version + 0, // Packet Flags + 0, 30, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 17, + }, + ProtocolVersion: 3, + Connect: ConnectParams{ + ProtocolName: []byte("MQIsdp"), + Clean: false, + Keepalive: 30, + ClientIdentifier: "zen", + }, + }, + }, + { + Case: TConnectMqtt311, + Desc: "mqtt v3.1.1", + Primary: true, + RawBytes: []byte{ + Connect << 4, 15, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Packet Flags + 0, 60, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 15, + }, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + Clean: false, + Keepalive: 60, + ClientIdentifier: "zen", + }, + }, + }, + { + Case: TConnectMqtt5, + Desc: "mqtt v5", + Primary: true, + RawBytes: []byte{ + Connect << 4, 87, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 5, // Protocol Version + 0, // Packet Flags + 0, 30, // Keepalive + + // Properties + 71, // length + 17, 0, 0, 0, 120, // Session Expiry Interval (17) + 21, 0, 5, 'S', 'H', 'A', '-', '1', // Authentication Method (21) + 22, 0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a', // Authentication Data (22) + 23, 1, // Request Problem Info (23) + 25, 1, // Request Response Info (25) + 33, 1, 244, // Receive Maximum (33) + 34, 3, 231, // Topic Alias Maximum (34) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + 38, // User Properties (38) + 0, 4, 'k', 'e', 'y', '2', + 0, 6, 'v', 'a', 'l', 'u', 'e', '2', + 39, 0, 0, 125, 0, // Maximum Packet Size (39) + + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 87, + }, + ProtocolVersion: 5, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + Clean: false, + Keepalive: 30, + ClientIdentifier: "zen", + }, + Properties: Properties{ + SessionExpiryInterval: uint32(120), + SessionExpiryIntervalFlag: true, + AuthenticationMethod: "SHA-1", + AuthenticationData: []byte("auth-data"), + RequestProblemInfo: byte(1), + RequestProblemInfoFlag: true, + RequestResponseInfo: byte(1), + ReceiveMaximum: uint16(500), + TopicAliasMaximum: uint16(999), + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + { + Key: "key2", + Val: "value2", + }, + }, + MaximumPacketSize: uint32(32000), + }, + }, + }, + { + Case: TConnectClean, + Desc: "mqtt 3.1.1, clean session", + RawBytes: []byte{ + Connect << 4, 15, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 2, // Packet Flags + 0, 45, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 15, + }, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + Clean: true, + Keepalive: 45, + ClientIdentifier: "zen", + }, + }, + }, + { + Case: TConnectMqtt5LWT, + Desc: "mqtt 5 clean session, lwt", + RawBytes: []byte{ + Connect << 4, 47, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 5, // Protocol Version + 14, // Packet Flags + 0, 30, // Keepalive + + // Properties + 10, // length + 17, 0, 0, 0, 120, // Session Expiry Interval (17) + 39, 0, 0, 125, 0, // Maximum Packet Size (39) + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 5, // will properties length + 24, 0, 0, 2, 88, // will delay interval (24) + + 0, 3, // Will Topic - MSB+LSB + 'l', 'w', 't', + 0, 8, // Will Message MSB+LSB + 'n', 'o', 't', 'a', 'g', 'a', 'i', 'n', + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 42, + }, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + Clean: true, + Keepalive: 30, + ClientIdentifier: "zen", + WillFlag: true, + WillTopic: "lwt", + WillPayload: []byte("notagain"), + WillQos: 1, + WillProperties: Properties{ + WillDelayInterval: uint32(600), + }, + }, + Properties: Properties{ + SessionExpiryInterval: uint32(120), + SessionExpiryIntervalFlag: true, + MaximumPacketSize: uint32(32000), + }, + }, + }, + { + Case: TConnectUserPass, + Desc: "mqtt 3.1.1, username, password", + RawBytes: []byte{ + Connect << 4, 28, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0 | 1<<6 | 1<<7, // Packet Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 5, // Username MSB+LSB + 'm', 'o', 'c', 'h', 'i', + 0, 4, // Password MSB+LSB + ',', '.', '/', ';', + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 28, + }, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + Clean: false, + Keepalive: 20, + ClientIdentifier: "zen", + UsernameFlag: true, + PasswordFlag: true, + Username: []byte("mochi"), + Password: []byte(",./;"), + }, + }, + }, + { + Case: TConnectUserPassLWT, + Desc: "mqtt 3.1.1, username, password, lwt", + Primary: true, + RawBytes: []byte{ + Connect << 4, 44, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 206, // Packet Flags + 0, 120, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 3, // Will Topic - MSB+LSB + 'l', 'w', 't', + 0, 9, // Will Message MSB+LSB + 'n', 'o', 't', ' ', 'a', 'g', 'a', 'i', 'n', + 0, 5, // Username MSB+LSB + 'm', 'o', 'c', 'h', 'i', + 0, 4, // Password MSB+LSB + ',', '.', '/', ';', + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 44, + }, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + Clean: true, + Keepalive: 120, + ClientIdentifier: "zen", + UsernameFlag: true, + PasswordFlag: true, + Username: []byte("mochi"), + Password: []byte(",./;"), + WillFlag: true, + WillTopic: "lwt", + WillPayload: []byte("not again"), + WillQos: 1, + }, + }, + }, + { + Case: TConnectZeroByteUsername, + Desc: "username flag but 0 byte username", + Group: "decode", + RawBytes: []byte{ + Connect << 4, 23, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 5, // Protocol Version + 130, // Packet Flags + 0, 30, // Keepalive + 5, // length + 17, 0, 0, 0, 120, // Session Expiry Interval (17) + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 0, // Username MSB+LSB + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 23, + }, + ProtocolVersion: 5, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + Clean: true, + Keepalive: 30, + ClientIdentifier: "zen", + Username: []byte{}, + UsernameFlag: true, + }, + Properties: Properties{ + SessionExpiryInterval: uint32(120), + SessionExpiryIntervalFlag: true, + }, + }, + }, + + // Fail States + { + Case: TConnectMalProtocolName, + Desc: "malformed protocol name", + Group: "decode", + FailFirst: ErrMalformedProtocolName, + RawBytes: []byte{ + Connect << 4, 0, // Fixed header + 0, 7, // Protocol Name - MSB+LSB + 'M', 'Q', 'I', 's', 'd', // Protocol Name + }, + }, + { + Case: TConnectMalProtocolVersion, + Desc: "malformed protocol version", + Group: "decode", + FailFirst: ErrMalformedProtocolVersion, + RawBytes: []byte{ + Connect << 4, 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + }, + }, + { + Case: TConnectMalFlags, + Desc: "malformed flags", + Group: "decode", + FailFirst: ErrMalformedFlags, + RawBytes: []byte{ + Connect << 4, 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + + }, + }, + { + Case: TConnectMalKeepalive, + Desc: "malformed keepalive", + Group: "decode", + FailFirst: ErrMalformedKeepalive, + RawBytes: []byte{ + Connect << 4, 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Flags + }, + }, + { + Case: TConnectMalClientID, + Desc: "malformed client id", + Group: "decode", + FailFirst: ErrClientIdentifierNotValid, + RawBytes: []byte{ + Connect << 4, 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', // Client ID "zen" + }, + }, + { + Case: TConnectMalWillTopic, + Desc: "malformed will topic", + Group: "decode", + FailFirst: ErrMalformedWillTopic, + RawBytes: []byte{ + Connect << 4, 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 14, // Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 6, // Will Topic - MSB+LSB + 'l', + }, + }, + { + Case: TConnectMalWillFlag, + Desc: "malformed will flag", + Group: "decode", + FailFirst: ErrMalformedWillPayload, + RawBytes: []byte{ + Connect << 4, 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 14, // Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 3, // Will Topic - MSB+LSB + 'l', 'w', 't', + 0, 9, // Will Message MSB+LSB + 'n', 'o', 't', ' ', 'a', + }, + }, + { + Case: TConnectMalUsername, + Desc: "malformed username", + Group: "decode", + FailFirst: ErrMalformedUsername, + RawBytes: []byte{ + Connect << 4, 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 206, // Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 3, // Will Topic - MSB+LSB + 'l', 'w', 't', + 0, 9, // Will Message MSB+LSB + 'n', 'o', 't', ' ', 'a', 'g', 'a', 'i', 'n', + 0, 5, // Username MSB+LSB + 'm', 'o', 'c', + }, + }, + + { + Case: TConnectInvalidFlagNoUsername, + Desc: "username flag with no username bytes", + Group: "decode", + FailFirst: ErrProtocolViolationFlagNoUsername, + RawBytes: []byte{ + Connect << 4, 17, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 5, // Protocol Version + 130, // Flags + 0, 20, // Keepalive + 0, + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + }, + }, + { + Case: TConnectMalPassword, + Desc: "malformed password", + Group: "decode", + FailFirst: ErrMalformedPassword, + RawBytes: []byte{ + Connect << 4, 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 206, // Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 3, // Will Topic - MSB+LSB + 'l', 'w', 't', + 0, 9, // Will Message MSB+LSB + 'n', 'o', 't', ' ', 'a', 'g', 'a', 'i', 'n', + 0, 5, // Username MSB+LSB + 'm', 'o', 'c', 'h', 'i', + 0, 4, // Password MSB+LSB + ',', '.', + }, + }, + { + Case: TConnectMalFixedHeader, + Desc: "malformed fixedheader oversize", + Group: "decode", + FailFirst: ErrMalformedProtocolName, // packet test doesn't test fixedheader oversize + RawBytes: []byte{ + Connect << 4, 255, 255, 255, 255, 255, // Fixed header + }, + }, + { + Case: TConnectMalReservedBit, + Desc: "reserved bit not 0", + Group: "nodecode", + FailFirst: ErrProtocolViolation, + RawBytes: []byte{ + Connect << 4, 15, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 1, // Packet Flags + 0, 45, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', + }, + }, + { + Case: TConnectMalProperties, + Desc: "malformed properties", + Group: "decode", + FailFirst: ErrMalformedProperties, + RawBytes: []byte{ + Connect << 4, 47, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 5, // Protocol Version + 14, // Packet Flags + 0, 30, // Keepalive + 10, + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + { + Case: TConnectMalWillProperties, + Desc: "malformed will properties", + Group: "decode", + FailFirst: ErrMalformedWillProperties, + RawBytes: []byte{ + Connect << 4, 47, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 5, // Protocol Version + 14, // Packet Flags + 0, 30, // Keepalive + 10, // length + 17, 0, 0, 0, 120, // Session Expiry Interval (17) + 39, 0, 0, 125, 0, // Maximum Packet Size (39) + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 5, // will properties length + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + + // Validation Tests + { + Case: TConnectInvalidProtocolName, + Desc: "invalid protocol name", + Group: "validate", + Expect: ErrProtocolViolationProtocolName, + Packet: &Packet{ + FixedHeader: FixedHeader{Type: Connect}, + Connect: ConnectParams{ + ProtocolName: []byte("stuff"), + }, + }, + }, + { + Case: TConnectInvalidProtocolVersion, + Desc: "invalid protocol version", + Group: "validate", + Expect: ErrProtocolViolationProtocolVersion, + Packet: &Packet{ + FixedHeader: FixedHeader{Type: Connect}, + ProtocolVersion: 2, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + }, + }, + }, + { + Case: TConnectInvalidProtocolVersion2, + Desc: "invalid protocol version", + Group: "validate", + Expect: ErrProtocolViolationProtocolVersion, + Packet: &Packet{ + FixedHeader: FixedHeader{Type: Connect}, + ProtocolVersion: 2, + Connect: ConnectParams{ + ProtocolName: []byte("MQIsdp"), + }, + }, + }, + { + Case: TConnectInvalidReservedBit, + Desc: "reserved bit not 0", + Group: "validate", + Expect: ErrProtocolViolationReservedBit, + Packet: &Packet{ + FixedHeader: FixedHeader{Type: Connect}, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + }, + ReservedBit: 1, + }, + }, + { + Case: TConnectInvalidClientIDTooLong, + Desc: "client id too long", + Group: "validate", + Expect: ErrClientIdentifierNotValid, + Packet: &Packet{ + FixedHeader: FixedHeader{Type: Connect}, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + ClientIdentifier: func() string { + return string(make([]byte, 65536)) + }(), + }, + }, + }, + { + Case: TConnectInvalidUsernameNoFlag, + Desc: "has username but no flag", + Group: "validate", + Expect: ErrProtocolViolationUsernameNoFlag, + Packet: &Packet{ + FixedHeader: FixedHeader{Type: Connect}, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + Username: []byte("username"), + }, + }, + }, + { + Case: TConnectInvalidFlagNoPassword, + Desc: "has password flag but no password", + Group: "validate", + Expect: ErrProtocolViolationFlagNoPassword, + Packet: &Packet{ + FixedHeader: FixedHeader{Type: Connect}, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + PasswordFlag: true, + }, + }, + }, + { + Case: TConnectInvalidPasswordNoFlag, + Desc: "has password flag but no password", + Group: "validate", + Expect: ErrProtocolViolationPasswordNoFlag, + Packet: &Packet{ + FixedHeader: FixedHeader{Type: Connect}, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + Password: []byte("password"), + }, + }, + }, + { + Case: TConnectInvalidUsernameTooLong, + Desc: "username too long", + Group: "validate", + Expect: ErrProtocolViolationUsernameTooLong, + Packet: &Packet{ + FixedHeader: FixedHeader{Type: Connect}, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + UsernameFlag: true, + Username: func() []byte { + return make([]byte, 65536) + }(), + }, + }, + }, + { + Case: TConnectInvalidPasswordTooLong, + Desc: "password too long", + Group: "validate", + Expect: ErrProtocolViolationPasswordTooLong, + Packet: &Packet{ + FixedHeader: FixedHeader{Type: Connect}, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + UsernameFlag: true, + Username: []byte{}, + PasswordFlag: true, + Password: func() []byte { + return make([]byte, 65536) + }(), + }, + }, + }, + { + Case: TConnectInvalidWillFlagNoPayload, + Desc: "will flag no payload", + Group: "validate", + Expect: ErrProtocolViolationWillFlagNoPayload, + Packet: &Packet{ + FixedHeader: FixedHeader{Type: Connect}, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + WillFlag: true, + }, + }, + }, + { + Case: TConnectInvalidWillFlagQosOutOfRange, + Desc: "will flag no payload", + Group: "validate", + Expect: ErrProtocolViolationQosOutOfRange, + Packet: &Packet{ + FixedHeader: FixedHeader{Type: Connect}, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + WillFlag: true, + WillTopic: "a/b/c", + WillPayload: []byte{'b'}, + WillQos: 4, + }, + }, + }, + { + Case: TConnectInvalidWillSurplusRetain, + Desc: "no will flag surplus retain", + Group: "validate", + Expect: ErrProtocolViolationWillFlagSurplusRetain, + Packet: &Packet{ + FixedHeader: FixedHeader{Type: Connect}, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + WillRetain: true, + }, + }, + }, + + // Spec Tests + { + Case: TConnectSpecInvalidUTF8D800, + Desc: "invalid utf8 string (a) - code point U+D800", + Group: "decode", + FailFirst: ErrClientIdentifierNotValid, + RawBytes: []byte{ + Connect << 4, 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Flags + 0, 20, // Keepalive + 0, 4, // Client ID - MSB+LSB + 'e', 0xed, 0xa0, 0x80, // Client id bearing U+D800 + }, + }, + { + Case: TConnectSpecInvalidUTF8DFFF, + Desc: "invalid utf8 string (b) - code point U+DFFF", + Group: "decode", + FailFirst: ErrClientIdentifierNotValid, + RawBytes: []byte{ + Connect << 4, 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Flags + 0, 20, // Keepalive + 0, 4, // Client ID - MSB+LSB + 'e', 0xed, 0xa3, 0xbf, // Client id bearing U+D8FF + }, + }, + + { + Case: TConnectSpecInvalidUTF80000, + Desc: "invalid utf8 string (c) - code point U+0000", + Group: "decode", + FailFirst: ErrClientIdentifierNotValid, + RawBytes: []byte{ + Connect << 4, 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'e', 0xc0, 0x80, // Client id bearing U+0000 + }, + }, + + { + Case: TConnectSpecInvalidUTF8NoSkip, + Desc: "utf8 string must not skip or strip code point U+FEFF", + //Group: "decode", + //FailFirst: ErrMalformedClientID, + RawBytes: []byte{ + Connect << 4, 18, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Flags + 0, 20, // Keepalive + 0, 6, // Client ID - MSB+LSB + 'e', 'b', 0xEF, 0xBB, 0xBF, 'd', // Client id bearing U+FEFF + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 16, + }, + ProtocolVersion: 4, + Connect: ConnectParams{ + ProtocolName: []byte("MQTT"), + Keepalive: 20, + ClientIdentifier: string([]byte{'e', 'b', 0xEF, 0xBB, 0xBF, 'd'}), + }, + }, + }, + }, + Connack: { + { + Case: TConnackAcceptedNoSession, + Desc: "accepted, no session", + Primary: true, + RawBytes: []byte{ + Connack << 4, 2, // fixed header + 0, // No existing session + CodeSuccess.Code, + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + SessionPresent: false, + ReasonCode: CodeSuccess.Code, + }, + }, + { + Case: TConnackAcceptedSessionExists, + Desc: "accepted, session exists", + Primary: true, + RawBytes: []byte{ + Connack << 4, 2, // fixed header + 1, // Session present + CodeSuccess.Code, + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + SessionPresent: true, + ReasonCode: CodeSuccess.Code, + }, + }, + { + Case: TConnackAcceptedAdjustedExpiryInterval, + Desc: "accepted, no session, adjusted expiry interval mqtt5", + Primary: true, + RawBytes: []byte{ + Connack << 4, 8, // fixed header + 0, // Session present + CodeSuccess.Code, + 5, // length + 17, 0, 0, 0, 120, // Session Expiry Interval (17) + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 8, + }, + ReasonCode: CodeSuccess.Code, + Properties: Properties{ + SessionExpiryInterval: uint32(120), + SessionExpiryIntervalFlag: true, + }, + }, + }, + { + Case: TConnackAcceptedMqtt5, + Desc: "accepted no session mqtt5", + Primary: true, + RawBytes: []byte{ + Connack << 4, 124, // fixed header + 0, // No existing session + CodeSuccess.Code, + // Properties + 121, // length + 17, 0, 0, 0, 120, // Session Expiry Interval (17) + 18, 0, 8, 'm', 'o', 'c', 'h', 'i', '-', 'v', '5', // Assigned Client ID (18) + 19, 0, 20, // Server Keep Alive (19) + 21, 0, 5, 'S', 'H', 'A', '-', '1', // Authentication Method (21) + 22, 0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a', // Authentication Data (22) + 26, 0, 8, 'r', 'e', 's', 'p', 'o', 'n', 's', 'e', // Response Info (26) + 28, 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', // Server Reference (28) + 31, 0, 6, 'r', 'e', 'a', 's', 'o', 'n', // Reason String (31) + 33, 1, 244, // Receive Maximum (33) + 34, 3, 231, // Topic Alias Maximum (34) + 36, 1, // Maximum Qos (36) + 37, 1, // Retain Available (37) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + 38, // User Properties (38) + 0, 4, 'k', 'e', 'y', '2', + 0, 6, 'v', 'a', 'l', 'u', 'e', '2', + 39, 0, 0, 125, 0, // Maximum Packet Size (39) + 40, 1, // Wildcard Subscriptions Available (40) + 41, 1, // Subscription ID Available (41) + 42, 1, // Shared Subscriptions Available (42) + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 124, + }, + SessionPresent: false, + ReasonCode: CodeSuccess.Code, + Properties: Properties{ + SessionExpiryInterval: uint32(120), + SessionExpiryIntervalFlag: true, + AssignedClientID: "mochi-v5", + ServerKeepAlive: uint16(20), + ServerKeepAliveFlag: true, + AuthenticationMethod: "SHA-1", + AuthenticationData: []byte("auth-data"), + ResponseInfo: "response", + ServerReference: "mochi-2", + ReasonString: "reason", + ReceiveMaximum: uint16(500), + TopicAliasMaximum: uint16(999), + MaximumQos: byte(1), + MaximumQosFlag: true, + RetainAvailable: byte(1), + RetainAvailableFlag: true, + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + { + Key: "key2", + Val: "value2", + }, + }, + MaximumPacketSize: uint32(32000), + WildcardSubAvailable: byte(1), + WildcardSubAvailableFlag: true, + SubIDAvailable: byte(1), + SubIDAvailableFlag: true, + SharedSubAvailable: byte(1), + SharedSubAvailableFlag: true, + }, + }, + }, + { + Case: TConnackMinMqtt5, + Desc: "accepted min properties mqtt5", + Primary: true, + RawBytes: []byte{ + Connack << 4, 13, // fixed header + 1, // existing session + CodeSuccess.Code, + 10, // Properties length + 18, 0, 5, 'm', 'o', 'c', 'h', 'i', // Assigned Client ID (18) + 36, 1, // Maximum Qos (36) + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 13, + }, + SessionPresent: true, + ReasonCode: CodeSuccess.Code, + Properties: Properties{ + AssignedClientID: "mochi", + MaximumQos: byte(1), + MaximumQosFlag: true, + }, + }, + }, + { + Case: TConnackMinCleanMqtt5, + Desc: "accepted min properties mqtt5b", + Primary: true, + RawBytes: []byte{ + Connack << 4, 3, // fixed header + 0, // existing session + CodeSuccess.Code, + 0, // Properties length + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 16, + }, + SessionPresent: false, + ReasonCode: CodeSuccess.Code, + }, + }, + { + Case: TConnackServerKeepalive, + Desc: "server set keepalive", + Primary: true, + RawBytes: []byte{ + Connack << 4, 6, // fixed header + 1, // existing session + CodeSuccess.Code, + 3, // Properties length + 19, 0, 10, // server keepalive + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 6, + }, + SessionPresent: true, + ReasonCode: CodeSuccess.Code, + Properties: Properties{ + ServerKeepAlive: uint16(10), + ServerKeepAliveFlag: true, + }, + }, + }, + { + Case: TConnackInvalidMinMqtt5, + Desc: "failure min properties mqtt5", + Primary: true, + RawBytes: append([]byte{ + Connack << 4, 23, // fixed header + 0, // No existing session + ErrUnspecifiedError.Code, + // Properties + 20, // length + 31, 0, 17, // Reason String (31) + }, []byte(ErrUnspecifiedError.Reason)...), + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 23, + }, + SessionPresent: false, + ReasonCode: ErrUnspecifiedError.Code, + Properties: Properties{ + ReasonString: ErrUnspecifiedError.Reason, + }, + }, + }, + + { + Case: TConnackProtocolViolationNoSession, + Desc: "miscellaneous protocol violation", + RawBytes: []byte{ + Connack << 4, 2, // fixed header + 0, // Session present + ErrProtocolViolation.Code, + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + ReasonCode: ErrProtocolViolation.Code, + }, + }, + { + Case: TConnackBadProtocolVersion, + Desc: "bad protocol version", + RawBytes: []byte{ + Connack << 4, 2, // fixed header + 1, // Session present + ErrProtocolViolationProtocolVersion.Code, + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + SessionPresent: true, + ReasonCode: ErrProtocolViolationProtocolVersion.Code, + }, + }, + { + Case: TConnackBadClientID, + Desc: "bad client id", + RawBytes: []byte{ + Connack << 4, 2, // fixed header + 1, // Session present + ErrClientIdentifierNotValid.Code, + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + SessionPresent: true, + ReasonCode: ErrClientIdentifierNotValid.Code, + }, + }, + { + Case: TConnackServerUnavailable, + Desc: "server unavailable", + RawBytes: []byte{ + Connack << 4, 2, // fixed header + 1, // Session present + ErrServerUnavailable.Code, + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + SessionPresent: true, + ReasonCode: ErrServerUnavailable.Code, + }, + }, + { + Case: TConnackBadUsernamePassword, + Desc: "bad username or password", + RawBytes: []byte{ + Connack << 4, 2, // fixed header + 1, // Session present + ErrBadUsernameOrPassword.Code, + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + SessionPresent: true, + ReasonCode: ErrBadUsernameOrPassword.Code, + }, + }, + { + Case: TConnackBadUsernamePasswordNoSession, + Desc: "bad username or password no session", + RawBytes: []byte{ + Connack << 4, 2, // fixed header + 0, // No session present + Err3NotAuthorized.Code, // use v3 remapping + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + ReasonCode: Err3NotAuthorized.Code, + }, + }, + { + Case: TConnackMqtt5BadUsernamePasswordNoSession, + Desc: "mqtt5 bad username or password no session", + RawBytes: []byte{ + Connack << 4, 3, // fixed header + 0, // No session present + ErrBadUsernameOrPassword.Code, + 0, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + ReasonCode: ErrBadUsernameOrPassword.Code, + }, + }, + + { + Case: TConnackNotAuthorised, + Desc: "not authorised", + RawBytes: []byte{ + Connack << 4, 2, // fixed header + 1, // Session present + ErrNotAuthorized.Code, + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + SessionPresent: true, + ReasonCode: ErrNotAuthorized.Code, + }, + }, + { + Case: TConnackDropProperties, + Desc: "drop oversize properties", + Group: "encode", + RawBytes: []byte{ + Connack << 4, 40, // fixed header + 0, // No existing session + CodeSuccess.Code, + 19, // length + 28, 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', // Server Reference (28) + 31, 0, 6, 'r', 'e', 'a', 's', 'o', 'n', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + ActualBytes: []byte{ + Connack << 4, 13, // fixed header + 0, // No existing session + CodeSuccess.Code, + 10, // length + 28, 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', // Server Reference (28) + }, + Packet: &Packet{ + Mods: Mods{ + MaxSize: 5, + }, + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 40, + }, + ReasonCode: CodeSuccess.Code, + Properties: Properties{ + ReasonString: "reason", + ServerReference: "mochi-2", + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + { + Case: TConnackDropPropertiesPartial, + Desc: "drop oversize properties partial", + Group: "encode", + RawBytes: []byte{ + Connack << 4, 40, // fixed header + 0, // No existing session + CodeSuccess.Code, + 19, // length + 28, 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', // Server Reference (28) + 31, 0, 6, 'r', 'e', 'a', 's', 'o', 'n', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + ActualBytes: []byte{ + Connack << 4, 22, // fixed header + 0, // No existing session + CodeSuccess.Code, + 19, // length + 28, 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', // Server Reference (28) + 31, 0, 6, 'r', 'e', 'a', 's', 'o', 'n', // Reason String (31) + }, + Packet: &Packet{ + Mods: Mods{ + MaxSize: 18, + }, + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 40, + }, + ReasonCode: CodeSuccess.Code, + Properties: Properties{ + ReasonString: "reason", + ServerReference: "mochi-2", + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + // Fail States + { + Case: TConnackMalSessionPresent, + Desc: "malformed session present", + Group: "decode", + FailFirst: ErrMalformedSessionPresent, + RawBytes: []byte{ + Connect << 4, 2, // Fixed header + }, + }, + { + Case: TConnackMalReturnCode, + Desc: "malformed bad return Code", + Group: "decode", + //Primary: true, + FailFirst: ErrMalformedReasonCode, + RawBytes: []byte{ + Connect << 4, 2, // Fixed header + 0, + }, + }, + { + Case: TConnackMalProperties, + Desc: "malformed properties", + Group: "decode", + FailFirst: ErrMalformedProperties, + RawBytes: []byte{ + Connack << 4, 40, // fixed header + 0, // No existing session + CodeSuccess.Code, + 19, // length + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + }, + + Publish: { + { + Case: TPublishNoPayload, + Desc: "no payload", + Primary: true, + RawBytes: []byte{ + Publish << 4, 7, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 7, + }, + TopicName: "a/b/c", + Payload: []byte{}, + }, + }, + { + Case: TPublishBasic, + Desc: "basic", + Primary: true, + RawBytes: []byte{ + Publish << 4, 18, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 18, + }, + TopicName: "a/b/c", + Payload: []byte("hello mochi"), + }, + }, + + { + Case: TPublishMqtt5, + Desc: "mqtt v5", + Primary: true, + RawBytes: []byte{ + Publish << 4, 77, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 58, // length + 1, 1, // Payload Format (1) + 2, 0, 0, 0, 2, // Message Expiry (2) + 3, 0, 10, 't', 'e', 'x', 't', '/', 'p', 'l', 'a', 'i', 'n', // Content Type (3) + 8, 0, 5, 'a', '/', 'b', '/', 'c', // Response Topic (8) + 9, 0, 4, 'd', 'a', 't', 'a', // Correlations Data (9) + 11, 202, 212, 19, // Subscription Identifier (11) + 35, 0, 3, // Topic Alias (35) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 77, + }, + TopicName: "a/b/c", + Properties: Properties{ + PayloadFormat: byte(1), // UTF-8 Format + PayloadFormatFlag: true, + MessageExpiryInterval: uint32(2), + ContentType: "text/plain", + ResponseTopic: "a/b/c", + CorrelationData: []byte("data"), + SubscriptionIdentifier: []int{322122}, + TopicAlias: uint16(3), + TopicAliasFlag: true, + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + Payload: []byte("hello mochi"), + }, + }, + { + Case: TPublishBasicTopicAliasOnly, + Desc: "mqtt v5 topic alias only", + Primary: true, + RawBytes: []byte{ + Publish << 4, 17, // Fixed header + 0, 0, // Topic Name - LSB+MSB + 3, // length + 35, 0, 1, // Topic Alias (35) + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 17, + }, + Properties: Properties{ + TopicAlias: 1, + TopicAliasFlag: true, + }, + Payload: []byte("hello mochi"), + }, + }, + { + Case: TPublishBasicMqtt5, + Desc: "mqtt basic v5", + Primary: true, + RawBytes: []byte{ + Publish << 4, 22, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 3, // length + 35, 0, 1, // Topic Alias (35) + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 22, + }, + TopicName: "a/b/c", + Properties: Properties{ + TopicAlias: uint16(1), + TopicAliasFlag: true, + }, + Payload: []byte("hello mochi"), + }, + }, + + { + Case: TPublishQos1, + Desc: "qos:1, packet id", + Primary: true, + RawBytes: []byte{ + Publish<<4 | 1<<1, 20, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 7, // Packet ID - LSB+MSB + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Qos: 1, + Remaining: 20, + }, + TopicName: "a/b/c", + Payload: []byte("hello mochi"), + PacketID: 7, + }, + }, + { + Case: TPublishQos1Mqtt5, + Desc: "mqtt v5", + Primary: true, + RawBytes: []byte{ + Publish<<4 | 1<<1, 37, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 7, // Packet ID - LSB+MSB + // Properties + 16, // length + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 37, + Qos: 1, + }, + PacketID: 7, + TopicName: "a/b/c", + Properties: Properties{ + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + Payload: []byte("hello mochi"), + }, + }, + + { + Case: TPublishQos1Dup, + Desc: "qos:1, dup:true, packet id", + Primary: true, + RawBytes: []byte{ + Publish<<4 | 2 | 8, 20, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 7, // Packet ID - LSB+MSB + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Qos: 1, + Remaining: 20, + Dup: true, + }, + TopicName: "a/b/c", + Payload: []byte("hello mochi"), + PacketID: 7, + }, + }, + { + Case: TPublishQos1NoPayload, + Desc: "qos:1, packet id, no payload", + Primary: true, + RawBytes: []byte{ + Publish<<4 | 2, 9, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'y', '/', 'u', '/', 'i', // Topic Name + 0, 7, // Packet ID - LSB+MSB + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Qos: 1, + Remaining: 9, + }, + TopicName: "y/u/i", + PacketID: 7, + Payload: []byte{}, + }, + }, + { + Case: TPublishQos2, + Desc: "qos:2, packet id", + Primary: true, + RawBytes: []byte{ + Publish<<4 | 2<<1, 14, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 7, // Packet ID - LSB+MSB + 'h', 'e', 'l', 'l', 'o', // Payload + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Qos: 2, + Remaining: 14, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + PacketID: 7, + }, + }, + { + Case: TPublishQos2Mqtt5, + Desc: "mqtt v5", + Primary: true, + RawBytes: []byte{ + Publish<<4 | 2<<1, 37, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 7, // Packet ID - LSB+MSB + // Properties + 16, // length + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 37, + Qos: 2, + }, + PacketID: 7, + TopicName: "a/b/c", + Properties: Properties{ + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + Payload: []byte("hello mochi"), + }, + }, + { + Case: TPublishSubscriberIdentifier, + Desc: "subscription identifiers", + Primary: true, + RawBytes: []byte{ + Publish << 4, 23, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 4, // properties length + 11, 2, // Subscription Identifier (11) + 11, 3, // Subscription Identifier (11) + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 23, + }, + TopicName: "a/b/c", + Properties: Properties{ + SubscriptionIdentifier: []int{2, 3}, + }, + Payload: []byte("hello mochi"), + }, + }, + + { + Case: TPublishQos2Upgraded, + Desc: "qos:2, upgraded from publish to qos2 sub", + Primary: true, + RawBytes: []byte{ + Publish<<4 | 2<<1, 20, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 1, // Packet ID - LSB+MSB + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Qos: 2, + Remaining: 18, + }, + TopicName: "a/b/c", + Payload: []byte("hello mochi"), + PacketID: 1, + }, + }, + { + Case: TPublishRetain, + Desc: "retain", + RawBytes: []byte{ + Publish<<4 | 1<<0, 18, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Retain: true, + }, + TopicName: "a/b/c", + Payload: []byte("hello mochi"), + }, + }, + { + Case: TPublishRetainMqtt5, + Desc: "retain mqtt5", + RawBytes: []byte{ + Publish<<4 | 1<<0, 19, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, // properties length + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Publish, + Retain: true, + Remaining: 19, + }, + TopicName: "a/b/c", + Properties: Properties{}, + Payload: []byte("hello mochi"), + }, + }, + { + Case: TPublishDup, + Desc: "dup", + RawBytes: []byte{ + Publish<<4 | 8, 10, // Fixed header + 0, 3, // Topic Name - LSB+MSB + 'a', '/', 'b', // Topic Name + 'h', 'e', 'l', 'l', 'o', // Payload + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Dup: true, + }, + TopicName: "a/b", + Payload: []byte("hello"), + }, + }, + + // Fail States + { + Case: TPublishMalTopicName, + Desc: "malformed topic name", + Group: "decode", + FailFirst: ErrMalformedTopic, + RawBytes: []byte{ + Publish << 4, 7, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', + 0, 11, // Packet ID - LSB+MSB + }, + }, + { + Case: TPublishMalPacketID, + Desc: "malformed packet id", + Group: "decode", + FailFirst: ErrMalformedPacketID, + RawBytes: []byte{ + Publish<<4 | 2, 7, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'x', '/', 'y', '/', 'z', // Topic Name + 0, // Packet ID - LSB+MSB + }, + }, + { + Case: TPublishMalProperties, + Desc: "malformed properties", + Group: "decode", + FailFirst: ErrMalformedProperties, + RawBytes: []byte{ + Publish << 4, 35, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 16, + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + + // Copy tests + { + Case: TPublishCopyBasic, + Desc: "basic copyable", + Group: "copy", + RawBytes: []byte{ + Publish << 4, 18, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'z', '/', 'e', '/', 'n', // Topic Name + 'm', 'o', 'c', 'h', 'i', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Dup: true, + Retain: true, + Qos: 1, + }, + TopicName: "z/e/n", + Payload: []byte("mochi mochi"), + }, + }, + + // Spec tests + { + Case: TPublishSpecQos0NoPacketID, + Desc: "packet id must be 0 if qos is 0 (a)", + Group: "encode", + // this version tests for correct byte array mutuation. + // this does not check if -incoming- Packets are parsed as correct, + // it is impossible for the parser to determine if the payload start is incorrect. + RawBytes: []byte{ + Publish << 4, 12, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 3, // Packet ID - LSB+MSB + 'h', 'e', 'l', 'l', 'o', // Payload + }, + ActualBytes: []byte{ + Publish << 4, 12, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + // Packet ID is removed. + 'h', 'e', 'l', 'l', 'o', // Payload + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 12, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + }, + }, + { + Case: TPublishSpecQosMustPacketID, + Desc: "no packet id with qos > 0", + Group: "encode", + Expect: ErrProtocolViolationNoPacketID, + RawBytes: []byte{ + Publish<<4 | 2, 14, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 0, // Packet ID - LSB+MSB + 'h', 'e', 'l', 'l', 'o', // Payload + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Qos: 1, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + PacketID: 0, + }, + }, + { + Case: TPublishDropOversize, + Desc: "drop oversized publish packet", + Group: "encode", + FailFirst: ErrPacketTooLarge, + RawBytes: []byte{ + Publish << 4, 10, // Fixed header + 0, 3, // Topic Name - LSB+MSB + 'a', '/', 'b', // Topic Name + 'h', 'e', 'l', 'l', 'o', // Payload + }, + Packet: &Packet{ + Mods: Mods{ + MaxSize: 2, + }, + FixedHeader: FixedHeader{ + Type: Publish, + }, + TopicName: "a/b", + Payload: []byte("hello"), + }, + }, + + // Validation Tests + { + Case: TPublishInvalidQos0NoPacketID, + Desc: "packet id must be 0 if qos is 0 (b)", + Group: "validate", + Expect: ErrProtocolViolationSurplusPacketID, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 12, + Qos: 0, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + PacketID: 3, + }, + }, + { + Case: TPublishInvalidQosMustPacketID, + Desc: "no packet id with qos > 0", + Group: "validate", + Expect: ErrProtocolViolationNoPacketID, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Qos: 1, + }, + PacketID: 0, + }, + }, + { + Case: TPublishInvalidSurplusSubID, + Desc: "surplus subscription identifier", + Group: "validate", + Expect: ErrProtocolViolationSurplusSubID, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + }, + Properties: Properties{ + SubscriptionIdentifier: []int{1}, + }, + TopicName: "a/b", + }, + }, + { + Case: TPublishInvalidSurplusWildcard, + Desc: "topic contains wildcards", + Group: "validate", + Expect: ErrProtocolViolationSurplusWildcard, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + }, + TopicName: "a/+", + }, + }, + { + Case: TPublishInvalidSurplusWildcard2, + Desc: "topic contains wildcards 2", + Group: "validate", + Expect: ErrProtocolViolationSurplusWildcard, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + }, + TopicName: "a/#", + }, + }, + { + Case: TPublishInvalidNoTopic, + Desc: "no topic or alias specified", + Group: "validate", + Expect: ErrProtocolViolationNoTopic, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + }, + }, + }, + { + Case: TPublishInvalidExcessTopicAlias, + Desc: "topic alias over maximum", + Group: "validate", + Expect: ErrTopicAliasInvalid, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + }, + Properties: Properties{ + TopicAlias: 1025, + }, + TopicName: "a/b", + }, + }, + { + Case: TPublishInvalidTopicAlias, + Desc: "topic alias flag and no alias", + Group: "validate", + Expect: ErrTopicAliasInvalid, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + }, + Properties: Properties{ + TopicAliasFlag: true, + TopicAlias: 0, + }, + TopicName: "a/b/", + }, + }, + { + Case: TPublishSpecDenySysTopic, + Desc: "deny publishing to $SYS topics", + Group: "validate", + Expect: CodeSuccess, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + }, + TopicName: "$SYS/any", + Payload: []byte("y"), + }, + RawBytes: []byte{ + Publish << 4, 11, // Fixed header + 0, 5, // Topic Name - LSB+MSB + '$', 'S', 'Y', 'S', '/', 'a', 'n', 'y', // Topic Name + 'y', // Payload + }, + }, + }, + + Puback: { + { + Case: TPuback, + Desc: "puback", + Primary: true, + RawBytes: []byte{ + Puback << 4, 2, // Fixed header + 0, 7, // Packet ID - LSB+MSB + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Puback, + Remaining: 2, + }, + PacketID: 7, + }, + }, + { + Case: TPubackMqtt5, + Desc: "mqtt5", + Primary: true, + RawBytes: []byte{ + Puback << 4, 20, // Fixed header + 0, 7, // Packet ID - LSB+MSB + CodeGrantedQos0.Code, // Reason Code + 16, // Properties Length + // 31, 0, 6, 'r', 'e', 'a', 's', 'o', 'n', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Puback, + Remaining: 20, + }, + PacketID: 7, + ReasonCode: CodeGrantedQos0.Code, + Properties: Properties{ + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + { + Case: TPubackMqtt5NotAuthorized, + Desc: "QOS 1 publish not authorized mqtt5", + Primary: true, + RawBytes: []byte{ + Puback << 4, 37, // Fixed header + 0, 7, // Packet ID - LSB+MSB + ErrNotAuthorized.Code, // Reason Code + 33, // Properties Length + 31, 0, 14, 'n', 'o', 't', ' ', 'a', 'u', + 't', 'h', 'o', 'r', 'i', 'z', 'e', 'd', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Puback, + Remaining: 31, + }, + PacketID: 7, + ReasonCode: ErrNotAuthorized.Code, + Properties: Properties{ + ReasonString: ErrNotAuthorized.Reason, + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + { + Case: TPubackUnexpectedError, + Desc: "unexpected error", + Group: "decode", + RawBytes: []byte{ + Puback << 4, 29, // Fixed header + 0, 7, // Packet ID - LSB+MSB + ErrPayloadFormatInvalid.Code, // Reason Code + 25, // Properties Length + 31, 0, 22, 'p', 'a', 'y', 'l', 'o', 'a', 'd', + ' ', 'f', 'o', 'r', 'm', 'a', 't', + ' ', 'i', 'n', 'v', 'a', 'l', 'i', 'd', // Reason String (31) + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Puback, + Remaining: 28, + }, + PacketID: 7, + ReasonCode: ErrPayloadFormatInvalid.Code, + Properties: Properties{ + ReasonString: ErrPayloadFormatInvalid.Reason, + }, + }, + }, + + // Fail states + { + Case: TPubackMalPacketID, + Desc: "malformed packet id", + Group: "decode", + FailFirst: ErrMalformedPacketID, + RawBytes: []byte{ + Puback << 4, 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + { + Case: TPubackMalProperties, + Desc: "malformed properties", + Group: "decode", + FailFirst: ErrMalformedProperties, + RawBytes: []byte{ + Puback << 4, 20, // Fixed header + 0, 7, // Packet ID - LSB+MSB + CodeGrantedQos0.Code, // Reason Code + 16, + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + }, + Pubrec: { + { + Case: TPubrec, + Desc: "pubrec", + Primary: true, + RawBytes: []byte{ + Pubrec << 4, 2, // Fixed header + 0, 7, // Packet ID - LSB+MSB + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Pubrec, + Remaining: 2, + }, + PacketID: 7, + }, + }, + { + Case: TPubrecMqtt5, + Desc: "pubrec mqtt5", + Primary: true, + RawBytes: []byte{ + Pubrec << 4, 20, // Fixed header + 0, 7, // Packet ID - LSB+MSB + CodeSuccess.Code, // Reason Code + 16, // Properties Length + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Pubrec, + Remaining: 20, + }, + PacketID: 7, + ReasonCode: CodeSuccess.Code, + Properties: Properties{ + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + { + Case: TPubrecMqtt5IDInUse, + Desc: "packet id in use mqtt5", + Primary: true, + RawBytes: []byte{ + Pubrec << 4, 47, // Fixed header + 0, 7, // Packet ID - LSB+MSB + ErrPacketIdentifierInUse.Code, // Reason Code + 43, // Properties Length + 31, 0, 24, 'p', 'a', 'c', 'k', 'e', 't', + ' ', 'i', 'd', 'e', 'n', 't', 'i', 'f', 'i', 'e', 'r', + ' ', 'i', 'n', + ' ', 'u', 's', 'e', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Pubrec, + Remaining: 31, + }, + PacketID: 7, + ReasonCode: ErrPacketIdentifierInUse.Code, + Properties: Properties{ + ReasonString: ErrPacketIdentifierInUse.Reason, + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + { + Case: TPubrecMqtt5NotAuthorized, + Desc: "QOS 2 publish not authorized mqtt5", + Primary: true, + RawBytes: []byte{ + Pubrec << 4, 37, // Fixed header + 0, 7, // Packet ID - LSB+MSB + ErrNotAuthorized.Code, // Reason Code + 33, // Properties Length + 31, 0, 14, 'n', 'o', 't', ' ', 'a', 'u', + 't', 'h', 'o', 'r', 'i', 'z', 'e', 'd', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Pubrec, + Remaining: 31, + }, + PacketID: 7, + ReasonCode: ErrNotAuthorized.Code, + Properties: Properties{ + ReasonString: ErrNotAuthorized.Reason, + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + { + Case: TPubrecMalReasonCode, + Desc: "malformed reason code", + Group: "decode", + FailFirst: ErrMalformedReasonCode, + RawBytes: []byte{ + Pubrec << 4, 31, // Fixed header + 0, 7, // Packet ID - LSB+MSB + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + // Validation + { + Case: TPubrecInvalidReason, + Desc: "invalid reason code", + Group: "validate", + FailFirst: ErrProtocolViolationInvalidReason, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Pubrec, + }, + PacketID: 7, + ReasonCode: ErrConnectionRateExceeded.Code, + }, + }, + // Fail states + { + Case: TPubrecMalPacketID, + Desc: "malformed packet id", + Group: "decode", + FailFirst: ErrMalformedPacketID, + RawBytes: []byte{ + Pubrec << 4, 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + { + Case: TPubrecMalProperties, + Desc: "malformed properties", + Group: "decode", + FailFirst: ErrMalformedProperties, + RawBytes: []byte{ + Pubrec << 4, 31, // Fixed header + 0, 7, // Packet ID - LSB+MSB + ErrPacketIdentifierInUse.Code, // Reason Code + 27, + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + }, + Pubrel: { + { + Case: TPubrel, + Desc: "pubrel", + Primary: true, + RawBytes: []byte{ + Pubrel<<4 | 1<<1, 2, // Fixed header + 0, 7, // Packet ID - LSB+MSB + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Pubrel, + Remaining: 2, + Qos: 1, + }, + PacketID: 7, + }, + }, + { + Case: TPubrelMqtt5, + Desc: "pubrel mqtt5", + Primary: true, + RawBytes: []byte{ + Pubrel<<4 | 1<<1, 20, // Fixed header + 0, 7, // Packet ID - LSB+MSB + CodeSuccess.Code, // Reason Code + 16, // Properties Length + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Pubrel, + Remaining: 20, + Qos: 1, + }, + PacketID: 7, + ReasonCode: CodeSuccess.Code, + Properties: Properties{ + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + { + Case: TPubrelMqtt5AckNoPacket, + Desc: "mqtt5 no packet id ack", + Primary: true, + RawBytes: append([]byte{ + Pubrel<<4 | 1<<1, 34, // Fixed header + 0, 7, // Packet ID - LSB+MSB + ErrPacketIdentifierNotFound.Code, // Reason Code + 30, // Properties Length + 31, 0, byte(len(ErrPacketIdentifierNotFound.Reason)), + }, []byte(ErrPacketIdentifierNotFound.Reason)...), + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Pubrel, + Remaining: 34, + Qos: 1, + }, + PacketID: 7, + ReasonCode: ErrPacketIdentifierNotFound.Code, + Properties: Properties{ + ReasonString: ErrPacketIdentifierNotFound.Reason, + }, + }, + }, + // Validation + { + Case: TPubrelInvalidReason, + Desc: "invalid reason code", + Group: "validate", + FailFirst: ErrProtocolViolationInvalidReason, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Pubrel, + }, + PacketID: 7, + ReasonCode: ErrConnectionRateExceeded.Code, + }, + }, + // Fail states + { + Case: TPubrelMalPacketID, + Desc: "malformed packet id", + Group: "decode", + FailFirst: ErrMalformedPacketID, + RawBytes: []byte{ + Pubrel << 4, 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + { + Case: TPubrelMalProperties, + Desc: "malformed properties", + Group: "decode", + FailFirst: ErrMalformedProperties, + RawBytes: []byte{ + Pubrel<<4 | 1<<1, 20, // Fixed header + 0, 7, // Packet ID - LSB+MSB + CodeSuccess.Code, // Reason Code + 16, + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + }, + Pubcomp: { + { + Case: TPubcomp, + Desc: "pubcomp", + Primary: true, + RawBytes: []byte{ + Pubcomp << 4, 2, // Fixed header + 0, 7, // Packet ID - LSB+MSB + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Pubcomp, + Remaining: 2, + }, + PacketID: 7, + }, + }, + { + Case: TPubcompMqtt5, + Desc: "mqtt5", + Primary: true, + RawBytes: []byte{ + Pubcomp << 4, 20, // Fixed header + 0, 7, // Packet ID - LSB+MSB + CodeSuccess.Code, // Reason Code + 16, // Properties Length + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Pubcomp, + Remaining: 20, + }, + PacketID: 7, + ReasonCode: CodeSuccess.Code, + Properties: Properties{ + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + { + Case: TPubcompMqtt5AckNoPacket, + Desc: "mqtt5 no packet id ack", + Primary: true, + RawBytes: append([]byte{ + Pubcomp << 4, 34, // Fixed header + 0, 7, // Packet ID - LSB+MSB + ErrPacketIdentifierNotFound.Code, // Reason Code + 30, // Properties Length + 31, 0, byte(len(ErrPacketIdentifierNotFound.Reason)), + }, []byte(ErrPacketIdentifierNotFound.Reason)...), + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Pubcomp, + Remaining: 34, + }, + PacketID: 7, + ReasonCode: ErrPacketIdentifierNotFound.Code, + Properties: Properties{ + ReasonString: ErrPacketIdentifierNotFound.Reason, + }, + }, + }, + // Validation + { + Case: TPubcompInvalidReason, + Desc: "invalid reason code", + Group: "validate", + FailFirst: ErrProtocolViolationInvalidReason, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Pubcomp, + }, + ReasonCode: ErrConnectionRateExceeded.Code, + }, + }, + // Fail states + { + Case: TPubcompMalPacketID, + Desc: "malformed packet id", + Group: "decode", + FailFirst: ErrMalformedPacketID, + RawBytes: []byte{ + Pubcomp << 4, 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + { + Case: TPubcompMalProperties, + Desc: "malformed properties", + Group: "decode", + FailFirst: ErrMalformedProperties, + RawBytes: []byte{ + Pubcomp << 4, 34, // Fixed header + 0, 7, // Packet ID - LSB+MSB + ErrPacketIdentifierNotFound.Code, // Reason Code + 22, + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + }, + Subscribe: { + { + Case: TSubscribe, + Desc: "subscribe", + Primary: true, + RawBytes: []byte{ + Subscribe<<4 | 1<<1, 10, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, // QoS + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Subscribe, + Remaining: 10, + Qos: 1, + }, + PacketID: 15, + Filters: Subscriptions{ + {Filter: "a/b/c"}, + }, + }, + }, + { + Case: TSubscribeMany, + Desc: "many", + Primary: true, + RawBytes: []byte{ + Subscribe<<4 | 1<<1, 30, // Fixed header + 0, 15, // Packet ID - LSB+MSB + + 0, 3, // Topic Name - LSB+MSB + 'a', '/', 'b', // Topic Name + 0, // QoS + + 0, 11, // Topic Name - LSB+MSB + 'd', '/', 'e', '/', 'f', '/', 'g', '/', 'h', '/', 'i', // Topic Name + 1, // QoS + + 0, 5, // Topic Name - LSB+MSB + 'x', '/', 'y', '/', 'z', // Topic Name + 2, // QoS + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Subscribe, + Remaining: 30, + Qos: 1, + }, + PacketID: 15, + Filters: Subscriptions{ + {Filter: "a/b", Qos: 0}, + {Filter: "d/e/f/g/h/i", Qos: 1}, + {Filter: "x/y/z", Qos: 2}, + }, + }, + }, + { + Case: TSubscribeMqtt5, + Desc: "mqtt5", + Primary: true, + RawBytes: []byte{ + Subscribe<<4 | 1<<1, 31, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 20, + 11, 202, 212, 19, // Subscription Identifier (11) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + + 0, 5, 'a', '/', 'b', '/', 'c', // Topic Name + 46, // subscription options + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Subscribe, + Remaining: 31, + Qos: 1, + }, + PacketID: 15, + Filters: Subscriptions{ + { + Filter: "a/b/c", + Qos: 2, + NoLocal: true, + RetainAsPublished: true, + RetainHandling: 2, + Identifier: 322122, + }, + }, + Properties: Properties{ + SubscriptionIdentifier: []int{322122}, + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + { + Case: TSubscribeRetainHandling1, + Desc: "retain handling 1", + Primary: true, + RawBytes: []byte{ + Subscribe<<4 | 1<<1, 11, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 0, // no properties + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0 | 1<<4, // subscription options + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Subscribe, + Remaining: 11, + Qos: 1, + }, + PacketID: 15, + Filters: Subscriptions{ + { + Filter: "a/b/c", + RetainHandling: 1, + }, + }, + }, + }, + { + Case: TSubscribeRetainHandling2, + Desc: "retain handling 2", + Primary: true, + RawBytes: []byte{ + Subscribe<<4 | 1<<1, 11, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 0, // no properties + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0 | 2<<4, // subscription options + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Subscribe, + Remaining: 11, + Qos: 1, + }, + PacketID: 15, + Filters: Subscriptions{ + { + Filter: "a/b/c", + RetainHandling: 2, + }, + }, + }, + }, + { + Case: TSubscribeRetainAsPublished, + Desc: "retain as published", + Primary: true, + RawBytes: []byte{ + Subscribe<<4 | 1<<1, 11, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 0, // no properties + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0 | 1<<3, // subscription options + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Subscribe, + Remaining: 11, + Qos: 1, + }, + PacketID: 15, + Filters: Subscriptions{ + { + Filter: "a/b/c", + RetainAsPublished: true, + }, + }, + }, + }, + { + Case: TSubscribeInvalidFilter, + Desc: "invalid filter", + Group: "reference", + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Subscribe, + Qos: 1, + }, + PacketID: 15, + Filters: Subscriptions{ + {Filter: "$SHARE/#", Identifier: 5}, + }, + }, + }, + { + Case: TSubscribeInvalidSharedNoLocal, + Desc: "shared and no local", + Group: "reference", + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Subscribe, + Qos: 1, + }, + PacketID: 15, + Filters: Subscriptions{ + {Filter: "$SHARE/tmp/a/b/c", Identifier: 5, NoLocal: true}, + }, + }, + }, + + // Fail states + { + Case: TSubscribeMalPacketID, + Desc: "malformed packet id", + Group: "decode", + FailFirst: ErrMalformedPacketID, + RawBytes: []byte{ + Subscribe<<4 | 1<<1, 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + { + Case: TSubscribeMalTopic, + Desc: "malformed topic", + Group: "decode", + FailFirst: ErrMalformedTopic, + RawBytes: []byte{ + Subscribe<<4 | 1<<1, 2, // Fixed header + 0, 21, // Packet ID - LSB+MSB + + 0, 3, // Topic Name - LSB+MSB + 'a', '/', + }, + }, + { + Case: TSubscribeMalQos, + Desc: "malformed subscribe - qos", + Group: "decode", + FailFirst: ErrMalformedQos, + RawBytes: []byte{ + Subscribe<<4 | 1<<1, 2, // Fixed header + 0, 22, // Packet ID - LSB+MSB + + 0, 3, // Topic Name - LSB+MSB + 'j', '/', 'b', // Topic Name + + }, + }, + { + Case: TSubscribeMalQosRange, + Desc: "malformed qos out of range", + Group: "decode", + FailFirst: ErrProtocolViolationQosOutOfRange, + RawBytes: []byte{ + Subscribe<<4 | 1<<1, 2, // Fixed header + 0, 22, // Packet ID - LSB+MSB + + 0, 3, // Topic Name - LSB+MSB + 'c', '/', 'd', // Topic Name + 5, // QoS + + }, + }, + { + Case: TSubscribeMalProperties, + Desc: "malformed properties", + Group: "decode", + FailFirst: ErrMalformedProperties, + RawBytes: []byte{ + Subscribe << 4, 11, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 4, + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + + // Validation + { + Case: TSubscribeInvalidQosMustPacketID, + Desc: "no packet id with qos > 0", + Group: "validate", + Expect: ErrProtocolViolationNoPacketID, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Subscribe, + Qos: 1, + }, + PacketID: 0, + Filters: Subscriptions{ + {Filter: "a/b"}, + }, + }, + }, + { + Case: TSubscribeInvalidNoFilters, + Desc: "no filters", + Group: "validate", + Expect: ErrProtocolViolationNoFilters, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Subscribe, + Qos: 1, + }, + PacketID: 2, + }, + }, + + { + Case: TSubscribeInvalidIdentifierOversize, + Desc: "oversize identifier", + Group: "validate", + Expect: ErrProtocolViolationOversizeSubID, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Subscribe, + Qos: 1, + }, + PacketID: 2, + Filters: Subscriptions{ + {Filter: "a/b", Identifier: 5}, + {Filter: "d/f", Identifier: 268435456}, + }, + }, + }, + + // Spec tests + { + Case: TSubscribeSpecQosMustPacketID, + Desc: "no packet id with qos > 0", + Group: "encode", + Expect: ErrProtocolViolationNoPacketID, + RawBytes: []byte{ + Subscribe<<4 | 1<<1, 10, // Fixed header + 0, 0, // Packet ID - LSB+MSB + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 1, // QoS + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Subscribe, + Qos: 1, + Remaining: 10, + }, + Filters: Subscriptions{ + {Filter: "a/b/c", Qos: 1}, + }, + PacketID: 0, + }, + }, + }, + Suback: { + { + Case: TSuback, + Desc: "suback", + Primary: true, + RawBytes: []byte{ + Suback << 4, 3, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 0, // Return Code QoS 0 + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Suback, + Remaining: 3, + }, + PacketID: 15, + ReasonCodes: []byte{0}, + }, + }, + { + Case: TSubackMany, + Desc: "many", + Primary: true, + RawBytes: []byte{ + Suback << 4, 6, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 0, // Return Code QoS 0 + 1, // Return Code QoS 1 + 2, // Return Code QoS 2 + 0x80, // Return Code fail + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Suback, + Remaining: 6, + }, + PacketID: 15, + ReasonCodes: []byte{0, 1, 2, 0x80}, + }, + }, + { + Case: TSubackDeny, + Desc: "deny mqtt5", + Primary: true, + RawBytes: []byte{ + Suback << 4, 4, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 0, // no properties + ErrNotAuthorized.Code, // Return Code QoS 0 + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Suback, + Remaining: 4, + }, + PacketID: 15, + ReasonCodes: []byte{ErrNotAuthorized.Code}, + }, + }, + { + Case: TSubackUnspecifiedError, + Desc: "unspecified error", + Primary: true, + RawBytes: []byte{ + Suback << 4, 3, // Fixed header + 0, 15, // Packet ID - LSB+MSB + ErrUnspecifiedError.Code, // Return Code QoS 0 + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Suback, + Remaining: 3, + }, + PacketID: 15, + ReasonCodes: []byte{ErrUnspecifiedError.Code}, + }, + }, + { + Case: TSubackUnspecifiedErrorMqtt5, + Desc: "unspecified error mqtt5", + Primary: true, + RawBytes: []byte{ + Suback << 4, 4, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 0, // no properties + ErrUnspecifiedError.Code, // Return Code QoS 0 + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Suback, + Remaining: 4, + }, + PacketID: 15, + ReasonCodes: []byte{ErrUnspecifiedError.Code}, + }, + }, + { + Case: TSubackMqtt5, + Desc: "mqtt5", + Primary: true, + RawBytes: []byte{ + Suback << 4, 20, // Fixed header + 0, 15, // Packet ID + 16, // Properties Length + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + CodeGrantedQos2.Code, // Return Code QoS 0 + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Suback, + Remaining: 20, + }, + PacketID: 15, + ReasonCodes: []byte{CodeGrantedQos2.Code}, + Properties: Properties{ + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + { + Case: TSubackPacketIDInUse, + Desc: "packet id in use", + Primary: true, + RawBytes: []byte{ + Suback << 4, 47, // Fixed header + 0, 15, // Packet ID + 43, // Properties Length + 31, 0, 24, 'p', 'a', 'c', 'k', 'e', 't', + ' ', 'i', 'd', 'e', 'n', 't', 'i', 'f', 'i', 'e', 'r', + ' ', 'i', 'n', + ' ', 'u', 's', 'e', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + ErrPacketIdentifierInUse.Code, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Suback, + Remaining: 47, + }, + PacketID: 15, + ReasonCodes: []byte{ErrPacketIdentifierInUse.Code}, + Properties: Properties{ + ReasonString: ErrPacketIdentifierInUse.Reason, + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + + // Fail states + { + Case: TSubackMalPacketID, + Desc: "malformed packet id", + Group: "decode", + FailFirst: ErrMalformedPacketID, + RawBytes: []byte{ + Suback << 4, 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + { + Case: TSubackMalProperties, + Desc: "malformed properties", + Group: "decode", + FailFirst: ErrMalformedProperties, + RawBytes: []byte{ + Suback << 4, 47, + 0, 15, + 43, + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + + { + Case: TSubackInvalidFilter, + Desc: "malformed packet id", + Group: "reference", + RawBytes: []byte{ + Suback << 4, 4, + 0, 15, + 0, // no properties + ErrTopicFilterInvalid.Code, + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + { + Case: TSubackInvalidSharedNoLocal, + Desc: "invalid shared no local", + Group: "reference", + RawBytes: []byte{ + Suback << 4, 4, + 0, 15, + 0, // no properties + ErrProtocolViolationInvalidSharedNoLocal.Code, + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + }, + + Unsubscribe: { + { + Case: TUnsubscribe, + Desc: "unsubscribe", + Primary: true, + RawBytes: []byte{ + Unsubscribe<<4 | 1<<1, 9, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Unsubscribe, + Remaining: 9, + Qos: 1, + }, + PacketID: 15, + Filters: Subscriptions{ + {Filter: "a/b/c"}, + }, + }, + }, + { + Case: TUnsubscribeMany, + Desc: "unsubscribe many", + Primary: true, + RawBytes: []byte{ + Unsubscribe<<4 | 1<<1, 27, // Fixed header + 0, 35, // Packet ID - LSB+MSB + + 0, 3, // Topic Name - LSB+MSB + 'a', '/', 'b', // Topic Name + + 0, 11, // Topic Name - LSB+MSB + 'd', '/', 'e', '/', 'f', '/', 'g', '/', 'h', '/', 'i', // Topic Name + + 0, 5, // Topic Name - LSB+MSB + 'x', '/', 'y', '/', 'z', // Topic Name + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Unsubscribe, + Remaining: 27, + Qos: 1, + }, + PacketID: 35, + Filters: Subscriptions{ + {Filter: "a/b"}, + {Filter: "d/e/f/g/h/i"}, + {Filter: "x/y/z"}, + }, + }, + }, + { + Case: TUnsubscribeMqtt5, + Desc: "mqtt5", + Primary: true, + RawBytes: []byte{ + Unsubscribe<<4 | 1<<1, 31, // Fixed header + 0, 15, // Packet ID - LSB+MSB + + 16, + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + + 0, 3, // Topic Name - LSB+MSB + 'a', '/', 'b', + + 0, 5, // Topic Name - LSB+MSB + 'x', '/', 'y', '/', 'w', + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Unsubscribe, + Remaining: 31, + Qos: 1, + }, + PacketID: 15, + Properties: Properties{ + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + Filters: Subscriptions{ + {Filter: "a/b"}, + {Filter: "x/y/w"}, + }, + }, + }, + + // Fail states + { + Case: TUnsubscribeMalPacketID, + Desc: "malformed packet id", + Group: "decode", + FailFirst: ErrMalformedPacketID, + RawBytes: []byte{ + Unsubscribe << 4, 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + { + Case: TUnsubscribeMalTopicName, + Desc: "malformed topic", + Group: "decode", + FailFirst: ErrMalformedTopic, + RawBytes: []byte{ + Unsubscribe << 4, 2, // Fixed header + 0, 21, // Packet ID - LSB+MSB + 0, 3, // Topic Name - LSB+MSB + 'a', '/', + }, + }, + { + Case: TUnsubscribeMalProperties, + Desc: "malformed properties", + Group: "decode", + FailFirst: ErrMalformedProperties, + RawBytes: []byte{ + Unsubscribe<<4 | 1<<1, 31, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 16, + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + + { + Case: TUnsubscribeInvalidQosMustPacketID, + Desc: "no packet id with qos > 0", + Group: "validate", + Expect: ErrProtocolViolationNoPacketID, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Unsubscribe, + Qos: 1, + }, + PacketID: 0, + Filters: Subscriptions{ + Subscription{Filter: "a/b"}, + }, + }, + }, + { + Case: TUnsubscribeInvalidNoFilters, + Desc: "no filters", + Group: "validate", + Expect: ErrProtocolViolationNoFilters, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Unsubscribe, + Qos: 1, + }, + PacketID: 2, + }, + }, + + { + Case: TUnsubscribeSpecQosMustPacketID, + Desc: "no packet id with qos > 0", + Group: "encode", + Expect: ErrProtocolViolationNoPacketID, + RawBytes: []byte{ + Unsubscribe<<4 | 1<<1, 9, // Fixed header + 0, 0, // Packet ID - LSB+MSB + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Unsubscribe, + Qos: 1, + Remaining: 9, + }, + PacketID: 0, + Filters: Subscriptions{ + {Filter: "a/b/c"}, + }, + }, + }, + }, + Unsuback: { + { + Case: TUnsuback, + Desc: "unsuback", + Primary: true, + RawBytes: []byte{ + Unsuback << 4, 2, // Fixed header + 0, 15, // Packet ID - LSB+MSB + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Unsuback, + Remaining: 2, + }, + PacketID: 15, + }, + }, + { + Case: TUnsubackMany, + Desc: "unsuback many", + Primary: true, + RawBytes: []byte{ + Unsuback << 4, 5, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 0, + CodeSuccess.Code, CodeSuccess.Code, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Unsuback, + Remaining: 5, + }, + PacketID: 15, + ReasonCodes: []byte{CodeSuccess.Code, CodeSuccess.Code}, + }, + }, + { + Case: TUnsubackMqtt5, + Desc: "mqtt5", + Primary: true, + RawBytes: []byte{ + Unsuback << 4, 21, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 16, // Properties Length + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + CodeSuccess.Code, CodeNoSubscriptionExisted.Code, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Unsuback, + Remaining: 21, + }, + PacketID: 15, + ReasonCodes: []byte{CodeSuccess.Code, CodeNoSubscriptionExisted.Code}, + Properties: Properties{ + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + { + Case: TUnsubackPacketIDInUse, + Desc: "packet id in use", + Primary: true, + RawBytes: []byte{ + Unsuback << 4, 48, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 43, // Properties Length + 31, 0, 24, 'p', 'a', 'c', 'k', 'e', 't', + ' ', 'i', 'd', 'e', 'n', 't', 'i', 'f', 'i', 'e', 'r', + ' ', 'i', 'n', + ' ', 'u', 's', 'e', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + ErrPacketIdentifierInUse.Code, ErrPacketIdentifierInUse.Code, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Unsuback, + Remaining: 48, + }, + PacketID: 15, + ReasonCodes: []byte{ErrPacketIdentifierInUse.Code, ErrPacketIdentifierInUse.Code}, + Properties: Properties{ + ReasonString: ErrPacketIdentifierInUse.Reason, + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + + // Fail states + { + Case: TUnsubackMalPacketID, + Desc: "malformed packet id", + Group: "decode", + FailFirst: ErrMalformedPacketID, + RawBytes: []byte{ + Unsuback << 4, 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + { + Case: TUnsubackMalProperties, + Desc: "malformed properties", + Group: "decode", + FailFirst: ErrMalformedProperties, + RawBytes: []byte{ + Unsuback << 4, 48, // Fixed header + 0, 15, // Packet ID - LSB+MSB + 43, + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + }, + + Pingreq: { + { + Case: TPingreq, + Desc: "ping request", + Primary: true, + RawBytes: []byte{ + Pingreq << 4, 0, // fixed header + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Pingreq, + Remaining: 0, + }, + }, + }, + }, + Pingresp: { + { + Case: TPingresp, + Desc: "ping response", + Primary: true, + RawBytes: []byte{ + Pingresp << 4, 0, // fixed header + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Pingresp, + Remaining: 0, + }, + }, + }, + }, + + Disconnect: { + { + Case: TDisconnect, + Desc: "disconnect", + Primary: true, + RawBytes: []byte{ + Disconnect << 4, 0, // fixed header + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Disconnect, + Remaining: 0, + }, + }, + }, + { + Case: TDisconnectTakeover, + Desc: "takeover", + Primary: true, + RawBytes: append([]byte{ + Disconnect << 4, 21, // fixed header + ErrSessionTakenOver.Code, // Reason Code + 19, // Properties Length + 31, 0, 16, // Reason String (31) + }, []byte(ErrSessionTakenOver.Reason)...), + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Disconnect, + Remaining: 0, + }, + ReasonCode: ErrSessionTakenOver.Code, + Properties: Properties{ + ReasonString: ErrSessionTakenOver.Reason, + }, + }, + }, + { + Case: TDisconnectShuttingDown, + Desc: "shutting down", + Primary: true, + RawBytes: append([]byte{ + Disconnect << 4, 25, // fixed header + ErrServerShuttingDown.Code, // Reason Code + 23, // Properties Length + 31, 0, 20, // Reason String (31) + }, []byte(ErrServerShuttingDown.Reason)...), + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Disconnect, + Remaining: 0, + }, + ReasonCode: ErrServerShuttingDown.Code, + Properties: Properties{ + ReasonString: ErrServerShuttingDown.Reason, + }, + }, + }, + { + Case: TDisconnectMqtt5, + Desc: "mqtt5", + Primary: true, + RawBytes: append([]byte{ + Disconnect << 4, 22, // fixed header + CodeDisconnect.Code, // Reason Code + 20, // Properties Length + 17, 0, 0, 0, 120, // Session Expiry Interval (17) + 31, 0, 12, // Reason String (31) + }, []byte(CodeDisconnect.Reason)...), + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Disconnect, + Remaining: 22, + }, + ReasonCode: CodeDisconnect.Code, + Properties: Properties{ + ReasonString: CodeDisconnect.Reason, + SessionExpiryInterval: 120, + SessionExpiryIntervalFlag: true, + }, + }, + }, + { + Case: TDisconnectMqtt5DisconnectWithWillMessage, + Desc: "mqtt5 disconnect with will message", + Primary: true, + RawBytes: append([]byte{ + Disconnect << 4, 38, // fixed header + CodeDisconnectWillMessage.Code, // Reason Code + 36, // Properties Length + 17, 0, 0, 0, 120, // Session Expiry Interval (17) + 31, 0, 28, // Reason String (31) + }, []byte(CodeDisconnectWillMessage.Reason)...), + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Disconnect, + Remaining: 22, + }, + ReasonCode: CodeDisconnectWillMessage.Code, + Properties: Properties{ + ReasonString: CodeDisconnectWillMessage.Reason, + SessionExpiryInterval: 120, + SessionExpiryIntervalFlag: true, + }, + }, + }, + { + Case: TDisconnectSecondConnect, + Desc: "second connect packet mqtt5", + RawBytes: append([]byte{ + Disconnect << 4, 46, // fixed header + ErrProtocolViolationSecondConnect.Code, + 44, + 31, 0, 41, // Reason String (31) + }, []byte(ErrProtocolViolationSecondConnect.Reason)...), + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Disconnect, + Remaining: 45, + }, + ReasonCode: ErrProtocolViolationSecondConnect.Code, + Properties: Properties{ + ReasonString: ErrProtocolViolationSecondConnect.Reason, + }, + }, + }, + { + Case: TDisconnectZeroNonZeroExpiry, + Desc: "zero non zero expiry", + RawBytes: []byte{ + Disconnect << 4, 2, // fixed header + ErrProtocolViolationZeroNonZeroExpiry.Code, + 0, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Disconnect, + Remaining: 2, + }, + ReasonCode: ErrProtocolViolationZeroNonZeroExpiry.Code, + }, + }, + { + Case: TDisconnectReceiveMaximum, + Desc: "receive maximum mqtt5", + RawBytes: append([]byte{ + Disconnect << 4, 29, // fixed header + ErrReceiveMaximum.Code, + 27, // Properties Length + 31, 0, 24, // Reason String (31) + }, []byte(ErrReceiveMaximum.Reason)...), + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Disconnect, + Remaining: 29, + }, + ReasonCode: ErrReceiveMaximum.Code, + Properties: Properties{ + ReasonString: ErrReceiveMaximum.Reason, + }, + }, + }, + { + Case: TDisconnectDropProperties, + Desc: "drop oversize properties partial", + Group: "encode", + RawBytes: []byte{ + Disconnect << 4, 39, // fixed header + CodeDisconnect.Code, + 19, // length + 28, 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', // Server Reference (28) + 31, 0, 6, 'r', 'e', 'a', 's', 'o', 'n', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + ActualBytes: []byte{ + Disconnect << 4, 12, // fixed header + CodeDisconnect.Code, + 10, // length + 28, 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', // Server Reference (28) + }, + Packet: &Packet{ + Mods: Mods{ + MaxSize: 3, + }, + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Disconnect, + Remaining: 40, + }, + ReasonCode: CodeSuccess.Code, + Properties: Properties{ + ReasonString: "reason", + ServerReference: "mochi-2", + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + // fail states + { + Case: TDisconnectMalProperties, + Desc: "malformed properties", + Group: "decode", + FailFirst: ErrMalformedProperties, + RawBytes: []byte{ + Disconnect << 4, 48, // fixed header + CodeDisconnect.Code, // Reason Code + 46, + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + { + Case: TDisconnectMalReasonCode, + Desc: "malformed reason code", + Group: "decode", + FailFirst: ErrMalformedReasonCode, + RawBytes: []byte{ + Disconnect << 4, 48, // fixed header + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + }, + Auth: { + { + Case: TAuth, + Desc: "auth", + Primary: true, + RawBytes: []byte{ + Auth << 4, 47, + CodeSuccess.Code, // reason code + 45, + 21, 0, 5, 'S', 'H', 'A', '-', '1', // Authentication Method (21) + 22, 0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a', // Authentication Data (22) + 31, 0, 6, 'r', 'e', 'a', 's', 'o', 'n', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Auth, + Remaining: 47, + }, + ReasonCode: CodeSuccess.Code, + Properties: Properties{ + AuthenticationMethod: "SHA-1", + AuthenticationData: []byte("auth-data"), + ReasonString: "reason", + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, + { + Case: TAuthMalReasonCode, + Desc: "malformed reason code", + Group: "decode", + FailFirst: ErrMalformedReasonCode, + RawBytes: []byte{ + Auth << 4, 47, + }, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Auth, + }, + ReasonCode: CodeNoMatchingSubscribers.Code, + }, + }, + // fail states + { + Case: TAuthMalProperties, + Desc: "malformed properties", + Group: "decode", + FailFirst: ErrMalformedProperties, + RawBytes: []byte{ + Auth << 4, 3, + CodeSuccess.Code, + 12, + }, + Packet: &Packet{ + ProtocolVersion: 5, + }, + }, + // Validation + { + Case: TAuthInvalidReason, + Desc: "invalid reason code", + Group: "validate", + Expect: ErrProtocolViolationInvalidReason, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Auth, + }, + ReasonCode: CodeNoMatchingSubscribers.Code, + }, + }, + { + Case: TAuthInvalidReason2, + Desc: "invalid reason code", + Group: "validate", + Expect: ErrProtocolViolationInvalidReason, + Packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Auth, + }, + ReasonCode: CodeNoMatchingSubscribers.Code, + }, + }, + }, +} diff --git a/packets/tpackets_test.go b/packets/tpackets_test.go new file mode 100644 index 0000000..8114207 --- /dev/null +++ b/packets/tpackets_test.go @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package packets + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func encodeTestOK(wanted TPacketCase) bool { + if wanted.RawBytes == nil { + return false + } + if wanted.Group != "" && wanted.Group != "encode" { + return false + } + return true +} + +func decodeTestOK(wanted TPacketCase) bool { + if wanted.Group != "" && wanted.Group != "decode" { + return false + } + return true +} + +func TestTPacketCaseGet(t *testing.T) { + require.Equal(t, TPacketData[Connect][1], TPacketData[Connect].Get(TConnectMqtt311)) + require.Equal(t, TPacketCase{}, TPacketData[Connect].Get(byte(128))) +} diff --git a/system/config.yaml b/system/config.yaml new file mode 100644 index 0000000..6691526 --- /dev/null +++ b/system/config.yaml @@ -0,0 +1,70 @@ +listeners: + - type: "tcp" + id: "file-tcp1" + address: ":1883" + - type: "ws" + id: "file-websocket" + address: ":1882" + - type: "healthcheck" + id: "file-healthcheck" + address: ":1880" +hooks: + debug: + enable: true + storage: + badger: + path: badger.db + gc_interval: 3 + gc_discard_ratio: 0.5 + pebble: + path: pebble.db + mode: "NoSync" + bolt: + path: bolt.db + bucket: "mochi" + redis: + h_prefix: "mc" + username: "mochi" + password: "melon" + address: "localhost:6379" + database: 1 + auth: + allow_all: false + ledger: + auth: + - username: peach + password: password1 + allow: true + acl: + - remote: 127.0.0.1:* + - username: melon + filters: + melon/#: 3 + updates/#: 2 +options: + client_net_write_buffer_size: 2048 + client_net_read_buffer_size: 2048 + sys_topic_resend_interval: 10 + inline_client: true + capabilities: + maximum_message_expiry_interval: 100 + maximum_client_writes_pending: 8192 + maximum_session_expiry_interval: 86400 + maximum_packet_size: 0 + receive_maximum: 1024 + maximum_inflight: 8192 + topic_alias_maximum: 65535 + shared_sub_available: 1 + minimum_protocol_version: 3 + maximum_qos: 2 + retain_available: 1 + wildcard_sub_available: 1 + sub_id_available: 1 + compatibilities: + obscure_not_authorized: true + passive_client_disconnect: false + always_return_response_info: false + restore_sys_info_on_restart: false + no_inherited_properties_on_ack: false +logging: + level: INFO diff --git a/system/system.go b/system/system.go new file mode 100644 index 0000000..1ceda94 --- /dev/null +++ b/system/system.go @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package system + +import "sync/atomic" + +// Info 系统信息 包含各种服务器统计信息的原子计数器和值 +// Info contains atomic counters and values for various server statistics +// commonly found in $SYS topics (and others). +// based on https://github.com/mqtt/mqtt.org/wiki/SYS-Topics +type Info struct { + Version string `json:"version"` // the current version of the server + Started int64 `json:"started"` // the time the server started in unix seconds + Time int64 `json:"time"` // current time on the server + Uptime int64 `json:"uptime"` // the number of seconds the server has been online + BytesReceived int64 `json:"bytes_received"` // total number of bytes received since the broker started + BytesSent int64 `json:"bytes_sent"` // total number of bytes sent since the broker started + ClientsConnected int64 `json:"clients_connected"` // number of currently connected clients + ClientsDisconnected int64 `json:"clients_disconnected"` // total number of persistent clients (with clean session disabled) that are registered at the broker but are currently disconnected + ClientsMaximum int64 `json:"clients_maximum"` // maximum number of active clients that have been connected + ClientsTotal int64 `json:"clients_total"` // total number of connected and disconnected clients with a persistent session currently connected and registered + MessagesReceived int64 `json:"messages_received"` // total number of publish messages received + MessagesSent int64 `json:"messages_sent"` // total number of publish messages sent + MessagesDropped int64 `json:"messages_dropped"` // total number of publish messages dropped to slow subscriber + Retained int64 `json:"retained"` // total number of retained messages active on the broker + Inflight int64 `json:"inflight"` // the number of messages currently in-flight + InflightDropped int64 `json:"inflight_dropped"` // the number of inflight messages which were dropped + Subscriptions int64 `json:"subscriptions"` // total number of subscriptions active on the broker + PacketsReceived int64 `json:"packets_received"` // the total number of publish messages received + PacketsSent int64 `json:"packets_sent"` // total number of messages of any type sent since the broker started + MemoryAlloc int64 `json:"memory_alloc"` // memory currently allocated + Threads int64 `json:"threads"` // number of active goroutines, named as threads for platform ambiguity +} + +// Clone makes a copy of Info using atomic operation +func (i *Info) Clone() *Info { + return &Info{ + Version: i.Version, + Started: atomic.LoadInt64(&i.Started), + Time: atomic.LoadInt64(&i.Time), + Uptime: atomic.LoadInt64(&i.Uptime), + BytesReceived: atomic.LoadInt64(&i.BytesReceived), + BytesSent: atomic.LoadInt64(&i.BytesSent), + ClientsConnected: atomic.LoadInt64(&i.ClientsConnected), + ClientsMaximum: atomic.LoadInt64(&i.ClientsMaximum), + ClientsTotal: atomic.LoadInt64(&i.ClientsTotal), + ClientsDisconnected: atomic.LoadInt64(&i.ClientsDisconnected), + MessagesReceived: atomic.LoadInt64(&i.MessagesReceived), + MessagesSent: atomic.LoadInt64(&i.MessagesSent), + MessagesDropped: atomic.LoadInt64(&i.MessagesDropped), + Retained: atomic.LoadInt64(&i.Retained), + Inflight: atomic.LoadInt64(&i.Inflight), + InflightDropped: atomic.LoadInt64(&i.InflightDropped), + Subscriptions: atomic.LoadInt64(&i.Subscriptions), + PacketsReceived: atomic.LoadInt64(&i.PacketsReceived), + PacketsSent: atomic.LoadInt64(&i.PacketsSent), + MemoryAlloc: atomic.LoadInt64(&i.MemoryAlloc), + Threads: atomic.LoadInt64(&i.Threads), + } +} diff --git a/system/system_test.go b/system/system_test.go new file mode 100644 index 0000000..b76df21 --- /dev/null +++ b/system/system_test.go @@ -0,0 +1,37 @@ +package system + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestClone(t *testing.T) { + o := &Info{ + Version: "version", + Started: 1, + Time: 2, + Uptime: 3, + BytesReceived: 4, + BytesSent: 5, + ClientsConnected: 6, + ClientsMaximum: 7, + ClientsTotal: 8, + ClientsDisconnected: 9, + MessagesReceived: 10, + MessagesSent: 11, + MessagesDropped: 20, + Retained: 12, + Inflight: 13, + InflightDropped: 14, + Subscriptions: 15, + PacketsReceived: 16, + PacketsSent: 17, + MemoryAlloc: 18, + Threads: 19, + } + + n := o.Clone() + + require.Equal(t, o, n) +}