So, it turns out on a 4090, when you do a float32 matmul with tfloat32 compute, you get ~85TFLOPs.
HOWEVER, when you build a custom kernel that loads operands in float16, performs the compute in tfloat32 and downcasts it again, you get 173 TFLOPs, which is only slightly less than