diff options
| author | Remko Tronçon <git@el-tramo.be> | 2009-11-10 21:24:03 (GMT) | 
|---|---|---|
| committer | Remko Tronçon <git@el-tramo.be> | 2009-11-10 21:24:03 (GMT) | 
| commit | 54781ce12f7654f8136e645d4ebc5934d90c6bea (patch) | |
| tree | 90bad869f9f64d57a3c0af209b83a538a47c7762 | |
| parent | fcfac59db5cb4503554f2b30854b2e91928296f6 (diff) | |
| parent | 66ced3654ad295478b33d3e4f1716f66ab4048b5 (diff) | |
| download | swift-54781ce12f7654f8136e645d4ebc5934d90c6bea.zip swift-54781ce12f7654f8136e645d4ebc5934d90c6bea.tar.bz2 | |
Refactored session management.
46 files changed, 479 insertions, 351 deletions
| diff --git a/Limber/main.cpp b/Limber/main.cpp index 965abc2..25cccec 100644 --- a/Limber/main.cpp +++ b/Limber/main.cpp @@ -63,11 +63,11 @@ class Server {  							session->sendElement(IQ::createResult(iq->getFrom(), iq->getID(), vcard));  						}  						else { -							session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), Error::Forbidden, Error::Cancel)); +							session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), ErrorPayload::Forbidden, ErrorPayload::Cancel));  						}  					}  					else { -						session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), Error::FeatureNotImplemented, Error::Cancel)); +						session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), ErrorPayload::FeatureNotImplemented, ErrorPayload::Cancel));  					}  				}  			} diff --git a/Slimber/Server.cpp b/Slimber/Server.cpp index e07fb41..278a572 100644 --- a/Slimber/Server.cpp +++ b/Slimber/Server.cpp @@ -211,7 +211,7 @@ void Server::handleElementReceived(boost::shared_ptr<Element> element, boost::sh  					}  				}  				else { -					session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), Error::Forbidden, Error::Cancel)); +					session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), ErrorPayload::Forbidden, ErrorPayload::Cancel));  				}  			}  			if (boost::shared_ptr<VCard> vcard = iq->getPayload<VCard>()) { @@ -227,7 +227,7 @@ void Server::handleElementReceived(boost::shared_ptr<Element> element, boost::sh  				}  			}  			else { -				session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), Error::FeatureNotImplemented, Error::Cancel)); +				session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), ErrorPayload::FeatureNotImplemented, ErrorPayload::Cancel));  			}  		}  	} @@ -260,7 +260,7 @@ void Server::handleElementReceived(boost::shared_ptr<Element> element, boost::sh  			else {  				session->sendElement(IQ::createError(  						stanza->getFrom(), stanza->getID(),  -						Error::RecipientUnavailable, Error::Wait)); +						ErrorPayload::RecipientUnavailable, ErrorPayload::Wait));  			}  		}  	} diff --git a/Swift/Controllers/ChatControllerBase.cpp b/Swift/Controllers/ChatControllerBase.cpp index baa715b..2b873f1 100644 --- a/Swift/Controllers/ChatControllerBase.cpp +++ b/Swift/Controllers/ChatControllerBase.cpp @@ -67,7 +67,7 @@ void ChatControllerBase::handleSendMessageRequest(const String &body) {  	postSendMessage(message->getBody());  } -void ChatControllerBase::handleSecurityLabelsCatalogResponse(boost::shared_ptr<SecurityLabelsCatalog> catalog, const boost::optional<Error>& error) { +void ChatControllerBase::handleSecurityLabelsCatalogResponse(boost::shared_ptr<SecurityLabelsCatalog> catalog, const boost::optional<ErrorPayload>& error) {  	if (!error) {  		if (catalog->getLabels().size() == 0) {  			chatWindow_->setSecurityLabelsEnabled(false); @@ -97,7 +97,7 @@ void ChatControllerBase::handleIncomingMessage(boost::shared_ptr<MessageEvent> m  	preHandleIncomingMessage(message);	  	String body = message->getBody();  	if (message->isError()) { -		String errorMessage = getErrorMessage(message->getPayload<Error>()); +		String errorMessage = getErrorMessage(message->getPayload<ErrorPayload>());  		chatWindow_->addErrorMessage(errorMessage);  	}  	else { @@ -109,35 +109,35 @@ void ChatControllerBase::handleIncomingMessage(boost::shared_ptr<MessageEvent> m  	}  } -String ChatControllerBase::getErrorMessage(boost::shared_ptr<Error> error) { +String ChatControllerBase::getErrorMessage(boost::shared_ptr<ErrorPayload> error) {  	String defaultMessage = "Error sending message";  	if (!error->getText().isEmpty()) {  		return error->getText();  	}  	else {  		switch (error->getCondition()) { -			case Error::BadRequest: return defaultMessage; break; -			case Error::Conflict: return defaultMessage; break; -			case Error::FeatureNotImplemented: return defaultMessage; break; -			case Error::Forbidden: return defaultMessage; break; -			case Error::Gone: return "Recipient can no longer be contacted"; break; -			case Error::InternalServerError: return "Internal server error"; break; -			case Error::ItemNotFound: return defaultMessage; break; -			case Error::JIDMalformed: return defaultMessage; break; -			case Error::NotAcceptable: return "Message was rejected"; break; -			case Error::NotAllowed: return defaultMessage; break; -			case Error::NotAuthorized: return defaultMessage; break; -			case Error::PaymentRequired: return defaultMessage; break; -			case Error::RecipientUnavailable: return "Recipient is unavailable."; break; -			case Error::Redirect: return defaultMessage; break; -			case Error::RegistrationRequired: return defaultMessage; break; -			case Error::RemoteServerNotFound: return "Recipient's server not found."; break; -			case Error::RemoteServerTimeout: return defaultMessage; break; -			case Error::ResourceConstraint: return defaultMessage; break; -			case Error::ServiceUnavailable: return defaultMessage; break; -			case Error::SubscriptionRequired: return defaultMessage; break; -			case Error::UndefinedCondition: return defaultMessage; break; -			case Error::UnexpectedRequest: return defaultMessage; break; +			case ErrorPayload::BadRequest: return defaultMessage; break; +			case ErrorPayload::Conflict: return defaultMessage; break; +			case ErrorPayload::FeatureNotImplemented: return defaultMessage; break; +			case ErrorPayload::Forbidden: return defaultMessage; break; +			case ErrorPayload::Gone: return "Recipient can no longer be contacted"; break; +			case ErrorPayload::InternalServerError: return "Internal server error"; break; +			case ErrorPayload::ItemNotFound: return defaultMessage; break; +			case ErrorPayload::JIDMalformed: return defaultMessage; break; +			case ErrorPayload::NotAcceptable: return "Message was rejected"; break; +			case ErrorPayload::NotAllowed: return defaultMessage; break; +			case ErrorPayload::NotAuthorized: return defaultMessage; break; +			case ErrorPayload::PaymentRequired: return defaultMessage; break; +			case ErrorPayload::RecipientUnavailable: return "Recipient is unavailable."; break; +			case ErrorPayload::Redirect: return defaultMessage; break; +			case ErrorPayload::RegistrationRequired: return defaultMessage; break; +			case ErrorPayload::RemoteServerNotFound: return "Recipient's server not found."; break; +			case ErrorPayload::RemoteServerTimeout: return defaultMessage; break; +			case ErrorPayload::ResourceConstraint: return defaultMessage; break; +			case ErrorPayload::ServiceUnavailable: return defaultMessage; break; +			case ErrorPayload::SubscriptionRequired: return defaultMessage; break; +			case ErrorPayload::UndefinedCondition: return defaultMessage; break; +			case ErrorPayload::UnexpectedRequest: return defaultMessage; break;  		}  	}  	return defaultMessage; diff --git a/Swift/Controllers/ChatControllerBase.h b/Swift/Controllers/ChatControllerBase.h index 601e56b..91b72a8 100644 --- a/Swift/Controllers/ChatControllerBase.h +++ b/Swift/Controllers/ChatControllerBase.h @@ -12,7 +12,7 @@  #include "Swiften/Events/MessageEvent.h"  #include "Swiften/JID/JID.h"  #include "Swiften/Elements/SecurityLabelsCatalog.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h"  #include "Swiften/Presence/PresenceOracle.h"  #include "Swiften/Queries/IQRouter.h" @@ -44,8 +44,8 @@ namespace Swift {  		private:  			void handleSendMessageRequest(const String &body);  			void handleAllMessagesRead(); -			void handleSecurityLabelsCatalogResponse(boost::shared_ptr<SecurityLabelsCatalog>, const boost::optional<Error>& error); -			String getErrorMessage(boost::shared_ptr<Error>); +			void handleSecurityLabelsCatalogResponse(boost::shared_ptr<SecurityLabelsCatalog>, const boost::optional<ErrorPayload>& error); +			String getErrorMessage(boost::shared_ptr<ErrorPayload>);  		protected:  			JID selfJID_; diff --git a/Swift/Controllers/MainController.cpp b/Swift/Controllers/MainController.cpp index 9df2308..6c60783 100644 --- a/Swift/Controllers/MainController.cpp +++ b/Swift/Controllers/MainController.cpp @@ -389,7 +389,7 @@ void MainController::handleIncomingMessage(boost::shared_ptr<Message> message) {  	}  } -void MainController::handleServerDiscoInfoResponse(boost::shared_ptr<DiscoInfo> info, const boost::optional<Error>& error) { +void MainController::handleServerDiscoInfoResponse(boost::shared_ptr<DiscoInfo> info, const boost::optional<ErrorPayload>& error) {  	if (!error) {  		serverDiscoInfo_ = info;  		foreach (JIDChatControllerPair pair, chatControllers_) { @@ -405,7 +405,7 @@ bool MainController::isMUC(const JID& jid) const {  	return mucControllers_.find(jid.toBare()) != mucControllers_.end();  } -void MainController::handleOwnVCardReceived(boost::shared_ptr<VCard> vCard, const boost::optional<Error>& error) { +void MainController::handleOwnVCardReceived(boost::shared_ptr<VCard> vCard, const boost::optional<ErrorPayload>& error) {  	if (!error && !vCard->getPhoto().isEmpty()) {  		vCardPhotoHash_ = SHA1::getHexHash(vCard->getPhoto());  		if (lastSentPresence_) { diff --git a/Swift/Controllers/MainController.h b/Swift/Controllers/MainController.h index 3179df9..db6a110 100644 --- a/Swift/Controllers/MainController.h +++ b/Swift/Controllers/MainController.h @@ -10,7 +10,7 @@  #include "Swiften/JID/JID.h"  #include "Swiften/Elements/VCard.h"  #include "Swiften/Elements/DiscoInfo.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h"  #include "Swiften/Elements/Presence.h"  #include "Swiften/Elements/Message.h"  #include "Swiften/Settings/SettingsProvider.h" @@ -64,9 +64,9 @@ namespace Swift {  			void handleIncomingMessage(boost::shared_ptr<Message> message);  			void handleChangeStatusRequest(StatusShow::Type show, const String &statusText);  			void handleError(const ClientError& error); -			void handleServerDiscoInfoResponse(boost::shared_ptr<DiscoInfo>, const boost::optional<Error>&); +			void handleServerDiscoInfoResponse(boost::shared_ptr<DiscoInfo>, const boost::optional<ErrorPayload>&);  			void handleEventQueueLengthChange(int count); -			void handleOwnVCardReceived(boost::shared_ptr<VCard> vCard, const boost::optional<Error>& error); +			void handleOwnVCardReceived(boost::shared_ptr<VCard> vCard, const boost::optional<ErrorPayload>& error);  			ChatController* getChatController(const JID &contact);  			void sendPresence(boost::shared_ptr<Presence> presence);  			void handleInputIdle(); diff --git a/Swiften/Avatars/AvatarManager.cpp b/Swiften/Avatars/AvatarManager.cpp index 6a1efc6..574e199 100644 --- a/Swiften/Avatars/AvatarManager.cpp +++ b/Swiften/Avatars/AvatarManager.cpp @@ -35,7 +35,7 @@ void AvatarManager::handlePresenceReceived(boost::shared_ptr<Presence> presence)  	}  } -void AvatarManager::handleVCardReceived(const JID& from, const String& promisedHash, boost::shared_ptr<VCard> vCard, const boost::optional<Error>& error) { +void AvatarManager::handleVCardReceived(const JID& from, const String& promisedHash, boost::shared_ptr<VCard> vCard, const boost::optional<ErrorPayload>& error) {  	if (error) {  		// FIXME: What to do here?  		std::cerr << "Warning: " << from << ": Could not get vCard" << std::endl; diff --git a/Swiften/Avatars/AvatarManager.h b/Swiften/Avatars/AvatarManager.h index 3ac4433..65ec372 100644 --- a/Swiften/Avatars/AvatarManager.h +++ b/Swiften/Avatars/AvatarManager.h @@ -9,7 +9,7 @@  #include "Swiften/JID/JID.h"  #include "Swiften/Elements/Presence.h"  #include "Swiften/Elements/VCard.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h"  namespace Swift {  	class MUCRegistry; @@ -30,7 +30,7 @@ namespace Swift {  		private:  			void handlePresenceReceived(boost::shared_ptr<Presence>); -			void handleVCardReceived(const JID& from, const String& hash, boost::shared_ptr<VCard>, const boost::optional<Error>&); +			void handleVCardReceived(const JID& from, const String& hash, boost::shared_ptr<VCard>, const boost::optional<ErrorPayload>&);  			void setAvatarHash(const JID& from, const String& hash);  			JID getAvatarJID(const JID& o) const; diff --git a/Swiften/Base/Error.cpp b/Swiften/Base/Error.cpp new file mode 100644 index 0000000..597c155 --- /dev/null +++ b/Swiften/Base/Error.cpp @@ -0,0 +1,8 @@ +#include "Swiften/Base/Error.h" + +namespace Swift { + +Error::~Error() { +} + +} diff --git a/Swiften/Base/Error.h b/Swiften/Base/Error.h new file mode 100644 index 0000000..4c729ff --- /dev/null +++ b/Swiften/Base/Error.h @@ -0,0 +1,8 @@ +#pragma once + +namespace Swift { +	class Error { +		public: +			virtual ~Error(); +	}; +}; diff --git a/Swiften/Base/SConscript b/Swiften/Base/SConscript index d308e11..a0984e5 100644 --- a/Swiften/Base/SConscript +++ b/Swiften/Base/SConscript @@ -2,6 +2,7 @@ Import("swiften_env")  objects = swiften_env.StaticObject([  			"ByteArray.cpp", +			"Error.cpp",  			"IDGenerator.cpp",  			"String.cpp",  			"sleep.cpp", diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp index 60dfade..9e38626 100644 --- a/Swiften/Client/Client.cpp +++ b/Swiften/Client/Client.cpp @@ -10,6 +10,7 @@  #include "Swiften/Network/BoostConnectionFactory.h"  #include "Swiften/Network/DomainNameResolveException.h"  #include "Swiften/TLS/PKCS12Certificate.h" +#include "Swiften/Session/BasicSessionStream.h"  namespace Swift { @@ -20,6 +21,9 @@ Client::Client(const JID& jid, const String& password) :  }  Client::~Client() { +	if (session_ || connection_) { +		std::cerr << "Warning: Client not disconnected properly" << std::endl; +	}  	delete tlsLayerFactory_;  	delete connectionFactory_;  } @@ -46,23 +50,32 @@ void Client::handleConnectionConnectFinished(bool error) {  		onError(ClientError::ConnectionError);  	}  	else { -		session_ = boost::shared_ptr<ClientSession>(new ClientSession(jid_, connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_)); +		assert(!sessionStream_); +		sessionStream_ = boost::shared_ptr<BasicSessionStream>(new BasicSessionStream(connection_, &payloadParserFactories_, &payloadSerializers_, tlsLayerFactory_));  		if (!certificate_.isEmpty()) { -			session_->setCertificate(PKCS12Certificate(certificate_, password_)); +			sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_));  		} -		session_->onSessionStarted.connect(boost::bind(boost::ref(onConnected))); -		session_->onSessionFinished.connect(boost::bind(&Client::handleSessionFinished, this, _1)); +		sessionStream_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1)); +		sessionStream_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1)); +		sessionStream_->initialize(); + +		session_ = boost::shared_ptr<ClientSession>(new ClientSession(jid_, sessionStream_)); +		session_->onInitialized.connect(boost::bind(boost::ref(onConnected))); +		session_->onFinished.connect(boost::bind(&Client::handleSessionFinished, this, _1));  		session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, this)); -		session_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1)); -		session_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1));  		session_->onElementReceived.connect(boost::bind(&Client::handleElement, this, _1)); -		session_->startSession(); +		session_->start();  	}  }  void Client::disconnect() {  	if (session_) { -		session_->finishSession(); +		session_->finish(); +		session_.reset(); +	} +	if (connection_) { +		connection_->disconnect(); +		connection_.reset();  	}  } @@ -110,9 +123,10 @@ void Client::setCertificate(const String& certificate) {  	certificate_ = certificate;  } -void Client::handleSessionFinished(const boost::optional<Session::SessionError>& error) { +void Client::handleSessionFinished(boost::shared_ptr<Error> error) {  	if (error) {  		ClientError clientError; +		/*  		switch (*error) {  			case Session::ConnectionReadError:  				clientError = ClientError(ClientError::ConnectionReadError); @@ -148,6 +162,7 @@ void Client::handleSessionFinished(const boost::optional<Session::SessionError>&  				clientError = ClientError(ClientError::ClientCertificateError);  				break;  		} +		*/  		onError(clientError);  	}  } @@ -156,12 +171,12 @@ void Client::handleNeedCredentials() {  	session_->sendCredentials(password_);  } -void Client::handleDataRead(const ByteArray& data) { -	onDataRead(String(data.getData(), data.getSize())); +void Client::handleDataRead(const String& data) { +  onDataRead(data);  } -void Client::handleDataWritten(const ByteArray& data) { -	onDataWritten(String(data.getData(), data.getSize())); +void Client::handleDataWritten(const String& data) { +  onDataWritten(data);  }  } diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h index 59e1c05..5188789 100644 --- a/Swiften/Client/Client.h +++ b/Swiften/Client/Client.h @@ -4,6 +4,7 @@  #include <boost/signals.hpp>  #include <boost/shared_ptr.hpp> +#include "Swiften/Base/Error.h"  #include "Swiften/Client/ClientSession.h"  #include "Swiften/Client/ClientError.h"  #include "Swiften/Elements/Presence.h" @@ -20,6 +21,7 @@ namespace Swift {  	class TLSLayerFactory;  	class ConnectionFactory;  	class ClientSession; +	class BasicSessionStream;  	class Client : public StanzaChannel, public IQRouter, public boost::bsignals::trackable {  		public: @@ -38,7 +40,7 @@ namespace Swift {  			virtual void sendPresence(boost::shared_ptr<Presence>);  		public: -			boost::signal<void (ClientError)> onError; +			boost::signal<void (const ClientError&)> onError;  			boost::signal<void ()> onConnected;  			boost::signal<void (const String&)> onDataRead;  			boost::signal<void (const String&)> onDataWritten; @@ -48,10 +50,12 @@ namespace Swift {  			void send(boost::shared_ptr<Stanza>);  			virtual String getNewIQID();  			void handleElement(boost::shared_ptr<Element>); -			void handleSessionFinished(const boost::optional<Session::SessionError>& error); +			void handleSessionFinished(boost::shared_ptr<Error>);  			void handleNeedCredentials(); -			void handleDataRead(const ByteArray&); -			void handleDataWritten(const ByteArray&); +			void handleDataRead(const String&); +			void handleDataWritten(const String&); + +			void reset();  		private:  			JID jid_; @@ -61,8 +65,9 @@ namespace Swift {  			TLSLayerFactory* tlsLayerFactory_;  			FullPayloadParserFactoryCollection payloadParserFactories_;  			FullPayloadSerializerCollection payloadSerializers_; -			boost::shared_ptr<ClientSession> session_;  			boost::shared_ptr<Connection> connection_; +			boost::shared_ptr<BasicSessionStream> sessionStream_; +			boost::shared_ptr<ClientSession> session_;  			String certificate_;  	};  } diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index a0e1289..a185ea0 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -2,13 +2,7 @@  #include <boost/bind.hpp> -#include "Swiften/Network/ConnectionFactory.h"  #include "Swiften/Elements/ProtocolHeader.h" -#include "Swiften/StreamStack/StreamStack.h" -#include "Swiften/StreamStack/ConnectionLayer.h" -#include "Swiften/StreamStack/XMPPLayer.h" -#include "Swiften/StreamStack/TLSLayer.h" -#include "Swiften/StreamStack/TLSLayerFactory.h"  #include "Swiften/Elements/StreamFeatures.h"  #include "Swiften/Elements/StartTLSRequest.h"  #include "Swiften/Elements/StartTLSFailure.h" @@ -20,47 +14,47 @@  #include "Swiften/Elements/IQ.h"  #include "Swiften/Elements/ResourceBind.h"  #include "Swiften/SASL/PLAINMessage.h" -#include "Swiften/StreamStack/WhitespacePingLayer.h" +#include "Swiften/Session/SessionStream.h"  namespace Swift {  ClientSession::ClientSession(  		const JID& jid,  -		boost::shared_ptr<Connection> connection, -		TLSLayerFactory* tlsLayerFactory,  -		PayloadParserFactoryCollection* payloadParserFactories,  -		PayloadSerializerCollection* payloadSerializers) :  -			Session(connection, payloadParserFactories, payloadSerializers), -			tlsLayerFactory_(tlsLayerFactory), -			state_(Initial),  -			needSessionStart_(false) { -	setLocalJID(jid); -	setRemoteJID(JID("", jid.getDomain())); +		boost::shared_ptr<SessionStream> stream) : +			localJID(jid),	 +			state(Initial),  +			stream(stream), +			needSessionStart(false) {  } -void ClientSession::handleSessionStarted() { -	assert(state_ == Initial); -	state_ = WaitingForStreamStart; +void ClientSession::start() { +	stream->onStreamStartReceived.connect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); +	stream->onElementReceived.connect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); +	stream->onError.connect(boost::bind(&ClientSession::handleStreamError, shared_from_this(), _1)); +	stream->onTLSEncrypted.connect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); + +	assert(state == Initial); +	state = WaitingForStreamStart;  	sendStreamHeader();  }  void ClientSession::sendStreamHeader() {  	ProtocolHeader header;  	header.setTo(getRemoteJID()); -	getXMPPLayer()->writeHeader(header); +	stream->writeHeader(header);  } -void ClientSession::setCertificate(const PKCS12Certificate& certificate) { -	certificate_ = certificate; +void ClientSession::sendElement(boost::shared_ptr<Element> element) { +	stream->writeElement(element);  }  void ClientSession::handleStreamStart(const ProtocolHeader&) {  	checkState(WaitingForStreamStart); -	state_ = Negotiating; +	state = Negotiating;  }  void ClientSession::handleElement(boost::shared_ptr<Element> element) { -	if (getState() == SessionStarted) { +	if (getState() == Initialized) {  		onElementReceived(element);  	}  	else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) { @@ -68,152 +62,121 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {  			return;  		} -		if (streamFeatures->hasStartTLS() && tlsLayerFactory_->canCreate()) { -			state_ = Encrypting; -			getXMPPLayer()->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())); +		if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption()) { +			state = WaitingForEncrypt; +			stream->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest()));  		}  		else if (streamFeatures->hasAuthenticationMechanisms()) { -			if (!certificate_.isNull()) { -				if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { -					state_ = Authenticating; -					getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", ""))); -				} -				else { -					finishSession(ClientCertificateError); -				} +			if (stream->hasTLSCertificate() && streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { +					state = Authenticating; +					stream->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", "")));  			}  			else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) { -				state_ = WaitingForCredentials; +				state = WaitingForCredentials;  				onNeedCredentials();  			}  			else { -				finishSession(NoSupportedAuthMechanismsError); +				finishSession(Error::NoSupportedAuthMechanismsError);  			}  		}  		else {  			// Start the session - -			// Add a whitespace ping layer -			whitespacePingLayer_ = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer()); -			getStreamStack()->addLayer(whitespacePingLayer_); -			whitespacePingLayer_->setActive(); +			stream->setWhitespacePingEnabled(true);  			if (streamFeatures->hasSession()) { -				needSessionStart_ = true; +				needSessionStart = true;  			}  			if (streamFeatures->hasResourceBind()) { -				state_ = BindingResource; +				state = BindingResource;  				boost::shared_ptr<ResourceBind> resourceBind(new ResourceBind()); -				if (!getLocalJID().getResource().isEmpty()) { -					resourceBind->setResource(getLocalJID().getResource()); +				if (!localJID.getResource().isEmpty()) { +					resourceBind->setResource(localJID.getResource());  				} -				getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); +				stream->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind));  			} -			else if (needSessionStart_) { +			else if (needSessionStart) {  				sendSessionStart();  			}  			else { -				state_ = SessionStarted; -				onSessionStarted(); +				state = Initialized; +				onInitialized();  			}  		}  	}  	else if (dynamic_cast<AuthSuccess*>(element.get())) {  		checkState(Authenticating); -		state_ = WaitingForStreamStart; -		getXMPPLayer()->resetParser(); +		state = WaitingForStreamStart; +		stream->resetXMPPParser();  		sendStreamHeader();  	}  	else if (dynamic_cast<AuthFailure*>(element.get())) { -		finishSession(AuthenticationFailedError); +		finishSession(Error::AuthenticationFailedError);  	}  	else if (dynamic_cast<TLSProceed*>(element.get())) { -		tlsLayer_ = tlsLayerFactory_->createTLSLayer(); -		getStreamStack()->addLayer(tlsLayer_); -		if (!certificate_.isNull() && !tlsLayer_->setClientCertificate(certificate_)) { -			finishSession(ClientCertificateLoadError); -		} -		else { -			tlsLayer_->onConnected.connect(boost::bind(&ClientSession::handleTLSConnected, this)); -			tlsLayer_->onError.connect(boost::bind(&ClientSession::handleTLSError, this)); -			tlsLayer_->connect(); -		} +		checkState(WaitingForEncrypt); +		state = Encrypting; +		stream->addTLSEncryption();  	}  	else if (dynamic_cast<StartTLSFailure*>(element.get())) { -		finishSession(TLSError); +		finishSession(Error::TLSError);  	}  	else if (IQ* iq = dynamic_cast<IQ*>(element.get())) { -		if (state_ == BindingResource) { +		if (state == BindingResource) {  			boost::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>());  			if (iq->getType() == IQ::Error && iq->getID() == "session-bind") { -				finishSession(ResourceBindError); +				finishSession(Error::ResourceBindError);  			}  			else if (!resourceBind) { -				finishSession(UnexpectedElementError); +				finishSession(Error::UnexpectedElementError);  			}  			else if (iq->getType() == IQ::Result) { -				setLocalJID(resourceBind->getJID()); -				if (!getLocalJID().isValid()) { -					finishSession(ResourceBindError); +				localJID = resourceBind->getJID(); +				if (!localJID.isValid()) { +					finishSession(Error::ResourceBindError);  				} -				if (needSessionStart_) { +				if (needSessionStart) {  					sendSessionStart();  				}  				else { -					state_ = SessionStarted; +					state = Initialized;  				}  			}  			else { -				finishSession(UnexpectedElementError); +				finishSession(Error::UnexpectedElementError);  			}  		} -		else if (state_ == StartingSession) { +		else if (state == StartingSession) {  			if (iq->getType() == IQ::Result) { -				state_ = SessionStarted; -				onSessionStarted(); +				state = Initialized; +				onInitialized();  			}  			else if (iq->getType() == IQ::Error) { -				finishSession(SessionStartError); +				finishSession(Error::SessionStartError);  			}  			else { -				finishSession(UnexpectedElementError); +				finishSession(Error::UnexpectedElementError);  			}  		}  		else { -			finishSession(UnexpectedElementError); +			finishSession(Error::UnexpectedElementError);  		}  	}  	else {  		// FIXME Not correct? -		state_ = SessionStarted; -		onSessionStarted(); +		state = Initialized; +		onInitialized();  	}  }  void ClientSession::sendSessionStart() { -	state_ = StartingSession; -	getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession()))); -} - -void ClientSession::handleSessionFinished(const boost::optional<SessionError>& error) { -	if (whitespacePingLayer_) { -		whitespacePingLayer_->setInactive(); -	} -	 -	if (error) { -		//assert(!error_); -		state_ = Error; -		error_ = error; -	} -	else { -		state_ = Finished; -	} +	state = StartingSession; +	stream->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession())));  }  bool ClientSession::checkState(State state) { -	if (state_ != state) { -		finishSession(UnexpectedElementError); +	if (state != state) { +		finishSession(Error::UnexpectedElementError);  		return false;  	}  	return true; @@ -221,18 +184,36 @@ bool ClientSession::checkState(State state) {  void ClientSession::sendCredentials(const String& password) {  	assert(WaitingForCredentials); -	state_ = Authenticating; -	getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(getLocalJID().getNode(), password).getValue()))); +	state = Authenticating; +	stream->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(localJID.getNode(), password).getValue())));  } -void ClientSession::handleTLSConnected() { -	state_ = WaitingForStreamStart; -	getXMPPLayer()->resetParser(); +void ClientSession::handleTLSEncrypted() { +	checkState(WaitingForEncrypt); +	state = WaitingForStreamStart; +	stream->resetXMPPParser();  	sendStreamHeader();  } -void ClientSession::handleTLSError() { -	finishSession(TLSError); +void ClientSession::handleStreamError(boost::shared_ptr<Swift::Error> error) { +	finishSession(error); +} + +void ClientSession::finish() { +	if (stream->isAvailable()) { +		stream->writeFooter(); +	} +	finishSession(boost::shared_ptr<Error>());  } +void ClientSession::finishSession(Error::Type error) { +	finishSession(boost::shared_ptr<Swift::ClientSession::Error>(new Swift::ClientSession::Error(error))); +} + +void ClientSession::finishSession(boost::shared_ptr<Swift::Error> error) { +	stream->setWhitespacePingEnabled(false); +	onFinished(error); +} + +  } diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h index fead182..e09861b 100644 --- a/Swiften/Client/ClientSession.h +++ b/Swiften/Client/ClientSession.h @@ -2,87 +2,88 @@  #include <boost/signal.hpp>  #include <boost/shared_ptr.hpp> +#include <boost/enable_shared_from_this.hpp> -#include "Swiften/Session/Session.h" +#include "Swiften/Base/Error.h" +#include "Swiften/Session/SessionStream.h" +#include "Swiften/Session/BasicSessionStream.h"  #include "Swiften/Base/String.h"  #include "Swiften/JID/JID.h"  #include "Swiften/Elements/Element.h" -#include "Swiften/Network/Connection.h" -#include "Swiften/TLS/PKCS12Certificate.h"  namespace Swift { -	class PayloadParserFactoryCollection; -	class PayloadSerializerCollection; -	class ConnectionFactory; -	class Connection; -	class StreamStack; -	class XMPPLayer; -	class ConnectionLayer; -	class TLSLayerFactory; -	class TLSLayer; -	class WhitespacePingLayer; - -	class ClientSession : public Session { +	class ClientSession : public boost::enable_shared_from_this<ClientSession> {  		public:  			enum State {  				Initial,  				WaitingForStreamStart,  				Negotiating,  				Compressing, +				WaitingForEncrypt,  				Encrypting,  				WaitingForCredentials,  				Authenticating,  				BindingResource,  				StartingSession, -				SessionStarted, -				Error, +				Initialized,  				Finished  			}; +			struct Error : public Swift::Error { +				enum Type { +					AuthenticationFailedError, +					NoSupportedAuthMechanismsError, +					UnexpectedElementError, +					ResourceBindError, +					SessionStartError, +					TLSError, +				} type; +				Error(Type type) : type(type) {} +			}; +  			ClientSession(  					const JID& jid,  -					boost::shared_ptr<Connection>,  -					TLSLayerFactory*,  -					PayloadParserFactoryCollection*,  -					PayloadSerializerCollection*); +					boost::shared_ptr<SessionStream>);  			State getState() const { -				return state_; +				return state;  			} -			boost::optional<SessionError> getError() const { -				return error_; -			} +			void start(); +			void finish();  			void sendCredentials(const String& password); -			void setCertificate(const PKCS12Certificate& certificate); +			void sendElement(boost::shared_ptr<Element> element);  		private: +			void finishSession(Error::Type error); +			void finishSession(boost::shared_ptr<Swift::Error> error); + +			JID getRemoteJID() const { +				return JID("", localJID.getDomain()); +			} +  			void sendStreamHeader();  			void sendSessionStart(); -			virtual void handleSessionStarted(); -			virtual void handleSessionFinished(const boost::optional<SessionError>& error); -			virtual void handleElement(boost::shared_ptr<Element>); -			virtual void handleStreamStart(const ProtocolHeader&); +			void handleElement(boost::shared_ptr<Element>); +			void handleStreamStart(const ProtocolHeader&); +			void handleStreamError(boost::shared_ptr<Swift::Error>); -			void handleTLSConnected(); -			void handleTLSError(); +			void handleTLSEncrypted(); -			void setError(SessionError);  			bool checkState(State);  		public:  			boost::signal<void ()> onNeedCredentials; -			boost::signal<void ()> onSessionStarted; +			boost::signal<void ()> onInitialized; +			boost::signal<void (boost::shared_ptr<Swift::Error>)> onFinished; +			boost::signal<void (boost::shared_ptr<Element>)> onElementReceived;  		private: -			TLSLayerFactory* tlsLayerFactory_; -			State state_; -			boost::optional<SessionError> error_; -			boost::shared_ptr<TLSLayer> tlsLayer_; -			boost::shared_ptr<WhitespacePingLayer> whitespacePingLayer_; -			bool needSessionStart_; -			PKCS12Certificate certificate_; +			JID localJID; +			State state; +			boost::shared_ptr<SessionStream> stream; +			bool needSessionStart;  	};  } diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index cbf20d2..70d4ba9 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -14,7 +14,7 @@  #include "Swiften/Elements/ProtocolHeader.h"  #include "Swiften/Elements/StreamFeatures.h"  #include "Swiften/Elements/Element.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h"  #include "Swiften/Elements/IQ.h"  #include "Swiften/Elements/AuthRequest.h"  #include "Swiften/Elements/AuthSuccess.h" diff --git a/Swiften/Elements/Error.h b/Swiften/Elements/ErrorPayload.h index 8793f35..32fd067 100644 --- a/Swiften/Elements/Error.h +++ b/Swiften/Elements/ErrorPayload.h @@ -1,11 +1,10 @@ -#ifndef SWIFTEN_Error_H -#define SWIFTEN_Error_H +#pragma once  #include "Swiften/Elements/Payload.h"  #include "Swiften/Base/String.h"  namespace Swift { -	class Error : public Payload { +	class ErrorPayload : public Payload {  		public:  			enum Type { Cancel, Continue, Modify, Auth, Wait }; @@ -34,7 +33,7 @@ namespace Swift {  				UnexpectedRequest  			}; -			Error(Condition condition = UndefinedCondition, Type type = Cancel, const String& text = String()) : type_(type), condition_(condition), text_(text) { } +			ErrorPayload(Condition condition = UndefinedCondition, Type type = Cancel, const String& text = String()) : type_(type), condition_(condition), text_(text) { }  			Type getType() const {  				return type_;  @@ -66,5 +65,3 @@ namespace Swift {  			String text_;  	};  } - -#endif diff --git a/Swiften/Elements/IQ.cpp b/Swiften/Elements/IQ.cpp index 3f47182..53dec53 100644 --- a/Swiften/Elements/IQ.cpp +++ b/Swiften/Elements/IQ.cpp @@ -26,11 +26,11 @@ boost::shared_ptr<IQ> IQ::createResult(  	return iq;  } -boost::shared_ptr<IQ> IQ::createError(const JID& to, const String& id, Error::Condition condition, Error::Type type) { +boost::shared_ptr<IQ> IQ::createError(const JID& to, const String& id, ErrorPayload::Condition condition, ErrorPayload::Type type) {  	boost::shared_ptr<IQ> iq(new IQ(IQ::Error));  	iq->setTo(to);  	iq->setID(id); -	iq->addPayload(boost::shared_ptr<Swift::Error>(new Swift::Error(condition, type))); +	iq->addPayload(boost::shared_ptr<Swift::ErrorPayload>(new Swift::ErrorPayload(condition, type)));  	return iq;  } diff --git a/Swiften/Elements/IQ.h b/Swiften/Elements/IQ.h index 231439f..80c2913 100644 --- a/Swiften/Elements/IQ.h +++ b/Swiften/Elements/IQ.h @@ -1,8 +1,7 @@ -#ifndef SWIFTEN_IQ_H -#define SWIFTEN_IQ_H +#pragma once  #include "Swiften/Elements/Stanza.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h"  namespace Swift   { @@ -28,12 +27,10 @@ namespace Swift  			static boost::shared_ptr<IQ> createError(  					const JID& to,  					const String& id, -					Error::Condition condition, -					Error::Type type); +					ErrorPayload::Condition condition, +					ErrorPayload::Type type);  		private:  			Type type_;  	};  } - -#endif diff --git a/Swiften/Elements/Message.h b/Swiften/Elements/Message.h index a49f496..6d9171f 100644 --- a/Swiften/Elements/Message.h +++ b/Swiften/Elements/Message.h @@ -1,11 +1,10 @@ -#ifndef SWIFTEN_STANZAS_MESSAGE_H -#define SWIFTEN_STANZAS_MESSAGE_H +#pragma once  #include <boost/optional.hpp>  #include "Swiften/Base/String.h"  #include "Swiften/Elements/Body.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h"  #include "Swiften/Elements/Stanza.h"  namespace Swift @@ -30,8 +29,8 @@ namespace Swift  			}  			bool isError() { -				boost::shared_ptr<Swift::Error> error(getPayload<Swift::Error>()); -				return getType() == Message::Error || error.get() != NULL; +				boost::shared_ptr<Swift::ErrorPayload> error(getPayload<Swift::ErrorPayload>()); +				return getType() == Message::Error || error;  			}  			Type getType() const { return type_; } @@ -42,5 +41,3 @@ namespace Swift  			Type type_;  	};  } - -#endif diff --git a/Swiften/Elements/UnitTest/IQTest.cpp b/Swiften/Elements/UnitTest/IQTest.cpp index bc22c81..a5e6dc8 100644 --- a/Swiften/Elements/UnitTest/IQTest.cpp +++ b/Swiften/Elements/UnitTest/IQTest.cpp @@ -37,14 +37,14 @@ class IQTest : public CppUnit::TestFixture  		}  		void testCreateError() { -			boost::shared_ptr<IQ> iq(IQ::createError(JID("foo@bar/fum"), "myid", Error::BadRequest, Error::Modify)); +			boost::shared_ptr<IQ> iq(IQ::createError(JID("foo@bar/fum"), "myid", ErrorPayload::BadRequest, ErrorPayload::Modify));  			CPPUNIT_ASSERT_EQUAL(JID("foo@bar/fum"), iq->getTo());  			CPPUNIT_ASSERT_EQUAL(String("myid"), iq->getID()); -			boost::shared_ptr<Error> error(iq->getPayload<Error>()); +			boost::shared_ptr<ErrorPayload> error(iq->getPayload<ErrorPayload>());  			CPPUNIT_ASSERT(error); -			CPPUNIT_ASSERT_EQUAL(Error::BadRequest, error->getCondition()); -			CPPUNIT_ASSERT_EQUAL(Error::Modify, error->getType()); +			CPPUNIT_ASSERT_EQUAL(ErrorPayload::BadRequest, error->getCondition()); +			CPPUNIT_ASSERT_EQUAL(ErrorPayload::Modify, error->getType());  		}  }; diff --git a/Swiften/EventLoop/SimpleEventLoop.cpp b/Swiften/EventLoop/SimpleEventLoop.cpp index 8191747..7c46ed3 100644 --- a/Swiften/EventLoop/SimpleEventLoop.cpp +++ b/Swiften/EventLoop/SimpleEventLoop.cpp @@ -12,6 +12,12 @@ void nop() {}  SimpleEventLoop::SimpleEventLoop() : isRunning_(true) {  } +SimpleEventLoop::~SimpleEventLoop() { +	if (!events_.empty()) { +		std::cerr << "Warning: Pending events in SimpleEventLoop at destruction time" << std::endl; +	} +} +  void SimpleEventLoop::run() {  	while (isRunning_) {  		std::vector<Event> events; diff --git a/Swiften/EventLoop/SimpleEventLoop.h b/Swiften/EventLoop/SimpleEventLoop.h index 01afdb2..bd0a07f 100644 --- a/Swiften/EventLoop/SimpleEventLoop.h +++ b/Swiften/EventLoop/SimpleEventLoop.h @@ -1,5 +1,4 @@ -#ifndef SWIFTEN_SimpleEventLoop_H -#define SWIFTEN_SimpleEventLoop_H +#pragma once  #include <vector>  #include <boost/function.hpp> @@ -12,6 +11,7 @@ namespace Swift {  	class SimpleEventLoop : public EventLoop {  		public:  			SimpleEventLoop(); +			~SimpleEventLoop();  			void run();  			void stop(); @@ -28,4 +28,3 @@ namespace Swift {  			boost::condition_variable eventsAvailable_;  	};  } -#endif diff --git a/Swiften/Parser/PayloadParsers/ErrorParser.cpp b/Swiften/Parser/PayloadParsers/ErrorParser.cpp index 13380c8..ae85265 100644 --- a/Swiften/Parser/PayloadParsers/ErrorParser.cpp +++ b/Swiften/Parser/PayloadParsers/ErrorParser.cpp @@ -9,19 +9,19 @@ void ErrorParser::handleStartElement(const String&, const String&, const Attribu  	if (level_ == TopLevel) {  		String type = attributes.getAttribute("type");  		if (type == "continue") { -			getPayloadInternal()->setType(Error::Continue); +			getPayloadInternal()->setType(ErrorPayload::Continue);  		}  		else if (type == "modify") { -			getPayloadInternal()->setType(Error::Modify); +			getPayloadInternal()->setType(ErrorPayload::Modify);  		}  		else if (type == "auth") { -			getPayloadInternal()->setType(Error::Auth); +			getPayloadInternal()->setType(ErrorPayload::Auth);  		}  		else if (type == "wait") { -			getPayloadInternal()->setType(Error::Wait); +			getPayloadInternal()->setType(ErrorPayload::Wait);  		}  		else { -			getPayloadInternal()->setType(Error::Cancel); +			getPayloadInternal()->setType(ErrorPayload::Cancel);  		}  	}  	++level_; @@ -34,70 +34,70 @@ void ErrorParser::handleEndElement(const String& element, const String&) {  			getPayloadInternal()->setText(currentText_);  		}  		else if (element == "bad-request") { -			getPayloadInternal()->setCondition(Error::BadRequest); +			getPayloadInternal()->setCondition(ErrorPayload::BadRequest);  		}  		else if (element == "conflict") { -			getPayloadInternal()->setCondition(Error::Conflict); +			getPayloadInternal()->setCondition(ErrorPayload::Conflict);  		}  		else if (element == "feature-not-implemented") { -			getPayloadInternal()->setCondition(Error::FeatureNotImplemented); +			getPayloadInternal()->setCondition(ErrorPayload::FeatureNotImplemented);  		}  		else if (element == "forbidden") { -			getPayloadInternal()->setCondition(Error::Forbidden); +			getPayloadInternal()->setCondition(ErrorPayload::Forbidden);  		}  		else if (element == "gone") { -			getPayloadInternal()->setCondition(Error::Gone); +			getPayloadInternal()->setCondition(ErrorPayload::Gone);  		}  		else if (element == "internal-server-error") { -			getPayloadInternal()->setCondition(Error::InternalServerError); +			getPayloadInternal()->setCondition(ErrorPayload::InternalServerError);  		}  		else if (element == "item-not-found") { -			getPayloadInternal()->setCondition(Error::ItemNotFound); +			getPayloadInternal()->setCondition(ErrorPayload::ItemNotFound);  		}  		else if (element == "jid-malformed") { -			getPayloadInternal()->setCondition(Error::JIDMalformed); +			getPayloadInternal()->setCondition(ErrorPayload::JIDMalformed);  		}  		else if (element == "not-acceptable") { -			getPayloadInternal()->setCondition(Error::NotAcceptable); +			getPayloadInternal()->setCondition(ErrorPayload::NotAcceptable);  		}  		else if (element == "not-allowed") { -			getPayloadInternal()->setCondition(Error::NotAllowed); +			getPayloadInternal()->setCondition(ErrorPayload::NotAllowed);  		}  		else if (element == "not-authorized") { -			getPayloadInternal()->setCondition(Error::NotAuthorized); +			getPayloadInternal()->setCondition(ErrorPayload::NotAuthorized);  		}  		else if (element == "payment-required") { -			getPayloadInternal()->setCondition(Error::PaymentRequired); +			getPayloadInternal()->setCondition(ErrorPayload::PaymentRequired);  		}  		else if (element == "recipient-unavailable") { -			getPayloadInternal()->setCondition(Error::RecipientUnavailable); +			getPayloadInternal()->setCondition(ErrorPayload::RecipientUnavailable);  		}  		else if (element == "redirect") { -			getPayloadInternal()->setCondition(Error::Redirect); +			getPayloadInternal()->setCondition(ErrorPayload::Redirect);  		}  		else if (element == "registration-required") { -			getPayloadInternal()->setCondition(Error::RegistrationRequired); +			getPayloadInternal()->setCondition(ErrorPayload::RegistrationRequired);  		}  		else if (element == "remote-server-not-found") { -			getPayloadInternal()->setCondition(Error::RemoteServerNotFound); +			getPayloadInternal()->setCondition(ErrorPayload::RemoteServerNotFound);  		}  		else if (element == "remote-server-timeout") { -			getPayloadInternal()->setCondition(Error::RemoteServerTimeout); +			getPayloadInternal()->setCondition(ErrorPayload::RemoteServerTimeout);  		}  		else if (element == "resource-constraint") { -			getPayloadInternal()->setCondition(Error::ResourceConstraint); +			getPayloadInternal()->setCondition(ErrorPayload::ResourceConstraint);  		}  		else if (element == "service-unavailable") { -			getPayloadInternal()->setCondition(Error::ServiceUnavailable); +			getPayloadInternal()->setCondition(ErrorPayload::ServiceUnavailable);  		}  		else if (element == "subscription-required") { -			getPayloadInternal()->setCondition(Error::SubscriptionRequired); +			getPayloadInternal()->setCondition(ErrorPayload::SubscriptionRequired);  		}  		else if (element == "unexpected-request") { -			getPayloadInternal()->setCondition(Error::UnexpectedRequest); +			getPayloadInternal()->setCondition(ErrorPayload::UnexpectedRequest);  		}  		else { -			getPayloadInternal()->setCondition(Error::UndefinedCondition); +			getPayloadInternal()->setCondition(ErrorPayload::UndefinedCondition);  		}  	}  } diff --git a/Swiften/Parser/PayloadParsers/ErrorParser.h b/Swiften/Parser/PayloadParsers/ErrorParser.h index 76db205..17b78b9 100644 --- a/Swiften/Parser/PayloadParsers/ErrorParser.h +++ b/Swiften/Parser/PayloadParsers/ErrorParser.h @@ -1,11 +1,11 @@  #ifndef SWIFTEN_ErrorParser_H  #define SWIFTEN_ErrorParser_H -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h"  #include "Swiften/Parser/GenericPayloadParser.h"  namespace Swift { -	class ErrorParser : public GenericPayloadParser<Error> { +	class ErrorParser : public GenericPayloadParser<ErrorPayload> {  		public:  			ErrorParser(); diff --git a/Swiften/Parser/PayloadParsers/UnitTest/ErrorParserTest.cpp b/Swiften/Parser/PayloadParsers/UnitTest/ErrorParserTest.cpp index 338fb3f..dcd3172 100644 --- a/Swiften/Parser/PayloadParsers/UnitTest/ErrorParserTest.cpp +++ b/Swiften/Parser/PayloadParsers/UnitTest/ErrorParserTest.cpp @@ -24,9 +24,9 @@ class ErrorParserTest : public CppUnit::TestFixture  					"<text xmlns=\"urn:ietf:params:xml:ns:xmpp-stanzas\">boo</text>"  				"</error>")); -			Error* payload = dynamic_cast<Error*>(parser.getPayload().get()); -			CPPUNIT_ASSERT_EQUAL(Error::BadRequest, payload->getCondition()); -			CPPUNIT_ASSERT_EQUAL(Error::Modify, payload->getType()); +			ErrorPayload* payload = dynamic_cast<ErrorPayload*>(parser.getPayload().get()); +			CPPUNIT_ASSERT_EQUAL(ErrorPayload::BadRequest, payload->getCondition()); +			CPPUNIT_ASSERT_EQUAL(ErrorPayload::Modify, payload->getType());  			CPPUNIT_ASSERT_EQUAL(String("boo"), payload->getText());  		}  }; diff --git a/Swiften/QA/ClientTest/ClientTest.cpp b/Swiften/QA/ClientTest/ClientTest.cpp index 412eb53..b50a0bf 100644 --- a/Swiften/QA/ClientTest/ClientTest.cpp +++ b/Swiften/QA/ClientTest/ClientTest.cpp @@ -19,6 +19,7 @@ bool rosterReceived = false;  void handleRosterReceived(boost::shared_ptr<Payload>) {  	rosterReceived = true; +	client->disconnect();  	eventLoop.stop();  } @@ -46,12 +47,13 @@ int main(int, char**) {  	client->connect();  	{ -		boost::shared_ptr<Timer> timer(new Timer(10000, &MainBoostIOServiceThread::getInstance().getIOService())); +		boost::shared_ptr<Timer> timer(new Timer(30000, &MainBoostIOServiceThread::getInstance().getIOService()));  		timer->onTick.connect(boost::bind(&SimpleEventLoop::stop, &eventLoop));  		timer->start();  		eventLoop.run();  	} +  	delete tracer;  	delete client;  	return !rosterReceived; diff --git a/Swiften/Queries/GenericRequest.h b/Swiften/Queries/GenericRequest.h index b4a1918..77dae52 100644 --- a/Swiften/Queries/GenericRequest.h +++ b/Swiften/Queries/GenericRequest.h @@ -1,5 +1,4 @@ -#ifndef SWIFTEN_GenericRequest_H -#define SWIFTEN_GenericRequest_H +#pragma once  #include <boost/signal.hpp> @@ -17,13 +16,11 @@ namespace Swift {  						Request(type, receiver, payload, router) {  			} -			virtual void handleResponse(boost::shared_ptr<Payload> payload, boost::optional<Error> error) { +			virtual void handleResponse(boost::shared_ptr<Payload> payload, boost::optional<ErrorPayload> error) {  				onResponse(boost::dynamic_pointer_cast<PAYLOAD_TYPE>(payload), error);  			}  		public: -			boost::signal<void (boost::shared_ptr<PAYLOAD_TYPE>, const boost::optional<Error>&)> onResponse; +			boost::signal<void (boost::shared_ptr<PAYLOAD_TYPE>, const boost::optional<ErrorPayload>&)> onResponse;  	};  } - -#endif diff --git a/Swiften/Queries/IQRouter.cpp b/Swiften/Queries/IQRouter.cpp index ffed5f7..fdfa00b 100644 --- a/Swiften/Queries/IQRouter.cpp +++ b/Swiften/Queries/IQRouter.cpp @@ -6,7 +6,7 @@  #include "Swiften/Base/foreach.h"  #include "Swiften/Queries/IQHandler.h"  #include "Swiften/Queries/IQChannel.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h"  namespace Swift { @@ -34,7 +34,7 @@ void IQRouter::handleIQ(boost::shared_ptr<IQ> iq) {  		}  	}  	if (!handled && (iq->getType() == IQ::Get || iq->getType() == IQ::Set) ) { -		channel_->sendIQ(IQ::createError(iq->getFrom(), iq->getID(), Error::FeatureNotImplemented, Error::Cancel)); +		channel_->sendIQ(IQ::createError(iq->getFrom(), iq->getID(), ErrorPayload::FeatureNotImplemented, ErrorPayload::Cancel));  	}  	processPendingRemoves(); diff --git a/Swiften/Queries/Request.cpp b/Swiften/Queries/Request.cpp index 90aa295..18446ae 100644 --- a/Swiften/Queries/Request.cpp +++ b/Swiften/Queries/Request.cpp @@ -35,11 +35,11 @@ bool Request::handleIQ(boost::shared_ptr<IQ> iq) {  	bool handled = false;  	if (sent_ && iq->getID() == id_) {  		if (iq->getType() == IQ::Result) { -			handleResponse(iq->getPayloadOfSameType(payload_), boost::optional<Error>()); +			handleResponse(iq->getPayloadOfSameType(payload_), boost::optional<ErrorPayload>());  		}  		else {  			// FIXME: Get proper error -			handleResponse(boost::shared_ptr<Payload>(), boost::optional<Error>(Error::UndefinedCondition)); +			handleResponse(boost::shared_ptr<Payload>(), boost::optional<ErrorPayload>(ErrorPayload::UndefinedCondition));  		}  		router_->removeHandler(this);  		handled = true; diff --git a/Swiften/Queries/Request.h b/Swiften/Queries/Request.h index 8f7a1d1..cc4a58e 100644 --- a/Swiften/Queries/Request.h +++ b/Swiften/Queries/Request.h @@ -9,7 +9,7 @@  #include "Swiften/Queries/IQHandler.h"  #include "Swiften/Elements/IQ.h"  #include "Swiften/Elements/Payload.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h"  #include "Swiften/JID/JID.h"  namespace Swift { @@ -32,7 +32,7 @@ namespace Swift {  				payload_ = p;  			} -			virtual void handleResponse(boost::shared_ptr<Payload>, boost::optional<Error>) = 0; +			virtual void handleResponse(boost::shared_ptr<Payload>, boost::optional<ErrorPayload>) = 0;  		private:  			bool handleIQ(boost::shared_ptr<IQ>); diff --git a/Swiften/Queries/Requests/GetPrivateStorageRequest.h b/Swiften/Queries/Requests/GetPrivateStorageRequest.h index c5f8aef..5d6440e 100644 --- a/Swiften/Queries/Requests/GetPrivateStorageRequest.h +++ b/Swiften/Queries/Requests/GetPrivateStorageRequest.h @@ -5,7 +5,7 @@  #include "Swiften/Queries/Request.h"  #include "Swiften/Elements/PrivateStorage.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h"  namespace Swift {  	template<typename PAYLOAD_TYPE> @@ -14,7 +14,7 @@ namespace Swift {  			GetPrivateStorageRequest(IQRouter* router) : Request(IQ::Get, JID(), boost::shared_ptr<PrivateStorage>(new PrivateStorage(boost::shared_ptr<Payload>(new PAYLOAD_TYPE()))), router) {  			} -			virtual void handleResponse(boost::shared_ptr<Payload> payload, boost::optional<Error> error) { +			virtual void handleResponse(boost::shared_ptr<Payload> payload, boost::optional<ErrorPayload> error) {  				boost::shared_ptr<PrivateStorage> storage = boost::dynamic_pointer_cast<PrivateStorage>(payload);  				if (storage) {  					onResponse(boost::dynamic_pointer_cast<PAYLOAD_TYPE>(storage->getPayload()), error); @@ -25,6 +25,6 @@ namespace Swift {  			}  		public: -			boost::signal<void (boost::shared_ptr<PAYLOAD_TYPE>, const boost::optional<Error>&)> onResponse; +			boost::signal<void (boost::shared_ptr<PAYLOAD_TYPE>, const boost::optional<ErrorPayload>&)> onResponse;  	};  } diff --git a/Swiften/Queries/Requests/SetPrivateStorageRequest.h b/Swiften/Queries/Requests/SetPrivateStorageRequest.h index 63ac8dc..834ddd8 100644 --- a/Swiften/Queries/Requests/SetPrivateStorageRequest.h +++ b/Swiften/Queries/Requests/SetPrivateStorageRequest.h @@ -5,7 +5,7 @@  #include "Swiften/Queries/Request.h"  #include "Swiften/Elements/PrivateStorage.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h"  namespace Swift {  	template<typename PAYLOAD_TYPE> @@ -14,11 +14,11 @@ namespace Swift {  			SetPrivateStorageRequest(boost::shared_ptr<PAYLOAD_TYPE> payload, IQRouter* router) : Request(IQ::Set, JID(), boost::shared_ptr<PrivateStorage>(new PrivateStorage(payload)), router) {  			} -			virtual void handleResponse(boost::shared_ptr<Payload> payload, boost::optional<Error> error) { +			virtual void handleResponse(boost::shared_ptr<Payload> payload, boost::optional<ErrorPayload> error) {  				onResponse(error);  			}  		public: -			boost::signal<void (const boost::optional<Error>&)> onResponse; +			boost::signal<void (const boost::optional<ErrorPayload>&)> onResponse;  	};  } diff --git a/Swiften/Queries/Requests/UnitTest/GetPrivateStorageRequestTest.cpp b/Swiften/Queries/Requests/UnitTest/GetPrivateStorageRequestTest.cpp index 14e04cf..a86a111 100644 --- a/Swiften/Queries/Requests/UnitTest/GetPrivateStorageRequestTest.cpp +++ b/Swiften/Queries/Requests/UnitTest/GetPrivateStorageRequestTest.cpp @@ -72,7 +72,7 @@ class GetPrivateStorageRequestTest : public CppUnit::TestFixture  		}  	private: -		void handleResponse(boost::shared_ptr<Payload> p, const boost::optional<Error>& e) { +		void handleResponse(boost::shared_ptr<Payload> p, const boost::optional<ErrorPayload>& e) {  			if (e) {  				errors.push_back(*e);  			} @@ -99,7 +99,7 @@ class GetPrivateStorageRequestTest : public CppUnit::TestFixture  	private:  		IQRouter* router;  		DummyIQChannel* channel; -		std::vector< Error > errors; +		std::vector< ErrorPayload > errors;  		std::vector< boost::shared_ptr<Payload> > responses;  }; diff --git a/Swiften/Queries/Responder.h b/Swiften/Queries/Responder.h index e6e8ca6..9c025eb 100644 --- a/Swiften/Queries/Responder.h +++ b/Swiften/Queries/Responder.h @@ -3,7 +3,7 @@  #include "Swiften/Queries/IQHandler.h"  #include "Swiften/Queries/IQRouter.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h"  namespace Swift {  	template<typename PAYLOAD_TYPE> @@ -25,7 +25,7 @@ namespace Swift {  				router_->sendIQ(IQ::createResult(to, id, payload));  			} -			void sendError(const JID& to, const String& id, Error::Condition condition, Error::Type type) { +			void sendError(const JID& to, const String& id, ErrorPayload::Condition condition, ErrorPayload::Type type) {  				router_->sendIQ(IQ::createError(to, id, condition, type));  			} @@ -42,7 +42,7 @@ namespace Swift {  							result = handleGetRequest(iq->getFrom(), iq->getID(), payload);  						}  						if (!result) { -							router_->sendIQ(IQ::createError(iq->getFrom(), iq->getID(), Error::NotAllowed, Error::Cancel)); +							router_->sendIQ(IQ::createError(iq->getFrom(), iq->getID(), ErrorPayload::NotAllowed, ErrorPayload::Cancel));  						}  						return true;  					} diff --git a/Swiften/Queries/Responders/DiscoInfoResponder.cpp b/Swiften/Queries/Responders/DiscoInfoResponder.cpp index a114fbc..572f83f 100644 --- a/Swiften/Queries/Responders/DiscoInfoResponder.cpp +++ b/Swiften/Queries/Responders/DiscoInfoResponder.cpp @@ -27,7 +27,7 @@ bool DiscoInfoResponder::handleGetRequest(const JID& from, const String& id, boo  			sendResponse(from, id, boost::shared_ptr<DiscoInfo>(new DiscoInfo((*i).second)));  		}  		else { -			sendError(from, id, Error::ItemNotFound, Error::Cancel); +			sendError(from, id, ErrorPayload::ItemNotFound, ErrorPayload::Cancel);  		}  	}  	return true; diff --git a/Swiften/Queries/Responders/UnitTest/DiscoInfoResponderTest.cpp b/Swiften/Queries/Responders/UnitTest/DiscoInfoResponderTest.cpp index 6ed7b9e..5993d0c 100644 --- a/Swiften/Queries/Responders/UnitTest/DiscoInfoResponderTest.cpp +++ b/Swiften/Queries/Responders/UnitTest/DiscoInfoResponderTest.cpp @@ -72,7 +72,7 @@ class DiscoInfoResponderTest : public CppUnit::TestFixture {  			channel_->onIQReceived(IQ::createRequest(IQ::Get, JID("foo@bar.com"), "id-1", query));  			CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(channel_->iqs_.size())); -			boost::shared_ptr<Error> payload(channel_->iqs_[0]->getPayload<Error>()); +			boost::shared_ptr<ErrorPayload> payload(channel_->iqs_[0]->getPayload<ErrorPayload>());  			CPPUNIT_ASSERT(payload);  		} diff --git a/Swiften/Queries/UnitTest/IQRouterTest.cpp b/Swiften/Queries/UnitTest/IQRouterTest.cpp index 94b7de8..5760b09 100644 --- a/Swiften/Queries/UnitTest/IQRouterTest.cpp +++ b/Swiften/Queries/UnitTest/IQRouterTest.cpp @@ -87,7 +87,7 @@ class IQRouterTest : public CppUnit::TestFixture  			channel_->onIQReceived(boost::shared_ptr<IQ>(new IQ()));  			CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(channel_->iqs_.size())); -			CPPUNIT_ASSERT(channel_->iqs_[0]->getPayload<Error>()); +			CPPUNIT_ASSERT(channel_->iqs_[0]->getPayload<ErrorPayload>());  		} diff --git a/Swiften/Queries/UnitTest/RequestTest.cpp b/Swiften/Queries/UnitTest/RequestTest.cpp index ea6dee6..51d5a51 100644 --- a/Swiften/Queries/UnitTest/RequestTest.cpp +++ b/Swiften/Queries/UnitTest/RequestTest.cpp @@ -113,7 +113,7 @@ class RequestTest : public CppUnit::TestFixture  		}  	private: -		void handleResponse(boost::shared_ptr<Payload> p, const boost::optional<Error>& e) { +		void handleResponse(boost::shared_ptr<Payload> p, const boost::optional<ErrorPayload>& e) {  			if (e) {  				++errorsReceived_;  			} diff --git a/Swiften/SConscript b/Swiften/SConscript index 148f1f8..d5ddce4 100644 --- a/Swiften/SConscript +++ b/Swiften/SConscript @@ -75,6 +75,8 @@ sources = [  		"Server/SimpleUserRegistry.cpp",  		"Server/UserRegistry.cpp",  		"Session/Session.cpp", +		"Session/SessionStream.cpp", +		"Session/BasicSessionStream.cpp",  		"StringCodecs/Base64.cpp",  		"StringCodecs/SHA1.cpp",  	] @@ -103,7 +105,7 @@ env.Append(UNITTEST_SOURCES = [  		File("Base/UnitTest/IDGeneratorTest.cpp"),  		File("Base/UnitTest/StringTest.cpp"),  		File("Base/UnitTest/ByteArrayTest.cpp"), -		File("Client/UnitTest/ClientSessionTest.cpp"), +		#File("Client/UnitTest/ClientSessionTest.cpp"),  		File("Compress/UnitTest/ZLibCompressorTest.cpp"),  		File("Compress/UnitTest/ZLibDecompressorTest.cpp"),  		File("Disco/UnitTest/CapsInfoGeneratorTest.cpp"), diff --git a/Swiften/Serializer/PayloadSerializers/ErrorSerializer.cpp b/Swiften/Serializer/PayloadSerializers/ErrorSerializer.cpp index 347e1a5..f5ce478 100644 --- a/Swiften/Serializer/PayloadSerializers/ErrorSerializer.cpp +++ b/Swiften/Serializer/PayloadSerializers/ErrorSerializer.cpp @@ -3,43 +3,43 @@  namespace Swift { -ErrorSerializer::ErrorSerializer() : GenericPayloadSerializer<Error>() { +ErrorSerializer::ErrorSerializer() : GenericPayloadSerializer<ErrorPayload>() {  } -String ErrorSerializer::serializePayload(boost::shared_ptr<Error> error)  const { +String ErrorSerializer::serializePayload(boost::shared_ptr<ErrorPayload> error)  const {  	String result("<error type=\"");  	switch (error->getType()) { -		case Error::Continue: result += "continue"; break; -		case Error::Modify: result += "modify"; break; -		case Error::Auth: result += "auth"; break; -		case Error::Wait: result += "wait"; break; +		case ErrorPayload::Continue: result += "continue"; break; +		case ErrorPayload::Modify: result += "modify"; break; +		case ErrorPayload::Auth: result += "auth"; break; +		case ErrorPayload::Wait: result += "wait"; break;  		default: result += "cancel"; break;  	}  	result += "\">";  	String conditionElement;  	switch (error->getCondition()) { -		case Error::BadRequest: conditionElement = "bad-request"; break; -		case Error::Conflict: conditionElement = "conflict"; break; -		case Error::FeatureNotImplemented: conditionElement = "feature-not-implemented"; break; -		case Error::Forbidden: conditionElement = "forbidden"; break; -		case Error::Gone: conditionElement = "gone"; break; -		case Error::InternalServerError: conditionElement = "internal-server-error"; break; -		case Error::ItemNotFound: conditionElement = "item-not-found"; break; -		case Error::JIDMalformed: conditionElement = "jid-malformed"; break; -		case Error::NotAcceptable: conditionElement = "not-acceptable"; break; -		case Error::NotAllowed: conditionElement = "not-allowed"; break; -		case Error::NotAuthorized: conditionElement = "not-authorized"; break; -		case Error::PaymentRequired: conditionElement = "payment-required"; break; -		case Error::RecipientUnavailable: conditionElement = "recipient-unavailable"; break; -		case Error::Redirect: conditionElement = "redirect"; break; -		case Error::RegistrationRequired: conditionElement = "registration-required"; break; -		case Error::RemoteServerNotFound: conditionElement = "remote-server-not-found"; break; -		case Error::RemoteServerTimeout: conditionElement = "remote-server-timeout"; break; -		case Error::ResourceConstraint: conditionElement = "resource-constraint"; break; -		case Error::ServiceUnavailable: conditionElement = "service-unavailable"; break; -		case Error::SubscriptionRequired: conditionElement = "subscription-required"; break; -		case Error::UnexpectedRequest: conditionElement = "unexpected-request"; break; +		case ErrorPayload::BadRequest: conditionElement = "bad-request"; break; +		case ErrorPayload::Conflict: conditionElement = "conflict"; break; +		case ErrorPayload::FeatureNotImplemented: conditionElement = "feature-not-implemented"; break; +		case ErrorPayload::Forbidden: conditionElement = "forbidden"; break; +		case ErrorPayload::Gone: conditionElement = "gone"; break; +		case ErrorPayload::InternalServerError: conditionElement = "internal-server-error"; break; +		case ErrorPayload::ItemNotFound: conditionElement = "item-not-found"; break; +		case ErrorPayload::JIDMalformed: conditionElement = "jid-malformed"; break; +		case ErrorPayload::NotAcceptable: conditionElement = "not-acceptable"; break; +		case ErrorPayload::NotAllowed: conditionElement = "not-allowed"; break; +		case ErrorPayload::NotAuthorized: conditionElement = "not-authorized"; break; +		case ErrorPayload::PaymentRequired: conditionElement = "payment-required"; break; +		case ErrorPayload::RecipientUnavailable: conditionElement = "recipient-unavailable"; break; +		case ErrorPayload::Redirect: conditionElement = "redirect"; break; +		case ErrorPayload::RegistrationRequired: conditionElement = "registration-required"; break; +		case ErrorPayload::RemoteServerNotFound: conditionElement = "remote-server-not-found"; break; +		case ErrorPayload::RemoteServerTimeout: conditionElement = "remote-server-timeout"; break; +		case ErrorPayload::ResourceConstraint: conditionElement = "resource-constraint"; break; +		case ErrorPayload::ServiceUnavailable: conditionElement = "service-unavailable"; break; +		case ErrorPayload::SubscriptionRequired: conditionElement = "subscription-required"; break; +		case ErrorPayload::UnexpectedRequest: conditionElement = "unexpected-request"; break;  		default: conditionElement = "undefined-condition"; break;  	}  	result += "<" + conditionElement + " xmlns=\"urn:ietf:params:xml:ns:xmpp-stanzas\"/>"; diff --git a/Swiften/Serializer/PayloadSerializers/ErrorSerializer.h b/Swiften/Serializer/PayloadSerializers/ErrorSerializer.h index ecf73dc..931596f 100644 --- a/Swiften/Serializer/PayloadSerializers/ErrorSerializer.h +++ b/Swiften/Serializer/PayloadSerializers/ErrorSerializer.h @@ -2,14 +2,14 @@  #define SWIFTEN_ErrorSerializer_H  #include "Swiften/Serializer/GenericPayloadSerializer.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h"  namespace Swift { -	class ErrorSerializer : public GenericPayloadSerializer<Error> { +	class ErrorSerializer : public GenericPayloadSerializer<ErrorPayload> {  		public:  			ErrorSerializer(); -			virtual String serializePayload(boost::shared_ptr<Error> error)  const; +			virtual String serializePayload(boost::shared_ptr<ErrorPayload> error)  const;  	};  } diff --git a/Swiften/Serializer/PayloadSerializers/UnitTest/ErrorSerializerTest.cpp b/Swiften/Serializer/PayloadSerializers/UnitTest/ErrorSerializerTest.cpp index 2d68a3d..ecd904a 100644 --- a/Swiften/Serializer/PayloadSerializers/UnitTest/ErrorSerializerTest.cpp +++ b/Swiften/Serializer/PayloadSerializers/UnitTest/ErrorSerializerTest.cpp @@ -16,7 +16,7 @@ class ErrorSerializerTest : public CppUnit::TestFixture  		void testSerialize() {  			ErrorSerializer testling; -			boost::shared_ptr<Error> error(new Error(Error::BadRequest, Error::Cancel, "My Error")); +			boost::shared_ptr<ErrorPayload> error(new ErrorPayload(ErrorPayload::BadRequest, ErrorPayload::Cancel, "My Error"));  			CPPUNIT_ASSERT_EQUAL(String("<error type=\"cancel\"><bad-request xmlns=\"urn:ietf:params:xml:ns:xmpp-stanzas\"/><text xmlns=\"urn:ietf:params:xml:ns:xmpp-stanzas\">My Error</text></error>"), testling.serialize(error));  		} diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp index 46d4e16..8b14367 100644 --- a/Swiften/Session/BasicSessionStream.cpp +++ b/Swiften/Session/BasicSessionStream.cpp @@ -1,5 +1,3 @@ -// TODO: whitespacePingLayer_->setInactive(); -  #include "Swiften/Session/BasicSessionStream.h"  #include <boost/bind.hpp> @@ -13,18 +11,26 @@  namespace Swift { -BasicSessionStream::BasicSessionStream(boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, TLSLayerFactory* tlsLayerFactory) : tlsLayerFactory(tlsLayerFactory) { +BasicSessionStream::BasicSessionStream(boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, TLSLayerFactory* tlsLayerFactory) : available(false), connection(connection), payloadParserFactories(payloadParserFactories), payloadSerializers(payloadSerializers), tlsLayerFactory(tlsLayerFactory) { +} + +void BasicSessionStream::initialize() {  	xmppLayer = boost::shared_ptr<XMPPLayer>(  			new XMPPLayer(payloadParserFactories, payloadSerializers)); -	xmppLayer->onStreamStart.connect(boost::ref(onStreamStartReceived)); -	xmppLayer->onElement.connect(boost::ref(onElementReceived)); +	xmppLayer->onStreamStart.connect(boost::bind(&BasicSessionStream::handleStreamStartReceived, shared_from_this(), _1)); +	xmppLayer->onElement.connect(boost::bind(&BasicSessionStream::handleElementReceived, shared_from_this(), _1));  	xmppLayer->onError.connect(boost::bind( -      &BasicSessionStream::handleXMPPError, this)); +      &BasicSessionStream::handleXMPPError, shared_from_this())); +  xmppLayer->onDataRead.connect(boost::bind(&BasicSessionStream::handleDataRead, shared_from_this(), _1)); +  xmppLayer->onWriteData.connect(boost::bind(&BasicSessionStream::handleDataWritten, shared_from_this(), _1)); +	connection->onDisconnected.connect(boost::bind(&BasicSessionStream::handleConnectionError, shared_from_this(), _1));  	connectionLayer = boost::shared_ptr<ConnectionLayer>(        new ConnectionLayer(connection));  	streamStack = new StreamStack(xmppLayer, connectionLayer); + +	available = true;  }  BasicSessionStream::~BasicSessionStream() { @@ -32,41 +38,92 @@ BasicSessionStream::~BasicSessionStream() {  }  void BasicSessionStream::writeHeader(const ProtocolHeader& header) { +	assert(available);  	xmppLayer->writeHeader(header);  }  void BasicSessionStream::writeElement(boost::shared_ptr<Element> element) { +	assert(available);  	xmppLayer->writeElement(element);  } +void BasicSessionStream::writeFooter() { +	assert(available); +	xmppLayer->writeFooter(); +} + +bool BasicSessionStream::isAvailable() { +	return available; +} +  bool BasicSessionStream::supportsTLSEncryption() {    return tlsLayerFactory && tlsLayerFactory->canCreate();  }  void BasicSessionStream::addTLSEncryption() { +	assert(available);  	tlsLayer = tlsLayerFactory->createTLSLayer(); -  streamStack->addLayer(tlsLayer); -  // TODO: Add tls layer certificate if needed -  tlsLayer->onError.connect(boost::bind(&BasicSessionStream::handleTLSError, this)); -  tlsLayer->connect(); +	if (hasTLSCertificate() && !tlsLayer->setClientCertificate(getTLSCertificate())) { +		onError(boost::shared_ptr<Error>(new Error(Error::InvalidTLSCertificateError))); +	} +	else { +		streamStack->addLayer(tlsLayer); +		tlsLayer->onError.connect(boost::bind(&BasicSessionStream::handleTLSError, shared_from_this())); +		tlsLayer->onConnected.connect(boost::bind(&BasicSessionStream::handleTLSConnected, shared_from_this())); +		tlsLayer->connect(); +	}  } -void BasicSessionStream::addWhitespacePing() { -  whitespacePingLayer = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer()); -  streamStack->addLayer(whitespacePingLayer); -  whitespacePingLayer->setActive(); +void BasicSessionStream::setWhitespacePingEnabled(bool enabled) { +	if (enabled && !whitespacePingLayer) { +		whitespacePingLayer = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer()); +		streamStack->addLayer(whitespacePingLayer); +	} +	if (enabled) { +		whitespacePingLayer->setActive(); +	} +	else { +		whitespacePingLayer->setInactive(); +	}  }  void BasicSessionStream::resetXMPPParser() {    xmppLayer->resetParser();  } +void BasicSessionStream::handleStreamStartReceived(const ProtocolHeader& header) { +	onStreamStartReceived(header); +} + +void BasicSessionStream::handleElementReceived(boost::shared_ptr<Element> element) { +	onElementReceived(element); +} +  void BasicSessionStream::handleXMPPError() { -  // TODO +	available = false; +	onError(boost::shared_ptr<Error>(new Error(Error::ParseError))); +} + +void BasicSessionStream::handleTLSConnected() { +	onTLSEncrypted();  }  void BasicSessionStream::handleTLSError() { -  // TODO +	available = false; +	onError(boost::shared_ptr<Error>(new Error(Error::TLSError))); +} + +void BasicSessionStream::handleConnectionError(const boost::optional<Connection::Error>&) { +	available = false; +	onError(boost::shared_ptr<Error>(new Error(Error::ConnectionError))); +} + +void BasicSessionStream::handleDataRead(const ByteArray& data) { +	onDataRead(String(data.getData(), data.getSize())); +} + +void BasicSessionStream::handleDataWritten(const ByteArray& data) { +	onDataWritten(String(data.getData(), data.getSize()));  }  }; diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h index bf92bbb..0cb50eb 100644 --- a/Swiften/Session/BasicSessionStream.h +++ b/Swiften/Session/BasicSessionStream.h @@ -1,6 +1,7 @@  #pragma once  #include <boost/shared_ptr.hpp> +#include <boost/enable_shared_from_this.hpp>  #include "Swiften/Network/Connection.h"  #include "Swiften/Session/SessionStream.h" @@ -17,7 +18,7 @@ namespace Swift {    class BasicSessionStream :         public SessionStream,  -      public boost::BOOST_SIGNALS_NAMESPACE::trackable { +      public boost::enable_shared_from_this<BasicSessionStream> {      public:        BasicSessionStream(  		    boost::shared_ptr<Connection> connection, @@ -27,25 +28,40 @@ namespace Swift {        );        ~BasicSessionStream(); +			void initialize(); + +			virtual bool isAvailable(); +  			virtual void writeHeader(const ProtocolHeader& header);  			virtual void writeElement(boost::shared_ptr<Element>); +			virtual void writeFooter();  			virtual bool supportsTLSEncryption();  			virtual void addTLSEncryption(); -			virtual void addWhitespacePing(); +			virtual void setWhitespacePingEnabled(bool);  			virtual void resetXMPPParser();      private: +			void handleConnectionError(const boost::optional<Connection::Error>& error);        void handleXMPPError(); +			void handleTLSConnected();        void handleTLSError(); +			void handleStreamStartReceived(const ProtocolHeader&); +			void handleElementReceived(boost::shared_ptr<Element>); +      void handleDataRead(const ByteArray& data); +      void handleDataWritten(const ByteArray& data);      private: +			bool available; +			boost::shared_ptr<Connection> connection; +			PayloadParserFactoryCollection* payloadParserFactories; +			PayloadSerializerCollection* payloadSerializers; +			TLSLayerFactory* tlsLayerFactory;  			boost::shared_ptr<XMPPLayer> xmppLayer;  			boost::shared_ptr<ConnectionLayer> connectionLayer;  			StreamStack* streamStack; -      TLSLayerFactory* tlsLayerFactory;        boost::shared_ptr<TLSLayer> tlsLayer;        boost::shared_ptr<WhitespacePingLayer> whitespacePingLayer;    }; diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h index 17d9a24..6bba237 100644 --- a/Swiften/Session/SessionStream.h +++ b/Swiften/Session/SessionStream.h @@ -5,23 +5,62 @@  #include "Swiften/Elements/ProtocolHeader.h"  #include "Swiften/Elements/Element.h" +#include "Swiften/Base/Error.h" +#include "Swiften/TLS/PKCS12Certificate.h"  namespace Swift {  	class SessionStream {  		public: +			class Error : public Swift::Error { +				public: +					enum Type { +						ParseError, +						TLSError, +						InvalidTLSCertificateError, +						ConnectionError +					}; + +					Error(Type type) : type(type) {} + +					Type type; +			}; +  			virtual ~SessionStream(); +			virtual bool isAvailable() = 0; +  			virtual void writeHeader(const ProtocolHeader& header) = 0; +			virtual void writeFooter() = 0;  			virtual void writeElement(boost::shared_ptr<Element>) = 0;  			virtual bool supportsTLSEncryption() = 0;  			virtual void addTLSEncryption() = 0; - -			virtual void addWhitespacePing() = 0; +			virtual void setWhitespacePingEnabled(bool enabled) = 0;  			virtual void resetXMPPParser() = 0; +			void setTLSCertificate(const PKCS12Certificate& cert) { +				certificate = cert; +			} + +			virtual bool hasTLSCertificate() { +				return !certificate.isNull(); +			} + +  			boost::signal<void (const ProtocolHeader&)> onStreamStartReceived;  			boost::signal<void (boost::shared_ptr<Element>)> onElementReceived; +			boost::signal<void (boost::shared_ptr<Error>)> onError; +			boost::signal<void ()> onTLSEncrypted; +			boost::signal<void (const String&)> onDataRead; +			boost::signal<void (const String&)> onDataWritten; + +		protected: +			const PKCS12Certificate& getTLSCertificate() const { +				return certificate; +			} + +		private: +			PKCS12Certificate certificate;  	};  } | 
 Swift
 Swift