diff options
| -rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp | 4 | ||||
| -rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.cpp | 118 | ||||
| -rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.h | 49 | 
3 files changed, 100 insertions, 71 deletions
| diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp b/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp index 17ac8cc..8d2d965 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp +++ b/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp @@ -30,7 +30,7 @@ OpenSSLCertificate::OpenSSLCertificate(const ByteArray& der) {  #else      const unsigned char* p = vecptr(der);  #endif -    cert = std::shared_ptr<X509>(d2i_X509(NULL, &p, der.size()), X509_free); +    cert = std::shared_ptr<X509>(d2i_X509(nullptr, &p, der.size()), X509_free);      if (!cert) {          SWIFT_LOG(warning) << "Error creating certificate from DER data" << std::endl;      } @@ -42,7 +42,7 @@ ByteArray OpenSSLCertificate::toDER() const {      if (!cert) {          return result;      } -    result.resize(i2d_X509(cert.get(), NULL)); +    result.resize(i2d_X509(cert.get(), nullptr));      unsigned char* p = vecptr(result);      i2d_X509(cert.get(), &p);      return result; diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp index 0805917..6f15edf 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp @@ -1,5 +1,5 @@  /* - * Copyright (c) 2010-2016 Isode Limited. + * Copyright (c) 2010-2018 Isode Limited.   * All rights reserved.   * See the COPYING file for more information.   */ @@ -10,10 +10,12 @@  #include <wincrypt.h>  #endif +#include <cassert> +#include <memory>  #include <vector> +  #include <openssl/err.h>  #include <openssl/pkcs12.h> -#include <memory>  #if defined(SWIFTEN_PLATFORM_MACOSX)  #include <Security/Security.h> @@ -39,10 +41,32 @@ static void freeX509Stack(STACK_OF(X509)* stack) {      sk_X509_free(stack);  } -OpenSSLContext::OpenSSLContext() : state_(Start), context_(0), handle_(0), readBIO_(0), writeBIO_(0) { +namespace { +    class OpenSSLInitializerFinalizer { +        public: +            OpenSSLInitializerFinalizer() { +                SSL_load_error_strings(); +                SSL_library_init(); +                OpenSSL_add_all_algorithms(); + +                // Disable compression +                /* +                STACK_OF(SSL_COMP)* compressionMethods = SSL_COMP_get_compression_methods(); +                sk_SSL_COMP_zero(compressionMethods);*/ +            } + +            ~OpenSSLInitializerFinalizer() { +                EVP_cleanup(); +            } + +            OpenSSLInitializerFinalizer(const OpenSSLInitializerFinalizer &) = delete; +    }; +} + +OpenSSLContext::OpenSSLContext() : state_(State::Start) {      ensureLibraryInitialized(); -    context_ = SSL_CTX_new(SSLv23_client_method()); -    SSL_CTX_set_options(context_, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); +    context_ = std::unique_ptr<SSL_CTX>(SSL_CTX_new(SSLv23_client_method())); +    SSL_CTX_set_options(context_.get(), SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);      // TODO: implement CRL checking      // TODO: download CRL (HTTP transport) @@ -52,7 +76,7 @@ OpenSSLContext::OpenSSLContext() : state_(Start), context_(0), handle_(0), readB      // TODO: handle OCSP stapling see https://www.rfc-editor.org/rfc/rfc4366.txt      // Load system certs  #if defined(SWIFTEN_PLATFORM_WINDOWS) -    X509_STORE* store = SSL_CTX_get_cert_store(context_); +    X509_STORE* store = SSL_CTX_get_cert_store(context_.get());      HCERTSTORE systemStore = CertOpenSystemStore(0, "ROOT");      if (systemStore) {          PCCERT_CONTEXT certContext = NULL; @@ -68,7 +92,7 @@ OpenSSLContext::OpenSSLContext() : state_(Start), context_(0), handle_(0), readB          }      }  #elif !defined(SWIFTEN_PLATFORM_MACOSX) -    SSL_CTX_set_default_verify_paths(context_); +    SSL_CTX_set_default_verify_paths(context_.get());  #elif defined(SWIFTEN_PLATFORM_MACOSX) && !defined(SWIFTEN_PLATFORM_IPHONE)      // On Mac OS X 10.5 (OpenSSL < 0.9.8), OpenSSL does not automatically look in the system store.      // On Mac OS X 10.6 (OpenSSL >= 0.9.8), OpenSSL *does* look in the system store to determine trust. @@ -76,7 +100,7 @@ OpenSSLContext::OpenSSLContext() : state_(Start), context_(0), handle_(0), readB      // the certificates first. See      //        http://opensource.apple.com/source/OpenSSL098/OpenSSL098-27/src/crypto/x509/x509_vfy_apple.c      // to understand why. We therefore add all certs from the system store ourselves. -    X509_STORE* store = SSL_CTX_get_cert_store(context_); +    X509_STORE* store = SSL_CTX_get_cert_store(context_.get());      CFArrayRef anchorCertificates;      if (SecTrustCopyAnchorCertificates(&anchorCertificates) == 0) {          for (int i = 0; i < CFArrayGetCount(anchorCertificates); ++i) { @@ -99,51 +123,37 @@ OpenSSLContext::OpenSSLContext() : state_(Start), context_(0), handle_(0), readB  }  OpenSSLContext::~OpenSSLContext() { -    SSL_free(handle_); -    SSL_CTX_free(context_);  }  void OpenSSLContext::ensureLibraryInitialized() { -    static bool isLibraryInitialized = false; -    if (!isLibraryInitialized) { -        SSL_load_error_strings(); -        SSL_library_init(); -        OpenSSL_add_all_algorithms(); - -        // Disable compression -        /* -        STACK_OF(SSL_COMP)* compressionMethods = SSL_COMP_get_compression_methods(); -        sk_SSL_COMP_zero(compressionMethods);*/ - -        isLibraryInitialized = true; -    } +    static OpenSSLInitializerFinalizer openSSLInit;  }  void OpenSSLContext::connect() { -    handle_ = SSL_new(context_); -    if (handle_ == nullptr) { -        state_ = Error; +    handle_ = std::unique_ptr<SSL>(SSL_new(context_.get())); +    if (!handle_) { +        state_ = State::Error;          onError(std::make_shared<TLSError>());          return;      } -    // Ownership of BIOs is ransferred +    // Ownership of BIOs is transferred      readBIO_ = BIO_new(BIO_s_mem());      writeBIO_ = BIO_new(BIO_s_mem()); -    SSL_set_bio(handle_, readBIO_, writeBIO_); +    SSL_set_bio(handle_.get(), readBIO_, writeBIO_); -    state_ = Connecting; +    state_ = State::Connecting;      doConnect();  }  void OpenSSLContext::doConnect() { -    int connectResult = SSL_connect(handle_); -    int error = SSL_get_error(handle_, connectResult); +    int connectResult = SSL_connect(handle_.get()); +    int error = SSL_get_error(handle_.get(), connectResult);      switch (error) {          case SSL_ERROR_NONE: { -            state_ = Connected; +            state_ = State::Connected;              //std::cout << x->name << std::endl; -            //const char* comp = SSL_get_current_compression(handle_); +            //const char* comp = SSL_get_current_compression(handle_.get());              //std::cout << "Compression: " << SSL_COMP_get_name(comp) << std::endl;              onConnected();              break; @@ -152,7 +162,7 @@ void OpenSSLContext::doConnect() {              sendPendingDataToNetwork();              break;          default: -            state_ = Error; +            state_ = State::Error;              onError(std::make_shared<TLSError>());      }  } @@ -170,23 +180,23 @@ void OpenSSLContext::sendPendingDataToNetwork() {  void OpenSSLContext::handleDataFromNetwork(const SafeByteArray& data) {      BIO_write(readBIO_, vecptr(data), data.size());      switch (state_) { -        case Connecting: +        case State::Connecting:              doConnect();              break; -        case Connected: +        case State::Connected:              sendPendingDataToApplication();              break; -        case Start: assert(false); break; -        case Error: /*assert(false);*/ break; +        case State::Start: assert(false); break; +        case State::Error: /*assert(false);*/ break;      }  }  void OpenSSLContext::handleDataFromApplication(const SafeByteArray& data) { -    if (SSL_write(handle_, vecptr(data), data.size()) >= 0) { +    if (SSL_write(handle_.get(), vecptr(data), data.size()) >= 0) {          sendPendingDataToNetwork();      }      else { -        state_ = Error; +        state_ = State::Error;          onError(std::make_shared<TLSError>());      }  } @@ -194,15 +204,15 @@ void OpenSSLContext::handleDataFromApplication(const SafeByteArray& data) {  void OpenSSLContext::sendPendingDataToApplication() {      SafeByteArray data;      data.resize(SSL_READ_BUFFERSIZE); -    int ret = SSL_read(handle_, vecptr(data), data.size()); +    int ret = SSL_read(handle_.get(), vecptr(data), data.size());      while (ret > 0) {          data.resize(ret);          onDataForApplication(data);          data.resize(SSL_READ_BUFFERSIZE); -        ret = SSL_read(handle_, vecptr(data), data.size()); +        ret = SSL_read(handle_.get(), vecptr(data), data.size());      } -    if (ret < 0 && SSL_get_error(handle_, ret) != SSL_ERROR_WANT_READ) { -        state_ = Error; +    if (ret < 0 && SSL_get_error(handle_.get(), ret) != SSL_ERROR_WANT_READ) { +        state_ = State::Error;          onError(std::make_shared<TLSError>());      }  } @@ -216,16 +226,16 @@ bool OpenSSLContext::setClientCertificate(CertificateWithKey::ref certificate) {      // Create a PKCS12 structure      BIO* bio = BIO_new(BIO_s_mem());      BIO_write(bio, vecptr(pkcs12Certificate->getData()), pkcs12Certificate->getData().size()); -    std::shared_ptr<PKCS12> pkcs12(d2i_PKCS12_bio(bio, NULL), PKCS12_free); +    std::shared_ptr<PKCS12> pkcs12(d2i_PKCS12_bio(bio, nullptr), PKCS12_free);      BIO_free(bio);      if (!pkcs12) {          return false;      }      // Parse PKCS12 -    X509 *certPtr = 0; -    EVP_PKEY* privateKeyPtr = 0; -    STACK_OF(X509)* caCertsPtr = 0; +    X509 *certPtr = nullptr; +    EVP_PKEY* privateKeyPtr = nullptr; +    STACK_OF(X509)* caCertsPtr = nullptr;      SafeByteArray password(pkcs12Certificate->getPassword());      password.push_back(0);      int result = PKCS12_parse(pkcs12.get(), reinterpret_cast<const char*>(vecptr(password)), &privateKeyPtr, &certPtr, &caCertsPtr); @@ -237,21 +247,21 @@ bool OpenSSLContext::setClientCertificate(CertificateWithKey::ref certificate) {      std::shared_ptr<STACK_OF(X509)> caCerts(caCertsPtr, freeX509Stack);      // Use the key & certificates -    if (SSL_CTX_use_certificate(context_, cert.get()) != 1) { +    if (SSL_CTX_use_certificate(context_.get(), cert.get()) != 1) {          return false;      } -    if (SSL_CTX_use_PrivateKey(context_, privateKey.get()) != 1) { +    if (SSL_CTX_use_PrivateKey(context_.get(), privateKey.get()) != 1) {          return false;      }      for (int i = 0;  i < sk_X509_num(caCerts.get()); ++i) { -        SSL_CTX_add_extra_chain_cert(context_, sk_X509_value(caCerts.get(), i)); +        SSL_CTX_add_extra_chain_cert(context_.get(), sk_X509_value(caCerts.get(), i));      }      return true;  }  std::vector<Certificate::ref> OpenSSLContext::getPeerCertificateChain() const {      std::vector<Certificate::ref> result; -    STACK_OF(X509)* chain = SSL_get_peer_cert_chain(handle_); +    STACK_OF(X509)* chain = SSL_get_peer_cert_chain(handle_.get());      for (int i = 0; i < sk_X509_num(chain); ++i) {          std::shared_ptr<X509> x509Cert(X509_dup(sk_X509_value(chain, i)), X509_free); @@ -262,7 +272,7 @@ std::vector<Certificate::ref> OpenSSLContext::getPeerCertificateChain() const {  }  std::shared_ptr<CertificateVerificationError> OpenSSLContext::getPeerCertificateVerificationError() const { -    int verifyResult = SSL_get_verify_result(handle_); +    int verifyResult = SSL_get_verify_result(handle_.get());      if (verifyResult != X509_V_OK) {          return std::make_shared<CertificateVerificationError>(getVerificationErrorTypeForResult(verifyResult));      } @@ -274,7 +284,7 @@ std::shared_ptr<CertificateVerificationError> OpenSSLContext::getPeerCertificate  ByteArray OpenSSLContext::getFinishMessage() const {      ByteArray data;      data.resize(MAX_FINISHED_SIZE); -    size_t size = SSL_get_finished(handle_, vecptr(data), data.size()); +    size_t size = SSL_get_finished(handle_.get(), vecptr(data), data.size());      data.resize(size);      return data;  } diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h index e75b3c9..49ada51 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.h +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h @@ -1,11 +1,13 @@  /* - * Copyright (c) 2010-2016 Isode Limited. + * Copyright (c) 2010-2018 Isode Limited.   * All rights reserved.   * See the COPYING file for more information.   */  #pragma once +#include <memory> +  #include <boost/noncopyable.hpp>  #include <boost/signals2.hpp> @@ -15,23 +17,40 @@  #include <Swiften/TLS/CertificateWithKey.h>  #include <Swiften/TLS/TLSContext.h> -namespace Swift { +namespace std { +    template<> +    class default_delete<SSL_CTX> { +    public: +        void operator()(SSL_CTX *ptr) { +            SSL_CTX_free(ptr); +        } +    }; +    template<> +    class default_delete<SSL> { +    public: +        void operator()(SSL *ptr) { +            SSL_free(ptr); +        } +    }; +} + +namespace Swift {      class OpenSSLContext : public TLSContext, boost::noncopyable {          public:              OpenSSLContext(); -            virtual ~OpenSSLContext(); +            virtual ~OpenSSLContext() override final; -            void connect(); -            bool setClientCertificate(CertificateWithKey::ref cert); +            void connect() override final; +            bool setClientCertificate(CertificateWithKey::ref cert) override final; -            void handleDataFromNetwork(const SafeByteArray&); -            void handleDataFromApplication(const SafeByteArray&); +            void handleDataFromNetwork(const SafeByteArray&) override final; +            void handleDataFromApplication(const SafeByteArray&) override final; -            std::vector<Certificate::ref> getPeerCertificateChain() const; -            std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const; +            std::vector<Certificate::ref> getPeerCertificateChain() const override final; +            std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const override final; -            virtual ByteArray getFinishMessage() const; +            virtual ByteArray getFinishMessage() const override final;          private:              static void ensureLibraryInitialized(); @@ -43,12 +62,12 @@ namespace Swift {              void sendPendingDataToApplication();          private: -            enum State { Start, Connecting, Connected, Error }; +            enum class State { Start, Connecting, Connected, Error };              State state_; -            SSL_CTX* context_; -            SSL* handle_; -            BIO* readBIO_; -            BIO* writeBIO_; +            std::unique_ptr<SSL_CTX> context_; +            std::unique_ptr<SSL> handle_; +            BIO* readBIO_ = nullptr; +            BIO* writeBIO_ = nullptr;      };  } | 
 Swift
 Swift