diff --git a/tcpserver/func.go b/tcpserver/func.go index 27bdf65..10a33b6 100644 --- a/tcpserver/func.go +++ b/tcpserver/func.go @@ -64,13 +64,10 @@ func (s *Server) HandleAuth(client *Client, message []byte) error { return fmt.Errorf("更新设备版本号失败") } - if client.authTimer != nil { - client.authTimer.Stop() - client.authTimer = nil - } - + // 设置客户端信息 client.Imei = msg.Imei client.IsAuth = true + client.LastPing = time.Now() //response := Message{ // MessageType: MessageType{Type: "reg"}, @@ -86,7 +83,7 @@ func (s *Server) HandleAuth(client *Client, message []byte) error { responseData, err := json.Marshal(response) if err != nil { - return err + return fmt.Errorf("failed to marshal response: %w", err) } _, err = client.Conn.Write(append(responseData, '\r', '\n')) diff --git a/tcpserver/handler.go b/tcpserver/handler.go index 3a41692..598d21b 100644 --- a/tcpserver/handler.go +++ b/tcpserver/handler.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "strings" + "time" ) type TCPHandler struct { @@ -18,17 +19,19 @@ type TCPHandler struct { func (h *TCPHandler) HandleClient(conn net.Conn) { reader := bufio.NewReader(conn) clientID := conn.RemoteAddr().String() - var client *Client - if value, ok := h.Server.GetClient(clientID); ok { - client = value - } else { - fmt.Println("找不到客户端:", clientID) - return + + // 创建临时客户端对象 + tempClient := &Client{ + ID: clientID, + Conn: conn, + ConnectedAt: time.Now(), + LastPing: time.Now(), + Done: make(chan struct{}), } broadcastMessage := func(data []byte) { - if h.Hub != nil { - h.Hub.Broadcast <- &ws.WsMessage{IMEI: client.Imei, Data: string(data)} + if h.Hub != nil && tempClient.IsAuth { + h.Hub.Broadcast <- &ws.WsMessage{IMEI: tempClient.Imei, Data: string(data)} } } @@ -66,143 +69,146 @@ func (h *TCPHandler) HandleClient(conn net.Conn) { switch msgType { case "reg": - if err := h.Server.HandleAuth(client, message); err != nil { + // 处理注册 + if err := h.Server.HandleAuth(tempClient, message); err != nil { fmt.Printf("客户端授权失败: %v\n", err) conn.Close() return } - fmt.Printf("客户端已授权: %s\n", client.Imei) + // 注册成功后,将客户端添加到 clients map + h.Server.addClient(tempClient) + fmt.Printf("客户端已授权并注册: %s\n", tempClient.Imei) broadcastMessage(message) case "ping": - if !client.IsAuth { + if !tempClient.IsAuth { fmt.Printf("来自未授权客户端的心跳 %s\n", conn.RemoteAddr()) conn.Close() return } - if err := h.Server.HandleHeartbeat(client, message); err != nil { + if err := h.Server.HandleHeartbeat(tempClient, message); err != nil { fmt.Printf("心跳错误: %v\n", err) continue } broadcastMessage(message) case "ota": - if !client.IsAuth { + if !tempClient.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() return } - if err := h.Server.HandleOta(client, message); err != nil { + if err := h.Server.HandleOta(tempClient, message); err != nil { fmt.Printf("OTA 错误: %v\n", err) continue } broadcastMessage(message) case "start": - if !client.IsAuth { + if !tempClient.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() return } - if err := h.Server.RealTimeReporting(client, message); err != nil { + if err := h.Server.RealTimeReporting(tempClient, message); err != nil { fmt.Printf("实时上报错误: %v\n", err) continue } broadcastMessage(message) case "stop": - if !client.IsAuth { + if !tempClient.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() return } - if err := h.Server.StopRealTimeReporting(client, message); err != nil { + if err := h.Server.StopRealTimeReporting(tempClient, message); err != nil { fmt.Printf("客户端停止实时上报数据 错误: %v\n", err) continue } broadcastMessage(message) case "up": - if !client.IsAuth { + if !tempClient.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() return } - if err := h.Server.TimingReporting(client, message); err != nil { + if err := h.Server.TimingReporting(tempClient, message); err != nil { fmt.Printf("客户端定时上报数据 错误: %v\n", err) continue } broadcastMessage(message) case "SetConfig": - if !client.IsAuth { + if !tempClient.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() return } - if err := h.Server.SetConfig(client, message); err != nil { + if err := h.Server.SetConfig(tempClient, message); err != nil { fmt.Printf("客户端楼层设置 错误: %v\n", err) continue } broadcastMessage(message) case "GetConfig": - if !client.IsAuth { + if !tempClient.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() return } - if err := h.Server.GetConfig(client, message); err != nil { + if err := h.Server.GetConfig(tempClient, message); err != nil { fmt.Printf("获取客户端楼层设置错误: %v\n", err) continue } broadcastMessage(message) case "mp3a", "mp3b", "mp3c": - if !client.IsAuth { + if !tempClient.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() return } - if err := h.Server.Mp3(client, message); err != nil { + if err := h.Server.Mp3(tempClient, message); err != nil { fmt.Printf("客户端设置语音内容错误: %v\n", err) continue } broadcastMessage(message) case "SetVoiceConf": - if !client.IsAuth { + if !tempClient.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() return } - if err := h.Server.SetVoiceConf(client, message); err != nil { + if err := h.Server.SetVoiceConf(tempClient, message); err != nil { fmt.Printf("客户端语音配置错误: %v\n", err) continue } broadcastMessage(message) case "GetVoiceConf": - if !client.IsAuth { + if !tempClient.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() return } - if err := h.Server.GetVoiceConf(client, message); err != nil { + if err := h.Server.GetVoiceConf(tempClient, message); err != nil { fmt.Printf("获取客户端语音配置错误: %v\n", err) continue } broadcastMessage(message) case "db": - if !client.IsAuth { + if !tempClient.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() return } - if err := h.Server.Db(client, message); err != nil { + if err := h.Server.Db(tempClient, message); err != nil { fmt.Printf("获取客户端音量分贝错误: %v\n", err) continue } broadcastMessage(message) default: - if !client.IsAuth { + if !tempClient.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() return diff --git a/tcpserver/tcpserver.go b/tcpserver/tcpserver.go index 021c5a4..df6e395 100644 --- a/tcpserver/tcpserver.go +++ b/tcpserver/tcpserver.go @@ -60,7 +60,6 @@ func (s *Server) Start() error { go s.handleShutdown() for { - select { case <-s.stopChan: return nil @@ -78,21 +77,36 @@ func (s *Server) Start() error { } } - s.connections.Store(conn.RemoteAddr(), conn) - client := s.addClient(conn) - fmt.Printf("客户端已连接: %s\n", client.ID) + // 创建临时客户端,不加入 clients map + tempClient := &Client{ + ID: conn.RemoteAddr().String(), + Conn: conn, + ConnectedAt: time.Now(), + LastPing: time.Now(), + Done: make(chan struct{}), + } + + // 设置认证超时 + time.AfterFunc(time.Minute, func() { + if !tempClient.IsAuth { + fmt.Printf("Client %s authentication timeout\n", tempClient.ID) + conn.Close() + } + }) s.wg.Add(1) - go func(c net.Conn, clientID string) { + go func(c net.Conn, client *Client) { defer s.wg.Done() defer func() { s.connections.Delete(c.RemoteAddr()) - s.removeClient(clientID) + if client.IsAuth { + s.removeClient(client.ID) + } c.Close() - fmt.Printf("客户端已断开连接: %s\n", clientID) + fmt.Printf("客户端已断开连接: %s\n", client.ID) }() s.Handler(c) - }(conn, client.ID) + }(conn, tempClient) } } @@ -126,29 +140,13 @@ func (s *Server) GetOnlineClients() []map[string]interface{} { } // addClient 添加新客户端 -func (s *Server) addClient(conn net.Conn) *Client { +func (s *Server) addClient(client *Client) { s.clientsMux.Lock() defer s.clientsMux.Unlock() - client := &Client{ - ID: conn.RemoteAddr().String(), - Conn: conn, - ConnectedAt: time.Now(), - LastPing: time.Now(), - Done: make(chan struct{}), - } - s.clients[client.ID] = client - + // 使用 IMEI 作为 key + s.clients[client.Imei] = client go client.startHeartbeat(s) - - client.authTimer = time.AfterFunc(time.Minute, func() { - if !client.IsAuth { - fmt.Printf("Client %s authentication timeout\n", client.ID) - client.Conn.Close() // 强制关闭连接 - } - }) - - return client } // removeClient 移除客户端