Defining a custom Collector in Java 8

Plan


Reasons to write its own Collector

There are really many reasons. Here are main of them :
– what we want to collect compels us to write complex logic in the stream operations or the collect part of the stream that appears neither natural nor readable.
That complexity/clumsiness may occur at multi places : grouping functions, instantiation/valuing of the result that is an instance of a custom class,  stateful or not associative functions used in stream operations/collect and so for…
– we want to reuse the collector in other places.
– we want to gather the collector logic in a specific component to avoid side effects with other members of the class where it is located or more generally to respect the single responsibility principle.
– we want to unit test the collector logic.
– we want that it may be executed faster when used with parallel streams.

The requirement to illustrate : a collect that cannot be done straightly with built-in Stream/Collector methods

We have a list of Rating objects where a Rating represents a rating for a movie, a rating being represented by an integer between 0 and 10.
Our need is the following : we would like to compute the percentage of good, average and bad movies with this rule :

mark > 6  => Good
mark >= 4 && mark <= 6 => Average
mark < 4 => Bad

The Rating class :

public class Rating {
 
    private final int mark;
 
    public Rating(int mark) {
        this.mark = mark;
    }
 
    public int getMark() {
        return mark;
    }
 
    @Override
    public String toString() {
        return "Rating{" +
                "mark=" + mark +
                '}';
    }
}

Here is the SummarizedRating class that represents the information we want to get from the collect :

public class SummarizedRating {
 
    private float prctGood;
    private float prctAverage;
    private float prctBad;
 
    public SummarizedRating(float prctGood, float prctAverage, float prctBad) {
        this.prctGood = prctGood;
        this.prctAverage = prctAverage;
        this.prctBad = prctBad;
    }
 
    public float getPrctGood() {
        return prctGood;
    }
 
    public float getPrctAverage() {
        return prctAverage;
    }
 
    public float getPrctBad() {
        return prctBad;
    }
 
    @Override
    public String toString() {
        return "SummarizedRating{" +
                "prctGood=" + prctGood +
                ", prctAverage=" + prctAverage +
                ", prctBad=" + prctBad +
                '}';
    }
}
<br />

Implementing it with built-in Collectors but without extracted functions is terrible

Whatever the way, the idea is broadly the same one : we want to count the ratings by rating level (average, bad and good), which allows to compute the percentage of each one.
In a solution with built-in collectors and no intermediary variables to maintain the state, we need first to group the elements in a Map. For Map values : no problem, that is the count of each rating level but for Map keys, what type should we use ? A String ? Probably no, as not specific enough. An enum seems better. So introduce that :

private enum RatingLevel {
        GOOD, AVERAGE, BAD
}

Here is the stream processing :

SummarizedRating summarizedRating =
        ratings.stream()
               .collect(collectingAndThen(groupingBy(r -> {
                                              int mark = r.getMark();
                                              if (mark >= 4 && mark <= 6) {
                                                  return RatingLevel.AVERAGE;
                                              } else if (mark < 4) {
                                                  return RatingLevel.BAD;
                                              }
                                              return RatingLevel.GOOD;
                                          }, counting()),
                                          m -> {
                                              float prctGood = m.getOrDefault(RatingLevel.GOOD,
                                                                              0L) / (float) ratings.size();
                                              float prctAverage = m.getOrDefault(RatingLevel.AVERAGE,
                                                                                 0L) / (float) ratings.size();
                                              float prctBad = m.getOrDefault(RatingLevel.BAD,
                                                                             0L) / (float) ratings.size();
                                              return new SummarizedRating(prctGood, prctAverage, prctBad);
                                          })
               );

It works but it is not very readable. It looks like a mix between functional programming (we have stream operations) and imperative programming (as we also have many statements and even conditional statements in the lambda).  The length of the lambda makes also harder to spot the main strategy used in the collectors, that is :
collectingAndThen(groupingBy(groupingByFunction, couting()),
thenFunction)
We could do it clearer by extracting statements located in the lambda into methods.

Implementing it with built-in Collectors and extracted functions is better while still presents some drawbacks

Here is the refactored stream processing :

SummarizedRating summarizedRating =
ratings.stream()
       .collect(collectingAndThen(groupingBy(r -> getRatingLevel(r), counting()),
				  m -> getSummarizedRating(ratings, m)
		   	        )
	       );

