itertoolsのgroupbyやchainで集計やってみる

Pythonイテレータ生成の関数はいろいろあるね。
先日MySQLで行った、
集計メモ - 偏った言語信者の垂れ流し
を同様にPythonでやってみる。

from datetime import date
from itertools import groupby, chain

# (user_id, num_count, cdate)
access_count = [
    (1, 10, date(2010, 8, 19)),
    (2, 1, date(2010, 8, 19)),
    (4, 5, date(2010, 8, 19)),
    (1, 4, date(2010, 8, 20)),
    (3, 3, date(2010, 8, 20)),
    (4, 2, date(2010, 8, 20)),
    (2, 4, date(2010, 8, 21)),
    (1, 7, date(2010, 8, 22)),
    (3, 20, date(2010, 8, 23)),
    (2, 100, date(2010, 8, 24)),
]

def main():
    print 'group by user_id'
    results = [(user_id, sum([num_count for _, num_count, _ in g])) for user_id, g in groupby(sorted(access_count, key=lambda x: x[0]), lambda y: y[0])]
    for row in results:
        print row

    print 'between 2010-08-19 and 2010-08-21'
    results = filter(lambda x: date(2010, 8, 19) <= x[2] <= date(2010, 8, 21), access_count)
    for row in results:
        print row

    print 'union all'
    results = chain(
        [(user_id, sum([num_count for _, num_count, _ in g]), 0) for user_id, g in groupby(sorted(access_count, key=lambda x: x[0]), lambda y: y[0])],
        [(user_id, 0, sum([num_count for _, num_count, _ in g])) for user_id, g in groupby(sorted(filter(lambda x: date(2010, 8, 19) <= x[2] <= date(2010, 8, 21), access_count), key=lambda x: x[0]), lambda y: y[0])],
    )
    for row in results:
        print row

    print 'union all + group by user_id'
    #(user_id, sum([num_count for _, num_count, _ in list(gg)]), sum([num_count for _, _, num_count in list(gg)]))
    results = [(user_id, count_all[1], count_3day[2]) for user_id, (count_all, count_3day) in groupby(sorted(
        chain(
            [(user_id, sum([num_count for _, num_count, _ in g]), 0) for user_id, g in groupby(sorted(access_count, key=lambda x: x[0]), lambda y: y[0])],
            [(user_id, 0, sum([num_count for _, num_count, _ in g])) for user_id, g in groupby(sorted(filter(lambda x: date(2010, 8, 19) <= x[2] <= date(2010, 8, 21), access_count), key=lambda x: x[0]), lambda y: y[0])],
        ),
        key=lambda v: v[0]
    ), lambda z: z[0])]
    for row in results:
        print row

if __name__ == '__main__':
    main()

まあ、さすがに改行なしで書くと読みづらいことこの上ないけど。

>python grouping.py
group by user_id
(1, 21)
(2, 105)
(3, 23)
(4, 7)
between 2010-08-19 and 2010-08-21
(1, 10, datetime.date(2010, 8, 19))
(2, 1, datetime.date(2010, 8, 19))
(4, 5, datetime.date(2010, 8, 19))
(1, 4, datetime.date(2010, 8, 20))
(3, 3, datetime.date(2010, 8, 20))
(4, 2, datetime.date(2010, 8, 20))
(2, 4, datetime.date(2010, 8, 21))
union all
(1, 21, 0)
(2, 105, 0)
(3, 23, 0)
(4, 7, 0)
(1, 0, 14)
(2, 0, 5)
(3, 0, 3)
(4, 0, 7)
union all + group by user_id
(1, 21, 14)
(2, 105, 5)
(3, 23, 3)
(4, 7, 7)