k3kaimu
9/16/2015 - 12:58 PM

SSEを使った複素数の内積

SSEを使った複素数の内積

import core.simd;
import std.stdio;
import std.datetime;
import std.math;
import std.parallelism;
import std.range;
import std.random;


cfloat dotProduct(in Vector!(float[4])[] a, in Vector!(float[4])[] b)
in{
    assert(a.length == b.length);
}
body{
    Vector!(float[4]) r, q;
    r = 0;
    q = 0;

    auto px = a.ptr,
         ph = b.ptr,
         qx = a.ptr + a.length;

    while(px != qx)
    {
        Vector!(float[4]) x0 = *px,
                          h = *ph;
        r += x0 * h;

        x0 = __simd(XMM.SHUFPS, x0, x0, 0b10_11_00_01);

        q += x0 * h;

        ++px;
        ++ph;
    }

    Vector!(float[4]) sign, ones;
    sign.array = [1.0f, -1.0f, 1.0f, -1.0f];
    ones.array = [1.0f, 1.0f, 1.0f, 1.0f];

    r = __simd(XMM.DPPS, r, sign, 0b11111111);
    q = __simd(XMM.DPPS, q, ones, 0b11111111);

    return r.array[0] + q.array[0]*1i;
}


void main()
{
    enum size_t N = 1024;
    enum size_t Times = 1024*32;

    Vector!(float[4])[] xs = new Vector!(float[4])[N],
                        hs = xs.dup;

    cfloat[] cxs = (cast(cfloat*)xs.ptr)[0 .. N*2],
             chs = (cast(cfloat*)hs.ptr)[0 .. N*2];

    foreach(i, ref e; cxs)
        e = uniform01() + uniform01()*1i;

    foreach(i, ref e; chs)
        e = uniform01() + uniform01()*2i;

    cfloat[5] res;
    res[0] = 0+0i;
    res[1] = 0+0i;

    auto start = Clock.currTime;
    {
        foreach(times; 0 .. Times)
        {
            res[0] = 0+0i;
            foreach(i; 0 .. N*2)
                res[0] += cxs[i] * chs[i];
        }
    }
    auto bnch1 = Clock.currTime - start;

    start = Clock.currTime;
    {
        foreach(times; 0 .. Times)
        {
            res[1] = 0+0i;
            auto px = cxs.ptr,
                 qx = px + cxs.length,
                 ph = chs.ptr;

            while(px != qx)
            {
                res[1] += (*px) * (*ph);
                ++px;
                ++ph;
            }
        }
    }
    auto bnch2 = Clock.currTime - start;

    start = Clock.currTime;
    {
        foreach(times; 0 .. Times)
        {
            res[2] = 0+0i;
            auto px = cxs.ptr,
                 qx = px + cxs.length,
                 ph = chs.ptr;

            while(px != qx)
            {
                res[2] += *(px+0) * *(ph+0);
                res[2] += *(px+1) * *(ph+1);
                res[2] += *(px+2) * *(ph+2);
                res[2] += *(px+3) * *(ph+3);
                res[2] += *(px+4) * *(ph+4);
                res[2] += *(px+5) * *(ph+5);
                res[2] += *(px+6) * *(ph+6);
                res[2] += *(px+7) * *(ph+7);
                px += 8;
                ph += 8;
            }
        }
    }
    auto bnch3 = Clock.currTime - start;

    start = Clock.currTime;
    {
        foreach(times; 0 .. Times)
            res[3] = dotProduct(xs, hs);
    }
    auto bnch4 = Clock.currTime - start;

    start = Clock.currTime;
    {
        foreach(times; parallel(iota(0, Times)))
            res[4] = dotProduct(xs, hs);  // やばそうだけど無視
    }
    auto bnch5 = Clock.currTime - start;

    /*
    459[ms], 146.067[Msps]
    257[ms], 260.873[Msps]
    212[ms], 316.248[Msps]
    92[ms], 728.747[Msps]
    21[ms], 3192.77[Msps]
    -556.742+1556.79i == -556.742+1556.79i == -556.742+1556.79i == -556.741+1556.79i == -556.741+1556.79i
    */
    writefln("%s[ms], %s[Msps]", bnch1.total!"msecs", cxs.length * Times * 1.0 / (bnch1.total!"usecs"));
    writefln("%s[ms], %s[Msps]", bnch2.total!"msecs", cxs.length * Times * 1.0 / (bnch2.total!"usecs"));
    writefln("%s[ms], %s[Msps]", bnch3.total!"msecs", cxs.length * Times * 1.0 / (bnch3.total!"usecs"));
    writefln("%s[ms], %s[Msps]", bnch4.total!"msecs", cxs.length * Times * 1.0 / (bnch4.total!"usecs"));
    writefln("%s[ms], %s[Msps]", bnch5.total!"msecs", cxs.length * Times * 1.0 / (bnch5.total!"usecs"));
    writefln("%(%s == %)", res[]);
}