230 lines
4.6 KiB
Go
230 lines
4.6 KiB
Go
package tcpserver
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"os/signal"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
)
|
|
|
|
type Server struct {
|
|
Address string
|
|
Handler func(net.Conn)
|
|
listener net.Listener
|
|
connections sync.Map
|
|
wg sync.WaitGroup
|
|
stopChan chan struct{}
|
|
|
|
clients map[string]*Client
|
|
|
|
clientsMux sync.RWMutex
|
|
stopOnce sync.Once
|
|
}
|
|
|
|
// Client 定义客户端结构
|
|
type Client struct {
|
|
ID string
|
|
Imei string
|
|
Conn net.Conn
|
|
ConnectedAt time.Time
|
|
LastPing time.Time
|
|
Done chan struct{}
|
|
IsAuth bool
|
|
authTimer *time.Timer
|
|
}
|
|
|
|
// NewServer 创建一个新的 TCP 服务器
|
|
func NewServer(address string, handler func(net.Conn)) *Server {
|
|
return &Server{
|
|
Address: address,
|
|
Handler: handler,
|
|
stopChan: make(chan struct{}),
|
|
clients: make(map[string]*Client),
|
|
}
|
|
}
|
|
|
|
func (s *Server) Start() error {
|
|
var err error
|
|
s.listener, err = net.Listen("tcp", s.Address)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start server: %w", err)
|
|
}
|
|
|
|
fmt.Println("Listening and serving TCP on", s.Address)
|
|
|
|
go s.handleShutdown()
|
|
|
|
for {
|
|
select {
|
|
case <-s.stopChan:
|
|
return nil
|
|
default:
|
|
}
|
|
|
|
conn, err := s.listener.Accept()
|
|
if err != nil {
|
|
select {
|
|
case <-s.stopChan:
|
|
return nil
|
|
default:
|
|
fmt.Println("Error accepting connection:", err)
|
|
continue
|
|
}
|
|
}
|
|
|
|
// 创建临时客户端,不加入 clients map
|
|
tempClient := &Client{
|
|
ID: conn.RemoteAddr().String(),
|
|
Conn: conn,
|
|
ConnectedAt: time.Now(),
|
|
LastPing: time.Now(),
|
|
Done: make(chan struct{}),
|
|
}
|
|
|
|
// 设置认证超时
|
|
time.AfterFunc(time.Minute, func() {
|
|
if !tempClient.IsAuth {
|
|
fmt.Printf("Client %s authentication timeout\n", tempClient.ID)
|
|
conn.Close()
|
|
}
|
|
})
|
|
|
|
s.wg.Add(1)
|
|
go func(c net.Conn, client *Client) {
|
|
defer s.wg.Done()
|
|
defer func() {
|
|
s.connections.Delete(c.RemoteAddr())
|
|
if client.IsAuth {
|
|
s.removeClient(client.ID)
|
|
}
|
|
c.Close()
|
|
fmt.Printf("客户端已断开连接: %s\n", client.ID)
|
|
}()
|
|
s.Handler(c)
|
|
}(conn, tempClient)
|
|
}
|
|
}
|
|
|
|
// handleShutdown 处理优雅关闭
|
|
func (s *Server) handleShutdown() {
|
|
c := make(chan os.Signal, 1)
|
|
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
<-c
|
|
fmt.Println("\nReceived shutdown signal...")
|
|
s.Stop()
|
|
}
|
|
|
|
func (s *Server) GetOnlineClients() []map[string]interface{} {
|
|
s.clientsMux.RLock()
|
|
defer s.clientsMux.RUnlock()
|
|
|
|
clients := make([]map[string]interface{}, 0, len(s.clients))
|
|
for _, client := range s.clients {
|
|
clientInfo := map[string]interface{}{
|
|
"id": client.ID,
|
|
"imei": client.Imei,
|
|
"addr": client.Conn.RemoteAddr().String(),
|
|
"connected_at": client.ConnectedAt,
|
|
"last_ping": client.LastPing,
|
|
"is_auth": client.IsAuth,
|
|
}
|
|
clients = append(clients, clientInfo)
|
|
}
|
|
return clients
|
|
}
|
|
|
|
// addClient 添加新客户端
|
|
func (s *Server) addClient(client *Client) {
|
|
s.clientsMux.Lock()
|
|
defer s.clientsMux.Unlock()
|
|
|
|
// 使用 IMEI 作为 key
|
|
s.clients[client.Imei] = client
|
|
go client.startHeartbeat(s)
|
|
}
|
|
|
|
// removeClient 移除客户端
|
|
func (s *Server) removeClient(id string) {
|
|
s.clientsMux.Lock()
|
|
defer s.clientsMux.Unlock()
|
|
if client, ok := s.clients[id]; ok {
|
|
if client.authTimer != nil {
|
|
client.authTimer.Stop()
|
|
client.authTimer = nil
|
|
}
|
|
close(client.Done)
|
|
delete(s.clients, id)
|
|
}
|
|
}
|
|
|
|
// Stop 添加一个新方法用于外部调用关闭服务
|
|
func (s *Server) Stop() {
|
|
s.stopOnce.Do(func() {
|
|
fmt.Println("Stopping TCP server...")
|
|
|
|
close(s.stopChan)
|
|
if s.listener != nil {
|
|
s.listener.Close()
|
|
}
|
|
|
|
s.clientsMux.Lock()
|
|
for id, client := range s.clients {
|
|
client.Conn.Close()
|
|
delete(s.clients, id)
|
|
}
|
|
s.clientsMux.Unlock()
|
|
|
|
s.wg.Wait()
|
|
fmt.Println("TCP server stopped.")
|
|
})
|
|
}
|
|
|
|
// 添加心跳检测方法
|
|
func (c *Client) startHeartbeat(s *Server) {
|
|
ticker := time.NewTicker(20 * time.Second) // 每10秒检查一次
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-c.Done:
|
|
return
|
|
case <-ticker.C:
|
|
if time.Since(c.LastPing) > 120*time.Second {
|
|
fmt.Printf("客户端 %s 心跳超时 \n", c.ID)
|
|
c.Conn.Close()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetClient 获取指定ID的客户端
|
|
func (s *Server) GetClient(id string) (*Client, bool) {
|
|
s.clientsMux.RLock()
|
|
defer s.clientsMux.RUnlock()
|
|
client, ok := s.clients[id]
|
|
return client, ok
|
|
}
|
|
|
|
func readUntilDelimiter(reader *bufio.Reader, delimiter []byte) ([]byte, error) {
|
|
var buffer bytes.Buffer
|
|
for {
|
|
chunk, err := reader.ReadBytes('\n')
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
buffer.Write(chunk)
|
|
if bytes.HasSuffix(buffer.Bytes(), delimiter) {
|
|
break
|
|
}
|
|
}
|
|
data := buffer.Bytes()
|
|
return data[:len(data)-len(delimiter)], nil
|
|
}
|