使用go语言来实现一个网络框架,把 tcp 自定义通信协议做成一个模板
我们知道一般的网络协议为了适用大部分人的使用,会封装一些的功能,但是这样会对性能产生一定的影响,所以当我们的项目或者模块在条件允许的情况下,我们可以自己定义协议来实现高速的网络框架。
创建项目
创建一个文件夹,fast-web。新建protocol.go, request.go ,response.go, client.go, server.go。
我们的协议可以定为这种形式:
//请求
Request: //命令 //参数的个数 //参数的长度
version command argsLength {argLength arg}
1byte 1byte 4byte 4byte unknown
//响应
Response:
version reply bodyLength {body}
1byte 1byte 4byte unknown
复制代码
protocol.go
const (
ProtocolVersion = byte(1) // 协议版本号
headerLengthInProtocol = 6 // 协议中头部占用的字节数
argsLengthInProtocol = 4 // 协议中参数个数占用的字节数
argLengthInProtocol = 4 // 协议中参数长度占用的字节数
bodyLengthInProtocol = 4 // 协议体长度占用的字节数
)
var (
// 协议版本不匹配错误,如果客户端和服务端的版本不一样就会返回这个错误
ProtocolVersionMismatchErr = errors.New("protocol version between client and server doesn't match")
)
复制代码
request.go
// 从 reader 中读取请求,并解析出命令和参数。
func readRequestFrom(reader io.Reader) (command byte, args [][]byte, err error) {
// 读取头部,指定具体的大小,使用 ReadFull 读取满指定字节数据,如果数据还没传输过来,这个方法会进行等待
header := make([]byte, headerLengthInProtocol)
_, err = io.ReadFull(reader, header)
if err != nil {
return 0, nil, err
}
// 头部的第一个字节是协议版本号,拿出来判断协议版本号是否一致
version := header[0]
if version != ProtocolVersion {
return 0, nil, ProtocolVersionMismatchErr
}
// 头部的第二个字节是命令,后面的四个字节是参数个数
command = header[1]
header = header[2:]
// 所有的整数到字节数组的转换使用大端形式,所以这里使用 BigEndian 来将头部后四个字节转换为一个 uint32 数字
argsLength := binary.BigEndian.Uint32(header)
args = make([][]byte, argsLength)
if argsLength > 0 {
// 读取参数长度,同理使用大端处理,并一次性读取满参数
argLength := make([]byte, argLengthInProtocol)
for i := uint32(0); i < argsLength; i++ {
_, err = io.ReadFull(reader, argLength)
if err != nil {
return 0, nil, err
}
arg := make([]byte, binary.BigEndian.Uint32(argLength))
_, err = io.ReadFull(reader, arg)
if err != nil {
return 0, nil, err
}
args[i] = arg
}
}
return command, args, nil
}
// 将请求写入到 writer 中。
func writeRequestTo(writer io.Writer, command byte, args [][]byte) (int, error) {
// 创建一个缓存区,并将协议版本号、命令和参数个数等写入缓存区
request := make([]byte, headerLengthInProtocol)
request[0] = ProtocolVersion
request[1] = command
binary.BigEndian.PutUint32(request[2:], uint32(len(args)))
if len(args) > 0 {
// 将参数都添加到缓存区
argLength := make([]byte, argLengthInProtocol)
for _, arg := range args {
binary.BigEndian.PutUint32(argLength, uint32(len(arg)))
request = append(request, argLength...)
request = append(request, arg...)
}
}
return writer.Write(request)
}
复制代码
response.go
const (
SuccessReply = 0 // 成功的答复码
ErrorReply = 1 // 发生错误的答复码
)
// 从 reader 中读取数据并解析出响应内容。
func readResponseFrom(reader io.Reader) (reply byte, body []byte, err error) {
// 读取指定字节数据
header := make([]byte, headerLengthInProtocol)
_, err = io.ReadFull(reader, header)
if err != nil {
return ErrorReply, nil, err
}
// 头部的第一个字节是协议版本号,如果版本号不一致很可能解析不成功,所以需要检查
// 实际上这边可以做一个降级处理,就是尝试以响应的版本号去解析
version := header[0]
if version != ProtocolVersion {
return ErrorReply, nil, errors.New("response " + ProtocolVersionMismatchErr.Error())
}
// 从头部解析出答复码还有响应体长度,同理,使用大端解析数字
reply = header[1]
header = header[2:]
body = make([]byte, binary.BigEndian.Uint32(header))
_, err = io.ReadFull(reader, body)
if err != nil {
return ErrorReply, nil, err
}
return reply, body, nil
}
// 将响应写入到 writer。
func writeResponseTo(writer io.Writer, reply byte, body []byte) (int, error) {
// 将响应体相关数据写入响应缓存区,并发送
bodyLengthBytes := make([]byte, bodyLengthInProtocol)
binary.BigEndian.PutUint32(bodyLengthBytes, uint32(len(body)))
response := make([]byte, 2, headerLengthInProtocol+len(body))
response[0] = ProtocolVersion
response[1] = reply
response = append(response, bodyLengthBytes...)
response = append(response, body...)
return writer.Write(response)
}
// 向 writer 写入错误信息为 msg 的响应。
func writeErrorResponseTo(writer io.Writer, msg string) (int, error) {
return writeResponseTo(writer, ErrorReply, []byte(msg))
}
复制代码
client.go
// 客户端结构。
type Client struct {
// 和服务端建立的连接。
conn net.Conn
// 通往服务端的读取器。
reader io.Reader
}
// 创建新的客户端。
func NewClient(network string, address string) (*Client, error) {
// 和服务端建立连接
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return &Client{
conn: conn,
reader: bufio.NewReader(conn),
}, nil
}
// 执行命令。
func (c *Client) Do(command byte, args [][]byte) (body []byte, err error) {
// 包装请求然后发送给服务端
_, err = writeRequestTo(c.conn, command, args)
if err != nil {
return nil, err
}
// 读取服务端返回的响应
reply, body, err := readResponseFrom(c.reader)
if err != nil {
return body, err
}
// 如果是错误答复码,将内容包装成 error 并返回
if reply == ErrorReply {
return body, errors.New(string(body))
}
return body, nil
}
// 关闭客户端。
func (c *Client) Close() error {
return c.conn.Close()
}
复制代码
server.go
var (
// 找不到对应的命令处理器错误
commandHandlerNotFoundErr = errors.New("failed to find a handler of command")
)
// 服务端结构。
type Server struct {
// 监听器,这个应该大家都很熟悉了吧。
listener net.Listener
// 命令处理器,通过命令可以找到对应的处理器。
handlers map[byte]func(args [][]byte) (body []byte, err error)
}
// 创建新的服务端。
func NewServer() *Server {
return &Server{
handlers: map[byte]func(args [][]byte) (body []byte, err error){},
}
}
// 注册命令处理器。
func (s *Server) RegisterHandler(command byte, handler func(args [][]byte) (body []byte, err error)) {
s.handlers[command] = handler
}
// 监听并服务于 network 和 address。
func (s *Server) ListenAndServe(network string, address string) (err error) {
// 监听指定地址
s.listener, err = net.Listen(network, address)
if err != nil {
return err
}
// 使用 WaitGroup 记录连接数,并等待所有连接处理完毕
wg := &sync.WaitGroup{}
for {
// 等待客户端连接
conn, err := s.listener.Accept()
if err != nil {
// This error means listener has been closed
// See src/internal/poll/fd.go@ErrNetClosing
if strings.Contains(err.Error(), "use of closed network connection") {
break
}
continue
}
// 记录连接
wg.Add(1)
go func() {
defer wg.Done()
s.handleConn(conn)
}()
}
// 等待所有连接处理完毕
wg.Wait()
return nil
}
// 处理连接。
func (s *Server) handleConn(conn net.Conn) {
// 将连接包装成缓冲读取器,提高读取的性能
reader := bufio.NewReader(conn)
defer conn.Close()
for {
// 读取并解析请求请求
command, args, err := readRequestFrom(reader)
if err != nil {
if err == ProtocolVersionMismatchErr {
continue
}
return
}
// 处理请求
reply, body, err := s.handleRequest(command, args)
if err != nil {
writeErrorResponseTo(conn, err.Error())
continue
}
// 发送处理结果的响应
_, err = writeResponseTo(conn, reply, body)
if err != nil {
continue
}
}
}
// 处理请求。
func (s *Server) handleRequest(command byte, args [][]byte) (reply byte, body []byte, err error) {
// 从命令处理器集合中选出对应的处理器
handle, ok := s.handlers[command]
if !ok {
return ErrorReply, nil, commandHandlerNotFoundErr
}
// 将处理结果返回
body, err = handle(args)
if err != nil {
return ErrorReply, body, err
}
return SuccessReply, body, err
}
// 关闭服务端的方法。
func (s *Server) Close() error {
if s.listener == nil {
return nil
}
return s.listener.Close()
}
复制代码
注:代码提取自vex框架。