代码整理
This commit is contained in:
649
mqtt/clients.go
Normal file
649
mqtt/clients.go
Normal file
@@ -0,0 +1,649 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"testmqtt/packets"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultKeepalive uint16 = 10 // the default connection keepalive value in seconds.
|
||||
defaultClientProtocolVersion byte = 4 // the default mqtt protocol version of connecting clients (if somehow unspecified).
|
||||
minimumKeepalive uint16 = 5 // the minimum recommended keepalive - values under with display a warning.
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMinimumKeepalive = errors.New("client keepalive is below minimum recommended value and may exhibit connection instability")
|
||||
)
|
||||
|
||||
// ReadFn is the function signature for the function used for reading and processing new packets.
|
||||
type ReadFn func(*Client, packets.Packet) error
|
||||
|
||||
// Clients contains a map of the clients known by the broker.
|
||||
type Clients struct {
|
||||
internal map[string]*Client // clients known by the broker, keyed on client id.
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// NewClients returns an instance of Clients.
|
||||
func NewClients() *Clients {
|
||||
return &Clients{
|
||||
internal: make(map[string]*Client),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new client to the clients map, keyed on client id.
|
||||
func (cl *Clients) Add(val *Client) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
cl.internal[val.ID] = val
|
||||
}
|
||||
|
||||
// GetAll returns all the clients.
|
||||
func (cl *Clients) GetAll() map[string]*Client {
|
||||
cl.RLock()
|
||||
defer cl.RUnlock()
|
||||
m := map[string]*Client{}
|
||||
for k, v := range cl.internal {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Get returns the value of a client if it exists.
|
||||
func (cl *Clients) Get(id string) (*Client, bool) {
|
||||
cl.RLock()
|
||||
defer cl.RUnlock()
|
||||
val, ok := cl.internal[id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the length of the clients map.
|
||||
func (cl *Clients) Len() int {
|
||||
cl.RLock()
|
||||
defer cl.RUnlock()
|
||||
val := len(cl.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// Delete removes a client from the internal map.
|
||||
func (cl *Clients) Delete(id string) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
delete(cl.internal, id)
|
||||
}
|
||||
|
||||
// GetByListener returns clients matching a listener id.
|
||||
func (cl *Clients) GetByListener(id string) []*Client {
|
||||
cl.RLock()
|
||||
defer cl.RUnlock()
|
||||
clients := make([]*Client, 0, cl.Len())
|
||||
for _, client := range cl.internal {
|
||||
if client.Net.Listener == id && !client.Closed() {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
}
|
||||
return clients
|
||||
}
|
||||
|
||||
// Client contains information about a client known by the broker.
|
||||
type Client struct {
|
||||
Properties ClientProperties // client properties
|
||||
State ClientState // the operational state of the client.
|
||||
Net ClientConnection // network connection state of the client
|
||||
ID string // the client id.
|
||||
ops *ops // ops provides a reference to server ops.
|
||||
sync.RWMutex // mutex
|
||||
}
|
||||
|
||||
// ClientConnection contains the connection transport and metadata for the client.
|
||||
type ClientConnection struct {
|
||||
Conn net.Conn // the net.Conn used to establish the connection
|
||||
bconn *bufio.Reader // a buffered net.Conn for reading packets
|
||||
outbuf *bytes.Buffer // a buffer for writing packets
|
||||
Remote string // the remote address of the client
|
||||
Listener string // listener id of the client
|
||||
Inline bool // if true, the client is the built-in 'inline' embedded client
|
||||
}
|
||||
|
||||
// ClientProperties contains the properties which define the client behaviour.
|
||||
type ClientProperties struct {
|
||||
Props packets.Properties
|
||||
Will Will
|
||||
Username []byte
|
||||
ProtocolVersion byte
|
||||
Clean bool
|
||||
}
|
||||
|
||||
// Will contains the last will and testament details for a client connection.
|
||||
type Will struct {
|
||||
Payload []byte // -
|
||||
User []packets.UserProperty // -
|
||||
TopicName string // -
|
||||
Flag uint32 // 0,1
|
||||
WillDelayInterval uint32 // -
|
||||
Qos byte // -
|
||||
Retain bool // -
|
||||
}
|
||||
|
||||
// ClientState tracks the state of the client.
|
||||
type ClientState struct {
|
||||
TopicAliases TopicAliases // a map of topic aliases
|
||||
stopCause atomic.Value // reason for stopping
|
||||
Inflight *Inflight // a map of in-flight qos messages
|
||||
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
|
||||
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
|
||||
outbound chan *packets.Packet // queue for pending outbound packets
|
||||
endOnce sync.Once // only end once
|
||||
isTakenOver uint32 // used to identify orphaned clients
|
||||
packetID uint32 // the current highest packetID
|
||||
open context.Context // indicate that the client is open for packet exchange
|
||||
cancelOpen context.CancelFunc // cancel function for open context
|
||||
outboundQty int32 // number of messages currently in the outbound queue
|
||||
Keepalive uint16 // the number of seconds the connection can wait
|
||||
ServerKeepalive bool // keepalive was set by the server
|
||||
}
|
||||
|
||||
// newClient returns a new instance of Client. This is almost exclusively used by Server
|
||||
// for creating new clients, but it lives here because it's not dependent.
|
||||
func newClient(c net.Conn, o *ops) *Client {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cl := &Client{
|
||||
State: ClientState{
|
||||
Inflight: NewInflights(),
|
||||
Subscriptions: NewSubscriptions(),
|
||||
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
|
||||
open: ctx,
|
||||
cancelOpen: cancel,
|
||||
Keepalive: defaultKeepalive,
|
||||
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
|
||||
},
|
||||
Properties: ClientProperties{
|
||||
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
|
||||
},
|
||||
ops: o,
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
cl.Net = ClientConnection{
|
||||
Conn: c,
|
||||
bconn: bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
|
||||
Remote: c.RemoteAddr().String(),
|
||||
}
|
||||
}
|
||||
|
||||
return cl
|
||||
}
|
||||
|
||||
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
|
||||
func (cl *Client) WriteLoop() {
|
||||
for {
|
||||
select {
|
||||
case pk := <-cl.State.outbound:
|
||||
if err := cl.WritePacket(*pk); err != nil {
|
||||
// TODO : Figure out what to do with error
|
||||
cl.ops.log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk)
|
||||
}
|
||||
atomic.AddInt32(&cl.State.outboundQty, -1)
|
||||
case <-cl.State.open.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ParseConnect parses the connect parameters and properties for a client.
|
||||
func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
|
||||
cl.Net.Listener = lid
|
||||
|
||||
cl.Properties.ProtocolVersion = pk.ProtocolVersion
|
||||
cl.Properties.Username = pk.Connect.Username
|
||||
cl.Properties.Clean = pk.Connect.Clean
|
||||
cl.Properties.Props = pk.Properties.Copy(false)
|
||||
|
||||
if cl.Properties.Props.ReceiveMaximum > cl.ops.options.Capabilities.MaximumInflight { // 3.3.4 Non-normative
|
||||
cl.Properties.Props.ReceiveMaximum = cl.ops.options.Capabilities.MaximumInflight
|
||||
}
|
||||
|
||||
if pk.Connect.Keepalive <= minimumKeepalive {
|
||||
cl.ops.log.Warn(
|
||||
ErrMinimumKeepalive.Error(),
|
||||
"client", cl.ID,
|
||||
"keepalive", pk.Connect.Keepalive,
|
||||
"recommended", minimumKeepalive,
|
||||
)
|
||||
}
|
||||
|
||||
cl.State.Keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
|
||||
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
|
||||
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
|
||||
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
|
||||
|
||||
cl.ID = pk.Connect.ClientIdentifier
|
||||
if cl.ID == "" {
|
||||
cl.ID = xid.New().String() // [MQTT-3.1.3-6] [MQTT-3.1.3-7]
|
||||
cl.Properties.Props.AssignedClientID = cl.ID
|
||||
}
|
||||
|
||||
if pk.Connect.WillFlag {
|
||||
cl.Properties.Will = Will{
|
||||
Qos: pk.Connect.WillQos,
|
||||
Retain: pk.Connect.WillRetain,
|
||||
Payload: pk.Connect.WillPayload,
|
||||
TopicName: pk.Connect.WillTopic,
|
||||
WillDelayInterval: pk.Connect.WillProperties.WillDelayInterval,
|
||||
User: pk.Connect.WillProperties.User,
|
||||
}
|
||||
if pk.Properties.SessionExpiryIntervalFlag &&
|
||||
pk.Properties.SessionExpiryInterval < pk.Connect.WillProperties.WillDelayInterval {
|
||||
cl.Properties.Will.WillDelayInterval = pk.Properties.SessionExpiryInterval
|
||||
}
|
||||
if pk.Connect.WillFlag {
|
||||
cl.Properties.Will.Flag = 1 // atomic for checking
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
|
||||
func (cl *Client) refreshDeadline(keepalive uint16) {
|
||||
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
|
||||
if keepalive > 0 {
|
||||
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
|
||||
}
|
||||
|
||||
if cl.Net.Conn != nil {
|
||||
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
|
||||
}
|
||||
}
|
||||
|
||||
// NextPacketID returns the next available (unused) packet id for the client.
|
||||
// If no unused packet ids are available, an error is returned and the client
|
||||
// should be disconnected.
|
||||
func (cl *Client) NextPacketID() (i uint32, err error) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
|
||||
i = atomic.LoadUint32(&cl.State.packetID)
|
||||
started := i
|
||||
overflowed := false
|
||||
for {
|
||||
if overflowed && i == started {
|
||||
return 0, packets.ErrQuotaExceeded
|
||||
}
|
||||
|
||||
if i >= cl.ops.options.Capabilities.maximumPacketID {
|
||||
overflowed = true
|
||||
i = 0
|
||||
continue
|
||||
}
|
||||
|
||||
i++
|
||||
|
||||
if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
|
||||
atomic.StoreUint32(&cl.State.packetID, i)
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
|
||||
func (cl *Client) ResendInflightMessages(force bool) error {
|
||||
if cl.State.Inflight.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, tk := range cl.State.Inflight.GetAll(false) {
|
||||
if tk.FixedHeader.Type == packets.Publish {
|
||||
tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
|
||||
}
|
||||
|
||||
cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0)
|
||||
err := cl.WritePacket(tk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if tk.FixedHeader.Type == packets.Puback || tk.FixedHeader.Type == packets.Pubcomp {
|
||||
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
||||
cl.ops.hooks.OnQosComplete(cl, tk)
|
||||
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearInflights deletes all inflight messages for the client, e.g. for a disconnected user with a clean session.
|
||||
func (cl *Client) ClearInflights() {
|
||||
for _, tk := range cl.State.Inflight.GetAll(false) {
|
||||
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
||||
cl.ops.hooks.OnQosDropped(cl, tk)
|
||||
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ClearExpiredInflights deletes any inflight messages which have expired.
|
||||
func (cl *Client) ClearExpiredInflights(now, maximumExpiry int64) []uint16 {
|
||||
deleted := []uint16{}
|
||||
for _, tk := range cl.State.Inflight.GetAll(false) {
|
||||
expired := tk.ProtocolVersion == 5 && tk.Expiry > 0 && tk.Expiry < now // [MQTT-3.3.2-5]
|
||||
|
||||
// If the maximum message expiry interval is set (greater than 0), and the message
|
||||
// retention period exceeds the maximum expiry, the message will be forcibly removed.
|
||||
enforced := maximumExpiry > 0 && now-tk.Created > maximumExpiry
|
||||
|
||||
if expired || enforced {
|
||||
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
||||
cl.ops.hooks.OnQosDropped(cl, tk)
|
||||
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
||||
deleted = append(deleted, tk.PacketID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return deleted
|
||||
}
|
||||
|
||||
// Read reads incoming packets from the connected client and transforms them into
|
||||
// packets to be handled by the packetHandler.
|
||||
func (cl *Client) Read(packetHandler ReadFn) error {
|
||||
var err error
|
||||
|
||||
for {
|
||||
if cl.Closed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.State.Keepalive)
|
||||
fh := new(packets.FixedHeader)
|
||||
err = cl.ReadFixedHeader(fh)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pk, err := cl.ReadPacket(fh)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = packetHandler(cl, pk) // Process inbound packet.
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop instructs the client to shut down all processing goroutines and disconnect.
|
||||
func (cl *Client) Stop(err error) {
|
||||
cl.State.endOnce.Do(func() {
|
||||
|
||||
if cl.Net.Conn != nil {
|
||||
_ = cl.Net.Conn.Close() // omit close error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
cl.State.stopCause.Store(err)
|
||||
}
|
||||
|
||||
if cl.State.cancelOpen != nil {
|
||||
cl.State.cancelOpen()
|
||||
}
|
||||
|
||||
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
|
||||
})
|
||||
}
|
||||
|
||||
// StopCause returns the reason the client connection was stopped, if any.
|
||||
func (cl *Client) StopCause() error {
|
||||
if cl.State.stopCause.Load() == nil {
|
||||
return nil
|
||||
}
|
||||
return cl.State.stopCause.Load().(error)
|
||||
}
|
||||
|
||||
// StopTime returns the the time the client disconnected in unix time, else zero.
|
||||
func (cl *Client) StopTime() int64 {
|
||||
return atomic.LoadInt64(&cl.State.disconnected)
|
||||
}
|
||||
|
||||
// Closed returns true if client connection is closed.
|
||||
func (cl *Client) Closed() bool {
|
||||
return cl.State.open == nil || cl.State.open.Err() != nil
|
||||
}
|
||||
|
||||
// ReadFixedHeader reads in the values of the next packet's fixed header.
|
||||
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
if cl.Net.bconn == nil {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
b, err := cl.Net.bconn.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = fh.Decode(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var bu int
|
||||
fh.Remaining, bu, err = packets.DecodeLength(cl.Net.bconn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize {
|
||||
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
|
||||
}
|
||||
|
||||
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadPacket reads the remaining buffer into an MQTT packet.
|
||||
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
|
||||
atomic.AddInt64(&cl.ops.info.PacketsReceived, 1)
|
||||
|
||||
pk.ProtocolVersion = cl.Properties.ProtocolVersion // inherit client protocol version for decoding
|
||||
pk.FixedHeader = *fh
|
||||
p := make([]byte, pk.FixedHeader.Remaining)
|
||||
n, err := io.ReadFull(cl.Net.bconn, p)
|
||||
if err != nil {
|
||||
return pk, err
|
||||
}
|
||||
|
||||
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(n))
|
||||
|
||||
// Decode the remaining packet values using a fresh copy of the bytes,
|
||||
// otherwise the next packet will change the data of this one.
|
||||
px := append([]byte{}, p[:]...)
|
||||
switch pk.FixedHeader.Type {
|
||||
case packets.Connect:
|
||||
err = pk.ConnectDecode(px)
|
||||
case packets.Disconnect:
|
||||
err = pk.DisconnectDecode(px)
|
||||
case packets.Connack:
|
||||
err = pk.ConnackDecode(px)
|
||||
case packets.Publish:
|
||||
err = pk.PublishDecode(px)
|
||||
if err == nil {
|
||||
atomic.AddInt64(&cl.ops.info.MessagesReceived, 1)
|
||||
}
|
||||
case packets.Puback:
|
||||
err = pk.PubackDecode(px)
|
||||
case packets.Pubrec:
|
||||
err = pk.PubrecDecode(px)
|
||||
case packets.Pubrel:
|
||||
err = pk.PubrelDecode(px)
|
||||
case packets.Pubcomp:
|
||||
err = pk.PubcompDecode(px)
|
||||
case packets.Subscribe:
|
||||
err = pk.SubscribeDecode(px)
|
||||
case packets.Suback:
|
||||
err = pk.SubackDecode(px)
|
||||
case packets.Unsubscribe:
|
||||
err = pk.UnsubscribeDecode(px)
|
||||
case packets.Unsuback:
|
||||
err = pk.UnsubackDecode(px)
|
||||
case packets.Pingreq:
|
||||
case packets.Pingresp:
|
||||
case packets.Auth:
|
||||
err = pk.AuthDecode(px)
|
||||
default:
|
||||
err = fmt.Errorf("invalid packet type; %v", pk.FixedHeader.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return pk, err
|
||||
}
|
||||
|
||||
pk, err = cl.ops.hooks.OnPacketRead(cl, pk)
|
||||
return
|
||||
}
|
||||
|
||||
// WritePacket encodes and writes a packet to the client.
|
||||
func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
if cl.Closed() {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
if cl.Net.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if pk.Expiry > 0 {
|
||||
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
|
||||
}
|
||||
|
||||
pk.ProtocolVersion = cl.Properties.ProtocolVersion
|
||||
if pk.Mods.MaxSize == 0 { // NB we use this statement to embed client packet sizes in tests
|
||||
pk.Mods.MaxSize = cl.Properties.Props.MaximumPacketSize
|
||||
}
|
||||
|
||||
if cl.Properties.Props.RequestProblemInfoFlag && cl.Properties.Props.RequestProblemInfo == 0x0 {
|
||||
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
|
||||
}
|
||||
|
||||
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo {
|
||||
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
|
||||
}
|
||||
|
||||
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
|
||||
|
||||
var err error
|
||||
buf := new(bytes.Buffer)
|
||||
switch pk.FixedHeader.Type {
|
||||
case packets.Connect:
|
||||
err = pk.ConnectEncode(buf)
|
||||
case packets.Connack:
|
||||
err = pk.ConnackEncode(buf)
|
||||
case packets.Publish:
|
||||
err = pk.PublishEncode(buf)
|
||||
case packets.Puback:
|
||||
err = pk.PubackEncode(buf)
|
||||
case packets.Pubrec:
|
||||
err = pk.PubrecEncode(buf)
|
||||
case packets.Pubrel:
|
||||
err = pk.PubrelEncode(buf)
|
||||
case packets.Pubcomp:
|
||||
err = pk.PubcompEncode(buf)
|
||||
case packets.Subscribe:
|
||||
err = pk.SubscribeEncode(buf)
|
||||
case packets.Suback:
|
||||
err = pk.SubackEncode(buf)
|
||||
case packets.Unsubscribe:
|
||||
err = pk.UnsubscribeEncode(buf)
|
||||
case packets.Unsuback:
|
||||
err = pk.UnsubackEncode(buf)
|
||||
case packets.Pingreq:
|
||||
err = pk.PingreqEncode(buf)
|
||||
case packets.Pingresp:
|
||||
err = pk.PingrespEncode(buf)
|
||||
case packets.Disconnect:
|
||||
err = pk.DisconnectEncode(buf)
|
||||
case packets.Auth:
|
||||
err = pk.AuthEncode(buf)
|
||||
default:
|
||||
err = fmt.Errorf("%w: %v", packets.ErrNoValidPacketAvailable, pk.FixedHeader.Type)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pk.Mods.MaxSize > 0 && uint32(buf.Len()) > pk.Mods.MaxSize {
|
||||
return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25]
|
||||
}
|
||||
|
||||
n, err := func() (int64, error) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
if len(cl.State.outbound) == 0 {
|
||||
if cl.Net.outbuf == nil {
|
||||
return buf.WriteTo(cl.Net.Conn)
|
||||
}
|
||||
|
||||
// first write to buffer, then flush buffer
|
||||
n, _ := cl.Net.outbuf.Write(buf.Bytes()) // will always be successful
|
||||
err = cl.flushOutbuf()
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// there are more writes in the queue
|
||||
if cl.Net.outbuf == nil {
|
||||
if buf.Len() >= cl.ops.options.ClientNetWriteBufferSize {
|
||||
return buf.WriteTo(cl.Net.Conn)
|
||||
}
|
||||
cl.Net.outbuf = new(bytes.Buffer)
|
||||
}
|
||||
|
||||
n, _ := cl.Net.outbuf.Write(buf.Bytes()) // will always be successful
|
||||
if cl.Net.outbuf.Len() < cl.ops.options.ClientNetWriteBufferSize {
|
||||
return int64(n), nil
|
||||
}
|
||||
|
||||
err = cl.flushOutbuf()
|
||||
return int64(n), err
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
atomic.AddInt64(&cl.ops.info.BytesSent, n)
|
||||
atomic.AddInt64(&cl.ops.info.PacketsSent, 1)
|
||||
if pk.FixedHeader.Type == packets.Publish {
|
||||
atomic.AddInt64(&cl.ops.info.MessagesSent, 1)
|
||||
}
|
||||
|
||||
cl.ops.hooks.OnPacketSent(cl, pk, buf.Bytes())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (cl *Client) flushOutbuf() (err error) {
|
||||
if cl.Net.outbuf == nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = cl.Net.outbuf.WriteTo(cl.Net.Conn)
|
||||
if err == nil {
|
||||
cl.Net.outbuf = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
930
mqtt/clients_test.go
Normal file
930
mqtt/clients_test.go
Normal file
@@ -0,0 +1,930 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"testmqtt/packets"
|
||||
"testmqtt/system"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const pkInfo = "packet type %v, %s"
|
||||
|
||||
var errClientStop = errors.New("test stop")
|
||||
|
||||
func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
|
||||
r, w = net.Pipe()
|
||||
|
||||
cl = newClient(w, &ops{
|
||||
info: new(system.Info),
|
||||
hooks: new(Hooks),
|
||||
log: logger,
|
||||
options: &Options{
|
||||
Capabilities: &Capabilities{
|
||||
ReceiveMaximum: 10,
|
||||
MaximumInflight: 5,
|
||||
TopicAliasMaximum: 10000,
|
||||
MaximumClientWritesPending: 3,
|
||||
maximumPacketID: 10,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
cl.ID = "mochi"
|
||||
cl.State.Inflight.maximumSendQuota = 5
|
||||
cl.State.Inflight.sendQuota = 5
|
||||
cl.State.Inflight.maximumReceiveQuota = 10
|
||||
cl.State.Inflight.receiveQuota = 10
|
||||
cl.Properties.Props.TopicAliasMaximum = 0
|
||||
cl.Properties.Props.RequestResponseInfo = 0x1
|
||||
|
||||
go cl.WriteLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestNewInflights(t *testing.T) {
|
||||
require.NotNil(t, NewInflights().internal)
|
||||
}
|
||||
|
||||
func TestNewClients(t *testing.T) {
|
||||
cl := NewClients()
|
||||
require.NotNil(t, cl.internal)
|
||||
}
|
||||
|
||||
func TestClientsAdd(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
}
|
||||
|
||||
func TestClientsGet(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
cl.Add(&Client{ID: "t2"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
|
||||
client, ok := cl.Get("t1")
|
||||
require.Equal(t, true, ok)
|
||||
require.Equal(t, "t1", client.ID)
|
||||
}
|
||||
|
||||
func TestClientsGetAll(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
cl.Add(&Client{ID: "t2"})
|
||||
cl.Add(&Client{ID: "t3"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
require.Contains(t, cl.internal, "t3")
|
||||
|
||||
clients := cl.GetAll()
|
||||
require.Len(t, clients, 3)
|
||||
}
|
||||
|
||||
func TestClientsLen(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
cl.Add(&Client{ID: "t2"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
require.Equal(t, 2, cl.Len())
|
||||
}
|
||||
|
||||
func TestClientsDelete(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
|
||||
cl.Delete("t1")
|
||||
_, ok := cl.Get("t1")
|
||||
require.Equal(t, false, ok)
|
||||
require.Nil(t, cl.internal["t1"])
|
||||
}
|
||||
|
||||
func TestClientsGetByListener(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}})
|
||||
cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
|
||||
clients := cl.GetByListener("tcp1")
|
||||
require.NotEmpty(t, clients)
|
||||
require.Equal(t, 1, len(clients))
|
||||
require.Equal(t, "tcp1", clients[0].Net.Listener)
|
||||
}
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
require.NotNil(t, cl)
|
||||
require.NotNil(t, cl.State.Inflight.internal)
|
||||
require.NotNil(t, cl.State.Subscriptions)
|
||||
require.NotNil(t, cl.State.TopicAliases)
|
||||
require.Equal(t, defaultKeepalive, cl.State.Keepalive)
|
||||
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
|
||||
require.NotNil(t, cl.Net.Conn)
|
||||
require.NotNil(t, cl.Net.bconn)
|
||||
require.NotNil(t, cl.ops)
|
||||
require.NotNil(t, cl.ops.options.Capabilities)
|
||||
require.False(t, cl.Net.Inline)
|
||||
}
|
||||
|
||||
func TestClientParseConnect(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
Connect: packets.ConnectParams{
|
||||
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
|
||||
Clean: true,
|
||||
Keepalive: 60,
|
||||
ClientIdentifier: "mochi",
|
||||
WillFlag: true,
|
||||
WillTopic: "lwt",
|
||||
WillPayload: []byte("lol gg"),
|
||||
WillQos: 1,
|
||||
WillRetain: false,
|
||||
},
|
||||
Properties: packets.Properties{
|
||||
ReceiveMaximum: uint16(5),
|
||||
},
|
||||
}
|
||||
|
||||
cl.ParseConnect("tcp1", pk)
|
||||
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
|
||||
require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive)
|
||||
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
|
||||
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
|
||||
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
|
||||
require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload)
|
||||
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
|
||||
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
|
||||
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
|
||||
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
|
||||
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
|
||||
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota)
|
||||
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
|
||||
}
|
||||
|
||||
func TestClientParseConnectReceiveMaxExceedMaxInflight(t *testing.T) {
|
||||
const MaxInflight uint16 = 1
|
||||
cl, _, _ := newTestClient()
|
||||
cl.ops.options.Capabilities.MaximumInflight = MaxInflight
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
Connect: packets.ConnectParams{
|
||||
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
|
||||
Clean: true,
|
||||
Keepalive: 60,
|
||||
ClientIdentifier: "mochi",
|
||||
WillFlag: true,
|
||||
WillTopic: "lwt",
|
||||
WillPayload: []byte("lol gg"),
|
||||
WillQos: 1,
|
||||
WillRetain: false,
|
||||
},
|
||||
Properties: packets.Properties{
|
||||
ReceiveMaximum: uint16(5),
|
||||
},
|
||||
}
|
||||
|
||||
cl.ParseConnect("tcp1", pk)
|
||||
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
|
||||
require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive)
|
||||
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
|
||||
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
|
||||
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
|
||||
require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload)
|
||||
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
|
||||
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
|
||||
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
|
||||
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
|
||||
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
|
||||
require.Equal(t, int32(MaxInflight), cl.State.Inflight.sendQuota)
|
||||
require.Equal(t, int32(MaxInflight), cl.State.Inflight.maximumSendQuota)
|
||||
}
|
||||
|
||||
func TestClientParseConnectOverrideWillDelay(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
Connect: packets.ConnectParams{
|
||||
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
|
||||
Clean: true,
|
||||
Keepalive: 60,
|
||||
ClientIdentifier: "mochi",
|
||||
WillFlag: true,
|
||||
WillProperties: packets.Properties{
|
||||
WillDelayInterval: 200,
|
||||
},
|
||||
},
|
||||
Properties: packets.Properties{
|
||||
SessionExpiryInterval: 100,
|
||||
SessionExpiryIntervalFlag: true,
|
||||
},
|
||||
}
|
||||
|
||||
cl.ParseConnect("tcp1", pk)
|
||||
require.Equal(t, pk.Properties.SessionExpiryInterval, cl.Properties.Will.WillDelayInterval)
|
||||
}
|
||||
|
||||
func TestClientParseConnectNoID(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.ParseConnect("tcp1", packets.Packet{})
|
||||
require.NotEmpty(t, cl.ID)
|
||||
}
|
||||
|
||||
func TestClientParseConnectBelowMinimumKeepalive(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
var b bytes.Buffer
|
||||
x := bufio.NewWriter(&b)
|
||||
cl.ops.log = slog.New(slog.NewTextHandler(x, nil))
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
Connect: packets.ConnectParams{
|
||||
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
|
||||
Keepalive: minimumKeepalive - 1,
|
||||
ClientIdentifier: "mochi",
|
||||
},
|
||||
}
|
||||
cl.ParseConnect("tcp1", pk)
|
||||
err := x.Flush()
|
||||
require.NoError(t, err)
|
||||
require.True(t, strings.Contains(b.String(), ErrMinimumKeepalive.Error()))
|
||||
require.NotEmpty(t, cl.ID)
|
||||
}
|
||||
|
||||
func TestClientNextPacketID(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), i)
|
||||
|
||||
i, err = cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(2), i)
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDInUse(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
// skip over 2
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), i)
|
||||
|
||||
i, err = cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(3), i)
|
||||
|
||||
// Skip over overflow
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 65535})
|
||||
atomic.StoreUint32(&cl.State.packetID, 65534)
|
||||
|
||||
i, err = cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), i)
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDExhausted(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
|
||||
cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)}
|
||||
}
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
|
||||
require.Equal(t, uint32(0), i)
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDOverflow(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ {
|
||||
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
|
||||
}
|
||||
|
||||
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID - 1
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
|
||||
cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{}
|
||||
|
||||
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID
|
||||
_, err = cl.NextPacketID()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
|
||||
}
|
||||
|
||||
func TestClientClearInflights(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
n := time.Now().Unix()
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 1, Expiry: n - 1})
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 2, Expiry: n - 2})
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 3, Created: n - 3}) // within bounds
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 5, Created: n - 5}) // over max server expiry limit
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 7, Created: n})
|
||||
|
||||
require.Equal(t, 5, cl.State.Inflight.Len())
|
||||
cl.ClearInflights()
|
||||
require.Equal(t, 0, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestClientClearExpiredInflights(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
n := time.Now().Unix()
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 1, Expiry: n - 1})
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 2, Expiry: n - 2})
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 3, Created: n - 3}) // within bounds
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 5, Created: n - 5}) // over max server expiry limit
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 7, Created: n})
|
||||
require.Equal(t, 5, cl.State.Inflight.Len())
|
||||
|
||||
deleted := cl.ClearExpiredInflights(n, 4)
|
||||
require.Len(t, deleted, 3)
|
||||
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
|
||||
require.Equal(t, 2, cl.State.Inflight.Len())
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 11, Expiry: n - 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 12, Expiry: n - 2}) // expiry is ineffective for v3.
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 13, Created: n - 3}) // within bounds for v3
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 15, Created: n - 5}) // over max server expiry limit
|
||||
require.Equal(t, 6, cl.State.Inflight.Len())
|
||||
|
||||
deleted = cl.ClearExpiredInflights(n, 4)
|
||||
require.Len(t, deleted, 3)
|
||||
require.ElementsMatch(t, []uint16{11, 12, 15}, deleted)
|
||||
require.Equal(t, 3, cl.State.Inflight.Len())
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 17, Created: n - 1})
|
||||
deleted = cl.ClearExpiredInflights(n, 0) // maximumExpiry = 0 do not process abandon messages
|
||||
require.Len(t, deleted, 0)
|
||||
require.Equal(t, 4, cl.State.Inflight.Len())
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 18, Expiry: n - 1})
|
||||
deleted = cl.ClearExpiredInflights(n, 0) // maximumExpiry = 0 do not abandon messages
|
||||
require.ElementsMatch(t, []uint16{18}, deleted) // expiry is still effective for v5.
|
||||
require.Len(t, deleted, 1)
|
||||
require.Equal(t, 4, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessages(t *testing.T) {
|
||||
pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
cl.State.Inflight.Set(*pk1.Packet)
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
|
||||
go func() {
|
||||
err := cl.ResendInflightMessages(true)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
_ = w.Close()
|
||||
}()
|
||||
|
||||
buf, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, cl.State.Inflight.Len())
|
||||
require.Equal(t, pk1.RawBytes, buf)
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
|
||||
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
|
||||
cl, r, _ := newTestClient()
|
||||
_ = r.Close()
|
||||
|
||||
cl.State.Inflight.Set(*pk1.Packet)
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
err := cl.ResendInflightMessages(true)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, io.ErrClosedPipe)
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessagesNoMessages(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
err := cl.ResendInflightMessages(true)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClientRefreshDeadline(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.refreshDeadline(10)
|
||||
require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeader(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
_, _ = r.Write([]byte{packets.Connect << 4, 0x00})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.BytesReceived))
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
_, _ = r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
cl.ops.options.Capabilities.MaximumPacketSize = 2
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
_, _ = r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, io.EOF, err)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
_, _ = r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadOK(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
_, _ = r.Write([]byte{
|
||||
packets.Publish << 4, 18, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'a', '/', 'b', '/', 'c', // Topic Name
|
||||
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload,
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'd', '/', 'e', '/', 'f', // Topic Name
|
||||
'y', 'e', 'a', 'h', // Payload
|
||||
})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
var pks []packets.Packet
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
|
||||
pks = append(pks, pk)
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
err := <-o
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
require.Equal(t, 2, len(pks))
|
||||
require.Equal(t, []packets.Packet{
|
||||
{
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Remaining: 18,
|
||||
},
|
||||
TopicName: "a/b/c",
|
||||
Payload: []byte("hello mochi"),
|
||||
},
|
||||
{
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Remaining: 11,
|
||||
},
|
||||
TopicName: "d/e/f",
|
||||
Payload: []byte("yeah"),
|
||||
},
|
||||
}, pks)
|
||||
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.MessagesReceived))
|
||||
}
|
||||
|
||||
func TestClientReadDone(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
cl.State.cancelOpen()
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
require.NoError(t, <-o)
|
||||
}
|
||||
|
||||
func TestClientStop(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
require.Equal(t, int64(0), cl.StopTime())
|
||||
cl.Stop(nil)
|
||||
require.Equal(t, nil, cl.State.stopCause.Load())
|
||||
require.InDelta(t, time.Now().Unix(), cl.State.disconnected, 1.0)
|
||||
require.Equal(t, cl.State.disconnected, cl.StopTime())
|
||||
require.True(t, cl.Closed())
|
||||
require.Equal(t, nil, cl.StopCause())
|
||||
}
|
||||
|
||||
func TestClientClosed(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
require.False(t, cl.Closed())
|
||||
cl.Stop(nil)
|
||||
require.True(t, cl.Closed())
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderError(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
_, _ = r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
cl.Net.bconn = nil
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, ErrConnectionClosed, err)
|
||||
}
|
||||
|
||||
func TestClientReadReadHandlerErr(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
_, _ = r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'd', '/', 'e', '/', 'f', // Topic Name
|
||||
'y', 'e', 'a', 'h', // Payload
|
||||
})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
err := cl.Read(func(cl *Client, pk packets.Packet) error {
|
||||
return errors.New("test")
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadReadPacketOK(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
_, _ = r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk, err := cl.ReadPacket(fh)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pk)
|
||||
|
||||
require.Equal(t, packets.Packet{
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Remaining: 11,
|
||||
},
|
||||
TopicName: "d/e/f",
|
||||
Payload: []byte("yeah"),
|
||||
}, pk)
|
||||
}
|
||||
|
||||
func TestClientReadPacket(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
for _, tx := range pkTable {
|
||||
tt := tx // avoid data race
|
||||
t.Run(tt.Desc, func(t *testing.T) {
|
||||
atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0)
|
||||
go func() {
|
||||
_, _ = r.Write(tt.RawBytes)
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.Packet.ProtocolVersion == 5 {
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
} else {
|
||||
cl.Properties.ProtocolVersion = 0
|
||||
}
|
||||
|
||||
pk, err := cl.ReadPacket(fh)
|
||||
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
|
||||
require.NotNil(t, pk, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, *tt.Packet, pk, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
if tt.Packet.FixedHeader.Type == packets.Publish {
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsReceived), pkInfo, tt.Case, tt.Desc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReadPacketInvalidTypeError(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
_ = cl.Net.Conn.Close()
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid packet type")
|
||||
}
|
||||
|
||||
func TestClientWritePacket(t *testing.T) {
|
||||
for _, tt := range pkTable {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
|
||||
|
||||
o := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
o <- buf
|
||||
}()
|
||||
|
||||
err := cl.WritePacket(*tt.Packet)
|
||||
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
_ = cl.Net.Conn.Close()
|
||||
|
||||
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
cl.Stop(errClientStop)
|
||||
time.Sleep(time.Millisecond * 1)
|
||||
|
||||
// The stop cause is either the test error, EOF, or a
|
||||
// closed pipe, depending on which goroutine runs first.
|
||||
err = cl.StopCause()
|
||||
require.True(t,
|
||||
errors.Is(err, errClientStop) ||
|
||||
errors.Is(err, io.EOF) ||
|
||||
errors.Is(err, io.ErrClosedPipe))
|
||||
|
||||
require.Equal(t, int64(len(tt.RawBytes)), atomic.LoadInt64(&cl.ops.info.BytesSent))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsSent))
|
||||
if tt.Packet.FixedHeader.Type == packets.Publish {
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.MessagesSent))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientWritePacketBuffer(t *testing.T) {
|
||||
r, w := net.Pipe()
|
||||
|
||||
cl := newClient(w, &ops{
|
||||
info: new(system.Info),
|
||||
hooks: new(Hooks),
|
||||
log: logger,
|
||||
options: &Options{
|
||||
Capabilities: &Capabilities{
|
||||
ReceiveMaximum: 10,
|
||||
TopicAliasMaximum: 10000,
|
||||
MaximumClientWritesPending: 3,
|
||||
maximumPacketID: 10,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
cl.ID = "mochi"
|
||||
cl.State.Inflight.maximumSendQuota = 5
|
||||
cl.State.Inflight.sendQuota = 5
|
||||
cl.State.Inflight.maximumReceiveQuota = 10
|
||||
cl.State.Inflight.receiveQuota = 10
|
||||
cl.Properties.Props.TopicAliasMaximum = 0
|
||||
cl.Properties.Props.RequestResponseInfo = 0x1
|
||||
|
||||
cl.ops.options.ClientNetWriteBufferSize = 10
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
small := packets.TPacketData[packets.Publish].Get(packets.TPublishNoPayload).Packet
|
||||
large := packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
|
||||
|
||||
cl.State.outbound <- small
|
||||
|
||||
tt := []struct {
|
||||
pks []*packets.Packet
|
||||
size int
|
||||
}{
|
||||
{
|
||||
pks: []*packets.Packet{small, small},
|
||||
size: 18,
|
||||
},
|
||||
{
|
||||
pks: []*packets.Packet{large},
|
||||
size: 20,
|
||||
},
|
||||
{
|
||||
pks: []*packets.Packet{small},
|
||||
size: 0,
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
for i, tx := range tt {
|
||||
for _, pk := range tx.pks {
|
||||
cl.Properties.ProtocolVersion = pk.ProtocolVersion
|
||||
err := cl.WritePacket(*pk)
|
||||
require.NoError(t, err, "index: %d", i)
|
||||
if i == len(tt)-1 {
|
||||
cl.Net.Conn.Close()
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var n int
|
||||
var err error
|
||||
for i, tx := range tt {
|
||||
buf := make([]byte, 100)
|
||||
if i == len(tt)-1 {
|
||||
buf, err = io.ReadAll(r)
|
||||
n = len(buf)
|
||||
} else {
|
||||
n, err = io.ReadAtLeast(r, buf, 1)
|
||||
}
|
||||
require.NoError(t, err, "index: %d", i)
|
||||
require.Equal(t, tx.size, n, "index: %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteClientOversizePacket(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Properties.Props.MaximumPacketSize = 2
|
||||
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
|
||||
err := cl.WritePacket(pk)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, packets.ErrPacketTooLarge, err)
|
||||
}
|
||||
|
||||
func TestClientReadPacketReadingError(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
_, _ = r.Write([]byte{
|
||||
0, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{
|
||||
Type: 0,
|
||||
Remaining: 11,
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadPacketReadUnknown(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
_, _ = r.Write([]byte{
|
||||
0, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{
|
||||
Remaining: 1,
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientWritePacketWriteNoConn(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Stop(errClientStop)
|
||||
|
||||
err := cl.WritePacket(*pkTable[1].Packet)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, ErrConnectionClosed, err)
|
||||
}
|
||||
|
||||
func TestClientWritePacketWriteError(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
_ = cl.Net.Conn.Close()
|
||||
|
||||
err := cl.WritePacket(*pkTable[1].Packet)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientWritePacketInvalidPacket(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
err := cl.WritePacket(packets.Packet{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
var (
|
||||
pkTable = []packets.TPacketCase{
|
||||
packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311),
|
||||
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedMqtt5),
|
||||
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession),
|
||||
packets.TPacketData[packets.Publish].Get(packets.TPublishBasic),
|
||||
packets.TPacketData[packets.Publish].Get(packets.TPublishMqtt5),
|
||||
packets.TPacketData[packets.Puback].Get(packets.TPuback),
|
||||
packets.TPacketData[packets.Pubrec].Get(packets.TPubrec),
|
||||
packets.TPacketData[packets.Pubrel].Get(packets.TPubrel),
|
||||
packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp),
|
||||
packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe),
|
||||
packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5),
|
||||
packets.TPacketData[packets.Suback].Get(packets.TSuback),
|
||||
packets.TPacketData[packets.Suback].Get(packets.TSubackMqtt5),
|
||||
packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribe),
|
||||
packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5),
|
||||
packets.TPacketData[packets.Unsuback].Get(packets.TUnsuback),
|
||||
packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackMqtt5),
|
||||
packets.TPacketData[packets.Pingreq].Get(packets.TPingreq),
|
||||
packets.TPacketData[packets.Pingresp].Get(packets.TPingresp),
|
||||
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect),
|
||||
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5),
|
||||
packets.TPacketData[packets.Auth].Get(packets.TAuth),
|
||||
}
|
||||
)
|
||||
864
mqtt/hooks.go
Normal file
864
mqtt/hooks.go
Normal file
@@ -0,0 +1,864 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co, thedevop, dgduncan
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"testmqtt/hooks/storage"
|
||||
"testmqtt/packets"
|
||||
"testmqtt/system"
|
||||
)
|
||||
|
||||
const (
|
||||
SetOptions byte = iota
|
||||
OnSysInfoTick
|
||||
OnStarted
|
||||
OnStopped
|
||||
OnConnectAuthenticate
|
||||
OnACLCheck
|
||||
OnConnect
|
||||
OnSessionEstablish
|
||||
OnSessionEstablished
|
||||
OnDisconnect
|
||||
OnAuthPacket
|
||||
OnPacketRead
|
||||
OnPacketEncode
|
||||
OnPacketSent
|
||||
OnPacketProcessed
|
||||
OnSubscribe
|
||||
OnSubscribed
|
||||
OnSelectSubscribers
|
||||
OnUnsubscribe
|
||||
OnUnsubscribed
|
||||
OnPublish
|
||||
OnPublished
|
||||
OnPublishDropped
|
||||
OnRetainMessage
|
||||
OnRetainPublished
|
||||
OnQosPublish
|
||||
OnQosComplete
|
||||
OnQosDropped
|
||||
OnPacketIDExhausted
|
||||
OnWill
|
||||
OnWillSent
|
||||
OnClientExpired
|
||||
OnRetainedExpired
|
||||
StoredClients
|
||||
StoredSubscriptions
|
||||
StoredInflightMessages
|
||||
StoredRetainedMessages
|
||||
StoredSysInfo
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidConfigType indicates a different Type of config value was expected to what was received.
|
||||
// ErrInvalidConfigType = errors.New("invalid config type provided")
|
||||
ErrInvalidConfigType = errors.New("提供的配置类型无效")
|
||||
)
|
||||
|
||||
// HookLoadConfig contains the hook and configuration as loaded from a configuration (usually file).
|
||||
type HookLoadConfig struct {
|
||||
Hook Hook
|
||||
Config any
|
||||
}
|
||||
|
||||
// Hook provides an interface of handlers for different events which occur
|
||||
// during the lifecycle of the broker.
|
||||
type Hook interface {
|
||||
ID() string
|
||||
Provides(b byte) bool
|
||||
Init(config any) error
|
||||
Stop() error
|
||||
SetOpts(l *slog.Logger, o *HookOptions)
|
||||
|
||||
OnStarted()
|
||||
OnStopped()
|
||||
OnConnectAuthenticate(cl *Client, pk packets.Packet) bool
|
||||
OnACLCheck(cl *Client, topic string, write bool) bool
|
||||
OnSysInfoTick(*system.Info)
|
||||
OnConnect(cl *Client, pk packets.Packet) error
|
||||
OnSessionEstablish(cl *Client, pk packets.Packet)
|
||||
OnSessionEstablished(cl *Client, pk packets.Packet)
|
||||
OnDisconnect(cl *Client, err error, expire bool)
|
||||
OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error)
|
||||
OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) // triggers when a new packet is received by a client, but before packet validation
|
||||
OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet // modify a packet before it is byte-encoded and written to the client
|
||||
OnPacketSent(cl *Client, pk packets.Packet, b []byte) // triggers when packet bytes have been written to the client
|
||||
OnPacketProcessed(cl *Client, pk packets.Packet, err error) // triggers after a packet from the client been processed (handled)
|
||||
OnSubscribe(cl *Client, pk packets.Packet) packets.Packet
|
||||
OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
|
||||
OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers
|
||||
OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet
|
||||
OnUnsubscribed(cl *Client, pk packets.Packet)
|
||||
OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
|
||||
OnPublished(cl *Client, pk packets.Packet)
|
||||
OnPublishDropped(cl *Client, pk packets.Packet)
|
||||
OnRetainMessage(cl *Client, pk packets.Packet, r int64)
|
||||
OnRetainPublished(cl *Client, pk packets.Packet)
|
||||
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
|
||||
OnQosComplete(cl *Client, pk packets.Packet)
|
||||
OnQosDropped(cl *Client, pk packets.Packet)
|
||||
OnPacketIDExhausted(cl *Client, pk packets.Packet)
|
||||
OnWill(cl *Client, will Will) (Will, error)
|
||||
OnWillSent(cl *Client, pk packets.Packet)
|
||||
OnClientExpired(cl *Client)
|
||||
OnRetainedExpired(filter string)
|
||||
StoredClients() ([]storage.Client, error)
|
||||
StoredSubscriptions() ([]storage.Subscription, error)
|
||||
StoredInflightMessages() ([]storage.Message, error)
|
||||
StoredRetainedMessages() ([]storage.Message, error)
|
||||
StoredSysInfo() (storage.SystemInfo, error)
|
||||
}
|
||||
|
||||
// HookOptions contains values which are inherited from the server on initialisation.
|
||||
type HookOptions struct {
|
||||
Capabilities *Capabilities
|
||||
}
|
||||
|
||||
// Hooks is a slice of Hook interfaces to be called in sequence.
|
||||
type Hooks struct {
|
||||
Log *slog.Logger // a logger for the hook (from the server)
|
||||
internal atomic.Value // a slice of []Hook
|
||||
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
|
||||
qty int64 // the number of hooks in use
|
||||
sync.Mutex // a mutex for locking when adding hooks
|
||||
}
|
||||
|
||||
// Len returns the number of hooks added.
|
||||
func (h *Hooks) Len() int64 {
|
||||
return atomic.LoadInt64(&h.qty)
|
||||
}
|
||||
|
||||
// Provides returns true if any one hook provides any of the requested hook methods.
|
||||
func (h *Hooks) Provides(b ...byte) bool {
|
||||
for _, hook := range h.GetAll() {
|
||||
for _, hb := range b {
|
||||
if hook.Provides(hb) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Add adds and initializes a new hook.
|
||||
func (h *Hooks) Add(hook Hook, config any) error {
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
|
||||
err := hook.Init(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
|
||||
}
|
||||
|
||||
i, ok := h.internal.Load().([]Hook)
|
||||
if !ok {
|
||||
i = []Hook{}
|
||||
}
|
||||
|
||||
i = append(i, hook)
|
||||
h.internal.Store(i)
|
||||
atomic.AddInt64(&h.qty, 1)
|
||||
h.wg.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAll returns a slice of all the hooks.
|
||||
func (h *Hooks) GetAll() []Hook {
|
||||
i, ok := h.internal.Load().([]Hook)
|
||||
if !ok {
|
||||
return []Hook{}
|
||||
}
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
// Stop indicates all attached hooks to gracefully end.
|
||||
func (h *Hooks) Stop() {
|
||||
go func() {
|
||||
for _, hook := range h.GetAll() {
|
||||
h.Log.Info("stopping hook", "hook", hook.ID())
|
||||
if err := hook.Stop(); err != nil {
|
||||
h.Log.Debug("problem stopping hook", "error", err, "hook", hook.ID())
|
||||
}
|
||||
|
||||
h.wg.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
h.wg.Wait()
|
||||
}
|
||||
|
||||
// OnSysInfoTick is called when the $SYS topic values are published out.
|
||||
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSysInfoTick) {
|
||||
hook.OnSysInfoTick(sys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnStarted is called when the server has successfully started.
|
||||
func (h *Hooks) OnStarted() {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnStarted) {
|
||||
hook.OnStarted()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnStopped is called when the server has successfully stopped.
|
||||
func (h *Hooks) OnStopped() {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnStopped) {
|
||||
hook.OnStopped()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnConnect is called when a new client connects, and may return a packets.Code as an error to halt the connection.
|
||||
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) error {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnConnect) {
|
||||
err := hook.OnConnect(cl, pk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnSessionEstablish is called right after a new client connects and authenticates and right before
|
||||
// the session is established and CONNACK is sent.
|
||||
func (h *Hooks) OnSessionEstablish(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSessionEstablish) {
|
||||
hook.OnSessionEstablish(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
|
||||
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSessionEstablished) {
|
||||
hook.OnSessionEstablished(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect is called when a client is disconnected for any reason.
|
||||
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnDisconnect) {
|
||||
hook.OnDisconnect(cl, err, expire)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPacketRead is called when a packet is received from a client.
|
||||
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketRead) {
|
||||
npk, err := hook.OnPacketRead(cl, pkx)
|
||||
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
|
||||
h.Log.Debug("packet rejected", "hook", hook.ID(), "packet", pkx)
|
||||
return pk, err
|
||||
} else if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
pkx = npk
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// OnAuthPacket is called when an auth packet is received. It is intended to allow developers
|
||||
// to create their own auth packet handling mechanisms.
|
||||
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnAuthPacket) {
|
||||
npk, err := hook.OnAuthPacket(cl, pkx)
|
||||
if err != nil {
|
||||
return pk, err
|
||||
}
|
||||
|
||||
pkx = npk
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
|
||||
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketEncode) {
|
||||
pk = hook.OnPacketEncode(cl, pk)
|
||||
}
|
||||
}
|
||||
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
|
||||
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketProcessed) {
|
||||
hook.OnPacketProcessed(cl, pk, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
|
||||
// containing the bytes sent.
|
||||
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketSent) {
|
||||
hook.OnPacketSent(cl, pk, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribe is called when a client subscribes to one or more filters. This method
|
||||
// differs from OnSubscribed in that it allows you to modify the subscription values
|
||||
// before the packet is processed. The return values of the hook methods are passed-through
|
||||
// in the order the hooks were attached.
|
||||
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSubscribe) {
|
||||
pk = hook.OnSubscribe(cl, pk)
|
||||
}
|
||||
}
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnSubscribed is called when a client subscribes to one or more filters.
|
||||
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSubscribed) {
|
||||
hook.OnSubscribed(cl, pk, reasonCodes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnSelectSubscribers is called when subscribers have been collected for a topic, but before
|
||||
// shared subscription subscribers have been selected. This hook can be used to programmatically
|
||||
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
|
||||
// group in a custom manner (such as based on client id, ip, etc).
|
||||
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSelectSubscribers) {
|
||||
subs = hook.OnSelectSubscribers(subs, pk)
|
||||
}
|
||||
}
|
||||
return subs
|
||||
}
|
||||
|
||||
// OnUnsubscribe is called when a client unsubscribes from one or more filters. This method
|
||||
// differs from OnUnsubscribed in that it allows you to modify the unsubscription values
|
||||
// before the packet is processed. The return values of the hook methods are passed-through
|
||||
// in the order the hooks were attached.
|
||||
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnUnsubscribe) {
|
||||
pk = hook.OnUnsubscribe(cl, pk)
|
||||
}
|
||||
}
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
|
||||
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnUnsubscribed) {
|
||||
hook.OnUnsubscribed(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPublish is called when a client publishes a message. This method differs from OnPublished
|
||||
// in that it allows you to modify you to modify the incoming packet before it is processed.
|
||||
// The return values of the hook methods are passed-through in the order the hooks were attached.
|
||||
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublish) {
|
||||
npk, err := hook.OnPublish(cl, pkx)
|
||||
if err != nil {
|
||||
if errors.Is(err, packets.ErrRejectPacket) {
|
||||
h.Log.Debug("publish packet rejected",
|
||||
"error", err,
|
||||
"hook", hook.ID(),
|
||||
"packet", pkx)
|
||||
return pk, err
|
||||
}
|
||||
h.Log.Error("publish packet error",
|
||||
"error", err,
|
||||
"hook", hook.ID(),
|
||||
"packet", pkx)
|
||||
return pk, err
|
||||
}
|
||||
pkx = npk
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// OnPublished is called when a client has published a message to subscribers.
|
||||
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublished) {
|
||||
hook.OnPublished(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPublishDropped is called when a message to a client was dropped instead of delivered
|
||||
// such as when a client is too slow to respond.
|
||||
func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublishDropped) {
|
||||
hook.OnPublishDropped(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage is called then a published message is retained.
|
||||
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnRetainMessage) {
|
||||
hook.OnRetainMessage(cl, pk, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainPublished is called when a retained message is published.
|
||||
func (h *Hooks) OnRetainPublished(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnRetainPublished) {
|
||||
hook.OnRetainPublished(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish is called when a publish packet with Qos >= 1 is issued to a subscriber.
|
||||
// In other words, this method is called when a new inflight message is created or resent.
|
||||
// It is typically used to store a new inflight message.
|
||||
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosPublish) {
|
||||
hook.OnQosPublish(cl, pk, sent, resends)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete is called when the Qos flow for a message has been completed.
|
||||
// In other words, when an inflight message is resolved.
|
||||
// It is typically used to delete an inflight message from a store.
|
||||
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosComplete) {
|
||||
hook.OnQosComplete(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped is called the Qos flow for a message expires. In other words, when
|
||||
// an inflight message expires or is abandoned. It is typically used to delete an
|
||||
// inflight message from a store.
|
||||
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosDropped) {
|
||||
hook.OnQosDropped(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPacketIDExhausted is called when the client runs out of unused packet ids to
|
||||
// assign to a packet.
|
||||
func (h *Hooks) OnPacketIDExhausted(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketIDExhausted) {
|
||||
hook.OnPacketIDExhausted(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnWill is called when a client disconnects and publishes an LWT message. This method
|
||||
// differs from OnWillSent in that it allows you to modify the LWT message before it is
|
||||
// published. The return values of the hook methods are passed-through in the order
|
||||
// the hooks were attached.
|
||||
func (h *Hooks) OnWill(cl *Client, will Will) Will {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnWill) {
|
||||
mlwt, err := hook.OnWill(cl, will)
|
||||
if err != nil {
|
||||
h.Log.Error("parse will error",
|
||||
"error", err,
|
||||
"hook", hook.ID(),
|
||||
"will", will)
|
||||
continue
|
||||
}
|
||||
will = mlwt
|
||||
}
|
||||
}
|
||||
|
||||
return will
|
||||
}
|
||||
|
||||
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
|
||||
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnWillSent) {
|
||||
hook.OnWillSent(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired is called when a client session has expired and should be deleted.
|
||||
func (h *Hooks) OnClientExpired(cl *Client) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnClientExpired) {
|
||||
hook.OnClientExpired(cl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired is called when a retained message has expired and should be deleted.
|
||||
func (h *Hooks) OnRetainedExpired(filter string) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnRetainedExpired) {
|
||||
hook.OnRetainedExpired(filter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all clients, e.g. from a persistent store, is used to
|
||||
// populate the server clients list before start.
|
||||
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredClients) {
|
||||
v, err := hook.StoredClients()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to load clients", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
if len(v) > 0 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
|
||||
// used to populate the server subscriptions list before start.
|
||||
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredSubscriptions) {
|
||||
v, err := hook.StoredSubscriptions()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to load subscriptions", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
if len(v) > 0 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
|
||||
// and is used to populate the restored clients with inflight messages before start.
|
||||
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredInflightMessages) {
|
||||
v, err := hook.StoredInflightMessages()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to load inflight messages", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
if len(v) > 0 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
|
||||
// and is used to populate the server topics with retained messages before start.
|
||||
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredRetainedMessages) {
|
||||
v, err := hook.StoredRetainedMessages()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to load retained messages", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
if len(v) > 0 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSysInfo returns a set of system info values.
|
||||
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredSysInfo) {
|
||||
v, err := hook.StoredSysInfo()
|
||||
if err != nil {
|
||||
h.Log.Error("failed to load $SYS info", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
if v.Version != "" {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
|
||||
// An implementation of this method MUST be used to allow or deny access to the
|
||||
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
|
||||
// check connecting users against an existing user database.
|
||||
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnConnectAuthenticate) {
|
||||
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// OnACLCheck is called when a user attempts to publish or subscribe to a topic filter.
|
||||
// An implementation of this method MUST be used to allow or deny access to the
|
||||
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
|
||||
// check publishing and subscribing users against an existing permissions or roles database.
|
||||
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnACLCheck) {
|
||||
if ok := hook.OnACLCheck(cl, topic, write); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// HookBase provides a set of default methods for each hook. It should be embedded in
|
||||
// all hooks.
|
||||
type HookBase struct {
|
||||
Hook
|
||||
Log *slog.Logger
|
||||
Opts *HookOptions
|
||||
}
|
||||
|
||||
// ID returns the ID of the hook.
|
||||
func (h *HookBase) ID() string {
|
||||
return "base"
|
||||
}
|
||||
|
||||
// Provides indicates which methods a hook provides. The default is none - this method
|
||||
// should be overridden by the embedding hook.
|
||||
func (h *HookBase) Provides(b byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Init performs any pre-start initializations for the hook, such as connecting to databases
|
||||
// or opening files.
|
||||
func (h *HookBase) Init(config any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetOpts is called by the server to propagate internal values and generally should
|
||||
// not be called manually.
|
||||
func (h *HookBase) SetOpts(l *slog.Logger, opts *HookOptions) {
|
||||
h.Log = l
|
||||
h.Opts = opts
|
||||
}
|
||||
|
||||
// Stop is called to gracefully shut down the hook.
|
||||
func (h *HookBase) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnStarted is called when the server starts.
|
||||
func (h *HookBase) OnStarted() {}
|
||||
|
||||
// OnStopped is called when the server stops.
|
||||
func (h *HookBase) OnStopped() {}
|
||||
|
||||
// OnSysInfoTick is called when the server publishes system info.
|
||||
func (h *HookBase) OnSysInfoTick(*system.Info) {}
|
||||
|
||||
// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
|
||||
func (h *HookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// OnACLCheck is called when a user attempts to subscribe or publish to a topic.
|
||||
func (h *HookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// OnConnect is called when a new client connects.
|
||||
func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnSessionEstablish is called right after a new client connects and authenticates and right before
|
||||
// the session is established and CONNACK is sent.
|
||||
func (h *HookBase) OnSessionEstablish(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
|
||||
func (h *HookBase) OnSessionEstablished(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnDisconnect is called when a client is disconnected for any reason.
|
||||
func (h *HookBase) OnDisconnect(cl *Client, err error, expire bool) {}
|
||||
|
||||
// OnAuthPacket is called when an auth packet is received from the client.
|
||||
func (h *HookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// OnPacketRead is called when a packet is received.
|
||||
func (h *HookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// OnPacketEncode is called before a packet is byte-encoded and written to the client.
|
||||
func (h *HookBase) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnPacketSent is called immediately after a packet is written to a client.
|
||||
func (h *HookBase) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {}
|
||||
|
||||
// OnPacketProcessed is called immediately after a packet from a client is processed.
|
||||
func (h *HookBase) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {}
|
||||
|
||||
// OnSubscribe is called when a client subscribes to one or more filters.
|
||||
func (h *HookBase) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnSubscribed is called when a client subscribes to one or more filters.
|
||||
func (h *HookBase) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {}
|
||||
|
||||
// OnSelectSubscribers is called when selecting subscribers to receive a message.
|
||||
func (h *HookBase) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
|
||||
return subs
|
||||
}
|
||||
|
||||
// OnUnsubscribe is called when a client unsubscribes from one or more filters.
|
||||
func (h *HookBase) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
|
||||
func (h *HookBase) OnUnsubscribed(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnPublish is called when a client publishes a message.
|
||||
func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// OnPublished is called when a client has published a message to subscribers.
|
||||
func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnPublishDropped is called when a message to a client is dropped instead of being delivered.
|
||||
func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnRetainMessage is called then a published message is retained.
|
||||
func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}
|
||||
|
||||
// OnRetainPublished is called when a retained message is published.
|
||||
func (h *HookBase) OnRetainPublished(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnQosPublish is called when a publish packet with Qos > 1 is issued to a subscriber.
|
||||
func (h *HookBase) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {}
|
||||
|
||||
// OnQosComplete is called when the Qos flow for a message has been completed.
|
||||
func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnQosDropped is called the Qos flow for a message expires.
|
||||
func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnPacketIDExhausted is called when the client runs out of unused packet ids to assign to a packet.
|
||||
func (h *HookBase) OnPacketIDExhausted(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnWill is called when a client disconnects and publishes an LWT message.
|
||||
func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) {
|
||||
return will, nil
|
||||
}
|
||||
|
||||
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
|
||||
func (h *HookBase) OnWillSent(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnClientExpired is called when a client session has expired.
|
||||
func (h *HookBase) OnClientExpired(cl *Client) {}
|
||||
|
||||
// OnRetainedExpired is called when a retained message for a topic has expired.
|
||||
func (h *HookBase) OnRetainedExpired(topic string) {}
|
||||
|
||||
// StoredClients returns all clients from a store.
|
||||
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all subcriptions from a store.
|
||||
func (h *HookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all inflight messages from a store.
|
||||
func (h *HookBase) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// StoredRetainedMessages 返回存储区中所有保留的消息
|
||||
// StoredRetainedMessages returns all retained messages from a store.
|
||||
func (h *HookBase) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSysInfo 返回一组系统信息值
|
||||
// StoredSysInfo returns a set of system info values.
|
||||
func (h *HookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
return
|
||||
}
|
||||
667
mqtt/hooks_test.go
Normal file
667
mqtt/hooks_test.go
Normal file
@@ -0,0 +1,667 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"testmqtt/hooks/storage"
|
||||
"testmqtt/packets"
|
||||
"testmqtt/system"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type modifiedHookBase struct {
|
||||
HookBase
|
||||
err error
|
||||
fail bool
|
||||
failAt int
|
||||
}
|
||||
|
||||
var errTestHook = errors.New("error")
|
||||
|
||||
func (h *modifiedHookBase) ID() string {
|
||||
return "modified"
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) Init(config any) error {
|
||||
if config != nil {
|
||||
return errTestHook
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) Provides(b byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) Stop() error {
|
||||
if h.fail {
|
||||
return errTestHook
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnConnect(cl *Client, pk packets.Packet) error {
|
||||
if h.fail {
|
||||
return errTestHook
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
if h.fail {
|
||||
if h.err != nil {
|
||||
return pk, h.err
|
||||
}
|
||||
|
||||
return pk, errTestHook
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
if h.fail {
|
||||
if h.err != nil {
|
||||
return pk, h.err
|
||||
}
|
||||
|
||||
return pk, errTestHook
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
if h.fail {
|
||||
if h.err != nil {
|
||||
return pk, h.err
|
||||
}
|
||||
|
||||
return pk, errTestHook
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnWill(cl *Client, will Will) (Will, error) {
|
||||
if h.fail {
|
||||
return will, errTestHook
|
||||
}
|
||||
|
||||
return will, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredClients() (v []storage.Client, err error) {
|
||||
if h.fail || h.failAt == 1 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return []storage.Client{
|
||||
{ID: "cl1"},
|
||||
{ID: "cl2"},
|
||||
{ID: "cl3"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.fail || h.failAt == 2 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return []storage.Subscription{
|
||||
{ID: "sub1"},
|
||||
{ID: "sub2"},
|
||||
{ID: "sub3"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.fail || h.failAt == 3 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return []storage.Message{
|
||||
{ID: "r1"},
|
||||
{ID: "r2"},
|
||||
{ID: "r3"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.fail || h.failAt == 4 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return []storage.Message{
|
||||
{ID: "i1"},
|
||||
{ID: "i2"},
|
||||
{ID: "i3"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.fail || h.failAt == 5 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return storage.SystemInfo{
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type providesCheckHook struct {
|
||||
HookBase
|
||||
}
|
||||
|
||||
func (h *providesCheckHook) Provides(b byte) bool {
|
||||
return b == OnConnect
|
||||
}
|
||||
|
||||
func TestHooksProvides(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(providesCheckHook), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.Add(new(HookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, h.Provides(OnConnect, OnDisconnect))
|
||||
require.False(t, h.Provides(OnDisconnect))
|
||||
}
|
||||
|
||||
func TestHooksAddLenGetAll(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(HookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
|
||||
require.Equal(t, int64(2), h.Len())
|
||||
|
||||
all := h.GetAll()
|
||||
require.Equal(t, "base", all[0].ID())
|
||||
require.Equal(t, "modified", all[1].ID())
|
||||
}
|
||||
|
||||
func TestHooksAddInitFailure(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(modifiedHookBase), map[string]any{})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&h.qty))
|
||||
}
|
||||
|
||||
func TestHooksStop(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = logger
|
||||
|
||||
err := h.Add(new(HookBase), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
|
||||
require.Equal(t, int64(1), h.Len())
|
||||
|
||||
h.Stop()
|
||||
}
|
||||
|
||||
// coverage: also cover some empty functions
|
||||
func TestHooksNonReturns(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
cl := new(Client)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
t.Run("step-"+strconv.Itoa(i), func(t *testing.T) {
|
||||
// on first iteration, check without hook methods
|
||||
h.OnStarted()
|
||||
h.OnStopped()
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
h.OnSessionEstablish(cl, packets.Packet{})
|
||||
h.OnSessionEstablished(cl, packets.Packet{})
|
||||
h.OnDisconnect(cl, nil, false)
|
||||
h.OnPacketSent(cl, packets.Packet{}, []byte{})
|
||||
h.OnPacketProcessed(cl, packets.Packet{}, nil)
|
||||
h.OnSubscribed(cl, packets.Packet{}, []byte{1})
|
||||
h.OnUnsubscribed(cl, packets.Packet{})
|
||||
h.OnPublished(cl, packets.Packet{})
|
||||
h.OnPublishDropped(cl, packets.Packet{})
|
||||
h.OnRetainMessage(cl, packets.Packet{}, 0)
|
||||
h.OnRetainPublished(cl, packets.Packet{})
|
||||
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
|
||||
h.OnQosComplete(cl, packets.Packet{})
|
||||
h.OnQosDropped(cl, packets.Packet{})
|
||||
h.OnPacketIDExhausted(cl, packets.Packet{})
|
||||
h.OnWillSent(cl, packets.Packet{})
|
||||
h.OnClientExpired(cl)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
|
||||
// on second iteration, check added hook methods
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHooksOnConnectAuthenticate(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
|
||||
ok := h.OnConnectAuthenticate(new(Client), packets.Packet{})
|
||||
require.False(t, ok)
|
||||
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
ok = h.OnConnectAuthenticate(new(Client), packets.Packet{})
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestHooksOnACLCheck(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
|
||||
ok := h.OnACLCheck(new(Client), "a/b/c", true)
|
||||
require.False(t, ok)
|
||||
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
ok = h.OnACLCheck(new(Client), "a/b/c", true)
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestHooksOnSubscribe(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pki := packets.Packet{
|
||||
Filters: packets.Subscriptions{
|
||||
{Filter: "a/b/c", Qos: 1},
|
||||
},
|
||||
}
|
||||
pk := h.OnSubscribe(new(Client), pki)
|
||||
require.EqualValues(t, pk, pki)
|
||||
}
|
||||
|
||||
func TestHooksOnSelectSubscribers(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
subs := &Subscribers{
|
||||
Subscriptions: map[string]packets.Subscription{
|
||||
"cl1": {Filter: "a/b/c"},
|
||||
},
|
||||
}
|
||||
|
||||
subs2 := h.OnSelectSubscribers(subs, packets.Packet{})
|
||||
require.EqualValues(t, subs, subs2)
|
||||
}
|
||||
|
||||
func TestHooksOnUnsubscribe(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pki := packets.Packet{
|
||||
Filters: packets.Subscriptions{
|
||||
{Filter: "a/b/c", Qos: 1},
|
||||
},
|
||||
}
|
||||
|
||||
pk := h.OnUnsubscribe(new(Client), pki)
|
||||
require.EqualValues(t, pk, pki)
|
||||
}
|
||||
|
||||
func TestHooksOnPublish(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
// coverage: failure
|
||||
hook.fail = true
|
||||
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
// coverage: reject packet
|
||||
hook.err = packets.ErrRejectPacket
|
||||
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrRejectPacket)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHooksOnPacketRead(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
// coverage: failure
|
||||
hook.fail = true
|
||||
pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
// coverage: reject packet
|
||||
hook.err = packets.ErrRejectPacket
|
||||
pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrRejectPacket)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHooksOnAuthPacket(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
hook.fail = true
|
||||
pk, err = h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHooksOnConnect(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.OnConnect(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
|
||||
hook.fail = true
|
||||
err = h.OnConnect(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHooksOnPacketEncode(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk := h.OnPacketEncode(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHooksOnLWT(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
lwt := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
|
||||
require.Equal(t, "a/b/c", lwt.TopicName)
|
||||
|
||||
// coverage: fail lwt
|
||||
hook.fail = true
|
||||
lwt = h.OnWill(new(Client), Will{TopicName: "a/b/c"})
|
||||
require.Equal(t, "a/b/c", lwt.TopicName)
|
||||
}
|
||||
|
||||
func TestHooksStoredClients(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 0)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 3)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredClients()
|
||||
require.Error(t, err)
|
||||
require.Len(t, v, 0)
|
||||
}
|
||||
|
||||
func TestHooksStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 0)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 3)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredSubscriptions()
|
||||
require.Error(t, err)
|
||||
require.Len(t, v, 0)
|
||||
}
|
||||
|
||||
func TestHooksStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 0)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 3)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredRetainedMessages()
|
||||
require.Error(t, err)
|
||||
require.Len(t, v, 0)
|
||||
}
|
||||
|
||||
func TestHooksStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 0)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 3)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredInflightMessages()
|
||||
require.Error(t, err)
|
||||
require.Len(t, v, 0)
|
||||
}
|
||||
|
||||
func TestHooksStoredSysInfo(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", v.Info.Version)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2.0.0", v.Info.Version)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredSysInfo()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "", v.Info.Version)
|
||||
}
|
||||
|
||||
func TestHookBaseID(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
require.Equal(t, "base", h.ID())
|
||||
}
|
||||
|
||||
func TestHookBaseProvidesNone(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
require.False(t, h.Provides(OnConnect))
|
||||
require.False(t, h.Provides(OnDisconnect))
|
||||
}
|
||||
|
||||
func TestHookBaseInit(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
require.Nil(t, h.Init(nil))
|
||||
}
|
||||
|
||||
func TestHookBaseSetOpts(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
h.SetOpts(logger, new(HookOptions))
|
||||
require.NotNil(t, h.Log)
|
||||
require.NotNil(t, h.Opts)
|
||||
}
|
||||
|
||||
func TestHookBaseClose(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
require.Nil(t, h.Stop())
|
||||
}
|
||||
|
||||
func TestHookBaseOnConnectAuthenticate(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v := h.OnConnectAuthenticate(new(Client), packets.Packet{})
|
||||
require.False(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseOnACLCheck(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v := h.OnACLCheck(new(Client), "topic", true)
|
||||
require.False(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseOnConnect(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
err := h.OnConnect(new(Client), packets.Packet{})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHookBaseOnPublish(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHookBaseOnPacketRead(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHookBaseOnAuthPacket(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHookBaseOnLWT(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
lwt, err := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "a/b/c", lwt.TopicName)
|
||||
}
|
||||
|
||||
func TestHookBaseStoredClients(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseStoredSubscriptions(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseStoredInflightMessages(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseStoredRetainedMessages(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseStoreSysInfo(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", v.Version)
|
||||
}
|
||||
156
mqtt/inflight.go
Normal file
156
mqtt/inflight.go
Normal file
@@ -0,0 +1,156 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"testmqtt/packets"
|
||||
)
|
||||
|
||||
// Inflight is a map of InflightMessage keyed on packet id.
|
||||
type Inflight struct {
|
||||
sync.RWMutex
|
||||
internal map[uint16]packets.Packet // internal contains the inflight packets
|
||||
receiveQuota int32 // remaining inbound qos quota for flow control
|
||||
sendQuota int32 // remaining outbound qos quota for flow control
|
||||
maximumReceiveQuota int32 // maximum allowed receive quota
|
||||
maximumSendQuota int32 // maximum allowed send quota
|
||||
}
|
||||
|
||||
// NewInflights returns a new instance of an Inflight packets map.
|
||||
func NewInflights() *Inflight {
|
||||
return &Inflight{
|
||||
internal: map[uint16]packets.Packet{},
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds or updates an inflight packet by packet id.
|
||||
func (i *Inflight) Set(m packets.Packet) bool {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
_, ok := i.internal[m.PacketID]
|
||||
i.internal[m.PacketID] = m
|
||||
return !ok
|
||||
}
|
||||
|
||||
// Get returns an inflight packet by packet id.
|
||||
func (i *Inflight) Get(id uint16) (packets.Packet, bool) {
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
|
||||
if m, ok := i.internal[id]; ok {
|
||||
return m, true
|
||||
}
|
||||
|
||||
return packets.Packet{}, false
|
||||
}
|
||||
|
||||
// Len returns the size of the inflight messages map.
|
||||
func (i *Inflight) Len() int {
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
return len(i.internal)
|
||||
}
|
||||
|
||||
// Clone returns a new instance of Inflight with the same message data.
|
||||
// This is used when transferring inflights from a taken-over session.
|
||||
func (i *Inflight) Clone() *Inflight {
|
||||
c := NewInflights()
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
for k, v := range i.internal {
|
||||
c.internal[k] = v
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// GetAll returns all the inflight messages.
|
||||
func (i *Inflight) GetAll(immediate bool) []packets.Packet {
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
|
||||
m := []packets.Packet{}
|
||||
for _, v := range i.internal {
|
||||
if !immediate || (immediate && v.Expiry < 0) {
|
||||
m = append(m, v)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(m, func(i, j int) bool {
|
||||
return uint16(m[i].Created) < uint16(m[j].Created)
|
||||
})
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// NextImmediate returns the next inflight packet which is indicated to be sent immediately.
|
||||
// This typically occurs when the quota has been exhausted, and we need to wait until new quota
|
||||
// is free to continue sending.
|
||||
func (i *Inflight) NextImmediate() (packets.Packet, bool) {
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
|
||||
m := i.GetAll(true)
|
||||
if len(m) > 0 {
|
||||
return m[0], true
|
||||
}
|
||||
|
||||
return packets.Packet{}, false
|
||||
}
|
||||
|
||||
// Delete removes an in-flight message from the map. Returns true if the message existed.
|
||||
func (i *Inflight) Delete(id uint16) bool {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
_, ok := i.internal[id]
|
||||
delete(i.internal, id)
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// TakeRecieveQuota reduces the receive quota by 1.
|
||||
func (i *Inflight) DecreaseReceiveQuota() {
|
||||
if atomic.LoadInt32(&i.receiveQuota) > 0 {
|
||||
atomic.AddInt32(&i.receiveQuota, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// TakeRecieveQuota increases the receive quota by 1.
|
||||
func (i *Inflight) IncreaseReceiveQuota() {
|
||||
if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) {
|
||||
atomic.AddInt32(&i.receiveQuota, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// ResetReceiveQuota resets the receive quota to the maximum allowed value.
|
||||
func (i *Inflight) ResetReceiveQuota(n int32) {
|
||||
atomic.StoreInt32(&i.receiveQuota, n)
|
||||
atomic.StoreInt32(&i.maximumReceiveQuota, n)
|
||||
}
|
||||
|
||||
// DecreaseSendQuota reduces the send quota by 1.
|
||||
func (i *Inflight) DecreaseSendQuota() {
|
||||
if atomic.LoadInt32(&i.sendQuota) > 0 {
|
||||
atomic.AddInt32(&i.sendQuota, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// IncreaseSendQuota increases the send quota by 1.
|
||||
func (i *Inflight) IncreaseSendQuota() {
|
||||
if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) {
|
||||
atomic.AddInt32(&i.sendQuota, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// ResetSendQuota resets the send quota to the maximum allowed value.
|
||||
func (i *Inflight) ResetSendQuota(n int32) {
|
||||
atomic.StoreInt32(&i.sendQuota, n)
|
||||
atomic.StoreInt32(&i.maximumSendQuota, n)
|
||||
}
|
||||
199
mqtt/inflight_test.go
Normal file
199
mqtt/inflight_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"testmqtt/packets"
|
||||
)
|
||||
|
||||
func TestInflightSet(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
|
||||
require.True(t, r)
|
||||
require.NotNil(t, cl.State.Inflight.internal[1])
|
||||
require.NotEqual(t, 0, cl.State.Inflight.internal[1].PacketID)
|
||||
|
||||
r = cl.State.Inflight.Set(packets.Packet{PacketID: 1})
|
||||
require.False(t, r)
|
||||
}
|
||||
|
||||
func TestInflightGet(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
|
||||
msg, ok := cl.State.Inflight.Get(2)
|
||||
require.True(t, ok)
|
||||
require.NotEqual(t, 0, msg.PacketID)
|
||||
}
|
||||
|
||||
func TestInflightGetAllAndImmediate(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
|
||||
|
||||
require.Equal(t, []packets.Packet{
|
||||
{PacketID: 1, Created: 1},
|
||||
{PacketID: 2, Created: 2},
|
||||
{PacketID: 3, Created: 3, Expiry: -1},
|
||||
{PacketID: 4, Created: 4, Expiry: -1},
|
||||
{PacketID: 5, Created: 5},
|
||||
}, cl.State.Inflight.GetAll(false))
|
||||
|
||||
require.Equal(t, []packets.Packet{
|
||||
{PacketID: 3, Created: 3, Expiry: -1},
|
||||
{PacketID: 4, Created: 4, Expiry: -1},
|
||||
}, cl.State.Inflight.GetAll(true))
|
||||
}
|
||||
|
||||
func TestInflightLen(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestInflightClone(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
|
||||
cloned := cl.State.Inflight.Clone()
|
||||
require.NotNil(t, cloned)
|
||||
require.NotSame(t, cloned, cl.State.Inflight)
|
||||
}
|
||||
|
||||
func TestInflightDelete(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3})
|
||||
require.NotNil(t, cl.State.Inflight.internal[3])
|
||||
|
||||
r := cl.State.Inflight.Delete(3)
|
||||
require.True(t, r)
|
||||
require.Equal(t, uint16(0), cl.State.Inflight.internal[3].PacketID)
|
||||
|
||||
_, ok := cl.State.Inflight.Get(3)
|
||||
require.False(t, ok)
|
||||
|
||||
r = cl.State.Inflight.Delete(3)
|
||||
require.False(t, r)
|
||||
}
|
||||
|
||||
func TestResetReceiveQuota(t *testing.T) {
|
||||
i := NewInflights()
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
|
||||
i.ResetReceiveQuota(6)
|
||||
require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(6), atomic.LoadInt32(&i.receiveQuota))
|
||||
}
|
||||
|
||||
func TestReceiveQuota(t *testing.T) {
|
||||
i := NewInflights()
|
||||
i.receiveQuota = 4
|
||||
i.maximumReceiveQuota = 5
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Return 1
|
||||
i.IncreaseReceiveQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Try to go over max limit
|
||||
i.IncreaseReceiveQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Reset to max 1
|
||||
i.ResetReceiveQuota(1)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Take 1
|
||||
i.DecreaseReceiveQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Try to go below zero
|
||||
i.DecreaseReceiveQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
|
||||
}
|
||||
|
||||
func TestResetSendQuota(t *testing.T) {
|
||||
i := NewInflights()
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
|
||||
i.ResetSendQuota(6)
|
||||
require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(6), atomic.LoadInt32(&i.sendQuota))
|
||||
}
|
||||
|
||||
func TestSendQuota(t *testing.T) {
|
||||
i := NewInflights()
|
||||
i.sendQuota = 4
|
||||
i.maximumSendQuota = 5
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Return 1
|
||||
i.IncreaseSendQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Try to go over max limit
|
||||
i.IncreaseSendQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Reset to max 1
|
||||
i.ResetSendQuota(1)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Take 1
|
||||
i.DecreaseSendQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Try to go below zero
|
||||
i.DecreaseSendQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
|
||||
}
|
||||
|
||||
func TestNextImmediate(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
|
||||
|
||||
pk, ok := cl.State.Inflight.NextImmediate()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, packets.Packet{PacketID: 3, Created: 3, Expiry: -1}, pk)
|
||||
|
||||
r := cl.State.Inflight.Delete(3)
|
||||
require.True(t, r)
|
||||
|
||||
pk, ok = cl.State.Inflight.NextImmediate()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, packets.Packet{PacketID: 4, Created: 4, Expiry: -1}, pk)
|
||||
|
||||
r = cl.State.Inflight.Delete(4)
|
||||
require.True(t, r)
|
||||
|
||||
_, ok = cl.State.Inflight.NextImmediate()
|
||||
require.False(t, ok)
|
||||
}
|
||||
1759
mqtt/server.go
Normal file
1759
mqtt/server.go
Normal file
File diff suppressed because it is too large
Load Diff
3915
mqtt/server_test.go
Normal file
3915
mqtt/server_test.go
Normal file
File diff suppressed because it is too large
Load Diff
824
mqtt/topics.go
Normal file
824
mqtt/topics.go
Normal file
@@ -0,0 +1,824 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"testmqtt/packets"
|
||||
)
|
||||
|
||||
var (
|
||||
// SharePrefix 共享主题的前缀
|
||||
SharePrefix = "$SHARE" // the prefix indicating a share topic
|
||||
// SysPrefix 系统信息主题的前缀
|
||||
SysPrefix = "$SYS" // the prefix indicating a system info topic
|
||||
)
|
||||
|
||||
// TopicAliases contains inbound and outbound topic alias registrations.
|
||||
type TopicAliases struct {
|
||||
Inbound *InboundTopicAliases
|
||||
Outbound *OutboundTopicAliases
|
||||
}
|
||||
|
||||
// NewTopicAliases returns an instance of TopicAliases.
|
||||
func NewTopicAliases(topicAliasMaximum uint16) TopicAliases {
|
||||
return TopicAliases{
|
||||
Inbound: NewInboundTopicAliases(topicAliasMaximum),
|
||||
Outbound: NewOutboundTopicAliases(topicAliasMaximum),
|
||||
}
|
||||
}
|
||||
|
||||
// NewInboundTopicAliases returns a pointer to InboundTopicAliases.
|
||||
func NewInboundTopicAliases(topicAliasMaximum uint16) *InboundTopicAliases {
|
||||
return &InboundTopicAliases{
|
||||
maximum: topicAliasMaximum,
|
||||
internal: map[uint16]string{},
|
||||
}
|
||||
}
|
||||
|
||||
// InboundTopicAliases contains a map of topic aliases received from the client.
|
||||
type InboundTopicAliases struct {
|
||||
internal map[uint16]string
|
||||
sync.RWMutex
|
||||
maximum uint16
|
||||
}
|
||||
|
||||
// Set sets a new alias for a specific topic.
|
||||
func (a *InboundTopicAliases) Set(id uint16, topic string) string {
|
||||
a.Lock()
|
||||
defer a.Unlock()
|
||||
|
||||
if a.maximum == 0 {
|
||||
return topic // ?
|
||||
}
|
||||
|
||||
if existing, ok := a.internal[id]; ok && topic == "" {
|
||||
return existing
|
||||
}
|
||||
|
||||
a.internal[id] = topic
|
||||
return topic
|
||||
}
|
||||
|
||||
// OutboundTopicAliases contains a map of topic aliases sent from the broker to the client.
|
||||
type OutboundTopicAliases struct {
|
||||
internal map[string]uint16
|
||||
sync.RWMutex
|
||||
cursor uint32
|
||||
maximum uint16
|
||||
}
|
||||
|
||||
// NewOutboundTopicAliases returns a pointer to OutboundTopicAliases.
|
||||
func NewOutboundTopicAliases(topicAliasMaximum uint16) *OutboundTopicAliases {
|
||||
return &OutboundTopicAliases{
|
||||
maximum: topicAliasMaximum,
|
||||
internal: map[string]uint16{},
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets a new topic alias for a topic and returns the alias value, and a boolean
|
||||
// indicating if the alias already existed.
|
||||
func (a *OutboundTopicAliases) Set(topic string) (uint16, bool) {
|
||||
a.Lock()
|
||||
defer a.Unlock()
|
||||
|
||||
if a.maximum == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if i, ok := a.internal[topic]; ok {
|
||||
return i, true
|
||||
}
|
||||
|
||||
i := atomic.LoadUint32(&a.cursor)
|
||||
if i+1 > uint32(a.maximum) {
|
||||
// if i+1 > math.MaxUint16 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
a.internal[topic] = uint16(i) + 1
|
||||
atomic.StoreUint32(&a.cursor, i+1)
|
||||
return uint16(i) + 1, false
|
||||
}
|
||||
|
||||
// SharedSubscriptions contains a map of subscriptions to a shared filter,
|
||||
// keyed on share group then client id.
|
||||
type SharedSubscriptions struct {
|
||||
internal map[string]map[string]packets.Subscription
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSharedSubscriptions returns a new instance of Subscriptions.
|
||||
func NewSharedSubscriptions() *SharedSubscriptions {
|
||||
return &SharedSubscriptions{
|
||||
internal: map[string]map[string]packets.Subscription{},
|
||||
}
|
||||
}
|
||||
|
||||
// Add creates a new shared subscription for a group and client id pair.
|
||||
func (s *SharedSubscriptions) Add(group, id string, val packets.Subscription) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
if _, ok := s.internal[group]; !ok {
|
||||
s.internal[group] = map[string]packets.Subscription{}
|
||||
}
|
||||
s.internal[group][id] = val
|
||||
}
|
||||
|
||||
// Delete deletes a client id from a shared subscription group.
|
||||
func (s *SharedSubscriptions) Delete(group, id string) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
delete(s.internal[group], id)
|
||||
if len(s.internal[group]) == 0 {
|
||||
delete(s.internal, group)
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the subscription properties for a client id in a share group, if one exists.
|
||||
func (s *SharedSubscriptions) Get(group, id string) (val packets.Subscription, ok bool) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
if _, ok := s.internal[group]; !ok {
|
||||
return val, ok
|
||||
}
|
||||
|
||||
val, ok = s.internal[group][id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// GroupLen returns the number of groups subscribed to the filter.
|
||||
func (s *SharedSubscriptions) GroupLen() int {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val := len(s.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// Len returns the total number of shared subscriptions to a filter across all groups.
|
||||
func (s *SharedSubscriptions) Len() int {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
n := 0
|
||||
for _, group := range s.internal {
|
||||
n += len(group)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// GetAll returns all shared subscription groups and their subscriptions.
|
||||
func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscription {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
m := map[string]map[string]packets.Subscription{}
|
||||
for group, subs := range s.internal {
|
||||
if _, ok := m[group]; !ok {
|
||||
m[group] = map[string]packets.Subscription{}
|
||||
}
|
||||
|
||||
for id, sub := range subs {
|
||||
m[group][id] = sub
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// InlineSubFn is the signature for a callback function which will be called
|
||||
// when an inline client receives a message on a topic it is subscribed to.
|
||||
// The sub argument contains information about the subscription that was matched for any filters.
|
||||
type InlineSubFn func(cl *Client, sub packets.Subscription, pk packets.Packet)
|
||||
|
||||
// InlineSubscriptions represents a map of internal subscriptions keyed on client.
|
||||
type InlineSubscriptions struct {
|
||||
internal map[int]InlineSubscription
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// NewInlineSubscriptions returns a new instance of InlineSubscriptions.
|
||||
func NewInlineSubscriptions() *InlineSubscriptions {
|
||||
return &InlineSubscriptions{
|
||||
internal: map[int]InlineSubscription{},
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new internal subscription for a client id.
|
||||
func (s *InlineSubscriptions) Add(val InlineSubscription) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.internal[val.Identifier] = val
|
||||
}
|
||||
|
||||
// GetAll returns all internal subscriptions.
|
||||
func (s *InlineSubscriptions) GetAll() map[int]InlineSubscription {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
m := map[int]InlineSubscription{}
|
||||
for k, v := range s.internal {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Get returns an internal subscription for a client id.
|
||||
func (s *InlineSubscriptions) Get(id int) (val InlineSubscription, ok bool) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val, ok = s.internal[id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the number of internal subscriptions.
|
||||
func (s *InlineSubscriptions) Len() int {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val := len(s.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// Delete removes an internal subscription by the client id.
|
||||
func (s *InlineSubscriptions) Delete(id int) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
delete(s.internal, id)
|
||||
}
|
||||
|
||||
// Subscriptions is a map of subscriptions keyed on client.
|
||||
type Subscriptions struct {
|
||||
internal map[string]packets.Subscription
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSubscriptions returns a new instance of Subscriptions.
|
||||
func NewSubscriptions() *Subscriptions {
|
||||
return &Subscriptions{
|
||||
internal: map[string]packets.Subscription{},
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new subscription for a client. ID can be a filter in the
|
||||
// case this map is client state, or a client id if particle state.
|
||||
func (s *Subscriptions) Add(id string, val packets.Subscription) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.internal[id] = val
|
||||
}
|
||||
|
||||
// GetAll returns all subscriptions.
|
||||
func (s *Subscriptions) GetAll() map[string]packets.Subscription {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
m := map[string]packets.Subscription{}
|
||||
for k, v := range s.internal {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Get returns a subscriptions for a specific client or filter id.
|
||||
func (s *Subscriptions) Get(id string) (val packets.Subscription, ok bool) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val, ok = s.internal[id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the number of subscriptions.
|
||||
func (s *Subscriptions) Len() int {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val := len(s.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// Delete removes a subscription by client or filter id.
|
||||
func (s *Subscriptions) Delete(id string) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
delete(s.internal, id)
|
||||
}
|
||||
|
||||
// ClientSubscriptions is a map of aggregated subscriptions for a client.
|
||||
type ClientSubscriptions map[string]packets.Subscription
|
||||
|
||||
type InlineSubscription struct {
|
||||
packets.Subscription
|
||||
Handler InlineSubFn
|
||||
}
|
||||
|
||||
// Subscribers contains the shared and non-shared subscribers matching a topic.
|
||||
type Subscribers struct {
|
||||
Shared map[string]map[string]packets.Subscription
|
||||
SharedSelected map[string]packets.Subscription
|
||||
Subscriptions map[string]packets.Subscription
|
||||
InlineSubscriptions map[int]InlineSubscription
|
||||
}
|
||||
|
||||
// SelectShared returns one subscriber for each shared subscription group.
|
||||
func (s *Subscribers) SelectShared() {
|
||||
s.SharedSelected = map[string]packets.Subscription{}
|
||||
for _, subs := range s.Shared {
|
||||
for client, sub := range subs {
|
||||
cls, ok := s.SharedSelected[client]
|
||||
if !ok {
|
||||
cls = sub
|
||||
}
|
||||
|
||||
s.SharedSelected[client] = cls.Merge(sub)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MergeSharedSelected merges the selected subscribers for a shared subscription group
|
||||
// and the non-shared subscribers, to ensure that no subscriber gets multiple messages
|
||||
// due to have both types of subscription matching the same filter.
|
||||
func (s *Subscribers) MergeSharedSelected() {
|
||||
for client, sub := range s.SharedSelected {
|
||||
cls, ok := s.Subscriptions[client]
|
||||
if !ok {
|
||||
cls = sub
|
||||
}
|
||||
|
||||
s.Subscriptions[client] = cls.Merge(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// TopicsIndex is a prefix/trie tree containing topic subscribers and retained messages.
|
||||
type TopicsIndex struct {
|
||||
Retained *packets.Packets
|
||||
root *particle // a leaf containing a message and more leaves.
|
||||
}
|
||||
|
||||
// NewTopicsIndex returns a pointer to a new instance of Index.
|
||||
func NewTopicsIndex() *TopicsIndex {
|
||||
return &TopicsIndex{
|
||||
Retained: packets.NewPackets(),
|
||||
root: &particle{
|
||||
particles: newParticles(),
|
||||
subscriptions: NewSubscriptions(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// InlineSubscribe adds a new internal subscription for a topic filter, returning
|
||||
// true if the subscription was new.
|
||||
func (x *TopicsIndex) InlineSubscribe(subscription InlineSubscription) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
var existed bool
|
||||
n := x.set(subscription.Filter, 0)
|
||||
_, existed = n.inlineSubscriptions.Get(subscription.Identifier)
|
||||
n.inlineSubscriptions.Add(subscription)
|
||||
|
||||
return !existed
|
||||
}
|
||||
|
||||
// InlineUnsubscribe removes an internal subscription for a topic filter associated with a specific client,
|
||||
// returning true if the subscription existed.
|
||||
func (x *TopicsIndex) InlineUnsubscribe(id int, filter string) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
particle := x.seek(filter, 0)
|
||||
if particle == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
particle.inlineSubscriptions.Delete(id)
|
||||
|
||||
if particle.inlineSubscriptions.Len() == 0 {
|
||||
x.trim(particle)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Subscribe adds a new subscription for a client to a topic filter, returning
|
||||
// true if the subscription was new.
|
||||
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
var existed bool
|
||||
prefix, _ := isolateParticle(subscription.Filter, 0)
|
||||
if strings.EqualFold(prefix, SharePrefix) {
|
||||
group, _ := isolateParticle(subscription.Filter, 1)
|
||||
n := x.set(subscription.Filter, 2)
|
||||
_, existed = n.shared.Get(group, client)
|
||||
n.shared.Add(group, client, subscription)
|
||||
} else {
|
||||
n := x.set(subscription.Filter, 0)
|
||||
_, existed = n.subscriptions.Get(client)
|
||||
n.subscriptions.Add(client, subscription)
|
||||
}
|
||||
|
||||
return !existed
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscription filter for a client, returning true if the
|
||||
// subscription existed.
|
||||
func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
var d int
|
||||
prefix, _ := isolateParticle(filter, 0)
|
||||
shareSub := strings.EqualFold(prefix, SharePrefix)
|
||||
if shareSub {
|
||||
d = 2
|
||||
}
|
||||
|
||||
particle := x.seek(filter, d)
|
||||
if particle == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if shareSub {
|
||||
group, _ := isolateParticle(filter, 1)
|
||||
particle.shared.Delete(group, client)
|
||||
} else {
|
||||
particle.subscriptions.Delete(client)
|
||||
}
|
||||
|
||||
x.trim(particle)
|
||||
return true
|
||||
}
|
||||
|
||||
// RetainMessage saves a message payload to the end of a topic address. Returns
|
||||
// 1 if a retained message was added, and -1 if the retained message was removed.
|
||||
// 0 is returned if sequential empty payloads are received.
|
||||
func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
n := x.set(pk.TopicName, 0)
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
if len(pk.Payload) > 0 {
|
||||
n.retainPath = pk.TopicName
|
||||
x.Retained.Add(pk.TopicName, pk)
|
||||
return 1
|
||||
}
|
||||
|
||||
var out int64
|
||||
if pke, ok := x.Retained.Get(pk.TopicName); ok && len(pke.Payload) > 0 && pke.FixedHeader.Retain {
|
||||
out = -1 // if a retained packet existed, return -1
|
||||
}
|
||||
|
||||
n.retainPath = ""
|
||||
x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7]
|
||||
x.trim(n)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// set creates a topic address in the index and returns the final particle.
|
||||
func (x *TopicsIndex) set(topic string, d int) *particle {
|
||||
var key string
|
||||
var hasNext = true
|
||||
n := x.root
|
||||
for hasNext {
|
||||
key, hasNext = isolateParticle(topic, d)
|
||||
d++
|
||||
|
||||
p := n.particles.get(key)
|
||||
if p == nil {
|
||||
p = newParticle(key, n)
|
||||
n.particles.add(p)
|
||||
}
|
||||
n = p
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// seek finds the particle at a specific index in a topic filter.
|
||||
func (x *TopicsIndex) seek(filter string, d int) *particle {
|
||||
var key string
|
||||
var hasNext = true
|
||||
n := x.root
|
||||
for hasNext {
|
||||
key, hasNext = isolateParticle(filter, d)
|
||||
n = n.particles.get(key)
|
||||
d++
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// trim removes empty filter particles from the index.
|
||||
func (x *TopicsIndex) trim(n *particle) {
|
||||
for n.parent != nil && n.retainPath == "" && n.particles.len()+n.subscriptions.Len()+n.shared.Len()+n.inlineSubscriptions.Len() == 0 {
|
||||
key := n.key
|
||||
n = n.parent
|
||||
n.particles.delete(key)
|
||||
}
|
||||
}
|
||||
|
||||
// Messages returns a slice of any retained messages which match a filter.
|
||||
func (x *TopicsIndex) Messages(filter string) []packets.Packet {
|
||||
return x.scanMessages(filter, 0, nil, []packets.Packet{})
|
||||
}
|
||||
|
||||
// scanMessages returns all retained messages on topics matching a given filter.
|
||||
func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []packets.Packet) []packets.Packet {
|
||||
if n == nil {
|
||||
n = x.root
|
||||
}
|
||||
|
||||
if len(filter) == 0 || x.Retained.Len() == 0 {
|
||||
return pks
|
||||
}
|
||||
|
||||
if !strings.ContainsRune(filter, '#') && !strings.ContainsRune(filter, '+') {
|
||||
if pk, ok := x.Retained.Get(filter); ok {
|
||||
pks = append(pks, pk)
|
||||
}
|
||||
return pks
|
||||
}
|
||||
|
||||
key, hasNext := isolateParticle(filter, d)
|
||||
if key == "+" || key == "#" || d == -1 {
|
||||
for _, adjacent := range n.particles.getAll() {
|
||||
if d == 0 && adjacent.key == SysPrefix {
|
||||
continue
|
||||
}
|
||||
|
||||
if !hasNext {
|
||||
if adjacent.retainPath != "" {
|
||||
if pk, ok := x.Retained.Get(adjacent.retainPath); ok {
|
||||
pks = append(pks, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasNext || (d >= 0 && key == "#") {
|
||||
pks = x.scanMessages(filter, d+1, adjacent, pks)
|
||||
}
|
||||
}
|
||||
return pks
|
||||
}
|
||||
|
||||
if particle := n.particles.get(key); particle != nil {
|
||||
if hasNext {
|
||||
return x.scanMessages(filter, d+1, particle, pks)
|
||||
}
|
||||
|
||||
if pk, ok := x.Retained.Get(particle.retainPath); ok {
|
||||
pks = append(pks, pk)
|
||||
}
|
||||
}
|
||||
|
||||
return pks
|
||||
}
|
||||
|
||||
// Subscribers returns a map of clients who are subscribed to matching filters,
|
||||
// their subscription ids and highest qos.
|
||||
func (x *TopicsIndex) Subscribers(topic string) *Subscribers {
|
||||
return x.scanSubscribers(topic, 0, nil, &Subscribers{
|
||||
Shared: map[string]map[string]packets.Subscription{},
|
||||
SharedSelected: map[string]packets.Subscription{},
|
||||
Subscriptions: map[string]packets.Subscription{},
|
||||
InlineSubscriptions: map[int]InlineSubscription{},
|
||||
})
|
||||
}
|
||||
|
||||
// scanSubscribers returns a list of client subscriptions matching an indexed topic address.
|
||||
func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Subscribers) *Subscribers {
|
||||
if n == nil {
|
||||
n = x.root
|
||||
}
|
||||
|
||||
if len(topic) == 0 {
|
||||
return subs
|
||||
}
|
||||
|
||||
key, hasNext := isolateParticle(topic, d)
|
||||
for _, partKey := range []string{key, "+"} {
|
||||
if particle := n.particles.get(partKey); particle != nil { // [MQTT-3.3.2-3]
|
||||
if hasNext {
|
||||
x.scanSubscribers(topic, d+1, particle, subs)
|
||||
} else {
|
||||
x.gatherSubscriptions(topic, particle, subs)
|
||||
x.gatherSharedSubscriptions(particle, subs)
|
||||
x.gatherInlineSubscriptions(particle, subs)
|
||||
|
||||
if wild := particle.particles.get("#"); wild != nil && partKey != "+" {
|
||||
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
|
||||
x.gatherSharedSubscriptions(wild, subs)
|
||||
x.gatherInlineSubscriptions(particle, subs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if particle := n.particles.get("#"); particle != nil {
|
||||
x.gatherSubscriptions(topic, particle, subs)
|
||||
x.gatherSharedSubscriptions(particle, subs)
|
||||
x.gatherInlineSubscriptions(particle, subs)
|
||||
}
|
||||
|
||||
return subs
|
||||
}
|
||||
|
||||
// gatherSubscriptions collects any matching subscriptions, and gathers any identifiers or highest qos values.
|
||||
func (x *TopicsIndex) gatherSubscriptions(topic string, particle *particle, subs *Subscribers) {
|
||||
if subs.Subscriptions == nil {
|
||||
subs.Subscriptions = map[string]packets.Subscription{}
|
||||
}
|
||||
|
||||
for client, sub := range particle.subscriptions.GetAll() {
|
||||
if len(sub.Filter) > 0 && topic[0] == '$' && (sub.Filter[0] == '+' || sub.Filter[0] == '#') { // don't match $ topics with top level wildcards [MQTT-4.7.1-1] [MQTT-4.7.1-2]
|
||||
continue
|
||||
}
|
||||
|
||||
cls, ok := subs.Subscriptions[client]
|
||||
if !ok {
|
||||
cls = sub
|
||||
}
|
||||
|
||||
subs.Subscriptions[client] = cls.Merge(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// gatherSharedSubscriptions gathers all shared subscriptions for a particle.
|
||||
func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscribers) {
|
||||
if subs.Shared == nil {
|
||||
subs.Shared = map[string]map[string]packets.Subscription{}
|
||||
}
|
||||
|
||||
for _, shares := range particle.shared.GetAll() {
|
||||
for client, sub := range shares {
|
||||
if _, ok := subs.Shared[sub.Filter]; !ok {
|
||||
subs.Shared[sub.Filter] = map[string]packets.Subscription{}
|
||||
}
|
||||
|
||||
subs.Shared[sub.Filter][client] = sub
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// gatherSharedSubscriptions gathers all inline subscriptions for a particle.
|
||||
func (x *TopicsIndex) gatherInlineSubscriptions(particle *particle, subs *Subscribers) {
|
||||
if subs.InlineSubscriptions == nil {
|
||||
subs.InlineSubscriptions = map[int]InlineSubscription{}
|
||||
}
|
||||
|
||||
for id, inline := range particle.inlineSubscriptions.GetAll() {
|
||||
subs.InlineSubscriptions[id] = inline
|
||||
}
|
||||
}
|
||||
|
||||
// isolateParticle extracts a particle between d / and d+1 / without allocations.
|
||||
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
|
||||
var next, end int
|
||||
for i := 0; end > -1 && i <= d; i++ {
|
||||
end = strings.IndexRune(filter, '/')
|
||||
|
||||
switch {
|
||||
case d > -1 && i == d && end > -1:
|
||||
hasNext = true
|
||||
particle = filter[next:end]
|
||||
case end > -1:
|
||||
hasNext = false
|
||||
filter = filter[end+1:]
|
||||
default:
|
||||
hasNext = false
|
||||
particle = filter[next:]
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// IsSharedFilter returns true if the filter uses the share prefix.
|
||||
func IsSharedFilter(filter string) bool {
|
||||
prefix, _ := isolateParticle(filter, 0)
|
||||
return strings.EqualFold(prefix, SharePrefix)
|
||||
}
|
||||
|
||||
// IsValidFilter returns true if the filter is valid.
|
||||
func IsValidFilter(filter string, forPublish bool) bool {
|
||||
if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publish.
|
||||
return false // [MQTT-4.7.3-1]
|
||||
}
|
||||
|
||||
if forPublish {
|
||||
if len(filter) >= len(SysPrefix) && strings.EqualFold(filter[0:len(SysPrefix)], SysPrefix) {
|
||||
// 4.7.2 Non-normative - The Server SHOULD prevent Clients from using such Topic Names [$SYS] to exchange messages with other Clients.
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.ContainsRune(filter, '+') || strings.ContainsRune(filter, '#') {
|
||||
return false //[MQTT-3.3.2-2]
|
||||
}
|
||||
}
|
||||
|
||||
wildhash := strings.IndexRune(filter, '#')
|
||||
if wildhash >= 0 && wildhash != len(filter)-1 { // [MQTT-4.7.1-2]
|
||||
return false
|
||||
}
|
||||
|
||||
prefix, hasNext := isolateParticle(filter, 0)
|
||||
if !hasNext && strings.EqualFold(prefix, SharePrefix) {
|
||||
return false // [MQTT-4.8.2-1]
|
||||
}
|
||||
|
||||
if hasNext && strings.EqualFold(prefix, SharePrefix) {
|
||||
group, hasNext := isolateParticle(filter, 1)
|
||||
if !hasNext {
|
||||
return false // [MQTT-4.8.2-1]
|
||||
}
|
||||
|
||||
if strings.ContainsRune(group, '+') || strings.ContainsRune(group, '#') {
|
||||
return false // [MQTT-4.8.2-2]
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// particle is a child node on the tree.
|
||||
type particle struct {
|
||||
key string // the key of the particle
|
||||
parent *particle // a pointer to the parent of the particle
|
||||
particles particles // a map of child particles
|
||||
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
|
||||
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
|
||||
inlineSubscriptions *InlineSubscriptions // a map of inline subscriptions for this particle
|
||||
retainPath string // path of a retained message
|
||||
sync.Mutex // mutex for when making changes to the particle
|
||||
}
|
||||
|
||||
// newParticle returns a pointer to a new instance of particle.
|
||||
func newParticle(key string, parent *particle) *particle {
|
||||
return &particle{
|
||||
key: key,
|
||||
parent: parent,
|
||||
particles: newParticles(),
|
||||
subscriptions: NewSubscriptions(),
|
||||
shared: NewSharedSubscriptions(),
|
||||
inlineSubscriptions: NewInlineSubscriptions(),
|
||||
}
|
||||
}
|
||||
|
||||
// particles is a concurrency safe map of particles.
|
||||
type particles struct {
|
||||
internal map[string]*particle
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// newParticles returns a map of particles.
|
||||
func newParticles() particles {
|
||||
return particles{
|
||||
internal: map[string]*particle{},
|
||||
}
|
||||
}
|
||||
|
||||
// add adds a new particle.
|
||||
func (p *particles) add(val *particle) {
|
||||
p.Lock()
|
||||
p.internal[val.key] = val
|
||||
p.Unlock()
|
||||
}
|
||||
|
||||
// getAll returns all particles.
|
||||
func (p *particles) getAll() map[string]*particle {
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
m := map[string]*particle{}
|
||||
for k, v := range p.internal {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// get returns a particle by id (key).
|
||||
func (p *particles) get(id string) *particle {
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
return p.internal[id]
|
||||
}
|
||||
|
||||
// len returns the number of particles.
|
||||
func (p *particles) len() int {
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
val := len(p.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// delete removes a particle.
|
||||
func (p *particles) delete(id string) {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
delete(p.internal, id)
|
||||
}
|
||||
1068
mqtt/topics_test.go
Normal file
1068
mqtt/topics_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user