代码整理
This commit is contained in:
99
listeners/http_healthcheck.go
Normal file
99
listeners/http_healthcheck.go
Normal file
@@ -0,0 +1,99 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: Derek Duncan
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const TypeHealthCheck = "healthcheck"
|
||||
|
||||
// HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint.
|
||||
type HTTPHealthCheck struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config Config // configuration values for the listener
|
||||
listen *http.Server // the http server
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewHTTPHealthCheck initializes and returns a new HTTP listener, listening on an address.
|
||||
func NewHTTPHealthCheck(config Config) *HTTPHealthCheck {
|
||||
return &HTTPHealthCheck{
|
||||
id: config.ID,
|
||||
address: config.Address,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *HTTPHealthCheck) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *HTTPHealthCheck) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *HTTPHealthCheck) Protocol() string {
|
||||
if l.listen != nil && l.listen.TLSConfig != nil {
|
||||
return "https"
|
||||
}
|
||||
|
||||
return "http"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *HTTPHealthCheck) Init(_ *slog.Logger) error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
l.listen = &http.Server{
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
Addr: l.address,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
if l.config.TLSConfig != nil {
|
||||
l.listen.TLSConfig = l.config.TLSConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts listening for new connections and serving responses.
|
||||
func (l *HTTPHealthCheck) Serve(establish EstablishFn) {
|
||||
if l.listen.TLSConfig != nil {
|
||||
_ = l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
_ = l.listen.ListenAndServe()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *HTTPHealthCheck) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
}
|
||||
137
listeners/http_healthcheck_test.go
Normal file
137
listeners/http_healthcheck_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: Derek Duncan
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewHTTPHealthCheck(t *testing.T) {
|
||||
l := NewHTTPHealthCheck(basicConfig)
|
||||
require.Equal(t, basicConfig.ID, l.id)
|
||||
require.Equal(t, basicConfig.Address, l.address)
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckID(t *testing.T) {
|
||||
l := NewHTTPHealthCheck(basicConfig)
|
||||
require.Equal(t, basicConfig.ID, l.ID())
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckAddress(t *testing.T) {
|
||||
l := NewHTTPHealthCheck(basicConfig)
|
||||
require.Equal(t, basicConfig.Address, l.Address())
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckProtocol(t *testing.T) {
|
||||
l := NewHTTPHealthCheck(basicConfig)
|
||||
require.Equal(t, "http", l.Protocol())
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckTLSProtocol(t *testing.T) {
|
||||
l := NewHTTPHealthCheck(tlsConfig)
|
||||
_ = l.Init(logger)
|
||||
require.Equal(t, "https", l.Protocol())
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckInit(t *testing.T) {
|
||||
l := NewHTTPHealthCheck(basicConfig)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, l.listen)
|
||||
require.Equal(t, basicConfig.Address, l.listen.Addr)
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckServeAndClose(t *testing.T) {
|
||||
// setup http stats listener
|
||||
l := NewHTTPHealthCheck(basicConfig)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// call healthcheck
|
||||
resp, err := http.Get("http://localhost" + testAddr + "/healthcheck")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
defer resp.Body.Close()
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ensure listening is closed
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.Equal(t, true, closed)
|
||||
|
||||
_, err = http.Get("http://localhost/healthcheck" + testAddr + "/healthcheck")
|
||||
require.Error(t, err)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) {
|
||||
// setup http stats listener
|
||||
l := NewHTTPHealthCheck(basicConfig)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// make disallowed method type http request
|
||||
resp, err := http.Post("http://localhost"+testAddr+"/healthcheck", "application/json", http.NoBody)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
defer resp.Body.Close()
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ensure listening is closed
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.Equal(t, true, closed)
|
||||
|
||||
_, err = http.Post("http://localhost/healthcheck"+testAddr+"/healthcheck", "application/json", http.NoBody)
|
||||
require.Error(t, err)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) {
|
||||
l := NewHTTPHealthCheck(tlsConfig)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
l.Close(MockCloser)
|
||||
}
|
||||
122
listeners/http_sysinfo.go
Normal file
122
listeners/http_sysinfo.go
Normal file
@@ -0,0 +1,122 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"testmqtt/system"
|
||||
)
|
||||
|
||||
const TypeSysInfo = "sysinfo"
|
||||
|
||||
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
|
||||
type HTTPStats struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config Config // configuration values for the listener
|
||||
listen *http.Server // the http server
|
||||
sysInfo *system.Info // pointers to the server data
|
||||
log *slog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewHTTPStats initializes and returns a new HTTP listener, listening on an address.
|
||||
func NewHTTPStats(config Config, sysInfo *system.Info) *HTTPStats {
|
||||
return &HTTPStats{
|
||||
sysInfo: sysInfo,
|
||||
id: config.ID,
|
||||
address: config.Address,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *HTTPStats) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *HTTPStats) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *HTTPStats) Protocol() string {
|
||||
if l.listen != nil && l.listen.TLSConfig != nil {
|
||||
return "https"
|
||||
}
|
||||
|
||||
return "http"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *HTTPStats) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", l.jsonHandler)
|
||||
l.listen = &http.Server{
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
Addr: l.address,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
if l.config.TLSConfig != nil {
|
||||
l.listen.TLSConfig = l.config.TLSConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts listening for new connections and serving responses.
|
||||
func (l *HTTPStats) Serve(establish EstablishFn) {
|
||||
|
||||
var err error
|
||||
if l.listen.TLSConfig != nil {
|
||||
err = l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
err = l.listen.ListenAndServe()
|
||||
}
|
||||
|
||||
// After the listener has been shutdown, no need to print the http.ErrServerClosed error.
|
||||
if err != nil && atomic.LoadUint32(&l.end) == 0 {
|
||||
l.log.Error("failed to serve.", "error", err, "listener", l.id)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *HTTPStats) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
|
||||
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
|
||||
info := *l.sysInfo.Clone()
|
||||
|
||||
out, err := json.MarshalIndent(info, "", "\t")
|
||||
if err != nil {
|
||||
_, _ = io.WriteString(w, err.Error())
|
||||
}
|
||||
|
||||
_, _ = w.Write(out)
|
||||
}
|
||||
149
listeners/http_sysinfo_test.go
Normal file
149
listeners/http_sysinfo_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"testmqtt/system"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewHTTPStats(t *testing.T) {
|
||||
l := NewHTTPStats(basicConfig, nil)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testAddr, l.address)
|
||||
}
|
||||
|
||||
func TestHTTPStatsID(t *testing.T) {
|
||||
l := NewHTTPStats(basicConfig, nil)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestHTTPStatsAddress(t *testing.T) {
|
||||
l := NewHTTPStats(basicConfig, nil)
|
||||
require.Equal(t, testAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestHTTPStatsProtocol(t *testing.T) {
|
||||
l := NewHTTPStats(basicConfig, nil)
|
||||
require.Equal(t, "http", l.Protocol())
|
||||
}
|
||||
|
||||
func TestHTTPStatsTLSProtocol(t *testing.T) {
|
||||
l := NewHTTPStats(tlsConfig, nil)
|
||||
_ = l.Init(logger)
|
||||
require.Equal(t, "https", l.Protocol())
|
||||
}
|
||||
|
||||
func TestHTTPStatsInit(t *testing.T) {
|
||||
sysInfo := new(system.Info)
|
||||
l := NewHTTPStats(basicConfig, sysInfo)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, l.sysInfo)
|
||||
require.Equal(t, sysInfo, l.sysInfo)
|
||||
require.NotNil(t, l.listen)
|
||||
require.Equal(t, testAddr, l.listen.Addr)
|
||||
}
|
||||
|
||||
func TestHTTPStatsServeAndClose(t *testing.T) {
|
||||
sysInfo := &system.Info{
|
||||
Version: "test",
|
||||
}
|
||||
|
||||
// setup http stats listener
|
||||
l := NewHTTPStats(basicConfig, sysInfo)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// get body from stats address
|
||||
resp, err := http.Get("http://localhost" + testAddr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
// decode body from json and check data
|
||||
v := new(system.Info)
|
||||
err = json.Unmarshal(body, v)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", v.Version)
|
||||
|
||||
// ensure listening is closed
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.Equal(t, true, closed)
|
||||
|
||||
_, err = http.Get("http://localhost" + testAddr)
|
||||
require.Error(t, err)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestHTTPStatsServeTLSAndClose(t *testing.T) {
|
||||
sysInfo := &system.Info{
|
||||
Version: "test",
|
||||
}
|
||||
|
||||
l := NewHTTPStats(tlsConfig, sysInfo)
|
||||
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
l.Close(MockCloser)
|
||||
}
|
||||
|
||||
func TestHTTPStatsFailedToServe(t *testing.T) {
|
||||
sysInfo := &system.Info{
|
||||
Version: "test",
|
||||
}
|
||||
|
||||
// setup http stats listener
|
||||
config := basicConfig
|
||||
config.Address = "wrong_addr"
|
||||
l := NewHTTPStats(config, sysInfo)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
<-o
|
||||
// ensure listening is closed
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
135
listeners/listeners.go
Normal file
135
listeners/listeners.go
Normal file
@@ -0,0 +1,135 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// Config contains configuration values for a listener.
|
||||
type Config struct {
|
||||
Type string
|
||||
ID string
|
||||
Address string
|
||||
// TLSConfig is a tls.Config configuration to be used with the listener. See examples folder for basic and mutual-tls use.
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// EstablishFn is a callback function for establishing new clients.
|
||||
type EstablishFn func(id string, c net.Conn) error
|
||||
|
||||
// CloseFn is a callback function for closing all listener clients.
|
||||
type CloseFn func(id string)
|
||||
|
||||
// Listener is an interface for network listeners. A network listener listens
|
||||
// for incoming client connections and adds them to the server.
|
||||
type Listener interface {
|
||||
Init(*slog.Logger) error // open the network address
|
||||
Serve(EstablishFn) // starting actively listening for new connections
|
||||
ID() string // return the id of the listener
|
||||
Address() string // the address of the listener
|
||||
Protocol() string // the protocol in use by the listener
|
||||
Close(CloseFn) // stop and close the listener
|
||||
}
|
||||
|
||||
// Listeners contains the network listeners for the broker.
|
||||
type Listeners struct {
|
||||
ClientsWg sync.WaitGroup // a waitgroup that waits for all clients in all listeners to finish.
|
||||
internal map[string]Listener // a map of active listeners.
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// New returns a new instance of Listeners.
|
||||
func New() *Listeners {
|
||||
return &Listeners{
|
||||
internal: map[string]Listener{},
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new listener to the listeners map, keyed on id.
|
||||
func (l *Listeners) Add(val Listener) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.internal[val.ID()] = val
|
||||
}
|
||||
|
||||
// Get returns the value of a listener if it exists.
|
||||
func (l *Listeners) Get(id string) (Listener, bool) {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
val, ok := l.internal[id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the length of the listeners map.
|
||||
func (l *Listeners) Len() int {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
return len(l.internal)
|
||||
}
|
||||
|
||||
// Delete removes a listener from the internal map.
|
||||
func (l *Listeners) Delete(id string) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
delete(l.internal, id)
|
||||
}
|
||||
|
||||
// Serve starts a listener serving from the internal map.
|
||||
func (l *Listeners) Serve(id string, establisher EstablishFn) {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
listener := l.internal[id]
|
||||
|
||||
go func(e EstablishFn) {
|
||||
listener.Serve(e)
|
||||
}(establisher)
|
||||
}
|
||||
|
||||
// ServeAll starts all listeners serving from the internal map.
|
||||
func (l *Listeners) ServeAll(establisher EstablishFn) {
|
||||
l.RLock()
|
||||
i := 0
|
||||
ids := make([]string, len(l.internal))
|
||||
for id := range l.internal {
|
||||
ids[i] = id
|
||||
i++
|
||||
}
|
||||
l.RUnlock()
|
||||
|
||||
for _, id := range ids {
|
||||
l.Serve(id, establisher)
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops a listener from the internal map.
|
||||
func (l *Listeners) Close(id string, closer CloseFn) {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
if listener, ok := l.internal[id]; ok {
|
||||
listener.Close(closer)
|
||||
}
|
||||
}
|
||||
|
||||
// CloseAll iterates and closes all registered listeners.
|
||||
func (l *Listeners) CloseAll(closer CloseFn) {
|
||||
l.RLock()
|
||||
i := 0
|
||||
ids := make([]string, len(l.internal))
|
||||
for id := range l.internal {
|
||||
ids[i] = id
|
||||
i++
|
||||
}
|
||||
l.RUnlock()
|
||||
|
||||
for _, id := range ids {
|
||||
l.Close(id, closer)
|
||||
}
|
||||
l.ClientsWg.Wait()
|
||||
}
|
||||
182
listeners/listeners_test.go
Normal file
182
listeners/listeners_test.go
Normal file
@@ -0,0 +1,182 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testAddr = ":22222"
|
||||
|
||||
var (
|
||||
basicConfig = Config{ID: "t1", Address: testAddr}
|
||||
tlsConfig = Config{ID: "t1", Address: testAddr, TLSConfig: tlsConfigBasic}
|
||||
|
||||
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
|
||||
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
|
||||
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
|
||||
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
|
||||
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
|
||||
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
|
||||
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
|
||||
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
|
||||
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
|
||||
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
|
||||
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
|
||||
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
|
||||
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
|
||||
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
|
||||
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
|
||||
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
|
||||
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
|
||||
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
|
||||
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
|
||||
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
|
||||
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
|
||||
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
|
||||
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
|
||||
tlsConfigBasic *tls.Config
|
||||
)
|
||||
|
||||
func init() {
|
||||
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Basic TLS Config
|
||||
tlsConfigBasic = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
tlsConfig.TLSConfig = tlsConfigBasic
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
l := New()
|
||||
require.NotNil(t, l.internal)
|
||||
}
|
||||
|
||||
func TestAddListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
}
|
||||
|
||||
func TestGetListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Add(NewMockListener("t2", testAddr))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
require.Contains(t, l.internal, "t2")
|
||||
|
||||
g, ok := l.Get("t1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, g.ID(), "t1")
|
||||
}
|
||||
|
||||
func TestLenListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Add(NewMockListener("t2", testAddr))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
require.Contains(t, l.internal, "t2")
|
||||
require.Equal(t, 2, l.Len())
|
||||
}
|
||||
|
||||
func TestDeleteListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
l.Delete("t1")
|
||||
_, ok := l.Get("t1")
|
||||
require.False(t, ok)
|
||||
require.Nil(t, l.internal["t1"])
|
||||
}
|
||||
|
||||
func TestServeListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Serve("t1", MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
require.True(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
|
||||
l.Close("t1", MockCloser)
|
||||
require.False(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
}
|
||||
|
||||
func TestServeAllListeners(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Add(NewMockListener("t2", testAddr))
|
||||
l.Add(NewMockListener("t3", testAddr))
|
||||
l.ServeAll(MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
require.True(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
require.True(t, l.internal["t2"].(*MockListener).IsServing())
|
||||
require.True(t, l.internal["t3"].(*MockListener).IsServing())
|
||||
|
||||
l.Close("t1", MockCloser)
|
||||
l.Close("t2", MockCloser)
|
||||
l.Close("t3", MockCloser)
|
||||
|
||||
require.False(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
require.False(t, l.internal["t2"].(*MockListener).IsServing())
|
||||
require.False(t, l.internal["t3"].(*MockListener).IsServing())
|
||||
}
|
||||
|
||||
func TestCloseListener(t *testing.T) {
|
||||
l := New()
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
l.Add(mocked)
|
||||
l.Serve("t1", MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
var closed bool
|
||||
l.Close("t1", func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.True(t, closed)
|
||||
}
|
||||
|
||||
func TestCloseAllListeners(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Add(NewMockListener("t2", testAddr))
|
||||
l.Add(NewMockListener("t3", testAddr))
|
||||
l.ServeAll(MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
require.True(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
require.True(t, l.internal["t2"].(*MockListener).IsServing())
|
||||
require.True(t, l.internal["t3"].(*MockListener).IsServing())
|
||||
|
||||
closed := make(map[string]bool)
|
||||
l.CloseAll(func(id string) {
|
||||
closed[id] = true
|
||||
})
|
||||
require.Contains(t, closed, "t1")
|
||||
require.Contains(t, closed, "t2")
|
||||
require.Contains(t, closed, "t3")
|
||||
require.True(t, closed["t1"])
|
||||
require.True(t, closed["t2"])
|
||||
require.True(t, closed["t3"])
|
||||
}
|
||||
105
listeners/mock.go
Normal file
105
listeners/mock.go
Normal file
@@ -0,0 +1,105 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
const TypeMock = "mock"
|
||||
|
||||
// MockEstablisher is a function signature which can be used in testing.
|
||||
func MockEstablisher(id string, c net.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MockCloser is a function signature which can be used in testing.
|
||||
func MockCloser(id string) {}
|
||||
|
||||
// MockListener is a mock listener for establishing client connections.
|
||||
type MockListener struct {
|
||||
sync.RWMutex
|
||||
id string // the id of the listener
|
||||
address string // the network address the listener binds to
|
||||
Config *Config // configuration for the listener
|
||||
done chan bool // indicate the listener is done
|
||||
Serving bool // indicate the listener is serving
|
||||
Listening bool // indiciate the listener is listening
|
||||
ErrListen bool // throw an error on listen
|
||||
}
|
||||
|
||||
// NewMockListener returns a new instance of MockListener.
|
||||
func NewMockListener(id, address string) *MockListener {
|
||||
return &MockListener{
|
||||
id: id,
|
||||
address: address,
|
||||
done: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
// Serve serves the mock listener.
|
||||
func (l *MockListener) Serve(establisher EstablishFn) {
|
||||
l.Lock()
|
||||
l.Serving = true
|
||||
l.Unlock()
|
||||
|
||||
for range l.done {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *MockListener) Init(log *slog.Logger) error {
|
||||
if l.ErrListen {
|
||||
return fmt.Errorf("listen failure")
|
||||
}
|
||||
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.Listening = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// ID returns the id of the mock listener.
|
||||
func (l *MockListener) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *MockListener) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *MockListener) Protocol() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
// Close closes the mock listener.
|
||||
func (l *MockListener) Close(closer CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.Serving = false
|
||||
closer(l.id)
|
||||
close(l.done)
|
||||
}
|
||||
|
||||
// IsServing indicates whether the mock listener is serving.
|
||||
func (l *MockListener) IsServing() bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
return l.Serving
|
||||
}
|
||||
|
||||
// IsListening indicates whether the mock listener is listening.
|
||||
func (l *MockListener) IsListening() bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
return l.Listening
|
||||
}
|
||||
99
listeners/mock_test.go
Normal file
99
listeners/mock_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMockEstablisher(t *testing.T) {
|
||||
_, w := net.Pipe()
|
||||
err := MockEstablisher("t1", w)
|
||||
require.NoError(t, err)
|
||||
_ = w.Close()
|
||||
}
|
||||
|
||||
func TestNewMockListener(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, "t1", mocked.id)
|
||||
require.Equal(t, testAddr, mocked.address)
|
||||
}
|
||||
func TestMockListenerID(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, "t1", mocked.ID())
|
||||
}
|
||||
|
||||
func TestMockListenerAddress(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, testAddr, mocked.Address())
|
||||
}
|
||||
func TestMockListenerProtocol(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, "mock", mocked.Protocol())
|
||||
}
|
||||
|
||||
func TestNewMockListenerIsListening(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, false, mocked.IsListening())
|
||||
}
|
||||
|
||||
func TestNewMockListenerIsServing(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, false, mocked.IsServing())
|
||||
}
|
||||
|
||||
func TestNewMockListenerInit(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, "t1", mocked.id)
|
||||
require.Equal(t, testAddr, mocked.address)
|
||||
|
||||
require.Equal(t, false, mocked.IsListening())
|
||||
err := mocked.Init(nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, mocked.IsListening())
|
||||
}
|
||||
|
||||
func TestNewMockListenerInitFailure(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
mocked.ErrListen = true
|
||||
err := mocked.Init(nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMockListenerServe(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, false, mocked.IsServing())
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
mocked.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond) // easy non-channel wait for start of serving
|
||||
require.Equal(t, true, mocked.IsServing())
|
||||
|
||||
var closed bool
|
||||
mocked.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
|
||||
_ = mocked.Init(nil)
|
||||
}
|
||||
|
||||
func TestMockListenerClose(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
var closed bool
|
||||
mocked.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
92
listeners/net.go
Normal file
92
listeners/net.go
Normal file
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: Jeroen Rinzema
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// Net is a listener for establishing client connections on basic TCP protocol.
|
||||
type Net struct { // [MQTT-4.2.0-1]
|
||||
mu sync.Mutex
|
||||
listener net.Listener // a net.Listener which will listen for new clients
|
||||
id string // the internal id of the listener
|
||||
log *slog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewNet initialises and returns a listener serving incoming connections on the given net.Listener
|
||||
func NewNet(id string, listener net.Listener) *Net {
|
||||
return &Net{
|
||||
id: id,
|
||||
listener: listener,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *Net) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *Net) Address() string {
|
||||
return l.listener.Addr().String()
|
||||
}
|
||||
|
||||
// Protocol returns the network of the listener.
|
||||
func (l *Net) Protocol() string {
|
||||
return l.listener.Addr().Network()
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *Net) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts waiting for new TCP connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *Net) Serve(establish EstablishFn) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn("", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *Net) Close(closeClients CloseFn) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listener != nil {
|
||||
err := l.listener.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
105
listeners/net_test.go
Normal file
105
listeners/net_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewNet(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "t1", l.id)
|
||||
}
|
||||
|
||||
func TestNetID(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestNetAddress(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, n.Addr().String(), l.Address())
|
||||
}
|
||||
|
||||
func TestNetProtocol(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "tcp", l.Protocol())
|
||||
}
|
||||
|
||||
func TestNetInit(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNetServeAndClose(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
|
||||
l.Close(MockCloser) // coverage: close closed
|
||||
l.Serve(MockEstablisher) // coverage: serve closed
|
||||
}
|
||||
|
||||
func TestNetEstablishThenEnd(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn) error {
|
||||
established <- true
|
||||
return errors.New("ending") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
_, _ = net.Dial("tcp", n.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
109
listeners/tcp.go
Normal file
109
listeners/tcp.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
const TypeTCP = "tcp"
|
||||
|
||||
// TCP is a listener for establishing client connections on basic TCP protocol.
|
||||
type TCP struct { // [MQTT-4.2.0-1]
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
listen net.Listener // a net.Listener which will listen for new clients
|
||||
config Config // configuration values for the listener
|
||||
log *slog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewTCP initializes and returns a new TCP listener, listening on an address.
|
||||
func NewTCP(config Config) *TCP {
|
||||
return &TCP{
|
||||
id: config.ID,
|
||||
address: config.Address,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *TCP) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *TCP) Address() string {
|
||||
if l.listen != nil {
|
||||
return l.listen.Addr().String()
|
||||
}
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *TCP) Protocol() string {
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *TCP) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
var err error
|
||||
if l.config.TLSConfig != nil {
|
||||
l.listen, err = tls.Listen("tcp", l.address, l.config.TLSConfig)
|
||||
} else {
|
||||
l.listen, err = net.Listen("tcp", l.address)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Serve starts waiting for new TCP connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *TCP) Serve(establish EstablishFn) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listen.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn("", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *TCP) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listen != nil {
|
||||
err := l.listen.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
124
listeners/tcp_test.go
Normal file
124
listeners/tcp_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewTCP(t *testing.T) {
|
||||
l := NewTCP(basicConfig)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testAddr, l.address)
|
||||
}
|
||||
|
||||
func TestTCPID(t *testing.T) {
|
||||
l := NewTCP(basicConfig)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestTCPAddress(t *testing.T) {
|
||||
l := NewTCP(basicConfig)
|
||||
require.Equal(t, testAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestTCPProtocol(t *testing.T) {
|
||||
l := NewTCP(basicConfig)
|
||||
require.Equal(t, "tcp", l.Protocol())
|
||||
}
|
||||
|
||||
func TestTCPProtocolTLS(t *testing.T) {
|
||||
l := NewTCP(tlsConfig)
|
||||
_ = l.Init(logger)
|
||||
defer l.listen.Close()
|
||||
require.Equal(t, "tcp", l.Protocol())
|
||||
}
|
||||
|
||||
func TestTCPInit(t *testing.T) {
|
||||
l := NewTCP(basicConfig)
|
||||
err := l.Init(logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
|
||||
l2 := NewTCP(tlsConfig)
|
||||
err = l2.Init(logger)
|
||||
l2.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l2.config.TLSConfig)
|
||||
}
|
||||
|
||||
func TestTCPServeAndClose(t *testing.T) {
|
||||
l := NewTCP(basicConfig)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
|
||||
l.Close(MockCloser) // coverage: close closed
|
||||
l.Serve(MockEstablisher) // coverage: serve closed
|
||||
}
|
||||
|
||||
func TestTCPServeTLSAndClose(t *testing.T) {
|
||||
l := NewTCP(tlsConfig)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestTCPEstablishThenEnd(t *testing.T) {
|
||||
l := NewTCP(basicConfig)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn) error {
|
||||
established <- true
|
||||
return errors.New("ending") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
_, _ = net.Dial("tcp", l.listen.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
102
listeners/unixsock.go
Normal file
102
listeners/unixsock.go
Normal file
@@ -0,0 +1,102 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: jason@zgwit.com
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
const TypeUnix = "unix"
|
||||
|
||||
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
|
||||
type UnixSock struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener.
|
||||
address string // the network address to bind to.
|
||||
config Config // configuration values for the listener
|
||||
listen net.Listener // a net.Listener which will listen for new clients.
|
||||
log *slog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
}
|
||||
|
||||
// NewUnixSock initializes and returns a new UnixSock listener, listening on an address.
|
||||
func NewUnixSock(config Config) *UnixSock {
|
||||
return &UnixSock{
|
||||
id: config.ID,
|
||||
address: config.Address,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *UnixSock) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *UnixSock) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *UnixSock) Protocol() string {
|
||||
return "unix"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *UnixSock) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
var err error
|
||||
_ = os.Remove(l.address)
|
||||
l.listen, err = net.Listen("unix", l.address)
|
||||
return err
|
||||
}
|
||||
|
||||
// Serve starts waiting for new UnixSock connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *UnixSock) Serve(establish EstablishFn) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listen.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn("", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *UnixSock) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listen != nil {
|
||||
err := l.listen.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
102
listeners/unixsock_test.go
Normal file
102
listeners/unixsock_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: jason@zgwit.com
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testUnixAddr = "mochi.sock"
|
||||
|
||||
var (
|
||||
unixConfig = Config{ID: "t1", Address: testUnixAddr}
|
||||
)
|
||||
|
||||
func TestNewUnixSock(t *testing.T) {
|
||||
l := NewUnixSock(unixConfig)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testUnixAddr, l.address)
|
||||
}
|
||||
|
||||
func TestUnixSockID(t *testing.T) {
|
||||
l := NewUnixSock(unixConfig)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestUnixSockAddress(t *testing.T) {
|
||||
l := NewUnixSock(unixConfig)
|
||||
require.Equal(t, testUnixAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestUnixSockProtocol(t *testing.T) {
|
||||
l := NewUnixSock(unixConfig)
|
||||
require.Equal(t, "unix", l.Protocol())
|
||||
}
|
||||
|
||||
func TestUnixSockInit(t *testing.T) {
|
||||
l := NewUnixSock(unixConfig)
|
||||
err := l.Init(logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
|
||||
t2Config := unixConfig
|
||||
t2Config.ID = "t2"
|
||||
l2 := NewUnixSock(t2Config)
|
||||
err = l2.Init(logger)
|
||||
l2.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUnixSockServeAndClose(t *testing.T) {
|
||||
l := NewUnixSock(unixConfig)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
|
||||
l.Close(MockCloser) // coverage: close closed
|
||||
l.Serve(MockEstablisher) // coverage: serve closed
|
||||
}
|
||||
|
||||
func TestUnixSockEstablishThenEnd(t *testing.T) {
|
||||
l := NewUnixSock(unixConfig)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn) error {
|
||||
established <- true
|
||||
return errors.New("ending") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
_, _ = net.Dial("unix", l.listen.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
199
listeners/websocket.go
Normal file
199
listeners/websocket.go
Normal file
@@ -0,0 +1,199 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const TypeWS = "ws"
|
||||
|
||||
var (
|
||||
// ErrInvalidMessage indicates that a message payload was not valid.
|
||||
ErrInvalidMessage = errors.New("message type not binary")
|
||||
)
|
||||
|
||||
// Websocket is a listener for establishing websocket connections.
|
||||
type Websocket struct { // [MQTT-4.2.0-1]
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config Config // configuration values for the listener
|
||||
listen *http.Server // a http server for serving websocket connections
|
||||
log *slog.Logger // server logger
|
||||
establish EstablishFn // the server's establish connection handler
|
||||
upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection.
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewWebsocket initializes and returns a new Websocket listener, listening on an address.
|
||||
func NewWebsocket(config Config) *Websocket {
|
||||
return &Websocket{
|
||||
id: config.ID,
|
||||
address: config.Address,
|
||||
config: config,
|
||||
upgrader: &websocket.Upgrader{
|
||||
Subprotocols: []string{"mqtt"},
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *Websocket) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *Websocket) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *Websocket) Protocol() string {
|
||||
if l.config.TLSConfig != nil {
|
||||
return "wss"
|
||||
}
|
||||
|
||||
return "ws"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *Websocket) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", l.handler)
|
||||
l.listen = &http.Server{
|
||||
Addr: l.address,
|
||||
Handler: mux,
|
||||
TLSConfig: l.config.TLSConfig,
|
||||
ReadTimeout: 60 * time.Second,
|
||||
WriteTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handler upgrades and handles an incoming websocket connection.
|
||||
func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := l.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
err = l.establish(l.id, &wsConn{Conn: c.UnderlyingConn(), c: c})
|
||||
if err != nil {
|
||||
l.log.Warn("", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Serve starts waiting for new Websocket connections, and calls the connection
|
||||
// establishment callback for any received.
|
||||
func (l *Websocket) Serve(establish EstablishFn) {
|
||||
var err error
|
||||
l.establish = establish
|
||||
|
||||
if l.listen.TLSConfig != nil {
|
||||
err = l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
err = l.listen.ListenAndServe()
|
||||
}
|
||||
|
||||
// After the listener has been shutdown, no need to print the http.ErrServerClosed error.
|
||||
if err != nil && atomic.LoadUint32(&l.end) == 0 {
|
||||
l.log.Error("failed to serve.", "error", err, "listener", l.id)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *Websocket) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
// wsConn is a websocket connection which satisfies the net.Conn interface.
|
||||
type wsConn struct {
|
||||
net.Conn
|
||||
c *websocket.Conn
|
||||
|
||||
// reader for the current message (can be nil)
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
|
||||
func (ws *wsConn) Read(p []byte) (int, error) {
|
||||
if ws.r == nil {
|
||||
op, r, err := ws.c.NextReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if op != websocket.BinaryMessage {
|
||||
err = ErrInvalidMessage
|
||||
return 0, err
|
||||
}
|
||||
|
||||
ws.r = r
|
||||
}
|
||||
|
||||
var n int
|
||||
for {
|
||||
// buffer is full, return what we've read so far
|
||||
if n == len(p) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
br, err := ws.r.Read(p[n:])
|
||||
n += br
|
||||
if err != nil {
|
||||
// when ANY error occurs, we consider this the end of the current message (either because it really is, via
|
||||
// io.EOF, or because something bad happened, in which case we want to drop the remainder)
|
||||
ws.r = nil
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes bytes to the websocket connection.
|
||||
func (ws *wsConn) Write(p []byte) (int, error) {
|
||||
err := ws.c.WriteMessage(websocket.BinaryMessage, p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Close signals the underlying websocket conn to close.
|
||||
func (ws *wsConn) Close() error {
|
||||
return ws.Conn.Close()
|
||||
}
|
||||
173
listeners/websocket_test.go
Normal file
173
listeners/websocket_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewWebsocket(t *testing.T) {
|
||||
l := NewWebsocket(basicConfig)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testAddr, l.address)
|
||||
}
|
||||
|
||||
func TestWebsocketID(t *testing.T) {
|
||||
l := NewWebsocket(basicConfig)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestWebsocketAddress(t *testing.T) {
|
||||
l := NewWebsocket(basicConfig)
|
||||
require.Equal(t, testAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestWebsocketProtocol(t *testing.T) {
|
||||
l := NewWebsocket(basicConfig)
|
||||
require.Equal(t, "ws", l.Protocol())
|
||||
}
|
||||
|
||||
func TestWebsocketProtocolTLS(t *testing.T) {
|
||||
l := NewWebsocket(tlsConfig)
|
||||
require.Equal(t, "wss", l.Protocol())
|
||||
}
|
||||
|
||||
func TestWebsocketInit(t *testing.T) {
|
||||
l := NewWebsocket(basicConfig)
|
||||
require.Nil(t, l.listen)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen)
|
||||
}
|
||||
|
||||
func TestWebsocketServeAndClose(t *testing.T) {
|
||||
l := NewWebsocket(basicConfig)
|
||||
_ = l.Init(logger)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestWebsocketServeTLSAndClose(t *testing.T) {
|
||||
l := NewWebsocket(tlsConfig)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestWebsocketFailedToServe(t *testing.T) {
|
||||
config := tlsConfig
|
||||
config.Address = "wrong_addr"
|
||||
l := NewWebsocket(config)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
<-o
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
|
||||
func TestWebsocketUpgrade(t *testing.T) {
|
||||
l := NewWebsocket(basicConfig)
|
||||
_ = l.Init(logger)
|
||||
|
||||
e := make(chan bool)
|
||||
l.establish = func(id string, c net.Conn) error {
|
||||
e <- true
|
||||
return nil
|
||||
}
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(l.handler))
|
||||
ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, <-e)
|
||||
|
||||
s.Close()
|
||||
_ = ws.Close()
|
||||
}
|
||||
|
||||
func TestWebsocketConnectionReads(t *testing.T) {
|
||||
l := NewWebsocket(basicConfig)
|
||||
_ = l.Init(nil)
|
||||
|
||||
recv := make(chan []byte)
|
||||
l.establish = func(id string, c net.Conn) error {
|
||||
var out []byte
|
||||
for {
|
||||
buf := make([]byte, 2048)
|
||||
n, err := c.Read(buf)
|
||||
require.NoError(t, err)
|
||||
out = append(out, buf[:n]...)
|
||||
if n < 2048 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
recv <- out
|
||||
return nil
|
||||
}
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(l.handler))
|
||||
ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkt := make([]byte, 3000) // make sure this is >2048
|
||||
for i := 0; i < len(pkt); i++ {
|
||||
pkt[i] = byte(i % 100)
|
||||
}
|
||||
|
||||
err = ws.WriteMessage(websocket.BinaryMessage, pkt)
|
||||
require.NoError(t, err)
|
||||
|
||||
got := <-recv
|
||||
require.Equal(t, 3000, len(got))
|
||||
require.Equal(t, pkt, got)
|
||||
|
||||
s.Close()
|
||||
_ = ws.Close()
|
||||
}
|
||||
Reference in New Issue
Block a user