math: add matrix multiplication

Signed-off-by: Stephen Gutekanst <stephen@hexops.com>
This commit is contained in:
Stephen Gutekanst 2023-09-08 22:36:10 -07:00
parent 7e8e1c03b9
commit d814bb1527

View file

@ -277,96 +277,26 @@ pub fn Mat(
else => @compileError("Expected Mat3x3, Mat4x4 found '" ++ @typeName(Matrix) ++ "'"), else => @compileError("Expected Mat3x3, Mat4x4 found '" ++ @typeName(Matrix) ++ "'"),
}; };
/// Matrix multiplication a*b
// TODO: needs tests
pub fn mul(a: Matrix, b: Matrix) Matrix {
var result: Matrix = undefined;
inline for (0..Matrix.rows) |row| {
inline for (0..Matrix.cols) |col| {
var sum: RowVec.T = 0.0;
inline for (0..RowVec.n) |i| {
sum += a.row(row).mul(b.col(col)).v[i];
}
result.v[col].v[row] = sum;
}
}
return result;
}
// TODO: the below code was correct in our old implementation, it just needs to be updated // TODO: the below code was correct in our old implementation, it just needs to be updated
// to work with this new Mat approach, swapping f32 for the generic T float type, moving 3x3 // to work with this new Mat approach, swapping f32 for the generic T float type, moving 3x3
// and 4x4 specific functions into the mixin above, writing new tests, etc. // and 4x4 specific functions into the mixin above, writing new tests, etc.
// // Multiplies matrices a * b
// pub inline fn mul(a: anytype, b: @TypeOf(a)) @TypeOf(a) {
// return if (@TypeOf(a) == Mat3x3) {
// const a00 = a[0][0];
// const a01 = a[0][1];
// const a02 = a[0][2];
// const a10 = a[1][0];
// const a11 = a[1][1];
// const a12 = a[1][2];
// const a20 = a[2][0];
// const a21 = a[2][1];
// const a22 = a[2][2];
// const b00 = b[0][0];
// const b01 = b[0][1];
// const b02 = b[0][2];
// const b10 = b[1][0];
// const b11 = b[1][1];
// const b12 = b[1][2];
// const b20 = b[2][0];
// const b21 = b[2][1];
// const b22 = b[2][2];
// return init(Mat3x3, .{
// a00 * b00 + a10 * b01 + a20 * b02,
// a01 * b00 + a11 * b01 + a21 * b02,
// a02 * b00 + a12 * b01 + a22 * b02,
// a00 * b10 + a10 * b11 + a20 * b12,
// a01 * b10 + a11 * b11 + a21 * b12,
// a02 * b10 + a12 * b11 + a22 * b12,
// a00 * b20 + a10 * b21 + a20 * b22,
// a01 * b20 + a11 * b21 + a21 * b22,
// a02 * b20 + a12 * b21 + a22 * b22,
// });
// } else if (@TypeOf(a) == Mat4x4) {
// const a00 = a[0][0];
// const a01 = a[0][1];
// const a02 = a[0][2];
// const a03 = a[0][3];
// const a10 = a[1][0];
// const a11 = a[1][1];
// const a12 = a[1][2];
// const a13 = a[1][3];
// const a20 = a[2][0];
// const a21 = a[2][1];
// const a22 = a[2][2];
// const a23 = a[2][3];
// const a30 = a[3][0];
// const a31 = a[3][1];
// const a32 = a[3][2];
// const a33 = a[3][3];
// const b00 = b[0][0];
// const b01 = b[0][1];
// const b02 = b[0][2];
// const b03 = b[0][3];
// const b10 = b[1][0];
// const b11 = b[1][1];
// const b12 = b[1][2];
// const b13 = b[1][3];
// const b20 = b[2][0];
// const b21 = b[2][1];
// const b22 = b[2][2];
// const b23 = b[2][3];
// const b30 = b[3][0];
// const b31 = b[3][1];
// const b32 = b[3][2];
// const b33 = b[3][3];
// return init(Mat4x4, .{
// a00 * b00 + a10 * b01 + a20 * b02 + a30 * b03,
// a01 * b00 + a11 * b01 + a21 * b02 + a31 * b03,
// a02 * b00 + a12 * b01 + a22 * b02 + a32 * b03,
// a03 * b00 + a13 * b01 + a23 * b02 + a33 * b03,
// a00 * b10 + a10 * b11 + a20 * b12 + a30 * b13,
// a01 * b10 + a11 * b11 + a21 * b12 + a31 * b13,
// a02 * b10 + a12 * b11 + a22 * b12 + a32 * b13,
// a03 * b10 + a13 * b11 + a23 * b12 + a33 * b13,
// a00 * b20 + a10 * b21 + a20 * b22 + a30 * b23,
// a01 * b20 + a11 * b21 + a21 * b22 + a31 * b23,
// a02 * b20 + a12 * b21 + a22 * b22 + a32 * b23,
// a03 * b20 + a13 * b21 + a23 * b22 + a33 * b23,
// a00 * b30 + a10 * b31 + a20 * b32 + a30 * b33,
// a01 * b30 + a11 * b31 + a21 * b32 + a31 * b33,
// a02 * b30 + a12 * b31 + a22 * b32 + a32 * b33,
// a03 * b30 + a13 * b31 + a23 * b32 + a33 * b33,
// });
// } else @compileError("Expected matrix, found '" ++ @typeName(@TypeOf(a)) ++ "'");
// }
// /// Check if two matrices are approximate equal. Returns true if the absolute difference between // /// Check if two matrices are approximate equal. Returns true if the absolute difference between
// /// each element in matrix them is less or equal than the specified tolerance. // /// each element in matrix them is less or equal than the specified tolerance.
// pub inline fn equals(a: anytype, b: @TypeOf(a), tolerance: f32) bool { // pub inline fn equals(a: anytype, b: @TypeOf(a), tolerance: f32) bool {
@ -816,55 +746,3 @@ test "Mat4x4_transpose" {
// } // }
// } // }
// } // }
// test "mat.mul" {
// {
// const tolerance = 1e-6;
// const t = Vec3{ 1, 2, -3 };
// const T = mat.translate3d(t);
// const s = Vec3{ 3, 1, -5 };
// const S = mat.scale3d(s);
// const r = Vec3{ 30, -40, 235 };
// const R_x = mat.rotateX(degreesToRadians(f32, r[0]));
// const R_y = mat.rotateY(degreesToRadians(f32, r[1]));
// const R_z = mat.rotateZ(degreesToRadians(f32, r[2]));
// const R_yz = mat.mul(R_y, R_z);
// // This values are calculated by hand with help of matrix calculator: https://matrix.reshish.com/multCalculation.php
// const expected_R_yz = mat.init(Mat4x4, .{
// -0.43938504177070496278, -0.8191520442889918, -0.36868782649461236545, 0,
// 0.62750687159713312638, -0.573576436351046, 0.52654078451836329713, 0,
// -0.6427876096865394, 0, 0.766044443118978, 0,
// 0, 0, 0, 1,
// });
// try expect(mat.equals(R_yz, expected_R_yz, tolerance));
// const R_xyz = mat.mul(R_x, R_yz);
// const expected_R_xyz = mat.init(Mat4x4, .{
// -0.439385041770705, -0.52506256666891627986, -0.72886904595489960019, 0,
// 0.6275068715971331, -0.76000215715133560834, 0.16920947734596765363, 0,
// -0.6427876096865394, -0.383022221559489, 0.66341394816893832989, 0,
// 0, 0, 0, 1,
// });
// try expect(mat.equals(R_xyz, expected_R_xyz, tolerance));
// const SR = mat.mul(S, R_xyz);
// const expected_SR = mat.init(Mat4x4, .{
// -1.318155125312115, -0.5250625666689163, 3.6443452297744985, 0,
// 1.8825206147913993, -0.7600021571513356, -0.8460473867298382, 0,
// -1.9283628290596182, -0.383022221559489, -3.3170697408446915, 0,
// 0, 0, 0, 1,
// });
// try expect(mat.equals(SR, expected_SR, tolerance));
// const TSR = mat.mul(T, SR);
// const expected_TSR = mat.init(Mat4x4, .{
// -1.318155125312115, -0.5250625666689163, 3.6443452297744985, 0,
// 1.8825206147913993, -0.7600021571513356, -0.8460473867298382, 0,
// -1.9283628290596182, -0.383022221559489, -3.3170697408446914, 0,
// 1, 2, -3, 1,
// });
// try expect(mat.equals(TSR, expected_TSR, tolerance));
// }
// }