#include #include #include #include #include #include #include #include #include #include extern int errno; void ErrorHandler(const char* file,int line,const char* expr) { fprintf(stderr,"Error at %s, %d: %s\n",file,line,expr); perror(" errno"); } #define ErrorLog(expr) ErrorHandler(__FILE__,__LINE__,expr) #define AssertRun(expr) (expr?1:(ErrorLog(#expr),0)) enum{ PeerSide_Client=0, PeerSide_Server=1, }; enum{ State_Normal, State_Empty, State_Full, State_Tail, State_Closed, }; struct Peer{ struct Connection* connection; struct Stream* input; struct Stream* output; int side; int fd; uint32_t event_mask; }; static const int buffer_size=2; struct Stream{ int state; char data[buffer_size]; char* begin; char* end; Peer* sender; Peer* recver; Stream() { state=State_Empty; begin=data; end=data; } }; struct Connection{ Peer peers[2]; Stream streams[2]; bool connecting; Connection() { peers[0].connection=this; peers[1].connection=this; peers[0].side=0; peers[1].side=1; peers[0].input=&streams[0]; peers[0].output=&streams[1]; peers[1].input=&streams[1]; peers[1].output=&streams[0]; streams[0].recver=&peers[0]; streams[0].sender=&peers[1]; streams[1].recver=&peers[1]; streams[1].sender=&peers[0]; } }; struct Context{ struct hostent* hent; unsigned short remote_port; unsigned short local_port; int server_fd; int epoll_fd; }; void SwitchEvents(Context* context,Peer* peer,uint32_t mask,uint32_t value) { uint32_t v1=(peer->event_mask&~mask)|value; if(peer->event_mask==v1)return; struct epoll_event epev; epev.events=v1; epev.data.ptr=peer; if(peer->event_mask!=0){ if(v1!=0){ if(!AssertRun(0==epoll_ctl(context->epoll_fd,EPOLL_CTL_MOD,peer->fd,&epev)))_exit(-1); }else{ if(!AssertRun(0==epoll_ctl(context->epoll_fd,EPOLL_CTL_DEL,peer->fd,&epev)))_exit(-1); }; }else{ if(!AssertRun(0==epoll_ctl(context->epoll_fd,EPOLL_CTL_ADD,peer->fd,&epev)))_exit(-1); }; peer->event_mask=v1; } void CreateConnection(Context* context) { struct sockaddr_in addr; addr.sin_family=AF_INET; addr.sin_addr.s_addr=INADDR_ANY; addr.sin_port=(context->local_port>>8)|(context->local_port<<8); socklen_t addr_len=sizeof(addr); int accept_fd=accept(context->server_fd,(struct sockaddr*)&addr,&addr_len); if(!AssertRun(0==fcntl(accept_fd,F_SETFL,O_NONBLOCK)))return; AssertRun(accept_fd!=-1); if(!AssertRun(context->hent->h_length==4))return; memcpy(&addr.sin_addr.s_addr,context->hent->h_addr,4); addr.sin_port=(context->remote_port>>8)|(context->remote_port<<8); int connect_fd=socket(AF_INET,SOCK_STREAM,0); if(!AssertRun(-1!=connect_fd))return; if(!AssertRun(0==fcntl(connect_fd,F_SETFL,O_NONBLOCK)))_exit(-1); bool connecting=false; if(0!=connect(connect_fd,(struct sockaddr*)&addr,sizeof(addr))){ if(errno==EINPROGRESS){ connecting=true; }else{ ErrorLog("Connection failed"); return; }; }; //connecting=true; // if(!AssertRun(0==connect(connect_fd,(struct sockaddr*)&addr,sizeof(addr))))_exit(-1); Connection* connection=new Connection; connection->connecting=connecting; Peer* server_peer=&connection->peers[PeerSide_Server]; Peer* client_peer=&connection->peers[PeerSide_Client]; server_peer->fd=connect_fd; client_peer->fd=accept_fd; server_peer->event_mask=0; client_peer->event_mask=0; SwitchEvents(context,client_peer,EPOLLIN|EPOLLOUT,connecting?0:EPOLLIN); SwitchEvents(context,server_peer,EPOLLIN|EPOLLOUT,connecting?EPOLLOUT:EPOLLIN); } void DestroyPeer(Context* context,Peer* peer) { fprintf(stderr,"DestroyPeer: fd=%d\n",peer->fd); if(peer->fd!=-1){ SwitchEvents(context,peer,EPOLLIN|EPOLLOUT,0); close(peer->fd); peer->fd=-1; }; } void DestroyConnection(Context* context,Connection* connection) { DestroyPeer(context,&connection->peers[0]); DestroyPeer(context,&connection->peers[1]); delete connection; } void HangUpPeer(Context* context,Peer* peer); void ReadFromPeer(Context* context,Peer* peer) { Connection* connection=peer->connection; Stream* stream=peer->input; if(!AssertRun(stream->state!=State_Full)||!AssertRun(stream->state!=State_Tail))_exit(-1); size_t sz; if(stream->begin > stream->data){ sz=stream->end - stream->begin; if(sz>0)memmove(stream->data,stream->begin,sz); stream->begin = stream->data; stream->end = stream->data + sz; }; sz=stream->data + sizeof(stream->data) - stream->end; if(!AssertRun(sz>0))return; int len=read(peer->fd,stream->end,sz); if(len>0){ stream->end+=len; if(stream->state==State_Empty){ stream->state=State_Normal; Peer* peer1=stream->sender; SwitchEvents(context,peer1,EPOLLOUT,EPOLLOUT); }; if(len>=sz){ stream->state=State_Full; SwitchEvents(context,peer,EPOLLIN,0); }; }else{ HangUpPeer(context,peer); // _exit(-1); }; } void WriteToPeer(Context* context,Peer* peer) { Connection* connection=peer->connection; Stream* stream=peer->output; size_t sz=stream->end - stream->begin; if(!AssertRun(stream->state!=State_Empty)||!AssertRun(sz>0))_exit(-1); int len=write(peer->fd,stream->begin,sz); if(len>0){ stream->begin+=len; if(stream->state==State_Full){ stream->state=State_Normal; Peer* peer1=stream->recver; SwitchEvents(context,peer1,EPOLLIN,EPOLLIN); }; if(len>=sz){ SwitchEvents(context,peer,EPOLLOUT,0); if(stream->state==State_Tail){ stream->state=State_Closed; if(peer->fd!=-1){ close(peer->fd); peer->fd=-1; }; }else{ stream->state=State_Empty; }; }; }else{ HangUpPeer(context,peer); // _exit(-1); }; } void HangUpPeer(Context* context,Peer* peer) { Connection* connection=peer->connection; Stream* istream=peer->input; Stream* ostream=peer->output; SwitchEvents(context,peer,EPOLLIN|EPOLLOUT,0); Peer* peer1=istream->sender; SwitchEvents(context,peer1,EPOLLIN,0); if(peer->fd!=-1){ close(peer->fd); peer->fd=-1; }; if(istream->state==State_Empty){ istream->state=State_Closed; if(peer1->fd!=-1){ close(peer1->fd); peer1->fd=-1; }; }else{ istream->state=State_Tail; }; } void UpdateConnection(Context* context,Peer* peer,uint32_t events) { Connection* connection=peer->connection; if(peer->side==PeerSide_Server&&connection->connecting){ int optval; socklen_t optlen=sizeof(optval); if(!AssertRun(0==getsockopt(peer->fd,SOL_SOCKET,SO_ERROR,&optval,&optlen)))_exit(-1); if(!AssertRun(optlen==sizeof(optval)))_exit(-1); if(optval==0){ connection->connecting=false; for(int i=0;i<2;++i)SwitchEvents(context,&connection->peers[i],EPOLLIN|EPOLLOUT,EPOLLIN); }else{ DestroyConnection(context,connection); }; }else{ if(events&EPOLLHUP){ HangUpPeer(context,peer); }else{ if(events&EPOLLOUT){ WriteToPeer(context,peer); }; if(events&EPOLLIN){ ReadFromPeer(context,peer); }; }; // DestroyConnection(context,peer->connection); }; //... } void RunProxy(struct hostent* hent,unsigned short remote_port,unsigned short local_port) { struct sockaddr_in server_addr; server_addr.sin_family=AF_INET; server_addr.sin_addr.s_addr=INADDR_ANY; server_addr.sin_port=(local_port>>8)|(local_port<<8); int server_fd=socket(AF_INET,SOCK_STREAM,0); if(!AssertRun(-1!=server_fd))return; if(!AssertRun(0==bind(server_fd,(struct sockaddr*)&server_addr,sizeof(server_addr))))return; if(!AssertRun(0==listen(server_fd,5)))return; if(!AssertRun(0==fcntl(server_fd,F_SETFL,O_NONBLOCK)))return; //AssertRun(0==accept(server_fd,(struct sockaddr*)&server_addr,&slen)); int epoll_fd=epoll_create(1024); if(!AssertRun(-1!=epoll_fd))return; struct epoll_event epev; epev.events=EPOLLIN; epev.data.ptr=0; Context context; context.hent=hent; context.remote_port=remote_port; context.local_port=local_port; context.server_fd=server_fd; context.epoll_fd=epoll_fd; if(AssertRun(0==epoll_ctl(epoll_fd,EPOLL_CTL_ADD,server_fd,&epev))){ static const int maxevents=1024; struct epoll_event ret_events[maxevents]; while(true){ int n_events=epoll_wait(epoll_fd,ret_events,maxevents,-1); if(!AssertRun(n_events>0))break; for(int i=0;idata.ptr==0){ CreateConnection(&context); }else{ UpdateConnection(&context,(Peer*)ev->data.ptr,ev->events); }; }; }; }; close(epoll_fd); close(server_fd); } int main(int argc,char** argv) { bool valid=false; struct hostent* hent; unsigned short remote_port; unsigned short local_port; if(AssertRun(argc>3)){ hent=gethostbyname(argv[1]); if(AssertRun(0!=hent)){ remote_port=strtoul(argv[2],0,10); local_port=strtoul(argv[3],0,10); valid=true; }; }; if(!valid){ fprintf(stderr,"Usage: %s \n",argv[0]); return -1; }; RunProxy(hent,remote_port,local_port); return 0; }