Skip to content

Commit

Permalink
refactor!: rework authorization no longer pass entry
Browse files Browse the repository at this point in the history
  • Loading branch information
SamMousa committed Dec 31, 2021
1 parent 81c59cc commit 15a7c4b
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 75 deletions.
1 change: 1 addition & 0 deletions LibEventSourcing.xml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
<Script file="source\StateManager.lua"/>
<Script file="source\Message.lua"/>
<Script file="source\AdvertiseHashMessage.lua"/>
<Script file="source\SubscribeMessage.lua"/>
<Script file="source\WeekDataMessage.lua"/>
<Script file="source\BulkDataMessage.lua"/>
<Script file="source\RequestWeekMessage.lua"/>
Expand Down
16 changes: 3 additions & 13 deletions source/LedgerFactory.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ Params
table: table -- Reference to the data, should be a saved variable
send: function(tableData, distribution, target, progressCallback): void -- function the sync will use to send outgoing data
sendLargeMessage: function(tableData, distribution, target, progressCallback): void -- function the sync will use to send large messages
authorizationHandler: function(entry, sender): bool -- Authorization handler, called before sending outgoing entries and before
committing incoming entries
authorizationHandler: function(sender): bool -- Authorization handler, called before sending outgoing entries and before trusting incoming entries
registerReceiveHandler: function(receiveCallback: function(message, distribution, sender)): void
Expand Down Expand Up @@ -77,22 +76,13 @@ LedgerFactory.createLedger = function(table, send, registerReceiveHandler, autho
submitEntry = function(entry)
-- not applying timetravel before auth, because from an addon perspective it is the current time.
-- check authorization
if not authorizationHandler(entry, UnitName("player")) then
error("Attempted to submit entries for which you are not authorized")
return
end

stateManager:addEvent(entry)
listSync:transmitViaGuild(entry)
end,
ignoreEntry = function(entry)
local ignoreEntry = stateManager:createIgnoreEntry(entry)
if listSync:transmitViaGuild(ignoreEntry, entry) then
-- only commit locally if we are authorized to send
sortedList:uniqueInsert(ignoreEntry)
else
error("Attempted to submit entries for which you are not authorized")
end
sortedList:uniqueInsert(ignoreEntry)
listSync:transmitViaGuild(ignoreEntry)
end,
catchup = function(limit)
stateManager:catchup(limit)
Expand Down
67 changes: 35 additions & 32 deletions source/ListSync.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ local RequestWeekMessage = LibStub("EventSourcing/Message/RequestWeek")
local RequestStateMessage = LibStub("EventSourcing/Message/RequestState")
local StateMessage = LibStub("EventSourcing/Message/State")
local BulkDataMessage = LibStub("EventSourcing/Message/BulkData")
local SubscribeMessage = LibStub("EventSourcing/Message/Subscribe")
local Message = LibStub("EventSourcing/Message")


Expand Down Expand Up @@ -146,7 +147,12 @@ local function requestWeekInhibitorCheck(listSync, week)
or listSync.inhibitors[messageType][week] < Util.time()
end

local function handleAdvertiseMessage(message, sender, distribution, stateManager, listSync)
local function handleSubscribeMessage(message, sender, distribution, stateManager, listSync)
-- This message should only be received on a whisper channel.
listSync.subscribers[#listSync.subscribers + 1] = sender
end

local function handleAdvertiseMessage(message, sender, _, stateManager, listSync)
-- This is the number of entries we expect to have after all data from advertisements in this message have been synced
local projectedEntries = stateManager:getSortedList():length()
local now = Util.time()
Expand Down Expand Up @@ -196,33 +202,33 @@ end


local function handleWeekDataMessage(message, sender, distribution, stateManager, listSync)
local count = 0
for _, v in ipairs(message.entries) do
local entry = stateManager:createLogEntryFromList(v)
-- Authorize each event
if listSync.authorizationHandler(entry, sender) then
if not listSync.authorizationHandler(sender) then
listSync.logger:Warning("Dropping week data message from unauthorized sender %s", sender)
else
local count = 0
for _, v in ipairs(message.entries) do
local entry = stateManager:createLogEntryFromList(v)
stateManager:queueRemoteEvent(entry)
count = count + 1
else
listSync.logger:Warning("Dropping event from unauthorized sender %s", sender)
end
listSync.logger:Info("Enqueued %d events for week %s from remote received from %s via %s", count, message.week, sender, distribution)
end
listSync.logger:Info("Enqueued %d events for week %s from remote received from %s via %s", count, message.week, sender, distribution)
end


local function handleBulkDataMessage(message, sender, distribution, stateManager, listSync)
local count = 0
for _, v in ipairs(message.entries) do
local entry = stateManager:createLogEntryFromList(v)
-- Authorize each event
if listSync.authorizationHandler(entry, sender) then

if not listSync.authorizationHandler(sender) then
listSync.logger:Warning("Dropping bulk data message from unauthorized sender %s", sender)
else
local count = 0
for _, v in ipairs(message.entries) do
local entry = stateManager:createLogEntryFromList(v)
stateManager:queueRemoteEvent(entry)
count = count + 1
else
listSync.logger:Warning("Dropping event from unauthorized sender %s", sender)
end
listSync.logger:Info("Enqueued %d events from remote received from %s via %s", count, sender, distribution)
end
listSync.logger:Info("Enqueued %d events from remote received from %s via %s", count, sender, distribution)
end

local function handleRequestWeekMessage(message, sender, distribution, stateManager, listSync)
Expand Down Expand Up @@ -289,14 +295,10 @@ local function advertiseWeekHashInhibitorCheckOrSet(listSync, week)
return false
end

local function transmitEntry(listSync, entry, authEntry, channel)
if listSync.authorizationHandler(authEntry or entry, UnitName("player")) then
local message = BulkDataMessage.create()
message:addEntry(listSync._stateManager:createListFromEntry(entry))
send(listSync, message, channel)
return true
end
return false
local function transmitEntry(listSync, entry, channel)
local message = BulkDataMessage.create()
message:addEntry(listSync._stateManager:createListFromEntry(entry))
send(listSync, message, channel)
end

function ListSync:new(stateManager, sendMessage, registerReceiveHandler, authorizationHandler, sendLargeMessage, logger)
Expand Down Expand Up @@ -327,6 +329,9 @@ function ListSync:new(stateManager, sendMessage, registerReceiveHandler, authori
entries = {}
}
o.playerName = UnitName("player")
-- A list of players that want our advertisements
o.subscribers = {}

registerReceiveHandler(function(message, distribution, sender)
handleMessage(o, message, distribution, sender)
end)
Expand All @@ -338,6 +343,7 @@ function ListSync:new(stateManager, sendMessage, registerReceiveHandler, authori
o.messageHandlers[RequestWeekMessage.type()] = { handleRequestWeekMessage }
o.messageHandlers[RequestStateMessage.type()] = { handleRequestStateMessage }
o.messageHandlers[StateMessage.type()] = { handleStateMessage }
o.messageHandlers[SubscribeMessage.type()] = { handleSubscribeMessage }
o.inhibitors = {}
-- Inhibitor for sending hash advertisements, format is week => timestamp inhibition ends
o.inhibitors[AdvertiseHashMessage.type()] = {}
Expand Down Expand Up @@ -373,17 +379,14 @@ end

--[[
Sends an entry out over the guild channel, if allowed
@param LogEntry authEntry, the entry to use for auth checking, defaults to the entry that is to be transmitted
@return bool whether we were authorized to send the message
]]--

function ListSync:transmitViaGuild(entry, authEntry)
return transmitEntry(self, entry, authEntry, CHANNEL_GUILD)
function ListSync:transmitViaGuild(entry)
return transmitEntry(self, entry, CHANNEL_GUILD)
end

function ListSync:transmitViaRaid(entry, authEntry)
return transmitEntry(self, entry, authEntry, CHANNEL_RAID)
function ListSync:transmitViaRaid(entry)
return transmitEntry(self, entry, CHANNEL_RAID)
end


Expand Down
60 changes: 32 additions & 28 deletions source/StateManager.lua
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,33 @@ local function entryToList(entry)
return result
end


local function castLogEntry(stateManager, entry)
-- Find which meta table we should use
local class = LogEntry.class(entry)
if stateManager.metatables[class] == nil then
error("Unknown class: " .. class)
end
setmetatable(entry, stateManager.metatables[class])
end


local function applyEntry(stateManager, entry, index)
local handler = stateManager.handlers[entry:class()] or stateManager.defaultHandler
local handler = stateManager.handlers[LogEntry.class(entry)] or stateManager.defaultHandler
if handler == nil then
error("Handler for class " .. entry:class() .. "is not registered and no default handler was set")
error(string.format("Handler for class %s is not registered and no default handler was set", LogEntry.class(entry)))
end

local result, hash

--[[ Check ignored entries ]]--
local uuid = entry:uuid();
local uuid = LogEntry.uuid(entry);
local numbersForHash = LogEntry.numbersForHash(entry);

if (stateManager.ignoredEntries[uuid] ~= nil) then
stateManager.ignoredEntries[uuid] = nil
else
castLogEntry(stateManager, entry)
handler(entry)
end

Expand Down Expand Up @@ -103,6 +115,7 @@ local function restartIfRequired(stateManager, ignoreThrottle)
end
return true
end

--[[
This function plays new entries, it is called repeatedly on a timer.
The goal of each call is to remain under the frame render time
Expand All @@ -112,7 +125,6 @@ local function updateState(stateManager, batchSize)
local applied = 0
while applied < batchSize and stateManager.lastAppliedIndex < #entries do
local entry = entries[stateManager.lastAppliedIndex + 1]
stateManager:castLogEntry(entry)
if (stateManager.timeTraveling ~= nil and entry:time() > stateManager.timeTraveling) then
if applied > 0 then
print(string.format("Stopping state updates due to time travel restriction, applied %d events", applied))
Expand All @@ -139,6 +151,17 @@ local function safeUpdateState(stateManager, limit)
return success, message
end

local function createLogEntryFromClass(stateManager, cls)
local table = {}
if stateManager.metatables[cls] == nil then
error("Unknown class: " .. cls)
end
setmetatable(table, stateManager.metatables[cls])
LogEntry.setClass(table, cls)
return table
end


-- END PRIVATE

function StateManager:new(list, logger)
Expand Down Expand Up @@ -187,15 +210,6 @@ function StateManager:isTimeTraveling()
return self.timeTraveling ~= nil
end

function StateManager:castLogEntry(table)
-- Find which meta table we should use
local class = LogEntry.class(table)
if self.metatables[class] == nil then
error("Unknown class: " .. class)
end
setmetatable(table, self.metatables[class])
end

function StateManager:queueRemoteEvent(entry)
table.insert(self.uncommittedEntries, entry)
end
Expand All @@ -220,31 +234,22 @@ end
function StateManager:createLogEntryFromList(list)
if list.cls ~= nil then
-- this is not really a list
self:castLogEntry(list)
castLogEntry(self, list)
return list
end
local class = table.remove(list)
local entry = self:createLogEntryFromClass(class)
local entry = createLogEntryFromClass(self, class)
hydrateEntryFromList(entry, list)
return entry
end

function StateManager:createListFromEntry(entry)
self:castLogEntry(entry)
castLogEntry(self, entry)
local result = entryToList(entry)
table.insert(result, entry:class())
table.insert(result, LogEntry.class(entry))
return result
end

function StateManager:createLogEntryFromClass(cls)
local table = {}
if self.metatables[cls] == nil then
error("Unknown class: " .. cls)
end
setmetatable(table, self.metatables[cls])
LogEntry.setClass(table, cls)
return table
end


function StateManager:registerHandler(eventType, handler)
Expand Down Expand Up @@ -352,8 +357,7 @@ end
function StateManager:lag()
if self.timeTraveling ~= nil and #self.list:entries() > self.lastAppliedIndex then
local nextEntry = self.list:entries()[self.lastAppliedIndex + 1]
self:castLogEntry(nextEntry)
if (nextEntry:time() > self.timeTraveling) then
if (LogEntry.time(nextEntry) > self.timeTraveling) then
return 0, 0
else
return 1, 0 -- this is a hack, during time travel lag will be binary
Expand Down
26 changes: 26 additions & 0 deletions source/SubscribeMessage.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
local Factory, _ = LibStub:NewLibrary("EventSourcing/Message/Subscribe", 1)
if not Factory then
return
end


local Message = LibStub("EventSourcing/Message")
local SubscribeMessage = Message:extend('SB')

function SubscribeMessage:new()
local o = Message.new(self)
return o
end

function SubscribeMessage:validate()
Message.validate(self)
end


function Factory.create()
return SubscribeMessage:new()
end

function Factory.type()
return SubscribeMessage._type
end
2 changes: 1 addition & 1 deletion tests/StateManagerTest.lua
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ assertTrue(type(StateManager) == 'table')
assertSame(0, stateManager:stateHash())
assertEmpty(messages)

sortedList:uniqueInsert(TestEntry:new('test'))
assertTrue(sortedList:uniqueInsert(TestEntry:new('test')))
assertSame(1, sortedList:length())
stateManager:catchup()
assertCount(1, messages)
Expand Down
6 changes: 5 additions & 1 deletion tests/_bootstrap.lua
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,13 @@ function assertEmpty(table)
assertCount(0, table)
end

function assertTable(table)
return assert(type(table) =="table")
end
function assertCount(expected, table)
assertTable(table)
assertionStatistics["total"] = assertionStatistics["total"] + 1
assert(#table == expected, string.format("failed assert that table has length %d", expected))
assert(#table == expected, string.format("failed assert that table has length %d, got %d", expected, #table))
assertionStatistics["passed"] = assertionStatistics["passed"] + 1
end
function assertError(cb)
Expand Down

0 comments on commit 15a7c4b

Please sign in to comment.