代码整理

This commit is contained in:
2024-08-21 15:32:05 +08:00
commit 84e5e65ee7
71 changed files with 29733 additions and 0 deletions

649
mqtt/clients.go Normal file
View File

@@ -0,0 +1,649 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/rs/xid"
"testmqtt/packets"
)
const (
defaultKeepalive uint16 = 10 // the default connection keepalive value in seconds.
defaultClientProtocolVersion byte = 4 // the default mqtt protocol version of connecting clients (if somehow unspecified).
minimumKeepalive uint16 = 5 // the minimum recommended keepalive - values under with display a warning.
)
var (
ErrMinimumKeepalive = errors.New("client keepalive is below minimum recommended value and may exhibit connection instability")
)
// ReadFn is the function signature for the function used for reading and processing new packets.
type ReadFn func(*Client, packets.Packet) error
// Clients contains a map of the clients known by the broker.
type Clients struct {
internal map[string]*Client // clients known by the broker, keyed on client id.
sync.RWMutex
}
// NewClients returns an instance of Clients.
func NewClients() *Clients {
return &Clients{
internal: make(map[string]*Client),
}
}
// Add adds a new client to the clients map, keyed on client id.
func (cl *Clients) Add(val *Client) {
cl.Lock()
defer cl.Unlock()
cl.internal[val.ID] = val
}
// GetAll returns all the clients.
func (cl *Clients) GetAll() map[string]*Client {
cl.RLock()
defer cl.RUnlock()
m := map[string]*Client{}
for k, v := range cl.internal {
m[k] = v
}
return m
}
// Get returns the value of a client if it exists.
func (cl *Clients) Get(id string) (*Client, bool) {
cl.RLock()
defer cl.RUnlock()
val, ok := cl.internal[id]
return val, ok
}
// Len returns the length of the clients map.
func (cl *Clients) Len() int {
cl.RLock()
defer cl.RUnlock()
val := len(cl.internal)
return val
}
// Delete removes a client from the internal map.
func (cl *Clients) Delete(id string) {
cl.Lock()
defer cl.Unlock()
delete(cl.internal, id)
}
// GetByListener returns clients matching a listener id.
func (cl *Clients) GetByListener(id string) []*Client {
cl.RLock()
defer cl.RUnlock()
clients := make([]*Client, 0, cl.Len())
for _, client := range cl.internal {
if client.Net.Listener == id && !client.Closed() {
clients = append(clients, client)
}
}
return clients
}
// Client contains information about a client known by the broker.
type Client struct {
Properties ClientProperties // client properties
State ClientState // the operational state of the client.
Net ClientConnection // network connection state of the client
ID string // the client id.
ops *ops // ops provides a reference to server ops.
sync.RWMutex // mutex
}
// ClientConnection contains the connection transport and metadata for the client.
type ClientConnection struct {
Conn net.Conn // the net.Conn used to establish the connection
bconn *bufio.Reader // a buffered net.Conn for reading packets
outbuf *bytes.Buffer // a buffer for writing packets
Remote string // the remote address of the client
Listener string // listener id of the client
Inline bool // if true, the client is the built-in 'inline' embedded client
}
// ClientProperties contains the properties which define the client behaviour.
type ClientProperties struct {
Props packets.Properties
Will Will
Username []byte
ProtocolVersion byte
Clean bool
}
// Will contains the last will and testament details for a client connection.
type Will struct {
Payload []byte // -
User []packets.UserProperty // -
TopicName string // -
Flag uint32 // 0,1
WillDelayInterval uint32 // -
Qos byte // -
Retain bool // -
}
// ClientState tracks the state of the client.
type ClientState struct {
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan *packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
open context.Context // indicate that the client is open for packet exchange
cancelOpen context.CancelFunc // cancel function for open context
outboundQty int32 // number of messages currently in the outbound queue
Keepalive uint16 // the number of seconds the connection can wait
ServerKeepalive bool // keepalive was set by the server
}
// newClient returns a new instance of Client. This is almost exclusively used by Server
// for creating new clients, but it lives here because it's not dependent.
func newClient(c net.Conn, o *ops) *Client {
ctx, cancel := context.WithCancel(context.Background())
cl := &Client{
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
open: ctx,
cancelOpen: cancel,
Keepalive: defaultKeepalive,
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
ops: o,
}
if c != nil {
cl.Net = ClientConnection{
Conn: c,
bconn: bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
Remote: c.RemoteAddr().String(),
}
}
return cl
}
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
func (cl *Client) WriteLoop() {
for {
select {
case pk := <-cl.State.outbound:
if err := cl.WritePacket(*pk); err != nil {
// TODO : Figure out what to do with error
cl.ops.log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk)
}
atomic.AddInt32(&cl.State.outboundQty, -1)
case <-cl.State.open.Done():
return
}
}
}
// ParseConnect parses the connect parameters and properties for a client.
func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Net.Listener = lid
cl.Properties.ProtocolVersion = pk.ProtocolVersion
cl.Properties.Username = pk.Connect.Username
cl.Properties.Clean = pk.Connect.Clean
cl.Properties.Props = pk.Properties.Copy(false)
if cl.Properties.Props.ReceiveMaximum > cl.ops.options.Capabilities.MaximumInflight { // 3.3.4 Non-normative
cl.Properties.Props.ReceiveMaximum = cl.ops.options.Capabilities.MaximumInflight
}
if pk.Connect.Keepalive <= minimumKeepalive {
cl.ops.log.Warn(
ErrMinimumKeepalive.Error(),
"client", cl.ID,
"keepalive", pk.Connect.Keepalive,
"recommended", minimumKeepalive,
)
}
cl.State.Keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
cl.ID = pk.Connect.ClientIdentifier
if cl.ID == "" {
cl.ID = xid.New().String() // [MQTT-3.1.3-6] [MQTT-3.1.3-7]
cl.Properties.Props.AssignedClientID = cl.ID
}
if pk.Connect.WillFlag {
cl.Properties.Will = Will{
Qos: pk.Connect.WillQos,
Retain: pk.Connect.WillRetain,
Payload: pk.Connect.WillPayload,
TopicName: pk.Connect.WillTopic,
WillDelayInterval: pk.Connect.WillProperties.WillDelayInterval,
User: pk.Connect.WillProperties.User,
}
if pk.Properties.SessionExpiryIntervalFlag &&
pk.Properties.SessionExpiryInterval < pk.Connect.WillProperties.WillDelayInterval {
cl.Properties.Will.WillDelayInterval = pk.Properties.SessionExpiryInterval
}
if pk.Connect.WillFlag {
cl.Properties.Will.Flag = 1 // atomic for checking
}
}
}
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (cl *Client) refreshDeadline(keepalive uint16) {
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
if cl.Net.Conn != nil {
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
}
}
// NextPacketID returns the next available (unused) packet id for the client.
// If no unused packet ids are available, an error is returned and the client
// should be disconnected.
func (cl *Client) NextPacketID() (i uint32, err error) {
cl.Lock()
defer cl.Unlock()
i = atomic.LoadUint32(&cl.State.packetID)
started := i
overflowed := false
for {
if overflowed && i == started {
return 0, packets.ErrQuotaExceeded
}
if i >= cl.ops.options.Capabilities.maximumPacketID {
overflowed = true
i = 0
continue
}
i++
if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
atomic.StoreUint32(&cl.State.packetID, i)
return i, nil
}
}
}
// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
func (cl *Client) ResendInflightMessages(force bool) error {
if cl.State.Inflight.Len() == 0 {
return nil
}
for _, tk := range cl.State.Inflight.GetAll(false) {
if tk.FixedHeader.Type == packets.Publish {
tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
}
cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0)
err := cl.WritePacket(tk)
if err != nil {
return err
}
if tk.FixedHeader.Type == packets.Puback || tk.FixedHeader.Type == packets.Pubcomp {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosComplete(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
}
}
}
return nil
}
// ClearInflights deletes all inflight messages for the client, e.g. for a disconnected user with a clean session.
func (cl *Client) ClearInflights() {
for _, tk := range cl.State.Inflight.GetAll(false) {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosDropped(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
}
}
}
// ClearExpiredInflights deletes any inflight messages which have expired.
func (cl *Client) ClearExpiredInflights(now, maximumExpiry int64) []uint16 {
deleted := []uint16{}
for _, tk := range cl.State.Inflight.GetAll(false) {
expired := tk.ProtocolVersion == 5 && tk.Expiry > 0 && tk.Expiry < now // [MQTT-3.3.2-5]
// If the maximum message expiry interval is set (greater than 0), and the message
// retention period exceeds the maximum expiry, the message will be forcibly removed.
enforced := maximumExpiry > 0 && now-tk.Created > maximumExpiry
if expired || enforced {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosDropped(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
deleted = append(deleted, tk.PacketID)
}
}
}
return deleted
}
// Read reads incoming packets from the connected client and transforms them into
// packets to be handled by the packetHandler.
func (cl *Client) Read(packetHandler ReadFn) error {
var err error
for {
if cl.Closed() {
return nil
}
cl.refreshDeadline(cl.State.Keepalive)
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
if err != nil {
return err
}
pk, err := cl.ReadPacket(fh)
if err != nil {
return err
}
err = packetHandler(cl, pk) // Process inbound packet.
if err != nil {
return err
}
}
}
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop(err error) {
cl.State.endOnce.Do(func() {
if cl.Net.Conn != nil {
_ = cl.Net.Conn.Close() // omit close error
}
if err != nil {
cl.State.stopCause.Store(err)
}
if cl.State.cancelOpen != nil {
cl.State.cancelOpen()
}
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
})
}
// StopCause returns the reason the client connection was stopped, if any.
func (cl *Client) StopCause() error {
if cl.State.stopCause.Load() == nil {
return nil
}
return cl.State.stopCause.Load().(error)
}
// StopTime returns the the time the client disconnected in unix time, else zero.
func (cl *Client) StopTime() int64 {
return atomic.LoadInt64(&cl.State.disconnected)
}
// Closed returns true if client connection is closed.
func (cl *Client) Closed() bool {
return cl.State.open == nil || cl.State.open.Err() != nil
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
if cl.Net.bconn == nil {
return ErrConnectionClosed
}
b, err := cl.Net.bconn.ReadByte()
if err != nil {
return err
}
err = fh.Decode(b)
if err != nil {
return err
}
var bu int
fh.Remaining, bu, err = packets.DecodeLength(cl.Net.bconn)
if err != nil {
return err
}
if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize {
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
return nil
}
// ReadPacket reads the remaining buffer into an MQTT packet.
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
atomic.AddInt64(&cl.ops.info.PacketsReceived, 1)
pk.ProtocolVersion = cl.Properties.ProtocolVersion // inherit client protocol version for decoding
pk.FixedHeader = *fh
p := make([]byte, pk.FixedHeader.Remaining)
n, err := io.ReadFull(cl.Net.bconn, p)
if err != nil {
return pk, err
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(n))
// Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one.
px := append([]byte{}, p[:]...)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectDecode(px)
case packets.Disconnect:
err = pk.DisconnectDecode(px)
case packets.Connack:
err = pk.ConnackDecode(px)
case packets.Publish:
err = pk.PublishDecode(px)
if err == nil {
atomic.AddInt64(&cl.ops.info.MessagesReceived, 1)
}
case packets.Puback:
err = pk.PubackDecode(px)
case packets.Pubrec:
err = pk.PubrecDecode(px)
case packets.Pubrel:
err = pk.PubrelDecode(px)
case packets.Pubcomp:
err = pk.PubcompDecode(px)
case packets.Subscribe:
err = pk.SubscribeDecode(px)
case packets.Suback:
err = pk.SubackDecode(px)
case packets.Unsubscribe:
err = pk.UnsubscribeDecode(px)
case packets.Unsuback:
err = pk.UnsubackDecode(px)
case packets.Pingreq:
case packets.Pingresp:
case packets.Auth:
err = pk.AuthDecode(px)
default:
err = fmt.Errorf("invalid packet type; %v", pk.FixedHeader.Type)
}
if err != nil {
return pk, err
}
pk, err = cl.ops.hooks.OnPacketRead(cl, pk)
return
}
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) error {
if cl.Closed() {
return ErrConnectionClosed
}
if cl.Net.Conn == nil {
return nil
}
if pk.Expiry > 0 {
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
}
pk.ProtocolVersion = cl.Properties.ProtocolVersion
if pk.Mods.MaxSize == 0 { // NB we use this statement to embed client packet sizes in tests
pk.Mods.MaxSize = cl.Properties.Props.MaximumPacketSize
}
if cl.Properties.Props.RequestProblemInfoFlag && cl.Properties.Props.RequestProblemInfo == 0x0 {
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
}
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
}
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
var err error
buf := new(bytes.Buffer)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectEncode(buf)
case packets.Connack:
err = pk.ConnackEncode(buf)
case packets.Publish:
err = pk.PublishEncode(buf)
case packets.Puback:
err = pk.PubackEncode(buf)
case packets.Pubrec:
err = pk.PubrecEncode(buf)
case packets.Pubrel:
err = pk.PubrelEncode(buf)
case packets.Pubcomp:
err = pk.PubcompEncode(buf)
case packets.Subscribe:
err = pk.SubscribeEncode(buf)
case packets.Suback:
err = pk.SubackEncode(buf)
case packets.Unsubscribe:
err = pk.UnsubscribeEncode(buf)
case packets.Unsuback:
err = pk.UnsubackEncode(buf)
case packets.Pingreq:
err = pk.PingreqEncode(buf)
case packets.Pingresp:
err = pk.PingrespEncode(buf)
case packets.Disconnect:
err = pk.DisconnectEncode(buf)
case packets.Auth:
err = pk.AuthEncode(buf)
default:
err = fmt.Errorf("%w: %v", packets.ErrNoValidPacketAvailable, pk.FixedHeader.Type)
}
if err != nil {
return err
}
if pk.Mods.MaxSize > 0 && uint32(buf.Len()) > pk.Mods.MaxSize {
return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25]
}
n, err := func() (int64, error) {
cl.Lock()
defer cl.Unlock()
if len(cl.State.outbound) == 0 {
if cl.Net.outbuf == nil {
return buf.WriteTo(cl.Net.Conn)
}
// first write to buffer, then flush buffer
n, _ := cl.Net.outbuf.Write(buf.Bytes()) // will always be successful
err = cl.flushOutbuf()
return int64(n), err
}
// there are more writes in the queue
if cl.Net.outbuf == nil {
if buf.Len() >= cl.ops.options.ClientNetWriteBufferSize {
return buf.WriteTo(cl.Net.Conn)
}
cl.Net.outbuf = new(bytes.Buffer)
}
n, _ := cl.Net.outbuf.Write(buf.Bytes()) // will always be successful
if cl.Net.outbuf.Len() < cl.ops.options.ClientNetWriteBufferSize {
return int64(n), nil
}
err = cl.flushOutbuf()
return int64(n), err
}()
if err != nil {
return err
}
atomic.AddInt64(&cl.ops.info.BytesSent, n)
atomic.AddInt64(&cl.ops.info.PacketsSent, 1)
if pk.FixedHeader.Type == packets.Publish {
atomic.AddInt64(&cl.ops.info.MessagesSent, 1)
}
cl.ops.hooks.OnPacketSent(cl, pk, buf.Bytes())
return err
}
func (cl *Client) flushOutbuf() (err error) {
if cl.Net.outbuf == nil {
return
}
_, err = cl.Net.outbuf.WriteTo(cl.Net.Conn)
if err == nil {
cl.Net.outbuf = nil
}
return
}

