Skip to content

Commit

Permalink
fix:kqueue bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chenxuan520 committed Oct 11, 2024
1 parent bb5cdf6 commit 0191f85
Show file tree
Hide file tree
Showing 14 changed files with 139 additions and 56 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
## TODO
- [ ] 支持设置LT和ET模式下的事件触发方式
- [ ] 支持设置连接的超时时间
- [ ] 支持UDP协议(添加测试)
- [x] 支持UDP协议(添加测试)
- [ ] 支持Http协议(彻底迁移 cppweb -> cppnet)
- [x] 抽象出epoll层
- [ ] 支持SSL
- [ ] accept 改造
- [x] accept 改造
15 changes: 14 additions & 1 deletion src/cppnet/server/io_multiplexing/epoll.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#ifdef __linux__

#include "epoll.hpp"
#include "utils/const.hpp"
#include <cstring>
Expand Down Expand Up @@ -39,6 +41,11 @@ int Epoll::RemoveSoc(const Socket &fd) {
void Epoll::Close() { epoll_fd_.Close(); }

int Epoll::Loop(NotifyCallBack callback) {
if (callback == nullptr) {
err_msg_ = "[logicerr]:" + std::string("callback is nullptr");
return kLogicErr;
}

while (loop_flag_) {
struct epoll_event evs[max_event_num_];
int nfds = epoll_wait(epoll_fd_.fd(), evs, max_event_num_, -1);
Expand All @@ -50,8 +57,12 @@ int Epoll::Loop(NotifyCallBack callback) {
return kSysErr;
}
for (int i = 0; i < nfds; ++i) {
if (evs[i].events & EPOLLRDHUP || evs[i].events & EPOLLERR) {
if (evs[i].events & EPOLLRDHUP) {
callback(*this, evs[i].data.fd, kIOEventLeave);

} else if (evs[i].events & EPOLLERR) {
callback(*this, evs[i].data.fd, kIOEventError);

} else if (evs[i].events & EPOLLIN) {
callback(*this, evs[i].data.fd, kIOEventRead);
}
Expand All @@ -70,3 +81,5 @@ int Epoll::UpdateEpollEvents(int epfd, int op, int fd, int event) {
}

} // namespace cppnet

#endif
1 change: 1 addition & 0 deletions src/cppnet/server/io_multiplexing/io_multiplexing_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class IOMultiplexingBase {
enum IOEvent {
kIOEventRead = 1,
kIOEventLeave = 2,
kIOEventError = 3,
};
using NotifyCallBack =
std::function<void(IOMultiplexingBase &, Socket, IOEvent)>;
Expand Down
24 changes: 15 additions & 9 deletions src/cppnet/server/io_multiplexing/kqueue.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#include "kqueue.hpp"
#include "utils/const.hpp"

#ifdef __APPLE__

#include "kqueue.hpp"
#include "utils/const.hpp"
#include <sys/event.h>

namespace cppnet {
Expand All @@ -13,7 +12,7 @@ KQueue::~KQueue() { Close(); }

int KQueue::Init() {
kq_fd_ = kqueue();
if (kq_fd_.status() == Socket::kInit) {
if (kq_fd_.status() != Socket::kInit) {
err_msg_ = "[syserr]:" + std::string(strerror(errno));
return kSysErr;
}
Expand Down Expand Up @@ -41,8 +40,13 @@ int KQueue::RemoveSoc(const Socket &fd) {
}

int KQueue::Loop(NotifyCallBack callback) {
if (callback == nullptr) {
err_msg_ = "[logicerr]:" + std::string("callback is nullptr");
return kLogicErr;
}

while (loop_flag_) {
struct kevent evs[1024];
struct kevent evs[max_event_num_];
int nfds = kevent(kq_fd_.fd(), nullptr, 0, evs, 1024, nullptr);
if (nfds < 0) {
if (errno == EINTR) {
Expand All @@ -51,10 +55,12 @@ int KQueue::Loop(NotifyCallBack callback) {
return kSysErr;
}
for (int i = 0; i < nfds; ++i) {
if (evs[i].filter == EVFILT_READ) {
if (callback != nullptr) {
callback(*this, evs[i].ident);
}
if (evs[i].flags & EV_EOF) {
callback(*this, evs[i].ident, kIOEventLeave);
} else if (evs[i].flags & EV_ERROR) {
callback(*this, evs[i].ident, kIOEventError);
} else if (evs[i].filter == EVFILT_READ) {
callback(*this, evs[i].ident, kIOEventRead);
}
}
}
Expand Down
81 changes: 47 additions & 34 deletions src/cppnet/server/tcp_server.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
#include "tcp_server.hpp"
#include "../utils/const.hpp"
#include <arpa/inet.h>
#include <fcntl.h>
#include <iostream>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <ostream>
#include <signal.h>
#include "io_multiplexing/io_multiplexing_factory.hpp"

#include <string.h>
#include <unistd.h>

Expand Down Expand Up @@ -66,18 +60,19 @@ int TcpServer::AddSoc(const Socket &soc) {

void TcpServer::HandleAccept() {
Address addr;
socklen_t in_len = sizeof(in_addr);

auto new_socket = listenfd_.Accept(addr, &in_len);
auto new_socket = listenfd_.Accept(addr);
if (new_socket.status() != Socket::kInit) {
err_msg_ = "[syserr]:" + std::string(strerror(errno));
event_callback_(kEventError, *this, new_socket);
return;
}

// ET mode need set to non block
auto rc = new_socket.SetNoBlock();
if (rc < 0) {
err_msg_ = "[syserr]:" + std::string(strerror(errno));
event_callback_(kEventError, *this, new_socket);
new_socket.Close();
return;
}
Expand All @@ -86,32 +81,34 @@ void TcpServer::HandleAccept() {
rc = io_multiplexing_->MonitorSoc(new_socket);
if (rc < 0) {
err_msg_ = "[syserr]:" + std::string(strerror(errno));
event_callback_(kEventError, *this, new_socket);
new_socket.Close();
return;
}

if (event_callback_) {
event_callback_(kEventAccept, *this, new_socket);
}
event_callback_(kEventAccept, *this, new_socket);
}

void TcpServer::HandleRead(int fd) {
if (event_callback_) {
Socket soc(fd);
event_callback_(kEventRead, *this, soc);
}
Socket soc(fd);
event_callback_(kEventRead, *this, soc);
}

void TcpServer::HandleLeave(int fd) {
Socket soc(fd);
if (event_callback_) {
event_callback_(kEventLeave, *this, soc);
}
event_callback_(kEventLeave, *this, soc);
RemoveSoc(soc);
soc.Close();
}

void TcpServer::HandleError(int fd) {
Socket soc(fd);
event_callback_(kEventError, *this, soc);
RemoveSoc(soc);
soc.Close();
}

int TcpServer::EpollLoop() {
int TcpServer::EventLoop() {
if (listenfd_.status() != Socket::kInit) {
err_msg_ = "[logicerr]:epfd or listenfd not init";
return kLogicErr;
Expand All @@ -122,21 +119,32 @@ int TcpServer::EpollLoop() {
return kLogicErr;
}

auto callback = [this](IOMultiplexingBase &, Socket fd,
IOMultiplexingBase::IOEvent event) {
if (fd == listenfd_) {
HandleAccept();
} else if (event == IOMultiplexingBase::kIOEventLeave) {
HandleLeave(fd.fd());
} else if (event == IOMultiplexingBase::kIOEventRead) {
HandleRead(fd.fd());
switch (mode_) {

case kIOMultiplexing: {
auto callback = [this](IOMultiplexingBase &, Socket fd,
IOMultiplexingBase::IOEvent event) {
if (fd == listenfd_) {
HandleAccept();
} else if (event == IOMultiplexingBase::kIOEventLeave) {
HandleLeave(fd.fd());
} else if (event == IOMultiplexingBase::kIOEventRead) {
HandleRead(fd.fd());
} else if (event == IOMultiplexingBase::kIOEventError) {
HandleError(fd.fd());
} else {
err_msg_ = "[logicerr]:unknown event";
}
};
auto rc = io_multiplexing_->Loop(callback);
if (rc < 0) {
err_msg_ = "[syserr]:" + std::string(strerror(errno));
return kSysErr;
}
};
} break;

auto rc = io_multiplexing_->Loop(callback);
if (rc < 0) {
err_msg_ = "[syserr]:" + std::string(strerror(errno));
return kSysErr;
default:
return kSuccess;
}
return kSuccess;
}
Expand Down Expand Up @@ -194,6 +202,11 @@ int TcpServer::Init() {
return kSysErr;
}
return kSuccess;

// init default callback
event_callback_ = [](Event, TcpServer &, Socket) {};

return kSuccess;
}

} // namespace cppnet
8 changes: 5 additions & 3 deletions src/cppnet/server/tcp_server.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "../socket/socket.hpp"
#include "io_multiplexing/io_multiplexing_factory.hpp"
#include "io_multiplexing/io_multiplexing_base.hpp"
#include <functional>
#include <memory>
#include <string>
Expand All @@ -16,6 +16,7 @@ class TcpServer {
kEventAccept = 0,
kEventRead = 1,
kEventLeave = 2,
kEventError = 3,
};
using EventCallBack = std::function<void(Event, TcpServer &, Socket)>;

Expand Down Expand Up @@ -52,10 +53,10 @@ class TcpServer {
*/
void Register(EventCallBack cb);
/**
* @brief: Epoll event loop.
* @brief: Event loop.
* @return: 0 if success, -1 if failed.
*/
int EpollLoop();
int EventLoop();
/**
* @brief: Close file descriptor.With remove from epoll.
*/
Expand Down Expand Up @@ -84,6 +85,7 @@ class TcpServer {
void HandleAccept();
void HandleRead(int fd);
void HandleLeave(int fd);
void HandleError(int fd);

private:
// server address
Expand Down
6 changes: 6 additions & 0 deletions src/cppnet/socket/address.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,10 @@ void Address::GetIPAndPort(std::string &ip, uint16_t &port) {
port = ntohs(addr_.sin_port);
}

socklen_t *Address::GetAddrLen() {
static socklen_t len = 0;
len = sizeof(sockaddr_in);
return &len;
}

} // namespace cppnet
1 change: 1 addition & 0 deletions src/cppnet/socket/address.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Address {
public:
sockaddr_in *GetAddr() { return &addr_; }
sockaddr *GetSockAddr() { return reinterpret_cast<sockaddr *>(&addr_); }
static socklen_t *GetAddrLen();

private:
sockaddr_in addr_;
Expand Down
4 changes: 2 additions & 2 deletions src/cppnet/socket/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ int Socket::InitUdp() {
return 0;
}

Socket Socket::Accept(Address &addr, socklen_t *plen) const {
Socket Socket::Accept(Address &addr) const {
if (status_ != kInit) {
return Socket(-1);
}
auto accept_fd = ::accept(fd_, addr.GetSockAddr(), plen);
auto accept_fd = ::accept(fd_, addr.GetSockAddr(), addr.GetAddrLen());
return Socket(accept_fd);
}

Expand Down
2 changes: 1 addition & 1 deletion src/cppnet/socket/socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Socket {
* @param addr: server address.
* @param plen: pointer of address length.
*/
Socket Accept(Address &addr, socklen_t *plen) const;
Socket Accept(Address &addr) const;
/**
* @brief: Close socket.
*/
Expand Down
1 change: 1 addition & 0 deletions src/test/main.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "server/tcp_server_test.hpp"
#include "socket/socket_test.hpp"
#include "test.h"
#include "timer/timer_test.hpp"
#include "utils/host_test.hpp"
Expand Down
5 changes: 2 additions & 3 deletions src/test/server/tcp_server_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <atomic>
#include <string>
#include <unistd.h>
#include <vector>

using namespace cppnet;
using namespace std;
Expand Down Expand Up @@ -66,7 +65,7 @@ TEST(TcpServer, SigleClient) {
});

// run server
server.EpollLoop();
server.EventLoop();
}

TEST(TcpServer, MultiClient) {
Expand Down Expand Up @@ -150,7 +149,7 @@ TEST(TcpServer, MultiClient) {
}

// run server
rc = server.EpollLoop();
rc = server.EventLoop();
MUST_EQUAL(rc, 0);

wait_group.Wait();
Expand Down
Loading

0 comments on commit 0191f85

Please sign in to comment.