C++

.h

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/epoll.h>
#include <sys/poll.h>

create_fd

int create_fd() {
    // 创建socket
    auto fd = socket(AF_INET, SOCK_STREAM, 0);
    if (fd == -1) {
        printf("Create socket error");
        exit(1);
    }

    // 初始化地址、端口
    sockaddr_in server_addr{};
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(9999);
    server_addr.sin_addr.s_addr = htonl(INADDR_ANY);

    // 绑定地址、端口数据到fd
    auto ret = bind(fd, (sockaddr *) &server_addr, sizeof(server_addr));
    if (ret == -1) {
        printf("Bind error");
        exit(1);
    }

    // 监听
    ret = listen(fd, 10);
    if (ret == -1) {
        printf("Listen error");
        exit(1);
    }

    return fd;
}

select

/**
 * fd_set是一个数组,长度固定为1024
 */
void use_select() {
    auto fd = create_fd();
    // 初始化FD_SET
    fd_set read_set;
    fd_set r_set;
    FD_ZERO(&read_set);
    FD_SET(fd, &read_set);

    auto max_fd = fd;
    while (true) {
        r_set = read_set;
        auto ready = select(max_fd + 1, &r_set, nullptr, nullptr, nullptr);
        if (ready == -1) {
            printf("Select error");
            break;
        }

        // 如果可读,这里是建立连接
        if (FD_ISSET(fd, &r_set)) {
            auto client_fd = accept(fd, nullptr, nullptr);
            FD_SET(client_fd, &read_set);
            max_fd = client_fd > max_fd ? client_fd : max_fd;

            // 获取客户端地址信息
            sockaddr_in client_addr{};
            socklen_t client_addr_len = sizeof(client_addr);
            getpeername(client_fd, (sockaddr *) &client_addr, &client_addr_len);
            char client_ip[INET_ADDRSTRLEN];
            inet_ntop(AF_INET, &(client_addr.sin_addr), client_ip, INET_ADDRSTRLEN);
            printf("Client[%s:%d] connected: %d\n", client_ip, htons(client_addr.sin_port), client_fd);
        }

        for (auto i = 0; i <= max_fd; ++i) {
            // 不处理服务端fd
            if (i != fd && FD_ISSET(i, &r_set)) {
                char read_buf[1024];
                auto n_bytes = read(i, &read_buf, sizeof(read_buf));
                // 客户端出现异常
                if (n_bytes == -1) {
                    printf("Receive error");
                    FD_CLR(i, &read_set);
                    close(i);
                    continue;
                }

                if (n_bytes == 0) {
                    printf("Client disconnected");
                    FD_CLR(i, &read_set);
                    close(i);
                    continue;
                }

                printf("Received: %s\n", read_buf);

                // 写回响应
                char write_buf[] = "Received!";
                auto r = write(i, &write_buf, sizeof(write_buf));
                if (r == -1) {
                    printf("Send error");
                    FD_CLR(i, &read_set);
                    close(i);
                }
            }
        }
    }

    close(fd);
}

poll

/**
 * struct pollfd
 * {
 *      int fd;		            File descriptor to poll.
 *      short int events;		Types of events poller cares about.
 *      short int revents;		Types of events that actually occurred.
 * }
 */
