1BRC Challenge
One thing that recently got nerd the hell out of me was 1 billion row challenge. Citing the original site:
Your mission, should you decide to accept it, is deceptively simple: write a Java program for retrieving temperature measurement values from a text file and calculating the min, mean, and max temperature per weather station. There’s just one caveat: the file has 1,000,000,000 rows!
I was working on it after hourse and 1 week after taking on the challenge there are several conclusions worth writing about.
Baseline
To make dev loop faster I decided limit the task to 100 million rows.
The baseline from the autor was an idiomatic streams, but single threaded.
public class CalculateAverage_baseline {
private static final String FILE = "./measurements.txt";
private static record Measurement(String station, double value) {
private Measurement(String[] parts) {
this(parts[0], Double.parseDouble(parts[1]));
}
}
private static record ResultRow(double min, double mean, double max) {
public String toString() {
return round(min) + "/" + round(mean) + "/" + round(max);
}
private double round(double value) {
return Math.round(value * 10.0) / 10.0;
}
};
private static class MeasurementAggregator {
private double min = Double.POSITIVE_INFINITY;
private double max = Double.NEGATIVE_INFINITY;
private double sum;
private long count;
}
public static void main(String[] args) throws IOException {
Collector<Measurement, MeasurementAggregator, ResultRow> collector = Collector.of(
MeasurementAggregator::new,
(a, m) -> {
a.min = Math.min(a.min, m.value);
a.max = Math.max(a.max, m.value);
a.sum += m.value;
a.count++;
},
(agg1, agg2) -> {
var res = new MeasurementAggregator();
res.min = Math.min(agg1.min, agg2.min);
res.max = Math.max(agg1.max, agg2.max);
res.sum = agg1.sum + agg2.sum;
res.count = agg1.count + agg2.count;
return res;
},
agg -> {
return new ResultRow(agg.min, (Math.round(agg.sum * 10.0) / 10.0) / agg.count, agg.max);
});
Map<String, ResultRow> measurements = new TreeMap<>(Files.lines(Paths.get(FILE))
.map(l -> new Measurement(l.split(";")))
.collect(groupingBy(m -> m.station(), collector)));
System.out.println(measurements);
}
}
The performance is not good as expected:
time ./calculate_average_baseline.sh
real 0m15,846s
user 0m16,002s
sys 0m0,769s
Parallelizing the code
My first thought was to parallelize the code and see how fast it will run. Obvious choice is to use streams the same way but add .parallel() call. So updates go to central HashMap and access has to be synchronized.
One way is to use locks, tried it and it was slow. Remembering we have 100 million rows locking on whole object is very wasteful. Much better would be to lock only on specific hashmap entry with given key.
What surprised me was that Atomic* primitives and spinlocks were actually slower than wait/notify. Luckly there is another way you can do same thing with ConcurrentHashMap. When calling compute method the run lambda will be synchronized.
So the code reads file on the fly and uses Spliterator to carve out overflow of work to another thread run on ForkJoinPool. Then each iteration of foreach does the update on ConcurrentHashMap.
public class CalculateAverage_dg2 {
private static final String FILE = "./measurements.txt";
private static class Measurement {
public double min = Double.MAX_VALUE;
public double max = Double.MIN_VALUE;
public double sum = 0;
public double count = 0;
public Measurement(double value) {
this.min = value;
this.max = value;
this.sum = value;
this.count = 1;
}
@Override
public String toString() {
return round(min) + "/" + round(max) + "/" + round(sum / count);
}
}
public static void main(String[] args) throws Exception {
ConcurrentHashMap<String, Measurement> measurements = new ConcurrentHashMap<>();
Files.lines(Path.of(FILE)).parallel().forEach(line -> {
int split = line.indexOf(";");
String key = line.substring(0, split);
double val = Double.parseDouble(line.substring(split + 1));
measurements.compute(key, (k, _v) -> {
if (_v == null) {
return new Measurement(val);
}
_v.min = Double.min(_v.min, val);
_v.max = Double.max(_v.max, val);
_v.sum += val;
_v.count++;
return _v;
});
});
var sorted = new TreeMap<>(measurements);
System.out.println(sorted);
}
private static double round(double value) {
return Math.round(value * 10.0) / 10.0;
}
}
The results were good - on 4 core (8 threads) CPU it runs 4 times faster:
time ./calculate_average_dg2.sh
real 0m4,992s
user 0m33,384s
sys 0m1,207s
Spliterators performance is good because of data locality that it preserves. Processors have limited amount of L1/L2/L3 caches so if you don't keep the data close to computation then this translates to cache missess. So best way is to perform operations "semi-sequentially" just like the spliterators.
But... Streams create threads dynamically and this means you can't partition the data into well known parts. So the access has to be synchronized in one place and this one big problem here.
It would be the end of the story if I didn't check the leaderboard on competitions github repo. There were speedups as fast as astonishing 32x times faster. So there is much to analyse. Below are the most important optimizations I tried in my playground.
Optimizations
O1. Reading longs instead of bytes
You may recall that everything is represented in JVM as an int. This means that even a byte is in fact and 4-byte int. I'm not sure if JVM doesn't do some kind of compression of byte[] array. If not that would mean the data is actually 4x times larager. So it is wasteful to search newline byte by byte. The trick is to load 8 bytes at once and find newline in whole batch.
long word = buffer.getLong();
long match = word ^ 0x0a0a0a0a0a0a0a0aL;
long line = (match - 0x0101010101010101L) & (~match & 0x8080808080808080L);
if (line == 0) {
i += 8;
continue;
}
next = i + (Long.numberOfTrailingZeros(line) >>> 3) + 1;
Xor will zero the byte that has newline (0x0a). Then we cause underflow which sets leading bit to 1. And then we zero out everthing else. From the leading bit we can get offset of newline in 8-byte pack.
O2. Represent measurements as int / skip parsing
We can take advantage of the data format: floats have only one decimal point and temperatures are not greater or lesser than 100 points for sure. So the idea is to parse the number manually and multiply it by 10 so that it is the integer.
This is not huge improvement but still. Actually big improvement is skipping Double.parseDouble as we can discard uncessary complexity like long mantissa and exponent part.
O3. Defer creating of String
Strings are a little bit heavy because they are not underlied by byte[] but char[]. So each time we create a string we encode it and according to the charset. We can defer creating strings until they are needed for printing. But we have to find some other key for the hashmap.
Copying byte array is much faster and this is what I used.
O4. Faster hashes
Faster non-cryptographic hash is needed. One suitable algorithm for this task is FNV. It is simple enought to implement and run fast.
@Override
public int hashCode() {
long hash = 0x811C9DC5;
long prime = 0x01000193;
for (byte b : bytes) {
hash = hash ^ b;
hash *= prime;
}
return (int) hash;
}
O5. Skipping synchronization
With such huge load synchronization is too slow. It is fine from more coarse grained control flow, but 100 million is just too much. So optimization is to use separate threads each with separate hashmap. Then merge hashmaps after the work is done. As number of hashmaps will be several orders of magnitude less than rows then merging time will be negligible.
O6. Off heap memory
We could use Unsafe to skip array bounds checking for example. But much better way is avoid allocating data on the heap alltogether. With mapping file chunks directly to memory we can easily achieve that. We get MemorySegment and invoke methods which are "native" to get fast access.
07. Aligning array address to the multitude of 8 bytes
There is performance penalty for accessing unaligned data, which has to be padded either way. So we can remap all the data in segment once and then enjoy the benefits of reading values in multitudes of 8-bytes (longs).
Final version
The final version uses all above techniques and gets decent peformance. The file is split into 100MB chunks which are memory mapped by each thread. They are executed in form of ForkJoinPool task and run on common pool which has concurrency level same as numer of threads.
class Measurement {
public int min = Integer.MAX_VALUE;
public int max = Integer.MIN_VALUE;
public int sum = 0;
public int count = 0;
public Measurement(int value) {
this.min = value;
this.max = value;
this.sum = value;
this.count = 1;
}
@Override
public String toString() {
return round(min / 10.0) + "/" + round(max / 10.0) + "/" + round((sum / 10.0) / count);
}
private static double round(double value) {
return Math.round(value * 10.0) / 10.0;
}
}
class FastKey implements Comparable<FastKey> {
private byte[] bytes;
private int hash;
private static final Charset charset = Charset.forName("UTF-8");
public FastKey(MemorySegment segment) {
bytes = segment.toArray(ValueLayout.OfByte.JAVA_BYTE);
long hash = 0x811C9DC5;
long prime = 0x01000193;
for (byte b : bytes) {
hash = hash ^ b;
hash *= prime;
}
this.hash = (int) hash;
}
@Override
public int hashCode() {
return this.hash;
}
@Override
public boolean equals(Object obj) {
if (obj == this) {
return true;
}
if (!(obj instanceof FastKey)) {
return false;
}
return Arrays.equals(bytes, ((FastKey) obj).bytes);
}
@Override
public int compareTo(FastKey o) {
return Arrays.compare(bytes, o.bytes);
}
@Override
public String toString() {
return new String(bytes, charset);
}
}
class ComputeMeasurementsPartTask implements Callable<Map<FastKey, Measurement>> {
private int start;
private int end;
private static final Charset charset = Charset.forName("UTF-8");
private MemorySegment segment;
private byte getByte(int index) {
return segment.get(ValueLayout.OfByte.JAVA_BYTE, index);
}
private long getLong(int index) {
return segment.get(ValueLayout.OfLong.JAVA_LONG, index);
}
public ComputeMeasurementsPartTask(int start, int end, int limit, FileChannel channel) throws IOException {
this.segment = channel.map(FileChannel.MapMode.READ_ONLY, start, limit - start, Arena.global());
int s = start, s2 = start, e = end;
if (s != 0) {
while (getByte(s - s2) != 0x0a) {
s++;
}
s++;
}
while (e < limit && getByte(e - s2) != 0x0a) {
e++;
}
int prefix = s % 8;
MemorySegment padded = Arena.global().allocate(e - s + prefix);
padded.asSlice(prefix).copyFrom(segment.asSlice(s - s2, e - s));
this.segment = padded;
this.start = prefix;
this.end = e - s + prefix;
}
private void doActualWork(int start, int end, Map<FastKey, Measurement> measurements) {
int splitIndex = start;
while (getByte(splitIndex) != 0x3b) {
splitIndex++;
}
var key = new FastKey(segment.asSlice(start, splitIndex - start));
boolean negative = false;
int ind = splitIndex + 1;
if (getByte(ind) == (byte) '-') {
negative = true;
ind++;
}
int v = 0;
if (end - ind == 4) {
v = v * 10 + getByte(ind++) - '0';
v = v * 10 + getByte(ind++) - '0';
ind++; // '.'
v = v * 10 + getByte(ind) - '0';
}
else {
v = getByte(ind++) - '0';
ind++; // '.'
v = v * 10 + getByte(ind) - '0';
}
int val = negative ? -v : v;
var _v = measurements.get(key);
if (_v != null) {
if (val < _v.min) {
_v.min = val;
}
if (val > _v.max) {
_v.max = val;
}
_v.sum += val;
_v.count++;
}
else {
measurements.put(key, new Measurement(val));
}
}
@Override
public Map<FastKey, Measurement> call() throws Exception {
var measurements = new HashMap<FastKey, Measurement>();
int prev = start;
for (int i = 0; i < end; i += 8) {
if (i + 8 < end) {
long word = getLong(i);
long match = word ^ 0x0a0a0a0a0a0a0a0aL;
long line = (match - 0x0101010101010101L) & (~match & 0x8080808080808080L);
if (line == 0) {
i += 8;
continue;
}
int next = i + (Long.numberOfTrailingZeros(line) >>> 3);
doActualWork(prev, next, measurements);
prev = next + 1;
}
else {
doActualWork(prev, end, measurements);
}
}
return measurements;
}
}
public class CalculateAverage_dg {
public static final String FILE = "./measurements.txt";
public static void main(String[] args) throws Exception {
FileChannel channel = FileChannel.open(Path.of(CalculateAverage_dg.FILE), StandardOpenOption.READ);
var parts = new ArrayList<Map<FastKey, Measurement>>();
var futures = new ArrayList<Future<Map<FastKey, Measurement>>>();
int fileSize = (int) Files.size(Path.of(FILE));
int chunkSize = 20 * 1024 * 1024;
int noOfThreads = fileSize / chunkSize;
for (int i = 0; i < noOfThreads; i++) {
int start = i * chunkSize;
int end = Math.min((i + 1) * chunkSize, fileSize);
int limit = Math.min((i + 1) * chunkSize + 1024, fileSize);
var task = new ComputeMeasurementsPartTask(start, end, limit, channel);
futures.add(ForkJoinPool.commonPool().submit(task));
}
for (var future : futures) {
parts.add(future.get());
}
var measurements = parts.stream().flatMap(map -> map.entrySet().stream())
.collect(
Collectors.toMap(
Map.Entry::getKey,
Map.Entry::getValue,
(e1, e2) -> {
e1.min = Integer.min(e1.min, e2.min);
e1.max = Integer.max(e1.max, e2.max);
e1.sum += e2.sum;
e1.count += e2.count;
return e1;
}));
var sorted = new TreeMap<>(measurements);
System.out.println(sorted);
}
}
The time is impressive:
time ./calculate_average_dg.sh
real 0m1,465s
user 0m8,505s
sys 0m0,526s
For comparison this is the time of top solution from the leaderboard.
time ./calculate_average_thomaswue.sh
real 0m0,794s
user 0m5,032s
sys 0m0,199s
So it is "only" two times slower than the best. I don't know how would that translate to the leaderboard. But this guy is actually founder of GraalVM so he definitely knows what he is doing.
Conclusion
- ConcurrentHashMap is suprisingly fast, without good background it would be hard to write something faster
- Off-heap memory gives considerable performance boost
- Reading file as longs is a nice trick that I didn't know up to date