From 6208636d69b5fde1d01d9361748c565d35816e20 Mon Sep 17 00:00:00 2001 From: Kae <80987908+Novaenia@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:26:38 +1100 Subject: [PATCH] Fixes --- source/core/StarDataStreamDevices.cpp | 2 + source/core/StarDataStreamDevices.hpp | 4 +- source/core/StarZSTDCompression.cpp | 88 ++++++++++++--------------- source/core/StarZSTDCompression.hpp | 2 + source/game/StarNetPacketSocket.cpp | 58 +++++++++--------- 5 files changed, 76 insertions(+), 78 deletions(-) diff --git a/source/core/StarDataStreamDevices.cpp b/source/core/StarDataStreamDevices.cpp index b769167..6cf05ea 100644 --- a/source/core/StarDataStreamDevices.cpp +++ b/source/core/StarDataStreamDevices.cpp @@ -132,6 +132,8 @@ void DataStreamBuffer::writeData(char const* data, size_t len) { DataStreamExternalBuffer::DataStreamExternalBuffer() : m_buffer() {} +DataStreamExternalBuffer::DataStreamExternalBuffer(ByteArray const& byteArray) : DataStreamExternalBuffer(byteArray.ptr(), byteArray.size()) {} + DataStreamExternalBuffer::DataStreamExternalBuffer(DataStreamBuffer const& buffer) : DataStreamExternalBuffer(buffer.ptr(), buffer.size()) {} DataStreamExternalBuffer::DataStreamExternalBuffer(char const* externalData, size_t len) : DataStreamExternalBuffer() { diff --git a/source/core/StarDataStreamDevices.hpp b/source/core/StarDataStreamDevices.hpp index 5d404ab..3f34a72 100644 --- a/source/core/StarDataStreamDevices.hpp +++ b/source/core/StarDataStreamDevices.hpp @@ -126,7 +126,9 @@ private: class DataStreamExternalBuffer : public DataStream { public: DataStreamExternalBuffer(); - DataStreamExternalBuffer(DataStreamBuffer const& buffer); + explicit DataStreamExternalBuffer(ByteArray const& byteArray); + explicit DataStreamExternalBuffer(DataStreamBuffer const& buffer); + DataStreamExternalBuffer(DataStreamExternalBuffer const& buffer) = default; DataStreamExternalBuffer(char const* externalData, size_t len); diff --git a/source/core/StarZSTDCompression.cpp b/source/core/StarZSTDCompression.cpp index 733b182..5c66c9a 100644 --- a/source/core/StarZSTDCompression.cpp +++ b/source/core/StarZSTDCompression.cpp @@ -7,6 +7,7 @@ CompressionStream::CompressionStream() : m_cStream(ZSTD_createCStream()) { ZSTD_CCtx_setParameter(m_cStream, ZSTD_c_enableLongDistanceMatching, 1); ZSTD_CCtx_setParameter(m_cStream, ZSTD_c_windowLog, 24); ZSTD_initCStream(m_cStream, 2); + m_output.resize(ZSTD_CStreamOutSize()); } CompressionStream::~CompressionStream() { ZSTD_freeCStream(m_cStream); } @@ -14,39 +15,33 @@ CompressionStream::~CompressionStream() { ZSTD_freeCStream(m_cStream); } ByteArray CompressionStream::compress(const char* in, size_t inLen) { size_t const cInSize = ZSTD_CStreamInSize (); size_t const cOutSize = ZSTD_CStreamOutSize(); - ByteArray output(cOutSize, 0); - size_t written = 0, read = 0; - while (read < inLen) { - ZSTD_inBuffer inBuffer = {in + read, min(cInSize, inLen - read), 0}; - ZSTD_outBuffer outBuffer = {output.ptr() + written, output.size() - written, 0}; - bool finished = false; - do { - size_t ret = ZSTD_compressStream2(m_cStream, &outBuffer, &inBuffer, ZSTD_e_flush); - if (ZSTD_isError(ret)) { - throw IOException(strf("ZSTD compression error {}", ZSTD_getErrorName(ret))); - break; - } + ZSTD_inBuffer inBuffer = {in, inLen, 0}; + size_t written = 0; + bool finished = false; + do { + ZSTD_outBuffer outBuffer = {m_output.ptr() + written, min(cOutSize, m_output.size() - written), 0}; + size_t ret = ZSTD_compressStream2(m_cStream, &outBuffer, &inBuffer, ZSTD_e_flush); + if (ZSTD_isError(ret)) { + throw IOException(strf("ZSTD compression error {}", ZSTD_getErrorName(ret))); + break; + } - if (outBuffer.pos == outBuffer.size) { - output.resize(output.size() * 2); - outBuffer.dst = output.ptr(); - outBuffer.size = output.size(); - continue; - } - - finished = ret == 0 && inBuffer.pos == inBuffer.size; - } while (!finished); - - read += inBuffer.pos; written += outBuffer.pos; - } - output.resize(written); - return output; + if (outBuffer.pos == outBuffer.size) { + if (written >= m_output.size()) + m_output.resize(m_output.size() * 2); + continue; + } + + finished = ret == 0 && inBuffer.pos == inBuffer.size; + } while (!finished); + return ByteArray(m_output.ptr(), written); } DecompressionStream::DecompressionStream() : m_dStream(ZSTD_createDStream()) { ZSTD_DCtx_setParameter(m_dStream, ZSTD_d_windowLogMax, 25); ZSTD_initDStream(m_dStream); + m_output.resize(ZSTD_DStreamOutSize()); } DecompressionStream::~DecompressionStream() { ZSTD_freeDStream(m_dStream); } @@ -54,31 +49,26 @@ DecompressionStream::~DecompressionStream() { ZSTD_freeDStream(m_dStream); } ByteArray DecompressionStream::decompress(const char* in, size_t inLen) { size_t const dInSize = ZSTD_DStreamInSize (); size_t const dOutSize = ZSTD_DStreamOutSize(); - ByteArray output(dOutSize, 0); - size_t written = 0, read = 0; - while (read < inLen) { - ZSTD_inBuffer inBuffer = {in + read, min(dInSize, inLen - read), 0}; - ZSTD_outBuffer outBuffer = {output.ptr() + written, output.size() - written, 0}; - do { - size_t ret = ZSTD_decompressStream(m_dStream, &outBuffer, &inBuffer); - if (ZSTD_isError(ret)) { - throw IOException(strf("ZSTD decompression error {}", ZSTD_getErrorName(ret))); - break; - } + ZSTD_inBuffer inBuffer = {in, inLen, 0}; + size_t written = 0; + bool finished = false; + do { + ZSTD_outBuffer outBuffer = {m_output.ptr() + written, min(dOutSize, m_output.size() - written), 0}; + size_t ret = ZSTD_decompressStream(m_dStream, &outBuffer, &inBuffer); + if (ZSTD_isError(ret)) { + throw IOException(strf("ZSTD decompression error {}", ZSTD_getErrorName(ret))); + break; + } - if (outBuffer.pos == outBuffer.size) { - output.resize(output.size() * 2); - outBuffer.dst = output.ptr(); - outBuffer.size = output.size(); - continue; - } - } while (inBuffer.pos < inBuffer.size); - - read += inBuffer.pos; written += outBuffer.pos; - } - output.resize(written); - return output; + if (outBuffer.pos == outBuffer.size) { + if (written >= m_output.size()) + m_output.resize(m_output.size() * 2); + continue; + } + finished = inBuffer.pos == inBuffer.size; + } while (!finished); + return ByteArray(m_output.ptr(), written); } } \ No newline at end of file diff --git a/source/core/StarZSTDCompression.hpp b/source/core/StarZSTDCompression.hpp index 77719bf..9c296c1 100644 --- a/source/core/StarZSTDCompression.hpp +++ b/source/core/StarZSTDCompression.hpp @@ -19,6 +19,7 @@ public: private: ZSTD_CStream* m_cStream; + ByteArray m_output; }; inline ByteArray CompressionStream::compress(ByteArray const& in) { @@ -35,6 +36,7 @@ public: private: ZSTD_DStream* m_dStream; + ByteArray m_output; }; inline ByteArray DecompressionStream::decompress(ByteArray const& in) { diff --git a/source/game/StarNetPacketSocket.cpp b/source/game/StarNetPacketSocket.cpp index 2ebaadb..e1feb1a 100644 --- a/source/game/StarNetPacketSocket.cpp +++ b/source/game/StarNetPacketSocket.cpp @@ -212,8 +212,8 @@ List TcpPacketSocket::receivePackets() { uint64_t const PacketBatchLimit = 131072; List packets; try { - DataStreamExternalBuffer ds(m_inputBuffer.ptr(), m_inputBuffer.size()); - bool atLeastOne = false; + DataStreamExternalBuffer ds(m_inputBuffer); + size_t trimPos = 0; while (!ds.atEnd()) { PacketType packetType; uint64_t packetSize = 0; @@ -233,19 +233,19 @@ List TcpPacketSocket::receivePackets() { if (packetSize > PacketSizeLimit) throw IOException::format("{} bytes large {} exceeds max size!", packetSize, PacketTypeNames.getRight(packetType)); - if (packetSize > ds.size() - ds.pos()) + if (packetSize > ds.remaining()) break; - atLeastOne = true; m_incomingStats.mix(packetType, packetSize, !m_useCompressionStream); DataStreamExternalBuffer packetStream(ds.ptr() + ds.pos(), packetSize); ByteArray uncompressed; if (packetCompressed) { - uncompressed = uncompressData(packetStream.ptr() + packetStream.pos(), packetSize, PacketSizeLimit); + uncompressed = uncompressData(packetStream.ptr(), packetSize, PacketSizeLimit); packetStream.reset(uncompressed.ptr(), uncompressed.size()); } ds.seek(packetSize, IOSeek::Relative); + trimPos = ds.pos(); size_t count = 0; do { @@ -262,10 +262,10 @@ List TcpPacketSocket::receivePackets() { packets.append(std::move(packet)); } while (!packetStream.atEnd()); } - if (atLeastOne) - m_inputBuffer.trimLeft(ds.pos()); + if (trimPos) + m_inputBuffer.trimLeft(trimPos); } catch (IOException const& e) { - Logger::warn("I/O error in TcpPacketSocket::readPackets, closing: {}", outputException(e, false)); + Logger::warn("I/O error in TcpPacketSocket::receivePackets, closing: {}", outputException(e, false)); m_inputBuffer.clear(); m_socket->shutdown(); } @@ -282,30 +282,32 @@ bool TcpPacketSocket::writeData() { bool dataSent = false; try { - if (m_useCompressionStream) { - auto compressed = m_compressionStream.compress(m_outputBuffer); - m_outputBuffer.clear(); + if (!m_outputBuffer.empty()) { + if (m_useCompressionStream) { + auto compressed = m_compressionStream.compress(m_outputBuffer); + m_outputBuffer.clear(); - m_compressedBuffer.append(compressed.ptr(), compressed.size()); - size_t written = m_socket->send(m_compressedBuffer.ptr(), m_compressedBuffer.size()); - if (written > 0) { - dataSent = true; - m_compressedBuffer.trimLeft(written); - m_outgoingStats.mix(written); - } - } else { - while (!m_outputBuffer.empty()) { - size_t written = m_socket->send(m_outputBuffer.ptr(), m_outputBuffer.size()); - if (written == 0) - break; - dataSent = true; - m_outputBuffer.trimLeft(written); + m_compressedBuffer.append(compressed.ptr(), compressed.size()); + size_t written = m_socket->send(m_compressedBuffer.ptr(), m_compressedBuffer.size()); + if (written > 0) { + dataSent = true; + m_compressedBuffer.trimLeft(written); + m_outgoingStats.mix(written); + } + } else { + do { + size_t written = m_socket->send(m_outputBuffer.ptr(), m_outputBuffer.size()); + if (written == 0) + break; + dataSent = true; + m_outputBuffer.trimLeft(written); + } while (!m_outputBuffer.empty()); } } } catch (SocketClosedException const& e) { Logger::debug("TcpPacketSocket socket closed: {}", outputException(e, false)); } catch (IOException const& e) { - Logger::warn("I/O error in TcpPacketSocket::sendData: {}", outputException(e, false)); + Logger::warn("I/O error in TcpPacketSocket::writeData: {}", outputException(e, false)); m_socket->shutdown(); } return dataSent; @@ -320,8 +322,8 @@ bool TcpPacketSocket::readData() { if (readAmount == 0) break; dataReceived = true; - m_incomingStats.mix(readAmount); if (m_useCompressionStream) { + m_incomingStats.mix(readAmount); auto decompressed = m_decompressionStream.decompress(readBuffer, readAmount); m_inputBuffer.append(decompressed.ptr(), decompressed.size()); } else { @@ -435,7 +437,7 @@ List P2PPacketSocket::receivePackets() { } while (!packetStream.atEnd()); } } catch (IOException const& e) { - Logger::warn("I/O error in P2PPacketSocket::readPackets, closing: {}", outputException(e, false)); + Logger::warn("I/O error in P2PPacketSocket::receivePackets, closing: {}", outputException(e, false)); m_socket.reset(); } return packets;