930
mqtt/clients_test.go Normal file
View File

@@ -0,0 +1,930 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"log/slog"
"net"
"strings"
"sync/atomic"
"testing"
"time"
"testmqtt/packets"
"testmqtt/system"
"github.com/stretchr/testify/require"
)
const pkInfo = "packet type %v, %s"
var errClientStop = errors.New("test stop")
func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
r, w = net.Pipe()
cl = newClient(w, &ops{
info: new(system.Info),
hooks: new(Hooks),
log: logger,
options: &Options{
Capabilities: &Capabilities{
ReceiveMaximum: 10,
MaximumInflight: 5,
TopicAliasMaximum: 10000,
MaximumClientWritesPending: 3,
maximumPacketID: 10,
},
},
})
cl.ID = "mochi"
cl.State.Inflight.maximumSendQuota = 5
cl.State.Inflight.sendQuota = 5
cl.State.Inflight.maximumReceiveQuota = 10
cl.State.Inflight.receiveQuota = 10
cl.Properties.Props.TopicAliasMaximum = 0
cl.Properties.Props.RequestResponseInfo = 0x1
go cl.WriteLoop()
return
}
func TestNewInflights(t *testing.T) {
require.NotNil(t, NewInflights().internal)
}
func TestNewClients(t *testing.T) {
cl := NewClients()
require.NotNil(t, cl.internal)
}
func TestClientsAdd(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
require.Contains(t, cl.internal, "t1")
}
func TestClientsGet(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
client, ok := cl.Get("t1")
require.Equal(t, true, ok)
require.Equal(t, "t1", client.ID)
}
func TestClientsGetAll(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
cl.Add(&Client{ID: "t3"})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
require.Contains(t, cl.internal, "t3")
clients := cl.GetAll()
require.Len(t, clients, 3)
}
func TestClientsLen(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
require.Equal(t, 2, cl.Len())
}
func TestClientsDelete(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
require.Contains(t, cl.internal, "t1")
cl.Delete("t1")
_, ok := cl.Get("t1")
require.Equal(t, false, ok)
require.Nil(t, cl.internal["t1"])
}
func TestClientsGetByListener(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
clients := cl.GetByListener("tcp1")
require.NotEmpty(t, clients)
require.Equal(t, 1, len(clients))
require.Equal(t, "tcp1", clients[0].Net.Listener)
}
func TestNewClient(t *testing.T) {
cl, _, _ := newTestClient()
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.Keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
require.NotNil(t, cl.ops)
require.NotNil(t, cl.ops.options.Capabilities)
require.False(t, cl.Net.Inline)
}
func TestClientParseConnect(t *testing.T) {
cl, _, _ := newTestClient()
pk := packets.Packet{
ProtocolVersion: 4,
Connect: packets.ConnectParams{
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
Clean: true,
Keepalive: 60,
ClientIdentifier: "mochi",
WillFlag: true,
WillTopic: "lwt",
WillPayload: []byte("lol gg"),
WillQos: 1,
WillRetain: false,
},
Properties: packets.Properties{
ReceiveMaximum: uint16(5),
},
}
cl.ParseConnect("tcp1", pk)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive)
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload)
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
}
func TestClientParseConnectReceiveMaxExceedMaxInflight(t *testing.T) {
const MaxInflight uint16 = 1
cl, _, _ := newTestClient()
cl.ops.options.Capabilities.MaximumInflight = MaxInflight
pk := packets.Packet{
ProtocolVersion: 4,
Connect: packets.ConnectParams{
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
Clean: true,
Keepalive: 60,
ClientIdentifier: "mochi",
WillFlag: true,
WillTopic: "lwt",
WillPayload: []byte("lol gg"),
WillQos: 1,
WillRetain: false,
},
Properties: packets.Properties{
ReceiveMaximum: uint16(5),
},
}
cl.ParseConnect("tcp1", pk)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive)
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload)
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
require.Equal(t, int32(MaxInflight), cl.State.Inflight.sendQuota)
require.Equal(t, int32(MaxInflight), cl.State.Inflight.maximumSendQuota)
}
func TestClientParseConnectOverrideWillDelay(t *testing.T) {
cl, _, _ := newTestClient()
pk := packets.Packet{
ProtocolVersion: 4,
Connect: packets.ConnectParams{
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
Clean: true,
Keepalive: 60,
ClientIdentifier: "mochi",
WillFlag: true,
WillProperties: packets.Properties{
WillDelayInterval: 200,
},
},
Properties: packets.Properties{
SessionExpiryInterval: 100,
SessionExpiryIntervalFlag: true,
},
}
cl.ParseConnect("tcp1", pk)
require.Equal(t, pk.Properties.SessionExpiryInterval, cl.Properties.Will.WillDelayInterval)
}
func TestClientParseConnectNoID(t *testing.T) {
cl, _, _ := newTestClient()
cl.ParseConnect("tcp1", packets.Packet{})
require.NotEmpty(t, cl.ID)
}
func TestClientParseConnectBelowMinimumKeepalive(t *testing.T) {
cl, _, _ := newTestClient()
var b bytes.Buffer
x := bufio.NewWriter(&b)
cl.ops.log = slog.New(slog.NewTextHandler(x, nil))
pk := packets.Packet{
ProtocolVersion: 4,
Connect: packets.ConnectParams{
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
Keepalive: minimumKeepalive - 1,
ClientIdentifier: "mochi",
},
}
cl.ParseConnect("tcp1", pk)
err := x.Flush()
require.NoError(t, err)
require.True(t, strings.Contains(b.String(), ErrMinimumKeepalive.Error()))
require.NotEmpty(t, cl.ID)
}
func TestClientNextPacketID(t *testing.T) {
cl, _, _ := newTestClient()
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(2), i)
}
func TestClientNextPacketIDInUse(t *testing.T) {
cl, _, _ := newTestClient()
// skip over 2
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(3), i)
// Skip over overflow
cl.State.Inflight.Set(packets.Packet{PacketID: 65535})
atomic.StoreUint32(&cl.State.packetID, 65534)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
}
func TestClientNextPacketIDExhausted(t *testing.T) {
cl, _, _ := newTestClient()
for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)}
}
i, err := cl.NextPacketID()
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
require.Equal(t, uint32(0), i)
}
func TestClientNextPacketIDOverflow(t *testing.T) {
cl, _, _ := newTestClient()
for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
}
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID - 1
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{}
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID
_, err = cl.NextPacketID()
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
}
func TestClientClearInflights(t *testing.T) {
cl, _, _ := newTestClient()
n := time.Now().Unix()
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 1, Expiry: n - 1})
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 2, Expiry: n - 2})
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 3, Created: n - 3}) // within bounds
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 5, Created: n - 5}) // over max server expiry limit
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 7, Created: n})
require.Equal(t, 5, cl.State.Inflight.Len())
cl.ClearInflights()
require.Equal(t, 0, cl.State.Inflight.Len())
}
func TestClientClearExpiredInflights(t *testing.T) {
cl, _, _ := newTestClient()
n := time.Now().Unix()
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 1, Expiry: n - 1})
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 2, Expiry: n - 2})
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 3, Created: n - 3}) // within bounds
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 5, Created: n - 5}) // over max server expiry limit
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 7, Created: n})
require.Equal(t, 5, cl.State.Inflight.Len())
deleted := cl.ClearExpiredInflights(n, 4)
require.Len(t, deleted, 3)
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
require.Equal(t, 2, cl.State.Inflight.Len())
cl.State.Inflight.Set(packets.Packet{PacketID: 11, Expiry: n - 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 12, Expiry: n - 2}) // expiry is ineffective for v3.
cl.State.Inflight.Set(packets.Packet{PacketID: 13, Created: n - 3}) // within bounds for v3
cl.State.Inflight.Set(packets.Packet{PacketID: 15, Created: n - 5}) // over max server expiry limit
require.Equal(t, 6, cl.State.Inflight.Len())
deleted = cl.ClearExpiredInflights(n, 4)
require.Len(t, deleted, 3)
require.ElementsMatch(t, []uint16{11, 12, 15}, deleted)
require.Equal(t, 3, cl.State.Inflight.Len())
cl.State.Inflight.Set(packets.Packet{PacketID: 17, Created: n - 1})
deleted = cl.ClearExpiredInflights(n, 0) // maximumExpiry = 0 do not process abandon messages
require.Len(t, deleted, 0)
require.Equal(t, 4, cl.State.Inflight.Len())
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 18, Expiry: n - 1})
deleted = cl.ClearExpiredInflights(n, 0) // maximumExpiry = 0 do not abandon messages
require.ElementsMatch(t, []uint16{18}, deleted) // expiry is still effective for v5.
require.Len(t, deleted, 1)
require.Equal(t, 4, cl.State.Inflight.Len())
}
func TestClientResendInflightMessages(t *testing.T) {
pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
cl, r, w := newTestClient()
cl.State.Inflight.Set(*pk1.Packet)
require.Equal(t, 1, cl.State.Inflight.Len())
go func() {
err := cl.ResendInflightMessages(true)
require.NoError(t, err)
time.Sleep(time.Millisecond)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, 0, cl.State.Inflight.Len())
require.Equal(t, pk1.RawBytes, buf)
}
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
cl, r, _ := newTestClient()
_ = r.Close()
cl.State.Inflight.Set(*pk1.Packet)
require.Equal(t, 1, cl.State.Inflight.Len())
err := cl.ResendInflightMessages(true)
require.Error(t, err)
require.ErrorIs(t, err, io.ErrClosedPipe)
require.Equal(t, 1, cl.State.Inflight.Len())
}
func TestClientResendInflightMessagesNoMessages(t *testing.T) {
cl, _, _ := newTestClient()
err := cl.ResendInflightMessages(true)
require.NoError(t, err)
}
func TestClientRefreshDeadline(t *testing.T) {
cl, _, _ := newTestClient()
cl.refreshDeadline(10)
require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
}
func TestClientReadFixedHeader(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
_, _ = r.Write([]byte{packets.Connect << 4, 0x00})
_ = r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.BytesReceived))
}
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
_, _ = r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
_ = r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
}
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
cl, r, _ := newTestClient()
cl.ops.options.Capabilities.MaximumPacketSize = 2
defer cl.Stop(errClientStop)
go func() {
_, _ = r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
_ = r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
}
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
_ = r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.Equal(t, io.EOF, err)
}
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
_, _ = r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
_ = r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
}
func TestClientReadOK(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
_, _ = r.Write([]byte{
packets.Publish << 4, 18, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload,
packets.Publish << 4, 11, // Fixed header
0, 5, // Topic Name - LSB+MSB
'd', '/', 'e', '/', 'f', // Topic Name
'y', 'e', 'a', 'h', // Payload
})
_ = r.Close()
}()
var pks []packets.Packet
o := make(chan error)
go func() {
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
pks = append(pks, pk)
return nil
})
}()
err := <-o
require.Error(t, err)
require.ErrorIs(t, err, io.EOF)
require.Equal(t, 2, len(pks))
require.Equal(t, []packets.Packet{
{
ProtocolVersion: cl.Properties.ProtocolVersion,
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Remaining: 18,
},
TopicName: "a/b/c",
Payload: []byte("hello mochi"),
},
{
ProtocolVersion: cl.Properties.ProtocolVersion,
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Remaining: 11,
},
TopicName: "d/e/f",
Payload: []byte("yeah"),
},
}, pks)
require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.MessagesReceived))
}
func TestClientReadDone(t *testing.T) {
cl, _, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.State.cancelOpen()
o := make(chan error)
go func() {
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
return nil
})
}()
require.NoError(t, <-o)
}
func TestClientStop(t *testing.T) {
cl, _, _ := newTestClient()
require.Equal(t, int64(0), cl.StopTime())
cl.Stop(nil)
require.Equal(t, nil, cl.State.stopCause.Load())
require.InDelta(t, time.Now().Unix(), cl.State.disconnected, 1.0)
require.Equal(t, cl.State.disconnected, cl.StopTime())
require.True(t, cl.Closed())
require.Equal(t, nil, cl.StopCause())
}
func TestClientClosed(t *testing.T) {
cl, _, _ := newTestClient()
require.False(t, cl.Closed())
cl.Stop(nil)
require.True(t, cl.Closed())
}
func TestClientReadFixedHeaderError(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
_, _ = r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
})
_ = r.Close()
}()
cl.Net.bconn = nil
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.ErrorIs(t, ErrConnectionClosed, err)
}
func TestClientReadReadHandlerErr(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
_, _ = r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
0, 5, // Topic Name - LSB+MSB
'd', '/', 'e', '/', 'f', // Topic Name
'y', 'e', 'a', 'h', // Payload
})
_ = r.Close()
}()
err := cl.Read(func(cl *Client, pk packets.Packet) error {
return errors.New("test")
})
require.Error(t, err)
}
func TestClientReadReadPacketOK(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
_, _ = r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
_ = r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
pk, err := cl.ReadPacket(fh)
require.NoError(t, err)
require.NotNil(t, pk)
require.Equal(t, packets.Packet{
ProtocolVersion: cl.Properties.ProtocolVersion,
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Remaining: 11,
},
TopicName: "d/e/f",
Payload: []byte("yeah"),
}, pk)
}
func TestClientReadPacket(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
for _, tx := range pkTable {
tt := tx // avoid data race
t.Run(tt.Desc, func(t *testing.T) {
atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0)
go func() {
_, _ = r.Write(tt.RawBytes)
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
if tt.Packet.ProtocolVersion == 5 {
cl.Properties.ProtocolVersion = 5
} else {
cl.Properties.ProtocolVersion = 0
}
pk, err := cl.ReadPacket(fh)
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
require.NotNil(t, pk, pkInfo, tt.Case, tt.Desc)
require.Equal(t, *tt.Packet, pk, pkInfo, tt.Case, tt.Desc)
if tt.Packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsReceived), pkInfo, tt.Case, tt.Desc)
}
})
}
}
func TestClientReadPacketInvalidTypeError(t *testing.T) {
cl, _, _ := newTestClient()
_ = cl.Net.Conn.Close()
_, err := cl.ReadPacket(&packets.FixedHeader{})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid packet type")
}
func TestClientWritePacket(t *testing.T) {
for _, tt := range pkTable {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
o := make(chan []byte)
go func() {
buf, err := io.ReadAll(r)
require.NoError(t, err)
o <- buf
}()
err := cl.WritePacket(*tt.Packet)
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
time.Sleep(2 * time.Millisecond)
_ = cl.Net.Conn.Close()
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
cl.Stop(errClientStop)
time.Sleep(time.Millisecond * 1)
// The stop cause is either the test error, EOF, or a
// closed pipe, depending on which goroutine runs first.
err = cl.StopCause()
require.True(t,
errors.Is(err, errClientStop) ||
errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrClosedPipe))
require.Equal(t, int64(len(tt.RawBytes)), atomic.LoadInt64(&cl.ops.info.BytesSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsSent))
if tt.Packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.MessagesSent))
}
}
}
func TestClientWritePacketBuffer(t *testing.T) {
r, w := net.Pipe()
cl := newClient(w, &ops{
info: new(system.Info),
hooks: new(Hooks),
log: logger,
options: &Options{
Capabilities: &Capabilities{
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
MaximumClientWritesPending: 3,
maximumPacketID: 10,
},
},
})
cl.ID = "mochi"
cl.State.Inflight.maximumSendQuota = 5
cl.State.Inflight.sendQuota = 5
cl.State.Inflight.maximumReceiveQuota = 10
cl.State.Inflight.receiveQuota = 10
cl.Properties.Props.TopicAliasMaximum = 0
cl.Properties.Props.RequestResponseInfo = 0x1
cl.ops.options.ClientNetWriteBufferSize = 10
defer cl.Stop(errClientStop)
small := packets.TPacketData[packets.Publish].Get(packets.TPublishNoPayload).Packet
large := packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
cl.State.outbound <- small
tt := []struct {
pks []*packets.Packet
size int
}{
{
pks: []*packets.Packet{small, small},
size: 18,
},
{
pks: []*packets.Packet{large},
size: 20,
},
{
pks: []*packets.Packet{small},
size: 0,
},
}
go func() {
for i, tx := range tt {
for _, pk := range tx.pks {
cl.Properties.ProtocolVersion = pk.ProtocolVersion
err := cl.WritePacket(*pk)
require.NoError(t, err, "index: %d", i)
if i == len(tt)-1 {
cl.Net.Conn.Close()
}
time.Sleep(100 * time.Millisecond)
}
}
}()
var n int
var err error
for i, tx := range tt {
buf := make([]byte, 100)
if i == len(tt)-1 {
buf, err = io.ReadAll(r)
n = len(buf)
} else {
n, err = io.ReadAtLeast(r, buf, 1)
}
require.NoError(t, err, "index: %d", i)
require.Equal(t, tx.size, n, "index: %d", i)
}
}
func TestWriteClientOversizePacket(t *testing.T) {
cl, _, _ := newTestClient()
cl.Properties.Props.MaximumPacketSize = 2
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
err := cl.WritePacket(pk)
require.Error(t, err)
require.ErrorIs(t, packets.ErrPacketTooLarge, err)
}
func TestClientReadPacketReadingError(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
_, _ = r.Write([]byte{
0, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
_ = r.Close()
}()
_, err := cl.ReadPacket(&packets.FixedHeader{
Type: 0,
Remaining: 11,
})
require.Error(t, err)
}
func TestClientReadPacketReadUnknown(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
_, _ = r.Write([]byte{
0, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
_ = r.Close()
}()
_, err := cl.ReadPacket(&packets.FixedHeader{
Remaining: 1,
})
require.Error(t, err)
}
func TestClientWritePacketWriteNoConn(t *testing.T) {
cl, _, _ := newTestClient()
cl.Stop(errClientStop)
err := cl.WritePacket(*pkTable[1].Packet)
require.Error(t, err)
require.Equal(t, ErrConnectionClosed, err)
}
func TestClientWritePacketWriteError(t *testing.T) {
cl, _, _ := newTestClient()
_ = cl.Net.Conn.Close()
err := cl.WritePacket(*pkTable[1].Packet)
require.Error(t, err)
}
func TestClientWritePacketInvalidPacket(t *testing.T) {
cl, _, _ := newTestClient()
err := cl.WritePacket(packets.Packet{})
require.Error(t, err)
}
var (
pkTable = []packets.TPacketCase{
packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311),
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedMqtt5),
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession),
packets.TPacketData[packets.Publish].Get(packets.TPublishBasic),
packets.TPacketData[packets.Publish].Get(packets.TPublishMqtt5),
packets.TPacketData[packets.Puback].Get(packets.TPuback),
packets.TPacketData[packets.Pubrec].Get(packets.TPubrec),
packets.TPacketData[packets.Pubrel].Get(packets.TPubrel),
packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp),
packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe),
packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5),
packets.TPacketData[packets.Suback].Get(packets.TSuback),
packets.TPacketData[packets.Suback].Get(packets.TSubackMqtt5),
packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribe),
packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5),
packets.TPacketData[packets.Unsuback].Get(packets.TUnsuback),
packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackMqtt5),
packets.TPacketData[packets.Pingreq].Get(packets.TPingreq),
packets.TPacketData[packets.Pingresp].Get(packets.TPingresp),
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect),
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5),
packets.TPacketData[packets.Auth].Get(packets.TAuth),
}
)