And here the extracted functions :

private static RatingLevel getRatingLevel(Rating r) {
	int mark = r.getMark();
	if (mark >= 4 && mark <= 6) {
		return RatingLevel.AVERAGE;
	} else if (mark < 4) {
		return RatingLevel.BAD;
	}
	return RatingLevel.GOOD;
}
 
 
private static SummarizedRating getSummarizedRating(List<Rating> ratings, Map<RatingLevel, Long> m) {
	float prctGood = m.getOrDefault(RatingLevel.GOOD,
									0L) / (float) ratings.size();
	float prctAverage = m.getOrDefault(RatingLevel.AVERAGE,
									   0L) / (float) ratings.size();
	float prctBad = m.getOrDefault(RatingLevel.BAD,
								   0L) / (float) ratings.size();
	return new SummarizedRating(prctGood, prctAverage, prctBad);
}

It is absolutely clearer. But is it satisfactory ? Here are some limitations :
– we introduced an enum class « only » for the grouping operation while from the client side of the stream, the enum is not required.
– we introduced two additional methods specific to the stream processing in the class where the stream is created. These responsibilities/methods are not necessarily at the correct place if the class has already other distinct responsibilities : indeed we don’t want to make this class bloat or with low-cohesion of these members.
Extracting these methods in a specific class is a possibility but does it make sense to separate it from the stream operations since currently these are only designed for the stream ?
– we cannot reuse the overall stream logic in another place in a straight way.
– would the current collect be efficient in parallel stream ? Very probably. Could be still more efficient with a custom collector ? Very probably too.

As probably guessed, a custom collector can handle these limitations.

Implementing it with a custom Collector has many advantages

The stream processing

SummarizedRating summarizedRating =
                ratings.stream()
                       .collect(new RatingCollector());

The collector :

import java.util.Collections;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
 
public class RatingCollector implements Collector<Rating, RatingCollector.GroupedRatings, SummarizedRating> {
 
    static class GroupedRatings {
 
        private int goodRatings;
        private int averageRatings;
        private int badRatings;
 
        void incrementAverageRating() {
            averageRatings++;
        }
 
        void incrementGoodRating() {
            goodRatings++;
        }
 
        void incrementBadRating() {
            badRatings++;
        }
 
        long size() {
            return averageRatings + goodRatings + badRatings;
        }
 
        @Override
        public String toString() {
            return "GroupedRatings{" +
                    "goodRatings=" + goodRatings +
                    ", averageRatings=" + averageRatings +
                    ", badRatings=" + badRatings +
                    '}';
        }
    }
    // End GroupedRatings class declaration
 
    @Override
    public Supplier<GroupedRatings> supplier() {
        return GroupedRatings::new;
    }
 
    @Override
    public BiConsumer<GroupedRatings, Rating> accumulator() {
 
        return (groupedRatings, rating) -> {
            final int mark = rating.getMark();
            if (mark >= 4 && mark <= 6) {
                groupedRatings.incrementAverageRating();
            } else if (mark < 4) {
                groupedRatings.incrementBadRating();
            } else {
                groupedRatings.incrementGoodRating();
            }
        };
    }
 
    @Override
    public Function<GroupedRatings, SummarizedRating> finisher() {
        return groupedRatings -> {
            long size = groupedRatings.size();
            float prctAverage = groupedRatings.averageRatings / (float) size;
            float prctGood = groupedRatings.goodRatings / (float) size;
            float prctBad = groupedRatings.badRatings / (float) size;
            return new SummarizedRating(prctGood, prctAverage, prctBad);
        };
    }
 
    @Override
    public Set<Characteristics> characteristics() {
        return Collections.emptySet();
    }
 
    @Override
    public BinaryOperator<GroupedRatings> combiner() {
        return (g, otherG) -> {
            throw new UnsupportedOperationException("Parallel stream not supported");
        };
    }
 
}

Advantages of the collector :
 
