So I saw a talk called rand() Considered Harmful and it advocated for using the engine-distribution paradigm of random number generation over the simple std::rand()
plus modulus paradigm.
However, I wanted to see the failings of std::rand()
firsthand so I did a quick experiment:
Basically, I wrote 2 functions getRandNum_Old()
and getRandNum_New()
that generated a random number between 0 and 5 inclusive using std::rand()
and std::mt19937
+std::uniform_int_distribution
respectively.
Then I generated 960,000 (divisible by 6) random numbers using the "old" way and recorded the frequencies of the numbers 0-5. Then I calculated the standard deviation of these frequencies. What I'm looking for is a standard deviation as low as possible since that is what would happen if the distribution were truly uniform.
I ran that simulation 1000 times and recorded the standard deviation for each simulation. I also recorded the time it took in milliseconds.
Afterwards, I did the exact same again but this time generating random numbers the "new" way.
Afterwards, I calculated the mean and standard deviation of the list of standard deviations for both the old and new way and the mean and standard deviation for the list of times taken for both the old and new way.
Here were the results:
[OLD WAY]
Spread
mean: 346.554406
std dev: 110.318361
Time Taken (ms)
mean: 6.662910
std dev: 0.366301
[NEW WAY]
Spread
mean: 350.346792
std dev: 110.449190
Time Taken (ms)
mean: 28.053907
std dev: 0.654964
Surprisingly, the aggregate spread of rolls was the same for both methods. I.e., std::mt19937
+std::uniform_int_distribution
was not "more uniform" than simple std::rand()
+%
. Another observation I made was that the new was about 4x slower than the old way. Overall, it seemed like I was paying a huge cost in speed for almost no gain in quality.
Is my experiment flawed in some way? Or is std::rand()
really not that bad, and maybe even better?
For reference, here is the code I used in its entirety:
#include <cstdio>
#include <random>
#include <algorithm>
#include <chrono>
int getRandNum_Old() {
static bool init = false;
if (!init) {
std::srand(time(nullptr)); // Seed std::rand
init = true;
}
return std::rand() % 6;
}
int getRandNum_New() {
static bool init = false;
static std::random_device rd;
static std::mt19937 eng;
static std::uniform_int_distribution<int> dist(0,5);
if (!init) {
eng.seed(rd()); // Seed random engine
init = true;
}
return dist(eng);
}
template <typename T>
double mean(T* data, int n) {
double m = 0;
std::for_each(data, data+n, [&](T x){ m += x; });
m /= n;
return m;
}
template <typename T>
double stdDev(T* data, int n) {
double m = mean(data, n);
double sd = 0.0;
std::for_each(data, data+n, [&](T x){ sd += ((x-m) * (x-m)); });
sd /= n;
sd = sqrt(sd);
return sd;
}
int main() {
const int N = 960000; // Number of trials
const int M = 1000; // Number of simulations
const int D = 6; // Num sides on die
/* Do the things the "old" way (blech) */
int freqList_Old[D];
double stdDevList_Old[M];
double timeTakenList_Old[M];
for (int j = 0; j < M; j++) {
auto start = std::chrono::high_resolution_clock::now();
std::fill_n(freqList_Old, D, 0);
for (int i = 0; i < N; i++) {
int roll = getRandNum_Old();
freqList_Old[roll] += 1;
}
stdDevList_Old[j] = stdDev(freqList_Old, D);
auto end = std::chrono::high_resolution_clock::now();
auto dur = std::chrono::duration_cast<std::chrono::microseconds>(end-start);
double timeTaken = dur.count() / 1000.0;
timeTakenList_Old[j] = timeTaken;
}
/* Do the things the cool new way! */
int freqList_New[D];
double stdDevList_New[M];
double timeTakenList_New[M];
for (int j = 0; j < M; j++) {
auto start = std::chrono::high_resolution_clock::now();
std::fill_n(freqList_New, D, 0);
for (int i = 0; i < N; i++) {
int roll = getRandNum_New();
freqList_New[roll] += 1;
}
stdDevList_New[j] = stdDev(freqList_New, D);
auto end = std::chrono::high_resolution_clock::now();
auto dur = std::chrono::duration_cast<std::chrono::microseconds>(end-start);
double timeTaken = dur.count() / 1000.0;
timeTakenList_New[j] = timeTaken;
}
/* Display Results */
printf("[OLD WAY]\n");
printf("Spread\n");
printf(" mean: %.6f\n", mean(stdDevList_Old, M));
printf(" std dev: %.6f\n", stdDev(stdDevList_Old, M));
printf("Time Taken (ms)\n");
printf(" mean: %.6f\n", mean(timeTakenList_Old, M));
printf(" std dev: %.6f\n", stdDev(timeTakenList_Old, M));
printf("\n");
printf("[NEW WAY]\n");
printf("Spread\n");
printf(" mean: %.6f\n", mean(stdDevList_New, M));
printf(" std dev: %.6f\n", stdDev(stdDevList_New, M));
printf("Time Taken (ms)\n");
printf(" mean: %.6f\n", mean(timeTakenList_New, M));
printf(" std dev: %.6f\n", stdDev(timeTakenList_New, M));
}
Aucun commentaire:
Enregistrer un commentaire