Python

以后就玩Python了。交流QQ群:250329766

分治法举例之矩阵乘法

直观上感觉矩阵乘法是没法优化。但事实上可以。 这说明真没有那么多做不到的事,可能只是现在的你做不来,说不定有人能做到,说不定做到的那个人就是未来的你。


分治法举例之矩阵乘法

前言

矩阵按定义直接实现是比较直接简单的。时间复杂度也可以直接得出来是\(O(n^3)\)

直接实现

class matrix:
    '''
    为了简单起见,没有对数据做校验。
    假设矩阵就是nxn的。
    '''
    def __init__(self, data):
        if not data or not hasattr(data, '__getitem__'):
            raise ValueError("data not valid! %s" % data)
        self.data = data
        self.rows = len(data)
        self.cols = max(map(lambda row: len(row), data))

    def __mul__(self, another):
        if self.cols != another.rows:
            raise ValueError("not valid ddata ,only support mxn * nxp")
        ret = matrix([[0 for _ in range(another.cols)] for _ in range(self.rows)])
        for i in range(self.rows):
            for j in range(another.cols):
                num = 0
                for k in range(self.cols):
                    num += self._getitem(i, k) * another._getitem(k, j)
                ret._setitem(i, j, num)
        return ret

    def _getitem(self, i, j):
        if i >= self.rows or j >= self.cols:
            raise IndexError("index out of boundary,i=%d,j=%d, %s" % (
                i, j, self.data))
        try:
            return self.data[i][j]
        except Exception:
            return 0

    def _setitem(self, i, j, value):
        if i >= self.rows or j >= self.cols:
            raise IndexError("index out of boundary,i=%d,j=%d,value=%s, %s" % (
                i, j, str(value), self.data))
        if j >= len(self.data[i]):
            fill = self.cols - len(self.data[i])
            self.data[i].extend([0 for _ in range(fill)])
        self.data[i][j] = value

    def __str__(self):
        return "(rows:%d, cols:%d)->%s" % (self.rows, self.cols, self.data)

直接应用分治法

很简单的想法是把矩阵分成n/2的4块。于是
A*B = C
变成
\[
\begin{pmatrix}
a & b \\
c & d
\end{pmatrix}
*
\begin{pmatrix}
e & f \\
g & h
\end{pmatrix}
=
\begin{pmatrix}
r & s \\
t & u
\end{pmatrix}
\]

r = ae + bg
s = af + bh
t = ce + dg
u = cf + dh

于是有\(T(n) = 8T(n/2)+O(n^2)\)
\(O(n^{log_ba}) = O(n^{log_28})=O(n^3) > f(n)\)
满足主定理的第一种情况,于是\(T(n) = \Theta(n^3)\),并没有比直接实现快。

斯特拉森算法

斯特拉森于1969年提出的算法,运用分治策略并加上一些处理技巧设计出的一种矩阵乘法。
他巧妙地8变成了7。于是达到了\(T(n) = \Theta(n^{log_27})\approx\Theta(n^{2.81})\)
这还不是目前理论上最好的,暂时最好的达到了\(T(n) =\Theta(n^{2.376})\)

看一下他是怎么玩的。

P1 = (a+d) * (e+h)
P2 = (c+d) * e
P3 = a * (f-h)
P4 = d * (g-e)
P5 = (a+b) * h
P6 = (c-a) * (e+f)
P7 = (b-d) * (g+h)

利用此7个式子即可得到原来的r,s,t,u
r = P1 + P4 - P5 + P7
s = P3 + P5
t = P2 + P4
u = P1 + P3 -P2 + P6
验证一下u看看

u = P1 + P3 -P2 + P6
= (a+d) * (e+h) + a * (f-h) -((c+d) * e) + (c-a) * (e+f)
=ae + ah + de + dh + af - ah -ce - de + ce + cf -ae -af
= dh + cf

正确

代码实现如下:

#!/usr/bin/env python
from enum import Enum, IntEnum, unique
import sys