864
mqtt/hooks.go Normal file
View File

@@ -0,0 +1,864 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co, thedevop, dgduncan
package mqtt
import (
"errors"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"testmqtt/hooks/storage"
"testmqtt/packets"
"testmqtt/system"
)
const (
SetOptions byte = iota
OnSysInfoTick
OnStarted
OnStopped
OnConnectAuthenticate
OnACLCheck
OnConnect
OnSessionEstablish
OnSessionEstablished
OnDisconnect
OnAuthPacket
OnPacketRead
OnPacketEncode
OnPacketSent
OnPacketProcessed
OnSubscribe
OnSubscribed
OnSelectSubscribers
OnUnsubscribe
OnUnsubscribed
OnPublish
OnPublished
OnPublishDropped
OnRetainMessage
OnRetainPublished
OnQosPublish
OnQosComplete
OnQosDropped
OnPacketIDExhausted
OnWill
OnWillSent
OnClientExpired
OnRetainedExpired
StoredClients
StoredSubscriptions
StoredInflightMessages
StoredRetainedMessages
StoredSysInfo
)
var (
// ErrInvalidConfigType indicates a different Type of config value was expected to what was received.
// ErrInvalidConfigType = errors.New("invalid config type provided")
ErrInvalidConfigType = errors.New("提供的配置类型无效")
)
// HookLoadConfig contains the hook and configuration as loaded from a configuration (usually file).
type HookLoadConfig struct {
Hook Hook
Config any
}
// Hook provides an interface of handlers for different events which occur
// during the lifecycle of the broker.
type Hook interface {
ID() string
Provides(b byte) bool
Init(config any) error
Stop() error
SetOpts(l *slog.Logger, o *HookOptions)
OnStarted()
OnStopped()
OnConnectAuthenticate(cl *Client, pk packets.Packet) bool
OnACLCheck(cl *Client, topic string, write bool) bool
OnSysInfoTick(*system.Info)
OnConnect(cl *Client, pk packets.Packet) error
OnSessionEstablish(cl *Client, pk packets.Packet)
OnSessionEstablished(cl *Client, pk packets.Packet)
OnDisconnect(cl *Client, err error, expire bool)
OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error)
OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) // triggers when a new packet is received by a client, but before packet validation
OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet // modify a packet before it is byte-encoded and written to the client
OnPacketSent(cl *Client, pk packets.Packet, b []byte) // triggers when packet bytes have been written to the client
OnPacketProcessed(cl *Client, pk packets.Packet, err error) // triggers after a packet from the client been processed (handled)
OnSubscribe(cl *Client, pk packets.Packet) packets.Packet
OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers
OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet
OnUnsubscribed(cl *Client, pk packets.Packet)
OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
OnPublished(cl *Client, pk packets.Packet)
OnPublishDropped(cl *Client, pk packets.Packet)
OnRetainMessage(cl *Client, pk packets.Packet, r int64)
OnRetainPublished(cl *Client, pk packets.Packet)
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
OnQosComplete(cl *Client, pk packets.Packet)
OnQosDropped(cl *Client, pk packets.Packet)
OnPacketIDExhausted(cl *Client, pk packets.Packet)
OnWill(cl *Client, will Will) (Will, error)
OnWillSent(cl *Client, pk packets.Packet)
OnClientExpired(cl *Client)
OnRetainedExpired(filter string)
StoredClients() ([]storage.Client, error)
StoredSubscriptions() ([]storage.Subscription, error)
StoredInflightMessages() ([]storage.Message, error)
StoredRetainedMessages() ([]storage.Message, error)
StoredSysInfo() (storage.SystemInfo, error)
}
// HookOptions contains values which are inherited from the server on initialisation.
type HookOptions struct {
Capabilities *Capabilities
}
// Hooks is a slice of Hook interfaces to be called in sequence.
type Hooks struct {
Log *slog.Logger // a logger for the hook (from the server)
internal atomic.Value // a slice of []Hook
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
qty int64 // the number of hooks in use
sync.Mutex // a mutex for locking when adding hooks
}
// Len returns the number of hooks added.
func (h *Hooks) Len() int64 {
return atomic.LoadInt64(&h.qty)
}
// Provides returns true if any one hook provides any of the requested hook methods.
func (h *Hooks) Provides(b ...byte) bool {
for _, hook := range h.GetAll() {
for _, hb := range b {
if hook.Provides(hb) {
return true
}
}
}
return false
}
// Add adds and initializes a new hook.
func (h *Hooks) Add(hook Hook, config any) error {
h.Lock()
defer h.Unlock()
err := hook.Init(config)
if err != nil {
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
}
i, ok := h.internal.Load().([]Hook)
if !ok {
i = []Hook{}
}
i = append(i, hook)
h.internal.Store(i)
atomic.AddInt64(&h.qty, 1)
h.wg.Add(1)
return nil
}
// GetAll returns a slice of all the hooks.
func (h *Hooks) GetAll() []Hook {
i, ok := h.internal.Load().([]Hook)
if !ok {
return []Hook{}
}
return i
}
// Stop indicates all attached hooks to gracefully end.
func (h *Hooks) Stop() {
go func() {
for _, hook := range h.GetAll() {
h.Log.Info("stopping hook", "hook", hook.ID())
if err := hook.Stop(); err != nil {
h.Log.Debug("problem stopping hook", "error", err, "hook", hook.ID())
}
h.wg.Done()
}
}()
h.wg.Wait()
}
// OnSysInfoTick is called when the $SYS topic values are published out.
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
for _, hook := range h.GetAll() {
if hook.Provides(OnSysInfoTick) {
hook.OnSysInfoTick(sys)
}
}
}
// OnStarted is called when the server has successfully started.
func (h *Hooks) OnStarted() {
for _, hook := range h.GetAll() {
if hook.Provides(OnStarted) {
hook.OnStarted()
}
}
}
// OnStopped is called when the server has successfully stopped.
func (h *Hooks) OnStopped() {
for _, hook := range h.GetAll() {
if hook.Provides(OnStopped) {
hook.OnStopped()
}
}
}
// OnConnect is called when a new client connects, and may return a packets.Code as an error to halt the connection.
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) error {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnect) {
err := hook.OnConnect(cl, pk)
if err != nil {
return err
}
}
}
return nil
}
// OnSessionEstablish is called right after a new client connects and authenticates and right before
// the session is established and CONNACK is sent.
func (h *Hooks) OnSessionEstablish(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnSessionEstablish) {
hook.OnSessionEstablish(cl, pk)
}
}
}
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnSessionEstablished) {
hook.OnSessionEstablished(cl, pk)
}
}
}
// OnDisconnect is called when a client is disconnected for any reason.
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
for _, hook := range h.GetAll() {
if hook.Provides(OnDisconnect) {
hook.OnDisconnect(cl, err, expire)
}
}
}
// OnPacketRead is called when a packet is received from a client.
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketRead) {
npk, err := hook.OnPacketRead(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
h.Log.Debug("packet rejected", "hook", hook.ID(), "packet", pkx)
return pk, err
} else if err != nil {
continue
}
pkx = npk
}
}
return
}
// OnAuthPacket is called when an auth packet is received. It is intended to allow developers
// to create their own auth packet handling mechanisms.
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.GetAll() {
if hook.Provides(OnAuthPacket) {
npk, err := hook.OnAuthPacket(cl, pkx)
if err != nil {
return pk, err
}
pkx = npk
}
}
return
}
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketEncode) {
pk = hook.OnPacketEncode(cl, pk)
}
}
return pk
}
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketProcessed) {
hook.OnPacketProcessed(cl, pk, err)
}
}
}
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
// containing the bytes sent.
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketSent) {
hook.OnPacketSent(cl, pk, b)
}
}
}
// OnSubscribe is called when a client subscribes to one or more filters. This method
// differs from OnSubscribed in that it allows you to modify the subscription values
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribe) {
pk = hook.OnSubscribe(cl, pk)
}
}
return pk
}
// OnSubscribed is called when a client subscribes to one or more filters.
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribed) {
hook.OnSubscribed(cl, pk, reasonCodes)
}
}
}
// OnSelectSubscribers is called when subscribers have been collected for a topic, but before
// shared subscription subscribers have been selected. This hook can be used to programmatically
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
// group in a custom manner (such as based on client id, ip, etc).
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
for _, hook := range h.GetAll() {
if hook.Provides(OnSelectSubscribers) {
subs = hook.OnSelectSubscribers(subs, pk)
}
}
return subs
}
// OnUnsubscribe is called when a client unsubscribes from one or more filters. This method
// differs from OnUnsubscribed in that it allows you to modify the unsubscription values
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribe) {
pk = hook.OnUnsubscribe(cl, pk)
}
}
return pk
}
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribed) {
hook.OnUnsubscribed(cl, pk)
}
}
}
// OnPublish is called when a client publishes a message. This method differs from OnPublished
// in that it allows you to modify you to modify the incoming packet before it is processed.
// The return values of the hook methods are passed-through in the order the hooks were attached.
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.GetAll() {
if hook.Provides(OnPublish) {
npk, err := hook.OnPublish(cl, pkx)
if err != nil {
if errors.Is(err, packets.ErrRejectPacket) {
h.Log.Debug("publish packet rejected",
"error", err,
"hook", hook.ID(),
"packet", pkx)
return pk, err
}
h.Log.Error("publish packet error",
"error", err,
"hook", hook.ID(),
"packet", pkx)
return pk, err
}
pkx = npk
}
}
return
}
// OnPublished is called when a client has published a message to subscribers.
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublished) {
hook.OnPublished(cl, pk)
}
}
}
// OnPublishDropped is called when a message to a client was dropped instead of delivered
// such as when a client is too slow to respond.
func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublishDropped) {
hook.OnPublishDropped(cl, pk)
}
}
}
// OnRetainMessage is called then a published message is retained.
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainMessage) {
hook.OnRetainMessage(cl, pk, r)
}
}
}
// OnRetainPublished is called when a retained message is published.
func (h *Hooks) OnRetainPublished(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainPublished) {
hook.OnRetainPublished(cl, pk)
}
}
}
// OnQosPublish is called when a publish packet with Qos >= 1 is issued to a subscriber.
// In other words, this method is called when a new inflight message is created or resent.
// It is typically used to store a new inflight message.
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosPublish) {
hook.OnQosPublish(cl, pk, sent, resends)
}
}
}
// OnQosComplete is called when the Qos flow for a message has been completed.
// In other words, when an inflight message is resolved.
// It is typically used to delete an inflight message from a store.
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosComplete) {
hook.OnQosComplete(cl, pk)
}
}
}
// OnQosDropped is called the Qos flow for a message expires. In other words, when
// an inflight message expires or is abandoned. It is typically used to delete an
// inflight message from a store.
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosDropped) {
hook.OnQosDropped(cl, pk)
}
}
}
// OnPacketIDExhausted is called when the client runs out of unused packet ids to
// assign to a packet.
func (h *Hooks) OnPacketIDExhausted(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketIDExhausted) {
hook.OnPacketIDExhausted(cl, pk)
}
}
}
// OnWill is called when a client disconnects and publishes an LWT message. This method
// differs from OnWillSent in that it allows you to modify the LWT message before it is
// published. The return values of the hook methods are passed-through in the order
// the hooks were attached.
func (h *Hooks) OnWill(cl *Client, will Will) Will {
for _, hook := range h.GetAll() {
if hook.Provides(OnWill) {
mlwt, err := hook.OnWill(cl, will)
if err != nil {
h.Log.Error("parse will error",
"error", err,
"hook", hook.ID(),
"will", will)
continue
}
will = mlwt
}
}
return will
}
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnWillSent) {
hook.OnWillSent(cl, pk)
}
}
}
// OnClientExpired is called when a client session has expired and should be deleted.
func (h *Hooks) OnClientExpired(cl *Client) {
for _, hook := range h.GetAll() {
if hook.Provides(OnClientExpired) {
hook.OnClientExpired(cl)
}
}
}
// OnRetainedExpired is called when a retained message has expired and should be deleted.
func (h *Hooks) OnRetainedExpired(filter string) {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainedExpired) {
hook.OnRetainedExpired(filter)
}
}
}
// StoredClients returns all clients, e.g. from a persistent store, is used to
// populate the server clients list before start.
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredClients) {
v, err := hook.StoredClients()
if err != nil {
h.Log.Error("failed to load clients", "error", err, "hook", hook.ID())
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
// used to populate the server subscriptions list before start.
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSubscriptions) {
v, err := hook.StoredSubscriptions()
if err != nil {
h.Log.Error("failed to load subscriptions", "error", err, "hook", hook.ID())
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
// and is used to populate the restored clients with inflight messages before start.
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredInflightMessages) {
v, err := hook.StoredInflightMessages()
if err != nil {
h.Log.Error("failed to load inflight messages", "error", err, "hook", hook.ID())
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
// and is used to populate the server topics with retained messages before start.
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredRetainedMessages) {
v, err := hook.StoredRetainedMessages()
if err != nil {
h.Log.Error("failed to load retained messages", "error", err, "hook", hook.ID())
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredSysInfo returns a set of system info values.
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSysInfo) {
v, err := hook.StoredSysInfo()
if err != nil {
h.Log.Error("failed to load $SYS info", "error", err, "hook", hook.ID())
return v, err
}
if v.Version != "" {
return v, nil
}
}
}
return
}
// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
// An implementation of this method MUST be used to allow or deny access to the
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check connecting users against an existing user database.
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnectAuthenticate) {
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
return true
}
}
}
return false
}
// OnACLCheck is called when a user attempts to publish or subscribe to a topic filter.
// An implementation of this method MUST be used to allow or deny access to the
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check publishing and subscribing users against an existing permissions or roles database.
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
for _, hook := range h.GetAll() {
if hook.Provides(OnACLCheck) {
if ok := hook.OnACLCheck(cl, topic, write); ok {
return true
}
}
}
return false
}
// HookBase provides a set of default methods for each hook. It should be embedded in
// all hooks.
type HookBase struct {
Hook
Log *slog.Logger
Opts *HookOptions
}
// ID returns the ID of the hook.
func (h *HookBase) ID() string {
return "base"
}
// Provides indicates which methods a hook provides. The default is none - this method
// should be overridden by the embedding hook.
func (h *HookBase) Provides(b byte) bool {
return false
}
// Init performs any pre-start initializations for the hook, such as connecting to databases
// or opening files.
func (h *HookBase) Init(config any) error {
return nil
}
// SetOpts is called by the server to propagate internal values and generally should
// not be called manually.
func (h *HookBase) SetOpts(l *slog.Logger, opts *HookOptions) {
h.Log = l
h.Opts = opts
}
// Stop is called to gracefully shut down the hook.
func (h *HookBase) Stop() error {
return nil
}
// OnStarted is called when the server starts.
func (h *HookBase) OnStarted() {}
// OnStopped is called when the server stops.
func (h *HookBase) OnStopped() {}
// OnSysInfoTick is called when the server publishes system info.
func (h *HookBase) OnSysInfoTick(*system.Info) {}
// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
func (h *HookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
return false
}
// OnACLCheck is called when a user attempts to subscribe or publish to a topic.
func (h *HookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
return false
}
// OnConnect is called when a new client connects.
func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) error {
return nil
}
// OnSessionEstablish is called right after a new client connects and authenticates and right before
// the session is established and CONNACK is sent.
func (h *HookBase) OnSessionEstablish(cl *Client, pk packets.Packet) {}
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *HookBase) OnSessionEstablished(cl *Client, pk packets.Packet) {}
// OnDisconnect is called when a client is disconnected for any reason.
func (h *HookBase) OnDisconnect(cl *Client, err error, expire bool) {}
// OnAuthPacket is called when an auth packet is received from the client.
func (h *HookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
return pk, nil
}
// OnPacketRead is called when a packet is received.
func (h *HookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
return pk, nil
}
// OnPacketEncode is called before a packet is byte-encoded and written to the client.
func (h *HookBase) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
return pk
}
// OnPacketSent is called immediately after a packet is written to a client.
func (h *HookBase) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {}
// OnPacketProcessed is called immediately after a packet from a client is processed.
func (h *HookBase) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {}
// OnSubscribe is called when a client subscribes to one or more filters.
func (h *HookBase) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
return pk
}
// OnSubscribed is called when a client subscribes to one or more filters.
func (h *HookBase) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {}
// OnSelectSubscribers is called when selecting subscribers to receive a message.
func (h *HookBase) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
return subs
}
// OnUnsubscribe is called when a client unsubscribes from one or more filters.
func (h *HookBase) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
return pk
}
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
func (h *HookBase) OnUnsubscribed(cl *Client, pk packets.Packet) {}
// OnPublish is called when a client publishes a message.
func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
return pk, nil
}
// OnPublished is called when a client has published a message to subscribers.
func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
// OnPublishDropped is called when a message to a client is dropped instead of being delivered.
func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {}
// OnRetainMessage is called then a published message is retained.
func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}
// OnRetainPublished is called when a retained message is published.
func (h *HookBase) OnRetainPublished(cl *Client, pk packets.Packet) {}
// OnQosPublish is called when a publish packet with Qos > 1 is issued to a subscriber.
func (h *HookBase) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {}
// OnQosComplete is called when the Qos flow for a message has been completed.
func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {}
// OnQosDropped is called the Qos flow for a message expires.
func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {}
// OnPacketIDExhausted is called when the client runs out of unused packet ids to assign to a packet.
func (h *HookBase) OnPacketIDExhausted(cl *Client, pk packets.Packet) {}
// OnWill is called when a client disconnects and publishes an LWT message.
func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) {
return will, nil
}
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
func (h *HookBase) OnWillSent(cl *Client, pk packets.Packet) {}
// OnClientExpired is called when a client session has expired.
func (h *HookBase) OnClientExpired(cl *Client) {}
// OnRetainedExpired is called when a retained message for a topic has expired.
func (h *HookBase) OnRetainedExpired(topic string) {}
// StoredClients returns all clients from a store.
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
return
}
// StoredSubscriptions returns all subcriptions from a store.
func (h *HookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
return
}
// StoredInflightMessages returns all inflight messages from a store.
func (h *HookBase) StoredInflightMessages() (v []storage.Message, err error) {
return
}
// StoredRetainedMessages 返回存储区中所有保留的消息
// StoredRetainedMessages returns all retained messages from a store.
func (h *HookBase) StoredRetainedMessages() (v []storage.Message, err error) {
return
}
// StoredSysInfo 返回一组系统信息值
// StoredSysInfo returns a set of system info values.
func (h *HookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
return
}

