diff --git a/src/fancymath/vector.py b/src/fancymath/vector.py index 6c10466..c027e5b 100644 --- a/src/fancymath/vector.py +++ b/src/fancymath/vector.py @@ -36,6 +36,14 @@ class Vector(object): z = self.z + other.z return self.__class__(x, y, z) + def __mul__(self, other): + if not isinstance(other, Vector): + return NotImplemented + x = self.y * other.z - self.z * other.y + y = self.z * other.x - self.x * other.z + z = self.x * other.y - self.y * other.x + return self.__class__(x, y, z) + if __name__ == "__main__": v1 = Vector(3, 4, 0) diff --git a/test/test_vector.py b/test/test_vector.py index d3b0f65..53938f9 100644 --- a/test/test_vector.py +++ b/test/test_vector.py @@ -4,6 +4,28 @@ from fancymath.vector import Vector class TestVector(unittest.TestCase): + def test_cross_product(self): + # given + v1 = Vector(1, 2, 3) + v2 = Vector(3, 2, 1) + + # when + v3 = v1 * v2 + + # then + self.assertIsInstance(v3, Vector) + self.assertEqual(v3.x, -4) + self.assertEqual(v3.y, 8) + self.assertEqual(v3.z, -4) + + def test_cross_product_failure(self): + v1 = Vector(1, 2, 3) + v2 = (3, 2, 1) + + with self.assertRaises(TypeError): + v1 * v2 + + def test_add_vector(self): v1 = Vector(1, 2, 3) v2 = Vector(-3, -2, -1)