Java 8 Streams API:对Stream分组和分区

澳门新葡亰手机版 1

这篇文章展示了如何使用 Streams API 中的 Collector 及 groupingBy 和
partitioningBy 来对流中的元素进行分组和分区。

**本系列文章经补充和完善,已修订整理成书《Java编程的逻辑》,由机械工业出版社华章分社出版,于2018年1月上市热销,读者好评如潮!各大网店和书店有售,欢迎购买,京东自营链接:**

思考一下 Employee
对象流,每个对象对应一个名字、城市和销售数量,如下表所示:

澳门新葡亰手机版 1

+----------+------------+-----------------+
| Name     | City       | Number of Sales |
+----------+------------+-----------------+
| Alice    | London     | 200             |
| Bob      | London     | 150             |
| Charles  | New York   | 160             |
| Dorothy  | Hong Kong  | 190             |
+----------+------------+-----------------+

分组

首先,我们利用(lambda表达式出现之前的)命令式风格Java
程序对流中的雇员按城市进行分组:

Map<String, List<Employee>> result = new HashMap<>();
for (Employee e : employees) {
  String city = e.getCity();
  List<Employee> empsInCity = result.get(city);
  if (empsInCity == null) {
    empsInCity = new ArrayList<>();
    result.put(city, empsInCity);
  }
  empsInCity.add(e);
}

你可能很熟悉写这样的代码,你也看到了,一个如此简单的任务就需要这么多代码!

而在 Java 8 中,你可以使用 groupingBy
收集器,一条语句就能完成相同的功能,像这样:

Map<String, List<Employee>> employeesByCity =
  employees.stream().collect(groupingBy(Employee::getCity));

结果如下面的 map 所示:

{New York=[Charles], Hong Kong=[Dorothy], London=[Alice, Bob]}

还可以计算每个城市中雇员的数量,只需传递一个计数收集器给 groupingBy
收集器。第二个收集器的作用是在流分类的同一个组中对每个元素进行递归操作。

Map<String, Long> numEmployeesByCity =
  employees.stream().collect(groupingBy(Employee::getCity, counting()));

结果如下面的 map 所示:

{New York=1, Hong Kong=1, London=2}

顺便提一下,该功能与下面的 SQL 语句是等同的:

select city, count(*) from Employee group by city

另一个例子是计算每个城市的平均年龄,这可以联合使用 averagingInt 和
groupingBy 收集器:

Map<String, Double> avgSalesByCity =
  employees.stream().collect(groupingBy(Employee::getCity,
                               averagingInt(Employee::getNumSales)));

结果如下 map 所示:

{New York=160.0, Hong Kong=190.0, London=175.0}

 

分区

分区是一种特殊的分组,结果 map
至少包含两个不同的分组——一个true,一个false。例如,如果想找出最优秀的员工,你可以将所有雇员分为两组,一组销售量大于
N,另一组小于 N,使用 partitioningBy 收集器:

Map<Boolean, List<Employee>> partitioned =
  employees.stream().collect(partitioningBy(e -> e.getNumSales() > 150));

输出如下结果:

{false=[Bob], true=[Alice, Charles, Dorothy]}

你也可以将 groupingBy 收集器传递给 partitioningBy
收集器来将联合使用分区和分组。例如,你可以统计每个分区中的每个城市的雇员人数:

Map<Boolean, Map<String, Long>> result =
  employees.stream().collect(partitioningBy(e -> e.getNumSales() > 150,
                               groupingBy(Employee::getCity, counting())));

这样会生成一个二级 Map:

{false={London=1}, true={New York=1, Hong Kong=1, London=1}}

上节初步介绍了Java
8中的函数式数据处理,对于collect方法,我们只是演示了其最基本的应用,它还有很多强大的功能,比如,可以分组统计汇总,实现类似数据库查询语言SQL中的group
by功能。

具体都有哪些功能?有什么用?如何使用?基本原理是什么?本节进行详细讨论,我们先来进一步理解下collect方法。

理解collect

在上节中,过滤得到90分以上的学生列表,代码是这样的:

List<Student> above90List = students.stream()
        .filter(t->t.getScore()>90)
        .collect(Collectors.toList());

最后的collect调用看上去很神奇,它到底是怎么把Stream转换为List<Student>的呢?先看下collect方法的定义:

<R, A> R collect(Collector<? super T, A, R> collector)

它接受一个收集器collector作为参数,类型是Collector,这是一个接口,它的定义基本是:

public interface Collector<T, A, R> {
    Supplier<A> supplier();
    BiConsumer<A, T> accumulator();
    BinaryOperator<A> combiner();
    Function<A, R> finisher();
    Set<Characteristics> characteristics();
}

在顺序流中,collect方法与这些接口方法的交互大概是这样的:

//首先调用工厂方法supplier创建一个存放处理状态的容器container,类型为A
A container = collector.supplier().get();

//然后对流中的每一个元素t,调用累加器accumulator,参数为累计状态container和当前元素t
for (T t : data)
   collector.accumulator().accept(container, t);

