Utils = {}

--- Returns up to `n` random values from the given array. Will return fewer if `n > #fromArr`
---@generic T
---@param fromArr T[]
---@param n number
---@return T[]
function Utils.getNDifferentValues(fromArr, n)
    assert(n >= 0, "n must be a non-negative integer")
    if n > #fromArr then
        n = #fromArr
    end
    local found = 0
    local indexes = {}
    while found < n do
        local randomIndex = math.random(#fromArr)
        if not indexes[randomIndex] then
            found = found + 1
            indexes[randomIndex] = true
        end
    end

    local randoms = {}
    for i in pairs(indexes) do
        randoms[#randoms + 1] = fromArr[i]
    end
    return randoms
end

--- Track the number of instances of a given element, instead of needing multiple copies.
---@class CountSet
---@field private data table<table, number>
---@field private elementCount number
CountSet = {}

function CountSet.new()
    return setmetatable({ data = {}, elementCount = 0 }, { __index = CountSet })
end

function CountSet:add(element)
    local existing = self.data[element]
    if existing then
        self.data[element] = existing + 1
    else
        self.data[element] = 1
    end
    self.elementCount = self.elementCount + 1
end

function CountSet:balancedRandomPop()
    if self.elementCount == 0 then
        return
    end
    local toPop = math.random(self.elementCount)
    for element, count in pairs(self.data) do
        toPop = toPop - count
        if toPop <= 0 then
            local newCount = count - 1
            if newCount == 0 then
                self.data[element] = nil
            else
                self.data[element] = newCount
            end
            self.elementCount = self.elementCount - 1
            return element
        end
    end
end

function CountSet:iterRandom()
    return function()
        return self:balancedRandomPop()
    end
end