C++
.h
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/epoll.h>
#include <sys/poll.h>
create_fd
int create_fd() {
// 创建socket
auto fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd == -1) {
printf("Create socket error");
exit(1);
}
// 初始化地址、端口
sockaddr_in server_addr{};
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(9999);
server_addr.sin_addr.s_addr = htonl(INADDR_ANY);
// 绑定地址、端口数据到fd
auto ret = bind(fd, (sockaddr *) &server_addr, sizeof(server_addr));
if (ret == -1) {
printf("Bind error");
exit(1);
}
// 监听
ret = listen(fd, 10);
if (ret == -1) {
printf("Listen error");
exit(1);
}
return fd;
}
select
/**
* fd_set是一个数组,长度固定为1024
*/
void use_select() {
auto fd = create_fd();
// 初始化FD_SET
fd_set read_set;
fd_set r_set;
FD_ZERO(&read_set);
FD_SET(fd, &read_set);
auto max_fd = fd;
while (true) {
r_set = read_set;
auto ready = select(max_fd + 1, &r_set, nullptr, nullptr, nullptr);
if (ready == -1) {
printf("Select error");
break;
}
// 如果可读,这里是建立连接
if (FD_ISSET(fd, &r_set)) {
auto client_fd = accept(fd, nullptr, nullptr);
FD_SET(client_fd, &read_set);
max_fd = client_fd > max_fd ? client_fd : max_fd;
// 获取客户端地址信息
sockaddr_in client_addr{};
socklen_t client_addr_len = sizeof(client_addr);
getpeername(client_fd, (sockaddr *) &client_addr, &client_addr_len);
char client_ip[INET_ADDRSTRLEN];
inet_ntop(AF_INET, &(client_addr.sin_addr), client_ip, INET_ADDRSTRLEN);
printf("Client[%s:%d] connected: %d\n", client_ip, htons(client_addr.sin_port), client_fd);
}
for (auto i = 0; i <= max_fd; ++i) {
// 不处理服务端fd
if (i != fd && FD_ISSET(i, &r_set)) {
char read_buf[1024];
auto n_bytes = read(i, &read_buf, sizeof(read_buf));
// 客户端出现异常
if (n_bytes == -1) {
printf("Receive error");
FD_CLR(i, &read_set);
close(i);
continue;
}
if (n_bytes == 0) {
printf("Client disconnected");
FD_CLR(i, &read_set);
close(i);
continue;
}
printf("Received: %s\n", read_buf);
// 写回响应
char write_buf[] = "Received!";
auto r = write(i, &write_buf, sizeof(write_buf));
if (r == -1) {
printf("Send error");
FD_CLR(i, &read_set);
close(i);
}
}
}
}
close(fd);
}
poll
/**
* struct pollfd
* {
* int fd; File descriptor to poll.
* short int events; Types of events poller cares about.
* short int revents; Types of events that actually occurred.
* }
*/
void use_poll() {
printf("poll\n");
auto sfd = create_fd();
// 创建pollfd结构体,是poll用于处理fd的数据结构,类似select的fd_set
static const auto SIZE = 100;
pollfd pfds[SIZE];
for (auto i = 0; i < SIZE; ++i) {
pfds[i].fd = -1;
pfds[i].events = POLLIN;
}
pfds[0].fd = sfd;
// max_fd为pollfd数组中最后一个有效元素的下标
auto max_fd = 0;
while (true) {
// 返回值为已就绪的文件描述符的个数
auto n = poll(pfds, max_fd + 1, -1);
if (n == -1) {
perror("poll error");
break;
}
// 检测是否有新连接
// revents表示的是实际发生的事件
if (pfds[0].revents & POLLIN) {
// 第二三个参数是sockaddr_in,不需要对端信息的话可以指定为nullptr
auto cfd = accept(sfd, nullptr, nullptr);
printf("Someone connected\n");
for (auto i = 1; i < SIZE; ++i) {
// 找到没被使用的fd,即fd == -1的元素
if (pfds[i].fd == -1) {
pfds[i].fd = cfd;
max_fd = i > max_fd ? i : max_fd;
break;
}
}
}
// 检测是否有数据可读
for (auto i = 1; i < max_fd + 1; ++i) {
if (pfds[i].revents & POLLIN) {
char buf[1024]{};
auto n_bytes = read(pfds[i].fd, buf, sizeof(buf));
if (n_bytes == -1) {
perror("Read error");
close(pfds[i].fd);
pfds[i].fd = -1;
continue;
}
if (n_bytes == 0) {
close(pfds[i].fd);
pfds[i].fd = -1;
} else {
printf("客户端say: %s\n", buf);
write(pfds[i].fd, buf, strlen(buf) + 1);
}
}
}
}
close(sfd);
}
epoll
void use_epoll() {
printf("epoll\n");
auto sfd = create_fd();
// 创建epoll instance
auto efd = epoll_create(1);
epoll_event event{};
event.data.fd = sfd;
event.events |= EPOLLIN;
auto ret = epoll_ctl(efd, EPOLL_CTL_ADD, sfd, &event);
if (ret == -1) {
perror("EPoll CTL error");
exit(1);
}
// 创建epoll_event数组
epoll_event events[100];
int size = sizeof(events) / sizeof(epoll_event);
while (1) {
// events为内核传出参数,存储了已就绪的fds
// 获取到已就绪的fd数量
auto num = epoll_wait(efd, events, size, -1);
for (auto i = 0; i < num; ++i) {
auto cur_fd = events[i].data.fd;
if (cur_fd == sfd) {
// 有新的连接
auto cfd = accept(cur_fd, nullptr, nullptr);
event.data.fd = cfd;
event.events = EPOLLIN;
// 监听客户端的fd
ret = epoll_ctl(efd, EPOLL_CTL_ADD, cfd, &event);
if (ret == -1) {
perror("EPoll CTL error");
continue;
}
} else {
// 读写客户端数据
char buf[1024];
int n_bytes = read(cur_fd, buf, sizeof(buf));
if (n_bytes < 0) {
perror("READ error");
epoll_ctl(efd, EPOLL_CTL_DEL, cur_fd, nullptr);
close(cur_fd);
continue;
}
if (n_bytes == 0) {
// 客户端断开连接
epoll_ctl(efd, EPOLL_CTL_DEL, cur_fd, nullptr);
close(cur_fd);
} else {
printf("客户端say: %s\n", buf);
send(cur_fd, buf, strlen(buf), 0);
}
}
}
}
}
client
int main() {
// 创建socket
auto client_fd = socket(AF_INET, SOCK_STREAM, 0);
// 设置服务器ip、端口
sockaddr_in server_addr{};
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(9999);
ssize_t ret = inet_pton(AF_INET, "127.0.0.1", &server_addr.sin_addr);
if (ret == -1) {
perror("Cannot resolve ip address");
exit(1);
}
// 连接
ret = connect(client_fd, (const sockaddr *) &server_addr, sizeof(server_addr));
if (ret == -1) {
perror("Cannot connect to server");
exit(1);
}
auto times = 20;
while (--times >= 0) {
char w_buf[] = "hello, world!";
char r_buf[1024]{};
ret = write(client_fd, w_buf, strlen(w_buf));
if (ret == -1) {
perror("Receive from server error");
exit(1);
}
ret = read(client_fd, r_buf, sizeof(r_buf));
printf("ret: %zd\n", ret);
if (ret == -1) {
perror("Send to server error");
exit(1);
}
printf("Response from server: %s\n", r_buf);
sleep(1);
}
close(client_fd);
}
Java
server
public class NIOServer {
public static void main(String[] args) {
try (Selector selector = Selector.open()) {
ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
// 设置为Non-blocking
serverSocketChannel.configureBlocking(false);
// 绑定端口
ServerSocket socket = serverSocketChannel.socket();
socket.bind(new InetSocketAddress("0.0.0.0", 9999));
// 注册监听事件
serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);
while (true) {
// select/poll/epoll
int select = selector.select();
if (select == 0) {
continue;
}
for (SelectionKey selectedKey : selector.selectedKeys()) {
if (selectedKey.isAcceptable()) {
// 创建新连接
SocketChannel client = serverSocketChannel.accept();
// 给客户端socket设置non-block
client.configureBlocking(false);
// 给客户端socket设置监听事件,此时selector也会监听客户端的socket事件
client.register(selector, SelectionKey.OP_READ);
System.out.println("Client connected: " + client.getRemoteAddress());
} else if (selectedKey.isReadable()) {
// 从客户端读取数据
SocketChannel channel = (SocketChannel) selectedKey.channel();
ByteBuffer byteBuffer = ByteBuffer.allocate(1024);
byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
// 这里读取过来的话,只会设置position,不会设置limit
// 所以这里要flip,设置limit和重置position
// limit表示下一个不可写不可读的index
// position表示下一个写或读的position
// cap表示buffer的容量
// mark
// remaining(): 剩余的字节数 = limit - position
try {
int read = channel.read(byteBuffer);
if (read == -1) {
// 连接出现异常
channel.close();
continue;
}
} catch (IOException e) {
if (channel.isOpen()) {
channel.close();
}
}
byteBuffer.flip();
byte[] bytes = new byte[byteBuffer.limit()];
byteBuffer.get(bytes);
String data = new String(bytes, StandardCharsets.UTF_8);
System.out.println("Received from client[" + channel.getRemoteAddress() + "]: " + data);
// 监听写事件
selectedKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE);
} else if (selectedKey.isWritable()) {
// 向客户端写入数据
SocketChannel channel = (SocketChannel) selectedKey.channel();
ByteBuffer byteBuffer = ByteBuffer.wrap("Received!".toUpperCase().getBytes(StandardCharsets.UTF_8));
channel.write(byteBuffer);
// 重新监听
selectedKey.interestOps(SelectionKey.OP_READ);
}
}
selector.selectedKeys().clear();
}
// if (serverSocketChannel.isOpen()) {
// serverSocketChannel.close();
// }
} catch (IOException e) {
e.printStackTrace();
}
}
}
client
public class Client {
public static void main(String[] args) {
try (SocketChannel socket = SocketChannel.open()) {
socket.connect(new InetSocketAddress("172.17.224.1", 9999));
while (true) {
ByteBuffer byteBuffer = ByteBuffer.wrap("Hello, I'm Java".getBytes(StandardCharsets.UTF_8));
socket.write(byteBuffer);
TimeUnit.SECONDS.sleep(1);
}
} catch (IOException | InterruptedException e) {
e.printStackTrace();
}
}
}