代码整理

This commit is contained in:
2024-08-21 15:32:05 +08:00
commit 84e5e65ee7
71 changed files with 29733 additions and 0 deletions

172
packets/codec.go Normal file
View File

@@ -0,0 +1,172 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"encoding/binary"
"io"
"unicode/utf8"
"unsafe"
)
// bytesToString provides a zero-alloc no-copy byte to string conversion.
// via https://github.com/golang/go/issues/25484#issuecomment-391415660
func bytesToString(bs []byte) string {
return *(*string)(unsafe.Pointer(&bs))
}
// decodeUint16 extracts the value of two bytes from a byte array.
func decodeUint16(buf []byte, offset int) (uint16, int, error) {
if len(buf) < offset+2 {
return 0, 0, ErrMalformedOffsetUintOutOfRange
}
return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil
}
// decodeUint32 extracts the value of four bytes from a byte array.
func decodeUint32(buf []byte, offset int) (uint32, int, error) {
if len(buf) < offset+4 {
return 0, 0, ErrMalformedOffsetUintOutOfRange
}
return binary.BigEndian.Uint32(buf[offset : offset+4]), offset + 4, nil
}
// decodeString extracts a string from a byte array, beginning at an offset.
func decodeString(buf []byte, offset int) (string, int, error) {
b, n, err := decodeBytes(buf, offset)
if err != nil {
return "", 0, err
}
if !validUTF8(b) { // [MQTT-1.5.4-1] [MQTT-3.1.3-5]
return "", 0, ErrMalformedInvalidUTF8
}
return bytesToString(b), n, nil
}
// validUTF8 checks if the byte array contains valid UTF-8 characters.
func validUTF8(b []byte) bool {
return utf8.Valid(b) && bytes.IndexByte(b, 0x00) == -1 // [MQTT-1.5.4-1] [MQTT-1.5.4-2]
}
// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads.
func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
length, next, err := decodeUint16(buf, offset)
if err != nil {
return make([]byte, 0), 0, err
}
if next+int(length) > len(buf) {
return make([]byte, 0), 0, ErrMalformedOffsetBytesOutOfRange
}
return buf[next : next+int(length)], next + int(length), nil
}
// decodeByte extracts the value of a byte from a byte array.
func decodeByte(buf []byte, offset int) (byte, int, error) {
if len(buf) <= offset {
return 0, 0, ErrMalformedOffsetByteOutOfRange
}
return buf[offset], offset + 1, nil
}
// decodeByteBool extracts the value of a byte from a byte array and returns a bool.
func decodeByteBool(buf []byte, offset int) (bool, int, error) {
if len(buf) <= offset {
return false, 0, ErrMalformedOffsetBoolOutOfRange
}
return 1&buf[offset] > 0, offset + 1, nil
}
// encodeBool returns a byte instead of a bool.
func encodeBool(b bool) byte {
if b {
return 1
}
return 0
}
// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads.
func encodeBytes(val []byte) []byte {
// In most circumstances the number of bytes being encoded is small.
// Setting the cap to a low amount allows us to account for those without
// triggering allocation growth on append unless we need to.
buf := make([]byte, 2, 32)
binary.BigEndian.PutUint16(buf, uint16(len(val)))
return append(buf, val...)
}
// encodeUint16 encodes a uint16 value to a byte array.
func encodeUint16(val uint16) []byte {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, val)
return buf
}
// encodeUint32 encodes a uint16 value to a byte array.
func encodeUint32(val uint32) []byte {
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, val)
return buf
}
// encodeString encodes a string to a byte array.
func encodeString(val string) []byte {
// Like encodeBytes, we set the cap to a small number to avoid
// triggering allocation growth on append unless we absolutely need to.
buf := make([]byte, 2, 32)
binary.BigEndian.PutUint16(buf, uint16(len(val)))
return append(buf, []byte(val)...)
}
// encodeLength writes length bits for the header.
func encodeLength(b *bytes.Buffer, length int64) {
// 1.5.5 Variable Byte Integer encode non-normative
// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
for {
eb := byte(length % 128)
length /= 128
if length > 0 {
eb |= 0x80
}
b.WriteByte(eb)
if length == 0 {
break // [MQTT-1.5.5-1]
}
}
}
func DecodeLength(b io.ByteReader) (n, bu int, err error) {
// see 1.5.5 Variable Byte Integer decode non-normative
// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
var multiplier uint32
var value uint32
bu = 1
for {
eb, err := b.ReadByte()
if err != nil {
return 0, bu, err
}
value |= uint32(eb&127) << multiplier
if value > 268435455 {
return 0, bu, ErrMalformedVariableByteInteger
}
if (eb & 128) == 0 {
break
}
multiplier += 7
bu++
}
return int(value), bu, nil
}

422
packets/codec_test.go Normal file
View File

@@ -0,0 +1,422 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"errors"
"fmt"
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestBytesToString(t *testing.T) {
b := []byte{'a', 'b', 'c'}
require.Equal(t, "abc", bytesToString(b))
}
func TestDecodeString(t *testing.T) {
expect := []struct {
name string
rawBytes []byte
result string
offset int
shouldFail error
}{
{
offset: 0,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: "a/b/c/d",
},
{
offset: 14,
rawBytes: []byte{
Connect << 4, 17, // Fixed header
0, 6, // Protocol Name - MSB+LSB
'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name
3, // Protocol Version
0, // Packet Flags
0, 30, // Keepalive
0, 3, // Client ID - MSB+LSB
'h', 'e', 'y', // Client ID "zen"},
},
result: "hey",
},
{
offset: 2,
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
result: "1/2/3/4/a/b/c/d/e/^/@/!",
},
{
offset: 0,
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
result: "x/y/z",
},
{
offset: 0,
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
offset: 5,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
offset: 9,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
{
offset: 17,
rawBytes: []byte{
Connect << 4, 0, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
0, // Flags
0, 20, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
0, 6, // Will Topic - MSB+LSB
'l',
},
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
offset: 0,
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
shouldFail: ErrMalformedInvalidUTF8,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
func TestDecodeStringZeroWidthNoBreak(t *testing.T) { // [MQTT-1.5.4-3]
result, _, err := decodeString([]byte{0, 3, 0xEF, 0xBB, 0xBF}, 0)
require.NoError(t, err)
require.Equal(t, "\ufeff", result)
}
func TestDecodeBytes(t *testing.T) {
expect := []struct {
rawBytes []byte
result []uint8
next int
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // truncated connect packet (clean session)
result: []byte{0x4d, 0x51, 0x54, 0x54},
next: 6,
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // truncated connect packet, only checking start
result: []byte{0x4d, 0x51, 0x54, 0x54},
next: 6,
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81},
offset: 0,
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
rawBytes: []byte{0, 4, 77, 81},
offset: 8,
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
func TestDecodeByte(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint8
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
result: uint8(0x00),
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x04),
offset: 1,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x4d),
offset: 2,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x51),
offset: 3,
},
{
rawBytes: []byte{0, 4, 77, 80, 82, 84},
offset: 8,
shouldFail: ErrMalformedOffsetByteOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+1, offset)
})
}
}
func TestDecodeUint16(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint16
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: uint16(0x07),
offset: 0,
},
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: uint16(0x761),
offset: 1,
},
{
rawBytes: []byte{0, 7, 255, 47},
offset: 8,
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+2, offset)
})
}
}
func TestDecodeUint32(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint32
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 0, 0, 7, 8},
result: uint32(7),
offset: 0,
},
{
rawBytes: []byte{0, 0, 1, 226, 64, 8},
result: uint32(123456),
offset: 1,
},
{
rawBytes: []byte{0, 7, 255, 47},
offset: 8,
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint32(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+4, offset)
})
}
}
func TestDecodeByteBool(t *testing.T) {
expect := []struct {
rawBytes []byte
result bool
offset int
shouldFail error
}{
{
rawBytes: []byte{0x00, 0x00},
result: false,
},
{
rawBytes: []byte{0x01, 0x00},
result: true,
},
{
rawBytes: []byte{0x01, 0x00},
offset: 5,
shouldFail: ErrMalformedOffsetBoolOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, 1, offset)
})
}
}
func TestDecodeLength(t *testing.T) {
b := bytes.NewBuffer([]byte{0x78})
n, bu, err := DecodeLength(b)
require.NoError(t, err)
require.Equal(t, 120, n)
require.Equal(t, 1, bu)
b = bytes.NewBuffer([]byte{255, 255, 255, 127})
n, bu, err = DecodeLength(b)
require.NoError(t, err)
require.Equal(t, 268435455, n)
require.Equal(t, 4, bu)
}
func TestDecodeLengthErrors(t *testing.T) {
b := bytes.NewBuffer([]byte{})
_, _, err := DecodeLength(b)
require.Error(t, err)
b = bytes.NewBuffer([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f})
_, _, err = DecodeLength(b)
require.Error(t, err)
require.ErrorIs(t, err, ErrMalformedVariableByteInteger)
}
func TestEncodeBool(t *testing.T) {
result := encodeBool(true)
require.Equal(t, byte(1), result)
result = encodeBool(false)
require.Equal(t, byte(0), result)
// Check failure.
result = encodeBool(false)
require.NotEqual(t, byte(1), result)
}
func TestEncodeBytes(t *testing.T) {
result := encodeBytes([]byte("testing"))
require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result)
result = encodeBytes([]byte("testing"))
require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result)
}
func TestEncodeUint16(t *testing.T) {
result := encodeUint16(0)
require.Equal(t, []byte{0x00, 0x00}, result)
result = encodeUint16(32767)
require.Equal(t, []byte{0x7f, 0xff}, result)
result = encodeUint16(math.MaxUint16)
require.Equal(t, []byte{0xff, 0xff}, result)
}
func TestEncodeUint32(t *testing.T) {
result := encodeUint32(7)
require.Equal(t, []byte{0x00, 0x00, 0x00, 0x07}, result)
result = encodeUint32(32767)
require.Equal(t, []byte{0, 0, 127, 255}, result)
result = encodeUint32(math.MaxUint32)
require.Equal(t, []byte{255, 255, 255, 255}, result)
}
func TestEncodeString(t *testing.T) {
result := encodeString("testing")
require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result)
result = encodeString("")
require.Equal(t, []uint8{0x00, 0x00}, result)
result = encodeString("a")
require.Equal(t, []uint8{0x00, 0x01, 0x61}, result)
result = encodeString("b")
require.NotEqual(t, []uint8{0x00, 0x00}, result)
}
func TestEncodeLength(t *testing.T) {
b := new(bytes.Buffer)
encodeLength(b, 120)
require.Equal(t, []byte{0x78}, b.Bytes())
b = new(bytes.Buffer)
encodeLength(b, math.MaxInt64)
require.Equal(t, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, b.Bytes())
}
func TestValidUTF8(t *testing.T) {
require.True(t, validUTF8([]byte{0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}))
require.False(t, validUTF8([]byte{0xff, 0xff}))
require.False(t, validUTF8([]byte{0x74, 0x00, 0x73, 0x74}))
}