class matrix:
    '''
    为了简单起见,没有对数据做校验。
    假设矩阵就是nxn的。
    '''
    def __init__(self, data):
        if not data or not hasattr(data, '__getitem__'):
            raise ValueError("data not valid! %s" % data)
        self.data = data
        self.rows = len(data)
        self.cols = max(map(lambda row: len(row), data))
        if self.rows != self.cols:
            raise ValueError("only support nxn matrix, and n can continue divide by 2 util 1")

    def __add__(self, another):
        if self.rows != another.rows:
            raise ValueError("not valid ddata ,only support nxn * nxn")
        ret = matrix([[0]*self.rows for _ in range(self.rows)])
        for i in range(self.rows):
            for j in range(self.rows):
                ret._setitem(i, j, self._getitem(i, j) + another._getitem(i, j))

        return ret

    def __sub__(self, another):
        if self.rows != another.rows:
            raise ValueError("not valid ddata ,only support nxn * nxn")
        ret = matrix([[0]*self.rows for _ in range(self.rows)])
        for i in range(self.rows):
            for j in range(self.rows):
                ret._setitem(i, j, self._getitem(i, j) - another._getitem(i, j))

        return ret

    def __mul__(self, another):
        if self.rows != another.rows:
            raise ValueError("not valid ddata ,only support nxn * nxn")
        ret = matrix([[0]*self.rows for _ in range(self.rows)])
        if self.rows == 2:
            for i in range(self.rows):
                for j in range(another.cols):
                    num = 0
                    for k in range(self.cols):
                        num += self._getitem(i, k) * another._getitem(k, j)
                    ret._setitem(i, j, num)
        else:
            a = self._divide(matrix.DIRECTION.LEFT_TOP)
            b = self._divide(matrix.DIRECTION.RIGHT_TOP)
            c = self._divide(matrix.DIRECTION.LEFT_BOTTOM)
            d = self._divide(matrix.DIRECTION.RIGHT_BOTTOM)

            e = another._divide(matrix.DIRECTION.LEFT_TOP)
            f = another._divide(matrix.DIRECTION.RIGHT_TOP)
            g = another._divide(matrix.DIRECTION.LEFT_BOTTOM)
            h = another._divide(matrix.DIRECTION.RIGHT_BOTTOM)

            p1 = (a+d)*(e+h)
            p2 = (c+d)*e
            p3 = a * (f-h)
            p4 = d * (g-e)
            p5 = (a+b)*h
            p6 = (c-a)*(e+f)
            p7 = (b-d)*(g+h)

            r = p1 + p4 - p5 + p7
            s = p3 + p5
            t = p2 + p4
            u = p1 + p3 - p2 + p6

            ret._merge(matrix.DIRECTION.LEFT_TOP, r)
            ret._merge(matrix.DIRECTION.RIGHT_TOP, s)
            ret._merge(matrix.DIRECTION.LEFT_BOTTOM, t)
            ret._merge(matrix.DIRECTION.RIGHT_BOTTOM, u)

        return ret

    @unique
    class DIRECTION (IntEnum):
        LEFT_TOP = 1
        LEFT_BOTTOM = 2
        RIGHT_TOP = 3
        RIGHT_BOTTOM = 4

    def _divide(self, direction):
        ret = matrix([[0]*int(self.rows/2) for _ in range(int(self.rows/2))])
        row_start = col_start = 0
        if direction == matrix.DIRECTION.LEFT_TOP:
            row_start = 0
            col_start = 0
        elif direction == matrix.DIRECTION.LEFT_BOTTOM:
            row_start = int(self.rows/2)
            col_start = 0
        elif direction == matrix.DIRECTION.RIGHT_TOP:
            row_start = 0
            col_start = int(self.cols/2)
        else:
            row_start = int(self.rows/2)
            col_start = int(self.cols/2)

        for i in range(ret.rows):
            for j in range(ret.cols):
                item = self._getitem(i+row_start, j+col_start)
                ret._setitem(i, j, item)

        return ret

    def _merge(self, direction, another):
        row_start = col_start = 0
        if direction == matrix.DIRECTION.LEFT_TOP:
            row_start = 0
            col_start = 0
        elif direction == matrix.DIRECTION.LEFT_BOTTOM:
            row_start = int(self.rows/2)
            col_start = 0
        elif direction == matrix.DIRECTION.RIGHT_TOP:
            row_start = 0
            col_start = int(self.cols/2)
        else:
            row_start = int(self.rows/2)
            col_start = int(self.cols/2)

        for i in range(another.rows):
            for j in range(another.cols):
                item = another._getitem(i, j)
                self._setitem(i+row_start, j+col_start, item)



    def _getitem(self, i, j):
        if i >= self.rows or j >= self.cols:
            raise IndexError("index out of boundary,i=%d,j=%d, %s" % (
                i, j, self.data))
        try:
            return self.data[i][j]
        except Exception:
            return 0

    def _setitem(self, i, j, value):
        if i >= self.rows or j >= self.cols:
            raise IndexError("index out of boundary,i=%d,j=%d,value=%s, %s" % (
                i, j, str(value), self.data))
        if j >= len(self.data[i]):
            fill = self.cols - len(self.data[i])
            self.data[i].extend([0 for _ in range(fill)])
        self.data[i][j] = value

    def __str__(self):
        return "(rows:%d, cols:%d)->%s" % (self.rows, self.cols, self.data)

测试结果

方法 规模 时间
直接计算 8x8 0.054
拉特斯森 8x8 0.095
直接计算 16x16 0.063
拉特斯森 16x16 0.117
直接计算 32x32 0.090
拉特斯森 32x32 0.454
直接计算 64x64 0.419
拉特斯森 64x64 2.953
直接计算 128x128 2.946
拉特斯森 128x128 20.547
直接计算 256x256 24.835
拉特斯森 256x256 2:15.15
直接计算 512x512 3:15.98

总结

估计是我实现的问题,比预期结果要差,看视频里说的是到32就差不多了。
不过从上也可以看出来,拉特斯森增长的速度没有直接计算的快,迟早性能会更好。

另外,直观上感觉矩阵乘法是没法优化。但事实上可以。
这说明真没有那么多做不到的事,可能只是现在的你做不来,说不定有人能做到,说不定做到的那个人就是未来的你。

博文最后更新时间:


评论

  • 暂无评论

发表评论