void use_poll() {
    printf("poll\n");
    auto sfd = create_fd();
    // 创建pollfd结构体,是poll用于处理fd的数据结构,类似select的fd_set
    static const auto SIZE = 100;
    pollfd pfds[SIZE];
    for (auto i = 0; i < SIZE; ++i) {
        pfds[i].fd = -1;
        pfds[i].events = POLLIN;
    }
    pfds[0].fd = sfd;

    // max_fd为pollfd数组中最后一个有效元素的下标
    auto max_fd = 0;
    while (true) {
        // 返回值为已就绪的文件描述符的个数
        auto n = poll(pfds, max_fd + 1, -1);
        if (n == -1) {
            perror("poll error");
            break;
        }

        // 检测是否有新连接
        // revents表示的是实际发生的事件
        if (pfds[0].revents & POLLIN) {
            // 第二三个参数是sockaddr_in,不需要对端信息的话可以指定为nullptr
            auto cfd = accept(sfd, nullptr, nullptr);
            printf("Someone connected\n");
            for (auto i = 1; i < SIZE; ++i) {
                // 找到没被使用的fd,即fd == -1的元素
                if (pfds[i].fd == -1) {
                    pfds[i].fd = cfd;
                    max_fd = i > max_fd ? i : max_fd;
                    break;
                }
            }
        }

        // 检测是否有数据可读
        for (auto i = 1; i < max_fd + 1; ++i) {
            if (pfds[i].revents & POLLIN) {
                char buf[1024]{};
                auto n_bytes = read(pfds[i].fd, buf, sizeof(buf));
                if (n_bytes == -1) {
                    perror("Read error");
                    close(pfds[i].fd);
                    pfds[i].fd = -1;
                    continue;
                }

                if (n_bytes == 0) {
                    close(pfds[i].fd);
                    pfds[i].fd = -1;
                } else {
                    printf("客户端say: %s\n", buf);
                    write(pfds[i].fd, buf, strlen(buf) + 1);
                }
            }
        }
    }

    close(sfd);
}

epoll

void use_epoll() {
    printf("epoll\n");
    auto sfd = create_fd();
    // 创建epoll instance
    auto efd = epoll_create(1);
    epoll_event event{};
    event.data.fd = sfd;
    event.events |= EPOLLIN;
    auto ret = epoll_ctl(efd, EPOLL_CTL_ADD, sfd, &event);
    if (ret == -1) {
        perror("EPoll CTL error");
        exit(1);
    }

    // 创建epoll_event数组
    epoll_event events[100];
    int size = sizeof(events) / sizeof(epoll_event);

    while (1) {
        // events为内核传出参数,存储了已就绪的fds
        // 获取到已就绪的fd数量
        auto num = epoll_wait(efd, events, size, -1);
        for (auto i = 0; i < num; ++i) {
            auto cur_fd = events[i].data.fd;
            if (cur_fd == sfd) {
                // 有新的连接
                auto cfd = accept(cur_fd, nullptr, nullptr);
                event.data.fd = cfd;
                event.events = EPOLLIN;
                // 监听客户端的fd
                ret = epoll_ctl(efd, EPOLL_CTL_ADD, cfd, &event);
                if (ret == -1) {
                    perror("EPoll CTL error");
                    continue;
                }
            } else {
                // 读写客户端数据
                char buf[1024];
                int n_bytes = read(cur_fd, buf, sizeof(buf));
                if (n_bytes < 0) {
                    perror("READ error");
                    epoll_ctl(efd, EPOLL_CTL_DEL, cur_fd, nullptr);
                    close(cur_fd);
                    continue;
                }

                if (n_bytes == 0) {
                    // 客户端断开连接
                    epoll_ctl(efd, EPOLL_CTL_DEL, cur_fd, nullptr);
                    close(cur_fd);
                } else {
                    printf("客户端say: %s\n", buf);
                    send(cur_fd, buf, strlen(buf), 0);
                }
            }
        }
    }
}

client

int main() {
    // 创建socket
    auto client_fd = socket(AF_INET, SOCK_STREAM, 0);
    // 设置服务器ip、端口
    sockaddr_in server_addr{};
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(9999);
    ssize_t ret = inet_pton(AF_INET, "127.0.0.1", &server_addr.sin_addr);
    if (ret == -1) {
        perror("Cannot resolve ip address");
        exit(1);
    }
    // 连接
    ret = connect(client_fd, (const sockaddr *) &server_addr, sizeof(server_addr));
    if (ret == -1) {
        perror("Cannot connect to server");
        exit(1);
    }

    auto times = 20;

    while (--times >= 0) {
        char w_buf[] = "hello, world!";
        char r_buf[1024]{};

        ret = write(client_fd, w_buf, strlen(w_buf));
        if (ret == -1) {
            perror("Receive from server error");
            exit(1);
        }

        ret = read(client_fd, r_buf, sizeof(r_buf));
        printf("ret: %zd\n", ret);
        if (ret == -1) {
            perror("Send to server error");
            exit(1);
        }

        printf("Response from server: %s\n", r_buf);
        sleep(1);
    }

    close(client_fd);
}

