修改代码

This commit is contained in:
2025-01-08 17:56:09 +08:00
parent 4e49712e5d
commit ec1755ef03
3 changed files with 67 additions and 66 deletions

View File

@@ -64,13 +64,10 @@ func (s *Server) HandleAuth(client *Client, message []byte) error {
return fmt.Errorf("更新设备版本号失败") return fmt.Errorf("更新设备版本号失败")
} }
if client.authTimer != nil { // 设置客户端信息
client.authTimer.Stop()
client.authTimer = nil
}
client.Imei = msg.Imei client.Imei = msg.Imei
client.IsAuth = true client.IsAuth = true
client.LastPing = time.Now()
//response := Message{ //response := Message{
// MessageType: MessageType{Type: "reg"}, // MessageType: MessageType{Type: "reg"},
@@ -86,7 +83,7 @@ func (s *Server) HandleAuth(client *Client, message []byte) error {
responseData, err := json.Marshal(response) responseData, err := json.Marshal(response)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to marshal response: %w", err)
} }
_, err = client.Conn.Write(append(responseData, '\r', '\n')) _, err = client.Conn.Write(append(responseData, '\r', '\n'))

View File

@@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"net" "net"
"strings" "strings"
"time"
) )
type TCPHandler struct { type TCPHandler struct {
@@ -18,17 +19,19 @@ type TCPHandler struct {
func (h *TCPHandler) HandleClient(conn net.Conn) { func (h *TCPHandler) HandleClient(conn net.Conn) {
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
clientID := conn.RemoteAddr().String() clientID := conn.RemoteAddr().String()
var client *Client
if value, ok := h.Server.GetClient(clientID); ok { // 创建临时客户端对象
client = value tempClient := &Client{
} else { ID: clientID,
fmt.Println("找不到客户端:", clientID) Conn: conn,
return ConnectedAt: time.Now(),
LastPing: time.Now(),
Done: make(chan struct{}),
} }
broadcastMessage := func(data []byte) { broadcastMessage := func(data []byte) {
if h.Hub != nil { if h.Hub != nil && tempClient.IsAuth {
h.Hub.Broadcast <- &ws.WsMessage{IMEI: client.Imei, Data: string(data)} h.Hub.Broadcast <- &ws.WsMessage{IMEI: tempClient.Imei, Data: string(data)}
} }
} }
@@ -66,143 +69,146 @@ func (h *TCPHandler) HandleClient(conn net.Conn) {
switch msgType { switch msgType {
case "reg": case "reg":
if err := h.Server.HandleAuth(client, message); err != nil { // 处理注册
if err := h.Server.HandleAuth(tempClient, message); err != nil {
fmt.Printf("客户端授权失败: %v\n", err) fmt.Printf("客户端授权失败: %v\n", err)
conn.Close() conn.Close()
return return
} }
fmt.Printf("客户端已授权: %s\n", client.Imei) // 注册成功后,将客户端添加到 clients map
h.Server.addClient(tempClient)
fmt.Printf("客户端已授权并注册: %s\n", tempClient.Imei)
broadcastMessage(message) broadcastMessage(message)
case "ping": case "ping":
if !client.IsAuth { if !tempClient.IsAuth {
fmt.Printf("来自未授权客户端的心跳 %s\n", conn.RemoteAddr()) fmt.Printf("来自未授权客户端的心跳 %s\n", conn.RemoteAddr())
conn.Close() conn.Close()
return return
} }
if err := h.Server.HandleHeartbeat(client, message); err != nil { if err := h.Server.HandleHeartbeat(tempClient, message); err != nil {
fmt.Printf("心跳错误: %v\n", err) fmt.Printf("心跳错误: %v\n", err)
continue continue
} }
broadcastMessage(message) broadcastMessage(message)
case "ota": case "ota":
if !client.IsAuth { if !tempClient.IsAuth {
fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr())
conn.Close() conn.Close()
return return
} }
if err := h.Server.HandleOta(client, message); err != nil { if err := h.Server.HandleOta(tempClient, message); err != nil {
fmt.Printf("OTA 错误: %v\n", err) fmt.Printf("OTA 错误: %v\n", err)
continue continue
} }
broadcastMessage(message) broadcastMessage(message)
case "start": case "start":
if !client.IsAuth { if !tempClient.IsAuth {
fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr())
conn.Close() conn.Close()
return return
} }
if err := h.Server.RealTimeReporting(client, message); err != nil { if err := h.Server.RealTimeReporting(tempClient, message); err != nil {
fmt.Printf("实时上报错误: %v\n", err) fmt.Printf("实时上报错误: %v\n", err)
continue continue
} }
broadcastMessage(message) broadcastMessage(message)
case "stop": case "stop":
if !client.IsAuth { if !tempClient.IsAuth {
fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr())
conn.Close() conn.Close()
return return
} }
if err := h.Server.StopRealTimeReporting(client, message); err != nil { if err := h.Server.StopRealTimeReporting(tempClient, message); err != nil {
fmt.Printf("客户端停止实时上报数据 错误: %v\n", err) fmt.Printf("客户端停止实时上报数据 错误: %v\n", err)
continue continue
} }
broadcastMessage(message) broadcastMessage(message)
case "up": case "up":
if !client.IsAuth { if !tempClient.IsAuth {
fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr())
conn.Close() conn.Close()
return return
} }
if err := h.Server.TimingReporting(client, message); err != nil { if err := h.Server.TimingReporting(tempClient, message); err != nil {
fmt.Printf("客户端定时上报数据 错误: %v\n", err) fmt.Printf("客户端定时上报数据 错误: %v\n", err)
continue continue
} }
broadcastMessage(message) broadcastMessage(message)
case "SetConfig": case "SetConfig":
if !client.IsAuth { if !tempClient.IsAuth {
fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr())
conn.Close() conn.Close()
return return
} }
if err := h.Server.SetConfig(client, message); err != nil { if err := h.Server.SetConfig(tempClient, message); err != nil {
fmt.Printf("客户端楼层设置 错误: %v\n", err) fmt.Printf("客户端楼层设置 错误: %v\n", err)
continue continue
} }
broadcastMessage(message) broadcastMessage(message)
case "GetConfig": case "GetConfig":
if !client.IsAuth { if !tempClient.IsAuth {
fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr())
conn.Close() conn.Close()
return return
} }
if err := h.Server.GetConfig(client, message); err != nil { if err := h.Server.GetConfig(tempClient, message); err != nil {
fmt.Printf("获取客户端楼层设置错误: %v\n", err) fmt.Printf("获取客户端楼层设置错误: %v\n", err)
continue continue
} }
broadcastMessage(message) broadcastMessage(message)
case "mp3a", "mp3b", "mp3c": case "mp3a", "mp3b", "mp3c":
if !client.IsAuth { if !tempClient.IsAuth {
fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr())
conn.Close() conn.Close()
return return
} }
if err := h.Server.Mp3(client, message); err != nil { if err := h.Server.Mp3(tempClient, message); err != nil {
fmt.Printf("客户端设置语音内容错误: %v\n", err) fmt.Printf("客户端设置语音内容错误: %v\n", err)
continue continue
} }
broadcastMessage(message) broadcastMessage(message)
case "SetVoiceConf": case "SetVoiceConf":
if !client.IsAuth { if !tempClient.IsAuth {
fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr())
conn.Close() conn.Close()
return return
} }
if err := h.Server.SetVoiceConf(client, message); err != nil { if err := h.Server.SetVoiceConf(tempClient, message); err != nil {
fmt.Printf("客户端语音配置错误: %v\n", err) fmt.Printf("客户端语音配置错误: %v\n", err)
continue continue
} }
broadcastMessage(message) broadcastMessage(message)
case "GetVoiceConf": case "GetVoiceConf":
if !client.IsAuth { if !tempClient.IsAuth {
fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr())
conn.Close() conn.Close()
return return
} }
if err := h.Server.GetVoiceConf(client, message); err != nil { if err := h.Server.GetVoiceConf(tempClient, message); err != nil {
fmt.Printf("获取客户端语音配置错误: %v\n", err) fmt.Printf("获取客户端语音配置错误: %v\n", err)
continue continue
} }
broadcastMessage(message) broadcastMessage(message)
case "db": case "db":
if !client.IsAuth { if !tempClient.IsAuth {
fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr())
conn.Close() conn.Close()
return return
} }
if err := h.Server.Db(client, message); err != nil { if err := h.Server.Db(tempClient, message); err != nil {
fmt.Printf("获取客户端音量分贝错误: %v\n", err) fmt.Printf("获取客户端音量分贝错误: %v\n", err)
continue continue
} }
broadcastMessage(message) broadcastMessage(message)
default: default:
if !client.IsAuth { if !tempClient.IsAuth {
fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr())
conn.Close() conn.Close()
return return

View File

@@ -60,7 +60,6 @@ func (s *Server) Start() error {
go s.handleShutdown() go s.handleShutdown()
for { for {
select { select {
case <-s.stopChan: case <-s.stopChan:
return nil return nil
@@ -78,21 +77,36 @@ func (s *Server) Start() error {
} }
} }
s.connections.Store(conn.RemoteAddr(), conn) // 创建临时客户端,不加入 clients map
client := s.addClient(conn) tempClient := &Client{
fmt.Printf("客户端已连接: %s\n", client.ID) ID: conn.RemoteAddr().String(),
Conn: conn,
ConnectedAt: time.Now(),
LastPing: time.Now(),
Done: make(chan struct{}),
}
// 设置认证超时
time.AfterFunc(time.Minute, func() {
if !tempClient.IsAuth {
fmt.Printf("Client %s authentication timeout\n", tempClient.ID)
conn.Close()
}
})
s.wg.Add(1) s.wg.Add(1)
go func(c net.Conn, clientID string) { go func(c net.Conn, client *Client) {
defer s.wg.Done() defer s.wg.Done()
defer func() { defer func() {
s.connections.Delete(c.RemoteAddr()) s.connections.Delete(c.RemoteAddr())
s.removeClient(clientID) if client.IsAuth {
s.removeClient(client.ID)
}
c.Close() c.Close()
fmt.Printf("客户端已断开连接: %s\n", clientID) fmt.Printf("客户端已断开连接: %s\n", client.ID)
}() }()
s.Handler(c) s.Handler(c)
}(conn, client.ID) }(conn, tempClient)
} }
} }
@@ -126,29 +140,13 @@ func (s *Server) GetOnlineClients() []map[string]interface{} {
} }
// addClient 添加新客户端 // addClient 添加新客户端
func (s *Server) addClient(conn net.Conn) *Client { func (s *Server) addClient(client *Client) {
s.clientsMux.Lock() s.clientsMux.Lock()
defer s.clientsMux.Unlock() defer s.clientsMux.Unlock()
client := &Client{ // 使用 IMEI 作为 key
ID: conn.RemoteAddr().String(), s.clients[client.Imei] = client
Conn: conn,
ConnectedAt: time.Now(),
LastPing: time.Now(),
Done: make(chan struct{}),
}
s.clients[client.ID] = client
go client.startHeartbeat(s) go client.startHeartbeat(s)
client.authTimer = time.AfterFunc(time.Minute, func() {
if !client.IsAuth {
fmt.Printf("Client %s authentication timeout\n", client.ID)
client.Conn.Close() // 强制关闭连接
}
})
return client
} }
// removeClient 移除客户端 // removeClient 移除客户端