2022年4月28日 星期四

初探 clang 14.0 中的 Matrix Type

先前 Clang 14.0 的文章中提到 Clang 14.0 新增了 Matrix Type, 出於個人的好奇, 於是就開始動手玩了一下, 這裡先說結論, Clang 14.0 中的 Matrix Type 還是開發階段, 尚不太好操作. 這裡嘗試使用 Matrix Type 的應用是 Tiled Matrix Multiplication.

比較基準的 vector type 實作

首先是先前已經以 Clang Vector Extension 所實作的 float 型別 8x8 矩陣乘法

typedef float float8 __attribute__((ext_vector_type(8)));
void gemm_vec(
    // buffer, stride
    float *a, int sa,  
    float *b, int sb,
    float *c, int sc
){
    float8 vb[8];
    for(int y = 0; y < 8; y++)
        vb[y] = *((float8*)(b + sb*y));
    }
    for(int y = 0; y < 8; y++){
        float8 vc = *((float8*)(c + sc*y));
        float8 va = *((float8*)(a + sa*y));
        for(int x = 0; x < 8; x++)
            vc += va[x]*vb[x];
        *((float8*)(c + sc*y)) = vc;
    }
}

以 matrix type 實作的過程

接著是參考 Clang Matrix Type 的說明嘗試定義 8x8 的 Matrix Type:

typedef float float8x8 __attribute__((matrix_type(8, 8)));

首先值得一提的是定義的 matrix type 所宣告的變數, 類似於 vector type 內作為 1D array, matrix type 變數能作為 2D array 的語法來存取個別數值. 這裡定義好 float8x8 型別之後, 必須宣告變數並將 2D 資料載入, 於是會發現文件中並沒有說明如何載入資料! 這部份必須參考目前的 Draft Specification, 目前 Clang 僅支援 3 個 function:

  • M2 __builtin_matrix_transpose(M1 matrix)
  • M __builtin_matrix_column_major_load(T *ptr, size_t row, size_t col, size_t columnStride)
  • void __builtin_matrix_column_major_store(M matrix, T *ptr, size_t columnStride)

總之, 資料載入的方式有了, 那麼 column major load /store 的意義為何? 這主要是載入資料的順序, 參考 Wikipedia 關於 Row- and column-major order 頁面的圖解, 可以得知對於 column major load 的意義在於:

有點不太理解 clang 第一時間所支援的竟然不是 row major load / store. 然而使用 column major load store 的結果是所得到的 matrix 為原本所需的 transposed matrix. Anyway, 由於我們只有這個讀取方式, 所以我們先讀取原先 function 中的 a, b, c 矩陣:

float8x8 ma, mb, mc;
ma = __builtin_matrix_column_major_load(a, 8, 8, sa);
mb = __builtin_matrix_column_major_load(b, 8, 8, sb);
mc = __builtin_matrix_column_major_load(c, 8, 8, sc);

然而我們還是可以動些手腳來達成原本的 Tiled Matrix Multiplication, 首先我們原本要計算的資料為

C += A x B

然而這當中的 A, B, C, 透過 __builtin_matrix_column_major_load 所得到的已經是 transposed matrix (這裡我們用 A', B', C' 代表), 因此我們可以透過預先透過 __builtin_matrix_transpose 來轉回 A, B, C. 正確計算後能再轉回原先的 column major matrix 來存回正確的數值. 

ma = __builtin_matrix_transpose(ma);
mb = __builtin_matrix_transpose(mb);
mc = __builtin_matrix_transpose(mc);
mc += ma*mb;
mc = __builtin_matrix_transpose(mc); //transposed again
__builtin_matrix_column_major_store(mc, c, sc);

這裡使用了四次 __builtin_matrix_transpose, 僅僅是為了達到計算上正確的 C=AxB, 由於在回存時必須保持 transposed 狀態, 所以如果計算時能產生對應 C' 的資料即可, 於是我們可以這麼計算:

C'+=B'xA'
於是我們整個實作就可以簡化如下:

void gemm_vec(
    float *a, int sa,
    float *b, int sb,
    float *c, int sc
){
    float8x8 ma, mb, mc;
    ma = __builtin_matrix_column_major_load(a, 8, 8, sa);
    mb = __builtin_matrix_column_major_load(b, 8, 8, sb);
    mc = __builtin_matrix_column_major_load(c, 8, 8, sc);
    mc += mb * ma;
    __builtin_matrix_column_major_store(mc, c, sc);
}

最後是編譯時要注意, clang 必須加入 -fenable-matrix 這個參數

以這種實作搭配 AVX2 在 Intel Core i5-8350U 測試, 對於 naive, vector type 與 matrix type 來測試用於兩個 512x512 matrix multiplication 的效能

因此 single thread 來執行 naive, vector type, matrix type 三種實作的時間分別為 129.9ms, 26.0ms, 27.0ms. 因此以 matrix type 可以得到十分接近的 vector extension 的性能, 但是實作上相當簡潔易懂.

到此許多人應該會關心 clang 未來是否支援 row-major 的形式? 個人認為是會的, 首先 Intel AMX 指令集所使用的即是 row major 的 load/store. 再者是在 2020 年 LLVM 開發者大會中 "Matrix Support in LLVM and Clang" 的演講的倒數第二張投影片 "Remaining Work" 明確標示了 Row-major support.


使用後個人想提的幾個 idea:

定義 matrix 的同時, 應同時定義出 row 與 column 的 vector 型別, 目前不知如何取得定義語法, 這裡暫時以 ".row_vec" 與 ".col_vec" 來表示. 也就是說當我們定義了:

typedef float float8x8 __attribute__((matrix_type(8, 8)));

那麼可以用此定義一併定義出對應用在 row 與 col 的 vector type (或是分開定義但是 compiler 偵測 vector length):

float8x8.row_vec vrow, vrow2;
float8x8.col_vec vcol, vcol2;

如此可以合併使用 vector extension 與 matrix type 來處理 row major 或是 column major 的數值設定:

float8x8 mat;
...
vrow = *((
float8x8.row_vec*)ptr_row);
vcol = *((float8x8.col_vec*)ptr_col);
mat.row[1] = vrow;        // set a row
mat.col[2] = vcol;        // set a column

另外可以做 row vector 與  column vector 與矩陣相乘運算:

float8x8 mat;
...
vrow = *((
float8x8.row_vec*)ptr_row);
vcol = *((float8x8.col_vec*)ptr_col);
vrow2 = vcol*mat;
vcol2 = mat*vrow;


 

沒有留言:

在 ARM 平台上使用 Function Multi-Versioning (FMV) - 以使用 Android NDK 為例

Function Multi-Versioning (FMV) 過往的 CPU 發展歷程中, x86 平台由於因應各種應用需求的提出, 而陸陸續續加入了不同的指令集, 此外也可能因為針對市場做等級區隔, 支援的數量與種類也不等. 在 Linux 平台上這些 CPU 資訊可以透過...