– we locate the whole logic in a single place : the collector.
– we could reuse it wherever.
– we could document/allow/disallow/optimize the parallel stream processing.
– mapping/helper functions make part of the collector logic. Indeed, we decide the way which the stream elements are processed in  the accumulator() method of our collector.
– the intermediary collect into a map is not required any longer since with a custom collector we accumulate directly into a mutable object that contains only required values. Previously, the map was used as intermediary structure because without custom collector and mutable state it is the single way to group elements with a specific rating range. But with a mutable object that contains fields to represent the counts of rating for each rating range, the map is just helpless, we only need to update the value of the count fields.
– consequence of the previous point : better performance as no map lookup, add and much less memory print. Besides, we don’t need any longer to use Integer objects to represent the counts for each range. This way decreased performance because of unboxing and boxing operations for each processed element. Now we could replace them by three primitive int fields.

Make the custom Collector allow parallel streams

The combiner() method callable in the frame of parallel streams was not supported in the previous implementations. Here, we update the collector implementation by supporting this method.

import java.util.Collections;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
 
public class ConcurrentRatingCollector implements Collector<Rating, ConcurrentRatingCollector.GroupedRatings, SummarizedRating> {
 
    static class GroupedRatings {
 
        private int goodRatings;
        private int averageRatings;
        private int badRatings;
 
        void incrementAverageRating() {
            averageRatings++;
        }
 
        void incrementGoodRating() {
            goodRatings++;
        }
 
        void incrementBadRating() {
            badRatings++;
        }
 
        long size() {
            return averageRatings + goodRatings + badRatings;
        }
 
        @Override
        public String toString() {
            return "GroupedRatings{" +
                    "goodRatings=" + goodRatings +
                    ", averageRatings=" + averageRatings +
                    ", badRatings=" + badRatings +
                    '}';
        }
    }
    // End GroupedRatings class declaration
 
    @Override
    public Supplier<GroupedRatings> supplier() {
        return GroupedRatings::new;
    }
 
    @Override
    public BiConsumer<GroupedRatings, Rating> accumulator() {
 
        return (groupedRatings, rating) -> {
            final int mark = rating.getMark();
            if (mark >= 4 && mark <= 6) {
                groupedRatings.incrementAverageRating();
            } else if (mark < 4) {
                groupedRatings.incrementBadRating();
            } else {
                groupedRatings.incrementGoodRating();
            }
        };
    }
 
    @Override
    public Function<GroupedRatings, SummarizedRating> finisher() {
        return groupedRatings -> {
            long size = groupedRatings.size();
            float prctAverage = groupedRatings.averageRatings / (float) size;
            float prctGood = groupedRatings.goodRatings / (float) size;
            float prctBad = groupedRatings.badRatings / (float) size;
            return new SummarizedRating(prctGood, prctAverage, prctBad);
        };
    }
 
    @Override
    public Set<Characteristics> characteristics() {
        Set<Characteristics> characteristics = Collections.singleton(Characteristics.UNORDERED);
        return characteristics;
    }
 
    @Override
    public BinaryOperator<GroupedRatings> combiner() {
        return (g, otherG) -> {
            g.averageRatings += otherG.averageRatings;
            g.goodRatings += otherG.goodRatings;
            g.badRatings += otherG.badRatings;
            return g;
        };
    }
 
}

Benchmark of some solutions with JMH

import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import owncollector.rating.*;
 
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
 
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.groupingBy;
 
