Files
elevator-server/tcpserver/tcpserver.go
2024-12-24 14:07:17 +08:00

232 lines
4.7 KiB
Go

package tcpserver
import (
"bufio"
"bytes"
"fmt"
"net"
"os"
"os/signal"
"sync"
"syscall"
"time"
)
type Server struct {
Address string
Handler func(net.Conn)
listener net.Listener
connections sync.Map
wg sync.WaitGroup
stopChan chan struct{}
clients map[string]*Client
clientsMux sync.RWMutex
stopOnce sync.Once
}
// Client 定义客户端结构
type Client struct {
ID string
Imei string
Conn net.Conn
ConnectedAt time.Time
LastPing time.Time
Done chan struct{}
IsAuth bool
authTimer *time.Timer
}
// NewServer 创建一个新的 TCP 服务器
func NewServer(address string, handler func(net.Conn)) *Server {
return &Server{
Address: address,
Handler: handler,
stopChan: make(chan struct{}),
clients: make(map[string]*Client),
}
}
func (s *Server) Start() error {
var err error
s.listener, err = net.Listen("tcp", s.Address)
if err != nil {
return fmt.Errorf("failed to start server: %w", err)
}
fmt.Println("Listening and serving TCP on", s.Address)
go s.handleShutdown()
for {
select {
case <-s.stopChan:
return nil
default:
}
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.stopChan:
return nil
default:
fmt.Println("Error accepting connection:", err)
continue
}
}
s.connections.Store(conn.RemoteAddr(), conn)
client := s.addClient(conn)
fmt.Printf("客户端已连接: %s\n", client.ID)
s.wg.Add(1)
go func(c net.Conn, clientID string) {
defer s.wg.Done()
defer func() {
s.connections.Delete(c.RemoteAddr())
s.removeClient(clientID)
c.Close()
fmt.Printf("客户端已断开连接: %s\n", clientID)
}()
s.Handler(c)
}(conn, client.ID)
}
}
// handleShutdown 处理优雅关闭
func (s *Server) handleShutdown() {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
<-c
fmt.Println("\nReceived shutdown signal...")
s.Stop()
}
func (s *Server) GetOnlineClients() []map[string]interface{} {
s.clientsMux.RLock()
defer s.clientsMux.RUnlock()
clients := make([]map[string]interface{}, 0, len(s.clients))
for _, client := range s.clients {
clientInfo := map[string]interface{}{
"id": client.ID,
"imei": client.Imei,
"addr": client.Conn.RemoteAddr().String(),
"connected_at": client.ConnectedAt,
"last_ping": client.LastPing,
"is_auth": client.IsAuth,
}
clients = append(clients, clientInfo)
}
return clients
}
// addClient 添加新客户端
func (s *Server) addClient(conn net.Conn) *Client {
s.clientsMux.Lock()
defer s.clientsMux.Unlock()
client := &Client{
ID: conn.RemoteAddr().String(),
Conn: conn,
ConnectedAt: time.Now(),
LastPing: time.Now(),
Done: make(chan struct{}),
}
s.clients[client.ID] = client
go client.startHeartbeat(s)
client.authTimer = time.AfterFunc(time.Minute, func() {
if !client.IsAuth {
fmt.Printf("Client %s authentication timeout\n", client.ID)
client.Conn.Close() // 强制关闭连接
}
})
return client
}
// removeClient 移除客户端
func (s *Server) removeClient(id string) {
s.clientsMux.Lock()
defer s.clientsMux.Unlock()
if client, ok := s.clients[id]; ok {
if client.authTimer != nil {
client.authTimer.Stop()
client.authTimer = nil
}
close(client.Done)
delete(s.clients, id)
}
}
// Stop 添加一个新方法用于外部调用关闭服务
func (s *Server) Stop() {
s.stopOnce.Do(func() {
fmt.Println("Stopping TCP server...")
close(s.stopChan)
if s.listener != nil {
s.listener.Close()
}
s.clientsMux.Lock()
for id, client := range s.clients {
client.Conn.Close()
delete(s.clients, id)
}
s.clientsMux.Unlock()
s.wg.Wait()
fmt.Println("TCP server stopped.")
})
}
// 添加心跳检测方法
func (c *Client) startHeartbeat(s *Server) {
ticker := time.NewTicker(20 * time.Second) // 每10秒检查一次
defer ticker.Stop()
for {
select {
case <-c.Done:
return
case <-ticker.C:
if time.Since(c.LastPing) > 120*time.Second {
fmt.Printf("客户端 %s 心跳超时 \n", c.ID)
c.Conn.Close()
return
}
}
}
}
// GetClient 获取指定ID的客户端
func (s *Server) GetClient(id string) (*Client, bool) {
s.clientsMux.RLock()
defer s.clientsMux.RUnlock()
client, ok := s.clients[id]
return client, ok
}
func readUntilDelimiter(reader *bufio.Reader, delimiter []byte) ([]byte, error) {
var buffer bytes.Buffer
for {
chunk, err := reader.ReadBytes('\n')
if err != nil {
return nil, err
}
buffer.Write(chunk)
if bytes.HasSuffix(buffer.Bytes(), delimiter) {
break
}
}
data := buffer.Bytes()
return data[:len(data)-len(delimiter)], nil
}