为什么需要Keepalive机制?
-
防止中间设备造成的连接中断。许多网络中间设备(如 NAT 网关、防火墙、负载均衡器)会主动关闭长时间无数据交互的 TCP 连接,以节省资源。
-
检测连接的健康状态。在网络异常(如断网、服务崩溃)时,TCP 层可能不会立即感知连接失效,导致应用层误认为连接仍有效,后续请求会失败。
-
优化资源利用率。长连接可能因客户端异常退出或网络中断导致服务端资源泄漏(如未关闭的连接占用内存和线程)。
实现流程分析
为了在我们的 IM 项目中实现Keepalive机制,我们可以首先研究一下 gRPC 中的实现。以下是核心部分源码:
grpc@v1.37.0\internal\transport\http2_server.go
// keepalive running in a separate goroutine does the following:
// 1. Gracefully closes an idle connection after a duration of keepalive.MaxConnectionIdle.
// 2. Gracefully closes any connection after a duration of keepalive.MaxConnectionAge.
// 3. Forcibly closes a connection after an additive period of keepalive.MaxConnectionAgeGrace over keepalive.MaxConnectionAge.
// 4. Makes sure a connection is alive by sending pings with a frequency of keepalive.Time and closes a non-responsive connection
// after an additional duration of keepalive.Timeout.
func (t *http2Server) keepalive() {
p := &ping{}
// True iff a ping has been sent, and no data has been received since then.
outstandingPing := false
// Amount of time remaining before which we should receive an ACK for the
// last sent ping.
kpTimeoutLeft := time.Duration(0)
// Records the last value of t.lastRead before we go block on the timer.
// This is required to check for read activity since then.
prevNano := time.Now().UnixNano()
// Initialize the different timers to their default values.
idleTimer := time.NewTimer(t.kp.MaxConnectionIdle)
ageTimer := time.NewTimer(t.kp.MaxConnectionAge)
kpTimer := time.NewTimer(t.kp.Time)
defer func() {
// We need to drain the underlying channel in these timers after a call
// to Stop(), only if we are interested in resetting them. Clearly we
// are not interested in resetting them here.
idleTimer.Stop()
ageTimer.Stop()
kpTimer.Stop()
}()
for {
select {
case <-idleTimer.C:
t.mu.Lock()
idle := t.idle
if idle.IsZero() { // The connection is non-idle.
t.mu.Unlock()
idleTimer.Reset(t.kp.MaxConnectionIdle)
continue
}
val := t.kp.MaxConnectionIdle - time.Since(idle)
t.mu.Unlock()
if val <= 0 {
// The connection has been idle for a duration of keepalive.MaxConnectionIdle or more.
// Gracefully close the connection.
t.drain(http2.ErrCodeNo, []byte{})
return
}
idleTimer.Reset(val)
case <-ageTimer.C:
t.drain(http2.ErrCodeNo, []byte{})
ageTimer.Reset(t.kp.MaxConnectionAgeGrace)
select {
case <-ageTimer.C:
// Close the connection after grace period.
if logger.V(logLevel) {
logger.Infof("transport: closing server transport due to maximum connection age.")
}
t.Close()
case <-t.done:
}
return
case <-kpTimer.C:
lastRead := atomic.LoadInt64(&t.lastRead)
if lastRead > prevNano {
// There has been read activity since the last time we were
// here. Setup the timer to fire at kp.Time seconds from
// lastRead time and continue.
outstandingPing = false
kpTimer.Reset(time.Duration(lastRead) + t.kp.Time - time.Duration(time.Now().UnixNano()))
prevNano = lastRead
continue
}
if outstandingPing && kpTimeoutLeft <= 0 {
if logger.V(logLevel) {
logger.Infof("transport: closing server transport due to idleness.")
}
t.Close()
return
}
if !outstandingPing {
if channelz.IsOn() {
atomic.AddInt64(&t.czData.kpCount, 1)
}
t.controlBuf.put(p)
kpTimeoutLeft = t.kp.Timeout
outstandingPing = true
}
// The amount of time to sleep here is the minimum of kp.Time and
// timeoutLeft. This will ensure that we wait only for kp.Time
// before sending out the next ping (for cases where the ping is
// acked).
sleepDuration := minTime(t.kp.Time, kpTimeoutLeft)
kpTimeoutLeft -= sleepDuration
kpTimer.Reset(sleepDuration)
case <-t.done:
return
}
}
}
可以看到,这部分的实现实际上是比较简单的。主要由三个计时器控制:
-
idleTimer,管理连接的空闲时间。
-
ageTimer,管理连接的生命周期时间。
-
kpTimer,通过定时发送 ping 管理连接的健康状态。
乍一看,idleTimer 和 kpTimer 有些重复。但实际上,kpTimer 的时间一定是小于空闲时间的,因此它可以更加精密地检测到不健康的状态,从而减少系统的资源占用。这两个计时器是相辅相成的关系。
要在我们的 IM 服务中实现健康检测,我们可以这样设计:
客户端:
用计时器维护自身空闲时间,若空闲时间内没有发送消息,则发送一条心跳消息。收到服务端回复则重置计时器,否则视为断开连接。
服务端:
维护每个连接的空闲时间。若中途收到正常消息或心跳消息则重置。超过空闲时间直接关闭连接。
具体实现
首先,简单的 websocket.Conn 已经无法满足我们的功能需求,因此我们重新封装一个连接对象:
im/ws/websocket/connection.go
package websocket
import (
"net/http"
"time"
"github.com/gorilla/websocket"
)
type Conn struct {
mu sync.Mutex
conn *websocket.Conn
s *Server
idle time.Time
maxConnIdle time.Duration
done chan struct{}
}
func NewConn(s *Server, w http.ResponseWriter, r *http.Request) *Conn {
c, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
s.Errorf("failed to upgrade: %v", err)
return nil
}
conn := &Conn{
conn: c,
s: s,
idle: time.Now(),
maxConnIdle: defaultMaxConnIdle,
done: make(chan struct{}),
}
go conn.keepalive()
return conn
}
由于我们的设计较为简单,服务端只使用了一个计时器,所以将 gRPC 的 idleTimer 的逻辑复制过来即可:
func (c *Conn) keepalive() {
idleTimer := time.NewTimer(c.maxConnIdle)
defer idleTimer.Stop()
for {
select {
case <-idleTimer.C:
c.mu.Lock()
idle := c.idle
if idle.IsZero() {
c.mu.Unlock()
idleTimer.Reset(c.maxConnIdle)
continue
}
val := c.maxConnIdle - time.Since(idle)
c.mu.Unlock()
if val <= 0 {
c.s.CloseConn(c)
return
}
idleTimer.Reset(val)
case <-c.done:
return
}
}
}
为了更新 idle ,我们需要重构原本的读取/接收消息的机制,给新的 Conn 对象提供读写方法:
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
c.mu.Lock()
defer c.mu.Unlock()
messageType, p, err = c.conn.ReadMessage()
c.idle = time.Time{}
return messageType, p, err
}
func (c *Conn) WriteMessage(messageType int, data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
err := c.conn.WriteMessage(messageType, data)
c.idle = time.Now()
return err
}
注意!上面的写法是错误的!由于 ReadMessage 操作是阻塞的,所以必须在获取锁前调用,否则会导致 keepalive 无法获取锁。
正确的写法如下:
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
messageType, p, err = c.conn.ReadMessage()
c.mu.Lock()
defer c.mu.Unlock()
c.idle = time.Time{}
return messageType, p, err
}
func (c *Conn) WriteMessage(messageType int, data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
err := c.conn.WriteMessage(messageType, data)
c.idle = time.Now()
return err
}
同时 Close 方法也需要重构。这里注意,由于一些可能出现的并发情况,c.done 可能会被重复关闭,导致 panic。所以我们先 select 判断一下能否读取,因为 done 是无缓冲的 channel,在未被关闭的时候读取阻塞。
func (c *Conn) Close() error {
select {
case <-c.done:
default:
close(c.done)
}
return c.conn.Close()
}
随后将服务器中用到 websocket.Conn 的地方改成新的 Conn 即可。
为了区分客户端发送的普通消息和 ping 消息,我们更改以下 Message 的定义,增加一项 FrameType:
im/ws/websocket/message.go
package websocket
type FrameType uint8
const (
FrameData FrameType = iota
FramePing
)
type Message struct {
FrameType FrameType `json:"frame_type"`
Method string `json:"method"`
FromID string `json:"from_id"`
Data any `json:"data"`
}
func NewMessage(fromID string, data any) *Message {
return &Message{
FrameType: FrameData,
FromID: fromID,
Data: data,
}
}
随后更改连接的处理,对 ping 单独处理:
im/ws/websocket/server.go
func (s *Server) handleConn(conn *Conn) {
for {
_, msg, err := conn.ReadMessage()
if err != nil {
s.Errorf("failed to read message: %v", err)
s.CloseConn(conn)
return
}
var message Message
if err := json.Unmarshal(msg, &message); err != nil {
s.Errorf("failed to unmarshal message: %v", err)
s.CloseConn(conn)
return
}
switch message.FrameType {
case FramePing:
s.SendToConns(&Message{
FrameType: FramePing,
}, conn)
case FrameData:
if handler, ok := s.routes[message.Method]; ok {
handler(s, conn, &message)
} else {
s.SendToConns(&Message{
FrameType: FrameData,
Data: []byte("method not found"),
}, conn)
}
}
}
}
在 options 中增加传入 maxConnIdle:
im/ws/websocket/options.go
package websocket
import "time"
type ServerOptions func(*ServerOption)
type ServerOption struct {
Auth
pattern string
maxConnIdle time.Duration
}
func newServerOption(opts ...ServerOptions) *ServerOption {
opt := &ServerOption{
Auth: NewAuth(),
pattern: "/ws",
maxConnIdle: defaultMaxConnIdle,
}
for _, o := range opts {
o(opt)
}
return opt
}
func WithAuth(auth Auth) ServerOptions {
return func(o *ServerOption) {
o.Auth = auth
}
}
func WithPattern(pattern string) ServerOptions {
return func(o *ServerOption) {
o.pattern = pattern
}
}
func WithMaxConnIdle(maxConnIdle time.Duration) ServerOptions {
return func(o *ServerOption) {
if maxConnIdle > 0 {
o.maxConnIdle = maxConnIdle
}
}
}
在入口处,我们设置一个十秒的超时时间,然后用 apipost 测试是否正常关闭连接即可。
var configFile = flag.String("f", "etc/im.yaml", "the config file")
func main() {
flag.Parse()
var c config.Config
conf.MustLoad(*configFile, &c)
if err := c.SetUp(); err != nil {
panic(err)
}
ctx := svc.NewServiceContext(c)
srv := websocket.NewServer(c.ListenOn,
websocket.WithAuth(handler.NewJwtAuth(ctx)),
websocket.WithMaxConnIdle(10*time.Second),
)
handler.RegisterRoutes(srv, ctx)
fmt.Println("start websocket server at ", c.ListenOn)
srv.Start()
defer srv.Stop()
}