diff --git a/tests/test_clustering.py b/tests/test_clustering.py new file mode 100644 index 0000000000000000000000000000000000000000..8a30878c4ae9e9fcfdae93b7e4af4ac59659d424 --- /dev/null +++ b/tests/test_clustering.py @@ -0,0 +1,43 @@ +from pathlib import Path +from unittest.mock import Mock + +import numpy as np +import pandas as pd +import pytest +import rasterio as rio +import scipy + +from mnp.species_models.clustering import ClusterLocal + + +@pytest.fixture +def parameters(): + params = Mock() + params.geospatial_profile = {"transform": Mock(a=25.0)} + params.folders = {"clustering": "."} + params.species_traits = pd.read_csv( + "./data/parameters/species_traits.csv", index_col=0 + ) + return params + + +@pytest.fixture +def hsi(): + with rio.open(Path("./data/rasters/hsi_60002.tif")) as src: + return scipy.sparse.csr_array(src.read(1)) + + +@pytest.fixture +def expected_clustering(): + with rio.open(Path("./data/rasters/clustering_60002.tif")) as src: + return scipy.sparse.csr_array(src.read(1)) + + +class TestClusterLocal: + def test_run(self, parameters, hsi, expected_clustering): + clustering = ClusterLocal(parameters, 60002) + clustering.run(hsi) + assert clustering.array.sum() > 0 + np.testing.assert_array_equal( + clustering.array.todense(), expected_clustering.todense() + )