149
packets/codes.go Normal file
View File

@@ -0,0 +1,149 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
// Code contains a reason code and reason string for a response.
type Code struct {
Reason string
Code byte
}
// String returns the readable reason for a code.
func (c Code) String() string {
return c.Reason
}
// Error returns the readable reason for a code.
func (c Code) Error() string {
return c.Reason
}
var (
// QosCodes indicates the reason codes for each Qos byte.
QosCodes = map[byte]Code{
0: CodeGrantedQos0,
1: CodeGrantedQos1,
2: CodeGrantedQos2,
}
CodeSuccessIgnore = Code{Code: 0x00, Reason: "ignore packet"}
CodeSuccess = Code{Code: 0x00, Reason: "success"}
CodeDisconnect = Code{Code: 0x00, Reason: "disconnected"}
CodeGrantedQos0 = Code{Code: 0x00, Reason: "granted qos 0"}
CodeGrantedQos1 = Code{Code: 0x01, Reason: "granted qos 1"}
CodeGrantedQos2 = Code{Code: 0x02, Reason: "granted qos 2"}
CodeDisconnectWillMessage = Code{Code: 0x04, Reason: "disconnect with will message"}
CodeNoMatchingSubscribers = Code{Code: 0x10, Reason: "no matching subscribers"}
CodeNoSubscriptionExisted = Code{Code: 0x11, Reason: "no subscription existed"}
CodeContinueAuthentication = Code{Code: 0x18, Reason: "continue authentication"}
CodeReAuthenticate = Code{Code: 0x19, Reason: "re-authenticate"}
ErrUnspecifiedError = Code{Code: 0x80, Reason: "unspecified error"}
ErrMalformedPacket = Code{Code: 0x81, Reason: "malformed packet"}
ErrMalformedProtocolName = Code{Code: 0x81, Reason: "malformed packet: protocol name"}
ErrMalformedProtocolVersion = Code{Code: 0x81, Reason: "malformed packet: protocol version"}
ErrMalformedFlags = Code{Code: 0x81, Reason: "malformed packet: flags"}
ErrMalformedKeepalive = Code{Code: 0x81, Reason: "malformed packet: keepalive"}
ErrMalformedPacketID = Code{Code: 0x81, Reason: "malformed packet: packet identifier"}
ErrMalformedTopic = Code{Code: 0x81, Reason: "malformed packet: topic"}
ErrMalformedWillTopic = Code{Code: 0x81, Reason: "malformed packet: will topic"}
ErrMalformedWillPayload = Code{Code: 0x81, Reason: "malformed packet: will message"}
ErrMalformedUsername = Code{Code: 0x81, Reason: "malformed packet: username"}
ErrMalformedPassword = Code{Code: 0x81, Reason: "malformed packet: password"}
ErrMalformedQos = Code{Code: 0x81, Reason: "malformed packet: qos"}
ErrMalformedOffsetUintOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset uint out of range"}
ErrMalformedOffsetBytesOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset bytes out of range"}
ErrMalformedOffsetByteOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset byte out of range"}
ErrMalformedOffsetBoolOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset boolean out of range"}
ErrMalformedInvalidUTF8 = Code{Code: 0x81, Reason: "malformed packet: invalid utf-8 string"}
ErrMalformedVariableByteInteger = Code{Code: 0x81, Reason: "malformed packet: variable byte integer out of range"}
ErrMalformedBadProperty = Code{Code: 0x81, Reason: "malformed packet: unknown property"}
ErrMalformedProperties = Code{Code: 0x81, Reason: "malformed packet: properties"}
ErrMalformedWillProperties = Code{Code: 0x81, Reason: "malformed packet: will properties"}
ErrMalformedSessionPresent = Code{Code: 0x81, Reason: "malformed packet: session present"}
ErrMalformedReasonCode = Code{Code: 0x81, Reason: "malformed packet: reason code"}
ErrProtocolViolation = Code{Code: 0x82, Reason: "protocol violation"}
ErrProtocolViolationProtocolName = Code{Code: 0x82, Reason: "protocol violation: protocol name"}
ErrProtocolViolationProtocolVersion = Code{Code: 0x82, Reason: "protocol violation: protocol version"}
ErrProtocolViolationReservedBit = Code{Code: 0x82, Reason: "protocol violation: reserved bit not 0"}
ErrProtocolViolationFlagNoUsername = Code{Code: 0x82, Reason: "protocol violation: username flag set but no value"}
ErrProtocolViolationFlagNoPassword = Code{Code: 0x82, Reason: "protocol violation: password flag set but no value"}
ErrProtocolViolationUsernameNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
ErrProtocolViolationPasswordNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
ErrProtocolViolationPasswordTooLong = Code{Code: 0x82, Reason: "protocol violation: password too long"}
ErrProtocolViolationUsernameTooLong = Code{Code: 0x82, Reason: "protocol violation: username too long"}
ErrProtocolViolationNoPacketID = Code{Code: 0x82, Reason: "protocol violation: missing packet id"}
ErrProtocolViolationSurplusPacketID = Code{Code: 0x82, Reason: "protocol violation: surplus packet id"}
ErrProtocolViolationQosOutOfRange = Code{Code: 0x82, Reason: "protocol violation: qos out of range"}
ErrProtocolViolationSecondConnect = Code{Code: 0x82, Reason: "protocol violation: second connect packet"}
ErrProtocolViolationZeroNonZeroExpiry = Code{Code: 0x82, Reason: "protocol violation: non-zero expiry"}
ErrProtocolViolationRequireFirstConnect = Code{Code: 0x82, Reason: "protocol violation: first packet must be connect"}
ErrProtocolViolationWillFlagNoPayload = Code{Code: 0x82, Reason: "protocol violation: will flag no payload"}
ErrProtocolViolationWillFlagSurplusRetain = Code{Code: 0x82, Reason: "protocol violation: will flag surplus retain"}
ErrProtocolViolationSurplusWildcard = Code{Code: 0x82, Reason: "protocol violation: topic contains wildcards"}
ErrProtocolViolationSurplusSubID = Code{Code: 0x82, Reason: "protocol violation: contained subscription identifier"}
ErrProtocolViolationInvalidTopic = Code{Code: 0x82, Reason: "protocol violation: invalid topic"}
ErrProtocolViolationInvalidSharedNoLocal = Code{Code: 0x82, Reason: "protocol violation: invalid shared no local"}
ErrProtocolViolationNoFilters = Code{Code: 0x82, Reason: "protocol violation: must contain at least one filter"}
ErrProtocolViolationInvalidReason = Code{Code: 0x82, Reason: "protocol violation: invalid reason"}
ErrProtocolViolationOversizeSubID = Code{Code: 0x82, Reason: "protocol violation: oversize subscription id"}
ErrProtocolViolationDupNoQos = Code{Code: 0x82, Reason: "protocol violation: dup true with no qos"}
ErrProtocolViolationUnsupportedProperty = Code{Code: 0x82, Reason: "protocol violation: unsupported property"}
ErrProtocolViolationNoTopic = Code{Code: 0x82, Reason: "protocol violation: no topic or alias"}
ErrImplementationSpecificError = Code{Code: 0x83, Reason: "implementation specific error"}
ErrRejectPacket = Code{Code: 0x83, Reason: "packet rejected"}
ErrUnsupportedProtocolVersion = Code{Code: 0x84, Reason: "unsupported protocol version"}
ErrClientIdentifierNotValid = Code{Code: 0x85, Reason: "client identifier not valid"}
ErrClientIdentifierTooLong = Code{Code: 0x85, Reason: "client identifier too long"}
ErrBadUsernameOrPassword = Code{Code: 0x86, Reason: "bad username or password"}
ErrNotAuthorized = Code{Code: 0x87, Reason: "not authorized"}
ErrServerUnavailable = Code{Code: 0x88, Reason: "server unavailable"}
ErrServerBusy = Code{Code: 0x89, Reason: "server busy"}
ErrBanned = Code{Code: 0x8A, Reason: "banned"}
ErrServerShuttingDown = Code{Code: 0x8B, Reason: "server shutting down"}
ErrBadAuthenticationMethod = Code{Code: 0x8C, Reason: "bad authentication method"}
ErrKeepAliveTimeout = Code{Code: 0x8D, Reason: "keep alive timeout"}
ErrSessionTakenOver = Code{Code: 0x8E, Reason: "session takeover"}
ErrTopicFilterInvalid = Code{Code: 0x8F, Reason: "topic filter invalid"}
ErrTopicNameInvalid = Code{Code: 0x90, Reason: "topic name invalid"}
ErrPacketIdentifierInUse = Code{Code: 0x91, Reason: "packet identifier in use"}
ErrPacketIdentifierNotFound = Code{Code: 0x92, Reason: "packet identifier not found"}
ErrReceiveMaximum = Code{Code: 0x93, Reason: "receive maximum exceeded"}
ErrTopicAliasInvalid = Code{Code: 0x94, Reason: "topic alias invalid"}
ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"}
ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}
ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptions not supported"}
ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
ErrInlineSubscriptionHandlerInvalid = Code{Code: 0xA3, Reason: "inline subscription handler not valid."}
// MQTTv3 specific bytes.
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
Err3ClientIdentifierNotValid = Code{Code: 0x02}
Err3ServerUnavailable = Code{Code: 0x03}
ErrMalformedUsernameOrPassword = Code{Code: 0x04}
Err3NotAuthorized = Code{Code: 0x05}
// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
// This is required because MQTTv3 has different return byte specification.
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
V5CodesToV3 = map[Code]Code{
ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
ErrServerUnavailable: Err3ServerUnavailable,
ErrMalformedUsername: ErrMalformedUsernameOrPassword,
ErrMalformedPassword: ErrMalformedUsernameOrPassword,
ErrBadUsernameOrPassword: Err3NotAuthorized,
}
)

