diff --git a/global/var.go b/global/var.go index 242fd68..d9b4357 100644 --- a/global/var.go +++ b/global/var.go @@ -6,14 +6,7 @@ import ( ) var ( - // AppConf 配置信息 AppConf *config.Config - // Db 数据库 - Db *gorm.DB - - //InFluxDb influxdb2.Client - - Log *Logger - - //Cron *cron.Cron + Db *gorm.DB + Log *Logger ) diff --git a/model/device.go b/model/device.go index 6810fda..23bb2cf 100644 --- a/model/device.go +++ b/model/device.go @@ -7,17 +7,14 @@ import ( ) type Device struct { - Id int `gorm:"column:id;primaryKey" json:"id"` - //DriverId string `gorm:"column:driver_id;comment:设备ID;type:varchar(255)" json:"driver_id"` - Imei string `gorm:"column:imei;comment:IMEI;type:varchar(255)" json:"imei"` - //DriverName string `gorm:"column:driver_name;comment:设备名称;type:varchar(255)" json:"driver_name"` - DriverPass string `gorm:"column:driver_pass;comment:设备密码;type:varchar(255)" json:"driver_pass"` - DriverVer string `gorm:"column:driver_ver;comment:固件版本;type:varchar(255)" json:"driver_ver"` - //DriverFd string `gorm:"column:driver_fd;comment:设备FD;type:varchar(255)" json:"driver_fd"` - Remark string `gorm:"column:remark;comment:备注;type:varchar(255)" json:"remark"` - Created time.Time `gorm:"column:created;autoCreateTime;comment:创建时间" json:"created"` - Updated time.Time `gorm:"column:updated;autoUpdateTime;comment:修改时间" json:"updated"` - DeletedAt gorm.DeletedAt `gorm:"index;comment:删除时间" json:"-"` + Id int `gorm:"column:id;primaryKey" json:"id"` + Imei string `gorm:"column:imei;comment:IMEI;type:varchar(255)" json:"imei"` + DriverPass string `gorm:"column:driver_pass;comment:设备密码;type:varchar(255)" json:"driver_pass"` + DriverVer string `gorm:"column:driver_ver;comment:固件版本;type:varchar(255)" json:"driver_ver"` + Remark string `gorm:"column:remark;comment:备注;type:varchar(255)" json:"remark"` + Created time.Time `gorm:"column:created;autoCreateTime;comment:创建时间" json:"created"` + Updated time.Time `gorm:"column:updated;autoUpdateTime;comment:修改时间" json:"updated"` + DeletedAt gorm.DeletedAt `gorm:"index;comment:删除时间" json:"-"` } func (r *Device) TableName() string { diff --git a/tcpserver/handler.go b/tcpserver/handler.go index 11e529b..769000a 100644 --- a/tcpserver/handler.go +++ b/tcpserver/handler.go @@ -35,28 +35,22 @@ func (h *TCPHandler) HandleClient(conn net.Conn) { } for { - //message, err := reader.ReadBytes('\n') - //if err != nil { - // fmt.Println("客户端已断开连接:", err) - // break - //} message, err := readUntilDelimiter(reader, []byte("\r\n")) if err != nil { fmt.Println("Error reading message:", err) break } - // 去除末尾的换行符 - message = bytes.TrimSpace(message) + message = bytes.TrimSpace(message) output := strings.ReplaceAll(string(message), "\t", "") output = strings.ReplaceAll(output, "\n", "") message = []byte(output) fmt.Printf("收到消息来自 %s: %s\n", conn.RemoteAddr(), message) - //if !json.Valid(message) { - // fmt.Printf("来自客户端的数据非法 %s\n", conn.RemoteAddr()) - // conn.Close() - // return - //} + if !json.Valid(message) { + fmt.Printf("来自客户端的数据非法 %s\n", conn.RemoteAddr()) + conn.Close() + return + } var fullMsg map[string]interface{} if err := json.Unmarshal(message, &fullMsg); err != nil { @@ -73,18 +67,15 @@ func (h *TCPHandler) HandleClient(conn net.Conn) { switch msgType { case "reg": - // 处理登录请求 if err := h.Server.HandleAuth(client, message); err != nil { fmt.Printf("客户端授权失败: %v\n", err) conn.Close() return } fmt.Printf("客户端已授权: %s\n", client.Imei) - // 广播登录消息 broadcastMessage(message) case "ping": - // 处理心跳 if !client.IsAuth { fmt.Printf("来自未授权客户端的心跳 %s\n", conn.RemoteAddr()) conn.Close() @@ -94,11 +85,9 @@ func (h *TCPHandler) HandleClient(conn net.Conn) { fmt.Printf("心跳错误: %v\n", err) continue } - // 广播心跳消息 broadcastMessage(message) case "ota": - // 处理 OTA if !client.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() @@ -108,11 +97,9 @@ func (h *TCPHandler) HandleClient(conn net.Conn) { fmt.Printf("OTA 错误: %v\n", err) continue } - // 广播 OTA 消息 broadcastMessage(message) case "start": - // 处理 客户端实时上报数据 if !client.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() @@ -122,11 +109,9 @@ func (h *TCPHandler) HandleClient(conn net.Conn) { fmt.Printf("OTA 错误: %v\n", err) continue } - // 广播 OTA 消息 broadcastMessage(message) case "stop": - // 处理 客户端停止实时上报数据 if !client.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() @@ -136,11 +121,9 @@ func (h *TCPHandler) HandleClient(conn net.Conn) { fmt.Printf("客户端停止实时上报数据 错误: %v\n", err) continue } - // 广播 OTA 消息 broadcastMessage(message) case "up": - // 处理 客户端定时上报数据 if !client.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() @@ -150,17 +133,14 @@ func (h *TCPHandler) HandleClient(conn net.Conn) { fmt.Printf("客户端定时上报数据 错误: %v\n", err) continue } - // 广播 OTA 消息 broadcastMessage(message) default: - // 处理其他消息类型 if !client.IsAuth { fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr()) conn.Close() return } - // 广播其他类型的消息 broadcastMessage(message) } } @@ -209,24 +189,20 @@ func (s *Server) HandleAuth(client *Client, message []byte) error { return fmt.Errorf("设备密码不正确") } - // 更新版本号 model.DriverVer = msg.Ver err = repository.GroupRepositorys.Device.UpdateDevice(model) if err != nil { return fmt.Errorf("更新设备版本号失败") } - // 认证成功,停止登录超时定时器 if client.authTimer != nil { client.authTimer.Stop() client.authTimer = nil } - // 认证成功 client.Imei = msg.Imei client.IsAuth = true - // 发送响应 response := Message{ MessageType: MessageType{Type: "reg"}, MessageTime: MessageTime{Time: time.Now().Unix()}, @@ -248,45 +224,7 @@ func (s *Server) HandleOta(client *Client, message []byte) error { if msg.Type != "ota" { return fmt.Errorf("unauthorized") } - fmt.Printf("设备升级结果:%s\r\n", msg.State) - - //model, err := repository.GroupRepositorys.Device.GetDevice(map[string]interface{}{"imei": msg.Imei}) - //if err != nil { - // return fmt.Errorf("设备不存在") - //} - //if msg.Pwd != model.DriverPass { - // return fmt.Errorf("设备密码不正确") - //} - - // 更新版本号 - //model.DriverVer = msg.Ver - //err = repository.GroupRepositorys.Device.UpdateDevice(model) - //if err != nil { - // return fmt.Errorf("更新设备版本号失败") - //} - - // 认证成功,停止登录超时定时器 - //if client.authTimer != nil { - // client.authTimer.Stop() - // client.authTimer = nil - //} - - // 认证成功 - //client.Imei = msg.Imei - //client.IsAuth = true - - // 发送响应 - //response := Message{ - // MessageType: MessageType{Type: "reg"}, - // MessageTime: MessageTime{Time: time.Now().Unix()}, - //} - //responseData, err := json.Marshal(response) - //if err != nil { - // return err - //} - // - //_, err = client.Conn.Write(append(responseData, '\r', '\n')) return nil } diff --git a/tcpserver/message.go b/tcpserver/message.go index 4398873..0a88479 100644 --- a/tcpserver/message.go +++ b/tcpserver/message.go @@ -3,7 +3,6 @@ package tcpserver type MessageType struct { Type string `json:"Type"` } - type MessagePassword struct { Pwd string `json:"Pwd,omitempty"` } @@ -24,12 +23,6 @@ type MessageData struct { UpDataStruct } -type UpDataStruct struct { - Sum int `json:"sum"` - Time int `json:"time"` - Mile int `json:"mile"` -} - type Message struct { MessageType MessageImei @@ -39,3 +32,9 @@ type Message struct { MessageState MessageData `json:"Data"` } + +type UpDataStruct struct { + Sum int `json:"sum"` + Time int `json:"time"` + Mile int `json:"mile"` +} diff --git a/tcpserver/tcpserver.go b/tcpserver/tcpserver.go index f1fbdf4..021c5a4 100644 --- a/tcpserver/tcpserver.go +++ b/tcpserver/tcpserver.go @@ -12,31 +12,30 @@ import ( "time" ) -// Server 定义 TCP 服务器结构 type Server struct { - Address string // 监听地址 - Handler func(net.Conn) // 客户端连接处理函数 - listener net.Listener // TCP 监听器 - connections sync.Map // 活跃的客户端连接 - wg sync.WaitGroup // 等待所有 Goroutine 完成 - stopChan chan struct{} // 关闭信号 + Address string + Handler func(net.Conn) + listener net.Listener + connections sync.Map + wg sync.WaitGroup + stopChan chan struct{} clients map[string]*Client clientsMux sync.RWMutex - stopOnce sync.Once // 添加 sync.Once 来确保 Stop 只被执行一次 + stopOnce sync.Once } // Client 定义客户端结构 type Client struct { ID string - Imei string // 添加 IMEI + Imei string Conn net.Conn ConnectedAt time.Time LastPing time.Time Done chan struct{} - IsAuth bool // 添加认证状态 - authTimer *time.Timer // 添加登录超时定时器 + IsAuth bool + authTimer *time.Timer } // NewServer 创建一个新的 TCP 服务器 @@ -49,7 +48,6 @@ func NewServer(address string, handler func(net.Conn)) *Server { } } -// Start 启动服务器 func (s *Server) Start() error { var err error s.listener, err = net.Listen("tcp", s.Address) @@ -59,35 +57,31 @@ func (s *Server) Start() error { fmt.Println("Listening and serving TCP on", s.Address) - // 捕获终止信号 go s.handleShutdown() for { - // 非阻塞检查关闭信号 + select { case <-s.stopChan: return nil default: } - // 接受新连接 conn, err := s.listener.Accept() if err != nil { select { case <-s.stopChan: - return nil // 服务器已停止 + return nil default: fmt.Println("Error accepting connection:", err) continue } } - // 记录活跃连接并添加到 clients 映射 s.connections.Store(conn.RemoteAddr(), conn) client := s.addClient(conn) fmt.Printf("客户端已连接: %s\n", client.ID) - // 使用 Goroutine 处理客户端 s.wg.Add(1) go func(c net.Conn, clientID string) { defer s.wg.Done() @@ -112,7 +106,6 @@ func (s *Server) handleShutdown() { s.Stop() } -// GetOnlineClients 获取所有在线客户端信息 func (s *Server) GetOnlineClients() []map[string]interface{} { s.clientsMux.RLock() defer s.clientsMux.RUnlock() @@ -146,10 +139,8 @@ func (s *Server) addClient(conn net.Conn) *Client { } s.clients[client.ID] = client - // 启动心跳检测 go client.startHeartbeat(s) - // 添加登录超时检测 client.authTimer = time.AfterFunc(time.Minute, func() { if !client.IsAuth { fmt.Printf("Client %s authentication timeout\n", client.ID) @@ -169,7 +160,7 @@ func (s *Server) removeClient(id string) { client.authTimer.Stop() client.authTimer = nil } - close(client.Done) // 停止心跳检测 + close(client.Done) delete(s.clients, id) } } @@ -179,13 +170,11 @@ func (s *Server) Stop() { s.stopOnce.Do(func() { fmt.Println("Stopping TCP server...") - // 关闭监听器 close(s.stopChan) if s.listener != nil { s.listener.Close() } - // 关闭所有连接 s.clientsMux.Lock() for id, client := range s.clients { client.Conn.Close() @@ -193,7 +182,6 @@ func (s *Server) Stop() { } s.clientsMux.Unlock() - // 等待所有 Goroutine 完成 s.wg.Wait() fmt.Println("TCP server stopped.") }) @@ -211,7 +199,7 @@ func (c *Client) startHeartbeat(s *Server) { case <-ticker.C: if time.Since(c.LastPing) > 120*time.Second { fmt.Printf("客户端 %s 心跳超时 \n", c.ID) - c.Conn.Close() // 强制关闭连接 + c.Conn.Close() return } }