From 93357bd358a336a85826779f84cf15d8dab5fa8c Mon Sep 17 00:00:00 2001 From: Olle Lukowski <63189113+Olle-Lukowski@users.noreply.github.com> Date: Fri, 15 Sep 2023 19:38:46 +0200 Subject: [PATCH] math: add initial Quaternion implementation (#1000) --- src/math/quat.zig | 420 +++++++++++++++++++++++++++++++++++++++++++++- src/math/vec.zig | 2 +- 2 files changed, 420 insertions(+), 2 deletions(-) diff --git a/src/math/quat.zig b/src/math/quat.zig index c5c09648..9a980e95 100644 --- a/src/math/quat.zig +++ b/src/math/quat.zig @@ -4,6 +4,7 @@ const mach = @import("../main.zig"); const testing = mach.testing; const math = mach.math; const vec = @import("vec.zig"); +const mat = @import("mat.zig"); pub fn Quat(comptime Scalar: type) type { return extern struct { @@ -15,9 +16,266 @@ pub fn Quat(comptime Scalar: type) type { /// The underlying Vec type, e.g. math.Vec4, math.Vec4h, math.Vec4d pub const Vec = vec.Vec(4, Scalar); - pub inline fn init(x: T, y: T, z: T, w: T) Quat(Scalar) { + /// The Vec type used to represent axes, e.g. math.Vec3 + pub const Axis = vec.Vec(3, Scalar); + + /// Creates a quaternion from the given x, y, z, and w values + pub inline fn init(x: T, y: T, z: T, w: T) Quat(T) { return .{ .v = math.vec4(x, y, z, w) }; } + + /// Returns the identity quaternion. + pub inline fn identity() Quat(T) { + return init(0, 0, 0, 1); + } + + /// Returns the inverse of the quaternion. + pub inline fn inverse(q: *const Quat(T)) Quat(T) { + const s = 1 / q.len2(); + return init(-q.v.x() * s, -q.v.y() * s, -q.v.z() * s, q.v.w() * s); + } + + /// Creates a Quaternion based on the given `axis` and `angle`, and returns it. + pub inline fn fromAxisAngle(axis: Axis, angle: T) Quat(T) { + const halfAngle = angle * 0.5; + const s = std.math.sin(halfAngle); + + return init(s * axis.x(), s * axis.y(), s * axis.z(), std.math.cos(halfAngle)); + } + + /// Calculates the angle between two given quaternions. + pub inline fn angleBetween(a: *const Quat(T), b: *const Quat(T)) T { + const d = Vec.dot(&a.v, &b.v); + return std.math.acos(2 * d * d - 1); + } + + /// Multiplies two quaternions + pub inline fn mul(a: *const Quat(T), b: *const Quat(T)) Quat(T) { + const ax = a.v.x(); + const ay = a.v.y(); + const az = a.v.z(); + const aw = a.v.w(); + const bx = b.v.x(); + const by = b.v.y(); + const bz = b.v.z(); + const bw = b.v.w(); + + const x = aw * bx + ax * bw + ay * bz - az * by; + const y = aw * by + ay * bw + az * bx - ax * bz; + const z = aw * bz + az * bw + ax * by - ay * bx; + const w = aw * bw - ax * bx - ay * by - az * bz; + + return init(x, y, z, w); + } + + /// Adds two quaternions + pub inline fn add(a: *const Quat(T), b: *const Quat(T)) Quat(T) { + return init(a.v.x() + b.v.x(), a.v.y() + b.v.y(), a.v.z() + b.v.z(), a.v.w() + b.v.w()); + } + + /// Subtracts two quaternions + pub inline fn sub(a: *const Quat(T), b: *const Quat(T)) Quat(T) { + return init(a.v.x() - b.v.x(), a.v.y() - b.v.y(), a.v.z() - b.v.z(), a.v.w() - b.v.w()); + } + + /// Multiplies a Quaternion by a scalar + pub inline fn mulScalar(q: *const Quat(T), s: T) Quat(T) { + return init(q.v.x() * s, q.v.y() * s, q.v.z() * s, q.v.w() * s); + } + + /// Divides a Quaternion by a scalar + pub inline fn divScalar(q: *const Quat(T), s: T) Quat(T) { + return init(q.v.x() / s, q.v.y() / s, q.v.z() / s, q.v.w() / s); + } + + /// Rotates the give quaternion by the given angle, around the x-axis. + pub inline fn rotateX(q: *const Quat(T), angle: T) Quat(T) { + const halfAngle = angle * 0.5; + + const qx = q.v.x(); + const qy = q.v.y(); + const qz = q.v.z(); + const qw = q.v.w(); + + const bx = std.math.sin(halfAngle); + const bw = std.math.cos(halfAngle); + + return init(qx * bw + qw * bx, qy * bw + qz * bx, qz * bw - qy * bx, qw * bw - qx * bx); + } + + /// Rotates the give quaternion by the given angle, around the y-axis. + pub inline fn rotateY(q: *const Quat(T), angle: T) Quat(T) { + const halfAngle = angle * 0.5; + + const qx = q.v.x(); + const qy = q.v.y(); + const qz = q.v.z(); + const qw = q.v.w(); + + const by = std.math.sin(halfAngle); + const bw = std.math.cos(halfAngle); + + return init(qx * bw - qz * by, qy * bw + qw * by, qz * bw + qx * by, qw * bw - qy * by); + } + + /// Rotates the give quaternion by the given angle, around the z-axis. + pub inline fn rotateZ(q: *const Quat(T), angle: T) Quat(T) { + const halfAngle = angle * 0.5; + + const qx = q.v.x(); + const qy = q.v.y(); + const qz = q.v.z(); + const qw = q.v.w(); + + const bz = std.math.sin(halfAngle); + const bw = std.math.cos(halfAngle); + + return init(qx * bw - qy * bz, qy * bw + qx * bz, qz * bw + qw * bz, qw * bw - qz * bz); + } + + /// Calculates the spherical linear interpolation between two quaternions. + pub inline fn slerp(a: *const Quat(T), b: *const Quat(T), t: T) Quat(T) { + const ax = a.v.x(); + const ay = a.v.y(); + const az = a.v.z(); + const aw = a.v.w(); + + var bx = b.v.x(); + var by = b.v.y(); + var bz = b.v.z(); + var bw = b.v.w(); + + var cosOmega = ax * bx + ay * by + az * bz + aw * bw; + if (cosOmega < 0) { + cosOmega = -cosOmega; + bx = -bx; + by = -by; + bz = -bz; + bw = -bw; + } + + var scale0: T = 0.0; + var scale1: T = 0.0; + + if (1.0 - cosOmega > math.eps(T)) { + const omega = std.math.acos(cosOmega); + const sinOmega = std.math.sin(omega); + scale0 = std.math.sin((1.0 - t) * omega) / sinOmega; + scale1 = std.math.sin(t * omega) / sinOmega; + } else { + scale0 = 1.0 - t; + scale1 = t; + } + + return init(scale0 * ax + scale1 * bx, scale0 * ay + scale1 * by, scale0 * az + scale1 * bz, scale0 * aw + scale1 * bw); + } + + /// Calculates the conjugate of the given quaternion. + pub inline fn conjugate(q: *const Quat(T)) Quat(T) { + return init(-q.v.x(), -q.v.y(), -q.v.z(), q.v.w()); + } + + /// Creates a quaternion from the given transformation matrix. + pub inline fn fromMat(comptime matT: type, m: *const matT) Quat(T) { + var dst = Quat(T).identity(); + const trace = m.v[0].v[0] + m.v[1].v[1] + m.v[2].v[2]; + + if (trace > 0) { + const root = std.math.sqrt(trace + 1.0); + dst.v.v[3] = 0.5 * root; + const rootInv = 0.5 / root; + + dst.v.v[0] = (m.v[1].v[2] - m.v[2].v[1]) * rootInv; + dst.v.v[1] = (m.v[2].v[0] - m.v[0].v[2]) * rootInv; + dst.v.v[2] = (m.v[0].v[1] - m.v[1].v[0]) * rootInv; + } else { + var i: usize = 0; + + if (m.v[1].v[1] > m.v[0].v[0]) { + i = 1; + } + + if (m.v[2].v[2] > m.v[i].v[i]) { + i = 2; + } + + const j = (i + 1) % 3; + const k = (i + 2) % 3; + + const root = std.math.sqrt(m.v[i].v[i] - m.v[j].v[j] - m.v[k].v[k] + 1.0); + dst.v.v[i] = 0.5 * root; + + const rootInv = 0.5 / root; + + dst.v.v[3] = (m.v[j].v[k] - m.v[k].v[j]) * rootInv; + dst.v.v[j] = (m.v[j].v[i] - m.v[i].v[j]) * rootInv; + dst.v.v[k] = (m.v[k].v[i] - m.v[i].v[k]) * rootInv; + } + + return dst; + } + + /// Creates a quaternion from the given Euler angles. + pub inline fn fromEuler(x: T, y: T, z: T) Quat(T) { + const xHalf = x * 0.5; + const yHalf = y * 0.5; + const zHalf = z * 0.5; + + const sx = std.math.sin(xHalf); + const cx = std.math.cos(xHalf); + const sy = std.math.sin(yHalf); + const cy = std.math.cos(yHalf); + const sz = std.math.sin(zHalf); + const cz = std.math.cos(zHalf); + + const xRet = sx * cy * cz + cx * sy * sz; + const yRet = cx * sy * cz - sx * cy * sz; + const zRet = cx * cy * sz + sx * sy * cz; + const wRet = cx * cy * cz - sx * sy * sz; + + return init(xRet, yRet, zRet, wRet); + } + + /// Returns the dot product of two quaternions. + pub inline fn dot(a: *const Quat(T), b: *const Quat(T)) T { + return a.v.x() * b.v.x() + a.v.y() * b.v.y() + a.v.z() * b.v.z() + a.v.w() * b.v.w(); + } + + /// Linearly interpolates between two quaternions. + pub inline fn lerp(a: *const Quat(T), b: *const Quat(T), t: T) Quat(T) { + const xRet = a.v.x() + t * (b.v.x() - a.v.x()); + const yRet = a.v.y() + t * (b.v.y() - a.v.y()); + const zRet = a.v.z() + t * (b.v.z() - a.v.z()); + const wRet = a.v.w() + t * (b.v.w() - a.v.w()); + + return init(xRet, yRet, zRet, wRet); + } + + /// Computes the squared length of a given quaternion. + pub inline fn len2(q: *const Quat(T)) T { + return q.v.x() * q.v.x() + q.v.y() * q.v.y() + q.v.z() * q.v.z() + q.v.w() * q.v.w(); + } + + /// Computes the length of a given quaternion. + pub inline fn len(q: *const Quat(T)) T { + return std.math.sqrt(q.v.x() * q.v.x() + q.v.y() * q.v.y() + q.v.z() * q.v.z() + q.v.w() * q.v.w()); + } + + /// Computes the normalized version of a given quaternion. + pub inline fn normalize(q: *const Quat(T)) Quat(T) { + const q0 = q.v.x(); + const q1 = q.v.y(); + const q2 = q.v.z(); + const q3 = q.v.w(); + + const length = std.math.sqrt(q0 * q0 + q1 * q1 + q2 * q2 + q3 * q3); + + if (length > 0.00001) { + return init(q0 / length, q1 / length, q2 / length, q3 / length); + } else { + return init(0, 0, 0, 0); + } + } }; } @@ -32,3 +290,163 @@ test "init" { .v = math.vec4(1, 2, 3, 4), }); } + +test "inverse" { + const q = math.Quat.init(1.0, 2.0, 3.0, 4.0); + const expected = math.Quat.init(-0.1 / 3.0, -0.1 / 3.0 * 2.0, -0.1, 1.0 / 7.5); + const actual = q.inverse(); + + try testing.expect(math.Vec4, expected.v).eql(actual.v); +} + +test "fromAxisAngle" { + const expected = math.Quat.identity().rotateX(math.pi / 4.0); + const actual = math.Quat.fromAxisAngle(math.vec3(1, 0, 0), math.pi / 4.0); // 45 degrees in radians (π/4) around the x-axis + + try testing.expect(math.Vec4, expected.v).eql(actual.v); +} + +test "angleBetween" { + const a = math.Quat.fromAxisAngle(math.vec3(1, 0, 0), 1.0); + const b = math.Quat.fromAxisAngle(math.vec3(1, 0, 0), -1.0); + + try testing.expect(f32, math.Quat.angleBetween(&a, &b)).eql(2.0); +} + +test "mul" { + const a = math.Quat.init(1.0, 2.0, 3.0, 4.0); + const b = a.inverse(); + const expected = math.Quat.identity(); + const actual = math.Quat.mul(&a, &b); + + try testing.expect(math.Vec4, expected.v).eql(actual.v); +} + +test "add" { + const a = math.Quat.init(1.0, 2.0, 3.0, 4.0); + const b = math.Quat.init(5.0, 6.0, 7.0, 8.0); + const expected = math.Quat.init(6.0, 8.0, 10.0, 12.0); + const actual = math.Quat.add(&a, &b); + + try testing.expect(math.Vec4, expected.v).eql(actual.v); +} + +test "sub" { + const a = math.Quat.init(1.0, 2.0, 3.0, 4.0); + const b = math.Quat.init(5.0, 6.0, 7.0, 8.0); + const expected = math.Quat.init(-4.0, -4.0, -4.0, -4.0); + const actual = math.Quat.sub(&a, &b); + + try testing.expect(math.Vec4, expected.v).eql(actual.v); +} + +test "mulScalar" { + const q = math.Quat.init(1.0, 2.0, 3.0, 4.0); + const expected = math.Quat.init(2.0, 4.0, 6.0, 8.0); + const actual = math.Quat.mulScalar(&q, 2.0); + + try testing.expect(math.Vec4, expected.v).eql(actual.v); +} + +test "divScalar" { + const q = math.Quat.init(1.0, 2.0, 3.0, 4.0); + const expected = math.Quat.init(0.5, 1.0, 1.5, 2.0); + const actual = math.Quat.divScalar(&q, 2.0); + + try testing.expect(math.Vec4, expected.v).eql(actual.v); +} + +test "rotateX" { + const expected = math.Quat.fromAxisAngle(math.vec3(1, 0, 0), math.pi / 4.0); + const actual = math.Quat.identity().rotateX(math.pi / 4.0); + + try testing.expect(math.Vec4, expected.v).eql(actual.v); +} + +test "rotateY" { + const expected = math.Quat.fromAxisAngle(math.vec3(0, 1, 0), math.pi / 4.0); + const actual = math.Quat.identity().rotateY(math.pi / 4.0); + + try testing.expect(math.Vec4, expected.v).eql(actual.v); +} + +test "rotateZ" { + const expected = math.Quat.fromAxisAngle(math.vec3(0, 0, 1), math.pi / 4.0); + const actual = math.Quat.identity().rotateZ(math.pi / 4.0); + + try testing.expect(math.Vec4, expected.v).eql(actual.v); +} + +test "slerp" { + const a = math.Quat.init(1.0, 2.0, 3.0, 4.0); + const b = math.Quat.init(5.0, 6.0, 7.0, 8.0); + const expected = math.Quat.init(3.0, 4.0, 5.0, 6.0); + const actual = math.Quat.slerp(&a, &b, 0.5); + + try testing.expect(math.Vec4, expected.v).eql(actual.v); +} + +test "conjugate" { + const q = math.Quat.init(1.0, 2.0, 3.0, 4.0); + const expected = math.Quat.init(-1.0, -2.0, -3.0, 4.0); + const actual = math.Quat.conjugate(&q); + + try testing.expect(math.Vec4, expected.v).eql(actual.v); +} + +test "fromMat4" { + const m = math.Mat4x4.rotateX(math.pi / 4.0); + const q = math.Quat.fromMat(math.Mat4x4, &m); + const expected = math.Quat.identity().rotateX(math.pi / 4.0); + + try testing.expect(math.Vec4, expected.v).eql(q.v); +} + +test "fromEuler" { + const q = math.Quat.fromEuler(math.pi / 4.0, 0.0, 0.0); + const expected = math.Quat.identity().rotateX(math.pi / 4.0); + + try testing.expect(math.Vec4, expected.v).eql(q.v); +} + +test "dot" { + const a = math.Quat.init(1.0, 2.0, 3.0, 4.0); + const b = math.Quat.init(5.0, 6.0, 7.0, 8.0); + const expected = 70.0; + const actual = math.Quat.dot(&a, &b); + + try testing.expect(f32, actual).eql(expected); +} + +test "lerp" { + const a = math.Quat.init(1.0, 2.0, 3.0, 4.0); + const b = math.Quat.init(5.0, 6.0, 7.0, 8.0); + const expected = math.Quat.init(3.0, 4.0, 5.0, 6.0); + const actual = math.Quat.lerp(&a, &b, 0.5); + + try testing.expect(math.Vec4, expected.v).eql(actual.v); +} + +test "len2" { + const q = math.Quat.init(1.0, 2.0, 3.0, 4.0); + const expected = 30.0; + const actual = math.Quat.len2(&q); + + try testing.expect(f32, actual).eql(expected); +} + +test "len" { + const q = math.Quat.init(0.0, 0.0, 3.0, 4.0); + const expected = 5.0; + const actual = math.Quat.len(&q); + + try testing.expect(f32, actual).eql(expected); +} + +test "normalize" { + const q = math.Quat.init(0.0, 0.0, 3.0, 4.0); + const expected = math.Quat.init(0.0, 0.0, 0.6, 0.8); + const actual = math.Quat.normalize(&q); + + try testing.expect(math.Vec4, expected.v).eql(actual.v); +} diff --git a/src/math/vec.zig b/src/math/vec.zig index bce0d93b..7babd553 100644 --- a/src/math/vec.zig +++ b/src/math/vec.zig @@ -217,7 +217,7 @@ pub fn Vec(comptime n_value: usize, comptime Scalar: type) type { /// Calculates the dot product between vector a and b and returns scalar. pub inline fn dot(a: *const VecN, b: *const VecN) Scalar { - return .{ .v = @reduce(.Add, a.v * b.v) }; + return @reduce(.Add, a.v * b.v); } }; }