跳转至
阅读量:

Socket Takeover

在 Meta 的《Zero Downtime Release:Disruption-free Load Balancing of a Multi-Billion User Website》论文中提到了一种 Socket Takeover 技术,

该技术能够 将一个打开的 Socket FD 从旧进程传递到新进程

相关技术

file descriptor

要理解这项技术,首先要明白 Linux 的 file descriptor。file descriptor (简称 fd)是一个抽象的指示符,以一个非负整数来表示。fd 和底层文件系统中的文件进行绑定,fd 只在打开文件的进程有效。

如果把 fd 当作一个普通的值传递给其他进程用于打开,其他进程是无法使用这个 fd 的。

Unix domain socket

Unix domain socket(简称 uds)是在 unix 环境下的一种特殊的 socket。uds 地址采用路径名的形式。与网络套接字不同,跨 Unix 域套接字的 I/O 不涉及底层设备上的操作(这使得 Unix 域套接字比在同一主机上执行 IPC 的网络套接字要快得多)。

uds 在传输数据时,除了能够像普通的 socket 那样传输数据,还能够进行特殊的 辅助数据传输(Ancillary Data Transfer )

在 Linux 上可以进行三种类型的辅助数据传输:

  • SCM_RIGHTS
  • SCM_CREDENTIALS
  • SCM_SECURITY

在实现 Socket Takeover 时就需要使用到 SCM_RIGHTS

SCM_RIGHTS 能够在不同进程之间传递 file descriptor(更准确来说是传递的文件所有权,类似于父子进程之间能够共享文件)。

实现

下面是用 Go 实现的 Socket Takeover demo,该demo中包含三个模块:client(TCP 客户端)、server(TCP server)和forward(UDS Server)。

// client.go
package main

import (
    "fmt"
    "log"
    "net"
)

func main() {
    addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:8080")
    if err != nil {
        log.Fatal(err)
    }

    for i := 0; i < 20; i++ {
        conn, err := net.DialTCP("tcp", nil, addr)
        if err != nil {
            return
        }

        conn.Write([]byte(fmt.Sprintf("seq num %d", i)))
        buff := make([]byte, 1024)
        n, err := conn.Read(buff)
        if err != nil {
            log.Fatal(err)
        }
        log.Printf("recv data: %s\n", buff[:n])
        conn.Close()
    }
}
// server.go
package main

import (
    "log"
    "net"
    "syscall"
)

const (
    forwardSocket = "/tmp/forward_demo.sock"
)

func main() {
    addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:8080")
    if err != nil {
        log.Fatal(err)
    }

    tcp, err := net.ListenTCP("tcp", addr)
    if err != nil {
        log.Fatal(err)
    }

    for {
        conn, err := tcp.AcceptTCP()
        if err != nil {
            log.Fatal(err)
        }
        handle(conn)
    }
}

func handle(conn *net.TCPConn) {
    defer conn.Close()

    err := forwardConn(conn)
    if err != nil {
        log.Printf("fail to forward conn: %v", err)

        buff := make([]byte, 1024)
        n, err := conn.Read(buff)
        if err != nil {
            log.Fatal(err)
        }
        log.Printf("recv data: %s\n", buff[:n])
        conn.Write([]byte("process by server"))
    }
}

func forwardConn(conn *net.TCPConn) error {
    addr := &net.UnixAddr{
        Name: forwardSocket,
        Net:  "unix",
    }

    unixConn, err := net.DialUnix("unix", nil, addr)
    if err != nil {
        return err
    }
    defer unixConn.Close()

    file, err := conn.File()
    if err != nil {
        return err
    }

    datan, oobn, err := unixConn.WriteMsgUnix([]byte("forward"), syscall.UnixRights(int(file.Fd())), nil)
    if err != nil {
        return err
    }
    log.Printf("%d bytes and %d oob written successfully", datan, oobn)

    return nil
}
// forward.go
package main

import (
    "errors"
    "fmt"
    "log"
    "net"
    "os"
    "syscall"
)

const (
    forwardSocket = "/tmp/forward_demo.sock"
)

func main() {
    syscall.Unlink(forwardSocket)
    unixListener, err := net.ListenUnix("unix", &net.UnixAddr{Name: forwardSocket, Net: "unix"})
    if err != nil {
        panic(err)
    }

    for {
        unixConn, err := unixListener.AcceptUnix()
        if err != nil {
            panic(err)
        }

        tcp, err := unixToTCP(unixConn)
        if err != nil {
            panic(err)
        }

        buff := make([]byte, 1024)
        n, err := tcp.Read(buff)
        if err != nil {
            panic(err)
        }
        log.Printf("tcp %s -> %s: %s", tcp.LocalAddr(), tcp.RemoteAddr(), buff[:n])
        tcp.Write([]byte("process by forward"))
        tcp.Close()
    }
}

func unixToTCP(conn *net.UnixConn) (*net.TCPConn, error) {
    msg, oob := make([]byte, 128), make([]byte, 128)

    msgn, oobn, flag, addr, err := conn.ReadMsgUnix(msg, oob)
    if err != nil {
        return nil, err
    }
    log.Println(msgn, oobn, flag, addr)
    log.Printf("recv msg: %s", msg[:msgn])

    cmsgs, err := syscall.ParseSocketControlMessage(oob[0:oobn])
    if err != nil {
        return nil, err
    }

    if len(cmsgs) != 1 {
        return nil, fmt.Errorf("expected 1 control message; got %d", len(cmsgs))
    }

    fds, err := syscall.ParseUnixRights(&cmsgs[0])
    if err != nil {
        return nil, err
    } else if len(fds) != 1 {
        return nil, errors.New("invalid number of fds received")
    }

    fd := os.NewFile(uintptr(fds[0]), "")
    if fd == nil {
        return nil, errors.New("could not open fd")
    }

    fileConn, err := net.FileConn(fd)
    if err != nil {
        return nil, err
    }

    return fileConn.(*net.TCPConn), nil
}

评论