Commit 33001dec authored by Wit, Allard de's avatar Wit, Allard de
Browse files

Parallel loading and progressbar work properly now

parent b8fca6fc
......@@ -31,11 +31,11 @@ class CSVLoadingError(Exception):
super().__init__()
def load_parcel_info(dsn, pixcounts_file, shape_file, table_name):
def load_parcel_info(dsn, counts_file, shape_file, table_name):
"""Loads the parcel info from the
:param dsn: Data source name where to write to
:param pixcounts_file: CSV file from which pixel counts should be read
:param counts_file: CSV file from which pixel counts should be read
:param shape_file: shapefile whose .DBF file should be used as parcel info.
:param table_name: name of the table to write parcel info into
......@@ -46,7 +46,7 @@ def load_parcel_info(dsn, pixcounts_file, shape_file, table_name):
df["area_ha"] = df.geometry.area/1e4
df = df.set_index("fieldid")
fname_counts = Path(pixcounts_file)
fname_counts = Path(counts_file)
df_counts = pd.read_csv(fname_counts)
df_counts.set_index("field_ID", inplace=True)
df["pixcount"] = df_counts["pixcount"]
......@@ -113,13 +113,12 @@ def process_rows(rows):
return df
def write_to_database(engine, table_name, csv_readers, nlines, child_conn):
def write_to_database(engine, dataset_name, csv_readers, child_conn):
"""routine writes data from a set of CSV files into the database
:param engine: the database engine to be used
:param table_name: the name of the output table
:param dataset_name: the name of the dataset, will be used as output table name
:param csv_readers: the set of CSV DictReaders (one per CSV file)
:param nlines: number of lines to process
:param child_conn: The pipe to report progress in loading data from file
:return:
"""
......@@ -129,24 +128,22 @@ def write_to_database(engine, table_name, csv_readers, nlines, child_conn):
rows = {column_name: next(reader) for column_name, reader in csv_readers.items()}
df = process_rows(rows)
try:
df.to_sql(table_name, engine, if_exists="append", index=False)
df.to_sql(dataset_name, engine, if_exists="append", index=False)
except sa.exc.IntegrityError as e:
print(f"Field ID {df.fieldID.unique()} failed to insert in table {table_name}")
print(f"Field ID {df.fieldID.unique()} failed to insert in table {dataset_name}")
this_line += 1
if this_line % 100 == 0:
progress = this_line/nlines
child_conn.send(progress)
child_conn.send({dataset_name: this_line})
except StopIteration:
break
def load_satellite_csv(child_conn, dataset_name, dsn, bands, nlines, ):
nlines = count_lines(bands) # 803016
def load_satellite_csv(child_conn, dataset_name, dsn, bands, **kwargs):
mean_csv_readers = {}
for column_name, csv_fname in bands.items():
mean_csv_readers[column_name] = DictReader(open(csv_fname))
engine = prepare_db(dsn, table_name=dataset_name, bands=mean_csv_readers.keys())
write_to_database(engine, dataset_name, mean_csv_readers, nlines, child_conn)
write_to_database(engine, dataset_name, mean_csv_readers, child_conn)
class Process(mp.Process):
......@@ -171,6 +168,10 @@ class Process(mp.Process):
return self._exception
def start_parallel_loading(datasets):
pass
def load_data(yaml_file):
"""Loads data point to by the YAML config file.
......@@ -179,36 +180,43 @@ def load_data(yaml_file):
"""
grompy_conf = yaml.safe_load(open(yaml_file))
# First load parcel info
parcel_info = grompy_conf.pop("parcel_info")
load_parcel_info(**parcel_info)
# Start loading CSV files in parallel
process_list = []
parent_conn, child_conn = mp.Pipe()
lines_per_dataset = {}
for dataset_name, description in grompy_conf["datasets"].items():
if "nlines" not in description:
print("You must run 'grompy check' before trying 'grompy load'! Aborting...")
sys.exit()
print(f"Starting loading of: {dataset_name}")
lines_per_dataset[dataset_name] = description["nlines"]
p = Process(target=load_satellite_csv, args=(child_conn, dataset_name,), kwargs=description)
process_list.append(p)
for p in process_list:
p.start()
total_lines = sum(c for c in lines_per_dataset.values())
lines_per_dataset = {ds:0 for ds in lines_per_dataset}
try:
printProgressBar(0, 1000, decimals=2, length=50)
while any([p.is_alive() for p in process_list]):
printProgressBar(0, total_lines, decimals=2, length=50)
processes = [p for p in process_list if p.is_alive()]
while processes:
time.sleep(3)
progress = []
for p in process_list:
if parent_conn.poll():
progress.append(parent_conn.recv())
lines_per_dataset.update(parent_conn.recv())
if p.exception:
error, traceback = p.exception
raise CSVLoadingError(error, traceback)
if progress:
p = math.floor(min(progress) * 1000)
# print(f"\rprogress: {min(progress):7.4f}")
printProgressBar(p, 1000, decimals=2, length=50)
else:
# print("no progress value")
pass
current_lines = sum(c for c in lines_per_dataset.values())
printProgressBar(current_lines, total_lines, decimals=2, length=50)
processes = [p for p in process_list if p.is_alive()]
except KeyboardInterrupt:
for p in process_list:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment