cancel
Showing results for 
Search instead for 
Did you mean: 

Archives Discussions

gpgpu
Journeyman III

How does the terrible kernel perform matrix-matrix multiply?

kernel void
optimized_matmult(  float loopVar0,
        float4 A1[][], float4 A2[][], float4 A3[][], float4 A4[][],
        float4 A5[][], float4 A6[][], float4 A7[][], float4 A8[][],
        float4 B1[][], float4 B2[][], float4 B3[][], float4 B4[][],
        out float4 C1<>, out float4 C2<>, out float4  C3<>, out float4 C4<>,
        out float4 C5<>, out float4 C6<>, out float4  C7<>, out float4 C8<>
{
 // vPos - Position of the output matrix i.e. (x,y)
 float2 vPos = indexof(C1).xy;
 
 // Setting four210
 float4 four210 = float4(4.0f, 2.0f, 1.0f, 0.0f);
 
 // index - coordinates of A & B from where the values are fetched
 float4 index = float4(vPos.x, vPos.y, four210.w, four210.w);

 // Declaring and initializing accumulators
 float4 accumulator1 = four210.wwww;
 float4 accumulator2 = four210.wwww;
 float4 accumulator3 = four210.wwww;
 float4 accumulator4 = four210.wwww;
 float4 accumulator5 = four210.wwww;
 float4 accumulator6 = four210.wwww;
 float4 accumulator7 = four210.wwww;
 float4 accumulator8 = four210.wwww;
 
 float i0 = loopVar0; 
 
 while(i0 > 0.0f)
 {
   // Fetching values from A
   float4 A11 = A1[index.wy];
   float4 A22 = A2[index.wy];
   float4 A33 = A3[index.wy];
   float4 A44 = A4[index.wy];
   float4 A55 = A5[index.wy];
   float4 A66 = A6[index.wy];
   float4 A77 = A7[index.wy];
   float4 A88 = A8[index.wy];
   
   // Fetching values from B
   float4 B11 = B1[index.xw];
   float4 B22 = B2[index.xw];
   float4 B33 = B3[index.xw];
   float4 B44 = B4[index.xw]; 
   
   accumulator1 += A11.xxxx * B11.xyzw + A11.yyyy * B22.xyzw + A11.zzzz * B33.xyzw + A11.wwww * B44.xyzw;  
   accumulator2 += A22.xxxx * B11.xyzw + A22.yyyy * B22.xyzw + A22.zzzz * B33.xyzw + A22.wwww * B44.xyzw; 
   accumulator3 += A33.xxxx * B11.xyzw + A33.yyyy * B22.xyzw + A33.zzzz * B33.xyzw + A33.wwww * B44.xyzw; 
   accumulator4 += A44.xxxx * B11.xyzw + A44.yyyy * B22.xyzw + A44.zzzz * B33.xyzw + A44.wwww * B44.xyzw; 
   accumulator5 += A55.xxxx * B11.xyzw + A55.yyyy * B22.xyzw + A55.zzzz * B33.xyzw + A55.wwww * B44.xyzw; 
   accumulator6 += A66.xxxx * B11.xyzw + A66.yyyy * B22.xyzw + A66.zzzz * B33.xyzw + A66.wwww * B44.xyzw; 
   accumulator7 += A77.xxxx * B11.xyzw + A77.yyyy * B22.xyzw + A77.zzzz * B33.xyzw + A77.wwww * B44.xyzw; 
   accumulator8 += A88.xxxx * B11.xyzw + A88.yyyy * B22.xyzw + A88.zzzz * B33.xyzw + A88.wwww * B44.xyzw;
   
   index += four210.wwwz;
  // Reducing iterator
  i0 = i0 - 1.0f;
 }
 
 C1 = accumulator1;
 C2 = accumulator2;
 C3 = accumulator3;
 C4 = accumulator4;
 C5 = accumulator5;
 C6 = accumulator6;
 C7 = accumulator7;
 C8 = accumulator8;
}

0 Likes
1 Reply

There is a section of the user guide which is similiar to this, section 3.5.3.

Basically what is happening is the internal part of the loop is doing a 8x4 * 4x4 matrix multiply.
Since there are 8 A input values and each one is 4 elements wide and 4 B input values and each one is 4 elements wide, the matrix multiplication is done perfectly. The only thing to understand is that B goes down the column and A goes across the rows where each A value is the same row from a different input stream.
The data in the input streams are organized so every 8 consecutive rows goes into a seperate input stream.
For example, input stream 0, row 0 is from matrix a row 0, input stream 0, row 1, is matrix a row 8, etc..

Now, normally the B matrix needs to be transposed, however, we do the transposing via the swizzles and not via a traditional transpose. This allows us to transpose B on the fly without having to transpose the B matrix(which has bad performance, see matrix_transpose.exe in cal.h).

Then the output is then interleaved after the kernel is run on the CPU side.

Hope this helps.
0 Likes