init
This commit is contained in:
267
tcpserver/handler.go
Normal file
267
tcpserver/handler.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package tcpserver
|
||||
|
||||
import (
|
||||
"DT/repository"
|
||||
"DT/ws"
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TCPHandler struct {
|
||||
Server *Server
|
||||
Hub *ws.Hub
|
||||
}
|
||||
|
||||
func (h *TCPHandler) HandleClient(conn net.Conn) {
|
||||
reader := bufio.NewReader(conn)
|
||||
clientID := conn.RemoteAddr().String()
|
||||
var client *Client
|
||||
if value, ok := h.Server.GetClient(clientID); ok {
|
||||
client = value
|
||||
} else {
|
||||
fmt.Println("找不到客户端:", clientID)
|
||||
return
|
||||
}
|
||||
|
||||
broadcastMessage := func(data []byte) {
|
||||
if h.Hub != nil {
|
||||
h.Hub.Broadcast <- &ws.WsMessage{IMEI: client.Imei, Data: string(data)}
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
message, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
fmt.Println("客户端已断开连接:", err)
|
||||
break
|
||||
}
|
||||
|
||||
// 去除末尾的换行符
|
||||
message = bytes.TrimSpace(message)
|
||||
fmt.Printf("收到消息来自 %s: %s\n", conn.RemoteAddr(), string(message))
|
||||
|
||||
if !json.Valid(message) {
|
||||
fmt.Printf("来自客户端的数据非法 %s\n", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var fullMsg map[string]interface{}
|
||||
if err := json.Unmarshal(message, &fullMsg); err != nil {
|
||||
fmt.Printf("Error parsing message: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
msgType, _ := fullMsg["Type"].(string)
|
||||
if msgType == "" {
|
||||
fmt.Printf("接收到的消息类型为空 %s\n", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
switch msgType {
|
||||
case "reg":
|
||||
// 处理登录请求
|
||||
if err := h.Server.HandleAuth(client, message); err != nil {
|
||||
fmt.Printf("客户端授权失败: %v\n", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
fmt.Printf("客户端已授权: %s\n", client.Imei)
|
||||
// 广播登录消息
|
||||
broadcastMessage(message)
|
||||
|
||||
case "ping":
|
||||
// 处理心跳
|
||||
if !client.IsAuth {
|
||||
fmt.Printf("来自未授权客户端的心跳 %s\n", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
if err := h.Server.HandleHeartbeat(client, message); err != nil {
|
||||
fmt.Printf("心跳错误: %v\n", err)
|
||||
continue
|
||||
}
|
||||
// 广播心跳消息
|
||||
broadcastMessage(message)
|
||||
|
||||
case "ota":
|
||||
// 处理 OTA
|
||||
if !client.IsAuth {
|
||||
fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
if err := h.Server.HandleOta(client, message); err != nil {
|
||||
fmt.Printf("OTA 错误: %v\n", err)
|
||||
continue
|
||||
}
|
||||
// 广播 OTA 消息
|
||||
broadcastMessage(message)
|
||||
|
||||
case "start":
|
||||
// 处理 客户端实时上报数据
|
||||
if !client.IsAuth {
|
||||
fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
if err := h.Server.RealTimeReporting(client, message); err != nil {
|
||||
fmt.Printf("OTA 错误: %v\n", err)
|
||||
continue
|
||||
}
|
||||
// 广播 OTA 消息
|
||||
broadcastMessage(message)
|
||||
|
||||
default:
|
||||
// 处理其他消息类型
|
||||
if !client.IsAuth {
|
||||
fmt.Printf("来自未授权客户端的消息 %s\n", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
// 广播其他类型的消息
|
||||
broadcastMessage(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) HandleHeartbeat(client *Client, message []byte) error {
|
||||
if !client.IsAuth {
|
||||
return fmt.Errorf("unauthorized")
|
||||
}
|
||||
|
||||
var msg Message
|
||||
if err := json.Unmarshal(message, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.Type == "ping" {
|
||||
client.LastPing = time.Now()
|
||||
response := Message{
|
||||
MessageType: MessageType{Type: "pong"},
|
||||
}
|
||||
|
||||
responseData, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.Conn.Write(append(responseData, '\r', '\n'))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) HandleAuth(client *Client, message []byte) error {
|
||||
var msg Message
|
||||
if err := json.Unmarshal(message, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Type != "reg" {
|
||||
return fmt.Errorf("unauthorized")
|
||||
}
|
||||
|
||||
model, err := repository.GroupRepositorys.Device.GetDevice(map[string]interface{}{"imei": msg.Imei})
|
||||
if err != nil {
|
||||
return fmt.Errorf("设备不存在")
|
||||
}
|
||||
if msg.Pwd != model.DriverPass {
|
||||
return fmt.Errorf("设备密码不正确")
|
||||
}
|
||||
|
||||
// 更新版本号
|
||||
model.DriverVer = msg.Ver
|
||||
err = repository.GroupRepositorys.Device.UpdateDevice(model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新设备版本号失败")
|
||||
}
|
||||
|
||||
// 认证成功,停止登录超时定时器
|
||||
if client.authTimer != nil {
|
||||
client.authTimer.Stop()
|
||||
client.authTimer = nil
|
||||
}
|
||||
|
||||
// 认证成功
|
||||
client.Imei = msg.Imei
|
||||
client.IsAuth = true
|
||||
|
||||
// 发送响应
|
||||
response := Message{
|
||||
MessageType: MessageType{Type: "reg"},
|
||||
MessageTime: MessageTime{Time: time.Now().Unix()},
|
||||
}
|
||||
responseData, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = client.Conn.Write(append(responseData, '\r', '\n'))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) HandleOta(client *Client, message []byte) error {
|
||||
var msg Message
|
||||
if err := json.Unmarshal(message, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Type != "ota" {
|
||||
return fmt.Errorf("unauthorized")
|
||||
}
|
||||
|
||||
fmt.Printf("设备升级结果:%s\r\n", msg.State)
|
||||
|
||||
//model, err := repository.GroupRepositorys.Device.GetDevice(map[string]interface{}{"imei": msg.Imei})
|
||||
//if err != nil {
|
||||
// return fmt.Errorf("设备不存在")
|
||||
//}
|
||||
//if msg.Pwd != model.DriverPass {
|
||||
// return fmt.Errorf("设备密码不正确")
|
||||
//}
|
||||
|
||||
// 更新版本号
|
||||
//model.DriverVer = msg.Ver
|
||||
//err = repository.GroupRepositorys.Device.UpdateDevice(model)
|
||||
//if err != nil {
|
||||
// return fmt.Errorf("更新设备版本号失败")
|
||||
//}
|
||||
|
||||
// 认证成功,停止登录超时定时器
|
||||
//if client.authTimer != nil {
|
||||
// client.authTimer.Stop()
|
||||
// client.authTimer = nil
|
||||
//}
|
||||
|
||||
// 认证成功
|
||||
//client.Imei = msg.Imei
|
||||
//client.IsAuth = true
|
||||
|
||||
// 发送响应
|
||||
//response := Message{
|
||||
// MessageType: MessageType{Type: "reg"},
|
||||
// MessageTime: MessageTime{Time: time.Now().Unix()},
|
||||
//}
|
||||
//responseData, err := json.Marshal(response)
|
||||
//if err != nil {
|
||||
// return err
|
||||
//}
|
||||
//
|
||||
//_, err = client.Conn.Write(append(responseData, '\r', '\n'))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) RealTimeReporting(client *Client, message []byte) error {
|
||||
var msg Message
|
||||
if err := json.Unmarshal(message, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Type != "start" {
|
||||
return fmt.Errorf("unauthorized")
|
||||
}
|
||||
fmt.Printf("设备实时上报数据:%s\r\n", msg.Data)
|
||||
return nil
|
||||
}
|
||||
35
tcpserver/message.go
Normal file
35
tcpserver/message.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package tcpserver
|
||||
|
||||
type MessageType struct {
|
||||
Type string `json:"Type"`
|
||||
}
|
||||
|
||||
type MessagePassword struct {
|
||||
Pwd string `json:"Pwd,omitempty"`
|
||||
}
|
||||
type MessageImei struct {
|
||||
Imei string `json:"Imei,omitempty"`
|
||||
}
|
||||
type MessageVer struct {
|
||||
Ver string `json:"Ver,omitempty"`
|
||||
}
|
||||
type MessageTime struct {
|
||||
Time int64 `json:"Time,omitempty"`
|
||||
}
|
||||
type MessageState struct {
|
||||
State string `json:"State,omitempty"`
|
||||
}
|
||||
|
||||
type MessageData struct {
|
||||
Data string `json:"Data,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
MessageType
|
||||
MessageImei
|
||||
MessagePassword
|
||||
MessageVer
|
||||
MessageTime
|
||||
MessageState
|
||||
MessageData
|
||||
}
|
||||
225
tcpserver/tcpserver.go
Normal file
225
tcpserver/tcpserver.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package tcpserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Server 定义 TCP 服务器结构
|
||||
type Server struct {
|
||||
Address string // 监听地址
|
||||
Handler func(net.Conn) // 客户端连接处理函数
|
||||
listener net.Listener // TCP 监听器
|
||||
connections sync.Map // 活跃的客户端连接
|
||||
wg sync.WaitGroup // 等待所有 Goroutine 完成
|
||||
stopChan chan struct{} // 关闭信号
|
||||
|
||||
clients map[string]*Client
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
stopOnce sync.Once // 添加 sync.Once 来确保 Stop 只被执行一次
|
||||
}
|
||||
|
||||
// Client 定义客户端结构
|
||||
type Client struct {
|
||||
ID string
|
||||
Imei string // 添加 IMEI
|
||||
Conn net.Conn
|
||||
ConnectedAt time.Time
|
||||
LastPing time.Time
|
||||
Done chan struct{}
|
||||
IsAuth bool // 添加认证状态
|
||||
authTimer *time.Timer // 添加登录超时定时器
|
||||
}
|
||||
|
||||
// NewServer 创建一个新的 TCP 服务器
|
||||
func NewServer(address string, handler func(net.Conn)) *Server {
|
||||
return &Server{
|
||||
Address: address,
|
||||
Handler: handler,
|
||||
stopChan: make(chan struct{}),
|
||||
clients: make(map[string]*Client),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动服务器
|
||||
func (s *Server) Start() error {
|
||||
var err error
|
||||
s.listener, err = net.Listen("tcp", s.Address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("Listening and serving TCP on", s.Address)
|
||||
|
||||
// 捕获终止信号
|
||||
go s.handleShutdown()
|
||||
|
||||
for {
|
||||
// 非阻塞检查关闭信号
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
// 接受新连接
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
return nil // 服务器已停止
|
||||
default:
|
||||
fmt.Println("Error accepting connection:", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 记录活跃连接并添加到 clients 映射
|
||||
s.connections.Store(conn.RemoteAddr(), conn)
|
||||
client := s.addClient(conn)
|
||||
fmt.Printf("客户端已连接: %s\n", client.ID)
|
||||
|
||||
// 使用 Goroutine 处理客户端
|
||||
s.wg.Add(1)
|
||||
go func(c net.Conn, clientID string) {
|
||||
defer s.wg.Done()
|
||||
defer func() {
|
||||
s.connections.Delete(c.RemoteAddr())
|
||||
s.removeClient(clientID)
|
||||
c.Close()
|
||||
fmt.Printf("客户端已断开连接: %s\n", clientID)
|
||||
}()
|
||||
s.Handler(c)
|
||||
}(conn, client.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// handleShutdown 处理优雅关闭
|
||||
func (s *Server) handleShutdown() {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
<-c
|
||||
fmt.Println("\nReceived shutdown signal...")
|
||||
s.Stop()
|
||||
}
|
||||
|
||||
// GetOnlineClients 获取所有在线客户端信息
|
||||
func (s *Server) GetOnlineClients() []map[string]interface{} {
|
||||
s.clientsMux.RLock()
|
||||
defer s.clientsMux.RUnlock()
|
||||
|
||||
clients := make([]map[string]interface{}, 0, len(s.clients))
|
||||
for _, client := range s.clients {
|
||||
clientInfo := map[string]interface{}{
|
||||
"id": client.ID,
|
||||
"imei": client.Imei,
|
||||
"addr": client.Conn.RemoteAddr().String(),
|
||||
"connected_at": client.ConnectedAt,
|
||||
"last_ping": client.LastPing,
|
||||
"is_auth": client.IsAuth,
|
||||
}
|
||||
clients = append(clients, clientInfo)
|
||||
}
|
||||
return clients
|
||||
}
|
||||
|
||||
// addClient 添加新客户端
|
||||
func (s *Server) addClient(conn net.Conn) *Client {
|
||||
s.clientsMux.Lock()
|
||||
defer s.clientsMux.Unlock()
|
||||
|
||||
client := &Client{
|
||||
ID: conn.RemoteAddr().String(),
|
||||
Conn: conn,
|
||||
ConnectedAt: time.Now(),
|
||||
LastPing: time.Now(),
|
||||
Done: make(chan struct{}),
|
||||
}
|
||||
s.clients[client.ID] = client
|
||||
|
||||
// 启动心跳检测
|
||||
go client.startHeartbeat(s)
|
||||
|
||||
// 添加登录超时检测
|
||||
client.authTimer = time.AfterFunc(time.Minute, func() {
|
||||
if !client.IsAuth {
|
||||
fmt.Printf("Client %s authentication timeout\n", client.ID)
|
||||
client.Conn.Close() // 强制关闭连接
|
||||
}
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// removeClient 移除客户端
|
||||
func (s *Server) removeClient(id string) {
|
||||
s.clientsMux.Lock()
|
||||
defer s.clientsMux.Unlock()
|
||||
if client, ok := s.clients[id]; ok {
|
||||
if client.authTimer != nil {
|
||||
client.authTimer.Stop()
|
||||
client.authTimer = nil
|
||||
}
|
||||
close(client.Done) // 停止心跳检测
|
||||
delete(s.clients, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 添加一个新方法用于外部调用关闭服务
|
||||
func (s *Server) Stop() {
|
||||
s.stopOnce.Do(func() {
|
||||
fmt.Println("Stopping TCP server...")
|
||||
|
||||
// 关闭监听器
|
||||
close(s.stopChan)
|
||||
if s.listener != nil {
|
||||
s.listener.Close()
|
||||
}
|
||||
|
||||
// 关闭所有连接
|
||||
s.clientsMux.Lock()
|
||||
for id, client := range s.clients {
|
||||
client.Conn.Close()
|
||||
delete(s.clients, id)
|
||||
}
|
||||
s.clientsMux.Unlock()
|
||||
|
||||
// 等待所有 Goroutine 完成
|
||||
s.wg.Wait()
|
||||
fmt.Println("TCP server stopped.")
|
||||
})
|
||||
}
|
||||
|
||||
// 添加心跳检测方法
|
||||
func (c *Client) startHeartbeat(s *Server) {
|
||||
ticker := time.NewTicker(20 * time.Second) // 每10秒检查一次
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
if time.Since(c.LastPing) > 120*time.Second {
|
||||
fmt.Printf("客户端 %s 心跳超时 \n", c.ID)
|
||||
c.Conn.Close() // 强制关闭连接
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetClient 获取指定ID的客户端
|
||||
func (s *Server) GetClient(id string) (*Client, bool) {
|
||||
s.clientsMux.RLock()
|
||||
defer s.clientsMux.RUnlock()
|
||||
client, ok := s.clients[id]
|
||||
return client, ok
|
||||
}
|
||||
Reference in New Issue
Block a user