667
mqtt/hooks_test.go Normal file
View File

@@ -0,0 +1,667 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"errors"
"strconv"
"sync/atomic"
"testing"
"time"
"testmqtt/hooks/storage"
"testmqtt/packets"
"testmqtt/system"
"github.com/stretchr/testify/require"
)
type modifiedHookBase struct {
HookBase
err error
fail bool
failAt int
}
var errTestHook = errors.New("error")
func (h *modifiedHookBase) ID() string {
return "modified"
}
func (h *modifiedHookBase) Init(config any) error {
if config != nil {
return errTestHook
}
return nil
}
func (h *modifiedHookBase) Provides(b byte) bool {
return true
}
func (h *modifiedHookBase) Stop() error {
if h.fail {
return errTestHook
}
return nil
}
func (h *modifiedHookBase) OnConnect(cl *Client, pk packets.Packet) error {
if h.fail {
return errTestHook
}
return nil
}
func (h *modifiedHookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
return true
}
func (h *modifiedHookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
return true
}
func (h *modifiedHookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
if h.fail {
if h.err != nil {
return pk, h.err
}
return pk, errTestHook
}
return pk, nil
}
func (h *modifiedHookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
if h.fail {
if h.err != nil {
return pk, h.err
}
return pk, errTestHook
}
return pk, nil
}
func (h *modifiedHookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
if h.fail {
if h.err != nil {
return pk, h.err
}
return pk, errTestHook
}
return pk, nil
}
func (h *modifiedHookBase) OnWill(cl *Client, will Will) (Will, error) {
if h.fail {
return will, errTestHook
}
return will, nil
}
func (h *modifiedHookBase) StoredClients() (v []storage.Client, err error) {
if h.fail || h.failAt == 1 {
return v, errTestHook
}
return []storage.Client{
{ID: "cl1"},
{ID: "cl2"},
{ID: "cl3"},
}, nil
}
func (h *modifiedHookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.fail || h.failAt == 2 {
return v, errTestHook
}
return []storage.Subscription{
{ID: "sub1"},
{ID: "sub2"},
{ID: "sub3"},
}, nil
}
func (h *modifiedHookBase) StoredRetainedMessages() (v []storage.Message, err error) {
if h.fail || h.failAt == 3 {
return v, errTestHook
}
return []storage.Message{
{ID: "r1"},
{ID: "r2"},
{ID: "r3"},
}, nil
}
func (h *modifiedHookBase) StoredInflightMessages() (v []storage.Message, err error) {
if h.fail || h.failAt == 4 {
return v, errTestHook
}
return []storage.Message{
{ID: "i1"},
{ID: "i2"},
{ID: "i3"},
}, nil
}
func (h *modifiedHookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.fail || h.failAt == 5 {
return v, errTestHook
}
return storage.SystemInfo{
Info: system.Info{
Version: "2.0.0",
},
}, nil
}
type providesCheckHook struct {
HookBase
}
func (h *providesCheckHook) Provides(b byte) bool {
return b == OnConnect
}
func TestHooksProvides(t *testing.T) {
h := new(Hooks)
err := h.Add(new(providesCheckHook), nil)
require.NoError(t, err)
err = h.Add(new(HookBase), nil)
require.NoError(t, err)
require.True(t, h.Provides(OnConnect, OnDisconnect))
require.False(t, h.Provides(OnDisconnect))
}
func TestHooksAddLenGetAll(t *testing.T) {
h := new(Hooks)
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
err = h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(2), h.Len())
all := h.GetAll()
require.Equal(t, "base", all[0].ID())
require.Equal(t, "modified", all[1].ID())
}
func TestHooksAddInitFailure(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), map[string]any{})
require.Error(t, err)
require.Equal(t, int64(0), atomic.LoadInt64(&h.qty))
}
func TestHooksStop(t *testing.T) {
h := new(Hooks)
h.Log = logger
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(1), h.Len())
h.Stop()
}
// coverage: also cover some empty functions
func TestHooksNonReturns(t *testing.T) {
h := new(Hooks)
cl := new(Client)
for i := 0; i < 2; i++ {
t.Run("step-"+strconv.Itoa(i), func(t *testing.T) {
// on first iteration, check without hook methods
h.OnStarted()
h.OnStopped()
h.OnSysInfoTick(new(system.Info))
h.OnSessionEstablish(cl, packets.Packet{})
h.OnSessionEstablished(cl, packets.Packet{})
h.OnDisconnect(cl, nil, false)
h.OnPacketSent(cl, packets.Packet{}, []byte{})
h.OnPacketProcessed(cl, packets.Packet{}, nil)
h.OnSubscribed(cl, packets.Packet{}, []byte{1})
h.OnUnsubscribed(cl, packets.Packet{})
h.OnPublished(cl, packets.Packet{})
h.OnPublishDropped(cl, packets.Packet{})
h.OnRetainMessage(cl, packets.Packet{}, 0)
h.OnRetainPublished(cl, packets.Packet{})
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
h.OnQosComplete(cl, packets.Packet{})
h.OnQosDropped(cl, packets.Packet{})
h.OnPacketIDExhausted(cl, packets.Packet{})
h.OnWillSent(cl, packets.Packet{})
h.OnClientExpired(cl)
h.OnRetainedExpired("a/b/c")
// on second iteration, check added hook methods
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
})
}
}
func TestHooksOnConnectAuthenticate(t *testing.T) {
h := new(Hooks)
ok := h.OnConnectAuthenticate(new(Client), packets.Packet{})
require.False(t, ok)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
ok = h.OnConnectAuthenticate(new(Client), packets.Packet{})
require.True(t, ok)
}
func TestHooksOnACLCheck(t *testing.T) {
h := new(Hooks)
ok := h.OnACLCheck(new(Client), "a/b/c", true)
require.False(t, ok)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
ok = h.OnACLCheck(new(Client), "a/b/c", true)
require.True(t, ok)
}
func TestHooksOnSubscribe(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
pki := packets.Packet{
Filters: packets.Subscriptions{
{Filter: "a/b/c", Qos: 1},
},
}
pk := h.OnSubscribe(new(Client), pki)
require.EqualValues(t, pk, pki)
}
func TestHooksOnSelectSubscribers(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
subs := &Subscribers{
Subscriptions: map[string]packets.Subscription{
"cl1": {Filter: "a/b/c"},
},
}
subs2 := h.OnSelectSubscribers(subs, packets.Packet{})
require.EqualValues(t, subs, subs2)
}
func TestHooksOnUnsubscribe(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
pki := packets.Packet{
Filters: packets.Subscriptions{
{Filter: "a/b/c", Qos: 1},
},
}
pk := h.OnUnsubscribe(new(Client), pki)
require.EqualValues(t, pk, pki)
}
func TestHooksOnPublish(t *testing.T) {
h := new(Hooks)
h.Log = logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: failure
hook.fail = true
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: reject packet
hook.err = packets.ErrRejectPacket
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrRejectPacket)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnPacketRead(t *testing.T) {
h := new(Hooks)
h.Log = logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: failure
hook.fail = true
pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: reject packet
hook.err = packets.ErrRejectPacket
pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrRejectPacket)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnAuthPacket(t *testing.T) {
h := new(Hooks)
h.Log = logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
hook.fail = true
pk, err = h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnConnect(t *testing.T) {
h := new(Hooks)
h.Log = logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
err = h.OnConnect(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
hook.fail = true
err = h.OnConnect(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
}
func TestHooksOnPacketEncode(t *testing.T) {
h := new(Hooks)
h.Log = logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk := h.OnPacketEncode(new(Client), packets.Packet{PacketID: 10})
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnLWT(t *testing.T) {
h := new(Hooks)
h.Log = logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
lwt := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
require.Equal(t, "a/b/c", lwt.TopicName)
// coverage: fail lwt
hook.fail = true
lwt = h.OnWill(new(Client), Will{TopicName: "a/b/c"})
require.Equal(t, "a/b/c", lwt.TopicName)
}
func TestHooksStoredClients(t *testing.T) {
h := new(Hooks)
h.Log = logger
v, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredClients()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredClients()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredSubscriptions(t *testing.T) {
h := new(Hooks)
h.Log = logger
v, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredSubscriptions()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredRetainedMessages(t *testing.T) {
h := new(Hooks)
h.Log = logger
v, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredRetainedMessages()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredInflightMessages(t *testing.T) {
h := new(Hooks)
h.Log = logger
v, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredInflightMessages()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredSysInfo(t *testing.T) {
h := new(Hooks)
h.Log = logger
v, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "", v.Info.Version)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", v.Info.Version)
hook.fail = true
v, err = h.StoredSysInfo()
require.Error(t, err)
require.Equal(t, "", v.Info.Version)
}
func TestHookBaseID(t *testing.T) {
h := new(HookBase)
require.Equal(t, "base", h.ID())
}
func TestHookBaseProvidesNone(t *testing.T) {
h := new(HookBase)
require.False(t, h.Provides(OnConnect))
require.False(t, h.Provides(OnDisconnect))
}
func TestHookBaseInit(t *testing.T) {
h := new(HookBase)
require.Nil(t, h.Init(nil))
}
func TestHookBaseSetOpts(t *testing.T) {
h := new(HookBase)
h.SetOpts(logger, new(HookOptions))
require.NotNil(t, h.Log)
require.NotNil(t, h.Opts)
}
func TestHookBaseClose(t *testing.T) {
h := new(HookBase)
require.Nil(t, h.Stop())
}
func TestHookBaseOnConnectAuthenticate(t *testing.T) {
h := new(HookBase)
v := h.OnConnectAuthenticate(new(Client), packets.Packet{})
require.False(t, v)
}
func TestHookBaseOnACLCheck(t *testing.T) {
h := new(HookBase)
v := h.OnACLCheck(new(Client), "topic", true)
require.False(t, v)
}
func TestHookBaseOnConnect(t *testing.T) {
h := new(HookBase)
err := h.OnConnect(new(Client), packets.Packet{})
require.NoError(t, err)
}
func TestHookBaseOnPublish(t *testing.T) {
h := new(HookBase)
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHookBaseOnPacketRead(t *testing.T) {
h := new(HookBase)
pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHookBaseOnAuthPacket(t *testing.T) {
h := new(HookBase)
pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHookBaseOnLWT(t *testing.T) {
h := new(HookBase)
lwt, err := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
require.NoError(t, err)
require.Equal(t, "a/b/c", lwt.TopicName)
}
func TestHookBaseStoredClients(t *testing.T) {
h := new(HookBase)
v, err := h.StoredClients()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoredSubscriptions(t *testing.T) {
h := new(HookBase)
v, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoredInflightMessages(t *testing.T) {
h := new(HookBase)
v, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoredRetainedMessages(t *testing.T) {
h := new(HookBase)
v, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoreSysInfo(t *testing.T) {
h := new(HookBase)
v, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "", v.Version)
}

156
mqtt/inflight.go Normal file
View File

@@ -0,0 +1,156 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"sort"
"sync"
"sync/atomic"
"testmqtt/packets"
)
// Inflight is a map of InflightMessage keyed on packet id.
type Inflight struct {
sync.RWMutex
internal map[uint16]packets.Packet // internal contains the inflight packets
receiveQuota int32 // remaining inbound qos quota for flow control
sendQuota int32 // remaining outbound qos quota for flow control
maximumReceiveQuota int32 // maximum allowed receive quota
maximumSendQuota int32 // maximum allowed send quota
}
// NewInflights returns a new instance of an Inflight packets map.
func NewInflights() *Inflight {
return &Inflight{
internal: map[uint16]packets.Packet{},
}
}
// Set adds or updates an inflight packet by packet id.
func (i *Inflight) Set(m packets.Packet) bool {
i.Lock()
defer i.Unlock()
_, ok := i.internal[m.PacketID]
i.internal[m.PacketID] = m
return !ok
}
// Get returns an inflight packet by packet id.
func (i *Inflight) Get(id uint16) (packets.Packet, bool) {
i.RLock()
defer i.RUnlock()
if m, ok := i.internal[id]; ok {
return m, true
}
return packets.Packet{}, false
}
// Len returns the size of the inflight messages map.
func (i *Inflight) Len() int {
i.RLock()
defer i.RUnlock()
return len(i.internal)
}
// Clone returns a new instance of Inflight with the same message data.
// This is used when transferring inflights from a taken-over session.
func (i *Inflight) Clone() *Inflight {
c := NewInflights()
i.RLock()
defer i.RUnlock()
for k, v := range i.internal {
c.internal[k] = v
}
return c
}
// GetAll returns all the inflight messages.
func (i *Inflight) GetAll(immediate bool) []packets.Packet {
i.RLock()
defer i.RUnlock()
m := []packets.Packet{}
for _, v := range i.internal {
if !immediate || (immediate && v.Expiry < 0) {
m = append(m, v)
}
}
sort.Slice(m, func(i, j int) bool {
return uint16(m[i].Created) < uint16(m[j].Created)
})
return m
}
// NextImmediate returns the next inflight packet which is indicated to be sent immediately.
// This typically occurs when the quota has been exhausted, and we need to wait until new quota
// is free to continue sending.
func (i *Inflight) NextImmediate() (packets.Packet, bool) {
i.RLock()
defer i.RUnlock()
m := i.GetAll(true)
if len(m) > 0 {
return m[0], true
}
return packets.Packet{}, false
}
// Delete removes an in-flight message from the map. Returns true if the message existed.
func (i *Inflight) Delete(id uint16) bool {
i.Lock()
defer i.Unlock()
_, ok := i.internal[id]
delete(i.internal, id)
return ok
}
// TakeRecieveQuota reduces the receive quota by 1.
func (i *Inflight) DecreaseReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) > 0 {
atomic.AddInt32(&i.receiveQuota, -1)
}
}
// TakeRecieveQuota increases the receive quota by 1.
func (i *Inflight) IncreaseReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) {
atomic.AddInt32(&i.receiveQuota, 1)
}
}
// ResetReceiveQuota resets the receive quota to the maximum allowed value.
func (i *Inflight) ResetReceiveQuota(n int32) {
atomic.StoreInt32(&i.receiveQuota, n)
atomic.StoreInt32(&i.maximumReceiveQuota, n)
}
// DecreaseSendQuota reduces the send quota by 1.
func (i *Inflight) DecreaseSendQuota() {
if atomic.LoadInt32(&i.sendQuota) > 0 {
atomic.AddInt32(&i.sendQuota, -1)
}
}
// IncreaseSendQuota increases the send quota by 1.
func (i *Inflight) IncreaseSendQuota() {
if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) {
atomic.AddInt32(&i.sendQuota, 1)
}
}
// ResetSendQuota resets the send quota to the maximum allowed value.
func (i *Inflight) ResetSendQuota(n int32) {
atomic.StoreInt32(&i.sendQuota, n)
atomic.StoreInt32(&i.maximumSendQuota, n)
}

199
mqtt/inflight_test.go Normal file
View File

@@ -0,0 +1,199 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"sync/atomic"
"testing"
"github.com/stretchr/testify/require"
"testmqtt/packets"
)
func TestInflightSet(t *testing.T) {
cl, _, _ := newTestClient()
r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
require.True(t, r)
require.NotNil(t, cl.State.Inflight.internal[1])
require.NotEqual(t, 0, cl.State.Inflight.internal[1].PacketID)
r = cl.State.Inflight.Set(packets.Packet{PacketID: 1})
require.False(t, r)
}
func TestInflightGet(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
msg, ok := cl.State.Inflight.Get(2)
require.True(t, ok)
require.NotEqual(t, 0, msg.PacketID)
}
func TestInflightGetAllAndImmediate(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
require.Equal(t, []packets.Packet{
{PacketID: 1, Created: 1},
{PacketID: 2, Created: 2},
{PacketID: 3, Created: 3, Expiry: -1},
{PacketID: 4, Created: 4, Expiry: -1},
{PacketID: 5, Created: 5},
}, cl.State.Inflight.GetAll(false))
require.Equal(t, []packets.Packet{
{PacketID: 3, Created: 3, Expiry: -1},
{PacketID: 4, Created: 4, Expiry: -1},
}, cl.State.Inflight.GetAll(true))
}
func TestInflightLen(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len())
}
func TestInflightClone(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len())
cloned := cl.State.Inflight.Clone()
require.NotNil(t, cloned)
require.NotSame(t, cloned, cl.State.Inflight)
}
func TestInflightDelete(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 3})
require.NotNil(t, cl.State.Inflight.internal[3])
r := cl.State.Inflight.Delete(3)
require.True(t, r)
require.Equal(t, uint16(0), cl.State.Inflight.internal[3].PacketID)
_, ok := cl.State.Inflight.Get(3)
require.False(t, ok)
r = cl.State.Inflight.Delete(3)
require.False(t, r)
}
func TestResetReceiveQuota(t *testing.T) {
i := NewInflights()
require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
i.ResetReceiveQuota(6)
require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(6), atomic.LoadInt32(&i.receiveQuota))
}
func TestReceiveQuota(t *testing.T) {
i := NewInflights()
i.receiveQuota = 4
i.maximumReceiveQuota = 5
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota))
// Return 1
i.IncreaseReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
// Try to go over max limit
i.IncreaseReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
// Reset to max 1
i.ResetReceiveQuota(1)
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota))
// Take 1
i.DecreaseReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
// Try to go below zero
i.DecreaseReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
}
func TestResetSendQuota(t *testing.T) {
i := NewInflights()
require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
i.ResetSendQuota(6)
require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(6), atomic.LoadInt32(&i.sendQuota))
}
func TestSendQuota(t *testing.T) {
i := NewInflights()
i.sendQuota = 4
i.maximumSendQuota = 5
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota))
// Return 1
i.IncreaseSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
// Try to go over max limit
i.IncreaseSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
// Reset to max 1
i.ResetSendQuota(1)
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota))
// Take 1
i.DecreaseSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
// Try to go below zero
i.DecreaseSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
}
func TestNextImmediate(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
pk, ok := cl.State.Inflight.NextImmediate()
require.True(t, ok)
require.Equal(t, packets.Packet{PacketID: 3, Created: 3, Expiry: -1}, pk)
r := cl.State.Inflight.Delete(3)
require.True(t, r)
pk, ok = cl.State.Inflight.NextImmediate()
require.True(t, ok)
require.Equal(t, packets.Packet{PacketID: 4, Created: 4, Expiry: -1}, pk)
r = cl.State.Inflight.Delete(4)
require.True(t, r)
_, ok = cl.State.Inflight.NextImmediate()
require.False(t, ok)
}

