00001 #include <utilmm/system/socket.hh>
00002 #include <utilmm/system/system.hh>
00003
00004 #include <boost/noncopyable.hpp>
00005 #include <boost/lexical_cast.hpp>
00006 #include <boost/regex.hpp>
00007 #include <algorithm>
00008
00009 #include <sys/socket.h>
00010 #include <netdb.h>
00011 #include <sys/un.h>
00012 #include <netinet/in.h>
00013 #include <netinet/ip.h>
00014 #include <sys/time.h>
00015 #include <sys/select.h>
00016
00017 #include <iostream>
00018
00019 using namespace std;
00020 using namespace boost;
00021
00022 namespace utilmm
00023 {
00024 int base_socket::to_unix(Domain d)
00025 {
00026 switch(d)
00027 {
00028 case Unix: return PF_UNIX;
00029 case Inet: return PF_INET;
00030 }
00031 return 0;
00032 };
00033
00034 int base_socket::to_unix(Type t)
00035 {
00036 switch(t)
00037 {
00038 case Stream: return SOCK_STREAM;
00039 case Datagram: return SOCK_DGRAM;
00040 }
00041 return 0;
00042 }
00043
00044 base_socket::base_socket(int _fd)
00045 : m_fd(_fd) {}
00046 base_socket::base_socket(Domain domain, Type type)
00047 : m_fd(-1)
00048 {
00049 m_fd = ::socket(base_socket::to_unix(domain), base_socket::to_unix(type), 0);
00050 if (m_fd == -1)
00051 throw unix_error("cannot open the socket");
00052 m_domain = domain;
00053 m_type = type;
00054 }
00055
00056 base_socket::~base_socket()
00057 {
00058 if (m_fd != -1)
00059 close(m_fd);
00060 }
00061
00062 vector<uint8_t> base_socket::to_sockaddr(std::string const& to) const
00063 {
00064 vector<uint8_t> ret;
00065 if (m_domain == Unix)
00066 {
00067 static const unsigned int UNIX_MAX_PATH = 108;
00068
00069 if (to.size() >= UNIX_MAX_PATH)
00070 throw bad_address();
00071
00072 sockaddr_un addr;
00073 addr.sun_family = AF_UNIX;
00074 strncpy(addr.sun_path, to.c_str(), UNIX_MAX_PATH - 1);
00075 addr.sun_path[UNIX_MAX_PATH - 1] = 0;
00076
00077
00078 copy(
00079 reinterpret_cast<uint8_t*>(&addr),
00080 reinterpret_cast<uint8_t*>(&addr) + sizeof(addr),
00081 back_inserter(ret));
00082 }
00083 else if (m_domain == Inet)
00084 {
00085 static regex rx_ip("^(.+):(\\d+)$");
00086 smatch match;
00087 if (!regex_match(to, match, rx_ip))
00088 throw bad_address();
00089
00090 std::string hostname = string(match[1].first, match[1].second);
00091 uint16_t port = lexical_cast<uint16_t>( string(match[2].first, match[2].second) );
00092
00093 struct hostent* host = gethostbyname(hostname.c_str());
00094 if (!host)
00095 throw unix_error("cannot get host address");
00096
00097 if (host->h_addrtype == AF_INET)
00098 {
00099 sockaddr_in addr;
00100 addr.sin_family = AF_INET;
00101 addr.sin_port = htons(port);
00102 memcpy(&addr.sin_addr, host->h_addr_list[0], host->h_length);
00103 copy(
00104 reinterpret_cast<uint8_t*>(&addr),
00105 reinterpret_cast<uint8_t*>(&addr) + sizeof(addr),
00106 back_inserter(ret));
00107 }
00108 else
00109 {
00110 sockaddr_in6 addr;
00111 addr.sin6_family = AF_INET6;
00112 addr.sin6_port = htons(port);
00113 addr.sin6_flowinfo = 0;
00114 memcpy(&addr.sin6_addr, &host->h_addr_list[0], host->h_length);
00115 copy(
00116 reinterpret_cast<uint8_t*>(&addr),
00117 reinterpret_cast<uint8_t*>(&addr) + sizeof(addr),
00118 back_inserter(ret));
00119 }
00120 }
00121 return ret;
00122 }
00123
00124 int base_socket::fd() const { return m_fd; }
00125 bool base_socket::try_wait(int what) const
00126 {
00127 timeval tv = { 0, 0 };
00128 return wait(what, &tv) > 0;
00129 }
00130 void base_socket::wait(int what) const { wait(what, 0); }
00131 int base_socket::wait(int what, timeval* tv) const
00132 {
00133 fd_set rd_set, wr_set, exc_set;
00134 fd_set *rd_p = 0, *wr_p = 0, *exc_p = 0;
00135 if (what & WaitRead)
00136 {
00137 FD_ZERO(&rd_set);
00138 FD_SET(fd(), &rd_set);
00139 rd_p = &rd_set;
00140 }
00141 if (what & WaitWrite)
00142 {
00143 FD_ZERO(&wr_set);
00144 FD_SET(fd(), &wr_set);
00145 wr_p = &wr_set;
00146 }
00147 if (what & WaitException)
00148 {
00149 FD_ZERO(&exc_set);
00150 FD_SET(fd(), &exc_set);
00151 exc_p = &exc_set;
00152 }
00153
00154 int ret = select(m_fd + 1, rd_p, wr_p, exc_p, tv);
00155 if (ret == -1)
00156 throw unix_error("error while waiting for socket");
00157 return ret;
00158 }
00159
00160 void base_socket::flush() const
00161 { fsync(fd()); }
00162
00163
00164 socket::socket(Domain domain, Type type, std::string const& connect_to)
00165 : base_socket(domain, type)
00166 { connect(connect_to); }
00167 socket::socket(int _fd)
00168 : base_socket(_fd) {}
00169
00170 void socket::connect(std::string const& to)
00171 {
00172 vector<uint8_t> addr = to_sockaddr(to);
00173 if (::connect(fd(), reinterpret_cast<sockaddr*>(&addr[0]), addr.size()) == -1)
00174 throw unix_error("cannot connect to " + to);
00175 }
00176 int socket::read(void* buf, size_t size) const
00177 {
00178 int read_bytes = ::read(fd(), buf, size);
00179 if (read_bytes == -1)
00180 throw unix_error("cannot read on socket");
00181 return read_bytes;
00182 }
00183 int socket::write(void const* buf, size_t size) const
00184 {
00185 int written = ::write(fd(), buf, size);
00186 if (written == -1)
00187 throw unix_error("cannot write on socket");
00188 return written;
00189 }
00190
00191
00192
00193
00194
00195 server_socket::server_socket(Domain domain, Type type, std::string const& bind_to, int backlog)
00196 : base_socket(domain, type)
00197 {
00198 bind(bind_to);
00199 if (listen(fd(), backlog) == -1)
00200 throw unix_error("cannot listen to " + bind_to);
00201 }
00202
00203 void server_socket::bind(std::string const& to)
00204 {
00205 vector<uint8_t> addr = to_sockaddr(to);
00206 if (::bind(fd(), reinterpret_cast<sockaddr*>(&addr[0]), addr.size()) == -1)
00207 throw unix_error("cannot bind to " + to);
00208 }
00209
00210 void server_socket::wait() const
00211 { return base_socket::wait(base_socket::WaitRead); }
00212 bool server_socket::try_wait() const
00213 { return base_socket::try_wait(base_socket::WaitRead); }
00214
00215 socket* server_socket::accept() const
00216 {
00217 int sock_fd = ::accept(fd(), NULL, NULL);
00218 if (sock_fd == -1)
00219 throw unix_error("failed in accept()");
00220
00221 return new socket(sock_fd);
00222 }
00223 }
00224