LCOV - code coverage report
Current view: top level - source/network - StunClient.cpp (source / functions) Hit Total Coverage
Test: 0 A.D. test coverage report Lines: 10 150 6.7 %
Date: 2023-01-19 00:18:29 Functions: 1 13 7.7 %

          Line data    Source code
       1             : /* Copyright (C) 2022 Wildfire Games.
       2             :  * Copyright (C) 2013-2016 SuperTuxKart-Team.
       3             :  * This file is part of 0 A.D.
       4             :  *
       5             :  * 0 A.D. is free software: you can redistribute it and/or modify
       6             :  * it under the terms of the GNU General Public License as published by
       7             :  * the Free Software Foundation, either version 2 of the License, or
       8             :  * (at your option) any later version.
       9             :  *
      10             :  * 0 A.D. is distributed in the hope that it will be useful,
      11             :  * but WITHOUT ANY WARRANTY; without even the implied warranty of
      12             :  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      13             :  * GNU General Public License for more details.
      14             :  *
      15             :  * You should have received a copy of the GNU General Public License
      16             :  * along with 0 A.D.  If not, see <http://www.gnu.org/licenses/>.
      17             :  */
      18             : 
      19             : #include "precompiled.h"
      20             : 
      21             : #include "StunClient.h"
      22             : 
      23             : #include "lib/byte_order.h"
      24             : #include "lib/external_libraries/enet.h"
      25             : #include "ps/CLogger.h"
      26             : #include "ps/ConfigDB.h"
      27             : #include "ps/CStr.h"
      28             : 
      29             : #include <chrono>
      30             : #include <cstddef>
      31             : #include <thread>
      32             : #include <vector>
      33             : 
      34             : namespace StunClient
      35             : {
      36             : 
      37             : /**
      38             :  * These constants are defined in Section 6 of RFC 5389.
      39             :  */
      40             : const u32 m_MagicCookie = 0x2112A442;
      41             : const u16 m_MethodTypeBinding = 0x01;
      42             : const u32 m_BindingSuccessResponse = 0x0101;
      43             : 
      44             : /**
      45             :  * Bit determining whether comprehension of an attribute is optional.
      46             :  * Described in Section 15 of RFC 5389.
      47             :  */
      48             : const u16 m_ComprehensionOptional = 0x1 << 15;
      49             : 
      50             : /**
      51             :  * Bit determining whether the bit was assigned by IETF Review.
      52             :  * Described in section 18.1. of  RFC 5389.
      53             :  */
      54             : const u16 m_IETFReview = 0x1 << 14;
      55             : 
      56             : /**
      57             :  * These constants are defined in Section 15.1 of RFC 5389.
      58             :  */
      59             : const u8 m_IPAddressFamilyIPv4 = 0x01;
      60             : 
      61             : /**
      62             :  * These constants are defined in Section 18.2 of RFC 5389.
      63             :  */
      64             : const u16 m_AttrTypeMappedAddress = 0x001;
      65             : const u16 m_AttrTypeXORMappedAddress = 0x0020;
      66             : 
      67             : /**
      68             :  * Described in section 3 of RFC 5389.
      69             :  */
      70             : u8 m_TransactionID[12];
      71             : 
      72             : ENetAddress m_StunServer;
      73             : 
      74             : /**
      75             :  * Public IP + port discovered via the STUN transaction.
      76             :  */
      77             : ENetAddress m_PublicAddress;
      78             : 
      79             : /**
      80             :  * Push POD data to a network-byte-order buffer.
      81             :  * TODO: this should be optimised & moved to byte_order.h
      82             :  */
      83             : template<typename T, size_t n = sizeof(T)>
      84           0 : void AddToBuffer(std::vector<u8>& buffer, const T value)
      85             : {
      86             :     static_assert(std::is_pod_v<T>, "T must be POD");
      87           0 :     buffer.reserve(buffer.size() + n);
      88             :     // std::byte* can alias anything so this is legal.
      89           0 :     const std::byte* ptr = reinterpret_cast<const std::byte*>(&value);
      90           0 :     for (size_t a = 0; a < n; ++a)
      91             : #if BYTE_ORDER == LITTLE_ENDIAN
      92           0 :         buffer.push_back(static_cast<u8>(*(ptr + n - 1 - a)));
      93             : #else
      94             :         buffer.push_back(static_cast<u8>(*(ptr + a)));
      95             : #endif
      96           0 : }
      97             : 
      98             : /**
      99             :  * Read POD data from a network-byte-order buffer.
     100             :  * TODO: this should be optimised & moved to byte_order.h
     101             :  */
     102             : template<typename T, size_t n = sizeof(T)>
     103           0 : bool GetFromBuffer(const std::vector<u8>& buffer, u32& offset, T& result)
     104             : {
     105             :     static_assert(std::is_pod_v<T>, "T must be POD");
     106           0 :     if (offset + n > buffer.size())
     107           0 :         return false;
     108             : 
     109             :     // std::byte* can alias anything so this is legal.
     110           0 :     std::byte* ptr = reinterpret_cast<std::byte*>(&result);
     111           0 :     for (size_t a = 0; a < n; ++a)
     112             : #if BYTE_ORDER == LITTLE_ENDIAN
     113           0 :         *ptr++ = static_cast<std::byte>(buffer[offset + n - 1 - a]);
     114             : #else
     115             :         *ptr++ = static_cast<std::byte>(buffer[offset + a]);
     116             : #endif
     117             : 
     118           0 :     offset += n;
     119           0 :     return true;
     120             : }
     121             : 
     122           0 : void SendStunRequest(ENetHost& transactionHost, ENetAddress addr)
     123             : {
     124           0 :     std::vector<u8> buffer;
     125           0 :     AddToBuffer<u16>(buffer, m_MethodTypeBinding);
     126           0 :     AddToBuffer<u16>(buffer, 0); // length
     127           0 :     AddToBuffer<u32>(buffer, m_MagicCookie);
     128             : 
     129           0 :     for (std::size_t i = 0; i < sizeof(m_TransactionID); ++i)
     130             :     {
     131           0 :         u8 random_byte = rand() % 256;
     132           0 :         buffer.push_back(random_byte);
     133           0 :         m_TransactionID[i] = random_byte;
     134             :     }
     135             : 
     136             :     ENetBuffer enetBuffer;
     137           0 :     enetBuffer.data = buffer.data();
     138           0 :     enetBuffer.dataLength = buffer.size();
     139           0 :     enet_socket_send(transactionHost.socket, &addr, &enetBuffer, 1);
     140           0 : }
     141             : 
     142             : /**
     143             :  * Creates a STUN request and sends it to a STUN server.
     144             :  * The request is sent through transactionHost, from which the answer
     145             :  * will be retrieved by ReceiveStunResponse and interpreted by ParseStunResponse.
     146             :  */
     147           0 : bool CreateStunRequest(ENetHost& transactionHost)
     148             : {
     149           0 :     CStr server_name;
     150             :     int port;
     151           0 :     CFG_GET_VAL("lobby.stun.server", server_name);
     152           0 :     CFG_GET_VAL("lobby.stun.port", port);
     153             : 
     154           0 :     LOGMESSAGE("StunClient: Using STUN server %s:%d\n", server_name.c_str(), port);
     155             : 
     156             :     ENetAddress addr;
     157           0 :     addr.port = port;
     158           0 :     if (enet_address_set_host(&addr, server_name.c_str()) == -1)
     159           0 :         return false;
     160             : 
     161           0 :     m_StunServer = addr;
     162             : 
     163           0 :     StunClient::SendStunRequest(transactionHost, addr);
     164             : 
     165           0 :     return true;
     166             : }
     167             : 
     168             : /**
     169             :  * Gets the response from the STUN server and checks it for its validity.
     170             :  */
     171           0 : bool ReceiveStunResponse(ENetHost& transactionHost, std::vector<u8>& buffer)
     172             : {
     173             :     // TransportAddress sender;
     174           0 :     const int LEN = 2048;
     175             :     char input_buffer[LEN];
     176             : 
     177           0 :     memset(input_buffer, 0, LEN);
     178             : 
     179             :     ENetBuffer enetBuffer;
     180           0 :     enetBuffer.data = input_buffer;
     181           0 :     enetBuffer.dataLength = LEN;
     182             : 
     183           0 :     ENetAddress sender = m_StunServer;
     184           0 :     int len = enet_socket_receive(transactionHost.socket, &sender, &enetBuffer, 1);
     185             : 
     186           0 :     int delay = 200;
     187           0 :     CFG_GET_VAL("lobby.stun.delay", delay);
     188             : 
     189             :     // Wait to receive the message because enet sockets are non-blocking
     190           0 :     const int max_tries = 5;
     191           0 :     for (int count = 0; len <= 0 && (count < max_tries || max_tries == -1); ++count)
     192             :     {
     193           0 :         std::this_thread::sleep_for(std::chrono::milliseconds(delay));
     194           0 :         len = enet_socket_receive(transactionHost.socket, &sender, &enetBuffer, 1);
     195             :     }
     196             : 
     197           0 :     if (len <= 0)
     198             :     {
     199           0 :         LOGERROR("ReceiveStunResponse: recvfrom error (%d): %s", errno, strerror(errno));
     200           0 :         return false;
     201             :     }
     202             : 
     203           0 :     if (memcmp(&sender, &m_StunServer, sizeof(m_StunServer)) != 0)
     204           0 :         LOGERROR("ReceiveStunResponse: Received stun response from different address: %d.%d.%d.%d:%d %s",
     205             :             (sender.host >> 24) & 0xff,
     206             :             (sender.host >> 16) & 0xff,
     207             :             (sender.host >>  8) & 0xff,
     208             :             (sender.host >>  0) & 0xff,
     209             :             sender.port,
     210             :             input_buffer);
     211             : 
     212             :     // Convert to network string.
     213           0 :     buffer.resize(len);
     214           0 :     memcpy(buffer.data(), reinterpret_cast<u8*>(input_buffer), len);
     215             : 
     216           0 :     return true;
     217             : }
     218             : 
     219           0 : bool ParseStunResponse(const std::vector<u8>& buffer)
     220             : {
     221           0 :     u32 offset = 0;
     222             : 
     223           0 :     u16 responseType = 0;
     224           0 :     if (!GetFromBuffer(buffer, offset, responseType) || responseType != m_BindingSuccessResponse)
     225             :     {
     226           0 :         LOGERROR("STUN response isn't a binding success response");
     227           0 :         return false;
     228             :     }
     229             : 
     230             :     // Ignore message size
     231           0 :     offset += 2;
     232             : 
     233           0 :     u32 cookie = 0;
     234           0 :     if (!GetFromBuffer(buffer, offset, cookie) || cookie != m_MagicCookie)
     235             :     {
     236           0 :         LOGERROR("STUN response doesn't contain the magic cookie");
     237           0 :         return false;
     238             :     }
     239             : 
     240           0 :     for (std::size_t i = 0; i < sizeof(m_TransactionID); ++i)
     241             :     {
     242           0 :         u8 transactionChar = 0;
     243           0 :         if (!GetFromBuffer(buffer, offset, transactionChar) || transactionChar != m_TransactionID[i])
     244             :         {
     245           0 :             LOGERROR("STUN response doesn't contain the transaction ID");
     246           0 :             return false;
     247             :         }
     248             :     }
     249             : 
     250           0 :     while (offset < buffer.size())
     251             :     {
     252           0 :         u16 type = 0;
     253           0 :         u16 size = 0;
     254           0 :         if (!GetFromBuffer(buffer, offset, type) ||
     255           0 :             !GetFromBuffer(buffer, offset, size))
     256             :         {
     257           0 :             LOGERROR("STUN response contains invalid attribute");
     258           0 :             return false;
     259             :         }
     260             : 
     261             :         // The first two bits are irrelevant to the type
     262           0 :         type &= ~(m_ComprehensionOptional | m_IETFReview);
     263             : 
     264           0 :         switch (type)
     265             :         {
     266           0 :         case m_AttrTypeMappedAddress:
     267             :         case m_AttrTypeXORMappedAddress:
     268             :         {
     269           0 :             if (size != 8)
     270             :             {
     271           0 :                 LOGERROR("Invalid STUN Mapped Address length");
     272           0 :                 return false;
     273             :             }
     274             : 
     275             :             // Ignore the first byte as mentioned in Section 15.1 of RFC 5389.
     276           0 :             ++offset;
     277             : 
     278           0 :             u8 ipFamily = 0;
     279           0 :             if (!GetFromBuffer(buffer, offset, ipFamily) || ipFamily != m_IPAddressFamilyIPv4)
     280             :             {
     281           0 :                 LOGERROR("Unsupported address family, IPv4 is expected");
     282           0 :                 return false;
     283             :             }
     284             : 
     285           0 :             u16 port = 0;
     286           0 :             u32 ip = 0;
     287           0 :             if (!GetFromBuffer(buffer, offset, port) ||
     288           0 :                 !GetFromBuffer(buffer, offset, ip))
     289             :             {
     290           0 :                 LOGERROR("Mapped address doesn't contain IP and port");
     291           0 :                 return false;
     292             :             }
     293             : 
     294             :             // Obfuscation is described in Section 15.2 of RFC 5389.
     295           0 :             if (type == m_AttrTypeXORMappedAddress)
     296             :             {
     297           0 :                 port ^= m_MagicCookie >> 16;
     298           0 :                 ip ^= m_MagicCookie;
     299             :             }
     300             : 
     301             :             // ENetAddress takes a host byte-order port and network byte-order IP.
     302             :             // Network byte order is big endian, so convert appropriately.
     303           0 :             m_PublicAddress.host = to_be32(ip);
     304           0 :             m_PublicAddress.port = port;
     305             : 
     306           0 :             break;
     307             :         }
     308           0 :         default:
     309             :         {
     310             :             // We don't care about other attributes at all
     311             : 
     312             :             // Skip attribute
     313           0 :             offset += size;
     314             : 
     315             :             // Skip padding
     316           0 :             int padding = size % 4;
     317           0 :             if (padding)
     318           0 :                 offset += 4 - padding;
     319           0 :             break;
     320             :         }
     321             :         }
     322             :     }
     323             : 
     324           0 :     return true;
     325             : }
     326             : 
     327           0 : bool STUNRequestAndResponse(ENetHost& transactionHost)
     328             : {
     329           0 :     if (!CreateStunRequest(transactionHost))
     330           0 :         return false;
     331             : 
     332           0 :     std::vector<u8> buffer;
     333           0 :     return ReceiveStunResponse(transactionHost, buffer) &&
     334           0 :            ParseStunResponse(buffer);
     335             : }
     336             : 
     337           0 : bool FindPublicIP(ENetHost& transactionHost, CStr& ip, u16& port)
     338             : {
     339           0 :     if (!STUNRequestAndResponse(transactionHost))
     340           0 :         return false;
     341             : 
     342             :     // Convert m_IP to string
     343           0 :     char ipStr[256] = "(error)";
     344           0 :     enet_address_get_host_ip(&m_PublicAddress, ipStr, ARRAY_SIZE(ipStr));
     345             : 
     346           0 :     ip = ipStr;
     347           0 :     port = m_PublicAddress.port;
     348             : 
     349           0 :     LOGMESSAGE("StunClient: external IP address is %s:%i", ip.c_str(), port);
     350             : 
     351           0 :     return true;
     352             : }
     353             : 
     354           0 : void SendHolePunchingMessages(ENetHost& enetClient, const std::string& serverAddress, u16 serverPort)
     355             : {
     356             :     // Convert ip string to int64
     357             :     ENetAddress addr;
     358           0 :     addr.port = serverPort;
     359           0 :     enet_address_set_host(&addr, serverAddress.c_str());
     360             : 
     361           0 :     int delay = 200;
     362           0 :     CFG_GET_VAL("lobby.stun.delay", delay);
     363             : 
     364             :     // Send an UDP message from enet host to ip:port
     365           0 :     for (int i = 0; i < 3; ++i)
     366             :     {
     367           0 :         SendStunRequest(enetClient, addr);
     368           0 :         std::this_thread::sleep_for(std::chrono::milliseconds(delay));
     369             :     }
     370           0 : }
     371             : 
     372           1 : bool FindLocalIP(CStr& ip)
     373             : {
     374             :     // Open an UDP socket.
     375           1 :     ENetSocket socket = enet_socket_create(ENET_SOCKET_TYPE_DATAGRAM);
     376             : 
     377             :     ENetAddress addr;
     378           1 :     addr.port = 9; // Use the debug port (which we pick does not matter).
     379             :     // Connect to a random address. It does not need to be valid, only to not be the loopback address.
     380           1 :     if (enet_address_set_host(&addr, "100.0.100.0") == -1)
     381           0 :         return false;
     382             : 
     383             :     // Connect the socket. Being UDP, there is no actual outgoing traffic, this just binds it
     384             :     // to a valid port locally, allowing us to get the local IP of the machine.
     385           1 :     if (enet_socket_connect(socket, &addr) == -1)
     386           0 :         return false;
     387             : 
     388             :     // Fetch the local port & IP.
     389           1 :     if (enet_socket_get_address(socket, &addr) == -1)
     390           0 :         return false;
     391             : 
     392           1 :     enet_socket_destroy(socket);
     393             : 
     394             :     // Convert to a human readable string.
     395             :     char buf[50];
     396           1 :     if (enet_address_get_host_ip(&addr, buf, ARRAY_SIZE(buf)) == -1)
     397           0 :         return false;
     398             : 
     399           1 :     ip = buf;
     400             : 
     401           1 :     return true;
     402             : }
     403             : }

Generated by: LCOV version 1.13