Java

server

public class NIOServer {

    public static void main(String[] args) {
        try (Selector selector = Selector.open()) {
            ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
            // 设置为Non-blocking
            serverSocketChannel.configureBlocking(false);
            // 绑定端口
            ServerSocket socket = serverSocketChannel.socket();
            socket.bind(new InetSocketAddress("0.0.0.0", 9999));
            // 注册监听事件
            serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);

            while (true) {
                // select/poll/epoll
                int select = selector.select();
                if (select == 0) {
                    continue;
                }

                for (SelectionKey selectedKey : selector.selectedKeys()) {
                    if (selectedKey.isAcceptable()) {
                        // 创建新连接
                        SocketChannel client = serverSocketChannel.accept();
                        // 给客户端socket设置non-block
                        client.configureBlocking(false);
                        // 给客户端socket设置监听事件,此时selector也会监听客户端的socket事件
                        client.register(selector, SelectionKey.OP_READ);
                        System.out.println("Client connected: " + client.getRemoteAddress());
                    } else if (selectedKey.isReadable()) {
                        // 从客户端读取数据
                        SocketChannel channel = (SocketChannel) selectedKey.channel();
                        ByteBuffer byteBuffer = ByteBuffer.allocate(1024);
                        byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
                        // 这里读取过来的话,只会设置position,不会设置limit
                        // 所以这里要flip,设置limit和重置position
                        // limit表示下一个不可写不可读的index
                        // position表示下一个写或读的position
                        // cap表示buffer的容量
                        // mark
                        // remaining(): 剩余的字节数 = limit - position
                        try {
                            int read = channel.read(byteBuffer);
                            if (read == -1) {
                                // 连接出现异常
                                channel.close();
                                continue;
                            }
                        } catch (IOException e) {
                            if (channel.isOpen()) {
                                channel.close();
                            }
                        }

                        byteBuffer.flip();
                        byte[] bytes = new byte[byteBuffer.limit()];
                        byteBuffer.get(bytes);
                        String data = new String(bytes, StandardCharsets.UTF_8);
                        System.out.println("Received from client[" + channel.getRemoteAddress() + "]: " + data);

                        // 监听写事件
                        selectedKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE);
                    } else if (selectedKey.isWritable()) {
                        // 向客户端写入数据
                        SocketChannel channel = (SocketChannel) selectedKey.channel();
                        ByteBuffer byteBuffer = ByteBuffer.wrap("Received!".toUpperCase().getBytes(StandardCharsets.UTF_8));
                        channel.write(byteBuffer);

                        // 重新监听
                        selectedKey.interestOps(SelectionKey.OP_READ);
                    }
                }

                selector.selectedKeys().clear();
            }

//            if (serverSocketChannel.isOpen()) {
//                serverSocketChannel.close();
//            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

client

public class Client {

    public static void main(String[] args) {
        try (SocketChannel socket = SocketChannel.open()) {
            socket.connect(new InetSocketAddress("172.17.224.1", 9999));
            while (true) {
                ByteBuffer byteBuffer = ByteBuffer.wrap("Hello, I'm Java".getBytes(StandardCharsets.UTF_8));
                socket.write(byteBuffer);

                TimeUnit.SECONDS.sleep(1);
            }
        } catch (IOException | InterruptedException e) {
            e.printStackTrace();
        }
    }
}