//最后调用finisher对累计状态container进行可能的调整,类型转换(A转换为R),并返回结果
return collector.finisher().apply(container);

combiner只在并行流中有用,用于合并部分结果。characteristics用于标示收集器的特征,Collector接口的调用者可以利用这些特征进行一些优化,Characteristics是一个枚举,有三个值:CONCURRENT,
UNORDERED和IDENTITY_FINISH,它们的含义我们后面通过例子简要说明,目前可以忽略。

Collectors.toList()具体是什么呢?看下代码:

public static <T>
Collector<T, ?, List<T>> toList() {
    return new CollectorImpl<>((Supplier<List<T>>) ArrayList::new, List::add,
                               (left, right) -> { left.addAll(right); return left; },
                               CH_ID);
}

它的实现类是CollectorImpl,这是Collectors内部的一个私有类,实现很简单,主要就是定义了两个构造方法,接受函数式参数并赋值给内部变量。对toList来说:

  • supplier的实现是ArrayList::new,也就是创建一个ArrayList作为容器
  • accumulator的实现是List::add,也就是将碰到的每一个元素加到列表中,
  • 第三个参数是combiner,表示合并结果
  • 第四个参数CH_ID是一个静态变量,只有一个特征IDENTITY_FINISH,表示finisher没有什么事情可以做,就是把累计状态container直接返回

也就是说,collect(Collectors.toList())背后的伪代码如下所示:

List<T> container = new ArrayList<>();
for (T t : data)
   container.add(t);
return container;

与toList类似的容器收集器还有toSet, toCollection, toMap等,我们来看下。

容器收集器

toSet

toSet的使用与toList类似,只是它可以排重,就不举例了。toList背后的容器是ArrayList,toSet背后的容器是HashSet,其代码为:

public static <T>
Collector<T, ?, Set<T>> toSet() {
    return new CollectorImpl<>((Supplier<Set<T>>) HashSet::new, Set::add,
                               (left, right) -> { left.addAll(right); return left; },
                               CH_UNORDERED_ID);
}

CH_UNORDERED_ID是一个静态变量,它的特征有两个,一个是IDENTITY_FINISH,表示返回结果即为Supplier创建的HashSet,另一个是UNORDERED,表示收集器不会保留顺序,这也容易理解,因为背后容器是HashSet。

toCollection

toCollection是一个通用的容器收集器,可以用于任何Collection接口的实现类,它接受一个工厂方法Supplier作为参数,具体代码为:

public static <T, C extends Collection<T>>
Collector<T, ?, C> toCollection(Supplier<C> collectionFactory) {
    return new CollectorImpl<>(collectionFactory, Collection<T>::add,
                               (r1, r2) -> { r1.addAll(r2); return r1; },
                               CH_ID);
}

比如,如果希望排重但又希望保留出现的顺序,可以使用LinkedHashSet,Collector可以这么创建:

Collectors.toCollection(LinkedHashSet::new)

toMap

toMap将元素流转换为一个Map,我们知道,Map有键和值两部分,toMap至少需要两个函数参数,一个将元素转换为键,另一个将元素转换为值,其基本定义为:

public static <T, K, U> Collector<T, ?, Map<K,U>> toMap(
    Function<? super T, ? extends K> keyMapper,
    Function<? super T, ? extends U> valueMapper)

返回结果为Map<K,U>,keyMapper将元素转换为键,valueMapper将元素转换为值。比如,将学生流转换为学生名称和分数的Map,代码可以为:

Map<String,Double> nameScoreMap = students.stream().collect(
        Collectors.toMap(Student::getName, Student::getScore));

这里,Student::getName是keyMapper,Student::getScore是valueMapper。

实践中,经常需要将一个对象列表按主键转换为一个Map,以便以后按照主键进行快速查找,比如,假定Student的主键是id,希望转换学生流为学生id和学生对象的Map,代码可以为:

Map<String, Student> byIdMap = students.stream().collect(
        Collectors.toMap(Student::getId, t -> t));

t->t是valueMapper,表示值就是元素本身,这个函数用的比较多,接口Function定义了一个静态函数identity表示它,也就是说,上面的代码可以替换为:

Map<String, Student> byIdMap = students.stream().collect(
        Collectors.toMap(Student::getId, Function.identity()));

上面的toMap假定元素的键不能重复,如果有重复的,会抛出异常,比如:

Map<String,Integer> strLenMap = Stream.of("abc","hello","abc").collect(
        Collectors.toMap(Function.identity(), t->t.length()));

希望得到字符串与其长度的Map,但由于包含重复字符串”abc”,程序会抛出异常。这种情况下,我们希望的是程序忽略后面重复出现的元素,这时,可以使用另一个toMap函数:

public static <T, K, U> Collector<T, ?, Map<K,U>> toMap(
    Function<? super T, ? extends K> keyMapper,
    Function<? super T, ? extends U> valueMapper,
    BinaryOperator<U> mergeFunction)    

