Skip to content
Snippets Groups Projects
Commit 0bacead2 authored by Roelofsen, Hans's avatar Roelofsen, Hans
Browse files

rewrite to v2

parent f2002d4d
No related branches found
No related tags found
No related merge requests found
import os
from prefect import Flow, unmapped, Parameter
from prefect.executors import LocalDaskExecutor
try:
from src.source_rasters import get_source_rasters
from src.combinations import read_write_combination_chunk, build_combination_dict
from src.windows import get_rw_windows, update_window_to_joint_dict
from src.attribute_table import make_attribute_table, get_max_value
from src.target_raster import map_chunks_to_joint
except ModuleNotFoundError:
from source_rasters import get_source_rasters
from combinations import read_write_combination_chunk, build_combination_dict
from windows import get_rw_windows, update_window_to_joint_dict
from attribute_table import make_attribute_table, get_max_value
from target_raster import map_chunks_to_joint
with Flow("COMBINE") as COMBINE:
destination = Parameter("destination")
source_rasters = Parameter("sources")
testing = Parameter("testing")
source_rasters = get_source_rasters(source_rasters)
rw_windows = get_rw_windows(source_rasters=source_rasters, sample=testing)
solo_combis = read_write_combination_chunk.map(
rw_windows, unmapped(source_rasters), unmapped(destination)
)
solo_combis.set_dependencies(upstream_tasks=[rw_windows, source_rasters])
combination_dict = build_combination_dict(read_windows=rw_windows)
combination_dict.set_dependencies(upstream_tasks=[solo_combis])
attribute_table = make_attribute_table(
combination_dict, source_rasters, destination
)
attribute_table.set_dependencies(upstream_tasks=[combination_dict])
updated_windows = update_window_to_joint_dict.map(
rw_windows, unmapped(combination_dict)
)
updated_windows.set_dependencies(upstream_tasks=[combination_dict])
max_value = get_max_value(combination_dict)
max_value.set_dependencies(upstream_tasks=[combination_dict])
chunks_to_joint = map_chunks_to_joint(
rw_windows,
source_rasters,
destination,
max_value,
)
chunks_to_joint.set_dependencies(upstream_tasks=[max_value])
if __name__ == "__main__":
"""Run the COMBINE DAG.
The COMBINE dag is implemented using Prefect (https://docs.prefect.io/)
Parameters
----------
destination: str
Path to destination *.tif file
source_raster01: Path to first source raster
source_raster02: Path to second source raster
...: path to subsequent source rasters
Returns
-------
"""
import argparse
parser = argparse.ArgumentParser(
prog="Open Source COMBINE alternative. By: Hans Roelofsen"
)
parser.add_argument("destination", type=str, help="path to destination *.tif file")
parser.add_argument(
"sources", type=str, help="path to destination source files", nargs="+"
)
parser.add_argument("--test", action="store_true", help="test run with 20 windows")
args = parser.parse_args()
print("\n\nStarting COMBINE\n\n")
COMBINE.executor = LocalDaskExecutor()
COMBINE.run(
parameters=dict(
destination=args.destination, sources=args.sources, testing=args.test
)
)
import os
import pickle
import numpy as np
import rasterio as rio
from rasterio.windows import Window
import prefect
from prefect import task
from prefect import task, Parameter
try:
from windows import CombineWindow
......@@ -41,6 +43,43 @@ def get_combinations(
read_window.set_unique_combinations(uc=unique_combinations)
@task
def read_write_combination_chunk(
read_window: CombineWindow,
source_rasters: SourceRasters,
destination: Parameter,
):
"""
Write a combined arrays + dictionary to file
"""
logger = prefect.context.get("logger")
logger.info(
f"Finding unique combinations in window {read_window.nr} out of {read_window.total}"
)
# Read window from all source-rasters and store in list of arrays
arrays = [
rio.open(source_raster).read(1, window=read_window.rio_window).flatten()
for source_raster in source_rasters.source_list
]
# Zip to get vertical stacks
zipped_arrays = list(zip(*arrays))
# Dictionary between unique-combinations and an enumerated output value
d = {combi: n for n, combi in enumerate(set(zipped_arrays), start=1)}
# Create combined array and save to file
with open(
f"{os.path.splitext(destination)[0]}_{read_window.nr:05}.npy", "wb"
) as dest:
np.save(dest, np.array([d[array] for array in zipped_arrays]))
# Save this specific dictionary as a property of the window
read_window.save_window_combi_dict(combi_dict=d)
@task
def make_combination_dict(read_windows: list) -> dict:
all_combinations = set()
......@@ -49,3 +88,13 @@ def make_combination_dict(read_windows: list) -> dict:
return {combi: n for n, combi in enumerate(all_combinations, start=1)}
@task
def build_combination_dict(read_windows: list) -> dict:
"""
Gather solo-combination dictionaries from all Windows and return joint dictionary
"""
u_combis = set()
for d in [window.solo_combi_dict for window in read_windows]:
u_combis.update(list(d.keys()))
return {combi: n for n, combi in enumerate(u_combis, start=1)}
......@@ -17,6 +17,7 @@ class SourceRasters:
VerifyMatchingBounds(source_list=self.source_list),
VerifyMatchingResolution(source_list=self.source_list),
]
self.geospatial_profile = getattr(rio.open(self.source_list[0]), "profile")
def verify_input(self):
for procedure in self.procedures:
......@@ -75,7 +76,6 @@ class VerifyMatchingBounds(VerificationProcedure):
class VerifyMatchingResolution(VerificationProcedure):
def __init__(self, source_list):
self.rasters = source_list
self.status = True
......@@ -83,7 +83,9 @@ class VerifyMatchingResolution(VerificationProcedure):
def execute(self):
logger = prefect.context.get("logger")
reference_resolution = getattr(getattr(rio.open(self.rasters[0]), 'transform'), 'a')
reference_resolution = getattr(
getattr(rio.open(self.rasters[0]), "transform"), "a"
)
resolutions = []
messages = []
for raster in self.rasters:
......@@ -92,9 +94,10 @@ class VerifyMatchingResolution(VerificationProcedure):
messages.append(msg)
if not all(resolutions):
logger.error('Rasters have unequal resolutions')
logger.error("Rasters have unequal resolutions")
self.status = False
def is_geospatial_raster(file_path: str, msg=False) -> (bool, str):
"""
Is a file a gdal compatible geospatial raster?
......
......@@ -13,6 +13,69 @@ except ModuleNotFoundError:
from src.windows import CombineWindow
@task
def map_chunks_to_joint(
rw_windows: list,
source_rasters: SourceRasters,
destination: str,
max_value: int,
):
"""
Load a npy array from file, apply the joint combination dict to it and write as chunk to the full output raster
"""
logger = prefect.context.get("logger")
with rio.open(
destination,
"w",
driver="GTiff",
width=source_rasters.geospatial_profile["width"],
height=source_rasters.geospatial_profile["height"],
count=1,
dtype=rio.dtypes.get_minimum_dtype([0, max_value]),
transform=source_rasters.geospatial_profile["transform"],
compress="LZW",
crs=source_rasters.geospatial_profile["crs"],
) as dest:
dest.update_tags(
creation_date=datetime.datetime.now().strftime("%d-%b-%Y_%H:%M:%S"),
creator=os.environ.get("USERNAME"),
created_with="WENR Combine tool",
**dict(
zip(
[
f"sourceraster{i:2}"
for i, _ in enumerate(source_rasters.source_list, start=1)
],
source_rasters.source_list,
)
),
)
for rw_window in rw_windows:
logger.info(
f"Writing chunk {rw_window.nr} out of {rw_window.total} to file."
)
with open(
f"{os.path.splitext(destination)[0]}_{rw_window.nr:05}.npy", "rb"
) as f:
combi_array = np.load(f)
target_values = np.vectorize(rw_window.joint_combi_dict.__getitem__)(
combi_array
)
dest.write(
np.array(target_values).reshape(
(rw_window.rio_window.height, rw_window.rio_window.width)
),
indexes=1,
window=rw_window.rio_window,
)
@task
def write_target_raster_chunk(
rw_window: CombineWindow,
......@@ -22,7 +85,6 @@ def write_target_raster_chunk(
source_rasters: SourceRasters,
):
logger = prefect.context.get("logger")
with rio.open(
......
......@@ -33,6 +33,8 @@ class CombineWindow:
(row_start, row_stop), (column_start, column_stop)
)
self.unique_combinations = None
self.solo_combi_dict = None
self.joint_combi_dict = None
self.nr = nr
self.total = totals
......@@ -40,12 +42,35 @@ class CombineWindow:
a=resolution, b=0, c=top_left_x, d=0, e=resolution * -1, f=top_left_y
)
def save_window_combi_dict(self, combi_dict: dict):
self.solo_combi_dict = combi_dict
def set_unique_combinations(self, uc: set):
self.unique_combinations = uc
def get_unique_combinations(self):
return self.unique_combinations
def map_to_joint_combination_dict(self, joint_combination_dict: dict):
"""
self.u_combi_dict = {(1,1,2): 1,
(1,2,1): 2,
(0,0,1): 3}
joint_combination_dict = {(1,1,1): 1,
(1,1,2): 2,
(1,2,2): 3,
(1,2,1): 4,
(2,2,2): 5,
(0,0,0): 6,
(0,0,1): 7}
self.joint_combi_dict = {1: 2,
2: 4,
3: 7}
"""
self.joint_combi_dict = {
v: joint_combination_dict[k] for k, v in self.solo_combi_dict.items()
}
@task
def get_rw_windows(
......@@ -98,3 +123,10 @@ def get_rw_windows(
return random.sample(windows, 20)
else:
return windows
@task
def update_window_to_joint_dict(rw_window: CombineWindow, joint_dict: dict):
""" """
rw_window.map_to_joint_combination_dict(joint_combination_dict=joint_dict)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment