I discovered this while solving Problem 205 of Project Euler. The problem is as follows:
Peter has nine four-sided (pyramidal) dice, each with faces numbered 1, 2, 3, 4.
Colin has six six-sided (cubic) dice, each with faces numbered 1, 2, 3, 4, 5, 6.
Peter and Colin roll their dice and compare totals: the highest total wins. The result is a draw if the totals are equal.
What is the probability that Pyramidal Pete beats Cubic Colin? Give your answer rounded to seven decimal places in the form 0.abcdefg
I wrote a naive solution using Guava:
import com.google.common.collect.Sets;
import com.google.common.collect.ImmutableSet;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.*;
import java.util.stream.Collectors;
public class Problem205 {
public static void main(String[] args) {
long startTime = System.currentTimeMillis();
List<Integer> peter = Sets.cartesianProduct(Collections.nCopies(9, ImmutableSet.of(1, 2, 3, 4)))
.stream()
.map(l -> l
.stream()
.mapToInt(Integer::intValue)
.sum())
.collect(Collectors.toList());
List<Integer> colin = Sets.cartesianProduct(Collections.nCopies(6, ImmutableSet.of(1, 2, 3, 4, 5, 6)))
.stream()
.map(l -> l
.stream()
.mapToInt(Integer::intValue)
.sum())
.collect(Collectors.toList());
long startTime2 = System.currentTimeMillis();
// IMPORTANT BIT HERE! v
long solutions = peter
.stream()
.mapToLong(p -> colin
.stream()
.filter(c -> p > c)
.count())
.sum();
// IMPORTANT BIT HERE! ^
System.out.println("Counting solutions took " + (System.currentTimeMillis() - startTime2) + "ms");
System.out.println("Solution: " + BigDecimal
.valueOf(solutions)
.divide(BigDecimal
.valueOf((long) Math.pow(4, 9) * (long) Math.pow(6, 6)),
7,
RoundingMode.HALF_UP));
System.out.println("Found in: " + (System.currentTimeMillis() - startTime) + "ms");
}
}
The code I have highlighted, which uses a simple filter()
, count()
and sum()
, seems to run much faster in Java 9 than Java 8. Specifically, Java 8 counts the solutions in 37465ms on my machine. Java 9 does it in about 16000ms, which is the same whether I run the file compiled with Java 8 or one compiled with Java 9.
If I replace the streams code with what would seem to be the exact pre-streams equivalent:
long solutions = 0;
for (Integer p : peter) {
long count = 0;
for (Integer c : colin) {
if (p > c) {
count++;
}
}
solutions += count;
}
It counts the solutions in about 35000ms, with no measurable difference between Java 8 and Java 9.
What am I missing here? Why is the streams code so much faster in Java 9, and why isn't the for
loop?
I am running Ubuntu 16.04 LTS 64-bit. My Java 8 version:
java version "1.8.0_131"
Java(TM) SE Runtime Environment (build 1.8.0_131-b11)
Java HotSpot(TM) 64-Bit Server VM (build 25.131-b11, mixed mode)
My Java 9 version:
java version "9"
Java(TM) SE Runtime Environment (build 9+181)
Java HotSpot(TM) 64-Bit Server VM (build 9+181, mixed mode)
See Question&Answers more detail:
os