相比前面的toMap,它接受一个额外的参数mergeFunction,它用于处理冲突,在收集一个新元素时,如果新元素的键已经存在了,系统会将新元素的值与键对应的旧值一起传递给mergeFunction得到一个值,然后用这个值给键赋值。

对于前面字符串长度的例子,新值与旧值其实是一样的,我们可以用任意一个值,代码可以为:

Map<String,Integer> strLenMap = Stream.of("abc","hello","abc").collect(
        Collectors.toMap(Function.identity(),
                t->t.length(), (oldValue,value)->value));

有时,我们可能希望合并新值与旧值,比如一个联系人列表,对于相同的联系人,我们希望合并电话号码,mergeFunction可以定义为:

BinaryOperator<String> mergeFunction = (oldPhone,phone)->oldPhone+","+phone;

toMap还有一个更为通用的形式:

public static <T, K, U, M extends Map<K, U>> Collector<T, ?, M> toMap(
    Function<? super T, ? extends K> keyMapper,
    Function<? super T, ? extends U> valueMapper,
    BinaryOperator<U> mergeFunction,
    Supplier<M> mapSupplier) 

相比前面的toMap,多了一个mapSupplier,它是Map的工厂方法,对于前面两个toMap,其mapSupplier其实是HashMap::new。我们知道,HashMap是没有任何顺序的,如果希望保持元素出现的顺序,可以替换为LinkedHashMap,如果希望收集的结果排序,可以使用TreeMap。

toMap主要用于顺序流,对于并发流,Collectors有专门的名称为toConcurrentMap的收集器,它内部使用ConcurrentHashMap,用法类似,具体我们就不讨论了。

字符串收集器

除了将元素流收集到容器中,另一个常见的操作是收集为一个字符串。比如,获取所有的学生名称,用逗号连接起来,传统上,代码看上去像这样:

StringBuilder sb = new StringBuilder();
for(Student t : students){
    if(sb.length()>0){
        sb.append(",");
    }
    sb.append(t.getName());
}
return sb.toString();

针对这种常见的需求,Collectors提供了joining收集器:

public static Collector<CharSequence, ?, String> joining()
public static Collector<CharSequence, ?, String> joining(CharSequence delimiter)
public static Collector<CharSequence, ?, String> joining(
    CharSequence delimiter, CharSequence prefix, CharSequence suffix) 

第一个就是简单的把元素连接起来,第二个支持一个分隔符,第三个更为通用,可以给整个结果字符串加个前缀和后缀。比如:

String result = Stream.of("abc","老马","hello")
        .collect(Collectors.joining(",", "[", "]"));
System.out.println(result);                                                  

输出为:

[abc,老马,hello]

joining的内部也利用了StringBuilder澳门新葡亰手机版,,比如,第一个joining函数的代码为:

public static Collector<CharSequence, ?, String> joining() {
    return new CollectorImpl<CharSequence, StringBuilder, String>(
            StringBuilder::new, StringBuilder::append,
            (r1, r2) -> { r1.append(r2); return r1; },
            StringBuilder::toString, CH_NOID);
}

supplier是StringBuilder::new,accumulator是StringBuilder::append,finisher是StringBuilder::toString,CH_NOID表示特征集为空。

分组

分组类似于数据库查询语言SQL中的group
by语句,它将元素流中的每个元素分到一个组,可以针对分组再进行处理和收集,分组的功能比较强大,我们逐步来说明。

为便于举例,我们先修改下学生类Student,增加一个字段grade,表示年级,改下构造方法:

public Student(String name, String grade, double score) {
    this.name = name;
    this.grade = grade;
    this.score = score;
}

示例学生列表students改为:

static List<Student> students = Arrays.asList(new Student[] {
        new Student("zhangsan", "1", 91d),
        new Student("lisi", "2", 89d),
        new Student("wangwu", "1", 50d),
        new Student("zhaoliu", "2", 78d),
        new Student("sunqi", "1", 59d)});            

基本用法

最基本的分组收集器为:

public static <T, K> Collector<T, ?, Map<K, List<T>>>
    groupingBy(Function<? super T, ? extends K> classifier)

参数是一个类型为Function的分组器classifier,它将类型为T的元素转换为类型为K的一个值,这个值表示分组值,所有分组值一样的元素会被归为同一个组,放到一个列表中,所以返回值类型是Map<K,
List<T>>。 比如,将学生流按照年级进行分组,代码为:

Map<String, List<Student>> groups = students.stream()
        .collect(Collectors.groupingBy(Student::getGrade));

学生会分为两组,第一组键为”1″,分组学生包括”zhangsan”,
“wangwu”和”sunqi”,第二组键为”2″,分组学生包括”lisi”, “zhaoliu”。

这段代码基本等同于如下代码:

Map<String, List<Student>> groups = new HashMap<>();
for (Student t : students) {
    String key = t.getGrade();
    List<Student> container = groups.get(key);
    if (container == null) {
        container = new ArrayList<>();
        groups.put(key, container);
    }
    container.add(t);
}
System.out.println(groups);