00001 #include <utilmm/system/socket.hh>
00002 #include <utilmm/system/system.hh>
00003
00004 #include <boost/noncopyable.hpp>
00005 #include <boost/lexical_cast.hpp>
00006 #include <boost/regex.hpp>
00007 #include <algorithm>
00008
00009 #include <sys/socket.h>
00010 #include <netdb.h>
00011 #include <sys/un.h>
00012 #include <netinet/in.h>
00013 #include <netinet/ip.h>
00014 #include <sys/time.h>
00015 #include <sys/select.h>
00016
00017 #include <iostream>
00018
00019 using namespace std;
00020 using namespace boost;
00021
00022 namespace utilmm
00023 {
00024 int base_socket::to_unix(Domain d)
00025 {
00026 switch(d)
00027 {
00028 case Unix: return PF_UNIX;
00029 case Inet: return PF_INET;
00030 }
00031 return 0;
00032 };
00033
00034 int base_socket::to_unix(Type t)
00035 {
00036 switch(t)
00037 {
00038 case Stream: return SOCK_STREAM;
00039 case Datagram: return SOCK_DGRAM;
00040 }
00041 return 0;
00042 }
00043
00044 base_socket::base_socket(int _fd)
00045 : m_fd(_fd) {}
00046 base_socket::base_socket(Domain domain, Type type)
00047 : m_fd(-1)
00048 {
00049 m_fd = ::socket(base_socket::to_unix(domain), base_socket::to_unix(type), 0);
00050 if (m_fd == -1)
00051 throw unix_error("cannot open the socket");
00052 m_domain = domain;
00053 m_type = type;
00054 }
00055
00056 base_socket::~base_socket()
00057 {
00058 if (m_fd != -1)
00059 close(m_fd);
00060 }
00061
00062 vector<uint8_t> base_socket::to_sockaddr(std::string const& to) const
00063 {
00064 vector<uint8_t> ret;
00065 if (m_domain == Unix)
00066 {
00067 #if __APPLE__
00068 static const unsigned int UNIX_MAX_PATH = 104;
00069 #else
00070 static const unsigned int UNIX_MAX_PATH = 108;
00071 #endif
00072 if (to.size() >= UNIX_MAX_PATH)
00073 throw bad_address();
00074
00075 sockaddr_un addr;
00076 addr.sun_family = AF_UNIX;
00077 strncpy(addr.sun_path, to.c_str(), UNIX_MAX_PATH - 1);
00078 addr.sun_path[UNIX_MAX_PATH - 1] = 0;
00079
00080
00081 copy(
00082 reinterpret_cast<uint8_t*>(&addr),
00083 reinterpret_cast<uint8_t*>(&addr) + sizeof(addr),
00084 back_inserter(ret));
00085 }
00086 else if (m_domain == Inet)
00087 {
00088 static regex rx_ip("^(.+):(\\d+)$");
00089 smatch match;
00090 if (!regex_match(to, match, rx_ip))
00091 throw bad_address();
00092
00093 std::string hostname = string(match[1].first, match[1].second);
00094 uint16_t port = lexical_cast<uint16_t>( string(match[2].first, match[2].second) );
00095
00096 struct hostent* host = gethostbyname(hostname.c_str());
00097 if (!host)
00098 throw unix_error("cannot get host address");
00099
00100 if (host->h_addrtype == AF_INET)
00101 {
00102 sockaddr_in addr;
00103 addr.sin_family = AF_INET;
00104 addr.sin_port = htons(port);
00105 memcpy(&addr.sin_addr, host->h_addr_list[0], host->h_length);
00106 copy(
00107 reinterpret_cast<uint8_t*>(&addr),
00108 reinterpret_cast<uint8_t*>(&addr) + sizeof(addr),
00109 back_inserter(ret));
00110 }
00111 else
00112 {
00113 sockaddr_in6 addr;
00114 addr.sin6_family = AF_INET6;
00115 addr.sin6_port = htons(port);
00116 addr.sin6_flowinfo = 0;
00117 memcpy(&addr.sin6_addr, &host->h_addr_list[0], host->h_length);
00118 copy(
00119 reinterpret_cast<uint8_t*>(&addr),
00120 reinterpret_cast<uint8_t*>(&addr) + sizeof(addr),
00121 back_inserter(ret));
00122 }
00123 }
00124 return ret;
00125 }
00126
00127 int base_socket::fd() const { return m_fd; }
00128 bool base_socket::try_wait(int what) const
00129 {
00130 timeval tv = { 0, 0 };
00131 return wait(what, &tv) > 0;
00132 }
00133 void base_socket::wait(int what) const { wait(what, 0); }
00134 int base_socket::wait(int what, timeval* tv) const
00135 {
00136 fd_set rd_set, wr_set, exc_set;
00137 fd_set *rd_p = 0, *wr_p = 0, *exc_p = 0;
00138 if (what & WaitRead)
00139 {
00140 FD_ZERO(&rd_set);
00141 FD_SET(fd(), &rd_set);
00142 rd_p = &rd_set;
00143 }
00144 if (what & WaitWrite)
00145 {
00146 FD_ZERO(&wr_set);
00147 FD_SET(fd(), &wr_set);
00148 wr_p = &wr_set;
00149 }
00150 if (what & WaitException)
00151 {
00152 FD_ZERO(&exc_set);
00153 FD_SET(fd(), &exc_set);
00154 exc_p = &exc_set;
00155 }
00156
00157 int ret = select(m_fd + 1, rd_p, wr_p, exc_p, tv);
00158 if (ret == -1)
00159 throw unix_error("error while waiting for socket");
00160 return ret;
00161 }
00162
00163 void base_socket::flush() const
00164 { fsync(fd()); }
00165
00166
00167 socket::socket(Domain domain, Type type, std::string const& connect_to)
00168 : base_socket(domain, type)
00169 { connect(connect_to); }
00170 socket::socket(int _fd)
00171 : base_socket(_fd) {}
00172
00173 void socket::connect(std::string const& to)
00174 {
00175 vector<uint8_t> addr = to_sockaddr(to);
00176 if (::connect(fd(), reinterpret_cast<sockaddr*>(&addr[0]), addr.size()) == -1)
00177 throw unix_error("cannot connect to " + to);
00178 }
00179 int socket::read(void* buf, size_t size) const
00180 {
00181 int read_bytes = ::read(fd(), buf, size);
00182 if (read_bytes == -1)
00183 throw unix_error("cannot read on socket");
00184 return read_bytes;
00185 }
00186 int socket::write(void const* buf, size_t size) const
00187 {
00188 int written = ::write(fd(), buf, size);
00189 if (written == -1)
00190 throw unix_error("cannot write on socket");
00191 return written;
00192 }
00193
00194
00195
00196
00197
00198 server_socket::server_socket(Domain domain, Type type, std::string const& bind_to, int backlog)
00199 : base_socket(domain, type)
00200 {
00201 bind(bind_to);
00202 if (listen(fd(), backlog) == -1)
00203 throw unix_error("cannot listen to " + bind_to);
00204 }
00205
00206 void server_socket::bind(std::string const& to)
00207 {
00208 vector<uint8_t> addr = to_sockaddr(to);
00209 if (::bind(fd(), reinterpret_cast<sockaddr*>(&addr[0]), addr.size()) == -1)
00210 throw unix_error("cannot bind to " + to);
00211 }
00212
00213 void server_socket::wait() const
00214 { return base_socket::wait(base_socket::WaitRead); }
00215 bool server_socket::try_wait() const
00216 { return base_socket::try_wait(base_socket::WaitRead); }
00217
00218 socket* server_socket::accept() const
00219 {
00220 int sock_fd = ::accept(fd(), NULL, NULL);
00221 if (sock_fd == -1)
00222 throw unix_error("failed in accept()");
00223
00224 return new socket(sock_fd);
00225 }
00226 }
00227