diff --git a/src/dialogs/quickswitcher/quickswitcher.cpp b/src/dialogs/quickswitcher/quickswitcher.cpp index a493e16..ed6930e 100644 --- a/src/dialogs/quickswitcher/quickswitcher.cpp +++ b/src/dialogs/quickswitcher/quickswitcher.cpp @@ -55,12 +55,24 @@ void QuickSwitcher::IndexPrivateChannels() { void QuickSwitcher::IndexChannels() { auto &discord = Abaddon::Get().GetDiscordClient(); + const auto channels = discord.GetAllChannelData(); + // grab literally everything to do in memory otherwise we get a shit ton of IOs + auto overwrites = discord.GetAllPermissionOverwrites(); + + auto member_roles = discord.GetAllMemberRoles(discord.GetUserData().ID); + std::unordered_map roles; + for (const auto &[guild_id, guild_roles] : member_roles) { + for (const auto &role_data : guild_roles) { + roles.emplace(role_data.ID, role_data); + } + } + for (auto &channel : channels) { if (!channel.Name.has_value()) continue; if (!channel.IsText()) continue; - // might want to optimize this at some point - if (!discord.HasSelfChannelPermission(channel.ID, Permission::VIEW_CHANNEL)) continue; + if (channel.GuildID.has_value() && + !discord.HasSelfChannelPermission(channel, Permission::VIEW_CHANNEL, roles, member_roles[*channel.GuildID], overwrites[channel.ID])) continue; m_index[channel.ID] = { SwitcherEntry::ResultType::Channel, *channel.Name, static_cast(channel.ID), diff --git a/src/discord/discord.cpp b/src/discord/discord.cpp index ffbb4eb..bcf384d 100644 --- a/src/discord/discord.cpp +++ b/src/discord/discord.cpp @@ -362,6 +362,14 @@ std::vector DiscordClient::GetAllChannelData() const { return m_store.GetAllChannelData(); } +std::unordered_map> DiscordClient::GetAllPermissionOverwrites() const { + return m_store.GetAllPermissionOverwriteData(); +} + +std::unordered_map> DiscordClient::GetAllMemberRoles(Snowflake user_id) const { + return m_store.GetAllMemberRoles(user_id); +} + bool DiscordClient::IsThreadJoined(Snowflake thread_id) const { return std::find(m_joined_threads.begin(), m_joined_threads.end(), thread_id) != m_joined_threads.end(); } @@ -463,6 +471,88 @@ bool DiscordClient::CanManageMember(Snowflake guild_id, Snowflake actor, Snowfla if (!target_highest.has_value()) return true; return actor_highest->Position > target_highest->Position; } +bool DiscordClient::HasSelfChannelPermission(const ChannelData &channel, + Permission perm, + const std::unordered_map &roles, + const std::vector &member_roles, + const std::unordered_map &overwrites) const { + return HasChannelPermission(m_user_data.ID, channel, perm, roles, member_roles, overwrites); +} + +bool DiscordClient::HasChannelPermission(Snowflake user_id, + const ChannelData &channel, + Permission perm, + const std::unordered_map &roles, + const std::vector &member_roles, + const std::unordered_map &overwrites) const { + if (channel.IsDM()) return true; + if (!channel.GuildID.has_value()) return false; + const auto base = ComputePermissions(user_id, *channel.GuildID, roles, member_roles); + const auto computed_overwrites = ComputeOverwrites(base, user_id, channel, member_roles, overwrites); + return (computed_overwrites & perm) == perm; +} + +Permission DiscordClient::ComputePermissions(Snowflake member_id, + Snowflake guild_id, + const std::unordered_map &roles, + const std::vector &member_roles) const { + const auto guild_owner = m_store.GetGuildOwner(guild_id); + + if (guild_owner == member_id) + return Permission::ALL; + + if (const auto everyone_it = roles.find(guild_id); everyone_it != roles.end()) { + const auto &everyone = everyone_it->second; + + Permission perms = everyone.Permissions; + for (const auto &role : member_roles) { + perms |= role.Permissions; + } + + if ((perms & Permission::ADMINISTRATOR) == Permission::ADMINISTRATOR) + return Permission::ALL; + + return perms; + } + + return Permission::NONE; +} + +Permission DiscordClient::ComputeOverwrites(Permission base, + Snowflake member_id, + const ChannelData &channel, + const std::vector &member_roles, + const std::unordered_map &overwrites) const { + if ((base & Permission::ADMINISTRATOR) == Permission::ADMINISTRATOR) + return Permission::ALL; + + if (!channel.GuildID.has_value()) return Permission::NONE; + + Permission perms = base; + if (const auto overwrite_everyone = overwrites.find(*channel.GuildID); overwrite_everyone != overwrites.end()) { + perms &= ~overwrite_everyone->second.Deny; + perms |= overwrite_everyone->second.Allow; + } + + Permission allow = Permission::NONE; + Permission deny = Permission::NONE; + for (const auto &role : member_roles) { + if (const auto overwrite = overwrites.find(role.ID); overwrite != overwrites.end()) { + allow |= overwrite->second.Allow; + deny |= overwrite->second.Deny; + } + } + + perms &= ~deny; + perms |= allow; + + if (const auto member_overwrite = overwrites.find(member_id); member_overwrite != overwrites.end()) { + perms &= ~member_overwrite->second.Deny; + perms |= member_overwrite->second.Allow; + } + + return perms; +} void DiscordClient::ChatMessageCallback(const std::string &nonce, const http::response_type &response, const sigc::slot &callback) { if (!CheckCode(response)) { diff --git a/src/discord/discord.hpp b/src/discord/discord.hpp index 122e667..fc1407f 100644 --- a/src/discord/discord.hpp +++ b/src/discord/discord.hpp @@ -66,6 +66,8 @@ public: std::vector GetChildChannelIDs(Snowflake parent_id) const; std::optional GetWebhookMessageData(Snowflake message_id) const; std::vector GetAllChannelData() const; + std::unordered_map> GetAllPermissionOverwrites() const; + std::unordered_map> GetAllMemberRoles(Snowflake user_id) const; // get ids of given list of members for who we do not have the member data template @@ -88,6 +90,28 @@ public: Permission ComputeOverwrites(Permission base, Snowflake member_id, Snowflake channel_id) const; bool CanManageMember(Snowflake guild_id, Snowflake actor, Snowflake target) const; // kick, ban, edit nickname (cant think of a better name) + // IO-less calls + bool HasSelfChannelPermission(const ChannelData &channel, + Permission perm, + const std::unordered_map &roles, + const std::vector &member_roles, + const std::unordered_map &overwrites) const; + bool HasChannelPermission(Snowflake user_id, + const ChannelData &channel, + Permission perm, + const std::unordered_map &roles, + const std::vector &member_roles, + const std::unordered_map &overwrites) const; + Permission ComputePermissions(Snowflake member_id, + Snowflake guild_id, + const std::unordered_map &roles, + const std::vector &member_roles) const; + Permission ComputeOverwrites(Permission base, + Snowflake member_id, + const ChannelData &channel, + const std::vector &member_roles, + const std::unordered_map &overwrites) const; + void ChatMessageCallback(const std::string &nonce, const http::response_type &response, const sigc::slot &callback); void SendChatMessageNoAttachments(const ChatSubmitParams ¶ms, const sigc::slot &callback); void SendChatMessageAttachments(const ChatSubmitParams ¶ms, const sigc::slot &callback); diff --git a/src/discord/store.cpp b/src/discord/store.cpp index b725c78..671080b 100644 --- a/src/discord/store.cpp +++ b/src/discord/store.cpp @@ -1253,6 +1253,52 @@ std::vector Store::GetAllChannelData() const { return r; } +std::unordered_map> Store::GetAllPermissionOverwriteData() const { + auto &s = m_stmt_get_all_perms; + std::unordered_map> r; + + while (s->FetchOne()) { + PermissionOverwrite d; + Snowflake channel_id; + s->Get(0, d.ID); + s->Get(1, channel_id); + s->Get(2, d.Type); + s->Get(3, d.Allow); + s->Get(4, d.Deny); + r[channel_id][d.ID] = d; + } + + s->Reset(); + + return r; +} + +std::unordered_map> Store::GetAllMemberRoles(Snowflake user_id) const { + auto &s = m_stmt_get_self_member_roles; + std::unordered_map> r; + + s->Bind(1, user_id); + + while (s->FetchOne()) { + Snowflake guild_id; + RoleData role; + s->Get(0, role.ID); + s->Get(1, guild_id); + s->Get(2, role.Name); + s->Get(3, role.Color); + s->Get(4, role.IsHoisted); + s->Get(5, role.Position); + s->Get(6, role.Permissions); + s->Get(7, role.IsManaged); + s->Get(8, role.IsMentionable); + r[guild_id].push_back(std::move(role)); + } + + s->Reset(); + + return r; +} + void Store::ClearAll() { if (m_db.Execute(R"( DELETE FROM attachments; @@ -1983,6 +2029,14 @@ bool Store::CreateStatements() { return false; } + m_stmt_get_all_perms = std::make_unique(m_db, R"( + SELECT * FROM permissions + )"); + if (!m_stmt_get_all_perms->OK()) { + fprintf(stderr, "failed to prepare get all permissions statement: %s\n", m_db.ErrStr()); + return false; + } + m_stmt_set_ban = std::make_unique(m_db, R"( REPLACE INTO bans VALUES ( ?, ?, ? @@ -2062,6 +2116,18 @@ bool Store::CreateStatements() { return false; } + m_stmt_get_self_member_roles = std::make_unique(m_db, R"( + SELECT DISTINCT roles.* + FROM member_roles, roles + WHERE (member_roles.user = ? + AND member_roles.role = roles.id) + OR roles.id = roles.guild /* @everyone */ + )"); + if (!m_stmt_get_self_member_roles->OK()) { + fprintf(stderr, "failed to prepare get self member roles statement: %s\n", m_db.ErrStr()); + return false; + } + m_stmt_set_guild_emoji = std::make_unique(m_db, R"( REPLACE INTO guild_emojis VALUES ( ?, ? diff --git a/src/discord/store.hpp b/src/discord/store.hpp index 7612ad2..a5bb9f4 100644 --- a/src/discord/store.hpp +++ b/src/discord/store.hpp @@ -98,6 +98,8 @@ public: // this does NOT include recipients std::vector GetAllChannelData() const; + std::unordered_map> GetAllPermissionOverwriteData() const; + std::unordered_map> GetAllMemberRoles(Snowflake user_id) const; void ClearAll(); @@ -309,6 +311,7 @@ private: STMT(get_emoji); STMT(set_perm); STMT(get_perm); + STMT(get_all_perms); STMT(set_ban); STMT(get_ban); STMT(get_bans); @@ -317,6 +320,7 @@ private: STMT(set_member_roles); STMT(get_member_roles); STMT(clr_member_roles); + STMT(get_self_member_roles); STMT(set_guild_emoji); STMT(get_guild_emojis); STMT(clr_guild_emoji);