29
packets/codes_test.go Normal file
View File

@@ -0,0 +1,29 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestCodesString(t *testing.T) {
c := Code{
Reason: "test",
Code: 0x1,
}
require.Equal(t, "test", c.String())
}
func TestCodesError(t *testing.T) {
c := Code{
Reason: "error",
Code: 0x1,
}
require.Equal(t, "error", error(c).Error())
}

63
packets/fixedheader.go Normal file
View File

@@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
)
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
type FixedHeader struct {
Remaining int `json:"remaining"` // the number of remaining bytes in the payload.
Type byte `json:"type"` // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Qos byte `json:"qos"` // indicates the quality of service expected.
Dup bool `json:"dup"` // indicates if the packet was already sent at an earlier time.
Retain bool `json:"retain"` // whether the message should be retained.
}
// Encode encodes the FixedHeader and returns a bytes buffer.
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
encodeLength(buf, int64(fh.Remaining))
}
// Decode extracts the specification bits from the header byte.
func (fh *FixedHeader) Decode(hb byte) error {
fh.Type = hb >> 4 // Get the message type from the first 4 bytes.
switch fh.Type {
case Publish:
if (hb>>1)&0x01 > 0 && (hb>>1)&0x02 > 0 {
return ErrProtocolViolationQosOutOfRange // [MQTT-3.3.1-4]
}
fh.Dup = (hb>>3)&0x01 > 0 // is duplicate
fh.Qos = (hb >> 1) & 0x03 // qos flag
fh.Retain = hb&0x01 > 0 // is retain flag
case Pubrel:
fallthrough
case Subscribe:
fallthrough
case Unsubscribe:
if (hb>>0)&0x01 != 0 || (hb>>1)&0x01 != 1 || (hb>>2)&0x01 != 0 || (hb>>3)&0x01 != 0 { // [MQTT-3.8.1-1] [MQTT-3.10.1-1]
return ErrMalformedFlags
}
fh.Qos = (hb >> 1) & 0x03
default:
if (hb>>0)&0x01 != 0 ||
(hb>>1)&0x01 != 0 ||
(hb>>2)&0x01 != 0 ||
(hb>>3)&0x01 != 0 { // [MQTT-3.8.3-5] [MQTT-3.14.1-1] [MQTT-3.15.1-1]
return ErrMalformedFlags
}
}
if fh.Qos == 0 && fh.Dup {
return ErrProtocolViolationDupNoQos // [MQTT-3.3.1-2]
}
return nil
}

237
packets/fixedheader_test.go Normal file
View File

