tls rework

This commit is contained in:
miodragpop
2020-10-23 21:17:20 +02:00
parent be16f80abc
commit eaed7355c6
6 changed files with 706 additions and 220 deletions

View File

@@ -69,7 +69,13 @@ using namespace hush;
#endif
#define USE_TLS
#define COMPAT_NON_TLS // enables compatibility with nodes, that still doesn't support TLS connections
#if defined(USE_TLS) && !defined(TLS1_2_VERSION)
// minimum secure protocol is 1.2
// TLS1_2_VERSION is defined in openssl/tls1.h
#error "ERROR: Your OpenSSL version does not support TLS v1.2"
#endif
using namespace std;
@@ -359,6 +365,9 @@ void AddressCurrentlyConnected(const CService& addr)
}
CNode::eTlsOption CNode::tlsFallbackNonTls = CNode::eTlsOption::FALLBACK_UNSET;
CNode::eTlsOption CNode::tlsValidate = CNode::eTlsOption::FALLBACK_UNSET;
uint64_t CNode::nTotalBytesRecv = 0;
uint64_t CNode::nTotalBytesSent = 0;
CCriticalSection CNode::cs_totalBytesRecv;
@@ -435,54 +444,84 @@ CNode* ConnectNode(CAddress addrConnect, const char *pszDest)
addrman.Attempt(addrConnect);
SSL *ssl = NULL;
#ifdef USE_TLS
/* TCP connection is ready. Do client side SSL. */
#ifdef COMPAT_NON_TLS
if (CNode::GetTlsFallbackNonTls())
{
LOCK(cs_vNonTLSNodesOutbound);
NODE_ADDR nodeAddr(addrConnect.ToStringIP());
bool bUseTLS = ((GetBoolArg("-tls", true) || GetArg("-tls", "") == "only") && find(vNonTLSNodesOutbound.begin(),
vNonTLSNodesOutbound.end(),
nodeAddr) == vNonTLSNodesOutbound.end());
if (bUseTLS)
{
ssl = tlsmanager.connect(hSocket, addrConnect);
if (!ssl)
LOCK(cs_vNonTLSNodesOutbound);
LogPrint("tls", "%s():%d - handling connection to %s\n", __func__, __LINE__, addrConnect.ToString());
NODE_ADDR nodeAddr(addrConnect.ToStringIP());
bool bUseTLS = (find(vNonTLSNodesOutbound.begin(),
vNonTLSNodesOutbound.end(),
nodeAddr) == vNonTLSNodesOutbound.end());
unsigned long err_code = 0;
if (bUseTLS)
{
if (GetArg("-tls", "") != "only")
ssl = tlsmanager.connect(hSocket, addrConnect, err_code);
if (!ssl)
{
// Further reconnection will be made in non-TLS (unencrypted) mode if mandatory tls is not set
vNonTLSNodesOutbound.push_back(NODE_ADDR(addrConnect.ToStringIP(), GetTimeMillis()));
if (err_code == TLSManager::SELECT_TIMEDOUT)
{
// can fail for timeout in select on fd, that is not a ssl error and we should not
// consider this node as non TLS
LogPrint("tls", "%s():%d - Connection to %s timedout\n",
__func__, __LINE__, addrConnect.ToStringIP());
}
else
{
// Further reconnection will be made in non-TLS (unencrypted) mode
vNonTLSNodesOutbound.push_back(NODE_ADDR(addrConnect.ToStringIP(), GetTimeMillis()));
LogPrint("tls", "%s():%d - err_code %x, adding connection to %s vNonTLSNodesOutbound list (sz=%d)\n",
__func__, __LINE__, err_code, addrConnect.ToStringIP(), vNonTLSNodesOutbound.size());
}
CloseSocket(hSocket);
return NULL;
}
CloseSocket(hSocket);
return NULL;
}
else
{
LogPrintf ("Connection to %s will be unencrypted\n", addrConnect.ToString());
vNonTLSNodesOutbound.erase(
remove(
vNonTLSNodesOutbound.begin(),
vNonTLSNodesOutbound.end(),
nodeAddr),
vNonTLSNodesOutbound.end());
}
}
else
}
else
{
unsigned long err_code = 0;
ssl = tlsmanager.connect(hSocket, addrConnect, err_code);
if(!ssl)
{
LogPrintf ("Connection to %s will be unencrypted\n", addrConnect.ToString());
vNonTLSNodesOutbound.erase(
remove(
vNonTLSNodesOutbound.begin(),
vNonTLSNodesOutbound.end(),
nodeAddr),
vNonTLSNodesOutbound.end());
LogPrint("tls", "%s():%d - err_code %x, connection to %s failed)\n",
__func__, __LINE__, err_code, addrConnect.ToStringIP());
CloseSocket(hSocket);
return NULL;
}
}
#else
ssl = TLSManager::connect(hSocket, addrConnect);
if(!ssl)
// certificate validation is disabled by default
if (CNode::GetTlsValidate())
{
CloseSocket(hSocket);
return NULL;
if (ssl && !ValidatePeerCertificate(ssl))
{
LogPrintf ("TLS: ERROR: Wrong server certificate from %s. Connection will be closed.\n", addrConnect.ToString());
SSL_shutdown(ssl);
CloseSocket(hSocket);
SSL_free(ssl);
return NULL;
}
}
#endif // COMPAT_NON_TLS
#endif // USE_TLS
// Add node
@@ -509,15 +548,15 @@ CNode* ConnectNode(CAddress addrConnect, const char *pszDest)
void CNode::CloseSocketDisconnect()
{
fDisconnect = true;
{
LOCK(cs_hSocket);
if (hSocket != INVALID_SOCKET)
{
if (hSocket != INVALID_SOCKET)
{
try
{
LogPrint("net", "disconnecting peer=%d\n", id);
LogPrint("net", "disconnecting peer=%d\n", id);
}
catch(std::bad_alloc&)
{
@@ -528,13 +567,13 @@ void CNode::CloseSocketDisconnect()
if (ssl)
{
tlsmanager.waitFor(SSL_SHUTDOWN, hSocket, ssl, (DEFAULT_CONNECT_TIMEOUT / 1000));
unsigned long err_code = 0;
tlsmanager.waitFor(SSL_SHUTDOWN, hSocket, ssl, (DEFAULT_CONNECT_TIMEOUT / 1000), err_code);
SSL_free(ssl);
ssl = NULL;
}
CloseSocket(hSocket);
}
CloseSocket(hSocket);
}
}
// in case this fails, we'll empty the recv buffer when the CNode is deleted
@@ -708,6 +747,7 @@ void CNode::copyStats(CNodeStats &stats, const std::vector<bool> &m_asmap)
{
LOCK(cs_hSocket);
stats.fTLSEstablished = (ssl != NULL) && (SSL_get_state(ssl) == TLS_ST_OK);
stats.fTLSVerified = (ssl != NULL) && ValidatePeerCertificate(ssl);
}
}
@@ -812,15 +852,15 @@ void SocketSendData(CNode *pnode)
int nBytes = 0, nRet = 0;
{
LOCK(pnode->cs_hSocket);
if (pnode->hSocket == INVALID_SOCKET)
{
LogPrint("net", "Send: connection with %s is already closed\n", pnode->addr.ToString());
break;
}
bIsSSL = (pnode->ssl != NULL);
if (bIsSSL)
{
ERR_clear_error(); // clear the error queue, otherwise we may be reading an old error that occurred previously in the current thread
@@ -856,15 +896,15 @@ void SocketSendData(CNode *pnode)
if (nRet != SSL_ERROR_WANT_READ && nRet != SSL_ERROR_WANT_WRITE)
{
LogPrintf("ERROR: SSL_write %s; closing connection\n", ERR_error_string(nRet, NULL));
pnode->CloseSocketDisconnect();
}
pnode->CloseSocketDisconnect();
}
else
{
// preventive measure from exhausting CPU usage
//
MilliSleep(1); // 1 msec
}
}
}
else
{
if (nRet != WSAEWOULDBLOCK && nRet != WSAEMSGSIZE && nRet != WSAEINTR && nRet != WSAEINPROGRESS)
@@ -1155,29 +1195,40 @@ static void AcceptConnection(const ListenSocket& hListenSocket) {
#endif
SSL *ssl = NULL;
SetSocketNonBlocking(hSocket, true);
#ifdef USE_TLS
/* TCP connection is ready. Do server side SSL. */
#ifdef COMPAT_NON_TLS
if (CNode::GetTlsFallbackNonTls())
{
LOCK(cs_vNonTLSNodesInbound);
LogPrint("tls", "%s():%d - handling connection from %s\n", __func__, __LINE__, addr.ToString());
NODE_ADDR nodeAddr(addr.ToStringIP());
bool bUseTLS = ((GetBoolArg("-tls", true) || GetArg("-tls", "") == "only") && find(vNonTLSNodesInbound.begin(),
bool bUseTLS = (find(vNonTLSNodesInbound.begin(),
vNonTLSNodesInbound.end(),
nodeAddr) == vNonTLSNodesInbound.end());
unsigned long err_code = 0;
if (bUseTLS)
{
ssl = tlsmanager.accept( hSocket, addr);
ssl = tlsmanager.accept( hSocket, addr, err_code);
if(!ssl)
{
if (GetArg("-tls", "") != "only")
if (err_code == TLSManager::SELECT_TIMEDOUT)
{
// Further reconnection will be made in non-TLS (unencrypted) mode if mandatory tls is not set
// can fail also for timeout in select on fd, that is not a ssl error and we should not
// consider this node as non TLS
LogPrint("tls", "%s():%d - Connection from %s timedout\n", __func__, __LINE__, addr.ToStringIP());
}
else
{
// Further reconnection will be made in non-TLS (unencrypted) mode
vNonTLSNodesInbound.push_back(NODE_ADDR(addr.ToStringIP(), GetTimeMillis()));
LogPrint("tls", "%s():%d - err_code %x, adding connection from %s vNonTLSNodesInbound list (sz=%d)\n",
__func__, __LINE__, err_code, addr.ToStringIP(), vNonTLSNodesInbound.size());
}
CloseSocket(hSocket);
return;
@@ -1185,8 +1236,8 @@ static void AcceptConnection(const ListenSocket& hListenSocket) {
}
else
{
LogPrintf ("TLS: Connection from %s will be unencrypted\n", addr.ToString());
LogPrintf ("TLS: Connection from %s will be unencrypted\n", addr.ToStringIP());
vNonTLSNodesInbound.erase(
remove(
vNonTLSNodesInbound.begin(),
@@ -1196,14 +1247,32 @@ static void AcceptConnection(const ListenSocket& hListenSocket) {
vNonTLSNodesInbound.end());
}
}
#else
ssl = TLSManager::accept( hSocket, addr);
if(!ssl)
else
{
CloseSocket(hSocket);
return;
unsigned long err_code = 0;
ssl = tlsmanager.accept( hSocket, addr, err_code);
if(!ssl)
{
LogPrint("tls", "%s():%d - err_code %x, failure accepting connection from %s\n",
__func__, __LINE__, err_code, addr.ToStringIP());
CloseSocket(hSocket);
return;
}
}
// certificate validation is disabled by default
if (CNode::GetTlsValidate())
{
if (ssl && !ValidatePeerCertificate(ssl))
{
LogPrintf ("TLS: ERROR: Wrong client certificate from %s. Connection will be closed.\n", addr.ToString());
SSL_shutdown(ssl);
CloseSocket(hSocket);
SSL_free(ssl);
return;
}
}
#endif // COMPAT_NON_TLS
#endif // USE_TLS
CNode* pnode = new CNode(hSocket, addr, "", true, ssl);
@@ -1218,7 +1287,7 @@ static void AcceptConnection(const ListenSocket& hListenSocket) {
}
}
#if defined(USE_TLS) && defined(COMPAT_NON_TLS)
#if defined(USE_TLS)
void ThreadNonTLSPoolsCleaner()
{
while (true)
@@ -1228,7 +1297,9 @@ void ThreadNonTLSPoolsCleaner()
MilliSleep(DEFAULT_CONNECT_TIMEOUT); // sleep and sleep_for are interruption points, which will throw boost::thread_interrupted
}
}
#endif // USE_TLS && COMPAT_NON_TLS
#endif // USE_TLS
void ThreadSocketHandler()
{
@@ -1325,9 +1396,10 @@ void ThreadSocketHandler()
BOOST_FOREACH(CNode* pnode, vNodes)
{
LOCK(pnode->cs_hSocket);
if (pnode->hSocket == INVALID_SOCKET)
continue;
FD_SET(pnode->hSocket, &fdsetError);
hSocketMax = max(hSocketMax, pnode->hSocket);
have_fds = true;
@@ -1347,6 +1419,7 @@ void ThreadSocketHandler()
// * We send some data.
// * We wait for data to be received (and disconnect after timeout).
// * We process a message in the buffer (message handler thread).
{
TRY_LOCK(pnode->cs_vSend, lockSend);
if (lockSend && !pnode->vSendMsg.empty()) {
@@ -1407,8 +1480,9 @@ void ThreadSocketHandler()
{
boost::this_thread::interruption_point();
if (tlsmanager.threadSocketHandler(pnode,fdsetRecv,fdsetSend,fdsetError)==-1)
if (tlsmanager.threadSocketHandler(pnode,fdsetRecv,fdsetSend,fdsetError)==-1){
continue;
}
//
// Inactivity checking
@@ -1538,6 +1612,7 @@ void ThreadOpenConnections()
{
CAddress addr;
OpenNetworkConnection(addr, NULL, strAddr.c_str());
for (int i = 0; i < 10 && i < nLoop; i++)
{
MilliSleep(500);
@@ -1721,30 +1796,32 @@ bool OpenNetworkConnection(const CAddress& addrConnect, CSemaphoreGrant *grantOu
return false;
} else if (FindNode(std::string(pszDest)))
return false;
CNode* pnode = ConnectNode(addrConnect, pszDest);
boost::this_thread::interruption_point();
#if defined(USE_TLS) && defined(COMPAT_NON_TLS)
if (!pnode)
#if defined(USE_TLS)
if (CNode::GetTlsFallbackNonTls())
{
string strDest;
int port;
if (!pszDest)
strDest = addrConnect.ToStringIP();
else
SplitHostPort(string(pszDest), port, strDest);
if (tlsmanager.isNonTLSAddr(strDest, vNonTLSNodesOutbound, cs_vNonTLSNodesOutbound))
if (!pnode)
{
// Attempt to reconnect in non-TLS mode
pnode = ConnectNode(addrConnect, pszDest);
boost::this_thread::interruption_point();
string strDest;
int port;
if (!pszDest)
strDest = addrConnect.ToStringIP();
else
SplitHostPort(string(pszDest), port, strDest);
if (tlsmanager.isNonTLSAddr(strDest, vNonTLSNodesOutbound, cs_vNonTLSNodesOutbound))
{
// Attempt to reconnect in non-TLS mode
pnode = ConnectNode(addrConnect, pszDest);
boost::this_thread::interruption_point();
}
}
}
#endif
if (!pnode)
@@ -1836,6 +1913,7 @@ bool BindListenPort(const CService &addrBind, string& strError, bool fWhiteliste
// Create socket for listening for incoming connections
struct sockaddr_storage sockaddr;
socklen_t len = sizeof(sockaddr);
if (!addrBind.GetSockAddr((struct sockaddr*)&sockaddr, &len))
{
strError = strprintf("Error: Bind address family for %s not supported", addrBind.ToString());
@@ -2008,13 +2086,13 @@ void StartNode(boost::thread_group& threadGroup, CScheduler& scheduler)
Discover(threadGroup);
#ifdef USE_TLS
if (!tlsmanager.prepareCredentials())
{
LogPrintf("TLS: ERROR: %s: %s: Credentials weren't loaded. Node can't be started.\n", __FILE__, __func__);
return;
}
if (!tlsmanager.initialize())
{
LogPrintf("TLS: ERROR: %s: %s: TLS initialization failed. Node can't be started.\n", __FILE__, __func__);
@@ -2051,11 +2129,14 @@ void StartNode(boost::thread_group& threadGroup, CScheduler& scheduler)
// Process messages
threadGroup.create_thread(boost::bind(&TraceThread<void (*)()>, "msghand", &ThreadMessageHandler));
#if defined(USE_TLS) && defined(COMPAT_NON_TLS)
// Clean pools of addresses for non-TLS connections
threadGroup.create_thread(boost::bind(&TraceThread<void (*)()>, "poolscleaner", &ThreadNonTLSPoolsCleaner));
#if defined(USE_TLS)
if (CNode::GetTlsFallbackNonTls())
{
// Clean pools of addresses for non-TLS connections
threadGroup.create_thread(boost::bind(&TraceThread<void (*)()>, "poolscleaner", &ThreadNonTLSPoolsCleaner));
}
#endif
// Dump network addresses
scheduler.scheduleEvery(&DumpAddresses, DUMP_ADDRESSES_INTERVAL);
}
@@ -2076,42 +2157,34 @@ bool StopNode()
return true;
}
static class CNetCleanup
void CNode::NetCleanup()
{
public:
CNetCleanup() {}
// Close sockets
BOOST_FOREACH(CNode* pnode, vNodes)
pnode->CloseSocketDisconnect();
BOOST_FOREACH(ListenSocket& hListenSocket, vhListenSocket)
if (hListenSocket.socket != INVALID_SOCKET)
if (!CloseSocket(hListenSocket.socket))
LogPrintf("CloseSocket(hListenSocket) failed with error %s\n", NetworkErrorString(WSAGetLastError()));
~CNetCleanup()
{
// Close sockets
BOOST_FOREACH(CNode* pnode, vNodes)
if (pnode->hSocket != INVALID_SOCKET)
CloseSocket(pnode->hSocket);
BOOST_FOREACH(ListenSocket& hListenSocket, vhListenSocket)
if (hListenSocket.socket != INVALID_SOCKET)
if (!CloseSocket(hListenSocket.socket))
LogPrintf("CloseSocket(hListenSocket) failed with error %s\n", NetworkErrorString(WSAGetLastError()));
// clean up some globals (to help leak detection)
BOOST_FOREACH(CNode *pnode, vNodes)
delete pnode;
BOOST_FOREACH(CNode *pnode, vNodesDisconnected)
delete pnode;
vNodes.clear();
vNodesDisconnected.clear();
vhListenSocket.clear();
delete semOutbound;
semOutbound = NULL;
delete pnodeLocalHost;
pnodeLocalHost = NULL;
// clean up some globals (to help leak detection)
BOOST_FOREACH(CNode *pnode, vNodes)
delete pnode;
BOOST_FOREACH(CNode *pnode, vNodesDisconnected)
delete pnode;
vNodes.clear();
vNodesDisconnected.clear();
vhListenSocket.clear();
delete semOutbound;
semOutbound = NULL;
delete pnodeLocalHost;
pnodeLocalHost = NULL;
#ifdef _WIN32
// Shutdown Windows Sockets
WSACleanup();
#ifdef WIN32
// Shutdown Windows Sockets
WSACleanup();
#endif
}
}
instance_of_cnetcleanup;
void RelayTransaction(const CTransaction& tx)
{
@@ -2372,22 +2445,65 @@ CNode::CNode(SOCKET hSocketIn, const CAddress& addrIn, const std::string& addrNa
GetNodeSignals().InitializeNode(GetId(), this);
}
bool CNode::GetTlsFallbackNonTls()
{
if (tlsFallbackNonTls == eTlsOption::FALLBACK_UNSET)
{
// one time only setting of static class attribute
if ( GetBoolArg("-tlsfallbacknontls", true))
{
LogPrint("tls", "%s():%d - Non-TLS connections will be used in case of failure of TLS\n",
__func__, __LINE__);
tlsFallbackNonTls = eTlsOption::FALLBACK_TRUE;
}
else
{
LogPrint("tls", "%s():%d - Non-TLS connections will NOT be used in case of failure of TLS\n",
__func__, __LINE__);
tlsFallbackNonTls = eTlsOption::FALLBACK_FALSE;
}
}
return (tlsFallbackNonTls == eTlsOption::FALLBACK_TRUE);
}
bool CNode::GetTlsValidate()
{
if (tlsValidate == eTlsOption::FALLBACK_UNSET)
{
// one time only setting of static class attribute
if ( GetBoolArg("-tlsvalidate", false))
{
LogPrint("tls", "%s():%d - TLS certificates will be validated\n",
__func__, __LINE__);
tlsValidate = eTlsOption::FALLBACK_TRUE;
}
else
{
LogPrint("tls", "%s():%d - TLS certificates will NOT be validated\n",
__func__, __LINE__);
tlsValidate = eTlsOption::FALLBACK_FALSE;
}
}
return (tlsValidate == eTlsOption::FALLBACK_TRUE);
}
CNode::~CNode()
{
// No need to make a lock on cs_hSocket, because before deletion CNode object is removed from the vNodes vector, so any other thread hasn't access to it.
// Removal is synchronized with read and write routines, so all of them will be completed to this moment.
if (hSocket != INVALID_SOCKET)
{
if (ssl)
{
tlsmanager.waitFor(SSL_SHUTDOWN, hSocket, ssl, (DEFAULT_CONNECT_TIMEOUT / 1000));
unsigned long err_code = 0;
tlsmanager.waitFor(SSL_SHUTDOWN, hSocket, ssl, (DEFAULT_CONNECT_TIMEOUT / 1000), err_code);
SSL_free(ssl);
ssl = NULL;
}
CloseSocket(hSocket);
CloseSocket(hSocket);
}
if (pfilter)