$search
00001 #include <utilmm/system/socket.hh> 00002 #include <utilmm/system/system.hh> 00003 00004 #include <boost/noncopyable.hpp> 00005 #include <boost/lexical_cast.hpp> 00006 #include <boost/regex.hpp> 00007 #include <algorithm> 00008 00009 #include <sys/socket.h> 00010 #include <netdb.h> 00011 #include <sys/un.h> 00012 #include <netinet/in.h> 00013 #include <netinet/ip.h> 00014 #include <sys/time.h> 00015 #include <sys/select.h> 00016 00017 #include <iostream> 00018 00019 using namespace std; 00020 using namespace boost; 00021 00022 namespace utilmm 00023 { 00024 int base_socket::to_unix(Domain d) 00025 { 00026 switch(d) 00027 { 00028 case Unix: return PF_UNIX; 00029 case Inet: return PF_INET; 00030 } 00031 return 0; 00032 }; 00033 00034 int base_socket::to_unix(Type t) 00035 { 00036 switch(t) 00037 { 00038 case Stream: return SOCK_STREAM; 00039 case Datagram: return SOCK_DGRAM; 00040 } 00041 return 0; 00042 } 00043 00044 base_socket::base_socket(int _fd) 00045 : m_fd(_fd) {} 00046 base_socket::base_socket(Domain domain, Type type) 00047 : m_fd(-1) 00048 { 00049 m_fd = ::socket(base_socket::to_unix(domain), base_socket::to_unix(type), 0); 00050 if (m_fd == -1) 00051 throw unix_error("cannot open the socket"); 00052 m_domain = domain; 00053 m_type = type; 00054 } 00055 00056 base_socket::~base_socket() 00057 { 00058 if (m_fd != -1) 00059 close(m_fd); 00060 } 00061 00062 vector<uint8_t> base_socket::to_sockaddr(std::string const& to) const 00063 { 00064 vector<uint8_t> ret; 00065 if (m_domain == Unix) 00066 { 00067 static const unsigned int UNIX_MAX_PATH = 108; 00068 00069 if (to.size() >= UNIX_MAX_PATH) 00070 throw bad_address(); 00071 00072 sockaddr_un addr; 00073 addr.sun_family = AF_UNIX; 00074 strncpy(addr.sun_path, to.c_str(), UNIX_MAX_PATH - 1); 00075 addr.sun_path[UNIX_MAX_PATH - 1] = 0; 00076 00077 00078 copy( 00079 reinterpret_cast<uint8_t*>(&addr), 00080 reinterpret_cast<uint8_t*>(&addr) + sizeof(addr), 00081 back_inserter(ret)); 00082 } 00083 else if (m_domain == Inet) 00084 { 00085 static regex rx_ip("^(.+):(\\d+)$"); 00086 smatch match; 00087 if (!regex_match(to, match, rx_ip)) 00088 throw bad_address(); 00089 00090 std::string hostname = string(match[1].first, match[1].second); 00091 uint16_t port = lexical_cast<uint16_t>( string(match[2].first, match[2].second) ); 00092 00093 struct hostent* host = gethostbyname(hostname.c_str()); 00094 if (!host) 00095 throw unix_error("cannot get host address"); 00096 00097 if (host->h_addrtype == AF_INET) 00098 { 00099 sockaddr_in addr; 00100 addr.sin_family = AF_INET; 00101 addr.sin_port = htons(port); 00102 memcpy(&addr.sin_addr, host->h_addr_list[0], host->h_length); 00103 copy( 00104 reinterpret_cast<uint8_t*>(&addr), 00105 reinterpret_cast<uint8_t*>(&addr) + sizeof(addr), 00106 back_inserter(ret)); 00107 } 00108 else 00109 { 00110 sockaddr_in6 addr; 00111 addr.sin6_family = AF_INET6; 00112 addr.sin6_port = htons(port); 00113 addr.sin6_flowinfo = 0; // don't know what this is 00114 memcpy(&addr.sin6_addr, &host->h_addr_list[0], host->h_length); 00115 copy( 00116 reinterpret_cast<uint8_t*>(&addr), 00117 reinterpret_cast<uint8_t*>(&addr) + sizeof(addr), 00118 back_inserter(ret)); 00119 } 00120 } 00121 return ret; 00122 } 00123 00124 int base_socket::fd() const { return m_fd; } 00125 bool base_socket::try_wait(int what) const 00126 { 00127 timeval tv = { 0, 0 }; 00128 return wait(what, &tv) > 0; 00129 } 00130 void base_socket::wait(int what) const { wait(what, 0); } 00131 int base_socket::wait(int what, timeval* tv) const 00132 { 00133 fd_set rd_set, wr_set, exc_set; 00134 fd_set *rd_p = 0, *wr_p = 0, *exc_p = 0; 00135 if (what & WaitRead) 00136 { 00137 FD_ZERO(&rd_set); 00138 FD_SET(fd(), &rd_set); 00139 rd_p = &rd_set; 00140 } 00141 if (what & WaitWrite) 00142 { 00143 FD_ZERO(&wr_set); 00144 FD_SET(fd(), &wr_set); 00145 wr_p = &wr_set; 00146 } 00147 if (what & WaitException) 00148 { 00149 FD_ZERO(&exc_set); 00150 FD_SET(fd(), &exc_set); 00151 exc_p = &exc_set; 00152 } 00153 00154 int ret = select(m_fd + 1, rd_p, wr_p, exc_p, tv); 00155 if (ret == -1) 00156 throw unix_error("error while waiting for socket"); 00157 return ret; 00158 } 00159 00160 void base_socket::flush() const 00161 { fsync(fd()); } 00162 00163 00164 socket::socket(Domain domain, Type type, std::string const& connect_to) 00165 : base_socket(domain, type) 00166 { connect(connect_to); } 00167 socket::socket(int _fd) 00168 : base_socket(_fd) {} 00169 00170 void socket::connect(std::string const& to) 00171 { 00172 vector<uint8_t> addr = to_sockaddr(to); 00173 if (::connect(fd(), reinterpret_cast<sockaddr*>(&addr[0]), addr.size()) == -1) 00174 throw unix_error("cannot connect to " + to); 00175 } 00176 int socket::read(void* buf, size_t size) const 00177 { 00178 int read_bytes = ::read(fd(), buf, size); 00179 if (read_bytes == -1) 00180 throw unix_error("cannot read on socket"); 00181 return read_bytes; 00182 } 00183 int socket::write(void const* buf, size_t size) const 00184 { 00185 int written = ::write(fd(), buf, size); 00186 if (written == -1) 00187 throw unix_error("cannot write on socket"); 00188 return written; 00189 } 00190 00191 00192 00193 00194 00195 server_socket::server_socket(Domain domain, Type type, std::string const& bind_to, int backlog) 00196 : base_socket(domain, type) 00197 { 00198 bind(bind_to); 00199 if (listen(fd(), backlog) == -1) 00200 throw unix_error("cannot listen to " + bind_to); 00201 } 00202 00203 void server_socket::bind(std::string const& to) 00204 { 00205 vector<uint8_t> addr = to_sockaddr(to); 00206 if (::bind(fd(), reinterpret_cast<sockaddr*>(&addr[0]), addr.size()) == -1) 00207 throw unix_error("cannot bind to " + to); 00208 } 00209 00210 void server_socket::wait() const 00211 { return base_socket::wait(base_socket::WaitRead); } 00212 bool server_socket::try_wait() const 00213 { return base_socket::try_wait(base_socket::WaitRead); } 00214 00215 socket* server_socket::accept() const 00216 { 00217 int sock_fd = ::accept(fd(), NULL, NULL); 00218 if (sock_fd == -1) 00219 throw unix_error("failed in accept()"); 00220 00221 return new socket(sock_fd); 00222 } 00223 } 00224