⛏️ index : haiku.git

/*
 * Copyright 2012 Haiku, Inc. All rights reserved.
 * Distributed under the terms of the MIT License.
 *
 * Authors:
 *		Paweł Dziepak, pdziepak@quarnos.org
 */


#include "RPCServer.h"

#include <stdlib.h>

#include <util/AutoLock.h>
#include <util/Random.h>

#include "RPCCallbackServer.h"
#include "RPCReply.h"


using namespace RPC;


RequestManager::RequestManager()
	:
	fQueueHead(NULL),
	fQueueTail(NULL)
{
	mutex_init(&fLock, NULL);
}


RequestManager::~RequestManager()
{
	mutex_destroy(&fLock);
}


void
RequestManager::AddRequest(Request* request)
{
	ASSERT(request != NULL);

	MutexLocker _(fLock);
	if (fQueueTail != NULL)
		fQueueTail->fNext = request;
	else 
		fQueueHead = request;
	fQueueTail = request;
	request->fNext = NULL;
}


Request*
RequestManager::FindRequest(uint32 xid)
{
	MutexLocker _(fLock);
	Request* req = fQueueHead;
	Request* prev = NULL;
	while (req != NULL) {
		if (req->fXID == xid) {
			if (prev != NULL)
				prev->fNext = req->fNext;
			if (fQueueTail == req)
				fQueueTail = prev;
			if (fQueueHead == req)
				fQueueHead = req->fNext;

			return req;
		}

		prev = req;
		req = req->fNext;
	}

	return NULL;
}


Server::Server(Connection* connection, PeerAddress* address)
	:
	fConnection(connection),
	fAddress(address),
	fPrivateData(NULL),
	fCallback(NULL),
	fRepairCount(0),
	fXID(get_random<uint32>())
{
	ASSERT(connection != NULL);
	ASSERT(address != NULL);

	mutex_init(&fCallbackLock, NULL);
	mutex_init(&fRepairLock, NULL);

	_StartListening();
}


Server::~Server()
{
	if (fCallback != NULL)
		fCallback->CBServer()->UnregisterCallback(fCallback);
	delete fCallback;
	mutex_destroy(&fCallbackLock);
	mutex_destroy(&fRepairLock);

	delete fPrivateData;

	fThreadCancel = true;
	fConnection->Disconnect();

	status_t result;
	wait_for_thread(fThread, &result);

	delete fConnection;
}


status_t
Server::_StartListening()
{
	fThreadCancel = false;
	fThreadError = B_OK;
	fThread = spawn_kernel_thread(&Server::_ListenerThreadStart,
		"NFSv4 Listener", B_NORMAL_PRIORITY, this);
	if (fThread < B_OK)
		return fThread;

	status_t result = resume_thread(fThread);
	if (result != B_OK) {
		kill_thread(fThread);
		return result;
	}

	return B_OK;
}


status_t
Server::SendCallAsync(Call* call, Reply** reply, Request** request)
{
	ASSERT(call != NULL);
	ASSERT(reply != NULL);
	ASSERT(request != NULL);

	if (fThreadError != B_OK && Repair() != B_OK)
		return fThreadError;

	Request* req = new(std::nothrow) Request;
	if (req == NULL)
		return B_NO_MEMORY;

	uint32 xid = _GetXID();
	call->SetXID(xid);
	req->fXID = xid;
	req->fReply = reply;
	req->fEvent.Init(&req->fEvent, NULL);
	req->fDone = false;
	req->fError = B_OK;
	req->fNext = NULL;

	fRequests.AddRequest(req);

	*request = req;
	status_t error = ResendCallAsync(call, req);
	if (error != B_OK)
		delete req;
	return error;
}


status_t
Server::ResendCallAsync(Call* call, Request* request)
{
	ASSERT(call != NULL);
	ASSERT(request != NULL);

	if (fThreadError != B_OK && Repair() != B_OK) {
		fRequests.FindRequest(request->fXID);
		return fThreadError;
	}

	XDR::WriteStream& stream = call->Stream();
	status_t result = fConnection->Send(stream.Buffer(), stream.Size());
	if (result != B_OK) {
		fRequests.FindRequest(request->fXID);
		return result;
	}

	return B_OK;
}


status_t
Server::WakeCall(Request* request)
{
	ASSERT(request != NULL);

	Request* req = fRequests.FindRequest(request->fXID);
	if (req == NULL)
		return B_OK;

	request->fError = B_FILE_ERROR;
	*request->fReply = NULL;
	request->fDone = true;
	request->fEvent.NotifyAll();

	return B_OK;
}


status_t
Server::Repair()
{
	uint32 thisRepair = fRepairCount;

	MutexLocker _(fRepairLock);
	if (fRepairCount != thisRepair)
		return B_OK;

	fThreadCancel = true;

	status_t result = fConnection->Reconnect();
	if (result != B_OK)
		return result;

	wait_for_thread(fThread, &result);
	result = _StartListening();

	if (result == B_OK)
		fRepairCount++;

	return result;
}


Callback*
Server::GetCallback()
{
	MutexLocker _(fCallbackLock);

	if (fCallback == NULL) {
		fCallback = new(std::nothrow) Callback(this);
		if (fCallback == NULL)
			return NULL;

		CallbackServer* server = CallbackServer::Get(this);
		if (server == NULL) {
			delete fCallback;
			return NULL;
		}

		if (server->RegisterCallback(fCallback) != B_OK) {
			delete fCallback;
			return NULL;
		}
	}

	return fCallback;
}


