#ifndef STAR_WEIGHTED_POOL_HPP #define STAR_WEIGHTED_POOL_HPP #include "StarRandom.hpp" namespace Star { template struct WeightedPool { public: typedef pair ItemsType; typedef List ItemsList; WeightedPool(); template explicit WeightedPool(Container container); void add(double weight, Item item); void clear(); ItemsList const& items() const; size_t size() const; pair const& at(size_t index) const; double weight(size_t index) const; Item const& item(size_t index) const; bool empty() const; // Return item using the given randomness source Item select(RandomSource& rand) const; // Return item using the global randomness source Item select() const; // Return item using fast static randomness from the given seed Item select(uint64_t seed) const; // Return a list of n items which are selected uniquely (by index), where // n is the lesser of the desiredCount and the size of the pool. // This INFLUENCES PROBABILITIES so it should not be used where a // correct statistical distribution is required. List selectUniques(size_t desiredCount) const; List selectUniques(size_t desiredCount, uint64_t seed) const; size_t selectIndex(RandomSource& rand) const; size_t selectIndex() const; size_t selectIndex(uint64_t seed) const; private: size_t selectIndex(double target) const; ItemsList m_items; double m_totalWeight; }; template WeightedPool::WeightedPool() : m_totalWeight(0.0) {} template template WeightedPool::WeightedPool(Container container) : WeightedPool() { for (auto const& pair : container) add(get<0>(pair), get<1>(pair)); } template void WeightedPool::add(double weight, Item item) { if (weight <= 0.0) return; m_items.append({weight, std::move(item)}); m_totalWeight += weight; } template void WeightedPool::clear() { m_items.clear(); m_totalWeight = 0.0; } template auto WeightedPool::items() const -> ItemsList const & { return m_items; } template size_t WeightedPool::size() const { return m_items.count(); } template pair const& WeightedPool::at(size_t index) const { return m_items.at(index); } template double WeightedPool::weight(size_t index) const { return at(index).first; } template Item const& WeightedPool::item(size_t index) const { return at(index).second; } template bool WeightedPool::empty() const { return m_items.empty(); } template Item WeightedPool::select(RandomSource& rand) const { if (m_items.empty()) return Item(); return m_items[selectIndex(rand)].second; } template Item WeightedPool::select() const { if (m_items.empty()) return Item(); return m_items[selectIndex()].second; } template Item WeightedPool::select(uint64_t seed) const { if (m_items.empty()) return Item(); return m_items[selectIndex(seed)].second; } template List WeightedPool::selectUniques(size_t desiredCount) const { return selectUniques(desiredCount, Random::randu64()); } template List WeightedPool::selectUniques(size_t desiredCount, uint64_t seed) const { size_t targetCount = std::min(desiredCount, size()); Set indices; while (indices.size() < targetCount) indices.add(selectIndex(++seed)); List result; for (size_t i : indices) result.append(m_items[i].second); return result; } template size_t WeightedPool::selectIndex(RandomSource& rand) const { return selectIndex(rand.randd()); } template size_t WeightedPool::selectIndex() const { return selectIndex(Random::randd()); } template size_t WeightedPool::selectIndex(uint64_t seed) const { return selectIndex(staticRandomDouble(seed)); } template size_t WeightedPool::selectIndex(double target) const { if (m_items.empty()) return NPos; // Test a randomly generated target against each weighted item in turn, and // see if that weighted item's weight value crosses the target. This way, a // random item is picked from the list, but (roughly) weighted to be // proportional to its weight over the weight of all entries. // // TODO: This is currently O(n), but can easily be made O(log(n)) by using a // tree. If this shows up in performance measurements, this is an obvious // improvement. double accumulatedWeight = 0.0f; for (size_t i = 0; i < m_items.size(); ++i) { accumulatedWeight += m_items[i].first / m_totalWeight; if (target <= accumulatedWeight) return i; } // If we haven't crossed the target, just assume floating point error has // caused us to not quite make it to the last item. return m_items.size() - 1; } } #endif