use zlib-stream for websocket

This commit is contained in:
ouwou 2020-08-30 18:40:17 -04:00
parent c90c777daa
commit 4e7e5a3063
2 changed files with 18 additions and 34 deletions

View File

@ -4,10 +4,7 @@
DiscordClient::DiscordClient()
: m_http(DiscordAPI)
#ifdef ABADDON_USE_COMPRESSED_SOCKET
, m_decompress_buf(InflateChunkSize)
#endif
{
, m_decompress_buf(InflateChunkSize) {
LoadEventMap();
}
@ -19,6 +16,9 @@ void DiscordClient::Start() {
assert(!m_client_connected);
assert(!m_websocket.IsOpen());
std::memset(&m_zstream, 0, sizeof(m_zstream));
inflateInit2(&m_zstream, MAX_WBITS + 32);
m_client_connected = true;
m_websocket.StartConnection(DiscordGateway);
m_websocket.SetMessageCallback(std::bind(&DiscordClient::HandleGatewayMessageRaw, this, std::placeholders::_1));
@ -28,6 +28,8 @@ void DiscordClient::Stop() {
std::scoped_lock<std::mutex> guard(m_mutex);
if (!m_client_connected) return;
inflateEnd(&m_zstream);
m_heartbeat_waiter.kill();
if (m_heartbeat_thread.joinable()) m_heartbeat_thread.join();
m_client_connected = false;
@ -180,51 +182,40 @@ std::string DiscordClient::DecompressGatewayMessage(std::string str) {
}
void DiscordClient::HandleGatewayMessageRaw(std::string str) {
#ifdef ABADDON_USE_COMPRESSED_SOCKET // fuck you work
// handles multiple zlib compressed messages, calling HandleGatewayMessage when a full message is received
std::vector<uint8_t> buf(str.begin(), str.end());
int len = buf.size();
bool has_suffix = buf[len - 4] == 0x00 && buf[len - 3] == 0x00 && buf[len - 2] == 0xFF && buf[len - 1] == 0xFF;
m_compressed_buf.insert(m_compressed_buf.begin(), buf.begin(), buf.end());
m_compressed_buf.insert(m_compressed_buf.end(), buf.begin(), buf.end());
if (!has_suffix) return;
z_stream z;
std::memset(&z, 0, sizeof(z));
assert(inflateInit2(&z, 15) == 0);
z.next_in = m_compressed_buf.data();
z.avail_in = m_compressed_buf.size();
m_zstream.next_in = m_compressed_buf.data();
m_zstream.avail_in = m_compressed_buf.size();
m_zstream.total_in = m_zstream.total_out = 0;
// loop in case of really big messages (e.g. READY)
while (true) {
z.next_out = m_decompress_buf.data() + z.total_out;
z.avail_out = m_decompress_buf.size() - z.total_out;
m_zstream.next_out = m_decompress_buf.data() + m_zstream.total_out;
m_zstream.avail_out = m_decompress_buf.size() - m_zstream.total_out;
int err = inflate(&z, Z_SYNC_FLUSH);
if ((err == Z_OK || err == Z_BUF_ERROR) && z.avail_in > 0) {
int err = inflate(&m_zstream, Z_SYNC_FLUSH);
if ((err == Z_OK || err == Z_BUF_ERROR) && m_zstream.avail_in > 0) {
m_decompress_buf.resize(m_decompress_buf.size() + InflateChunkSize);
} else {
if (err != Z_OK) {
fprintf(stderr, "Error decompressing input buffer %d (%d/%d)\n", err, z.avail_in, z.avail_out);
fprintf(stderr, "Error decompressing input buffer %d (%d/%d)\n", err, m_zstream.avail_in, m_zstream.avail_out);
} else {
HandleGatewayMessage(std::string(m_decompress_buf.begin(), m_decompress_buf.begin() + z.total_out));
HandleGatewayMessage(std::string(m_decompress_buf.begin(), m_decompress_buf.begin() + m_zstream.total_out));
if (m_decompress_buf.size() > InflateChunkSize)
m_decompress_buf.resize(InflateChunkSize);
}
inflateEnd(&z);
break;
}
}
m_compressed_buf.clear();
#else
HandleGatewayMessage(str);
#endif
}
void DiscordClient::HandleGatewayMessage(std::string str) {

View File

@ -8,9 +8,7 @@
#include <set>
#include <unordered_set>
#include <mutex>
#ifdef ABADDON_USE_COMPRESSED_SOCKET
#include <zlib.h>
#endif
#include <zlib.h>
// bruh
#ifdef GetMessage
@ -42,11 +40,7 @@ class DiscordClient {
friend class Abaddon;
public:
#ifdef ABADDON_USE_COMPRESSED_SOCKET
static const constexpr char *DiscordGateway = "wss://gateway.discord.gg/?v=6&encoding=json&compress=zlib-stream";
#else
static const constexpr char *DiscordGateway = "wss://gateway.discord.gg/?v=6&encoding=json";
#endif
static const constexpr char *DiscordAPI = "https://discord.com/api";
static const constexpr char *GatewayIdentity = "Discord";
@ -80,11 +74,10 @@ public:
void UpdateToken(std::string token);
private:
#ifdef ABADDON_USE_COMPRESSED_SOCKET
static const constexpr int InflateChunkSize = 0x10000;
std::vector<uint8_t> m_compressed_buf;
std::vector<uint8_t> m_decompress_buf;
#endif
z_stream m_zstream;
std::string DecompressGatewayMessage(std::string str);
void HandleGatewayMessageRaw(std::string str);
void HandleGatewayMessage(std::string str);