@State(Scope.Benchmark)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public class CollectorBenchmark {
 
    private List<Rating> ratings;
 
    public static void main(String[] args) throws RunnerException {
        Options opt = new OptionsBuilder().include(CollectorBenchmark.class.getSimpleName())
                                          .warmupIterations(5)
                                          .measurementIterations(5)
                                          .forks(1)
                                          .build();
        new Runner(opt).run();
    }
 
    @Setup(Level.Iteration)
    public void doSetup() {
        ratings = new ArrayList<>();
        Random random = new Random();
        for (int i = 0; i < 100_000; i++) {
            ratings.add(new Rating(random.nextInt(11)));
        }
    }
 
 
    @Benchmark
    public void _1_sequential_stream_with_no_collector() {
        ratings.stream()
               .collect(collectingAndThen(groupingBy(r -> getRatingLevel(r), Collectors.counting()),
                                          m -> getSummarizedRating(ratings, m)
                        )
               );
    }
 
    @Benchmark
    public void _2_parallel_stream_with_no_collector() {
        ratings.parallelStream()
               .collect(collectingAndThen(groupingBy(r -> getRatingLevel(r), Collectors.counting()),
                                          m -> getSummarizedRating(ratings, m)
                        )
               );
    }
 
    @Benchmark
    public void _3_sequential_stream_with_collector() {
        ratings.stream()
               .collect(new RatingCollector());
    }
 
    @Benchmark
    public void _4_parallel_stream_with_collector() {
        ratings.parallelStream()
               .collect(new ConcurrentRatingCollector());
    }
 
 
    private static CollectingWithProvidedCollector.RatingLevel getRatingLevel(Rating r) {
        int mark = r.getMark();
        if (mark >= 4 && mark <= 6) {
            return CollectingWithProvidedCollector.RatingLevel.AVERAGE;
        } else if (mark < 4) {
            return CollectingWithProvidedCollector.RatingLevel.BAD;
        }
        return CollectingWithProvidedCollector.RatingLevel.GOOD;
    }
 
 
    private static SummarizedRating getSummarizedRating(List<Rating> ratings, Map<CollectingWithProvidedCollector.RatingLevel, Long> m) {
        float prctGood = m.getOrDefault(CollectingWithProvidedCollector.RatingLevel.GOOD,
                                        0L) / (float) ratings.size();
        float prctAverage = m.getOrDefault(CollectingWithProvidedCollector.RatingLevel.AVERAGE,
                                           0L) / (float) ratings.size();
        float prctBad = m.getOrDefault(CollectingWithProvidedCollector.RatingLevel.BAD,
                                       0L) / (float) ratings.size();
        return new SummarizedRating(prctGood, prctAverage, prctBad);
    }
 
 
}

Here is the JMH result (slower score is better) : 

# Run complete. Total time: 00:04:41
 
REMEMBER: The numbers below are just data. To gain reusable insights, you need to follow up on
why the numbers are the way they are. Use profilers (see -prof, -lprof), design factorial
experiments, perform baseline and negative tests that provide experimental control, make sure
the benchmarking environment is safe on JVM/OS/HW level, ask for reviews from the domain experts.
Do not assume the numbers tell you what you want them to tell.
 
Benchmark                                                         Mode   Cnt        Score       Error        Units
CollectorBenchmark._1_sequential_stream_with_builtin_collector    avgt    5  2273981,498 ± 27320,072  ns/op
CollectorBenchmark._2_parallel_stream_with_builtin_collector      avgt    5   471996,477 ±  2249,508  ns/op
CollectorBenchmark._3_sequential_stream_with_custom_collector     avgt    5   624875,239 ± 86014,917  ns/op
CollectorBenchmark._4_parallel_stream_with_custom_collector       avgt    5   196311,640 ±  1364,517  ns/op

Lessons to take away from this benchmark :
– the parallel stream has much better performance
– the custom collector has much better performance
– so using both versus using no one of them are extreme results of the benchmark: 196.311 ns/op versus 2.273.981 ns/op.

Conclusion : 

As shown, a collector presents many advantages : it respects the single responsibility principle, it is often more efficient and focuses/uses  a container class based on the requirements which generally makes more sense than any built-in collectors.
But we could also argue that the collector has some drawbacks : it is verbose and it has more reading indirection (that is you cannot understand the stream logic without navigating and reading the collector class code).
So as a rule of thumb, using a custom collector makes sense for :
– not trivial and repeated stream operation/collect.
– not readable streams code where we try/would try to twist it to make it working with our logic.
– case of performance matters.

Source code

Git repository

Ce contenu a été publié dans Non classé. Vous pouvez le mettre en favoris avec ce permalien.

3 réponses à Defining a custom Collector in Java 8

  1. Tai dit :

    Great writeup, helped me understand the abstract methods in the Collector interface much better. Thanks!

    • davidhxxx dit :

      Hello Tai,

      Please to hear :)
      I try to lay quite simple examples to be able to re-understand it even a long time after.
      I checked that, that is still the case, haha :)

  2. tanmay dit :

    You explained such a complex topic with this simplicity, its awesome!

Laisser un commentaire

Votre adresse de messagerie ne sera pas publiée. Les champs obligatoires sont indiqués avec *