target/arm: Implement bfloat16 matrix multiply accumulate

This is BFMMLA for both AArch64 AdvSIMD and SVE,
and VMMLA.BF16 for AArch32 NEON.

Reviewed-by: Peter Maydell <peter.maydell@linaro.org>
Signed-off-by: Richard Henderson <richard.henderson@linaro.org>
Message-id: 20210525225817.400336-9-richard.henderson@linaro.org
Signed-off-by: Peter Maydell <peter.maydell@linaro.org>
This commit is contained in:
Richard Henderson 2021-05-25 15:58:13 -07:00 committed by Peter Maydell
parent 839144784b
commit 81266a1f58
7 changed files with 81 additions and 3 deletions

View file

@ -2385,7 +2385,7 @@ static void do_mmla_b(void *vd, void *vn, void *vm, void *va, uint32_t desc,
* Process the entire segment at once, writing back the
* results only after we've consumed all of the inputs.
*
* Key to indicies by column:
* Key to indices by column:
* i j i j
*/
sum0 = a[H4(0 + 0)];
@ -2472,3 +2472,43 @@ void HELPER(gvec_bfdot_idx)(void *vd, void *vn, void *vm,
}
clear_tail(d, opr_sz, simd_maxsz(desc));
}
void HELPER(gvec_bfmmla)(void *vd, void *vn, void *vm, void *va, uint32_t desc)
{
intptr_t s, opr_sz = simd_oprsz(desc);
float32 *d = vd, *a = va;
uint32_t *n = vn, *m = vm;
for (s = 0; s < opr_sz / 4; s += 4) {
float32 sum00, sum01, sum10, sum11;
/*
* Process the entire segment at once, writing back the
* results only after we've consumed all of the inputs.
*
* Key to indicies by column:
* i j i k j k
*/
sum00 = a[s + H4(0 + 0)];
sum00 = bfdotadd(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)]);
sum00 = bfdotadd(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)]);
sum01 = a[s + H4(0 + 1)];
sum01 = bfdotadd(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)]);
sum01 = bfdotadd(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)]);
sum10 = a[s + H4(2 + 0)];
sum10 = bfdotadd(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)]);
sum10 = bfdotadd(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)]);
sum11 = a[s + H4(2 + 1)];
sum11 = bfdotadd(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)]);
sum11 = bfdotadd(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)]);
d[s + H4(0 + 0)] = sum00;
d[s + H4(0 + 1)] = sum01;
d[s + H4(2 + 0)] = sum10;
d[s + H4(2 + 1)] = sum11;
}
clear_tail(d, opr_sz, simd_maxsz(desc));
}