@@ -0,0 +1,237 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
type fixedHeaderTable struct {
desc string
rawBytes []byte
header FixedHeader
packetError bool
expect error
}
var fixedHeaderExpected = []fixedHeaderTable{
{
desc: "connect",
rawBytes: []byte{Connect << 4, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "connack",
rawBytes: []byte{Connack << 4, 0x00},
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "publish",
rawBytes: []byte{Publish << 4, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "publish qos 1",
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "publish qos 1 retain",
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
},
{
desc: "publish qos 2",
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
},
{
desc: "publish qos 2 retain",
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
},
{
desc: "publish dup qos 0",
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
expect: ErrProtocolViolationDupNoQos,
},
{
desc: "publish dup qos 0 retain",
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
expect: ErrProtocolViolationDupNoQos,
},
{
desc: "publish dup qos 1 retain",
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
},
{
desc: "publish dup qos 2 retain",
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
},
{
desc: "puback",
rawBytes: []byte{Puback << 4, 0x00},
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pubrec",
rawBytes: []byte{Pubrec << 4, 0x00},
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pubrel",
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "pubcomp",
rawBytes: []byte{Pubcomp << 4, 0x00},
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "subscribe",
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "suback",
rawBytes: []byte{Suback << 4, 0x00},
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "unsubscribe",
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "unsuback",
rawBytes: []byte{Unsuback << 4, 0x00},
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pingreq",
rawBytes: []byte{Pingreq << 4, 0x00},
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pingresp",
rawBytes: []byte{Pingresp << 4, 0x00},
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "disconnect",
rawBytes: []byte{Disconnect << 4, 0x00},
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "auth",
rawBytes: []byte{Auth << 4, 0x00},
header: FixedHeader{Type: Auth, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
// remaining length
{
desc: "remaining length 10",
rawBytes: []byte{Publish << 4, 0x0a},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
},
{
desc: "remaining length 512",
rawBytes: []byte{Publish << 4, 0x80, 0x04},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
},
{
desc: "remaining length 978",
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
},
{
desc: "remaining length 20202",
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
},
{
desc: "remaining length oversize",
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
packetError: true,
},
// Invalid flags for packet
{
desc: "invalid type dup is true",
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
expect: ErrMalformedFlags,
},
{
desc: "invalid type qos is 1",
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
expect: ErrMalformedFlags,
},
{
desc: "invalid type retain is true",
rawBytes: []byte{Connect<<4 | 1, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
expect: ErrMalformedFlags,
},
{
desc: "invalid publish qos bits 1 + 2 set",
rawBytes: []byte{Publish<<4 | 1<<1 | 1<<2, 0x00},
header: FixedHeader{Type: Publish},
expect: ErrProtocolViolationQosOutOfRange,
},
{
desc: "invalid pubrel bits 3,2,1,0 should be 0,0,1,0",
rawBytes: []byte{Pubrel<<4 | 1<<2 | 1<<0, 0x00},
header: FixedHeader{Type: Pubrel, Qos: 1},
expect: ErrMalformedFlags,
},
{
desc: "invalid subscribe bits 3,2,1,0 should be 0,0,1,0",
rawBytes: []byte{Subscribe<<4 | 1<<2, 0x00},
header: FixedHeader{Type: Subscribe, Qos: 1},
expect: ErrMalformedFlags,
},
}
func TestFixedHeaderEncode(t *testing.T) {
for _, wanted := range fixedHeaderExpected {
t.Run(wanted.desc, func(t *testing.T) {
buf := new(bytes.Buffer)
wanted.header.Encode(buf)
if wanted.expect == nil {
require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()))
require.EqualValues(t, wanted.rawBytes, buf.Bytes())
}
})
}
}
func TestFixedHeaderDecode(t *testing.T) {
for _, wanted := range fixedHeaderExpected {
t.Run(wanted.desc, func(t *testing.T) {
fh := new(FixedHeader)
err := fh.Decode(wanted.rawBytes[0])
if wanted.expect != nil {
require.Equal(t, wanted.expect, err)
} else {
require.NoError(t, err)
require.Equal(t, wanted.header.Type, fh.Type)
require.Equal(t, wanted.header.Dup, fh.Dup)
require.Equal(t, wanted.header.Qos, fh.Qos)
require.Equal(t, wanted.header.Retain, fh.Retain)
}
})
}
}

1173
packets/packets.go Normal file

File diff suppressed because it is too large Load Diff

505
packets/packets_test.go Normal file
View File

@@ -0,0 +1,505 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"fmt"
"testing"
"github.com/jinzhu/copier"
"github.com/stretchr/testify/require"
)
const pkInfo = "packet type %v, %s"
var packetList = []byte{
Connect,
Connack,
Publish,
Puback,
Pubrec,
Pubrel,
Pubcomp,
Subscribe,
Suback,
Unsubscribe,
Unsuback,
Pingreq,
Pingresp,
Disconnect,
Auth,
}
var pkTable = []TPacketCase{
TPacketData[Connect].Get(TConnectMqtt311),
TPacketData[Connect].Get(TConnectMqtt5),
TPacketData[Connect].Get(TConnectUserPassLWT),
TPacketData[Connack].Get(TConnackAcceptedMqtt5),
TPacketData[Connack].Get(TConnackAcceptedNoSession),
TPacketData[Publish].Get(TPublishBasic),
TPacketData[Publish].Get(TPublishMqtt5),
TPacketData[Puback].Get(TPuback),
TPacketData[Pubrec].Get(TPubrec),
TPacketData[Pubrel].Get(TPubrel),
TPacketData[Pubcomp].Get(TPubcomp),
TPacketData[Subscribe].Get(TSubscribe),
TPacketData[Subscribe].Get(TSubscribeMqtt5),
TPacketData[Suback].Get(TSuback),
TPacketData[Unsubscribe].Get(TUnsubscribe),
TPacketData[Unsubscribe].Get(TUnsubscribeMqtt5),
TPacketData[Pingreq].Get(TPingreq),
TPacketData[Pingresp].Get(TPingresp),
TPacketData[Disconnect].Get(TDisconnect),
TPacketData[Disconnect].Get(TDisconnectMqtt5),
}
func TestNewPackets(t *testing.T) {
s := NewPackets()
require.NotNil(t, s.internal)
}
func TestPacketsAdd(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{})
require.Contains(t, s.internal, "cl1")
}
func TestPacketsGet(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
pk, ok := s.Get("cl1")
require.True(t, ok)
require.Equal(t, "a1", pk.TopicName)
}
func TestPacketsGetAll(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
s.Add("cl3", Packet{TopicName: "a3"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
require.Contains(t, s.internal, "cl3")
subs := s.GetAll()
require.Len(t, subs, 3)
}
func TestPacketsLen(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
require.Equal(t, 2, s.Len())
}
func TestSPacketsDelete(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
require.Contains(t, s.internal, "cl1")
s.Delete("cl1")
_, ok := s.Get("cl1")
require.False(t, ok)
}
func TestFormatPacketID(t *testing.T) {
for _, id := range []uint16{0, 7, 0x100, 0xffff} {
packet := &Packet{PacketID: id}
require.Equal(t, fmt.Sprint(id), packet.FormatID())
}
}
func TestSubscriptionOptionsEncodeDecode(t *testing.T) {
p := &Subscription{
Qos: 2,
NoLocal: true,
RetainAsPublished: true,
RetainHandling: 2,
}
x := new(Subscription)
x.decode(p.encode())
require.Equal(t, *p, *x)
p = &Subscription{
Qos: 1,
NoLocal: false,
RetainAsPublished: false,
RetainHandling: 1,
}
x = new(Subscription)
x.decode(p.encode())
require.Equal(t, *p, *x)
}
func TestPacketEncode(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if !encodeTestOK(wanted) {
return
}
pk := new(Packet)
_ = copier.Copy(pk, wanted.Packet)
require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
pk.Mods.AllowResponseInfo = true
buf := new(bytes.Buffer)
var err error
switch pkt {
case Connect:
err = pk.ConnectEncode(buf)
case Connack:
err = pk.ConnackEncode(buf)
case Publish:
err = pk.PublishEncode(buf)
case Puback:
err = pk.PubackEncode(buf)
case Pubrec:
err = pk.PubrecEncode(buf)
case Pubrel:
err = pk.PubrelEncode(buf)
case Pubcomp:
err = pk.PubcompEncode(buf)
case Subscribe:
err = pk.SubscribeEncode(buf)
case Suback:
err = pk.SubackEncode(buf)
case Unsubscribe:
err = pk.UnsubscribeEncode(buf)
case Unsuback:
err = pk.UnsubackEncode(buf)
case Pingreq:
err = pk.PingreqEncode(buf)
case Pingresp:
err = pk.PingrespEncode(buf)
case Disconnect:
err = pk.DisconnectEncode(buf)
case Auth:
err = pk.AuthEncode(buf)
}
if wanted.Expect != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
return
}
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
encoded := buf.Bytes()
// If ActualBytes is set, compare mutated version of byte string instead (to avoid length mismatches, etc).
if len(wanted.ActualBytes) > 0 {
wanted.RawBytes = wanted.ActualBytes
}
require.EqualValues(t, wanted.RawBytes, encoded, pkInfo, pkt, wanted.Desc)
})
}
}
}
func TestPacketDecode(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if !decodeTestOK(wanted) {
return
}
pk := &Packet{FixedHeader: FixedHeader{Type: pkt}}
pk.Mods.AllowResponseInfo = true
_ = pk.FixedHeader.Decode(wanted.RawBytes[0])
if len(wanted.RawBytes) > 0 {
pk.FixedHeader.Remaining = int(wanted.RawBytes[1])
}
if wanted.Packet != nil && wanted.Packet.ProtocolVersion != 0 {
pk.ProtocolVersion = wanted.Packet.ProtocolVersion
}
buf := wanted.RawBytes[2:]
var err error
switch pkt {
case Connect:
err = pk.ConnectDecode(buf)
case Connack:
err = pk.ConnackDecode(buf)
case Publish:
err = pk.PublishDecode(buf)
case Puback:
err = pk.PubackDecode(buf)
case Pubrec:
err = pk.PubrecDecode(buf)
case Pubrel:
err = pk.PubrelDecode(buf)
case Pubcomp:
err = pk.PubcompDecode(buf)
case Subscribe:
err = pk.SubscribeDecode(buf)
case Suback:
err = pk.SubackDecode(buf)
case Unsubscribe:
err = pk.UnsubscribeDecode(buf)
case Unsuback:
err = pk.UnsubackDecode(buf)
case Pingreq:
err = pk.PingreqDecode(buf)
case Pingresp:
err = pk.PingrespDecode(buf)
case Disconnect:
err = pk.DisconnectDecode(buf)
case Auth:
err = pk.AuthDecode(buf)
}
if wanted.FailFirst != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
require.ErrorIs(t, err, wanted.FailFirst, pkInfo, pkt, wanted.Desc)
return
}
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
require.EqualValues(t, wanted.Packet.Filters, pk.Filters, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Type, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Dup, pk.FixedHeader.Dup, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Qos, pk.FixedHeader.Qos, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Retain, pk.FixedHeader.Retain, pkInfo, pkt, wanted.Desc)
if pkt == Connect {
// we use ProtocolVersion for controlling packet encoding, but we don't need to test
// against it unless it's a connect packet.
require.Equal(t, wanted.Packet.ProtocolVersion, pk.ProtocolVersion, pkInfo, pkt, wanted.Desc)
}
require.Equal(t, wanted.Packet.Connect.ProtocolName, pk.Connect.ProtocolName, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Clean, pk.Connect.Clean, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.ClientIdentifier, pk.Connect.ClientIdentifier, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Keepalive, pk.Connect.Keepalive, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.UsernameFlag, pk.Connect.UsernameFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Username, pk.Connect.Username, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.PasswordFlag, pk.Connect.PasswordFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Password, pk.Connect.Password, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillFlag, pk.Connect.WillFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillTopic, pk.Connect.WillTopic, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillPayload, pk.Connect.WillPayload, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillQos, pk.Connect.WillQos, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillRetain, pk.Connect.WillRetain, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.ReasonCodes, pk.ReasonCodes, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.ReasonCode, pk.ReasonCode, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.SessionPresent, pk.SessionPresent, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.PacketID, pk.PacketID, pkInfo, pkt, wanted.Desc)
require.EqualValues(t, wanted.Packet.Properties, pk.Properties)
require.EqualValues(t, wanted.Packet.Connect.WillProperties, pk.Connect.WillProperties)
})
}
}
}
func TestValidate(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if wanted.Group == "validate" || wanted.Primary {
pk := wanted.Packet
var err error
switch pkt {
case Connect:
err = pk.ConnectValidate()
case Publish:
err = pk.PublishValidate(1024)
case Subscribe:
err = pk.SubscribeValidate()
case Unsubscribe:
err = pk.UnsubscribeValidate()
case Auth:
err = pk.AuthValidate()
}
if wanted.Expect != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
require.ErrorIs(t, wanted.Expect, err, pkInfo, pkt, wanted.Desc)
}
}
})
}
}
}
func TestAckValidatePubrec(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
CodeNoMatchingSubscribers.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicNameInvalid.Code,
ErrPacketIdentifierInUse.Code,
ErrQuotaExceeded.Code,
ErrPayloadFormatInvalid.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidatePubrel(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
ErrPacketIdentifierNotFound.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidatePubcomp(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
ErrPacketIdentifierNotFound.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubcomp}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidateSuback(t *testing.T) {
for _, b := range []byte{
CodeGrantedQos0.Code,
CodeGrantedQos1.Code,
CodeGrantedQos2.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicFilterInvalid.Code,
ErrPacketIdentifierInUse.Code,
ErrQuotaExceeded.Code,
ErrSharedSubscriptionsNotSupported.Code,
ErrSubscriptionIdentifiersNotSupported.Code,
ErrWildcardSubscriptionsNotSupported.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidateUnsuback(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
CodeNoSubscriptionExisted.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicFilterInvalid.Code,
ErrPacketIdentifierInUse.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestReasonCodeValidMisc(t *testing.T) {
pk := Packet{FixedHeader: FixedHeader{Type: Connack}, ReasonCode: CodeSuccess.Code}
require.True(t, pk.ReasonCodeValid())
}
func TestCopy(t *testing.T) {
for _, tt := range pkTable {
pkc := tt.Packet.Copy(true)
require.Equal(t, tt.Packet.FixedHeader.Qos, pkc.FixedHeader.Qos, pkInfo, tt.Case, tt.Desc)
require.Equal(t, false, pkc.FixedHeader.Dup, pkInfo, tt.Case, tt.Desc)
require.Equal(t, false, pkc.FixedHeader.Retain, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.TopicName, pkc.TopicName, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.ClientIdentifier, pkc.Connect.ClientIdentifier, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Keepalive, pkc.Connect.Keepalive, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ProtocolVersion, pkc.ProtocolVersion, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.PasswordFlag, pkc.Connect.PasswordFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.UsernameFlag, pkc.Connect.UsernameFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillQos, pkc.Connect.WillQos, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillTopic, pkc.Connect.WillTopic, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillFlag, pkc.Connect.WillFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillRetain, pkc.Connect.WillRetain, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillProperties, pkc.Connect.WillProperties, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Properties, pkc.Properties, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Clean, pkc.Connect.Clean, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.SessionPresent, pkc.SessionPresent, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ReasonCode, pkc.ReasonCode, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.PacketID, pkc.PacketID, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Filters, pkc.Filters, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Payload, pkc.Payload, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Password, pkc.Connect.Password, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Username, pkc.Connect.Username, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.ProtocolName, pkc.Connect.ProtocolName, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillPayload, pkc.Connect.WillPayload, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ReasonCodes, pkc.ReasonCodes, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
pkcc := tt.Packet.Copy(false)
require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc)
}
}
func TestMergeSubscription(t *testing.T) {
sub := Subscription{
Filter: "a/b/c",
RetainHandling: 0,
Qos: 0,
RetainAsPublished: false,
NoLocal: false,
Identifier: 1,
}
sub2 := Subscription{
Filter: "a/b/d",
RetainHandling: 0,
Qos: 2,
RetainAsPublished: false,
NoLocal: true,
Identifier: 2,
}
expect := Subscription{
Filter: "a/b/c",
RetainHandling: 0,
Qos: 2,
RetainAsPublished: false,
NoLocal: true,
Identifier: 1,
Identifiers: map[string]int{
"a/b/c": 1,
"a/b/d": 2,
},
}
require.Equal(t, expect, sub.Merge(sub2))
}

481
packets/properties.go Normal file
View File

@@ -0,0 +1,481 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"fmt"
"strings"
"testmqtt/mempool"
)
const (
PropPayloadFormat byte = 1
PropMessageExpiryInterval byte = 2
PropContentType byte = 3
PropResponseTopic byte = 8
PropCorrelationData byte = 9
PropSubscriptionIdentifier byte = 11
PropSessionExpiryInterval byte = 17
PropAssignedClientID byte = 18
PropServerKeepAlive byte = 19
PropAuthenticationMethod byte = 21
PropAuthenticationData byte = 22
PropRequestProblemInfo byte = 23
PropWillDelayInterval byte = 24
PropRequestResponseInfo byte = 25
PropResponseInfo byte = 26
PropServerReference byte = 28
PropReasonString byte = 31
PropReceiveMaximum byte = 33
PropTopicAliasMaximum byte = 34
PropTopicAlias byte = 35
PropMaximumQos byte = 36
PropRetainAvailable byte = 37
PropUser byte = 38
PropMaximumPacketSize byte = 39
PropWildcardSubAvailable byte = 40
PropSubIDAvailable byte = 41
PropSharedSubAvailable byte = 42
)
// validPacketProperties indicates which properties are valid for which packet types.
var validPacketProperties = map[byte]map[byte]byte{
PropPayloadFormat: {Publish: 1, WillProperties: 1},
PropMessageExpiryInterval: {Publish: 1, WillProperties: 1},
PropContentType: {Publish: 1, WillProperties: 1},
PropResponseTopic: {Publish: 1, WillProperties: 1},
PropCorrelationData: {Publish: 1, WillProperties: 1},
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
PropAssignedClientID: {Connack: 1},
PropServerKeepAlive: {Connack: 1},
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
PropRequestProblemInfo: {Connect: 1},
PropWillDelayInterval: {WillProperties: 1},
PropRequestResponseInfo: {Connect: 1},
PropResponseInfo: {Connack: 1},
PropServerReference: {Connack: 1, Disconnect: 1},
PropReasonString: {Connack: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Suback: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
PropReceiveMaximum: {Connect: 1, Connack: 1},
PropTopicAliasMaximum: {Connect: 1, Connack: 1},
PropTopicAlias: {Publish: 1},
PropMaximumQos: {Connack: 1},
PropRetainAvailable: {Connack: 1},
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1, WillProperties: 1},
PropMaximumPacketSize: {Connect: 1, Connack: 1},
PropWildcardSubAvailable: {Connack: 1},
PropSubIDAvailable: {Connack: 1},
PropSharedSubAvailable: {Connack: 1},
}
// UserProperty is an arbitrary key-value pair for a packet user properties array.
type UserProperty struct { // [MQTT-1.5.7-1]
Key string `json:"k"`
Val string `json:"v"`
}
// Properties contains all mqtt v5 properties available for a packet.
// Some properties have valid values of 0 or not-present. In this case, we opt for
// property flags to indicate the usage of property.
// Refer to mqtt v5 2.2.2.2 Property spec for more information.
type Properties struct {
CorrelationData []byte `json:"cd"`
SubscriptionIdentifier []int `json:"si"`
AuthenticationData []byte `json:"ad"`
User []UserProperty `json:"user"`
ContentType string `json:"ct"`
ResponseTopic string `json:"rt"`
AssignedClientID string `json:"aci"`
AuthenticationMethod string `json:"am"`
ResponseInfo string `json:"ri"`
ServerReference string `json:"sr"`
ReasonString string `json:"rs"`
MessageExpiryInterval uint32 `json:"me"`
SessionExpiryInterval uint32 `json:"sei"`
WillDelayInterval uint32 `json:"wdi"`
MaximumPacketSize uint32 `json:"mps"`
ServerKeepAlive uint16 `json:"ska"`
ReceiveMaximum uint16 `json:"rm"`
TopicAliasMaximum uint16 `json:"tam"`
TopicAlias uint16 `json:"ta"`
PayloadFormat byte `json:"pf"`
PayloadFormatFlag bool `json:"fpf"`
SessionExpiryIntervalFlag bool `json:"fsei"`
ServerKeepAliveFlag bool `json:"fska"`
RequestProblemInfo byte `json:"rpi"`
RequestProblemInfoFlag bool `json:"frpi"`
RequestResponseInfo byte `json:"rri"`
TopicAliasFlag bool `json:"fta"`
MaximumQos byte `json:"mqos"`
MaximumQosFlag bool `json:"fmqos"`
RetainAvailable byte `json:"ra"`
RetainAvailableFlag bool `json:"fra"`
WildcardSubAvailable byte `json:"wsa"`
WildcardSubAvailableFlag bool `json:"fwsa"`
SubIDAvailable byte `json:"sida"`
SubIDAvailableFlag bool `json:"fsida"`
SharedSubAvailable byte `json:"ssa"`
SharedSubAvailableFlag bool `json:"fssa"`
}
// Copy creates a new Properties struct with copies of the values.
func (p *Properties) Copy(allowTransfer bool) Properties {
pr := Properties{
PayloadFormat: p.PayloadFormat, // [MQTT-3.3.2-4]
PayloadFormatFlag: p.PayloadFormatFlag,
MessageExpiryInterval: p.MessageExpiryInterval,
ContentType: p.ContentType, // [MQTT-3.3.2-20]
ResponseTopic: p.ResponseTopic, // [MQTT-3.3.2-15]
SessionExpiryInterval: p.SessionExpiryInterval,
SessionExpiryIntervalFlag: p.SessionExpiryIntervalFlag,
AssignedClientID: p.AssignedClientID,
ServerKeepAlive: p.ServerKeepAlive,
ServerKeepAliveFlag: p.ServerKeepAliveFlag,
AuthenticationMethod: p.AuthenticationMethod,
RequestProblemInfo: p.RequestProblemInfo,
RequestProblemInfoFlag: p.RequestProblemInfoFlag,
WillDelayInterval: p.WillDelayInterval,
RequestResponseInfo: p.RequestResponseInfo,
ResponseInfo: p.ResponseInfo,
ServerReference: p.ServerReference,
ReasonString: p.ReasonString,
ReceiveMaximum: p.ReceiveMaximum,
TopicAliasMaximum: p.TopicAliasMaximum,
TopicAlias: 0, // NB; do not copy topic alias [MQTT-3.3.2-7] + we do not send to clients (currently) [MQTT-3.1.2-26] [MQTT-3.1.2-27]
MaximumQos: p.MaximumQos,
MaximumQosFlag: p.MaximumQosFlag,
RetainAvailable: p.RetainAvailable,
RetainAvailableFlag: p.RetainAvailableFlag,
MaximumPacketSize: p.MaximumPacketSize,
WildcardSubAvailable: p.WildcardSubAvailable,
WildcardSubAvailableFlag: p.WildcardSubAvailableFlag,
SubIDAvailable: p.SubIDAvailable,
SubIDAvailableFlag: p.SubIDAvailableFlag,
SharedSubAvailable: p.SharedSubAvailable,
SharedSubAvailableFlag: p.SharedSubAvailableFlag,
}
if allowTransfer {
pr.TopicAlias = p.TopicAlias
pr.TopicAliasFlag = p.TopicAliasFlag
}
if len(p.CorrelationData) > 0 {
pr.CorrelationData = append([]byte{}, p.CorrelationData...) // [MQTT-3.3.2-16]
}
if len(p.SubscriptionIdentifier) > 0 {
pr.SubscriptionIdentifier = append([]int{}, p.SubscriptionIdentifier...)
}
if len(p.AuthenticationData) > 0 {
pr.AuthenticationData = append([]byte{}, p.AuthenticationData...)
}
if len(p.User) > 0 {
pr.User = []UserProperty{}
for _, v := range p.User {
pr.User = append(pr.User, UserProperty{ // [MQTT-3.3.2-17]
Key: v.Key,
Val: v.Val,
})
}
}
return pr
}
// canEncode returns true if the property type is valid for the packet type.
func (p *Properties) canEncode(pkt byte, k byte) bool {
return validPacketProperties[k][pkt] == 1
}
// Encode encodes properties into a bytes buffer.
func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
if p == nil {
return
}
buf := mempool.GetBuffer()
defer mempool.PutBuffer(buf)
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
buf.WriteByte(PropPayloadFormat)
buf.WriteByte(p.PayloadFormat)
}
if p.canEncode(pkt, PropMessageExpiryInterval) && p.MessageExpiryInterval > 0 {
buf.WriteByte(PropMessageExpiryInterval)
buf.Write(encodeUint32(p.MessageExpiryInterval))
}
if p.canEncode(pkt, PropContentType) && p.ContentType != "" {
buf.WriteByte(PropContentType)
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
}
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseTopic)
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
}
if mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropCorrelationData)
buf.Write(encodeBytes(p.CorrelationData))
}
if p.canEncode(pkt, PropSubscriptionIdentifier) && len(p.SubscriptionIdentifier) > 0 {
for _, v := range p.SubscriptionIdentifier {
if v > 0 {
buf.WriteByte(PropSubscriptionIdentifier)
encodeLength(buf, int64(v))
}
}
}
if p.canEncode(pkt, PropSessionExpiryInterval) && p.SessionExpiryIntervalFlag { // [MQTT-3.14.2-2]
buf.WriteByte(PropSessionExpiryInterval)
buf.Write(encodeUint32(p.SessionExpiryInterval))
}
if p.canEncode(pkt, PropAssignedClientID) && p.AssignedClientID != "" {
buf.WriteByte(PropAssignedClientID)
buf.Write(encodeString(p.AssignedClientID))
}
if p.canEncode(pkt, PropServerKeepAlive) && p.ServerKeepAliveFlag {
buf.WriteByte(PropServerKeepAlive)
buf.Write(encodeUint16(p.ServerKeepAlive))
}
if p.canEncode(pkt, PropAuthenticationMethod) && p.AuthenticationMethod != "" {
buf.WriteByte(PropAuthenticationMethod)
buf.Write(encodeString(p.AuthenticationMethod))
}
if p.canEncode(pkt, PropAuthenticationData) && len(p.AuthenticationData) > 0 {
buf.WriteByte(PropAuthenticationData)
buf.Write(encodeBytes(p.AuthenticationData))
}
if p.canEncode(pkt, PropRequestProblemInfo) && p.RequestProblemInfoFlag {
buf.WriteByte(PropRequestProblemInfo)
buf.WriteByte(p.RequestProblemInfo)
}
if p.canEncode(pkt, PropWillDelayInterval) && p.WillDelayInterval > 0 {
buf.WriteByte(PropWillDelayInterval)
buf.Write(encodeUint32(p.WillDelayInterval))
}
if p.canEncode(pkt, PropRequestResponseInfo) && p.RequestResponseInfo > 0 {
buf.WriteByte(PropRequestResponseInfo)
buf.WriteByte(p.RequestResponseInfo)
}
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseInfo)
buf.Write(encodeString(p.ResponseInfo))
}
if p.canEncode(pkt, PropServerReference) && len(p.ServerReference) > 0 {
buf.WriteByte(PropServerReference)
buf.Write(encodeString(p.ServerReference))
}
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
b := encodeString(p.ReasonString)
if mods.MaxSize == 0 || uint32(n+len(b)+1) < mods.MaxSize {
buf.WriteByte(PropReasonString)
buf.Write(b)
}
}
if p.canEncode(pkt, PropReceiveMaximum) && p.ReceiveMaximum > 0 {
buf.WriteByte(PropReceiveMaximum)
buf.Write(encodeUint16(p.ReceiveMaximum))
}
if p.canEncode(pkt, PropTopicAliasMaximum) && p.TopicAliasMaximum > 0 {
buf.WriteByte(PropTopicAliasMaximum)
buf.Write(encodeUint16(p.TopicAliasMaximum))
}
if p.canEncode(pkt, PropTopicAlias) && p.TopicAliasFlag && p.TopicAlias > 0 { // [MQTT-3.3.2-8]
buf.WriteByte(PropTopicAlias)
buf.Write(encodeUint16(p.TopicAlias))
}
if p.canEncode(pkt, PropMaximumQos) && p.MaximumQosFlag && p.MaximumQos < 2 {
buf.WriteByte(PropMaximumQos)
buf.WriteByte(p.MaximumQos)
}
if p.canEncode(pkt, PropRetainAvailable) && p.RetainAvailableFlag {
buf.WriteByte(PropRetainAvailable)
buf.WriteByte(p.RetainAvailable)
}
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
for _, v := range p.User {
pb.WriteByte(PropUser)
pb.Write(encodeString(v.Key))
pb.Write(encodeString(v.Val))
}
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
if mods.MaxSize == 0 || uint32(n+pb.Len()+1) < mods.MaxSize {
buf.Write(pb.Bytes())
}
}
if p.canEncode(pkt, PropMaximumPacketSize) && p.MaximumPacketSize > 0 {
buf.WriteByte(PropMaximumPacketSize)
buf.Write(encodeUint32(p.MaximumPacketSize))
}
if p.canEncode(pkt, PropWildcardSubAvailable) && p.WildcardSubAvailableFlag {
buf.WriteByte(PropWildcardSubAvailable)
buf.WriteByte(p.WildcardSubAvailable)
}
if p.canEncode(pkt, PropSubIDAvailable) && p.SubIDAvailableFlag {
buf.WriteByte(PropSubIDAvailable)
buf.WriteByte(p.SubIDAvailable)
}
if p.canEncode(pkt, PropSharedSubAvailable) && p.SharedSubAvailableFlag {
buf.WriteByte(PropSharedSubAvailable)
buf.WriteByte(p.SharedSubAvailable)
}
encodeLength(b, int64(buf.Len()))
b.Write(buf.Bytes()) // [MQTT-3.1.3-10]
}
// Decode decodes property bytes into a properties struct.
func (p *Properties) Decode(pkt byte, b *bytes.Buffer) (n int, err error) {
if p == nil {
return 0, nil
}
var bu int
n, bu, err = DecodeLength(b)
if err != nil {
return n + bu, err
}
if n == 0 {
return n + bu, nil
}
bt := b.Bytes()
var k byte
for offset := 0; offset < n; {
k, offset, err = decodeByte(bt, offset)
if err != nil {
return n + bu, err
}
if _, ok := validPacketProperties[k][pkt]; !ok {
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pkt, ErrProtocolViolationUnsupportedProperty)
}
switch k {
case PropPayloadFormat:
p.PayloadFormat, offset, err = decodeByte(bt, offset)
p.PayloadFormatFlag = true
case PropMessageExpiryInterval:
p.MessageExpiryInterval, offset, err = decodeUint32(bt, offset)
case PropContentType:
p.ContentType, offset, err = decodeString(bt, offset)
case PropResponseTopic:
p.ResponseTopic, offset, err = decodeString(bt, offset)
case PropCorrelationData:
p.CorrelationData, offset, err = decodeBytes(bt, offset)
case PropSubscriptionIdentifier:
if p.SubscriptionIdentifier == nil {
p.SubscriptionIdentifier = []int{}
}
n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:]))
if err != nil {
return n + bu, err
}
p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n)
offset += bu
case PropSessionExpiryInterval:
p.SessionExpiryInterval, offset, err = decodeUint32(bt, offset)
p.SessionExpiryIntervalFlag = true
case PropAssignedClientID:
p.AssignedClientID, offset, err = decodeString(bt, offset)
case PropServerKeepAlive:
p.ServerKeepAlive, offset, err = decodeUint16(bt, offset)
p.ServerKeepAliveFlag = true
case PropAuthenticationMethod:
p.AuthenticationMethod, offset, err = decodeString(bt, offset)
case PropAuthenticationData:
p.AuthenticationData, offset, err = decodeBytes(bt, offset)
case PropRequestProblemInfo:
p.RequestProblemInfo, offset, err = decodeByte(bt, offset)
p.RequestProblemInfoFlag = true
case PropWillDelayInterval:
p.WillDelayInterval, offset, err = decodeUint32(bt, offset)
case PropRequestResponseInfo:
p.RequestResponseInfo, offset, err = decodeByte(bt, offset)
case PropResponseInfo:
p.ResponseInfo, offset, err = decodeString(bt, offset)
case PropServerReference:
p.ServerReference, offset, err = decodeString(bt, offset)
case PropReasonString:
p.ReasonString, offset, err = decodeString(bt, offset)
case PropReceiveMaximum:
p.ReceiveMaximum, offset, err = decodeUint16(bt, offset)
case PropTopicAliasMaximum:
p.TopicAliasMaximum, offset, err = decodeUint16(bt, offset)
case PropTopicAlias:
p.TopicAlias, offset, err = decodeUint16(bt, offset)
p.TopicAliasFlag = true
case PropMaximumQos:
p.MaximumQos, offset, err = decodeByte(bt, offset)
p.MaximumQosFlag = true
case PropRetainAvailable:
p.RetainAvailable, offset, err = decodeByte(bt, offset)
p.RetainAvailableFlag = true
case PropUser:
var k, v string
k, offset, err = decodeString(bt, offset)
if err != nil {
return n + bu, err
}
v, offset, err = decodeString(bt, offset)
p.User = append(p.User, UserProperty{Key: k, Val: v})
case PropMaximumPacketSize:
p.MaximumPacketSize, offset, err = decodeUint32(bt, offset)
case PropWildcardSubAvailable:
p.WildcardSubAvailable, offset, err = decodeByte(bt, offset)
p.WildcardSubAvailableFlag = true
case PropSubIDAvailable:
p.SubIDAvailable, offset, err = decodeByte(bt, offset)
p.SubIDAvailableFlag = true
case PropSharedSubAvailable:
p.SharedSubAvailable, offset, err = decodeByte(bt, offset)
p.SharedSubAvailableFlag = true
}
if err != nil {
return n + bu, err
}
}
return n + bu, nil
}

