* Copyright 2006-2009, Haiku, Inc. All Rights Reserved.
* Distributed under the terms of the MIT License.
*
* Authors:
* Axel Dörfler, axeld@pinc-software.de
* Hugo Santos, hugosantos@gmail.com
*/
#include "EndpointManager.h"
#include <new>
#include <unistd.h>
#include <KernelExport.h>
#include <NetUtilities.h>
#include <tracing.h>
#include "TCPEndpoint.h"
#ifdef TRACE_ENDPOINT_MANAGER
# define TRACE(x) dprintf x
#else
# define TRACE(x)
#endif
#if TCP_TRACING
# define ENDPOINT_TRACING
#endif
#ifdef ENDPOINT_TRACING
namespace EndpointTracing {
class Bind : public AbstractTraceEntry {
public:
Bind(TCPEndpoint* endpoint, ConstSocketAddress& address, bool ephemeral)
:
fEndpoint(endpoint),
fEphemeral(ephemeral)
{
address.AsString(fAddress, sizeof(fAddress), true);
Initialized();
}
Bind(TCPEndpoint* endpoint, SocketAddress& address, bool ephemeral)
:
fEndpoint(endpoint),
fEphemeral(ephemeral)
{
address.AsString(fAddress, sizeof(fAddress), true);
Initialized();
}
virtual void AddDump(TraceOutput& out)
{
out.Print("tcp:e:%p bind%s address %s", fEndpoint,
fEphemeral ? " ephemeral" : "", fAddress);
}
protected:
TCPEndpoint* fEndpoint;
char fAddress[32];
bool fEphemeral;
};
class Connect : public AbstractTraceEntry {
public:
Connect(TCPEndpoint* endpoint)
:
fEndpoint(endpoint)
{
endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true);
endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true);
Initialized();
}
virtual void AddDump(TraceOutput& out)
{
out.Print("tcp:e:%p connect local %s, peer %s", fEndpoint, fLocal,
fPeer);
}
protected:
TCPEndpoint* fEndpoint;
char fLocal[32];
char fPeer[32];
};
class Unbind : public AbstractTraceEntry {
public:
Unbind(TCPEndpoint* endpoint)
:
fEndpoint(endpoint)
{
endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true);
endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true);
Initialized();
}
#if 0
virtual void DumpStackTrace(TraceOutput& out)
{
out.PrintStackTrace(fStackTrace);
}
#endif
virtual void AddDump(TraceOutput& out)
{
out.Print("tcp:e:%p unbind, local %s, peer %s", fEndpoint, fLocal,
fPeer);
}
protected:
TCPEndpoint* fEndpoint;
char fLocal[32];
char fPeer[32];
};
}
# define T(x) new(std::nothrow) EndpointTracing::x
#else
# define T(x)
#endif
static const uint16 kLastReservedPort = 1023;
static const uint16 kFirstEphemeralPort = 40000;
ConnectionHashDefinition::ConnectionHashDefinition(EndpointManager* manager)
:
fManager(manager)
{
}
size_t
ConnectionHashDefinition::HashKey(const KeyType& key) const
{
return ConstSocketAddress(fManager->AddressModule(),
key.first).HashPair(key.second);
}
size_t
ConnectionHashDefinition::Hash(TCPEndpoint* endpoint) const
{
return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress());
}
bool
ConnectionHashDefinition::Compare(const KeyType& key,
TCPEndpoint* endpoint) const
{
return endpoint->LocalAddress().EqualTo(key.first, true)
&& endpoint->PeerAddress().EqualTo(key.second, true);
}
TCPEndpoint*&
ConnectionHashDefinition::GetLink(TCPEndpoint* endpoint) const
{
return endpoint->fConnectionHashLink;
}
size_t
EndpointHashDefinition::HashKey(uint16 port) const
{
return port;
}
size_t
EndpointHashDefinition::Hash(TCPEndpoint* endpoint) const
{
return endpoint->LocalAddress().Port();
}
bool
EndpointHashDefinition::Compare(uint16 port, TCPEndpoint* endpoint) const
{
return endpoint->LocalAddress().Port() == port;
}
bool
EndpointHashDefinition::CompareValues(TCPEndpoint* first,
TCPEndpoint* second) const
{
return first->LocalAddress().Port() == second->LocalAddress().Port();
}
TCPEndpoint*&
EndpointHashDefinition::GetLink(TCPEndpoint* endpoint) const
{
return endpoint->fEndpointHashLink;
}
EndpointManager::EndpointManager(net_domain* domain)
:
fDomain(domain),
fConnectionHash(this),
fLastPort(kFirstEphemeralPort)
{
rw_lock_init(&fLock, "TCP endpoint manager");
}
EndpointManager::~EndpointManager()
{
rw_lock_destroy(&fLock);
}
status_t
EndpointManager::Init()
{
status_t status = fConnectionHash.Init();
if (status == B_OK)
status = fEndpointHash.Init();
return status;
}
You must hold the manager's lock when calling this method (either read or
write).
*/
TCPEndpoint*
EndpointManager::_LookupConnection(const sockaddr* local, const sockaddr* peer)
{
return fConnectionHash.Lookup(std::make_pair(local, peer));
}
status_t
EndpointManager::SetConnection(TCPEndpoint* endpoint, const sockaddr* _local,
const sockaddr* peer, const sockaddr* interfaceLocal)
{
TRACE(("EndpointManager::SetConnection(%p)\n", endpoint));
WriteLocker _(fLock);
SocketAddressStorage local(AddressModule());
local.SetTo(_local);
if (local.IsEmpty(false)) {
uint16 port = local.Port();
local.SetTo(interfaceLocal);
local.SetPort(port);
}
if (_LookupConnection(*local, peer) != NULL)
return EADDRINUSE;
endpoint->LocalAddress().SetTo(*local);
endpoint->PeerAddress().SetTo(peer);
T(Connect(endpoint));
fConnectionHash.RemoveUnchecked(endpoint);
fConnectionHash.Insert(endpoint);
return B_OK;
}
status_t
EndpointManager::SetPassive(TCPEndpoint* endpoint)
{
WriteLocker _(fLock);
if (!endpoint->IsBound()) {
SocketAddressStorage local(AddressModule());
local.SetToEmpty();
status_t status = _BindToEphemeral(endpoint, *local);
if (status < B_OK)
return status;
}
SocketAddressStorage passive(AddressModule());
passive.SetToEmpty();
if (_LookupConnection(*endpoint->LocalAddress(), *passive))
return EADDRINUSE;
endpoint->PeerAddress().SetTo(*passive);
fConnectionHash.Insert(endpoint);
return B_OK;
}
TCPEndpoint*
EndpointManager::FindConnection(sockaddr* local, sockaddr* peer)
{
ReadLocker _(fLock);
TCPEndpoint *endpoint = _LookupConnection(local, peer);
if (endpoint != NULL) {
TRACE(("TCP: Received packet corresponds to explicit endpoint %p\n",
endpoint));
if (gSocketModule->acquire_socket(endpoint->socket))
return endpoint;
}
SocketAddressStorage wildcard(AddressModule());
wildcard.SetToEmpty();
endpoint = _LookupConnection(local, *wildcard);
if (endpoint != NULL) {
TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n",
endpoint));
if (gSocketModule->acquire_socket(endpoint->socket))
return endpoint;
}
SocketAddressStorage localWildcard(AddressModule());
localWildcard.SetToEmpty();
localWildcard.SetPort(AddressModule()->get_port(local));
endpoint = _LookupConnection(*localWildcard, *wildcard);
if (endpoint != NULL) {
TRACE(("TCP: Received packet corresponds to local wildcard endpoint "
"%p\n", endpoint));
if (gSocketModule->acquire_socket(endpoint->socket))
return endpoint;
}
TRACE(("TCP: no matching endpoint!\n"));
return NULL;
}
status_t
EndpointManager::Bind(TCPEndpoint* endpoint, const sockaddr* address)
{
if (!AddressModule()->is_same_family(address))
return EAFNOSUPPORT;
WriteLocker locker(fLock);
if (AddressModule()->get_port(address) == 0)
return _BindToEphemeral(endpoint, address);
return _BindToAddress(locker, endpoint, address);
}
status_t
EndpointManager::BindChild(TCPEndpoint* endpoint, const sockaddr* address)
{
WriteLocker _(fLock);
return _Bind(endpoint, address);
}
status_t
EndpointManager::_BindToAddress(WriteLocker& locker, TCPEndpoint* endpoint,
const sockaddr* _address)
{
ConstSocketAddress address(AddressModule(), _address);
uint16 port = address.Port();
TRACE(("EndpointManager::BindToAddress(%p)\n", endpoint));
T(Bind(endpoint, address, false));
if (ntohs(port) <= kLastReservedPort && geteuid() != 0)
return B_PERMISSION_DENIED;
bool retrying = false;
int32 retry = 0;
do {
EndpointTable::ValueIterator portUsers = fEndpointHash.Lookup(port);
retry = false;
while (portUsers.HasNext()) {
TCPEndpoint* user = portUsers.Next();
if (user->LocalAddress().IsEmpty(false)
|| address.EqualTo(*user->LocalAddress(), false)) {
tcp_state userState = user->State();
if (user->IsLocal()
&& (userState > ESTABLISHED || userState == CLOSED)) {
locker.Unlock();
snooze(10000);
locker.Lock();
if (!retrying) {
retrying = true;
retry = 5;
}
break;
}
if ((endpoint->socket->options & SO_REUSEADDR) == 0)
return EADDRINUSE;
if (userState != TIME_WAIT && userState != CLOSED)
return EADDRINUSE;
}
}
} while (retry-- > 0);
return _Bind(endpoint, *address);
}
status_t
EndpointManager::_BindToEphemeral(TCPEndpoint* endpoint,
const sockaddr* address)
{
TRACE(("EndpointManager::BindToEphemeral(%p)\n", endpoint));
uint32 max = fLastPort + 65536;
for (int32 i = 1; i < 5; i++) {
uint32 step = i == 4 ? 1 : (system_time() & 0x1f) + 1;
uint32 counter = fLastPort + step;
while (counter < max) {
uint16 port = counter & 0xffff;
if (port <= kLastReservedPort)
port += kLastReservedPort;
fLastPort = port;
port = htons(port);
if (!fEndpointHash.Lookup(port).HasNext()) {
SocketAddressStorage newAddress(AddressModule());
newAddress.SetTo(address);
newAddress.SetPort(port);
TRACE((" EndpointManager::BindToEphemeral(%p) -> %s\n",
endpoint, AddressString(Domain(), *newAddress,
true).Data()));
T(Bind(endpoint, newAddress, true));
return _Bind(endpoint, *newAddress);
}
counter += step;
}
}
return EADDRINUSE;
}
status_t
EndpointManager::_Bind(TCPEndpoint* endpoint, const sockaddr* address)
{
status_t status = endpoint->next->module->bind(endpoint->next, address);
if (status < B_OK)
return status;
fEndpointHash.Insert(endpoint);
return B_OK;
}
status_t
EndpointManager::Unbind(TCPEndpoint* endpoint)
{
TRACE(("EndpointManager::Unbind(%p)\n", endpoint));
T(Unbind(endpoint));
if (endpoint == NULL || !endpoint->IsBound()) {
TRACE((" endpoint is unbound.\n"));
return B_BAD_VALUE;
}
WriteLocker _(fLock);
if (!fEndpointHash.Remove(endpoint))
panic("bound endpoint %p not in hash!", endpoint);
fConnectionHash.Remove(endpoint);
(*endpoint->LocalAddress())->sa_len = 0;
return B_OK;
}
status_t
EndpointManager::ReplyWithReset(tcp_segment_header& segment, net_buffer* buffer)
{
TRACE(("TCP: Sending RST...\n"));
net_buffer* reply = gBufferModule->create(512);
if (reply == NULL)
return B_NO_MEMORY;
AddressModule()->set_to(reply->source, buffer->destination);
AddressModule()->set_to(reply->destination, buffer->source);
tcp_segment_header outSegment(TCP_FLAG_RESET);
outSegment.sequence = 0;
outSegment.acknowledge = 0;
outSegment.advertised_window = 0;
outSegment.urgent_offset = 0;
if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) {
outSegment.flags |= TCP_FLAG_ACKNOWLEDGE;
outSegment.acknowledge = segment.sequence + buffer->size;
if ((segment.flags & (TCP_FLAG_SYNCHRONIZE | TCP_FLAG_FINISH)) != 0)
outSegment.acknowledge++;
} else
outSegment.sequence = segment.acknowledge;
status_t status = add_tcp_header(AddressModule(), outSegment, reply);
if (status == B_OK)
status = Domain()->module->send_data(NULL, reply);
if (status != B_OK)
gBufferModule->free(reply);
return status;
}
void
EndpointManager::Dump() const
{
kprintf("-------- TCP Domain %p ---------\n", this);
kprintf("%10s %21s %21s %8s %8s %12s\n", "address", "local", "peer",
"recv-q", "send-q", "state");
ConnectionTable::Iterator iterator = fConnectionHash.GetIterator();
while (iterator.HasNext()) {
TCPEndpoint *endpoint = iterator.Next();
char localBuf[64], peerBuf[64];
endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true);
endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true);
kprintf("%p %21s %21s %8lu %8lu %12s\n", endpoint, localBuf, peerBuf,
endpoint->fReceiveQueue.Available(), endpoint->fSendQueue.Used(),
name_for_state(endpoint->State()));
}
}