.. _program_listing_file_include_beluga_views_sample.hpp: Program Listing for File sample.hpp =================================== |exhale_lsh| :ref:`Return to documentation for file ` (``include/beluga/views/sample.hpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp // Copyright 2024 Ekumen, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef BELUGA_VIEWS_SAMPLE_HPP #define BELUGA_VIEWS_SAMPLE_HPP #include #include #include #include #include #include namespace beluga::views { namespace detail { template struct sample_view : public ranges::view_facade, ranges::infinite> { public: sample_view() = default; constexpr sample_view(Range range, Distribution distribution, URNG& engine = ranges::detail::get_random_engine()) : range_{std::move(range)}, distribution_{std::move(distribution)}, engine_{std::addressof(engine)} { assert(ranges::size(range) > 0); assert(distribution_.min() == 0); assert(distribution_.max() == static_cast(ranges::size(range_)) - 1); } private: // `ranges::range_access` needs access to the cursor members. friend ranges::range_access; static_assert(ranges::sized_range); static_assert(ranges::random_access_range); static_assert(std::is_same_v>); struct cursor { public: cursor() = default; constexpr explicit cursor(sample_view* view) : view_(view), first_{ranges::begin(view_->range_)}, it_{first_ + view_->compute_offset()} {} [[nodiscard]] constexpr decltype(auto) read() const noexcept(noexcept(*this->it_)) { return *it_; } constexpr void next() { it_ = first_ + view_->compute_offset(); } private: sample_view* view_; ranges::iterator_t first_; ranges::iterator_t it_; }; [[nodiscard]] constexpr auto begin_cursor() { return cursor{this}; } [[nodiscard]] constexpr auto end_cursor() const noexcept { return ranges::unreachable_sentinel_t{}; } [[nodiscard]] constexpr auto compute_offset() { return distribution_(*engine_); } Range range_; Distribution distribution_; URNG* engine_; }; template struct is_random_distribution : public std::false_type {}; template struct is_random_distribution()(std::declval()))>> : std::true_type {}; template inline constexpr bool is_random_distribution_v = is_random_distribution::value; struct sample_base_fn { protected: template constexpr auto sample_from_range(Range&& range, Weights&& weights, URNG& engine) const { static_assert(ranges::sized_range); static_assert(ranges::random_access_range); static_assert(ranges::input_range); using result_type = ranges::range_difference_t; auto w = ranges::views::common(weights); auto distribution = std::discrete_distribution{ranges::begin(w), ranges::end(w)}; return sample_view{ranges::views::all(std::forward(range)), std::move(distribution), engine}; } template constexpr auto sample_from_range(Range&& range, URNG& engine) const { static_assert(ranges::sized_range); static_assert(ranges::random_access_range); if constexpr (beluga::is_particle_range_v) { return sample_from_range(beluga::views::states(range), beluga::views::weights(range), engine) | ranges::views::transform(beluga::make_from_state>); } else { using result_type = ranges::range_difference_t; auto distribution = std::uniform_int_distribution{0, static_cast(ranges::size(range) - 1)}; return sample_view{ranges::views::all(std::forward(range)), std::move(distribution), engine}; } } template constexpr auto sample_from_distribution(Distribution distribution, URNG& engine) const { return ranges::views::generate( [distribution = std::move(distribution), &engine]() mutable { return distribution(engine); }); } }; struct sample_fn : public sample_base_fn { template constexpr auto operator()(T&& t, U&& u, V& v) const { static_assert(ranges::range); static_assert(ranges::range); return sample_from_range(std::forward(t), std::forward(u), v); // Assume V is a URNG } template constexpr auto operator()(T&& t, U&& u) const { if constexpr (ranges::range && ranges::range) { auto& engine = ranges::detail::get_random_engine(); return sample_from_range(std::forward(t), std::forward(u), engine); } else if constexpr (is_random_distribution_v) { static_assert(std::is_lvalue_reference_v); // Assume U is a URNG return sample_from_distribution(std::forward(t), u); } else { static_assert(ranges::range); static_assert(std::is_lvalue_reference_v); // Assume U is a URNG return sample_from_range(std::forward(t), u); } } template constexpr auto operator()(T&& t) const { if constexpr (ranges::range) { auto& engine = ranges::detail::get_random_engine(); return sample_from_range(std::forward(t), engine); } else if constexpr (is_random_distribution_v) { auto& engine = ranges::detail::get_random_engine(); return sample_from_distribution(std::forward(t), engine); } else { static_assert(std::is_lvalue_reference_v); // Assume T is a URNG return ranges::make_view_closure(ranges::bind_back(sample_fn{}, std::ref(t))); } } template constexpr auto operator()(Range&& range, std::reference_wrapper engine) const { static_assert(ranges::range); return sample_from_range(std::forward(range), engine.get()); } }; } // namespace detail inline constexpr ranges::views::view_closure sample; } // namespace beluga::views #endif