uint32
Server::_GetXID()
{
	return static_cast<uint32>(atomic_add(&fXID, 1));
}


status_t
Server::_Listener()
{
	status_t result;
	uint32 size;
	void* buffer = NULL;

	while (!fThreadCancel) {
		result = fConnection->Receive(&buffer, &size);
		if (result == B_NO_MEMORY)
			continue;
		else if (result != B_OK) {
			fThreadError = result;
			return result;
		}

		ASSERT(buffer != NULL && size > 0);
		Reply* reply = new(std::nothrow) Reply(buffer, size);
		if (reply == NULL) {
			free(buffer);
			continue;
		}

		Request* req = fRequests.FindRequest(reply->GetXID());
		if (req != NULL) {
			*req->fReply = reply;
			req->fDone = true;
			req->fEvent.NotifyAll();
		} else
			delete reply;
	}

	return B_OK;
}


status_t
Server::_ListenerThreadStart(void* object)
{
	ASSERT(object != NULL);

	Server* server = reinterpret_cast<Server*>(object);
	return server->_Listener();
}


ServerManager::ServerManager()
	:
	fRoot(NULL)
{
	mutex_init(&fLock, NULL);
}


ServerManager::~ServerManager()
{
	mutex_destroy(&fLock);
}


status_t
ServerManager::Acquire(Server** _server, AddressResolver* resolver,
	ProgramData* (*createPrivateData)(Server*))
{
	PeerAddress address;
	status_t result;

	while ((result = resolver->GetNextAddress(&address)) == B_OK) {
		result = _Acquire(_server, address, createPrivateData);
		if (result == B_OK)
			break;
	}

	return result;
}


status_t
ServerManager::_Acquire(Server** _server, const PeerAddress& address,
	ProgramData* (*createPrivateData)(Server*))
{
	ASSERT(_server != NULL);
	ASSERT(createPrivateData != NULL);

	status_t result;

	MutexLocker locker(fLock);
	ServerNode* node = _Find(address);
	if (node != NULL) {
		node->fRefCount++;
		*_server = node->fServer;

		return B_OK;
	}

	node = new(std::nothrow) ServerNode;
	if (node == NULL)
		return B_NO_MEMORY;

	node->fID = address;

	Connection* conn;
	result = Connection::Connect(&conn, address);
	if (result != B_OK) {
		delete node;
		return result;
	}

	node->fServer = new Server(conn, &node->fID);
	if (node->fServer == NULL) {
		delete node;
		delete conn;
		return B_NO_MEMORY;
	}
	node->fServer->SetPrivateData(createPrivateData(node->fServer));

	node->fRefCount = 1;
	node->fLeft = node->fRight = NULL;

	ServerNode* nd = _Insert(node);
	if (nd != node) {
		nd->fRefCount++;

		delete node->fServer;
		delete node;
		*_server = nd->fServer;
		return B_OK;
	}

	*_server = node->fServer;
	return B_OK;
}


void
ServerManager::Release(Server* server)
{
	ASSERT(server != NULL);

	MutexLocker _(fLock);
	ServerNode* node = _Find(server->ID());
	if (node != NULL) {
		node->fRefCount--;

		if (node->fRefCount == 0) {
			_Delete(node);
			delete node->fServer;
			delete node;
		}
	}
}


ServerNode*
ServerManager::_Find(const PeerAddress& address)
{
	ServerNode* node = fRoot;
	while (node != NULL) {
		if (node->fID == address)
			return node;
		if (node->fID < address)
			node = node->fRight;
		else
			node = node->fLeft;
	}

	return node;
}


void
ServerManager::_Delete(ServerNode* node)
{
	ASSERT(node != NULL);

	bool found = false;
	ServerNode* previous = NULL;
	ServerNode* current = fRoot;
	while (current != NULL) {
		if (current->fID == node->fID) {
			found = true;
			break;
		}

		if (current->fID < node->fID) {
			previous = current;
			current = current->fRight;
		} else {
			previous = current;
			current = current->fLeft;
		}
	}

	if (!found)
		return;

	if (previous == NULL)
		fRoot = NULL;
	else if (current->fLeft == NULL && current->fRight == NULL) {
		if (previous->fID < node->fID)
			previous->fRight = NULL;
		else
			previous->fLeft = NULL;
	} else if (current->fLeft != NULL && current->fRight == NULL) {
		if (previous->fID < node->fID)
			previous->fRight = current->fLeft;
		else
			previous->fLeft = current->fLeft;
	} else if (current->fLeft == NULL && current->fRight != NULL) {
		if (previous->fID < node->fID)
			previous->fRight = current->fRight;
		else
			previous->fLeft = current->fRight;
	} else {
		ServerNode* left_prev = current;
		ServerNode*	left = current->fLeft;

		while (left->fLeft != NULL) {
			left_prev = left;
			left = left->fLeft;
		}

		if (previous->fID < node->fID)
			previous->fRight = left;
		else
			previous->fLeft = left;


		left_prev->fLeft = NULL;
	}
}


ServerNode*
ServerManager::_Insert(ServerNode* node)
{
	ASSERT(node != NULL);

	ServerNode* previous = NULL;
	ServerNode* current = fRoot;
	while (current != NULL) {
		if (current->fID == node->fID)
			return current;
		if (current->fID < node->fID) {
			previous = current;
			current = current->fRight;
		} else {
			previous = current;
			current = current->fLeft;
		}
	}

	if (previous == NULL)
		fRoot = node;
	else if (previous->fID < node->fID)
		previous->fRight = node;
	else
		previous->fLeft = node;

	return node;
}