How to sample from $\{1, 2, …, K\}$ for $n$ random variables, each with different mass functions, in R?

In R, I have an $N \times K$ matrix $P$ where the $i$’th row of $P$ corresponds to a distribution on $\{1, …, K\}$. Essentially, I need to sample from each row efficiently. A naive implementation is:

X = rep(0, N);
for(i in 1:N){
    X[i] = sample(1:K, 1, prob = P[i, ]);

This is much too slow. In principle I could move this to C but I’m sure there must be an existing way of doing this. I would like something in the spirit of the following code (which does not work):

X = sample(1:K, N, replace = TRUE, prob = P)

EDIT: For motivation, take $N = 10000$ and $K = 100$. I have $P_1, …, P_{5000}$ matrices all $N \times K$ and I need to sample a vector from each of them.


We can do this in a couple of simple ways. The first is easy to code, easy to understand and reasonably fast. The second is a little trickier, but much more efficient for this size of problem than the first method or other approaches mentioned here.

Method 1: Quick and dirty.

To get a single observation from the probability distribution of each row, we can simply do the following.

# Q is the cumulative distribution of each row.
Q <- t(apply(P,1,cumsum))

# Get a sample with one observation from the distribution of each row.
X <- rowSums(runif(N) > Q) + 1

This produces the cumulative distribution of each row of $P$ and then samples one observation from each distribution. Notice that if we can reuse $P$ then we can calculate $Q$ once and store it for later use. However, the question needs something that works for a different $P$ at each iteration.

If you need multiple ($n$) observations from each row, then replace the last line with the following one.

# Returns an N x n matrix
X <- replicate(n, rowSums(runif(N) > Q)+1)

This is really not an extremely efficient way in general to do this, but it does take good advantage of R vectorization capabilities, which is usually the primary determinant of execution speed. It is also straightforward to understand.

Method 2: Concatenating the cdfs.

Suppose we had a function that took two vectors, the second of which was sorted in monotonically nondecreasing order and found the index in the second vector of the greatest lower bound of each element in the first. Then, we could use this function and a slick trick: Just create the cumulative sum of the cdfs of all the rows. This gives a monotonically increasing vector with elements in the range $[0,N]$.

Here is the code.

i <- 0:(N-1)

# Cumulative function of the cdfs of each row of P.
Q <- cumsum(t(P))

# Find the interval and then back adjust
findInterval(runif(N)+i, Q)-i*K+1

Notice what the last line does, it creates random variables distributed in $(0,1), (1,2), \dots, (N-1,N)$ and then calls findInterval to find the index of the greatest lower bound of each entry. So, this tells us that the first element of runif(N)+i will be found between index 1 and index $K$, the second will be found between index $K+1$ and $2K$, etc, each according to the distribution of the corresponding row of $P$. Then we need to back transform to get each of the indices back in the range $\{1,\ldots,K\}$.

Because findInterval is fast both algorithmically and implementation-wise, this method turns out to be extremely efficient.

A benchmark

On my old laptop (MacBook Pro, 2.66 GHz, 8GB RAM), I tried this with $N = 10000$ and $K = 100$ and generating 5000 samples of size $N$, exactly as suggested in the updated question, for a total of 50 million random variates.

The code for Method 1 took almost exactly 15 minutes to execute, or about 55K random variates per second. The code for Method 2 took about four and a half minutes to execute, or about 183K random variates per second.

Here is the code for the sake of reproducibility. (Note that, as indicated in a comment, $Q$ is recalculated for each of the 5000 iterations to simulate the OP’s situation.)

# Benchmark code
N <- 10000
K <- 100

P <- matrix(runif(N*K),N,K)
P <- P / rowSums(P) <- function(P)
    Q <- t(apply(P,1,cumsum))
    X <- rowSums(runif(nrow(P)) > Q) + 1

method.two <- function(P)
    n <- nrow(P)
    i <- 0:(n-1)
    Q <- cumsum(t(P))
    findInterval(runif(n)+i, Q)-i*ncol(P)+1

Here is the output.

# Method 1: Timing
> system.time(replicate(5e3,
   user  system elapsed 
691.693 195.812 899.246 

# Method 2: Timing
> system.time(replicate(5e3, method.two(P)))
   user  system elapsed 
182.325  82.430 273.021 

Postscript: By looking at the code for findInterval, we can see that it does some checks on the input to see if there are NA entries or if the second argument is not sorted. Hence, if we wanted to squeeze more performance out of this, we could create our own modified version of findInterval which strips out these checks which are unnecessary in our case.

Source : Link , Question Author : guy , Answer Author : cardinal

Leave a Comment