diff options
| author | Alexey Melnikov <alexey.melnikov@isode.com> | 2012-02-13 17:54:23 (GMT) | 
|---|---|---|
| committer | Kevin Smith <git@kismith.co.uk> | 2012-02-22 14:08:13 (GMT) | 
| commit | 110eb87e848b85dd74a6f19413c775520a75ea35 (patch) | |
| tree | b10236387180fca676a29f24c747c9d0fd94d8dd | |
| parent | 64fc103d0d5d1d523d00dcc5b231715160475f7e (diff) | |
| download | swift-contrib-110eb87e848b85dd74a6f19413c775520a75ea35.zip swift-contrib-110eb87e848b85dd74a6f19413c775520a75ea35.tar.bz2 | |
Initial implementation of using CAPI certificates with Schannel.
Introduced a new parent class for all certificates with keys
(class CertificateWithKey is the new parent for PKCS12Certificate.)
Switched to using "CertificateWithKey *" instead of "const CertificateWithKey&"
Added calling of a Windows dialog for certificate selection when Schannel
TLS implementation is used.
This compiles, but is not tested.
License: This patch is BSD-licensed, see Documentation/Licenses/BSD-simplified.txt for details.
| -rw-r--r-- | Swift/QtUI/CAPICertificateSelector.cpp | 138 | ||||
| -rw-r--r-- | Swift/QtUI/CAPICertificateSelector.h | 13 | ||||
| -rw-r--r-- | Swift/QtUI/QtLoginWindow.cpp | 19 | ||||
| -rw-r--r-- | Swift/QtUI/SConscript | 3 | ||||
| -rw-r--r-- | Swiften/Client/CoreClient.cpp | 27 | ||||
| -rw-r--r-- | Swiften/Client/CoreClient.h | 2 | ||||
| -rw-r--r-- | Swiften/Session/SessionStream.cpp | 1 | ||||
| -rw-r--r-- | Swiften/Session/SessionStream.h | 12 | ||||
| -rw-r--r-- | Swiften/StreamStack/TLSLayer.cpp | 2 | ||||
| -rw-r--r-- | Swiften/StreamStack/TLSLayer.h | 4 | ||||
| -rw-r--r-- | Swiften/TLS/CAPICertificate.h | 196 | ||||
| -rw-r--r-- | Swiften/TLS/CertificateWithKey.h | 32 | ||||
| -rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.cpp | 14 | ||||
| -rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.h | 4 | ||||
| -rw-r--r-- | Swiften/TLS/PKCS12Certificate.h | 27 | ||||
| -rw-r--r-- | Swiften/TLS/Schannel/SchannelContext.cpp | 82 | ||||
| -rw-r--r-- | Swiften/TLS/Schannel/SchannelContext.h | 11 | ||||
| -rw-r--r-- | Swiften/TLS/TLSContext.h | 4 | 
18 files changed, 559 insertions, 32 deletions
| diff --git a/Swift/QtUI/CAPICertificateSelector.cpp b/Swift/QtUI/CAPICertificateSelector.cpp new file mode 100644 index 0000000..44f5793 --- /dev/null +++ b/Swift/QtUI/CAPICertificateSelector.cpp @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2012 Isode Limited, London, England. + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +#include <string> + +#define SECURITY_WIN32 +#include <Windows.h> +#include <WinCrypt.h> +#include <cryptuiapi.h> + +#include "CAPICertificateSelector.h" + +namespace Swift { + +#define cert_dlg_title L"TLS Client Certificate Selection" +#define cert_dlg_prompt L"Select a certificate to use for authentication" +/////Hmm, maybe we should not exlude the "location" column +#define exclude_columns	 CRYPTUI_SELECT_LOCATION_COLUMN \ +			|CRYPTUI_SELECT_INTENDEDUSE_COLUMN + + + +static std::string getCertUri(PCCERT_CONTEXT cert, const char * cert_store_name) { +	DWORD required_size; +	char * comma; +	char * p_in; +	char * p_out; +	char * subject_name; +	std::string ret = std::string("certstore:") + cert_store_name + ":"; + +	required_size = CertNameToStrA(cert->dwCertEncodingType, +				&cert->pCertInfo->Subject, +				/* Discard attribute names: */ +				CERT_SIMPLE_NAME_STR | CERT_NAME_STR_REVERSE_FLAG, +				NULL, +				0); + +	subject_name = static_cast<char *>(malloc(required_size+1)); + +	if (!CertNameToStrA(cert->dwCertEncodingType, +			    &cert->pCertInfo->Subject, +			    /* Discard attribute names: */ +			    CERT_SIMPLE_NAME_STR | CERT_NAME_STR_REVERSE_FLAG, +			    subject_name, +			    required_size)) { +		return ""; +	} + +	/* Now search for the "," (ignoring escapes) +	    and truncate the rest of the string */ +	if (subject_name[0] == '"') { +		for (comma = subject_name + 1; comma[0]; comma++) { +			if (comma[0] == '"') { +				comma++; +				if (comma[0] != '"') { +					break; +				} +			} +		} +	} else { +		comma = strchr(subject_name, ','); +	} + +	if (comma != NULL) { +		*comma = '\0'; +	} + +	/* We now need to unescape the returned RDN */ +	if (subject_name[0] == '"') { +		for (p_in = subject_name + 1, p_out = subject_name; p_in[0]; p_in++, p_out++) { +			if (p_in[0] == '"') { +				p_in++; +			} + +			p_out[0] = p_in[0]; +		} +		p_out[0] = '\0'; +	} + +	ret += subject_name; +	free(subject_name); + +	return ret; +} + +std::string selectCAPICertificate() { + +	const char * cert_store_name = "MY"; +	PCCERT_CONTEXT cert; +	DWORD store_flags; +	HCERTSTORE hstore; +	HWND hwnd; + +	store_flags = CERT_STORE_OPEN_EXISTING_FLAG | +		      CERT_STORE_READONLY_FLAG | +		      CERT_SYSTEM_STORE_CURRENT_USER; + +	hstore = CertOpenStore(CERT_STORE_PROV_SYSTEM_A, 0, 0, store_flags, cert_store_name); +	if (!hstore) { +		return ""; +	} + + +////Does this handle need to be freed as well? +	hwnd = GetForegroundWindow(); +	if (!hwnd) { +		hwnd = GetActiveWindow(); +	} + +	/* Call Windows dialog to select a suitable certificate */ +	cert = CryptUIDlgSelectCertificateFromStore(hstore, +						  hwnd, +						  cert_dlg_title, +						  cert_dlg_prompt, +						  exclude_columns, +						  0, +						  NULL); + +	if (hstore) { +		CertCloseStore(hstore, 0); +	} + +	if (cert) { +		std::string ret = getCertUri(cert, cert_store_name); + +		CertFreeCertificateContext(cert); + +		return ret; +	} else { +		return ""; +	} +} + + +} diff --git a/Swift/QtUI/CAPICertificateSelector.h b/Swift/QtUI/CAPICertificateSelector.h new file mode 100644 index 0000000..9a0ee92 --- /dev/null +++ b/Swift/QtUI/CAPICertificateSelector.h @@ -0,0 +1,13 @@ +/* + * Copyright (c) 2012 Isode Limited, London, England. + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +#pragma once + +#include <string> + +namespace Swift { +	std::string selectCAPICertificate(); +} diff --git a/Swift/QtUI/QtLoginWindow.cpp b/Swift/QtUI/QtLoginWindow.cpp index 1cd3206..6b9d389 100644 --- a/Swift/QtUI/QtLoginWindow.cpp +++ b/Swift/QtUI/QtLoginWindow.cpp @@ -41,6 +41,10 @@  #include <QtMainWindow.h>  #include <QtUtilities.h> +#ifdef HAVE_SCHANNEL +#include "CAPICertificateSelector.h" +#endif +  namespace Swift{  QtLoginWindow::QtLoginWindow(UIEventStream* uiEventStream, SettingsProvider* settings) : QMainWindow(), settings_(settings) { @@ -357,10 +361,17 @@ void QtLoginWindow::setLoginAutomatically(bool loginAutomatically) {  void QtLoginWindow::handleCertficateChecked(bool checked) {  	if (checked) { -		 certificateFile_ = QFileDialog::getOpenFileName(this, tr("Select an authentication certificate"), QString(), QString("*.cert;*.p12;*.pfx")); -		 if (certificateFile_.isEmpty()) { -			 certificateButton_->setChecked(false); -		 } +#ifdef HAVE_SCHANNEL +		certificateFile_ = selectCAPICertificate(); +		if (certificateFile_.isEmpty()) { +			certificateButton_->setChecked(false); +		} +#else +		certificateFile_ = QFileDialog::getOpenFileName(this, tr("Select an authentication certificate"), QString(), QString("*.cert;*.p12;*.pfx")); +		if (certificateFile_.isEmpty()) { +			certificateButton_->setChecked(false); +		} +#endif  	}  	else {  		certificateFile_ = ""; diff --git a/Swift/QtUI/SConscript b/Swift/QtUI/SConscript index d37958f..a8b8c78 100644 --- a/Swift/QtUI/SConscript +++ b/Swift/QtUI/SConscript @@ -55,6 +55,8 @@ if env["PLATFORM"] == "win32" :    #myenv["LINKFLAGS"] = ["/SUBSYSTEM:CONSOLE"]    myenv.Append(LINKFLAGS = ["/SUBSYSTEM:WINDOWS"])    myenv.Append(LIBS = "qtmain") +  if myenv.get("HAVE_SCHANNEL", 0) : +    myenv.Append(LIBS = "Cryptui")  myenv.WriteVal("DefaultTheme.qrc", myenv.Value(generateDefaultTheme(myenv.Dir("#/Swift/resources/themes/Default")))) @@ -151,6 +153,7 @@ if env["PLATFORM"] == "win32" :  	# Adding it explicitly until i figure out why    myenv.Depends(res, "../Controllers/BuildVersion.h")    sources += [ +			"CAPICertificateSelector.cpp",  			"WindowsNotifier.cpp",  			"#/Swift/resources/Windows/Swift.res"  		] diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp index de12fb7..36bfe35 100644 --- a/Swiften/Client/CoreClient.cpp +++ b/Swiften/Client/CoreClient.cpp @@ -126,6 +126,19 @@ void CoreClient::bindSessionToStream() {  	session_->start();  } +bool CoreClient::isCAPIURI() { +#ifdef HAVE_SCHANNEL +	if (!boost::iequals(certificate_.substr(0, 10), "certstore:")) { +		return false; +	} + +	return true; + +#else +	return false; +#endif +} +  /**   * Only called for TCP sessions. BOSH is handled inside the BOSHSessionStream.   */ @@ -144,7 +157,19 @@ void CoreClient::handleConnectorFinished(boost::shared_ptr<Connection> connectio  		assert(!sessionStream_);  		sessionStream_ = boost::make_shared<BasicSessionStream>(ClientStreamType, connection_, getPayloadParserFactories(), getPayloadSerializers(), networkFactories->getTLSContextFactory(), networkFactories->getTimerFactory(), networkFactories->getXMLParserFactory());  		if (!certificate_.empty()) { -			sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_)); +			CertificateWithKey* cert; + +#if defined(SWIFTEN_PLATFORM_WIN32) +			if (isCAPIURI()) { +				cert = new CAPICertificate(certificate_); +			} else { +				cert = new PKCS12Certificate(certificate_, password_); +			} +#else +			cert = new PKCS12Certificate(certificate_, password_); +#endif + +			sessionStream_->setTLSCertificate(cert);  		}  		sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1));  		sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1)); diff --git a/Swiften/Client/CoreClient.h b/Swiften/Client/CoreClient.h index c231fdc..6712e03 100644 --- a/Swiften/Client/CoreClient.h +++ b/Swiften/Client/CoreClient.h @@ -196,6 +196,8 @@ namespace Swift {  			 */  			virtual void handleConnected() {}; +			bool isCAPIURI(); +  		private:  			void handleConnectorFinished(boost::shared_ptr<Connection>);  			void handleStanzaChannelAvailableChanged(bool available); diff --git a/Swiften/Session/SessionStream.cpp b/Swiften/Session/SessionStream.cpp index 0d73b63..487ad8b 100644 --- a/Swiften/Session/SessionStream.cpp +++ b/Swiften/Session/SessionStream.cpp @@ -9,6 +9,7 @@  namespace Swift {  SessionStream::~SessionStream() { +	delete certificate;  }  }; diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h index 096f185..58015b3 100644 --- a/Swiften/Session/SessionStream.h +++ b/Swiften/Session/SessionStream.h @@ -14,7 +14,7 @@  #include <Swiften/Elements/Element.h>  #include <Swiften/Base/Error.h>  #include <Swiften/Base/SafeByteArray.h> -#include <Swiften/TLS/PKCS12Certificate.h> +#include <Swiften/TLS/CertificateWithKey.h>  #include <Swiften/TLS/Certificate.h>  #include <Swiften/TLS/CertificateVerificationError.h> @@ -36,6 +36,8 @@ namespace Swift {  					Type type;  			}; +			SessionStream(): certificate(0) {} +  			virtual ~SessionStream();  			virtual void close() = 0; @@ -56,12 +58,12 @@ namespace Swift {  			virtual void resetXMPPParser() = 0; -			void setTLSCertificate(const PKCS12Certificate& cert) { +			void setTLSCertificate(CertificateWithKey* cert) {  				certificate = cert;  			}  			virtual bool hasTLSCertificate() { -				return !certificate.isNull(); +				return certificate && !certificate->isNull();  			}  			virtual Certificate::ref getPeerCertificate() const = 0; @@ -77,11 +79,11 @@ namespace Swift {  			boost::signal<void (const SafeByteArray&)> onDataWritten;  		protected: -			const PKCS12Certificate& getTLSCertificate() const { +			CertificateWithKey * getTLSCertificate() const {  				return certificate;  			}  		private: -			PKCS12Certificate certificate; +			CertificateWithKey * certificate;  	};  } diff --git a/Swiften/StreamStack/TLSLayer.cpp b/Swiften/StreamStack/TLSLayer.cpp index 6f2223d..b7efbcb 100644 --- a/Swiften/StreamStack/TLSLayer.cpp +++ b/Swiften/StreamStack/TLSLayer.cpp @@ -37,7 +37,7 @@ void TLSLayer::handleDataRead(const SafeByteArray& data) {  	context->handleDataFromNetwork(data);  } -bool TLSLayer::setClientCertificate(const PKCS12Certificate& certificate) { +bool TLSLayer::setClientCertificate(CertificateWithKey * certificate) {  	return context->setClientCertificate(certificate);  } diff --git a/Swiften/StreamStack/TLSLayer.h b/Swiften/StreamStack/TLSLayer.h index a8693d5..6dc9135 100644 --- a/Swiften/StreamStack/TLSLayer.h +++ b/Swiften/StreamStack/TLSLayer.h @@ -14,7 +14,7 @@  namespace Swift {  	class TLSContext;  	class TLSContextFactory; -	class PKCS12Certificate; +	class CertificateWithKey;  	class TLSLayer : public StreamLayer {  		public: @@ -22,7 +22,7 @@ namespace Swift {  			~TLSLayer();  			void connect(); -			bool setClientCertificate(const PKCS12Certificate&); +			bool setClientCertificate(CertificateWithKey * cert);  			Certificate::ref getPeerCertificate() const;  			boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const; diff --git a/Swiften/TLS/CAPICertificate.h b/Swiften/TLS/CAPICertificate.h new file mode 100644 index 0000000..fcdb4c2 --- /dev/null +++ b/Swiften/TLS/CAPICertificate.h @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2012 Isode Limited, London, England. + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +#pragma once + +#include <Swiften/Base/SafeByteArray.h> +#include <Swiften/TLS/CertificateWithKey.h> + +#include <boost/algorithm/string/predicate.hpp> + +#define SECURITY_WIN32 +#include <WinCrypt.h> + +namespace Swift { +	class CAPICertificate : public Swift::CertificateWithKey { +		public: +			CAPICertificate(const std::string& capiUri) +			    : valid_(false), uri_(capiUri), cert_store_handle_(0), cert_store_(NULL), cert_name_(NULL) { +				setUri(capiUri); +			} + +			virtual ~CAPICertificate() { +				if (cert_store_handle_ != NULL) +				{ +					CertCloseStore(cert_store_handle_, 0); +				} +			} + +			virtual bool isNull() const { +				return uri_.empty() || !valid_; +			} + +			virtual bool isPrivateKeyExportable() const { +				/* We can check with CAPI, but for now the answer is "no" */ +				return false; +			} + +			virtual const std::string& getCertStoreName() const { +			    return cert_store_; +			} + +			virtual const std::string& getCertName() const { +			    return cert_name_; +			} + +			const ByteArray& getData() const { +////Might need to throw an exception here, or really generate PKCS12 blob from CAPI data? +				assert(0); +			} + +			void setData(const ByteArray& data) { +				assert(0); +			} + +			const SafeByteArray& getPassword() const { +/////Can't pass NULL to createSafeByteArray! +/////Should this throw an exception instead? +				return createSafeByteArray(""); +			} + +		protected: +			void setUri (const std::string& capiUri) { + +				valid_ = false; + +				/* Syntax: "certstore:" [<cert_store> ":"] <cert_id> */ + +				if (!boost::iequals(capiUri.substr(0, 10), "certstore:")) { +					return; +				} + +				/* Substring of subject: uses "storename" */ +				std::string capi_identity = capiUri.substr(10); +				std::string new_cert_store_name; +				size_t pos = capi_identity.find_first_of (':'); + +				if (pos == std::string::npos) { +					/* Using the default certificate store */ +					new_cert_store_name = "MY"; +					cert_name_ = capi_identity; +				} else { +					new_cert_store_name = capi_identity.substr(0, pos); +					cert_name_ = capi_identity.substr(pos + 1); +				} + +				PCCERT_CONTEXT pCertContext = NULL; + +				if (cert_store_handle_ != NULL) +				{ +					if (new_cert_store_name != cert_store_) { +						CertCloseStore(cert_store_handle_, 0); +						cert_store_handle_ = NULL; +					} +				} + +				if (cert_store_handle_ == NULL) +				{ +					cert_store_handle_ = CertOpenSystemStore(0, cert_store_.c_str()); +					if (!cert_store_handle_) +					{ +						return; +					} +				} + +				cert_store_ = new_cert_store_name; + +				/* NB: This might have to change, depending on how we locate certificates */ + +				// Find client certificate. Note that this sample just searches for a  +				// certificate that contains the user name somewhere in the subject name. +				pCertContext = CertFindCertificateInStore(cert_store_handle_, +					X509_ASN_ENCODING, +					0,				// dwFindFlags +					CERT_FIND_SUBJECT_STR_A, +					cert_name_.c_str(),		// *pvFindPara +					NULL );				// pPrevCertContext + +				if (pCertContext == NULL) +				{ +					return; +				} + + +				/* Now verify that we can have access to the corresponding private key */ + +				DWORD len; +				CRYPT_KEY_PROV_INFO *pinfo; +				HCRYPTPROV hprov; +				HCRYPTKEY key; + +				if (!CertGetCertificateContextProperty(pCertContext, +								       CERT_KEY_PROV_INFO_PROP_ID, +								       NULL, +								       &len)) +				{ +					CertFreeCertificateContext(pCertContext); +					return; +				} + +				pinfo = static_cast<CRYPT_KEY_PROV_INFO *>(malloc(len)); +				if (!pinfo) { +					CertFreeCertificateContext(pCertContext); +					return; +				} + +				if (!CertGetCertificateContextProperty(pCertContext, +								       CERT_KEY_PROV_INFO_PROP_ID, +								       pinfo, +								       &len)) +				{ +					CertFreeCertificateContext(pCertContext); +					free(pinfo); +					return; +				} + +				CertFreeCertificateContext(pCertContext); + +				// Now verify if we have access to the private key +				if (!CryptAcquireContextW(&hprov, +							  pinfo->pwszContainerName, +							  pinfo->pwszProvName, +							  pinfo->dwProvType, +							  0)) +				{ +					free(pinfo); +					return; +				} + +				if (!CryptGetUserKey(hprov, pinfo->dwKeySpec, &key)) +				{ +					CryptReleaseContext(hprov, 0); +					free(pinfo); +					return; +				} + +				CryptDestroyKey(key); +				CryptReleaseContext(hprov, 0); +				free(pinfo); + +				valid_ = true; +			} + +		private: +			bool valid_; +			std::string uri_; + +			HCERTSTORE cert_store_handle_; + +			/* Parsed components of the uri_ */ +			std::string cert_store_; +			std::string cert_name_; +	}; +} diff --git a/Swiften/TLS/CertificateWithKey.h b/Swiften/TLS/CertificateWithKey.h new file mode 100644 index 0000000..6f6ea39 --- /dev/null +++ b/Swiften/TLS/CertificateWithKey.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2010-2012 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include <Swiften/Base/SafeByteArray.h> + +namespace Swift { +	class CertificateWithKey { +		public: +			CertificateWithKey() {} + +			virtual ~CertificateWithKey() {} + +			virtual bool isNull() const = 0; + +			virtual bool isPrivateKeyExportable() const = 0; + +			virtual const std::string& getCertStoreName() const = 0; + +			virtual const std::string& getCertName() const = 0; + +			virtual const ByteArray& getData() const = 0; + +			virtual void setData(const ByteArray& data) = 0; + +			virtual const SafeByteArray& getPassword() const = 0; +	}; +} diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp index 220e7f9..dd3462f 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp @@ -21,7 +21,7 @@  #include <Swiften/TLS/OpenSSL/OpenSSLContext.h>  #include <Swiften/TLS/OpenSSL/OpenSSLCertificate.h> -#include <Swiften/TLS/PKCS12Certificate.h> +#include <Swiften/TLS/CertificateWithKey.h>  #pragma GCC diagnostic ignored "-Wold-style-cast" @@ -185,14 +185,18 @@ void OpenSSLContext::sendPendingDataToApplication() {  	}  } -bool OpenSSLContext::setClientCertificate(const PKCS12Certificate& certificate) { -	if (certificate.isNull()) { +bool OpenSSLContext::setClientCertificate(CertificateWithKey * certificate) { +	if (!certificate || certificate->isNull()) { +		return false; +	} + +	if (!certificate->isPrivateKeyExportable()) {  		return false;  	}  	// Create a PKCS12 structure  	BIO* bio = BIO_new(BIO_s_mem()); -	BIO_write(bio, vecptr(certificate.getData()), certificate.getData().size()); +	BIO_write(bio, vecptr(certificate->getData()), certificate->getData().size());  	boost::shared_ptr<PKCS12> pkcs12(d2i_PKCS12_bio(bio, NULL), PKCS12_free);  	BIO_free(bio);  	if (!pkcs12) { @@ -203,7 +207,7 @@ bool OpenSSLContext::setClientCertificate(const PKCS12Certificate& certificate)  	X509 *certPtr = 0;  	EVP_PKEY* privateKeyPtr = 0;  	STACK_OF(X509)* caCertsPtr = 0; -	int result = PKCS12_parse(pkcs12.get(), reinterpret_cast<const char*>(vecptr(certificate.getPassword())), &privateKeyPtr, &certPtr, &caCertsPtr); +	int result = PKCS12_parse(pkcs12.get(), reinterpret_cast<const char*>(vecptr(certificate->getPassword())), &privateKeyPtr, &certPtr, &caCertsPtr);  	if (result != 1) {   		return false;  	} diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h index 04693a3..b53e715 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.h +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h @@ -14,7 +14,7 @@  #include <Swiften/Base/ByteArray.h>  namespace Swift { -	class PKCS12Certificate; +	class CertificateWithKey;  	class OpenSSLContext : public TLSContext, boost::noncopyable {  		public: @@ -22,7 +22,7 @@ namespace Swift {  			~OpenSSLContext();  			void connect(); -			bool setClientCertificate(const PKCS12Certificate& cert); +			bool setClientCertificate(CertificateWithKey * cert);  			void handleDataFromNetwork(const SafeByteArray&);  			void handleDataFromApplication(const SafeByteArray&); diff --git a/Swiften/TLS/PKCS12Certificate.h b/Swiften/TLS/PKCS12Certificate.h index c0e01d0..2f70456 100644 --- a/Swiften/TLS/PKCS12Certificate.h +++ b/Swiften/TLS/PKCS12Certificate.h @@ -7,9 +7,10 @@  #pragma once  #include <Swiften/Base/SafeByteArray.h> +#include <Swiften/TLS/CertificateWithKey.h>  namespace Swift { -	class PKCS12Certificate { +	class PKCS12Certificate : public Swift::CertificateWithKey {  		public:  			PKCS12Certificate() {} @@ -17,11 +18,29 @@ namespace Swift {  				readByteArrayFromFile(data_, filename);  			} -			bool isNull() const { +			virtual ~PKCS12Certificate() {} + +			virtual bool isNull() const {  				return data_.empty();  			} -			const ByteArray& getData() const { +			virtual bool isPrivateKeyExportable() const { +/////Hopefully a PKCS12 is never missing a private key +				return true; +			} + +			virtual const std::string& getCertStoreName() const { +/////				assert(0); +				throw std::exception(); +			} + +			virtual const std::string& getCertName() const { +				/* We can return the original filename instead, if we care */ +/////				assert(0); +				throw std::exception(); +			} + +			virtual const ByteArray& getData() const {  				return data_;  			} @@ -29,7 +48,7 @@ namespace Swift {  				data_ = data;  			} -			const SafeByteArray& getPassword() const { +			virtual const SafeByteArray& getPassword() const {  				return password_;  			} diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp index 6771d4a..6f50b3a 100644 --- a/Swiften/TLS/Schannel/SchannelContext.cpp +++ b/Swiften/TLS/Schannel/SchannelContext.cpp @@ -15,6 +15,9 @@ SchannelContext::SchannelContext()  : m_state(Start)  , m_secContext(0)  , m_verificationError(CertificateVerificationError::UnknownError) +, m_my_cert_store(NULL) +, m_cert_store_name("MY") +, m_cert_name(NULL)  {  	m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY |   				  ISC_REQ_CONFIDENTIALITY | @@ -30,6 +33,13 @@ SchannelContext::SchannelContext()  //------------------------------------------------------------------------ +SchannelContext::~SchannelContext() +{ +	if (m_my_cert_store) CertCloseStore(m_my_cert_store, 0); +} + +//------------------------------------------------------------------------ +  void SchannelContext::determineStreamSizes()  {  	QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes); @@ -39,17 +49,65 @@ void SchannelContext::determineStreamSizes()  void SchannelContext::connect()   { +	PCCERT_CONTEXT   pCertContext = NULL; +  	m_state = Connecting; +	// If a user name is specified, then attempt to find a client +	// certificate. Otherwise, just create a NULL credential. +	if (!m_cert_name.empty()) +	{ +		if (m_my_cert_store == NULL) +		{ +			m_my_cert_store = CertOpenSystemStore(0, m_cert_store_name.c_str()); +			if (!m_my_cert_store) +			{ +/////			printf( "**** Error 0x%x returned by CertOpenSystemStore\n", GetLastError() ); +				indicateError(); +				return; +			} +		} + +		// Find client certificate. Note that this sample just searches for a  +		// certificate that contains the user name somewhere in the subject name. +		pCertContext = CertFindCertificateInStore( m_my_cert_store, +			X509_ASN_ENCODING, +			0,				// dwFindFlags +			CERT_FIND_SUBJECT_STR_A, +			m_cert_name.c_str(),		// *pvFindPara +			NULL );				// pPrevCertContext + +		if (pCertContext == NULL) +		{ +/////		printf("**** Error 0x%x returned by CertFindCertificateInStore\n", GetLastError()); +			indicateError(); +			return; +		} +	} +  	// We use an empty list for client certificates  	PCCERT_CONTEXT clientCerts[1] = {0};  	SCHANNEL_CRED sc = {0};  	sc.dwVersion = SCHANNEL_CRED_VERSION; -	sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us -	sc.paCred = clientCerts; + +/////SSL3?  	sc.grbitEnabledProtocols = SP_PROT_SSL3_CLIENT | SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT; -	sc.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | /*SCH_CRED_NO_DEFAULT_CREDS*/ SCH_CRED_USE_DEFAULT_CREDS | SCH_CRED_REVOCATION_CHECK_CHAIN; +/////Check SCH_CRED_REVOCATION_CHECK_CHAIN +	sc.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_REVOCATION_CHECK_CHAIN; + +	if (pCertContext) +	{ +		sc.cCreds = 1; +		sc.paCred = &pCertContext; +		sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; +	} +	else +	{ +		sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us +		sc.paCred = clientCerts; +		sc.dwFlags |= SCH_CRED_USE_DEFAULT_CREDS; +	}  	// Swiften performs the server name check for us  	sc.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK; @@ -65,6 +123,9 @@ void SchannelContext::connect()  		m_credHandle.Reset(),  		NULL); +	// cleanup: Free the certificate context. Schannel has already made its own copy. +	if (pCertContext) CertFreeCertificateContext(pCertContext); +  	if (status != SEC_E_OK)   	{  		// We failed to obtain the credentials handle @@ -456,8 +517,21 @@ void SchannelContext::encryptAndSendData(const SafeByteArray& data)  //------------------------------------------------------------------------ -bool SchannelContext::setClientCertificate(const PKCS12Certificate& certificate)  +bool SchannelContext::setClientCertificate(CertificateWithKey * certificate)  { +	if (!certificate || certificate->isNull()) { +		return false; +	} + +	if (!certificate->isPrivateKeyExportable()) { +		// We assume that the Certificate Store Name/Certificate Name +		// are valid at this point +		m_cert_store_name = certificate->getCertStoreName(); +		m_cert_name = certificate->getCertName(); + +		return true; +	} +  	return false;  } diff --git a/Swiften/TLS/Schannel/SchannelContext.h b/Swiften/TLS/Schannel/SchannelContext.h index 66467fe..0cdb3d7 100644 --- a/Swiften/TLS/Schannel/SchannelContext.h +++ b/Swiften/TLS/Schannel/SchannelContext.h @@ -10,6 +10,7 @@  #include "Swiften/TLS/TLSContext.h"  #include "Swiften/TLS/Schannel/SchannelUtil.h" +#include <Swiften/TLS/CertificateWithKey.h>  #include "Swiften/Base/ByteArray.h"  #define SECURITY_WIN32 @@ -28,13 +29,15 @@ namespace Swift  		typedef boost::shared_ptr<SchannelContext> sp_t;  	public: -						SchannelContext(); +		SchannelContext(); + +		~SchannelContext();  		//  		// TLSContext  		//  		virtual void	connect(); -		virtual bool	setClientCertificate(const PKCS12Certificate&); +		virtual bool	setClientCertificate(CertificateWithKey * cert);  		virtual void	handleDataFromNetwork(const SafeByteArray& data);  		virtual void	handleDataFromApplication(const SafeByteArray& data); @@ -77,5 +80,9 @@ namespace Swift  		SecPkgContext_StreamSizes m_streamSizes;  		std::vector<char>	m_receivedData; + +		HCERTSTORE		m_my_cert_store; +		std::string		m_cert_store_name; +		std::string		m_cert_name;  	};  } diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h index 1538863..ada813a 100644 --- a/Swiften/TLS/TLSContext.h +++ b/Swiften/TLS/TLSContext.h @@ -14,7 +14,7 @@  #include <Swiften/TLS/CertificateVerificationError.h>  namespace Swift { -	class PKCS12Certificate; +	class CertificateWithKey;  	class TLSContext {  		public: @@ -22,7 +22,7 @@ namespace Swift {  			virtual void connect() = 0; -			virtual bool setClientCertificate(const PKCS12Certificate& cert) = 0; +			virtual bool setClientCertificate(CertificateWithKey * cert) = 0;  			virtual void handleDataFromNetwork(const SafeByteArray&) = 0;  			virtual void handleDataFromApplication(const SafeByteArray&) = 0; | 
 Swift
 Swift