diff --git a/core/math/random_number_generator.cpp b/core/math/random_number_generator.cpp index c3f36b32a5a..226d748c526 100644 --- a/core/math/random_number_generator.cpp +++ b/core/math/random_number_generator.cpp @@ -42,6 +42,7 @@ void RandomNumberGenerator::_bind_methods() { ClassDB::bind_method(D_METHOD("randfn", "mean", "deviation"), &RandomNumberGenerator::randfn, DEFVAL(0.0), DEFVAL(1.0)); ClassDB::bind_method(D_METHOD("randf_range", "from", "to"), &RandomNumberGenerator::randf_range); ClassDB::bind_method(D_METHOD("randi_range", "from", "to"), &RandomNumberGenerator::randi_range); + ClassDB::bind_method(D_METHOD("rand_weighted", "weights"), &RandomNumberGenerator::rand_weighted); ClassDB::bind_method(D_METHOD("randomize"), &RandomNumberGenerator::randomize); ADD_PROPERTY(PropertyInfo(Variant::INT, "seed"), "set_seed", "get_seed"); diff --git a/core/math/random_number_generator.h b/core/math/random_number_generator.h index e1c353d439f..bedeb56ce4f 100644 --- a/core/math/random_number_generator.h +++ b/core/math/random_number_generator.h @@ -57,6 +57,8 @@ public: _FORCE_INLINE_ real_t randfn(real_t p_mean = 0.0, real_t p_deviation = 1.0) { return randbase.randfn(p_mean, p_deviation); } _FORCE_INLINE_ int randi_range(int p_from, int p_to) { return randbase.random(p_from, p_to); } + _FORCE_INLINE_ int rand_weighted(const Vector &p_weights) { return randbase.rand_weighted(p_weights); } + RandomNumberGenerator() { randbase.randomize(); } }; diff --git a/core/math/random_pcg.cpp b/core/math/random_pcg.cpp index 45a9285ddd5..e754a342717 100644 --- a/core/math/random_pcg.cpp +++ b/core/math/random_pcg.cpp @@ -31,6 +31,7 @@ #include "random_pcg.h" #include "core/os/os.h" +#include "core/templates/vector.h" RandomPCG::RandomPCG(uint64_t p_seed, uint64_t p_inc) : pcg(), @@ -42,6 +43,26 @@ void RandomPCG::randomize() { seed(((uint64_t)OS::get_singleton()->get_unix_time() + OS::get_singleton()->get_ticks_usec()) * pcg.state + PCG_DEFAULT_INC_64); } +int RandomPCG::rand_weighted(const Vector &p_weights) { + ERR_FAIL_COND_V_MSG(p_weights.is_empty(), -1, "Weights array is empty."); + int64_t weights_size = p_weights.size(); + const float *weights = p_weights.ptr(); + float weights_sum = 0.0; + for (int64_t i = 0; i < weights_size; ++i) { + weights_sum += weights[i]; + } + + float remaining_distance = Math::randf() * weights_sum; + for (int64_t i = 0; i < weights_size; ++i) { + remaining_distance -= weights[i]; + if (remaining_distance < 0) { + return i; + } + } + + return -1; +} + double RandomPCG::random(double p_from, double p_to) { return randd() * (p_to - p_from) + p_from; } diff --git a/core/math/random_pcg.h b/core/math/random_pcg.h index cc22b23b70e..fa8ad3cfb3b 100644 --- a/core/math/random_pcg.h +++ b/core/math/random_pcg.h @@ -59,6 +59,9 @@ static int __bsr_clz32(uint32_t x) { #define LDEXPF(s, e) ldexp(s, e) #endif +template +class Vector; + class RandomPCG { pcg32_random_t pcg; uint64_t current_seed = 0; // The seed the current generator state started from. @@ -87,6 +90,8 @@ public: return pcg32_boundedrand_r(&pcg, bounds); } + int rand_weighted(const Vector &p_weights); + // Obtaining floating point numbers in [0, 1] range with "good enough" uniformity. // These functions sample the output of rand() as the fraction part of an infinite binary number, // with some tricks applied to reduce ops and branching: diff --git a/doc/classes/RandomNumberGenerator.xml b/doc/classes/RandomNumberGenerator.xml index 2b6e6af596d..cb1e68d374d 100644 --- a/doc/classes/RandomNumberGenerator.xml +++ b/doc/classes/RandomNumberGenerator.xml @@ -17,6 +17,25 @@ $DOCS_URL/tutorials/math/random_number_generation.html + + + + + Returns a random index with non-uniform weights. Prints an error and returns [code]-1[/code] if the array is empty. + [codeblocks] + [gdscript] + var rnd = RandomNumberGenerator.new() + + var my_array = ["one", "two", "three, "four"] + var weights = PackedFloat32Array([0.5, 1, 1, 2]) + + # Prints one of the four elements in `my_array`. + # It is more likely to print "four", and less likely to print "two". + print(my_array[rng.rand_weighted(weights)]) + [/gdscript] + [/codeblocks] + +