gpgpu

How does the terrible kernel perform matrix-matrix multiply?

Discussion created by gpgpu on Oct 30, 2008
Latest reply on Oct 30, 2008 by MicahVillmow

kernel void
optimized_matmult(  float loopVar0,
        float4 A1[][], float4 A2[][], float4 A3[][], float4 A4[][],
        float4 A5[][], float4 A6[][], float4 A7[][], float4 A8[][],
        float4 B1[][], float4 B2[][], float4 B3[][], float4 B4[][],
        out float4 C1<>, out float4 C2<>, out float4  C3<>, out float4 C4<>,
        out float4 C5<>, out float4 C6<>, out float4  C7<>, out float4 C8<>
{
 // vPos - Position of the output matrix i.e. (x,y)
 float2 vPos = indexof(C1).xy;
 
 // Setting four210
 float4 four210 = float4(4.0f, 2.0f, 1.0f, 0.0f);
 
 // index - coordinates of A & B from where the values are fetched
 float4 index = float4(vPos.x, vPos.y, four210.w, four210.w);

 // Declaring and initializing accumulators
 float4 accumulator1 = four210.wwww;
 float4 accumulator2 = four210.wwww;
 float4 accumulator3 = four210.wwww;
 float4 accumulator4 = four210.wwww;
 float4 accumulator5 = four210.wwww;
 float4 accumulator6 = four210.wwww;
 float4 accumulator7 = four210.wwww;
 float4 accumulator8 = four210.wwww;
 
 float i0 = loopVar0; 
 
 while(i0 > 0.0f)
 {
   // Fetching values from A
   float4 A11 = A1[index.wy];
   float4 A22 = A2[index.wy];
   float4 A33 = A3[index.wy];
   float4 A44 = A4[index.wy];
   float4 A55 = A5[index.wy];
   float4 A66 = A6[index.wy];
   float4 A77 = A7[index.wy];
   float4 A88 = A8[index.wy];
   
   // Fetching values from B
   float4 B11 = B1[index.xw];
   float4 B22 = B2[index.xw];
   float4 B33 = B3[index.xw];
   float4 B44 = B4[index.xw]; 
   
   accumulator1 += A11.xxxx * B11.xyzw + A11.yyyy * B22.xyzw + A11.zzzz * B33.xyzw + A11.wwww * B44.xyzw;  
   accumulator2 += A22.xxxx * B11.xyzw + A22.yyyy * B22.xyzw + A22.zzzz * B33.xyzw + A22.wwww * B44.xyzw; 
   accumulator3 += A33.xxxx * B11.xyzw + A33.yyyy * B22.xyzw + A33.zzzz * B33.xyzw + A33.wwww * B44.xyzw; 
   accumulator4 += A44.xxxx * B11.xyzw + A44.yyyy * B22.xyzw + A44.zzzz * B33.xyzw + A44.wwww * B44.xyzw; 
   accumulator5 += A55.xxxx * B11.xyzw + A55.yyyy * B22.xyzw + A55.zzzz * B33.xyzw + A55.wwww * B44.xyzw; 
   accumulator6 += A66.xxxx * B11.xyzw + A66.yyyy * B22.xyzw + A66.zzzz * B33.xyzw + A66.wwww * B44.xyzw; 
   accumulator7 += A77.xxxx * B11.xyzw + A77.yyyy * B22.xyzw + A77.zzzz * B33.xyzw + A77.wwww * B44.xyzw; 
   accumulator8 += A88.xxxx * B11.xyzw + A88.yyyy * B22.xyzw + A88.zzzz * B33.xyzw + A88.wwww * B44.xyzw;
   
   index += four210.wwwz;
  // Reducing iterator
  i0 = i0 - 1.0f;
 }
 
 C1 = accumulator1;
 C2 = accumulator2;
 C3 = accumulator3;
 C4 = accumulator4;
 C5 = accumulator5;
 C6 = accumulator6;
 C7 = accumulator7;
 C8 = accumulator8;
}

Outcomes