333
packets/properties_test.go Normal file
View File

@@ -0,0 +1,333 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
var (
propertiesStruct = Properties{
PayloadFormat: byte(1), // UTF-8 Format
PayloadFormatFlag: true,
MessageExpiryInterval: uint32(2),
ContentType: "text/plain",
ResponseTopic: "a/b/c",
CorrelationData: []byte("data"),
SubscriptionIdentifier: []int{322122},
SessionExpiryInterval: uint32(120),
SessionExpiryIntervalFlag: true,
AssignedClientID: "mochi-v5",
ServerKeepAlive: uint16(20),
ServerKeepAliveFlag: true,
AuthenticationMethod: "SHA-1",
AuthenticationData: []byte("auth-data"),
RequestProblemInfo: byte(1),
RequestProblemInfoFlag: true,
WillDelayInterval: uint32(600),
RequestResponseInfo: byte(1),
ResponseInfo: "response",
ServerReference: "mochi-2",
ReasonString: "reason",
ReceiveMaximum: uint16(500),
TopicAliasMaximum: uint16(999),
TopicAlias: uint16(3),
TopicAliasFlag: true,
MaximumQos: byte(1),
MaximumQosFlag: true,
RetainAvailable: byte(1),
RetainAvailableFlag: true,
User: []UserProperty{
{
Key: "hello",
Val: "世界",
},
{
Key: "key2",
Val: "value2",
},
},
MaximumPacketSize: uint32(32000),
WildcardSubAvailable: byte(1),
WildcardSubAvailableFlag: true,
SubIDAvailable: byte(1),
SubIDAvailableFlag: true,
SharedSubAvailable: byte(1),
SharedSubAvailableFlag: true,
}
propertiesBytes = []byte{
172, 1, // VBI
// Payload Format (1) (vbi:2)
1, 1,
// Message Expiry (2) (vbi:7)
2, 0, 0, 0, 2,
// Content Type (3) (vbi:20)
3,
0, 10, 't', 'e', 'x', 't', '/', 'p', 'l', 'a', 'i', 'n',
// Response Topic (8) (vbi:28)
8,
0, 5, 'a', '/', 'b', '/', 'c',
// Correlations Data (9) (vbi:35)
9,
0, 4, 'd', 'a', 't', 'a',
// Subscription Identifier (11) (vbi:39)
11,
202, 212, 19,
// Session Expiry Interval (17) (vbi:43)
17,
0, 0, 0, 120,
// Assigned Client ID (18) (vbi:55)
18,
0, 8, 'm', 'o', 'c', 'h', 'i', '-', 'v', '5',
// Server Keep Alive (19) (vbi:58)
19,
0, 20,
// Authentication Method (21) (vbi:66)
21,
0, 5, 'S', 'H', 'A', '-', '1',
// Authentication Data (22) (vbi:78)
22,
0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a',
// Request Problem Info (23) (vbi:80)
23, 1,
// Will Delay Interval (24) (vbi:85)
24,
0, 0, 2, 88,
// Request Response Info (25) (vbi:87)
25, 1,
// Response Info (26) (vbi:98)
26,
0, 8, 'r', 'e', 's', 'p', 'o', 'n', 's', 'e',
// Server Reference (28) (vbi:108)
28,
0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2',
// Reason String (31) (vbi:117)
31,
0, 6, 'r', 'e', 'a', 's', 'o', 'n',
// Receive Maximum (33) (vbi:120)
33,
1, 244,
// Topic Alias Maximum (34) (vbi:123)
34,
3, 231,
// Topic Alias (35) (vbi:126)
35,
0, 3,
// Maximum Qos (36) (vbi:128)
36, 1,
// Retain Available (37) (vbi: 130)
37, 1,
// User Properties (38) (vbi:161)
38,
0, 5, 'h', 'e', 'l', 'l', 'o',
0, 6, 228, 184, 150, 231, 149, 140,
38,
0, 4, 'k', 'e', 'y', '2',
0, 6, 'v', 'a', 'l', 'u', 'e', '2',
// Maximum Packet Size (39) (vbi:166)
39,
0, 0, 125, 0,
// Wildcard Subscriptions Available (40) (vbi:168)
40, 1,
// Subscription ID Available (41) (vbi:170)
41, 1,
// Shared Subscriptions Available (42) (vbi:172)
42, 1,
}
)
func init() {
validPacketProperties[PropPayloadFormat][Reserved] = 1
validPacketProperties[PropMessageExpiryInterval][Reserved] = 1
validPacketProperties[PropContentType][Reserved] = 1
validPacketProperties[PropResponseTopic][Reserved] = 1
validPacketProperties[PropCorrelationData][Reserved] = 1
validPacketProperties[PropSubscriptionIdentifier][Reserved] = 1
validPacketProperties[PropSessionExpiryInterval][Reserved] = 1
validPacketProperties[PropAssignedClientID][Reserved] = 1
validPacketProperties[PropServerKeepAlive][Reserved] = 1
validPacketProperties[PropAuthenticationMethod][Reserved] = 1
validPacketProperties[PropAuthenticationData][Reserved] = 1
validPacketProperties[PropRequestProblemInfo][Reserved] = 1
validPacketProperties[PropWillDelayInterval][Reserved] = 1
validPacketProperties[PropRequestResponseInfo][Reserved] = 1
validPacketProperties[PropResponseInfo][Reserved] = 1
validPacketProperties[PropServerReference][Reserved] = 1
validPacketProperties[PropReasonString][Reserved] = 1
validPacketProperties[PropReceiveMaximum][Reserved] = 1
validPacketProperties[PropTopicAliasMaximum][Reserved] = 1
validPacketProperties[PropTopicAlias][Reserved] = 1
validPacketProperties[PropMaximumQos][Reserved] = 1
validPacketProperties[PropRetainAvailable][Reserved] = 1
validPacketProperties[PropUser][Reserved] = 1
validPacketProperties[PropMaximumPacketSize][Reserved] = 1
validPacketProperties[PropWildcardSubAvailable][Reserved] = 1
validPacketProperties[PropSubIDAvailable][Reserved] = 1
validPacketProperties[PropSharedSubAvailable][Reserved] = 1
}
func TestEncodeProperties(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
require.Equal(t, propertiesBytes, b.Bytes())
}
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(Reserved, Mods{DisallowProblemInfo: true}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
require.False(t, bytes.Contains(b.Bytes(), []byte{26, 0, 8}))
}
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(Reserved, Mods{AllowResponseInfo: false}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
}
func TestEncodePropertiesNil(t *testing.T) {
type tmp struct {
p *Properties
}
pr := tmp{}
b := bytes.NewBuffer([]byte{})
pr.p.Encode(Reserved, Mods{}, b, 0)
require.Equal(t, []byte{}, b.Bytes())
}
func TestEncodeZeroProperties(t *testing.T) {
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
props := new(Properties)
b := bytes.NewBuffer([]byte{})
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
require.Equal(t, []byte{0x00}, b.Bytes())
}
func TestDecodeProperties(t *testing.T) {
b := bytes.NewBuffer(propertiesBytes)
props := new(Properties)
n, err := props.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, 172+2, n)
require.EqualValues(t, propertiesStruct, *props)
}
func TestDecodePropertiesNil(t *testing.T) {
b := bytes.NewBuffer(propertiesBytes)
type tmp struct {
p *Properties
}
pr := tmp{}
n, err := pr.p.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, 0, n)
}
func TestDecodePropertiesBadInitialVBI(t *testing.T) {
b := bytes.NewBuffer([]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
require.ErrorIs(t, ErrMalformedVariableByteInteger, err)
}
func TestDecodePropertiesZeroLengthVBI(t *testing.T) {
b := bytes.NewBuffer([]byte{0})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, props, new(Properties))
}
func TestDecodePropertiesBadKeyByte(t *testing.T) {
b := bytes.NewBuffer([]byte{64, 1})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
require.ErrorIs(t, err, ErrMalformedOffsetByteOutOfRange)
}
func TestDecodePropertiesInvalidForPacket(t *testing.T) {
b := bytes.NewBuffer([]byte{1, 99})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
require.ErrorIs(t, err, ErrProtocolViolationUnsupportedProperty)
}
func TestDecodePropertiesGeneralFailure(t *testing.T) {
b := bytes.NewBuffer([]byte{10, 11, 202, 212, 19})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
}
func TestDecodePropertiesBadSubscriptionID(t *testing.T) {
b := bytes.NewBuffer([]byte{10, 11, 255, 255, 255, 255, 255, 255, 255, 255})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
}
func TestDecodePropertiesBadUserProps(t *testing.T) {
b := bytes.NewBuffer([]byte{10, 38, 255, 255, 255, 255, 255, 255, 255, 255})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
}
func TestCopyProperties(t *testing.T) {
require.EqualValues(t, propertiesStruct, propertiesStruct.Copy(true))
}
func TestCopyPropertiesNoTransfer(t *testing.T) {
pkA := propertiesStruct
pkB := pkA.Copy(false)
// Properties which should never be transferred from one connection to another
require.Equal(t, uint16(0), pkB.TopicAlias)
}

4031
packets/tpackets.go Normal file

File diff suppressed because it is too large Load Diff

33
packets/tpackets_test.go Normal file
View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"testing"
"github.com/stretchr/testify/require"
)
func encodeTestOK(wanted TPacketCase) bool {
if wanted.RawBytes == nil {
return false
}
if wanted.Group != "" && wanted.Group != "encode" {
return false
}
return true
}
func decodeTestOK(wanted TPacketCase) bool {
if wanted.Group != "" && wanted.Group != "decode" {
return false
}
return true
}
func TestTPacketCaseGet(t *testing.T) {
require.Equal(t, TPacketData[Connect][1], TPacketData[Connect].Get(TConnectMqtt311))
require.Equal(t, TPacketCase{}, TPacketData[Connect].Get(byte(128)))
}