LCOV - code coverage report
Current view: top level - source/network - StunClient.cpp (source / functions) Hit Total Coverage
Test: 0 A.D. test coverage report Lines: 12 159 7.5 %
Date: 2021-09-24 14:46:47 Functions: 1 10 10.0 %

          Line data    Source code
       1             : /* Copyright (C) 2021 Wildfire Games.
       2             :  * Copyright (C) 2013-2016 SuperTuxKart-Team.
       3             :  * This file is part of 0 A.D.
       4             :  *
       5             :  * 0 A.D. is free software: you can redistribute it and/or modify
       6             :  * it under the terms of the GNU General Public License as published by
       7             :  * the Free Software Foundation, either version 2 of the License, or
       8             :  * (at your option) any later version.
       9             :  *
      10             :  * 0 A.D. is distributed in the hope that it will be useful,
      11             :  * but WITHOUT ANY WARRANTY; without even the implied warranty of
      12             :  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      13             :  * GNU General Public License for more details.
      14             :  *
      15             :  * You should have received a copy of the GNU General Public License
      16             :  * along with 0 A.D.  If not, see <http://www.gnu.org/licenses/>.
      17             :  */
      18             : 
      19             : #include "precompiled.h"
      20             : 
      21             : #include "StunClient.h"
      22             : 
      23             : #include "lib/byte_order.h"
      24             : #include "ps/CLogger.h"
      25             : #include "ps/ConfigDB.h"
      26             : #include "ps/CStr.h"
      27             : 
      28             : #include "lib/external_libraries/enet.h"
      29             : 
      30             : #include <chrono>
      31             : #include <vector>
      32             : #include <thread>
      33             : 
      34             : namespace StunClient
      35             : {
      36             : 
      37             : /**
      38             :  * These constants are defined in Section 6 of RFC 5389.
      39             :  */
      40             : const u32 m_MagicCookie = 0x2112A442;
      41             : const u16 m_MethodTypeBinding = 0x01;
      42             : const u32 m_BindingSuccessResponse = 0x0101;
      43             : 
      44             : /**
      45             :  * Bit determining whether comprehension of an attribute is optional.
      46             :  * Described in Section 15 of RFC 5389.
      47             :  */
      48             : const u16 m_ComprehensionOptional = 0x1 << 15;
      49             : 
      50             : /**
      51             :  * Bit determining whether the bit was assigned by IETF Review.
      52             :  * Described in section 18.1. of  RFC 5389.
      53             :  */
      54             : const u16 m_IETFReview = 0x1 << 14;
      55             : 
      56             : /**
      57             :  * These constants are defined in Section 15.1 of RFC 5389.
      58             :  */
      59             : const u8 m_IPAddressFamilyIPv4 = 0x01;
      60             : 
      61             : /**
      62             :  * These constants are defined in Section 18.2 of RFC 5389.
      63             :  */
      64             : const u16 m_AttrTypeMappedAddress = 0x001;
      65             : const u16 m_AttrTypeXORMappedAddress = 0x0020;
      66             : 
      67             : /**
      68             :  * Described in section 3 of RFC 5389.
      69             :  */
      70             : u8 m_TransactionID[12];
      71             : 
      72             : ENetAddress m_StunServer;
      73             : 
      74             : /**
      75             :  * Public IP + port discovered via the STUN transaction.
      76             :  */
      77             : ENetAddress m_PublicAddress;
      78             : 
      79             : /**
      80             :  * Push POD data to a network-byte-order buffer.
      81             :  * TODO: this should be optimised & moved to byte_order.h
      82             :  */
      83             : template<typename T, size_t n = sizeof(T)>
      84           0 : void AddToBuffer(std::vector<u8>& buffer, const T value)
      85             : {
      86             :     static_assert(std::is_pod_v<T>, "T must be POD");
      87           0 :     buffer.reserve(buffer.size() + n);
      88             :     // std::byte* can alias anything so this is legal.
      89           0 :     const std::byte* ptr = reinterpret_cast<const std::byte*>(&value);
      90           0 :     for (size_t a = 0; a < n; ++a)
      91             : #if BYTE_ORDER == LITTLE_ENDIAN
      92           0 :         buffer.push_back(static_cast<u8>(*(ptr + n - 1 - a)));
      93             : #else
      94             :         buffer.push_back(static_cast<u8>(*(ptr + a)));
      95             : #endif
      96           0 : }
      97           0 : 
      98             : /**
      99             :  * Read POD data from a network-byte-order buffer.
     100           0 :  * TODO: this should be optimised & moved to byte_order.h
     101             :  */
     102           0 : template<typename T, size_t n = sizeof(T)>
     103           0 : bool GetFromBuffer(const std::vector<u8>& buffer, u32& offset, T& result)
     104             : {
     105           0 :     static_assert(std::is_pod_v<T>, "T must be POD");
     106             :     if (offset + n > buffer.size())
     107             :         return false;
     108             : 
     109           0 :     // std::byte* can alias anything so this is legal.
     110           0 :     std::byte* ptr = reinterpret_cast<std::byte*>(&result);
     111             :     for (size_t a = 0; a < n; ++a)
     112             : #if BYTE_ORDER == LITTLE_ENDIAN
     113           0 :         *ptr++ = static_cast<std::byte>(buffer[offset + n - 1 - a]);
     114             : #else
     115           0 :         *ptr++ = static_cast<std::byte>(buffer[offset + a]);
     116           0 : #endif
     117             : 
     118           0 :     offset += n;
     119             :     return true;
     120             : }
     121             : 
     122           0 : void SendStunRequest(ENetHost& transactionHost, ENetAddress addr)
     123             : {
     124             :     std::vector<u8> buffer;
     125             :     AddToBuffer<u16>(buffer, m_MethodTypeBinding);
     126             :     AddToBuffer<u16>(buffer, 0); // length
     127             :     AddToBuffer<u32>(buffer, m_MagicCookie);
     128             : 
     129             :     for (std::size_t i = 0; i < sizeof(m_TransactionID); ++i)
     130             :     {
     131             :         u8 random_byte = rand() % 256;
     132           0 :         buffer.push_back(random_byte);
     133             :         m_TransactionID[i] = random_byte;
     134             :     }
     135             : 
     136             :     ENetBuffer enetBuffer;
     137           0 :     enetBuffer.data = buffer.data();
     138             :     enetBuffer.dataLength = buffer.size();
     139           0 :     enet_socket_send(transactionHost.socket, &addr, &enetBuffer, 1);
     140             : }
     141             : 
     142             : /**
     143             :  * Creates a STUN request and sends it to a STUN server.
     144           0 :  * The request is sent through transactionHost, from which the answer
     145             :  * will be retrieved by ReceiveStunResponse and interpreted by ParseStunResponse.
     146             :  */
     147             : bool CreateStunRequest(ENetHost& transactionHost)
     148           0 : {
     149             :     CStr server_name;
     150           0 :     int port;
     151           0 :     CFG_GET_VAL("lobby.stun.server", server_name);
     152           0 :     CFG_GET_VAL("lobby.stun.port", port);
     153           0 : 
     154             :     LOGMESSAGE("StunClient: Using STUN server %s:%d\n", server_name.c_str(), port);
     155           0 : 
     156             :     ENetAddress addr;
     157           0 :     addr.port = port;
     158           0 :     if (enet_address_set_host(&addr, server_name.c_str()) == -1)
     159           0 :         return false;
     160             : 
     161             :     m_StunServer = addr;
     162           0 : 
     163           0 :     StunClient::SendStunRequest(transactionHost, addr);
     164           0 : 
     165           0 :     return true;
     166           0 : }
     167             : 
     168             : /**
     169             :  * Gets the response from the STUN server and checks it for its validity.
     170             :  */
     171             : bool ReceiveStunResponse(ENetHost& transactionHost, std::vector<u8>& buffer)
     172             : {
     173           0 :     // TransportAddress sender;
     174             :     const int LEN = 2048;
     175           0 :     char input_buffer[LEN];
     176           0 : 
     177           0 :     memset(input_buffer, 0, LEN);
     178           0 : 
     179             :     ENetBuffer enetBuffer;
     180           0 :     enetBuffer.data = input_buffer;
     181             :     enetBuffer.dataLength = LEN;
     182           0 : 
     183           0 :     ENetAddress sender = m_StunServer;
     184           0 :     int len = enet_socket_receive(transactionHost.socket, &sender, &enetBuffer, 1);
     185             : 
     186             :     int delay = 200;
     187           0 :     CFG_GET_VAL("lobby.stun.delay", delay);
     188             : 
     189           0 :     // Wait to receive the message because enet sockets are non-blocking
     190             :     const int max_tries = 5;
     191             :     for (int count = 0; len <= 0 && (count < max_tries || max_tries == -1); ++count)
     192             :     {
     193             :         std::this_thread::sleep_for(std::chrono::milliseconds(delay));
     194             :         len = enet_socket_receive(transactionHost.socket, &sender, &enetBuffer, 1);
     195             :     }
     196             : 
     197           0 :     if (len <= 0)
     198             :     {
     199             :         LOGERROR("ReceiveStunResponse: recvfrom error (%d): %s", errno, strerror(errno));
     200           0 :         return false;
     201           0 :     }
     202             : 
     203           0 :     if (memcmp(&sender, &m_StunServer, sizeof(m_StunServer)) != 0)
     204             :         LOGERROR("ReceiveStunResponse: Received stun response from different address: %d.%d.%d.%d:%d %s",
     205           0 :             (sender.host >> 24) & 0xff,
     206           0 :             (sender.host >> 16) & 0xff,
     207           0 :             (sender.host >>  8) & 0xff,
     208             :             (sender.host >>  0) & 0xff,
     209           0 :             sender.port,
     210           0 :             input_buffer);
     211             : 
     212           0 :     // Convert to network string.
     213           0 :     buffer.resize(len);
     214             :     memcpy(buffer.data(), reinterpret_cast<u8*>(input_buffer), len);
     215             : 
     216           0 :     return true;
     217           0 : }
     218             : 
     219           0 : bool ParseStunResponse(const std::vector<u8>& buffer)
     220           0 : {
     221             :     u32 offset = 0;
     222             : 
     223           0 :     u16 responseType = 0;
     224             :     if (!GetFromBuffer(buffer, offset, responseType) || responseType != m_BindingSuccessResponse)
     225           0 :     {
     226           0 :         LOGERROR("STUN response isn't a binding success response");
     227             :         return false;
     228             :     }
     229           0 : 
     230           0 :     // Ignore message size
     231             :     offset += 2;
     232             : 
     233             :     u32 cookie = 0;
     234             :     if (!GetFromBuffer(buffer, offset, cookie) || cookie != m_MagicCookie)
     235             :     {
     236             :         LOGERROR("STUN response doesn't contain the magic cookie");
     237             :         return false;
     238             :     }
     239           0 : 
     240           0 :     for (std::size_t i = 0; i < sizeof(m_TransactionID); ++i)
     241             :     {
     242           0 :         u8 transactionChar = 0;
     243             :         if (!GetFromBuffer(buffer, offset, transactionChar) || transactionChar != m_TransactionID[i])
     244             :         {
     245           0 :             LOGERROR("STUN response doesn't contain the transaction ID");
     246             :             return false;
     247           0 :         }
     248             :     }
     249           0 : 
     250           0 :     while (offset < buffer.size())
     251             :     {
     252           0 :         u16 type = 0;
     253           0 :         u16 size = 0;
     254             :         if (!GetFromBuffer(buffer, offset, type) ||
     255             :             !GetFromBuffer(buffer, offset, size))
     256             :         {
     257           0 :             LOGERROR("STUN response contains invalid attribute");
     258             :             return false;
     259           0 :         }
     260           0 : 
     261             :         // The first two bits are irrelevant to the type
     262           0 :         type &= ~(m_ComprehensionOptional | m_IETFReview);
     263           0 : 
     264             :         switch (type)
     265             :         {
     266           0 :         case m_AttrTypeMappedAddress:
     267             :         case m_AttrTypeXORMappedAddress:
     268           0 :         {
     269           0 :             if (size != 8)
     270             :             {
     271           0 :                 LOGERROR("Invalid STUN Mapped Address length");
     272           0 :                 return false;
     273             :             }
     274             : 
     275             :             // Ignore the first byte as mentioned in Section 15.1 of RFC 5389.
     276           0 :             ++offset;
     277             : 
     278           0 :             u8 ipFamily = 0;
     279           0 :             if (!GetFromBuffer(buffer, offset, ipFamily) || ipFamily != m_IPAddressFamilyIPv4)
     280           0 :             {
     281           0 :                 LOGERROR("Unsupported address family, IPv4 is expected");
     282             :                 return false;
     283           0 :             }
     284           0 : 
     285             :             u16 port = 0;
     286             :             u32 ip = 0;
     287             :             if (!GetFromBuffer(buffer, offset, port) ||
     288           0 :                 !GetFromBuffer(buffer, offset, ip))
     289             :             {
     290           0 :                 LOGERROR("Mapped address doesn't contain IP and port");
     291             :                 return false;
     292           0 :             }
     293           0 : 
     294           0 :             // Obfuscation is described in Section 15.2 of RFC 5389.
     295           0 :             if (type == m_AttrTypeXORMappedAddress)
     296             :             {
     297           0 :                 port ^= m_MagicCookie >> 16;
     298           0 :                 ip ^= m_MagicCookie;
     299             :             }
     300             : 
     301             :             // ENetAddress takes a host byte-order port and network byte-order IP.
     302           0 :             // Network byte order is big endian, so convert appropriately.
     303             :             m_PublicAddress.host = to_be32(ip);
     304           0 :             m_PublicAddress.port = port;
     305           0 : 
     306             :             break;
     307           0 :         }
     308           0 :         default:
     309             :         {
     310             :             // We don't care about other attributes at all
     311           0 : 
     312           0 :             // Skip attribute
     313           0 :             offset += size;
     314           0 : 
     315             :             // Skip padding
     316           0 :             int padding = size % 4;
     317           0 :             if (padding)
     318             :                 offset += 4 - padding;
     319             :             break;
     320             :         }
     321           0 :         }
     322             :     }
     323           0 : 
     324           0 :     return true;
     325             : }
     326             : 
     327             : bool STUNRequestAndResponse(ENetHost& transactionHost)
     328             : {
     329           0 :     if (!CreateStunRequest(transactionHost))
     330           0 :         return false;
     331             : 
     332           0 :     std::vector<u8> buffer;
     333             :     return ReceiveStunResponse(transactionHost, buffer) &&
     334           0 :            ParseStunResponse(buffer);
     335           0 : }
     336             : 
     337             : bool FindPublicIP(ENetHost& transactionHost, CStr& ip, u16& port)
     338             : {
     339           0 :     if (!STUNRequestAndResponse(transactionHost))
     340             :         return false;
     341             : 
     342           0 :     // Convert m_IP to string
     343           0 :     char ipStr[256] = "(error)";
     344           0 :     enet_address_get_host_ip(&m_PublicAddress, ipStr, ARRAY_SIZE(ipStr));
     345             : 
     346             :     ip = ipStr;
     347             :     port = m_PublicAddress.port;
     348             : 
     349             :     LOGMESSAGE("StunClient: external IP address is %s:%i", ip.c_str(), port);
     350             : 
     351             :     return true;
     352             : }
     353           0 : 
     354             : void SendHolePunchingMessages(ENetHost& enetClient, const std::string& serverAddress, u16 serverPort)
     355           0 : {
     356             :     // Convert ip string to int64
     357             :     ENetAddress addr;
     358           0 :     addr.port = serverPort;
     359           0 :     enet_address_set_host(&addr, serverAddress.c_str());
     360           0 : 
     361             :     int delay = 200;
     362             :     CFG_GET_VAL("lobby.stun.delay", delay);
     363           0 : 
     364             :     // Send an UDP message from enet host to ip:port
     365           0 :     for (int i = 0; i < 3; ++i)
     366             :     {
     367             :         SendStunRequest(enetClient, addr);
     368             :         std::this_thread::sleep_for(std::chrono::milliseconds(delay));
     369           0 :     }
     370           0 : }
     371             : 
     372           0 : bool FindLocalIP(CStr& ip)
     373           0 : {
     374             :     // Open an UDP socket.
     375           0 :     ENetSocket socket = enet_socket_create(ENET_SOCKET_TYPE_DATAGRAM);
     376             : 
     377           0 :     ENetAddress addr;
     378             :     addr.port = 9; // Use the debug port (which we pick does not matter).
     379             :     // Connect to a random address. It does not need to be valid, only to not be the loopback address.
     380           0 :     if (enet_address_set_host(&addr, "100.0.100.0") == -1)
     381             :         return false;
     382             : 
     383           0 :     // Connect the socket. Being UDP, there is no actual outgoing traffic, this just binds it
     384           0 :     // to a valid port locally, allowing us to get the local IP of the machine.
     385           0 :     if (enet_socket_connect(socket, &addr) == -1)
     386             :         return false;
     387           0 : 
     388           0 :     // Fetch the local port & IP.
     389             :     if (enet_socket_get_address(socket, &addr) == -1)
     390             :         return false;
     391           0 : 
     392             :     enet_socket_destroy(socket);
     393           0 : 
     394           0 :     // Convert to a human readable string.
     395             :     char buf[50];
     396           0 :     if (enet_address_get_host_ip(&addr, buf, ARRAY_SIZE(buf)) == -1)
     397             :         return false;
     398           1 : 
     399             :     ip = buf;
     400             : 
     401           1 :     return true;
     402             : }
     403           1 : }

Generated by: LCOV version 1.13