代码整理
This commit is contained in:
172
packets/codec.go
Normal file
172
packets/codec.go
Normal file
@@ -0,0 +1,172 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"unicode/utf8"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// bytesToString provides a zero-alloc no-copy byte to string conversion.
|
||||
// via https://github.com/golang/go/issues/25484#issuecomment-391415660
|
||||
func bytesToString(bs []byte) string {
|
||||
return *(*string)(unsafe.Pointer(&bs))
|
||||
}
|
||||
|
||||
// decodeUint16 extracts the value of two bytes from a byte array.
|
||||
func decodeUint16(buf []byte, offset int) (uint16, int, error) {
|
||||
if len(buf) < offset+2 {
|
||||
return 0, 0, ErrMalformedOffsetUintOutOfRange
|
||||
}
|
||||
|
||||
return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil
|
||||
}
|
||||
|
||||
// decodeUint32 extracts the value of four bytes from a byte array.
|
||||
func decodeUint32(buf []byte, offset int) (uint32, int, error) {
|
||||
if len(buf) < offset+4 {
|
||||
return 0, 0, ErrMalformedOffsetUintOutOfRange
|
||||
}
|
||||
|
||||
return binary.BigEndian.Uint32(buf[offset : offset+4]), offset + 4, nil
|
||||
}
|
||||
|
||||
// decodeString extracts a string from a byte array, beginning at an offset.
|
||||
func decodeString(buf []byte, offset int) (string, int, error) {
|
||||
b, n, err := decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
if !validUTF8(b) { // [MQTT-1.5.4-1] [MQTT-3.1.3-5]
|
||||
return "", 0, ErrMalformedInvalidUTF8
|
||||
}
|
||||
|
||||
return bytesToString(b), n, nil
|
||||
}
|
||||
|
||||
// validUTF8 checks if the byte array contains valid UTF-8 characters.
|
||||
func validUTF8(b []byte) bool {
|
||||
return utf8.Valid(b) && bytes.IndexByte(b, 0x00) == -1 // [MQTT-1.5.4-1] [MQTT-1.5.4-2]
|
||||
}
|
||||
|
||||
// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads.
|
||||
func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
|
||||
length, next, err := decodeUint16(buf, offset)
|
||||
if err != nil {
|
||||
return make([]byte, 0), 0, err
|
||||
}
|
||||
|
||||
if next+int(length) > len(buf) {
|
||||
return make([]byte, 0), 0, ErrMalformedOffsetBytesOutOfRange
|
||||
}
|
||||
|
||||
return buf[next : next+int(length)], next + int(length), nil
|
||||
}
|
||||
|
||||
// decodeByte extracts the value of a byte from a byte array.
|
||||
func decodeByte(buf []byte, offset int) (byte, int, error) {
|
||||
if len(buf) <= offset {
|
||||
return 0, 0, ErrMalformedOffsetByteOutOfRange
|
||||
}
|
||||
return buf[offset], offset + 1, nil
|
||||
}
|
||||
|
||||
// decodeByteBool extracts the value of a byte from a byte array and returns a bool.
|
||||
func decodeByteBool(buf []byte, offset int) (bool, int, error) {
|
||||
if len(buf) <= offset {
|
||||
return false, 0, ErrMalformedOffsetBoolOutOfRange
|
||||
}
|
||||
return 1&buf[offset] > 0, offset + 1, nil
|
||||
}
|
||||
|
||||
// encodeBool returns a byte instead of a bool.
|
||||
func encodeBool(b bool) byte {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads.
|
||||
func encodeBytes(val []byte) []byte {
|
||||
// In most circumstances the number of bytes being encoded is small.
|
||||
// Setting the cap to a low amount allows us to account for those without
|
||||
// triggering allocation growth on append unless we need to.
|
||||
buf := make([]byte, 2, 32)
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(val)))
|
||||
return append(buf, val...)
|
||||
}
|
||||
|
||||
// encodeUint16 encodes a uint16 value to a byte array.
|
||||
func encodeUint16(val uint16) []byte {
|
||||
buf := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(buf, val)
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeUint32 encodes a uint16 value to a byte array.
|
||||
func encodeUint32(val uint32) []byte {
|
||||
buf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(buf, val)
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeString encodes a string to a byte array.
|
||||
func encodeString(val string) []byte {
|
||||
// Like encodeBytes, we set the cap to a small number to avoid
|
||||
// triggering allocation growth on append unless we absolutely need to.
|
||||
buf := make([]byte, 2, 32)
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(val)))
|
||||
return append(buf, []byte(val)...)
|
||||
}
|
||||
|
||||
// encodeLength writes length bits for the header.
|
||||
func encodeLength(b *bytes.Buffer, length int64) {
|
||||
// 1.5.5 Variable Byte Integer encode non-normative
|
||||
// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
|
||||
for {
|
||||
eb := byte(length % 128)
|
||||
length /= 128
|
||||
if length > 0 {
|
||||
eb |= 0x80
|
||||
}
|
||||
b.WriteByte(eb)
|
||||
if length == 0 {
|
||||
break // [MQTT-1.5.5-1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func DecodeLength(b io.ByteReader) (n, bu int, err error) {
|
||||
// see 1.5.5 Variable Byte Integer decode non-normative
|
||||
// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
|
||||
var multiplier uint32
|
||||
var value uint32
|
||||
bu = 1
|
||||
for {
|
||||
eb, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return 0, bu, err
|
||||
}
|
||||
|
||||
value |= uint32(eb&127) << multiplier
|
||||
if value > 268435455 {
|
||||
return 0, bu, ErrMalformedVariableByteInteger
|
||||
}
|
||||
|
||||
if (eb & 128) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
multiplier += 7
|
||||
bu++
|
||||
}
|
||||
|
||||
return int(value), bu, nil
|
||||
}
|
||||
422
packets/codec_test.go
Normal file
422
packets/codec_test.go
Normal file
@@ -0,0 +1,422 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBytesToString(t *testing.T) {
|
||||
b := []byte{'a', 'b', 'c'}
|
||||
require.Equal(t, "abc", bytesToString(b))
|
||||
}
|
||||
|
||||
func TestDecodeString(t *testing.T) {
|
||||
expect := []struct {
|
||||
name string
|
||||
rawBytes []byte
|
||||
result string
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: "a/b/c/d",
|
||||
},
|
||||
{
|
||||
offset: 14,
|
||||
rawBytes: []byte{
|
||||
Connect << 4, 17, // Fixed header
|
||||
0, 6, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name
|
||||
3, // Protocol Version
|
||||
0, // Packet Flags
|
||||
0, 30, // Keepalive
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'h', 'e', 'y', // Client ID "zen"},
|
||||
},
|
||||
result: "hey",
|
||||
},
|
||||
{
|
||||
offset: 2,
|
||||
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
|
||||
result: "1/2/3/4/a/b/c/d/e/^/@/!",
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
|
||||
result: "x/y/z",
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
|
||||
shouldFail: ErrMalformedOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 5,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
|
||||
shouldFail: ErrMalformedOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 9,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
|
||||
shouldFail: ErrMalformedOffsetUintOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 17,
|
||||
rawBytes: []byte{
|
||||
Connect << 4, 0, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
0, // Flags
|
||||
0, 20, // Keepalive
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
0, 6, // Will Topic - MSB+LSB
|
||||
'l',
|
||||
},
|
||||
shouldFail: ErrMalformedOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
|
||||
shouldFail: ErrMalformedInvalidUTF8,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeStringZeroWidthNoBreak(t *testing.T) { // [MQTT-1.5.4-3]
|
||||
result, _, err := decodeString([]byte{0, 3, 0xEF, 0xBB, 0xBF}, 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "\ufeff", result)
|
||||
}
|
||||
|
||||
func TestDecodeBytes(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result []uint8
|
||||
next int
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // truncated connect packet (clean session)
|
||||
result: []byte{0x4d, 0x51, 0x54, 0x54},
|
||||
next: 6,
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // truncated connect packet, only checking start
|
||||
result: []byte{0x4d, 0x51, 0x54, 0x54},
|
||||
next: 6,
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81},
|
||||
offset: 0,
|
||||
shouldFail: ErrMalformedOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81},
|
||||
offset: 8,
|
||||
shouldFail: ErrMalformedOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeByte(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result uint8
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
|
||||
result: uint8(0x00),
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x04),
|
||||
offset: 1,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x4d),
|
||||
offset: 2,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x51),
|
||||
offset: 3,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 80, 82, 84},
|
||||
offset: 8,
|
||||
shouldFail: ErrMalformedOffsetByteOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+1, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeUint16(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result uint16
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: uint16(0x07),
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: uint16(0x761),
|
||||
offset: 1,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 7, 255, 47},
|
||||
offset: 8,
|
||||
shouldFail: ErrMalformedOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+2, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeUint32(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result uint32
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 0, 0, 7, 8},
|
||||
result: uint32(7),
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 0, 1, 226, 64, 8},
|
||||
result: uint32(123456),
|
||||
offset: 1,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 7, 255, 47},
|
||||
offset: 8,
|
||||
shouldFail: ErrMalformedOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeUint32(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+4, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeByteBool(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result bool
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0x00, 0x00},
|
||||
result: false,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0x01, 0x00},
|
||||
result: true,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0x01, 0x00},
|
||||
offset: 5,
|
||||
shouldFail: ErrMalformedOffsetBoolOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, 1, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeLength(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{0x78})
|
||||
n, bu, err := DecodeLength(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 120, n)
|
||||
require.Equal(t, 1, bu)
|
||||
|
||||
b = bytes.NewBuffer([]byte{255, 255, 255, 127})
|
||||
n, bu, err = DecodeLength(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 268435455, n)
|
||||
require.Equal(t, 4, bu)
|
||||
}
|
||||
|
||||
func TestDecodeLengthErrors(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
_, _, err := DecodeLength(b)
|
||||
require.Error(t, err)
|
||||
|
||||
b = bytes.NewBuffer([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f})
|
||||
_, _, err = DecodeLength(b)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrMalformedVariableByteInteger)
|
||||
}
|
||||
|
||||
func TestEncodeBool(t *testing.T) {
|
||||
result := encodeBool(true)
|
||||
require.Equal(t, byte(1), result)
|
||||
|
||||
result = encodeBool(false)
|
||||
require.Equal(t, byte(0), result)
|
||||
|
||||
// Check failure.
|
||||
result = encodeBool(false)
|
||||
require.NotEqual(t, byte(1), result)
|
||||
}
|
||||
|
||||
func TestEncodeBytes(t *testing.T) {
|
||||
result := encodeBytes([]byte("testing"))
|
||||
require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result)
|
||||
|
||||
result = encodeBytes([]byte("testing"))
|
||||
require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result)
|
||||
}
|
||||
|
||||
func TestEncodeUint16(t *testing.T) {
|
||||
result := encodeUint16(0)
|
||||
require.Equal(t, []byte{0x00, 0x00}, result)
|
||||
|
||||
result = encodeUint16(32767)
|
||||
require.Equal(t, []byte{0x7f, 0xff}, result)
|
||||
|
||||
result = encodeUint16(math.MaxUint16)
|
||||
require.Equal(t, []byte{0xff, 0xff}, result)
|
||||
}
|
||||
|
||||
func TestEncodeUint32(t *testing.T) {
|
||||
result := encodeUint32(7)
|
||||
require.Equal(t, []byte{0x00, 0x00, 0x00, 0x07}, result)
|
||||
|
||||
result = encodeUint32(32767)
|
||||
require.Equal(t, []byte{0, 0, 127, 255}, result)
|
||||
|
||||
result = encodeUint32(math.MaxUint32)
|
||||
require.Equal(t, []byte{255, 255, 255, 255}, result)
|
||||
}
|
||||
|
||||
func TestEncodeString(t *testing.T) {
|
||||
result := encodeString("testing")
|
||||
require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result)
|
||||
|
||||
result = encodeString("")
|
||||
require.Equal(t, []uint8{0x00, 0x00}, result)
|
||||
|
||||
result = encodeString("a")
|
||||
require.Equal(t, []uint8{0x00, 0x01, 0x61}, result)
|
||||
|
||||
result = encodeString("b")
|
||||
require.NotEqual(t, []uint8{0x00, 0x00}, result)
|
||||
}
|
||||
|
||||
func TestEncodeLength(t *testing.T) {
|
||||
b := new(bytes.Buffer)
|
||||
encodeLength(b, 120)
|
||||
require.Equal(t, []byte{0x78}, b.Bytes())
|
||||
|
||||
b = new(bytes.Buffer)
|
||||
encodeLength(b, math.MaxInt64)
|
||||
require.Equal(t, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, b.Bytes())
|
||||
}
|
||||
|
||||
func TestValidUTF8(t *testing.T) {
|
||||
require.True(t, validUTF8([]byte{0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}))
|
||||
require.False(t, validUTF8([]byte{0xff, 0xff}))
|
||||
require.False(t, validUTF8([]byte{0x74, 0x00, 0x73, 0x74}))
|
||||
}
|
||||
149
packets/codes.go
Normal file
149
packets/codes.go
Normal file
@@ -0,0 +1,149 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
// Code contains a reason code and reason string for a response.
|
||||
type Code struct {
|
||||
Reason string
|
||||
Code byte
|
||||
}
|
||||
|
||||
// String returns the readable reason for a code.
|
||||
func (c Code) String() string {
|
||||
return c.Reason
|
||||
}
|
||||
|
||||
// Error returns the readable reason for a code.
|
||||
func (c Code) Error() string {
|
||||
return c.Reason
|
||||
}
|
||||
|
||||
var (
|
||||
// QosCodes indicates the reason codes for each Qos byte.
|
||||
QosCodes = map[byte]Code{
|
||||
0: CodeGrantedQos0,
|
||||
1: CodeGrantedQos1,
|
||||
2: CodeGrantedQos2,
|
||||
}
|
||||
|
||||
CodeSuccessIgnore = Code{Code: 0x00, Reason: "ignore packet"}
|
||||
CodeSuccess = Code{Code: 0x00, Reason: "success"}
|
||||
CodeDisconnect = Code{Code: 0x00, Reason: "disconnected"}
|
||||
CodeGrantedQos0 = Code{Code: 0x00, Reason: "granted qos 0"}
|
||||
CodeGrantedQos1 = Code{Code: 0x01, Reason: "granted qos 1"}
|
||||
CodeGrantedQos2 = Code{Code: 0x02, Reason: "granted qos 2"}
|
||||
CodeDisconnectWillMessage = Code{Code: 0x04, Reason: "disconnect with will message"}
|
||||
CodeNoMatchingSubscribers = Code{Code: 0x10, Reason: "no matching subscribers"}
|
||||
CodeNoSubscriptionExisted = Code{Code: 0x11, Reason: "no subscription existed"}
|
||||
CodeContinueAuthentication = Code{Code: 0x18, Reason: "continue authentication"}
|
||||
CodeReAuthenticate = Code{Code: 0x19, Reason: "re-authenticate"}
|
||||
ErrUnspecifiedError = Code{Code: 0x80, Reason: "unspecified error"}
|
||||
ErrMalformedPacket = Code{Code: 0x81, Reason: "malformed packet"}
|
||||
ErrMalformedProtocolName = Code{Code: 0x81, Reason: "malformed packet: protocol name"}
|
||||
ErrMalformedProtocolVersion = Code{Code: 0x81, Reason: "malformed packet: protocol version"}
|
||||
ErrMalformedFlags = Code{Code: 0x81, Reason: "malformed packet: flags"}
|
||||
ErrMalformedKeepalive = Code{Code: 0x81, Reason: "malformed packet: keepalive"}
|
||||
ErrMalformedPacketID = Code{Code: 0x81, Reason: "malformed packet: packet identifier"}
|
||||
ErrMalformedTopic = Code{Code: 0x81, Reason: "malformed packet: topic"}
|
||||
ErrMalformedWillTopic = Code{Code: 0x81, Reason: "malformed packet: will topic"}
|
||||
ErrMalformedWillPayload = Code{Code: 0x81, Reason: "malformed packet: will message"}
|
||||
ErrMalformedUsername = Code{Code: 0x81, Reason: "malformed packet: username"}
|
||||
ErrMalformedPassword = Code{Code: 0x81, Reason: "malformed packet: password"}
|
||||
ErrMalformedQos = Code{Code: 0x81, Reason: "malformed packet: qos"}
|
||||
ErrMalformedOffsetUintOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset uint out of range"}
|
||||
ErrMalformedOffsetBytesOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset bytes out of range"}
|
||||
ErrMalformedOffsetByteOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset byte out of range"}
|
||||
ErrMalformedOffsetBoolOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset boolean out of range"}
|
||||
ErrMalformedInvalidUTF8 = Code{Code: 0x81, Reason: "malformed packet: invalid utf-8 string"}
|
||||
ErrMalformedVariableByteInteger = Code{Code: 0x81, Reason: "malformed packet: variable byte integer out of range"}
|
||||
ErrMalformedBadProperty = Code{Code: 0x81, Reason: "malformed packet: unknown property"}
|
||||
ErrMalformedProperties = Code{Code: 0x81, Reason: "malformed packet: properties"}
|
||||
ErrMalformedWillProperties = Code{Code: 0x81, Reason: "malformed packet: will properties"}
|
||||
ErrMalformedSessionPresent = Code{Code: 0x81, Reason: "malformed packet: session present"}
|
||||
ErrMalformedReasonCode = Code{Code: 0x81, Reason: "malformed packet: reason code"}
|
||||
ErrProtocolViolation = Code{Code: 0x82, Reason: "protocol violation"}
|
||||
ErrProtocolViolationProtocolName = Code{Code: 0x82, Reason: "protocol violation: protocol name"}
|
||||
ErrProtocolViolationProtocolVersion = Code{Code: 0x82, Reason: "protocol violation: protocol version"}
|
||||
ErrProtocolViolationReservedBit = Code{Code: 0x82, Reason: "protocol violation: reserved bit not 0"}
|
||||
ErrProtocolViolationFlagNoUsername = Code{Code: 0x82, Reason: "protocol violation: username flag set but no value"}
|
||||
ErrProtocolViolationFlagNoPassword = Code{Code: 0x82, Reason: "protocol violation: password flag set but no value"}
|
||||
ErrProtocolViolationUsernameNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
|
||||
ErrProtocolViolationPasswordNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
|
||||
ErrProtocolViolationPasswordTooLong = Code{Code: 0x82, Reason: "protocol violation: password too long"}
|
||||
ErrProtocolViolationUsernameTooLong = Code{Code: 0x82, Reason: "protocol violation: username too long"}
|
||||
ErrProtocolViolationNoPacketID = Code{Code: 0x82, Reason: "protocol violation: missing packet id"}
|
||||
ErrProtocolViolationSurplusPacketID = Code{Code: 0x82, Reason: "protocol violation: surplus packet id"}
|
||||
ErrProtocolViolationQosOutOfRange = Code{Code: 0x82, Reason: "protocol violation: qos out of range"}
|
||||
ErrProtocolViolationSecondConnect = Code{Code: 0x82, Reason: "protocol violation: second connect packet"}
|
||||
ErrProtocolViolationZeroNonZeroExpiry = Code{Code: 0x82, Reason: "protocol violation: non-zero expiry"}
|
||||
ErrProtocolViolationRequireFirstConnect = Code{Code: 0x82, Reason: "protocol violation: first packet must be connect"}
|
||||
ErrProtocolViolationWillFlagNoPayload = Code{Code: 0x82, Reason: "protocol violation: will flag no payload"}
|
||||
ErrProtocolViolationWillFlagSurplusRetain = Code{Code: 0x82, Reason: "protocol violation: will flag surplus retain"}
|
||||
ErrProtocolViolationSurplusWildcard = Code{Code: 0x82, Reason: "protocol violation: topic contains wildcards"}
|
||||
ErrProtocolViolationSurplusSubID = Code{Code: 0x82, Reason: "protocol violation: contained subscription identifier"}
|
||||
ErrProtocolViolationInvalidTopic = Code{Code: 0x82, Reason: "protocol violation: invalid topic"}
|
||||
ErrProtocolViolationInvalidSharedNoLocal = Code{Code: 0x82, Reason: "protocol violation: invalid shared no local"}
|
||||
ErrProtocolViolationNoFilters = Code{Code: 0x82, Reason: "protocol violation: must contain at least one filter"}
|
||||
ErrProtocolViolationInvalidReason = Code{Code: 0x82, Reason: "protocol violation: invalid reason"}
|
||||
ErrProtocolViolationOversizeSubID = Code{Code: 0x82, Reason: "protocol violation: oversize subscription id"}
|
||||
ErrProtocolViolationDupNoQos = Code{Code: 0x82, Reason: "protocol violation: dup true with no qos"}
|
||||
ErrProtocolViolationUnsupportedProperty = Code{Code: 0x82, Reason: "protocol violation: unsupported property"}
|
||||
ErrProtocolViolationNoTopic = Code{Code: 0x82, Reason: "protocol violation: no topic or alias"}
|
||||
ErrImplementationSpecificError = Code{Code: 0x83, Reason: "implementation specific error"}
|
||||
ErrRejectPacket = Code{Code: 0x83, Reason: "packet rejected"}
|
||||
ErrUnsupportedProtocolVersion = Code{Code: 0x84, Reason: "unsupported protocol version"}
|
||||
ErrClientIdentifierNotValid = Code{Code: 0x85, Reason: "client identifier not valid"}
|
||||
ErrClientIdentifierTooLong = Code{Code: 0x85, Reason: "client identifier too long"}
|
||||
ErrBadUsernameOrPassword = Code{Code: 0x86, Reason: "bad username or password"}
|
||||
ErrNotAuthorized = Code{Code: 0x87, Reason: "not authorized"}
|
||||
ErrServerUnavailable = Code{Code: 0x88, Reason: "server unavailable"}
|
||||
ErrServerBusy = Code{Code: 0x89, Reason: "server busy"}
|
||||
ErrBanned = Code{Code: 0x8A, Reason: "banned"}
|
||||
ErrServerShuttingDown = Code{Code: 0x8B, Reason: "server shutting down"}
|
||||
ErrBadAuthenticationMethod = Code{Code: 0x8C, Reason: "bad authentication method"}
|
||||
ErrKeepAliveTimeout = Code{Code: 0x8D, Reason: "keep alive timeout"}
|
||||
ErrSessionTakenOver = Code{Code: 0x8E, Reason: "session takeover"}
|
||||
ErrTopicFilterInvalid = Code{Code: 0x8F, Reason: "topic filter invalid"}
|
||||
ErrTopicNameInvalid = Code{Code: 0x90, Reason: "topic name invalid"}
|
||||
ErrPacketIdentifierInUse = Code{Code: 0x91, Reason: "packet identifier in use"}
|
||||
ErrPacketIdentifierNotFound = Code{Code: 0x92, Reason: "packet identifier not found"}
|
||||
ErrReceiveMaximum = Code{Code: 0x93, Reason: "receive maximum exceeded"}
|
||||
ErrTopicAliasInvalid = Code{Code: 0x94, Reason: "topic alias invalid"}
|
||||
ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
|
||||
ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
|
||||
ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
|
||||
ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"}
|
||||
ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
|
||||
ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
|
||||
ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}
|
||||
ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
|
||||
ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
|
||||
ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
|
||||
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptions not supported"}
|
||||
ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
|
||||
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
|
||||
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
|
||||
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
|
||||
ErrInlineSubscriptionHandlerInvalid = Code{Code: 0xA3, Reason: "inline subscription handler not valid."}
|
||||
|
||||
// MQTTv3 specific bytes.
|
||||
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
|
||||
Err3ClientIdentifierNotValid = Code{Code: 0x02}
|
||||
Err3ServerUnavailable = Code{Code: 0x03}
|
||||
ErrMalformedUsernameOrPassword = Code{Code: 0x04}
|
||||
Err3NotAuthorized = Code{Code: 0x05}
|
||||
|
||||
// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
|
||||
// This is required because MQTTv3 has different return byte specification.
|
||||
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
|
||||
V5CodesToV3 = map[Code]Code{
|
||||
ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
|
||||
ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
|
||||
ErrServerUnavailable: Err3ServerUnavailable,
|
||||
ErrMalformedUsername: ErrMalformedUsernameOrPassword,
|
||||
ErrMalformedPassword: ErrMalformedUsernameOrPassword,
|
||||
ErrBadUsernameOrPassword: Err3NotAuthorized,
|
||||
}
|
||||
)
|
||||
29
packets/codes_test.go
Normal file
29
packets/codes_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCodesString(t *testing.T) {
|
||||
c := Code{
|
||||
Reason: "test",
|
||||
Code: 0x1,
|
||||
}
|
||||
|
||||
require.Equal(t, "test", c.String())
|
||||
}
|
||||
|
||||
func TestCodesError(t *testing.T) {
|
||||
c := Code{
|
||||
Reason: "error",
|
||||
Code: 0x1,
|
||||
}
|
||||
|
||||
require.Equal(t, "error", error(c).Error())
|
||||
}
|
||||
63
packets/fixedheader.go
Normal file
63
packets/fixedheader.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
|
||||
type FixedHeader struct {
|
||||
Remaining int `json:"remaining"` // the number of remaining bytes in the payload.
|
||||
Type byte `json:"type"` // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
|
||||
Qos byte `json:"qos"` // indicates the quality of service expected.
|
||||
Dup bool `json:"dup"` // indicates if the packet was already sent at an earlier time.
|
||||
Retain bool `json:"retain"` // whether the message should be retained.
|
||||
}
|
||||
|
||||
// Encode encodes the FixedHeader and returns a bytes buffer.
|
||||
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
|
||||
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
|
||||
encodeLength(buf, int64(fh.Remaining))
|
||||
}
|
||||
|
||||
// Decode extracts the specification bits from the header byte.
|
||||
func (fh *FixedHeader) Decode(hb byte) error {
|
||||
fh.Type = hb >> 4 // Get the message type from the first 4 bytes.
|
||||
|
||||
switch fh.Type {
|
||||
case Publish:
|
||||
if (hb>>1)&0x01 > 0 && (hb>>1)&0x02 > 0 {
|
||||
return ErrProtocolViolationQosOutOfRange // [MQTT-3.3.1-4]
|
||||
}
|
||||
|
||||
fh.Dup = (hb>>3)&0x01 > 0 // is duplicate
|
||||
fh.Qos = (hb >> 1) & 0x03 // qos flag
|
||||
fh.Retain = hb&0x01 > 0 // is retain flag
|
||||
case Pubrel:
|
||||
fallthrough
|
||||
case Subscribe:
|
||||
fallthrough
|
||||
case Unsubscribe:
|
||||
if (hb>>0)&0x01 != 0 || (hb>>1)&0x01 != 1 || (hb>>2)&0x01 != 0 || (hb>>3)&0x01 != 0 { // [MQTT-3.8.1-1] [MQTT-3.10.1-1]
|
||||
return ErrMalformedFlags
|
||||
}
|
||||
|
||||
fh.Qos = (hb >> 1) & 0x03
|
||||
default:
|
||||
if (hb>>0)&0x01 != 0 ||
|
||||
(hb>>1)&0x01 != 0 ||
|
||||
(hb>>2)&0x01 != 0 ||
|
||||
(hb>>3)&0x01 != 0 { // [MQTT-3.8.3-5] [MQTT-3.14.1-1] [MQTT-3.15.1-1]
|
||||
return ErrMalformedFlags
|
||||
}
|
||||
}
|
||||
|
||||
if fh.Qos == 0 && fh.Dup {
|
||||
return ErrProtocolViolationDupNoQos // [MQTT-3.3.1-2]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
237
packets/fixedheader_test.go
Normal file
237
packets/fixedheader_test.go
Normal file
@@ -0,0 +1,237 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fixedHeaderTable struct {
|
||||
desc string
|
||||
rawBytes []byte
|
||||
header FixedHeader
|
||||
packetError bool
|
||||
expect error
|
||||
}
|
||||
|
||||
var fixedHeaderExpected = []fixedHeaderTable{
|
||||
{
|
||||
desc: "connect",
|
||||
rawBytes: []byte{Connect << 4, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "connack",
|
||||
rawBytes: []byte{Connack << 4, 0x00},
|
||||
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish",
|
||||
rawBytes: []byte{Publish << 4, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish qos 1",
|
||||
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish qos 1 retain",
|
||||
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish qos 2",
|
||||
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish qos 2 retain",
|
||||
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish dup qos 0",
|
||||
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
|
||||
expect: ErrProtocolViolationDupNoQos,
|
||||
},
|
||||
{
|
||||
desc: "publish dup qos 0 retain",
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
|
||||
expect: ErrProtocolViolationDupNoQos,
|
||||
},
|
||||
{
|
||||
desc: "publish dup qos 1 retain",
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish dup qos 2 retain",
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "puback",
|
||||
rawBytes: []byte{Puback << 4, 0x00},
|
||||
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pubrec",
|
||||
rawBytes: []byte{Pubrec << 4, 0x00},
|
||||
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pubrel",
|
||||
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pubcomp",
|
||||
rawBytes: []byte{Pubcomp << 4, 0x00},
|
||||
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "subscribe",
|
||||
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "suback",
|
||||
rawBytes: []byte{Suback << 4, 0x00},
|
||||
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "unsubscribe",
|
||||
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "unsuback",
|
||||
rawBytes: []byte{Unsuback << 4, 0x00},
|
||||
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pingreq",
|
||||
rawBytes: []byte{Pingreq << 4, 0x00},
|
||||
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pingresp",
|
||||
rawBytes: []byte{Pingresp << 4, 0x00},
|
||||
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "disconnect",
|
||||
rawBytes: []byte{Disconnect << 4, 0x00},
|
||||
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "auth",
|
||||
rawBytes: []byte{Auth << 4, 0x00},
|
||||
header: FixedHeader{Type: Auth, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
|
||||
// remaining length
|
||||
{
|
||||
desc: "remaining length 10",
|
||||
rawBytes: []byte{Publish << 4, 0x0a},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
|
||||
},
|
||||
{
|
||||
desc: "remaining length 512",
|
||||
rawBytes: []byte{Publish << 4, 0x80, 0x04},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
|
||||
},
|
||||
{
|
||||
desc: "remaining length 978",
|
||||
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
|
||||
},
|
||||
{
|
||||
desc: "remaining length 20202",
|
||||
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
|
||||
},
|
||||
{
|
||||
desc: "remaining length oversize",
|
||||
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
|
||||
packetError: true,
|
||||
},
|
||||
|
||||
// Invalid flags for packet
|
||||
{
|
||||
desc: "invalid type dup is true",
|
||||
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
{
|
||||
desc: "invalid type qos is 1",
|
||||
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
{
|
||||
desc: "invalid type retain is true",
|
||||
rawBytes: []byte{Connect<<4 | 1, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
{
|
||||
desc: "invalid publish qos bits 1 + 2 set",
|
||||
rawBytes: []byte{Publish<<4 | 1<<1 | 1<<2, 0x00},
|
||||
header: FixedHeader{Type: Publish},
|
||||
expect: ErrProtocolViolationQosOutOfRange,
|
||||
},
|
||||
{
|
||||
desc: "invalid pubrel bits 3,2,1,0 should be 0,0,1,0",
|
||||
rawBytes: []byte{Pubrel<<4 | 1<<2 | 1<<0, 0x00},
|
||||
header: FixedHeader{Type: Pubrel, Qos: 1},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
{
|
||||
desc: "invalid subscribe bits 3,2,1,0 should be 0,0,1,0",
|
||||
rawBytes: []byte{Subscribe<<4 | 1<<2, 0x00},
|
||||
header: FixedHeader{Type: Subscribe, Qos: 1},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
}
|
||||
|
||||
func TestFixedHeaderEncode(t *testing.T) {
|
||||
for _, wanted := range fixedHeaderExpected {
|
||||
t.Run(wanted.desc, func(t *testing.T) {
|
||||
buf := new(bytes.Buffer)
|
||||
wanted.header.Encode(buf)
|
||||
if wanted.expect == nil {
|
||||
require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()))
|
||||
require.EqualValues(t, wanted.rawBytes, buf.Bytes())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixedHeaderDecode(t *testing.T) {
|
||||
for _, wanted := range fixedHeaderExpected {
|
||||
t.Run(wanted.desc, func(t *testing.T) {
|
||||
fh := new(FixedHeader)
|
||||
err := fh.Decode(wanted.rawBytes[0])
|
||||
if wanted.expect != nil {
|
||||
require.Equal(t, wanted.expect, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.header.Type, fh.Type)
|
||||
require.Equal(t, wanted.header.Dup, fh.Dup)
|
||||
require.Equal(t, wanted.header.Qos, fh.Qos)
|
||||
require.Equal(t, wanted.header.Retain, fh.Retain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
1173
packets/packets.go
Normal file
1173
packets/packets.go
Normal file
File diff suppressed because it is too large
Load Diff
505
packets/packets_test.go
Normal file
505
packets/packets_test.go
Normal file
@@ -0,0 +1,505 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/copier"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const pkInfo = "packet type %v, %s"
|
||||
|
||||
var packetList = []byte{
|
||||
Connect,
|
||||
Connack,
|
||||
Publish,
|
||||
Puback,
|
||||
Pubrec,
|
||||
Pubrel,
|
||||
Pubcomp,
|
||||
Subscribe,
|
||||
Suback,
|
||||
Unsubscribe,
|
||||
Unsuback,
|
||||
Pingreq,
|
||||
Pingresp,
|
||||
Disconnect,
|
||||
Auth,
|
||||
}
|
||||
|
||||
var pkTable = []TPacketCase{
|
||||
TPacketData[Connect].Get(TConnectMqtt311),
|
||||
TPacketData[Connect].Get(TConnectMqtt5),
|
||||
TPacketData[Connect].Get(TConnectUserPassLWT),
|
||||
TPacketData[Connack].Get(TConnackAcceptedMqtt5),
|
||||
TPacketData[Connack].Get(TConnackAcceptedNoSession),
|
||||
TPacketData[Publish].Get(TPublishBasic),
|
||||
TPacketData[Publish].Get(TPublishMqtt5),
|
||||
TPacketData[Puback].Get(TPuback),
|
||||
TPacketData[Pubrec].Get(TPubrec),
|
||||
TPacketData[Pubrel].Get(TPubrel),
|
||||
TPacketData[Pubcomp].Get(TPubcomp),
|
||||
TPacketData[Subscribe].Get(TSubscribe),
|
||||
TPacketData[Subscribe].Get(TSubscribeMqtt5),
|
||||
TPacketData[Suback].Get(TSuback),
|
||||
TPacketData[Unsubscribe].Get(TUnsubscribe),
|
||||
TPacketData[Unsubscribe].Get(TUnsubscribeMqtt5),
|
||||
TPacketData[Pingreq].Get(TPingreq),
|
||||
TPacketData[Pingresp].Get(TPingresp),
|
||||
TPacketData[Disconnect].Get(TDisconnect),
|
||||
TPacketData[Disconnect].Get(TDisconnectMqtt5),
|
||||
}
|
||||
|
||||
func TestNewPackets(t *testing.T) {
|
||||
s := NewPackets()
|
||||
require.NotNil(t, s.internal)
|
||||
}
|
||||
|
||||
func TestPacketsAdd(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
}
|
||||
|
||||
func TestPacketsGet(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{TopicName: "a1"})
|
||||
s.Add("cl2", Packet{TopicName: "a2"})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
|
||||
pk, ok := s.Get("cl1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "a1", pk.TopicName)
|
||||
}
|
||||
|
||||
func TestPacketsGetAll(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{TopicName: "a1"})
|
||||
s.Add("cl2", Packet{TopicName: "a2"})
|
||||
s.Add("cl3", Packet{TopicName: "a3"})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
require.Contains(t, s.internal, "cl3")
|
||||
|
||||
subs := s.GetAll()
|
||||
require.Len(t, subs, 3)
|
||||
}
|
||||
|
||||
func TestPacketsLen(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{TopicName: "a1"})
|
||||
s.Add("cl2", Packet{TopicName: "a2"})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
require.Equal(t, 2, s.Len())
|
||||
}
|
||||
|
||||
func TestSPacketsDelete(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{TopicName: "a1"})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
|
||||
s.Delete("cl1")
|
||||
_, ok := s.Get("cl1")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestFormatPacketID(t *testing.T) {
|
||||
for _, id := range []uint16{0, 7, 0x100, 0xffff} {
|
||||
packet := &Packet{PacketID: id}
|
||||
require.Equal(t, fmt.Sprint(id), packet.FormatID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionOptionsEncodeDecode(t *testing.T) {
|
||||
p := &Subscription{
|
||||
Qos: 2,
|
||||
NoLocal: true,
|
||||
RetainAsPublished: true,
|
||||
RetainHandling: 2,
|
||||
}
|
||||
x := new(Subscription)
|
||||
x.decode(p.encode())
|
||||
require.Equal(t, *p, *x)
|
||||
|
||||
p = &Subscription{
|
||||
Qos: 1,
|
||||
NoLocal: false,
|
||||
RetainAsPublished: false,
|
||||
RetainHandling: 1,
|
||||
}
|
||||
x = new(Subscription)
|
||||
x.decode(p.encode())
|
||||
require.Equal(t, *p, *x)
|
||||
}
|
||||
|
||||
func TestPacketEncode(t *testing.T) {
|
||||
for _, pkt := range packetList {
|
||||
require.Contains(t, TPacketData, pkt)
|
||||
for _, wanted := range TPacketData[pkt] {
|
||||
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
|
||||
if !encodeTestOK(wanted) {
|
||||
return
|
||||
}
|
||||
|
||||
pk := new(Packet)
|
||||
_ = copier.Copy(pk, wanted.Packet)
|
||||
require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
pk.Mods.AllowResponseInfo = true
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
var err error
|
||||
switch pkt {
|
||||
case Connect:
|
||||
err = pk.ConnectEncode(buf)
|
||||
case Connack:
|
||||
err = pk.ConnackEncode(buf)
|
||||
case Publish:
|
||||
err = pk.PublishEncode(buf)
|
||||
case Puback:
|
||||
err = pk.PubackEncode(buf)
|
||||
case Pubrec:
|
||||
err = pk.PubrecEncode(buf)
|
||||
case Pubrel:
|
||||
err = pk.PubrelEncode(buf)
|
||||
case Pubcomp:
|
||||
err = pk.PubcompEncode(buf)
|
||||
case Subscribe:
|
||||
err = pk.SubscribeEncode(buf)
|
||||
case Suback:
|
||||
err = pk.SubackEncode(buf)
|
||||
case Unsubscribe:
|
||||
err = pk.UnsubscribeEncode(buf)
|
||||
case Unsuback:
|
||||
err = pk.UnsubackEncode(buf)
|
||||
case Pingreq:
|
||||
err = pk.PingreqEncode(buf)
|
||||
case Pingresp:
|
||||
err = pk.PingrespEncode(buf)
|
||||
case Disconnect:
|
||||
err = pk.DisconnectEncode(buf)
|
||||
case Auth:
|
||||
err = pk.AuthEncode(buf)
|
||||
}
|
||||
if wanted.Expect != nil {
|
||||
require.Error(t, err, pkInfo, pkt, wanted.Desc)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
|
||||
encoded := buf.Bytes()
|
||||
|
||||
// If ActualBytes is set, compare mutated version of byte string instead (to avoid length mismatches, etc).
|
||||
if len(wanted.ActualBytes) > 0 {
|
||||
wanted.RawBytes = wanted.ActualBytes
|
||||
}
|
||||
require.EqualValues(t, wanted.RawBytes, encoded, pkInfo, pkt, wanted.Desc)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacketDecode(t *testing.T) {
|
||||
for _, pkt := range packetList {
|
||||
require.Contains(t, TPacketData, pkt)
|
||||
for _, wanted := range TPacketData[pkt] {
|
||||
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
|
||||
if !decodeTestOK(wanted) {
|
||||
return
|
||||
}
|
||||
|
||||
pk := &Packet{FixedHeader: FixedHeader{Type: pkt}}
|
||||
pk.Mods.AllowResponseInfo = true
|
||||
_ = pk.FixedHeader.Decode(wanted.RawBytes[0])
|
||||
if len(wanted.RawBytes) > 0 {
|
||||
pk.FixedHeader.Remaining = int(wanted.RawBytes[1])
|
||||
}
|
||||
|
||||
if wanted.Packet != nil && wanted.Packet.ProtocolVersion != 0 {
|
||||
pk.ProtocolVersion = wanted.Packet.ProtocolVersion
|
||||
}
|
||||
|
||||
buf := wanted.RawBytes[2:]
|
||||
var err error
|
||||
switch pkt {
|
||||
case Connect:
|
||||
err = pk.ConnectDecode(buf)
|
||||
case Connack:
|
||||
err = pk.ConnackDecode(buf)
|
||||
case Publish:
|
||||
err = pk.PublishDecode(buf)
|
||||
case Puback:
|
||||
err = pk.PubackDecode(buf)
|
||||
case Pubrec:
|
||||
err = pk.PubrecDecode(buf)
|
||||
case Pubrel:
|
||||
err = pk.PubrelDecode(buf)
|
||||
case Pubcomp:
|
||||
err = pk.PubcompDecode(buf)
|
||||
case Subscribe:
|
||||
err = pk.SubscribeDecode(buf)
|
||||
case Suback:
|
||||
err = pk.SubackDecode(buf)
|
||||
case Unsubscribe:
|
||||
err = pk.UnsubscribeDecode(buf)
|
||||
case Unsuback:
|
||||
err = pk.UnsubackDecode(buf)
|
||||
case Pingreq:
|
||||
err = pk.PingreqDecode(buf)
|
||||
case Pingresp:
|
||||
err = pk.PingrespDecode(buf)
|
||||
case Disconnect:
|
||||
err = pk.DisconnectDecode(buf)
|
||||
case Auth:
|
||||
err = pk.AuthDecode(buf)
|
||||
}
|
||||
|
||||
if wanted.FailFirst != nil {
|
||||
require.Error(t, err, pkInfo, pkt, wanted.Desc)
|
||||
require.ErrorIs(t, err, wanted.FailFirst, pkInfo, pkt, wanted.Desc)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.EqualValues(t, wanted.Packet.Filters, pk.Filters, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.Equal(t, wanted.Packet.FixedHeader.Type, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.FixedHeader.Dup, pk.FixedHeader.Dup, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.FixedHeader.Qos, pk.FixedHeader.Qos, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.FixedHeader.Retain, pk.FixedHeader.Retain, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
if pkt == Connect {
|
||||
// we use ProtocolVersion for controlling packet encoding, but we don't need to test
|
||||
// against it unless it's a connect packet.
|
||||
require.Equal(t, wanted.Packet.ProtocolVersion, pk.ProtocolVersion, pkInfo, pkt, wanted.Desc)
|
||||
}
|
||||
require.Equal(t, wanted.Packet.Connect.ProtocolName, pk.Connect.ProtocolName, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.Clean, pk.Connect.Clean, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.ClientIdentifier, pk.Connect.ClientIdentifier, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.Keepalive, pk.Connect.Keepalive, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.Equal(t, wanted.Packet.Connect.UsernameFlag, pk.Connect.UsernameFlag, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.Username, pk.Connect.Username, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.PasswordFlag, pk.Connect.PasswordFlag, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.Password, pk.Connect.Password, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.Equal(t, wanted.Packet.Connect.WillFlag, pk.Connect.WillFlag, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.WillTopic, pk.Connect.WillTopic, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.WillPayload, pk.Connect.WillPayload, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.WillQos, pk.Connect.WillQos, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.WillRetain, pk.Connect.WillRetain, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.Equal(t, wanted.Packet.ReasonCodes, pk.ReasonCodes, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.ReasonCode, pk.ReasonCode, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.SessionPresent, pk.SessionPresent, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.PacketID, pk.PacketID, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.EqualValues(t, wanted.Packet.Properties, pk.Properties)
|
||||
require.EqualValues(t, wanted.Packet.Connect.WillProperties, pk.Connect.WillProperties)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
for _, pkt := range packetList {
|
||||
require.Contains(t, TPacketData, pkt)
|
||||
for _, wanted := range TPacketData[pkt] {
|
||||
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
|
||||
if wanted.Group == "validate" || wanted.Primary {
|
||||
pk := wanted.Packet
|
||||
var err error
|
||||
switch pkt {
|
||||
case Connect:
|
||||
err = pk.ConnectValidate()
|
||||
case Publish:
|
||||
err = pk.PublishValidate(1024)
|
||||
case Subscribe:
|
||||
err = pk.SubscribeValidate()
|
||||
case Unsubscribe:
|
||||
err = pk.UnsubscribeValidate()
|
||||
case Auth:
|
||||
err = pk.AuthValidate()
|
||||
}
|
||||
|
||||
if wanted.Expect != nil {
|
||||
require.Error(t, err, pkInfo, pkt, wanted.Desc)
|
||||
require.ErrorIs(t, wanted.Expect, err, pkInfo, pkt, wanted.Desc)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAckValidatePubrec(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeSuccess.Code,
|
||||
CodeNoMatchingSubscribers.Code,
|
||||
ErrUnspecifiedError.Code,
|
||||
ErrImplementationSpecificError.Code,
|
||||
ErrNotAuthorized.Code,
|
||||
ErrTopicNameInvalid.Code,
|
||||
ErrPacketIdentifierInUse.Code,
|
||||
ErrQuotaExceeded.Code,
|
||||
ErrPayloadFormatInvalid.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestAckValidatePubrel(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeSuccess.Code,
|
||||
ErrPacketIdentifierNotFound.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestAckValidatePubcomp(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeSuccess.Code,
|
||||
ErrPacketIdentifierNotFound.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubcomp}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestAckValidateSuback(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeGrantedQos0.Code,
|
||||
CodeGrantedQos1.Code,
|
||||
CodeGrantedQos2.Code,
|
||||
ErrUnspecifiedError.Code,
|
||||
ErrImplementationSpecificError.Code,
|
||||
ErrNotAuthorized.Code,
|
||||
ErrTopicFilterInvalid.Code,
|
||||
ErrPacketIdentifierInUse.Code,
|
||||
ErrQuotaExceeded.Code,
|
||||
ErrSharedSubscriptionsNotSupported.Code,
|
||||
ErrSubscriptionIdentifiersNotSupported.Code,
|
||||
ErrWildcardSubscriptionsNotSupported.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestAckValidateUnsuback(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeSuccess.Code,
|
||||
CodeNoSubscriptionExisted.Code,
|
||||
ErrUnspecifiedError.Code,
|
||||
ErrImplementationSpecificError.Code,
|
||||
ErrNotAuthorized.Code,
|
||||
ErrTopicFilterInvalid.Code,
|
||||
ErrPacketIdentifierInUse.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestReasonCodeValidMisc(t *testing.T) {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Connack}, ReasonCode: CodeSuccess.Code}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestCopy(t *testing.T) {
|
||||
for _, tt := range pkTable {
|
||||
pkc := tt.Packet.Copy(true)
|
||||
|
||||
require.Equal(t, tt.Packet.FixedHeader.Qos, pkc.FixedHeader.Qos, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, false, pkc.FixedHeader.Dup, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, false, pkc.FixedHeader.Retain, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
require.Equal(t, tt.Packet.TopicName, pkc.TopicName, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.ClientIdentifier, pkc.Connect.ClientIdentifier, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.Keepalive, pkc.Connect.Keepalive, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.ProtocolVersion, pkc.ProtocolVersion, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.PasswordFlag, pkc.Connect.PasswordFlag, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.UsernameFlag, pkc.Connect.UsernameFlag, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillQos, pkc.Connect.WillQos, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillTopic, pkc.Connect.WillTopic, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillFlag, pkc.Connect.WillFlag, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillRetain, pkc.Connect.WillRetain, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillProperties, pkc.Connect.WillProperties, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Properties, pkc.Properties, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.Clean, pkc.Connect.Clean, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.SessionPresent, pkc.SessionPresent, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.ReasonCode, pkc.ReasonCode, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.PacketID, pkc.PacketID, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Filters, pkc.Filters, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Payload, pkc.Payload, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.Password, pkc.Connect.Password, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.Username, pkc.Connect.Username, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.ProtocolName, pkc.Connect.ProtocolName, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillPayload, pkc.Connect.WillPayload, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.ReasonCodes, pkc.ReasonCodes, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
|
||||
require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
|
||||
|
||||
pkcc := tt.Packet.Copy(false)
|
||||
require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSubscription(t *testing.T) {
|
||||
sub := Subscription{
|
||||
Filter: "a/b/c",
|
||||
RetainHandling: 0,
|
||||
Qos: 0,
|
||||
RetainAsPublished: false,
|
||||
NoLocal: false,
|
||||
Identifier: 1,
|
||||
}
|
||||
|
||||
sub2 := Subscription{
|
||||
Filter: "a/b/d",
|
||||
RetainHandling: 0,
|
||||
Qos: 2,
|
||||
RetainAsPublished: false,
|
||||
NoLocal: true,
|
||||
Identifier: 2,
|
||||
}
|
||||
|
||||
expect := Subscription{
|
||||
Filter: "a/b/c",
|
||||
RetainHandling: 0,
|
||||
Qos: 2,
|
||||
RetainAsPublished: false,
|
||||
NoLocal: true,
|
||||
Identifier: 1,
|
||||
Identifiers: map[string]int{
|
||||
"a/b/c": 1,
|
||||
"a/b/d": 2,
|
||||
},
|
||||
}
|
||||
require.Equal(t, expect, sub.Merge(sub2))
|
||||
}
|
||||
481
packets/properties.go
Normal file
481
packets/properties.go
Normal file
@@ -0,0 +1,481 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"testmqtt/mempool"
|
||||
)
|
||||
|
||||
const (
|
||||
PropPayloadFormat byte = 1
|
||||
PropMessageExpiryInterval byte = 2
|
||||
PropContentType byte = 3
|
||||
PropResponseTopic byte = 8
|
||||
PropCorrelationData byte = 9
|
||||
PropSubscriptionIdentifier byte = 11
|
||||
PropSessionExpiryInterval byte = 17
|
||||
PropAssignedClientID byte = 18
|
||||
PropServerKeepAlive byte = 19
|
||||
PropAuthenticationMethod byte = 21
|
||||
PropAuthenticationData byte = 22
|
||||
PropRequestProblemInfo byte = 23
|
||||
PropWillDelayInterval byte = 24
|
||||
PropRequestResponseInfo byte = 25
|
||||
PropResponseInfo byte = 26
|
||||
PropServerReference byte = 28
|
||||
PropReasonString byte = 31
|
||||
PropReceiveMaximum byte = 33
|
||||
PropTopicAliasMaximum byte = 34
|
||||
PropTopicAlias byte = 35
|
||||
PropMaximumQos byte = 36
|
||||
PropRetainAvailable byte = 37
|
||||
PropUser byte = 38
|
||||
PropMaximumPacketSize byte = 39
|
||||
PropWildcardSubAvailable byte = 40
|
||||
PropSubIDAvailable byte = 41
|
||||
PropSharedSubAvailable byte = 42
|
||||
)
|
||||
|
||||
// validPacketProperties indicates which properties are valid for which packet types.
|
||||
var validPacketProperties = map[byte]map[byte]byte{
|
||||
PropPayloadFormat: {Publish: 1, WillProperties: 1},
|
||||
PropMessageExpiryInterval: {Publish: 1, WillProperties: 1},
|
||||
PropContentType: {Publish: 1, WillProperties: 1},
|
||||
PropResponseTopic: {Publish: 1, WillProperties: 1},
|
||||
PropCorrelationData: {Publish: 1, WillProperties: 1},
|
||||
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
|
||||
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
|
||||
PropAssignedClientID: {Connack: 1},
|
||||
PropServerKeepAlive: {Connack: 1},
|
||||
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
|
||||
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
|
||||
PropRequestProblemInfo: {Connect: 1},
|
||||
PropWillDelayInterval: {WillProperties: 1},
|
||||
PropRequestResponseInfo: {Connect: 1},
|
||||
PropResponseInfo: {Connack: 1},
|
||||
PropServerReference: {Connack: 1, Disconnect: 1},
|
||||
PropReasonString: {Connack: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Suback: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
|
||||
PropReceiveMaximum: {Connect: 1, Connack: 1},
|
||||
PropTopicAliasMaximum: {Connect: 1, Connack: 1},
|
||||
PropTopicAlias: {Publish: 1},
|
||||
PropMaximumQos: {Connack: 1},
|
||||
PropRetainAvailable: {Connack: 1},
|
||||
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1, WillProperties: 1},
|
||||
PropMaximumPacketSize: {Connect: 1, Connack: 1},
|
||||
PropWildcardSubAvailable: {Connack: 1},
|
||||
PropSubIDAvailable: {Connack: 1},
|
||||
PropSharedSubAvailable: {Connack: 1},
|
||||
}
|
||||
|
||||
// UserProperty is an arbitrary key-value pair for a packet user properties array.
|
||||
type UserProperty struct { // [MQTT-1.5.7-1]
|
||||
Key string `json:"k"`
|
||||
Val string `json:"v"`
|
||||
}
|
||||
|
||||
// Properties contains all mqtt v5 properties available for a packet.
|
||||
// Some properties have valid values of 0 or not-present. In this case, we opt for
|
||||
// property flags to indicate the usage of property.
|
||||
// Refer to mqtt v5 2.2.2.2 Property spec for more information.
|
||||
type Properties struct {
|
||||
CorrelationData []byte `json:"cd"`
|
||||
SubscriptionIdentifier []int `json:"si"`
|
||||
AuthenticationData []byte `json:"ad"`
|
||||
User []UserProperty `json:"user"`
|
||||
ContentType string `json:"ct"`
|
||||
ResponseTopic string `json:"rt"`
|
||||
AssignedClientID string `json:"aci"`
|
||||
AuthenticationMethod string `json:"am"`
|
||||
ResponseInfo string `json:"ri"`
|
||||
ServerReference string `json:"sr"`
|
||||
ReasonString string `json:"rs"`
|
||||
MessageExpiryInterval uint32 `json:"me"`
|
||||
SessionExpiryInterval uint32 `json:"sei"`
|
||||
WillDelayInterval uint32 `json:"wdi"`
|
||||
MaximumPacketSize uint32 `json:"mps"`
|
||||
ServerKeepAlive uint16 `json:"ska"`
|
||||
ReceiveMaximum uint16 `json:"rm"`
|
||||
TopicAliasMaximum uint16 `json:"tam"`
|
||||
TopicAlias uint16 `json:"ta"`
|
||||
PayloadFormat byte `json:"pf"`
|
||||
PayloadFormatFlag bool `json:"fpf"`
|
||||
SessionExpiryIntervalFlag bool `json:"fsei"`
|
||||
ServerKeepAliveFlag bool `json:"fska"`
|
||||
RequestProblemInfo byte `json:"rpi"`
|
||||
RequestProblemInfoFlag bool `json:"frpi"`
|
||||
RequestResponseInfo byte `json:"rri"`
|
||||
TopicAliasFlag bool `json:"fta"`
|
||||
MaximumQos byte `json:"mqos"`
|
||||
MaximumQosFlag bool `json:"fmqos"`
|
||||
RetainAvailable byte `json:"ra"`
|
||||
RetainAvailableFlag bool `json:"fra"`
|
||||
WildcardSubAvailable byte `json:"wsa"`
|
||||
WildcardSubAvailableFlag bool `json:"fwsa"`
|
||||
SubIDAvailable byte `json:"sida"`
|
||||
SubIDAvailableFlag bool `json:"fsida"`
|
||||
SharedSubAvailable byte `json:"ssa"`
|
||||
SharedSubAvailableFlag bool `json:"fssa"`
|
||||
}
|
||||
|
||||
// Copy creates a new Properties struct with copies of the values.
|
||||
func (p *Properties) Copy(allowTransfer bool) Properties {
|
||||
pr := Properties{
|
||||
PayloadFormat: p.PayloadFormat, // [MQTT-3.3.2-4]
|
||||
PayloadFormatFlag: p.PayloadFormatFlag,
|
||||
MessageExpiryInterval: p.MessageExpiryInterval,
|
||||
ContentType: p.ContentType, // [MQTT-3.3.2-20]
|
||||
ResponseTopic: p.ResponseTopic, // [MQTT-3.3.2-15]
|
||||
SessionExpiryInterval: p.SessionExpiryInterval,
|
||||
SessionExpiryIntervalFlag: p.SessionExpiryIntervalFlag,
|
||||
AssignedClientID: p.AssignedClientID,
|
||||
ServerKeepAlive: p.ServerKeepAlive,
|
||||
ServerKeepAliveFlag: p.ServerKeepAliveFlag,
|
||||
AuthenticationMethod: p.AuthenticationMethod,
|
||||
RequestProblemInfo: p.RequestProblemInfo,
|
||||
RequestProblemInfoFlag: p.RequestProblemInfoFlag,
|
||||
WillDelayInterval: p.WillDelayInterval,
|
||||
RequestResponseInfo: p.RequestResponseInfo,
|
||||
ResponseInfo: p.ResponseInfo,
|
||||
ServerReference: p.ServerReference,
|
||||
ReasonString: p.ReasonString,
|
||||
ReceiveMaximum: p.ReceiveMaximum,
|
||||
TopicAliasMaximum: p.TopicAliasMaximum,
|
||||
TopicAlias: 0, // NB; do not copy topic alias [MQTT-3.3.2-7] + we do not send to clients (currently) [MQTT-3.1.2-26] [MQTT-3.1.2-27]
|
||||
MaximumQos: p.MaximumQos,
|
||||
MaximumQosFlag: p.MaximumQosFlag,
|
||||
RetainAvailable: p.RetainAvailable,
|
||||
RetainAvailableFlag: p.RetainAvailableFlag,
|
||||
MaximumPacketSize: p.MaximumPacketSize,
|
||||
WildcardSubAvailable: p.WildcardSubAvailable,
|
||||
WildcardSubAvailableFlag: p.WildcardSubAvailableFlag,
|
||||
SubIDAvailable: p.SubIDAvailable,
|
||||
SubIDAvailableFlag: p.SubIDAvailableFlag,
|
||||
SharedSubAvailable: p.SharedSubAvailable,
|
||||
SharedSubAvailableFlag: p.SharedSubAvailableFlag,
|
||||
}
|
||||
|
||||
if allowTransfer {
|
||||
pr.TopicAlias = p.TopicAlias
|
||||
pr.TopicAliasFlag = p.TopicAliasFlag
|
||||
}
|
||||
|
||||
if len(p.CorrelationData) > 0 {
|
||||
pr.CorrelationData = append([]byte{}, p.CorrelationData...) // [MQTT-3.3.2-16]
|
||||
}
|
||||
|
||||
if len(p.SubscriptionIdentifier) > 0 {
|
||||
pr.SubscriptionIdentifier = append([]int{}, p.SubscriptionIdentifier...)
|
||||
}
|
||||
|
||||
if len(p.AuthenticationData) > 0 {
|
||||
pr.AuthenticationData = append([]byte{}, p.AuthenticationData...)
|
||||
}
|
||||
|
||||
if len(p.User) > 0 {
|
||||
pr.User = []UserProperty{}
|
||||
for _, v := range p.User {
|
||||
pr.User = append(pr.User, UserProperty{ // [MQTT-3.3.2-17]
|
||||
Key: v.Key,
|
||||
Val: v.Val,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return pr
|
||||
}
|
||||
|
||||
// canEncode returns true if the property type is valid for the packet type.
|
||||
func (p *Properties) canEncode(pkt byte, k byte) bool {
|
||||
return validPacketProperties[k][pkt] == 1
|
||||
}
|
||||
|
||||
// Encode encodes properties into a bytes buffer.
|
||||
func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
buf := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(buf)
|
||||
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
|
||||
buf.WriteByte(PropPayloadFormat)
|
||||
buf.WriteByte(p.PayloadFormat)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropMessageExpiryInterval) && p.MessageExpiryInterval > 0 {
|
||||
buf.WriteByte(PropMessageExpiryInterval)
|
||||
buf.Write(encodeUint32(p.MessageExpiryInterval))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropContentType) && p.ContentType != "" {
|
||||
buf.WriteByte(PropContentType)
|
||||
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
|
||||
}
|
||||
|
||||
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
|
||||
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropResponseTopic)
|
||||
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
|
||||
}
|
||||
|
||||
if mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropCorrelationData)
|
||||
buf.Write(encodeBytes(p.CorrelationData))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropSubscriptionIdentifier) && len(p.SubscriptionIdentifier) > 0 {
|
||||
for _, v := range p.SubscriptionIdentifier {
|
||||
if v > 0 {
|
||||
buf.WriteByte(PropSubscriptionIdentifier)
|
||||
encodeLength(buf, int64(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropSessionExpiryInterval) && p.SessionExpiryIntervalFlag { // [MQTT-3.14.2-2]
|
||||
buf.WriteByte(PropSessionExpiryInterval)
|
||||
buf.Write(encodeUint32(p.SessionExpiryInterval))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropAssignedClientID) && p.AssignedClientID != "" {
|
||||
buf.WriteByte(PropAssignedClientID)
|
||||
buf.Write(encodeString(p.AssignedClientID))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropServerKeepAlive) && p.ServerKeepAliveFlag {
|
||||
buf.WriteByte(PropServerKeepAlive)
|
||||
buf.Write(encodeUint16(p.ServerKeepAlive))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropAuthenticationMethod) && p.AuthenticationMethod != "" {
|
||||
buf.WriteByte(PropAuthenticationMethod)
|
||||
buf.Write(encodeString(p.AuthenticationMethod))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropAuthenticationData) && len(p.AuthenticationData) > 0 {
|
||||
buf.WriteByte(PropAuthenticationData)
|
||||
buf.Write(encodeBytes(p.AuthenticationData))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropRequestProblemInfo) && p.RequestProblemInfoFlag {
|
||||
buf.WriteByte(PropRequestProblemInfo)
|
||||
buf.WriteByte(p.RequestProblemInfo)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropWillDelayInterval) && p.WillDelayInterval > 0 {
|
||||
buf.WriteByte(PropWillDelayInterval)
|
||||
buf.Write(encodeUint32(p.WillDelayInterval))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropRequestResponseInfo) && p.RequestResponseInfo > 0 {
|
||||
buf.WriteByte(PropRequestResponseInfo)
|
||||
buf.WriteByte(p.RequestResponseInfo)
|
||||
}
|
||||
|
||||
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropResponseInfo)
|
||||
buf.Write(encodeString(p.ResponseInfo))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropServerReference) && len(p.ServerReference) > 0 {
|
||||
buf.WriteByte(PropServerReference)
|
||||
buf.Write(encodeString(p.ServerReference))
|
||||
}
|
||||
|
||||
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
|
||||
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
|
||||
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
|
||||
b := encodeString(p.ReasonString)
|
||||
if mods.MaxSize == 0 || uint32(n+len(b)+1) < mods.MaxSize {
|
||||
buf.WriteByte(PropReasonString)
|
||||
buf.Write(b)
|
||||
}
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropReceiveMaximum) && p.ReceiveMaximum > 0 {
|
||||
buf.WriteByte(PropReceiveMaximum)
|
||||
buf.Write(encodeUint16(p.ReceiveMaximum))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropTopicAliasMaximum) && p.TopicAliasMaximum > 0 {
|
||||
buf.WriteByte(PropTopicAliasMaximum)
|
||||
buf.Write(encodeUint16(p.TopicAliasMaximum))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropTopicAlias) && p.TopicAliasFlag && p.TopicAlias > 0 { // [MQTT-3.3.2-8]
|
||||
buf.WriteByte(PropTopicAlias)
|
||||
buf.Write(encodeUint16(p.TopicAlias))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropMaximumQos) && p.MaximumQosFlag && p.MaximumQos < 2 {
|
||||
buf.WriteByte(PropMaximumQos)
|
||||
buf.WriteByte(p.MaximumQos)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropRetainAvailable) && p.RetainAvailableFlag {
|
||||
buf.WriteByte(PropRetainAvailable)
|
||||
buf.WriteByte(p.RetainAvailable)
|
||||
}
|
||||
|
||||
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
for _, v := range p.User {
|
||||
pb.WriteByte(PropUser)
|
||||
pb.Write(encodeString(v.Key))
|
||||
pb.Write(encodeString(v.Val))
|
||||
}
|
||||
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
|
||||
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
|
||||
if mods.MaxSize == 0 || uint32(n+pb.Len()+1) < mods.MaxSize {
|
||||
buf.Write(pb.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropMaximumPacketSize) && p.MaximumPacketSize > 0 {
|
||||
buf.WriteByte(PropMaximumPacketSize)
|
||||
buf.Write(encodeUint32(p.MaximumPacketSize))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropWildcardSubAvailable) && p.WildcardSubAvailableFlag {
|
||||
buf.WriteByte(PropWildcardSubAvailable)
|
||||
buf.WriteByte(p.WildcardSubAvailable)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropSubIDAvailable) && p.SubIDAvailableFlag {
|
||||
buf.WriteByte(PropSubIDAvailable)
|
||||
buf.WriteByte(p.SubIDAvailable)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropSharedSubAvailable) && p.SharedSubAvailableFlag {
|
||||
buf.WriteByte(PropSharedSubAvailable)
|
||||
buf.WriteByte(p.SharedSubAvailable)
|
||||
}
|
||||
|
||||
encodeLength(b, int64(buf.Len()))
|
||||
b.Write(buf.Bytes()) // [MQTT-3.1.3-10]
|
||||
}
|
||||
|
||||
// Decode decodes property bytes into a properties struct.
|
||||
func (p *Properties) Decode(pkt byte, b *bytes.Buffer) (n int, err error) {
|
||||
if p == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var bu int
|
||||
n, bu, err = DecodeLength(b)
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return n + bu, nil
|
||||
}
|
||||
|
||||
bt := b.Bytes()
|
||||
var k byte
|
||||
for offset := 0; offset < n; {
|
||||
k, offset, err = decodeByte(bt, offset)
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
|
||||
if _, ok := validPacketProperties[k][pkt]; !ok {
|
||||
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pkt, ErrProtocolViolationUnsupportedProperty)
|
||||
}
|
||||
|
||||
switch k {
|
||||
case PropPayloadFormat:
|
||||
p.PayloadFormat, offset, err = decodeByte(bt, offset)
|
||||
p.PayloadFormatFlag = true
|
||||
case PropMessageExpiryInterval:
|
||||
p.MessageExpiryInterval, offset, err = decodeUint32(bt, offset)
|
||||
case PropContentType:
|
||||
p.ContentType, offset, err = decodeString(bt, offset)
|
||||
case PropResponseTopic:
|
||||
p.ResponseTopic, offset, err = decodeString(bt, offset)
|
||||
case PropCorrelationData:
|
||||
p.CorrelationData, offset, err = decodeBytes(bt, offset)
|
||||
case PropSubscriptionIdentifier:
|
||||
if p.SubscriptionIdentifier == nil {
|
||||
p.SubscriptionIdentifier = []int{}
|
||||
}
|
||||
|
||||
n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:]))
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n)
|
||||
offset += bu
|
||||
case PropSessionExpiryInterval:
|
||||
p.SessionExpiryInterval, offset, err = decodeUint32(bt, offset)
|
||||
p.SessionExpiryIntervalFlag = true
|
||||
case PropAssignedClientID:
|
||||
p.AssignedClientID, offset, err = decodeString(bt, offset)
|
||||
case PropServerKeepAlive:
|
||||
p.ServerKeepAlive, offset, err = decodeUint16(bt, offset)
|
||||
p.ServerKeepAliveFlag = true
|
||||
case PropAuthenticationMethod:
|
||||
p.AuthenticationMethod, offset, err = decodeString(bt, offset)
|
||||
case PropAuthenticationData:
|
||||
p.AuthenticationData, offset, err = decodeBytes(bt, offset)
|
||||
case PropRequestProblemInfo:
|
||||
p.RequestProblemInfo, offset, err = decodeByte(bt, offset)
|
||||
p.RequestProblemInfoFlag = true
|
||||
case PropWillDelayInterval:
|
||||
p.WillDelayInterval, offset, err = decodeUint32(bt, offset)
|
||||
case PropRequestResponseInfo:
|
||||
p.RequestResponseInfo, offset, err = decodeByte(bt, offset)
|
||||
case PropResponseInfo:
|
||||
p.ResponseInfo, offset, err = decodeString(bt, offset)
|
||||
case PropServerReference:
|
||||
p.ServerReference, offset, err = decodeString(bt, offset)
|
||||
case PropReasonString:
|
||||
p.ReasonString, offset, err = decodeString(bt, offset)
|
||||
case PropReceiveMaximum:
|
||||
p.ReceiveMaximum, offset, err = decodeUint16(bt, offset)
|
||||
case PropTopicAliasMaximum:
|
||||
p.TopicAliasMaximum, offset, err = decodeUint16(bt, offset)
|
||||
case PropTopicAlias:
|
||||
p.TopicAlias, offset, err = decodeUint16(bt, offset)
|
||||
p.TopicAliasFlag = true
|
||||
case PropMaximumQos:
|
||||
p.MaximumQos, offset, err = decodeByte(bt, offset)
|
||||
p.MaximumQosFlag = true
|
||||
case PropRetainAvailable:
|
||||
p.RetainAvailable, offset, err = decodeByte(bt, offset)
|
||||
p.RetainAvailableFlag = true
|
||||
case PropUser:
|
||||
var k, v string
|
||||
k, offset, err = decodeString(bt, offset)
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
v, offset, err = decodeString(bt, offset)
|
||||
p.User = append(p.User, UserProperty{Key: k, Val: v})
|
||||
case PropMaximumPacketSize:
|
||||
p.MaximumPacketSize, offset, err = decodeUint32(bt, offset)
|
||||
case PropWildcardSubAvailable:
|
||||
p.WildcardSubAvailable, offset, err = decodeByte(bt, offset)
|
||||
p.WildcardSubAvailableFlag = true
|
||||
case PropSubIDAvailable:
|
||||
p.SubIDAvailable, offset, err = decodeByte(bt, offset)
|
||||
p.SubIDAvailableFlag = true
|
||||
case PropSharedSubAvailable:
|
||||
p.SharedSubAvailable, offset, err = decodeByte(bt, offset)
|
||||
p.SharedSubAvailableFlag = true
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
}
|
||||
|
||||
return n + bu, nil
|
||||
}
|
||||
333
packets/properties_test.go
Normal file
333
packets/properties_test.go
Normal file
@@ -0,0 +1,333 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
propertiesStruct = Properties{
|
||||
PayloadFormat: byte(1), // UTF-8 Format
|
||||
PayloadFormatFlag: true,
|
||||
MessageExpiryInterval: uint32(2),
|
||||
ContentType: "text/plain",
|
||||
ResponseTopic: "a/b/c",
|
||||
CorrelationData: []byte("data"),
|
||||
SubscriptionIdentifier: []int{322122},
|
||||
SessionExpiryInterval: uint32(120),
|
||||
SessionExpiryIntervalFlag: true,
|
||||
AssignedClientID: "mochi-v5",
|
||||
ServerKeepAlive: uint16(20),
|
||||
ServerKeepAliveFlag: true,
|
||||
AuthenticationMethod: "SHA-1",
|
||||
AuthenticationData: []byte("auth-data"),
|
||||
RequestProblemInfo: byte(1),
|
||||
RequestProblemInfoFlag: true,
|
||||
WillDelayInterval: uint32(600),
|
||||
RequestResponseInfo: byte(1),
|
||||
ResponseInfo: "response",
|
||||
ServerReference: "mochi-2",
|
||||
ReasonString: "reason",
|
||||
ReceiveMaximum: uint16(500),
|
||||
TopicAliasMaximum: uint16(999),
|
||||
TopicAlias: uint16(3),
|
||||
TopicAliasFlag: true,
|
||||
MaximumQos: byte(1),
|
||||
MaximumQosFlag: true,
|
||||
RetainAvailable: byte(1),
|
||||
RetainAvailableFlag: true,
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
{
|
||||
Key: "key2",
|
||||
Val: "value2",
|
||||
},
|
||||
},
|
||||
MaximumPacketSize: uint32(32000),
|
||||
WildcardSubAvailable: byte(1),
|
||||
WildcardSubAvailableFlag: true,
|
||||
SubIDAvailable: byte(1),
|
||||
SubIDAvailableFlag: true,
|
||||
SharedSubAvailable: byte(1),
|
||||
SharedSubAvailableFlag: true,
|
||||
}
|
||||
|
||||
propertiesBytes = []byte{
|
||||
172, 1, // VBI
|
||||
|
||||
// Payload Format (1) (vbi:2)
|
||||
1, 1,
|
||||
|
||||
// Message Expiry (2) (vbi:7)
|
||||
2, 0, 0, 0, 2,
|
||||
|
||||
// Content Type (3) (vbi:20)
|
||||
3,
|
||||
0, 10, 't', 'e', 'x', 't', '/', 'p', 'l', 'a', 'i', 'n',
|
||||
|
||||
// Response Topic (8) (vbi:28)
|
||||
8,
|
||||
0, 5, 'a', '/', 'b', '/', 'c',
|
||||
|
||||
// Correlations Data (9) (vbi:35)
|
||||
9,
|
||||
0, 4, 'd', 'a', 't', 'a',
|
||||
|
||||
// Subscription Identifier (11) (vbi:39)
|
||||
11,
|
||||
202, 212, 19,
|
||||
|
||||
// Session Expiry Interval (17) (vbi:43)
|
||||
17,
|
||||
0, 0, 0, 120,
|
||||
|
||||
// Assigned Client ID (18) (vbi:55)
|
||||
18,
|
||||
0, 8, 'm', 'o', 'c', 'h', 'i', '-', 'v', '5',
|
||||
|
||||
// Server Keep Alive (19) (vbi:58)
|
||||
19,
|
||||
0, 20,
|
||||
|
||||
// Authentication Method (21) (vbi:66)
|
||||
21,
|
||||
0, 5, 'S', 'H', 'A', '-', '1',
|
||||
|
||||
// Authentication Data (22) (vbi:78)
|
||||
22,
|
||||
0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a',
|
||||
|
||||
// Request Problem Info (23) (vbi:80)
|
||||
23, 1,
|
||||
|
||||
// Will Delay Interval (24) (vbi:85)
|
||||
24,
|
||||
0, 0, 2, 88,
|
||||
|
||||
// Request Response Info (25) (vbi:87)
|
||||
25, 1,
|
||||
|
||||
// Response Info (26) (vbi:98)
|
||||
26,
|
||||
0, 8, 'r', 'e', 's', 'p', 'o', 'n', 's', 'e',
|
||||
|
||||
// Server Reference (28) (vbi:108)
|
||||
28,
|
||||
0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2',
|
||||
|
||||
// Reason String (31) (vbi:117)
|
||||
31,
|
||||
0, 6, 'r', 'e', 'a', 's', 'o', 'n',
|
||||
|
||||
// Receive Maximum (33) (vbi:120)
|
||||
33,
|
||||
1, 244,
|
||||
|
||||
// Topic Alias Maximum (34) (vbi:123)
|
||||
34,
|
||||
3, 231,
|
||||
|
||||
// Topic Alias (35) (vbi:126)
|
||||
35,
|
||||
0, 3,
|
||||
|
||||
// Maximum Qos (36) (vbi:128)
|
||||
36, 1,
|
||||
|
||||
// Retain Available (37) (vbi: 130)
|
||||
37, 1,
|
||||
|
||||
// User Properties (38) (vbi:161)
|
||||
38,
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
38,
|
||||
0, 4, 'k', 'e', 'y', '2',
|
||||
0, 6, 'v', 'a', 'l', 'u', 'e', '2',
|
||||
|
||||
// Maximum Packet Size (39) (vbi:166)
|
||||
39,
|
||||
0, 0, 125, 0,
|
||||
|
||||
// Wildcard Subscriptions Available (40) (vbi:168)
|
||||
40, 1,
|
||||
|
||||
// Subscription ID Available (41) (vbi:170)
|
||||
41, 1,
|
||||
|
||||
// Shared Subscriptions Available (42) (vbi:172)
|
||||
42, 1,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
validPacketProperties[PropPayloadFormat][Reserved] = 1
|
||||
validPacketProperties[PropMessageExpiryInterval][Reserved] = 1
|
||||
validPacketProperties[PropContentType][Reserved] = 1
|
||||
validPacketProperties[PropResponseTopic][Reserved] = 1
|
||||
validPacketProperties[PropCorrelationData][Reserved] = 1
|
||||
validPacketProperties[PropSubscriptionIdentifier][Reserved] = 1
|
||||
validPacketProperties[PropSessionExpiryInterval][Reserved] = 1
|
||||
validPacketProperties[PropAssignedClientID][Reserved] = 1
|
||||
validPacketProperties[PropServerKeepAlive][Reserved] = 1
|
||||
validPacketProperties[PropAuthenticationMethod][Reserved] = 1
|
||||
validPacketProperties[PropAuthenticationData][Reserved] = 1
|
||||
validPacketProperties[PropRequestProblemInfo][Reserved] = 1
|
||||
validPacketProperties[PropWillDelayInterval][Reserved] = 1
|
||||
validPacketProperties[PropRequestResponseInfo][Reserved] = 1
|
||||
validPacketProperties[PropResponseInfo][Reserved] = 1
|
||||
validPacketProperties[PropServerReference][Reserved] = 1
|
||||
validPacketProperties[PropReasonString][Reserved] = 1
|
||||
validPacketProperties[PropReceiveMaximum][Reserved] = 1
|
||||
validPacketProperties[PropTopicAliasMaximum][Reserved] = 1
|
||||
validPacketProperties[PropTopicAlias][Reserved] = 1
|
||||
validPacketProperties[PropMaximumQos][Reserved] = 1
|
||||
validPacketProperties[PropRetainAvailable][Reserved] = 1
|
||||
validPacketProperties[PropUser][Reserved] = 1
|
||||
validPacketProperties[PropMaximumPacketSize][Reserved] = 1
|
||||
validPacketProperties[PropWildcardSubAvailable][Reserved] = 1
|
||||
validPacketProperties[PropSubIDAvailable][Reserved] = 1
|
||||
validPacketProperties[PropSharedSubAvailable][Reserved] = 1
|
||||
}
|
||||
|
||||
func TestEncodeProperties(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
|
||||
require.Equal(t, propertiesBytes, b.Bytes())
|
||||
}
|
||||
|
||||
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(Reserved, Mods{DisallowProblemInfo: true}, b, 0)
|
||||
require.NotEqual(t, propertiesBytes, b.Bytes())
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{26, 0, 8}))
|
||||
}
|
||||
|
||||
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(Reserved, Mods{AllowResponseInfo: false}, b, 0)
|
||||
require.NotEqual(t, propertiesBytes, b.Bytes())
|
||||
require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
|
||||
require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
|
||||
}
|
||||
|
||||
func TestEncodePropertiesNil(t *testing.T) {
|
||||
type tmp struct {
|
||||
p *Properties
|
||||
}
|
||||
|
||||
pr := tmp{}
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
pr.p.Encode(Reserved, Mods{}, b, 0)
|
||||
require.Equal(t, []byte{}, b.Bytes())
|
||||
}
|
||||
|
||||
func TestEncodeZeroProperties(t *testing.T) {
|
||||
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
|
||||
props := new(Properties)
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
|
||||
require.Equal(t, []byte{0x00}, b.Bytes())
|
||||
}
|
||||
|
||||
func TestDecodeProperties(t *testing.T) {
|
||||
b := bytes.NewBuffer(propertiesBytes)
|
||||
|
||||
props := new(Properties)
|
||||
n, err := props.Decode(Reserved, b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 172+2, n)
|
||||
require.EqualValues(t, propertiesStruct, *props)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesNil(t *testing.T) {
|
||||
b := bytes.NewBuffer(propertiesBytes)
|
||||
|
||||
type tmp struct {
|
||||
p *Properties
|
||||
}
|
||||
|
||||
pr := tmp{}
|
||||
n, err := pr.p.Decode(Reserved, b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, n)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesBadInitialVBI(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, ErrMalformedVariableByteInteger, err)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesZeroLengthVBI(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{0})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, props, new(Properties))
|
||||
}
|
||||
|
||||
func TestDecodePropertiesBadKeyByte(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{64, 1})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrMalformedOffsetByteOutOfRange)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesInvalidForPacket(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{1, 99})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrProtocolViolationUnsupportedProperty)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesGeneralFailure(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{10, 11, 202, 212, 19})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesBadSubscriptionID(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{10, 11, 255, 255, 255, 255, 255, 255, 255, 255})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesBadUserProps(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{10, 38, 255, 255, 255, 255, 255, 255, 255, 255})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCopyProperties(t *testing.T) {
|
||||
require.EqualValues(t, propertiesStruct, propertiesStruct.Copy(true))
|
||||
}
|
||||
|
||||
func TestCopyPropertiesNoTransfer(t *testing.T) {
|
||||
pkA := propertiesStruct
|
||||
pkB := pkA.Copy(false)
|
||||
|
||||
// Properties which should never be transferred from one connection to another
|
||||
require.Equal(t, uint16(0), pkB.TopicAlias)
|
||||
}
|
||||
4031
packets/tpackets.go
Normal file
4031
packets/tpackets.go
Normal file
File diff suppressed because it is too large
Load Diff
33
packets/tpackets_test.go
Normal file
33
packets/tpackets_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func encodeTestOK(wanted TPacketCase) bool {
|
||||
if wanted.RawBytes == nil {
|
||||
return false
|
||||
}
|
||||
if wanted.Group != "" && wanted.Group != "encode" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func decodeTestOK(wanted TPacketCase) bool {
|
||||
if wanted.Group != "" && wanted.Group != "decode" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestTPacketCaseGet(t *testing.T) {
|
||||
require.Equal(t, TPacketData[Connect][1], TPacketData[Connect].Get(TConnectMqtt311))
|
||||
require.Equal(t, TPacketCase{}, TPacketData[Connect].Get(byte(128)))
|
||||
}
|
||||
Reference in New Issue
Block a user