1759
mqtt/server.go Normal file

File diff suppressed because it is too large Load Diff

3915
mqtt/server_test.go Normal file

File diff suppressed because it is too large Load Diff

824
mqtt/topics.go Normal file
View File

@@ -0,0 +1,824 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"strings"
"sync"
"sync/atomic"
"testmqtt/packets"
)
var (
// SharePrefix 共享主题的前缀
SharePrefix = "$SHARE" // the prefix indicating a share topic
// SysPrefix 系统信息主题的前缀
SysPrefix = "$SYS" // the prefix indicating a system info topic
)
// TopicAliases contains inbound and outbound topic alias registrations.
type TopicAliases struct {
Inbound *InboundTopicAliases
Outbound *OutboundTopicAliases
}
// NewTopicAliases returns an instance of TopicAliases.
func NewTopicAliases(topicAliasMaximum uint16) TopicAliases {
return TopicAliases{
Inbound: NewInboundTopicAliases(topicAliasMaximum),
Outbound: NewOutboundTopicAliases(topicAliasMaximum),
}
}
// NewInboundTopicAliases returns a pointer to InboundTopicAliases.
func NewInboundTopicAliases(topicAliasMaximum uint16) *InboundTopicAliases {
return &InboundTopicAliases{
maximum: topicAliasMaximum,
internal: map[uint16]string{},
}
}
// InboundTopicAliases contains a map of topic aliases received from the client.
type InboundTopicAliases struct {
internal map[uint16]string
sync.RWMutex
maximum uint16
}
// Set sets a new alias for a specific topic.
func (a *InboundTopicAliases) Set(id uint16, topic string) string {
a.Lock()
defer a.Unlock()
if a.maximum == 0 {
return topic // ?
}
if existing, ok := a.internal[id]; ok && topic == "" {
return existing
}
a.internal[id] = topic
return topic
}
// OutboundTopicAliases contains a map of topic aliases sent from the broker to the client.
type OutboundTopicAliases struct {
internal map[string]uint16
sync.RWMutex
cursor uint32
maximum uint16
}
// NewOutboundTopicAliases returns a pointer to OutboundTopicAliases.
func NewOutboundTopicAliases(topicAliasMaximum uint16) *OutboundTopicAliases {
return &OutboundTopicAliases{
maximum: topicAliasMaximum,
internal: map[string]uint16{},
}
}
// Set sets a new topic alias for a topic and returns the alias value, and a boolean
// indicating if the alias already existed.
func (a *OutboundTopicAliases) Set(topic string) (uint16, bool) {
a.Lock()
defer a.Unlock()
if a.maximum == 0 {
return 0, false
}
if i, ok := a.internal[topic]; ok {
return i, true
}
i := atomic.LoadUint32(&a.cursor)
if i+1 > uint32(a.maximum) {
// if i+1 > math.MaxUint16 {
return 0, false
}
a.internal[topic] = uint16(i) + 1
atomic.StoreUint32(&a.cursor, i+1)
return uint16(i) + 1, false
}
// SharedSubscriptions contains a map of subscriptions to a shared filter,
// keyed on share group then client id.
type SharedSubscriptions struct {
internal map[string]map[string]packets.Subscription
sync.RWMutex
}
// NewSharedSubscriptions returns a new instance of Subscriptions.
func NewSharedSubscriptions() *SharedSubscriptions {
return &SharedSubscriptions{
internal: map[string]map[string]packets.Subscription{},
}
}
// Add creates a new shared subscription for a group and client id pair.
func (s *SharedSubscriptions) Add(group, id string, val packets.Subscription) {
s.Lock()
defer s.Unlock()
if _, ok := s.internal[group]; !ok {
s.internal[group] = map[string]packets.Subscription{}
}
s.internal[group][id] = val
}
// Delete deletes a client id from a shared subscription group.
func (s *SharedSubscriptions) Delete(group, id string) {
s.Lock()
defer s.Unlock()
delete(s.internal[group], id)
if len(s.internal[group]) == 0 {
delete(s.internal, group)
}
}
// Get returns the subscription properties for a client id in a share group, if one exists.
func (s *SharedSubscriptions) Get(group, id string) (val packets.Subscription, ok bool) {
s.RLock()
defer s.RUnlock()
if _, ok := s.internal[group]; !ok {
return val, ok
}
val, ok = s.internal[group][id]
return val, ok
}
// GroupLen returns the number of groups subscribed to the filter.
func (s *SharedSubscriptions) GroupLen() int {
s.RLock()
defer s.RUnlock()
val := len(s.internal)
return val
}
// Len returns the total number of shared subscriptions to a filter across all groups.
func (s *SharedSubscriptions) Len() int {
s.RLock()
defer s.RUnlock()
n := 0
for _, group := range s.internal {
n += len(group)
}
return n
}
// GetAll returns all shared subscription groups and their subscriptions.
func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscription {
s.RLock()
defer s.RUnlock()
m := map[string]map[string]packets.Subscription{}
for group, subs := range s.internal {
if _, ok := m[group]; !ok {
m[group] = map[string]packets.Subscription{}
}
for id, sub := range subs {
m[group][id] = sub
}
}
return m
}
// InlineSubFn is the signature for a callback function which will be called
// when an inline client receives a message on a topic it is subscribed to.
// The sub argument contains information about the subscription that was matched for any filters.
type InlineSubFn func(cl *Client, sub packets.Subscription, pk packets.Packet)
// InlineSubscriptions represents a map of internal subscriptions keyed on client.
type InlineSubscriptions struct {
internal map[int]InlineSubscription
sync.RWMutex
}
// NewInlineSubscriptions returns a new instance of InlineSubscriptions.
func NewInlineSubscriptions() *InlineSubscriptions {
return &InlineSubscriptions{
internal: map[int]InlineSubscription{},
}
}
// Add adds a new internal subscription for a client id.
func (s *InlineSubscriptions) Add(val InlineSubscription) {
s.Lock()
defer s.Unlock()
s.internal[val.Identifier] = val
}
// GetAll returns all internal subscriptions.
func (s *InlineSubscriptions) GetAll() map[int]InlineSubscription {
s.RLock()
defer s.RUnlock()
m := map[int]InlineSubscription{}
for k, v := range s.internal {
m[k] = v
}
return m
}
// Get returns an internal subscription for a client id.
func (s *InlineSubscriptions) Get(id int) (val InlineSubscription, ok bool) {
s.RLock()
defer s.RUnlock()
val, ok = s.internal[id]
return val, ok
}
// Len returns the number of internal subscriptions.
func (s *InlineSubscriptions) Len() int {
s.RLock()
defer s.RUnlock()
val := len(s.internal)
return val
}
// Delete removes an internal subscription by the client id.
func (s *InlineSubscriptions) Delete(id int) {
s.Lock()
defer s.Unlock()
delete(s.internal, id)
}
// Subscriptions is a map of subscriptions keyed on client.
type Subscriptions struct {
internal map[string]packets.Subscription
sync.RWMutex
}
// NewSubscriptions returns a new instance of Subscriptions.
func NewSubscriptions() *Subscriptions {
return &Subscriptions{
internal: map[string]packets.Subscription{},
}
}
// Add adds a new subscription for a client. ID can be a filter in the
// case this map is client state, or a client id if particle state.
func (s *Subscriptions) Add(id string, val packets.Subscription) {
s.Lock()
defer s.Unlock()
s.internal[id] = val
}
// GetAll returns all subscriptions.
func (s *Subscriptions) GetAll() map[string]packets.Subscription {
s.RLock()
defer s.RUnlock()
m := map[string]packets.Subscription{}
for k, v := range s.internal {
m[k] = v
}
return m
}
// Get returns a subscriptions for a specific client or filter id.
func (s *Subscriptions) Get(id string) (val packets.Subscription, ok bool) {
s.RLock()
defer s.RUnlock()
val, ok = s.internal[id]
return val, ok
}
// Len returns the number of subscriptions.
func (s *Subscriptions) Len() int {
s.RLock()
defer s.RUnlock()
val := len(s.internal)
return val
}
// Delete removes a subscription by client or filter id.
func (s *Subscriptions) Delete(id string) {
s.Lock()
defer s.Unlock()
delete(s.internal, id)
}
// ClientSubscriptions is a map of aggregated subscriptions for a client.
type ClientSubscriptions map[string]packets.Subscription
type InlineSubscription struct {
packets.Subscription
Handler InlineSubFn
}
// Subscribers contains the shared and non-shared subscribers matching a topic.
type Subscribers struct {
Shared map[string]map[string]packets.Subscription
SharedSelected map[string]packets.Subscription
Subscriptions map[string]packets.Subscription
InlineSubscriptions map[int]InlineSubscription
}
// SelectShared returns one subscriber for each shared subscription group.
func (s *Subscribers) SelectShared() {
s.SharedSelected = map[string]packets.Subscription{}
for _, subs := range s.Shared {
for client, sub := range subs {
cls, ok := s.SharedSelected[client]
if !ok {
cls = sub
}
s.SharedSelected[client] = cls.Merge(sub)
break
}
}
}
// MergeSharedSelected merges the selected subscribers for a shared subscription group
// and the non-shared subscribers, to ensure that no subscriber gets multiple messages
// due to have both types of subscription matching the same filter.
func (s *Subscribers) MergeSharedSelected() {
for client, sub := range s.SharedSelected {
cls, ok := s.Subscriptions[client]
if !ok {
cls = sub
}
s.Subscriptions[client] = cls.Merge(sub)
}
}
// TopicsIndex is a prefix/trie tree containing topic subscribers and retained messages.
type TopicsIndex struct {
Retained *packets.Packets
root *particle // a leaf containing a message and more leaves.
}
// NewTopicsIndex returns a pointer to a new instance of Index.
func NewTopicsIndex() *TopicsIndex {
return &TopicsIndex{
Retained: packets.NewPackets(),
root: &particle{
particles: newParticles(),
subscriptions: NewSubscriptions(),
},
}
}
// InlineSubscribe adds a new internal subscription for a topic filter, returning
// true if the subscription was new.
func (x *TopicsIndex) InlineSubscribe(subscription InlineSubscription) bool {
x.root.Lock()
defer x.root.Unlock()
var existed bool
n := x.set(subscription.Filter, 0)
_, existed = n.inlineSubscriptions.Get(subscription.Identifier)
n.inlineSubscriptions.Add(subscription)
return !existed
}
// InlineUnsubscribe removes an internal subscription for a topic filter associated with a specific client,
// returning true if the subscription existed.
func (x *TopicsIndex) InlineUnsubscribe(id int, filter string) bool {
x.root.Lock()
defer x.root.Unlock()
particle := x.seek(filter, 0)
if particle == nil {
return false
}
particle.inlineSubscriptions.Delete(id)
if particle.inlineSubscriptions.Len() == 0 {
x.trim(particle)
}
return true
}
// Subscribe adds a new subscription for a client to a topic filter, returning
// true if the subscription was new.
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
x.root.Lock()
defer x.root.Unlock()
var existed bool
prefix, _ := isolateParticle(subscription.Filter, 0)
if strings.EqualFold(prefix, SharePrefix) {
group, _ := isolateParticle(subscription.Filter, 1)
n := x.set(subscription.Filter, 2)
_, existed = n.shared.Get(group, client)
n.shared.Add(group, client, subscription)
} else {
n := x.set(subscription.Filter, 0)
_, existed = n.subscriptions.Get(client)
n.subscriptions.Add(client, subscription)
}
return !existed
}
// Unsubscribe removes a subscription filter for a client, returning true if the
// subscription existed.
func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
x.root.Lock()
defer x.root.Unlock()
var d int
prefix, _ := isolateParticle(filter, 0)
shareSub := strings.EqualFold(prefix, SharePrefix)
if shareSub {
d = 2
}
particle := x.seek(filter, d)
if particle == nil {
return false
}
if shareSub {
group, _ := isolateParticle(filter, 1)
particle.shared.Delete(group, client)
} else {
particle.subscriptions.Delete(client)
}
x.trim(particle)
return true
}
// RetainMessage saves a message payload to the end of a topic address. Returns
// 1 if a retained message was added, and -1 if the retained message was removed.
// 0 is returned if sequential empty payloads are received.
func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
x.root.Lock()
defer x.root.Unlock()
n := x.set(pk.TopicName, 0)
n.Lock()
defer n.Unlock()
if len(pk.Payload) > 0 {
n.retainPath = pk.TopicName
x.Retained.Add(pk.TopicName, pk)
return 1
}
var out int64
if pke, ok := x.Retained.Get(pk.TopicName); ok && len(pke.Payload) > 0 && pke.FixedHeader.Retain {
out = -1 // if a retained packet existed, return -1
}
n.retainPath = ""
x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7]
x.trim(n)
return out
}
// set creates a topic address in the index and returns the final particle.
func (x *TopicsIndex) set(topic string, d int) *particle {
var key string
var hasNext = true
n := x.root
for hasNext {
key, hasNext = isolateParticle(topic, d)
d++
p := n.particles.get(key)
if p == nil {
p = newParticle(key, n)
n.particles.add(p)
}
n = p
}
return n
}
// seek finds the particle at a specific index in a topic filter.
func (x *TopicsIndex) seek(filter string, d int) *particle {
var key string
var hasNext = true
n := x.root
for hasNext {
key, hasNext = isolateParticle(filter, d)
n = n.particles.get(key)
d++
if n == nil {
return nil
}
}
return n
}
// trim removes empty filter particles from the index.
func (x *TopicsIndex) trim(n *particle) {
for n.parent != nil && n.retainPath == "" && n.particles.len()+n.subscriptions.Len()+n.shared.Len()+n.inlineSubscriptions.Len() == 0 {
key := n.key
n = n.parent
n.particles.delete(key)
}
}
// Messages returns a slice of any retained messages which match a filter.
func (x *TopicsIndex) Messages(filter string) []packets.Packet {
return x.scanMessages(filter, 0, nil, []packets.Packet{})
}
// scanMessages returns all retained messages on topics matching a given filter.
func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []packets.Packet) []packets.Packet {
if n == nil {
n = x.root
}
if len(filter) == 0 || x.Retained.Len() == 0 {
return pks
}
if !strings.ContainsRune(filter, '#') && !strings.ContainsRune(filter, '+') {
if pk, ok := x.Retained.Get(filter); ok {
pks = append(pks, pk)
}
return pks
}
key, hasNext := isolateParticle(filter, d)
if key == "+" || key == "#" || d == -1 {
for _, adjacent := range n.particles.getAll() {
if d == 0 && adjacent.key == SysPrefix {
continue
}
if !hasNext {
if adjacent.retainPath != "" {
if pk, ok := x.Retained.Get(adjacent.retainPath); ok {
pks = append(pks, pk)
}
}
}
if hasNext || (d >= 0 && key == "#") {
pks = x.scanMessages(filter, d+1, adjacent, pks)
}
}
return pks
}
if particle := n.particles.get(key); particle != nil {
if hasNext {
return x.scanMessages(filter, d+1, particle, pks)
}
if pk, ok := x.Retained.Get(particle.retainPath); ok {
pks = append(pks, pk)
}
}
return pks
}
// Subscribers returns a map of clients who are subscribed to matching filters,
// their subscription ids and highest qos.
func (x *TopicsIndex) Subscribers(topic string) *Subscribers {
return x.scanSubscribers(topic, 0, nil, &Subscribers{
Shared: map[string]map[string]packets.Subscription{},
SharedSelected: map[string]packets.Subscription{},
Subscriptions: map[string]packets.Subscription{},
InlineSubscriptions: map[int]InlineSubscription{},
})
}
// scanSubscribers returns a list of client subscriptions matching an indexed topic address.
func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Subscribers) *Subscribers {
if n == nil {
n = x.root
}
if len(topic) == 0 {
return subs
}
key, hasNext := isolateParticle(topic, d)
for _, partKey := range []string{key, "+"} {
if particle := n.particles.get(partKey); particle != nil { // [MQTT-3.3.2-3]
if hasNext {
x.scanSubscribers(topic, d+1, particle, subs)
} else {
x.gatherSubscriptions(topic, particle, subs)
x.gatherSharedSubscriptions(particle, subs)
x.gatherInlineSubscriptions(particle, subs)
if wild := particle.particles.get("#"); wild != nil && partKey != "+" {
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
x.gatherSharedSubscriptions(wild, subs)
x.gatherInlineSubscriptions(particle, subs)
}
}
}
}
if particle := n.particles.get("#"); particle != nil {
x.gatherSubscriptions(topic, particle, subs)
x.gatherSharedSubscriptions(particle, subs)
x.gatherInlineSubscriptions(particle, subs)
}
return subs
}
// gatherSubscriptions collects any matching subscriptions, and gathers any identifiers or highest qos values.
func (x *TopicsIndex) gatherSubscriptions(topic string, particle *particle, subs *Subscribers) {
if subs.Subscriptions == nil {
subs.Subscriptions = map[string]packets.Subscription{}
}
for client, sub := range particle.subscriptions.GetAll() {
if len(sub.Filter) > 0 && topic[0] == '$' && (sub.Filter[0] == '+' || sub.Filter[0] == '#') { // don't match $ topics with top level wildcards [MQTT-4.7.1-1] [MQTT-4.7.1-2]
continue
}
cls, ok := subs.Subscriptions[client]
if !ok {
cls = sub
}
subs.Subscriptions[client] = cls.Merge(sub)
}
}
// gatherSharedSubscriptions gathers all shared subscriptions for a particle.
func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscribers) {
if subs.Shared == nil {
subs.Shared = map[string]map[string]packets.Subscription{}
}
for _, shares := range particle.shared.GetAll() {
for client, sub := range shares {
if _, ok := subs.Shared[sub.Filter]; !ok {
subs.Shared[sub.Filter] = map[string]packets.Subscription{}
}
subs.Shared[sub.Filter][client] = sub
}
}
}
// gatherSharedSubscriptions gathers all inline subscriptions for a particle.
func (x *TopicsIndex) gatherInlineSubscriptions(particle *particle, subs *Subscribers) {
if subs.InlineSubscriptions == nil {
subs.InlineSubscriptions = map[int]InlineSubscription{}
}
for id, inline := range particle.inlineSubscriptions.GetAll() {
subs.InlineSubscriptions[id] = inline
}
}
// isolateParticle extracts a particle between d / and d+1 / without allocations.
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
var next, end int
for i := 0; end > -1 && i <= d; i++ {
end = strings.IndexRune(filter, '/')
switch {
case d > -1 && i == d && end > -1:
hasNext = true
particle = filter[next:end]
case end > -1:
hasNext = false
filter = filter[end+1:]
default:
hasNext = false
particle = filter[next:]
}
}
return
}
// IsSharedFilter returns true if the filter uses the share prefix.
func IsSharedFilter(filter string) bool {
prefix, _ := isolateParticle(filter, 0)
return strings.EqualFold(prefix, SharePrefix)
}
// IsValidFilter returns true if the filter is valid.
func IsValidFilter(filter string, forPublish bool) bool {
if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publish.
return false // [MQTT-4.7.3-1]
}
if forPublish {
if len(filter) >= len(SysPrefix) && strings.EqualFold(filter[0:len(SysPrefix)], SysPrefix) {
// 4.7.2 Non-normative - The Server SHOULD prevent Clients from using such Topic Names [$SYS] to exchange messages with other Clients.
return false
}
if strings.ContainsRune(filter, '+') || strings.ContainsRune(filter, '#') {
return false //[MQTT-3.3.2-2]
}
}
wildhash := strings.IndexRune(filter, '#')
if wildhash >= 0 && wildhash != len(filter)-1 { // [MQTT-4.7.1-2]
return false
}
prefix, hasNext := isolateParticle(filter, 0)
if !hasNext && strings.EqualFold(prefix, SharePrefix) {
return false // [MQTT-4.8.2-1]
}
if hasNext && strings.EqualFold(prefix, SharePrefix) {
group, hasNext := isolateParticle(filter, 1)
if !hasNext {
return false // [MQTT-4.8.2-1]
}
if strings.ContainsRune(group, '+') || strings.ContainsRune(group, '#') {
return false // [MQTT-4.8.2-2]
}
}
return true
}
// particle is a child node on the tree.
type particle struct {
key string // the key of the particle
parent *particle // a pointer to the parent of the particle
particles particles // a map of child particles
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
inlineSubscriptions *InlineSubscriptions // a map of inline subscriptions for this particle
retainPath string // path of a retained message
sync.Mutex // mutex for when making changes to the particle
}
// newParticle returns a pointer to a new instance of particle.
func newParticle(key string, parent *particle) *particle {
return &particle{
key: key,
parent: parent,
particles: newParticles(),
subscriptions: NewSubscriptions(),
shared: NewSharedSubscriptions(),
inlineSubscriptions: NewInlineSubscriptions(),
}
}
// particles is a concurrency safe map of particles.
type particles struct {
internal map[string]*particle
sync.RWMutex
}
// newParticles returns a map of particles.
func newParticles() particles {
return particles{
internal: map[string]*particle{},
}
}
// add adds a new particle.
func (p *particles) add(val *particle) {
p.Lock()
p.internal[val.key] = val
p.Unlock()
}
// getAll returns all particles.
func (p *particles) getAll() map[string]*particle {
p.RLock()
defer p.RUnlock()
m := map[string]*particle{}
for k, v := range p.internal {
m[k] = v
}
return m
}
// get returns a particle by id (key).
func (p *particles) get(id string) *particle {
p.RLock()
defer p.RUnlock()
return p.internal[id]
}
// len returns the number of particles.
func (p *particles) len() int {
p.RLock()
defer p.RUnlock()
val := len(p.internal)
return val
}
// delete removes a particle.
func (p *particles) delete(id string) {
p.Lock()
defer p.Unlock()
delete(p.internal, id)
}

1068
mqtt/topics_test.go Normal file

File diff suppressed because it is too large Load Diff