Skip to content
Snippets Groups Projects
Commit 402762a6 authored by Cees Voesenek's avatar Cees Voesenek
Browse files

Parameterise skeleton fitter tests over animal names

parent 9e55bd9e
Branches
Tags 1.40.7
1 merge request!19Move to modern standard Python package structure
import copy, os
from matplotlib import pyplot as plt
from flitrak3d.skeleton_fitter import optimiser, utils
from flitrak3d.skeleton_fitter.modules.plotting import plot_limbs, plot_body_and_limbs
import copy
from importlib import import_module
import os
from matplotlib import pyplot as plt
import pytest
def test_skeleton_fly(test_data_directory: str, show_plots=True):
animal_names = ['fly', 'fly_geo', 'fly_hybrid']
for animal_name in animal_names:
if '_' in animal_name: skeleton_type = animal_name[animal_name.index('_'):]
else: skeleton_type = ''
init_yaml_path = os.path.join(test_data_directory,
'mosquito_escapes/initial_parameters/mosquito' + skeleton_type + '.yaml')
animal = import_module('flitrak3d.skeleton_fitter.animals.' + animal_name)
animal_sets = animal.AnimalSettings
body = import_module('flitrak3d.skeleton_fitter.modules.bodies.' + animal_sets.body_module_name)
body_sets = body.BodyModuleSettings(init_yaml_path)
body_params = copy.deepcopy(body_sets.params_init)
body_params['yaw_a'] = -55.0
body_params['pitch_a'] = 20.0
body_params['roll_a'] = -5.0
[body_params['x_com'], body_params['y_com'], body_params['z_com']] = [0.0, 5.0, 3.0]
body_skeleton3d_1 = body.rotate_skeleton3d(body_sets.skeleton3d_init, body_params)
body_skeleton3d_1 = body.translate_skeleton3d(body_skeleton3d_1, body_params)
#plot_body(body_skeleton3d_1)
limbs, limbs_sets, limbs_params = {}, {}, {}
for num_limb, limb_name in enumerate(animal_sets.limb_names):
limbs[limb_name] = import_module('flitrak3d.skeleton_fitter.modules.limbs.' + animal_sets.limb_module_names[num_limb])
limbs_sets[limb_name] = limbs[limb_name].LimbsModuleSettings(init_yaml_path)
limbs_params[limb_name] = copy.deepcopy(limbs_sets[limb_name].params_init)
for param_name in body_params.keys():
limbs_params[limb_name]['right'][param_name] = body_params[param_name]
limbs_params[limb_name]['left'][param_name] = body_params[param_name]
limbs_params['wings']['right']['stroke_a'] = -55.0
limbs_params['wings']['right']['deviation_a'] = 10.0
limbs_params['wings']['right']['rotation_a'] = 60.0
limbs_params['wings']['left']['stroke_a'] = 5.0
limbs_params['wings']['left']['deviation_a'] = -12.0
limbs_params['wings']['left']['rotation_a'] = -20.0
limbs_skeleton3d = dict()
limbs_skeleton3d['wings'] = copy.deepcopy(limbs_sets['wings'].skeleton3d_init)
from flitrak3d.skeleton_fitter import optimiser, utils
from flitrak3d.skeleton_fitter.modules.plotting import plot_body_and_limbs
@pytest.mark.parametrize('animal_name', ['fly', 'fly_geo', 'fly_hybrid'])
def test_skeleton_fly(test_data_directory: str, animal_name: str, show_plots: bool = True) -> None:
if '_' in animal_name:
skeleton_type = animal_name[animal_name.index('_'):]
else:
skeleton_type = ''
init_yaml_path = os.path.join(test_data_directory,
'mosquito_escapes/initial_parameters/mosquito' + skeleton_type + '.yaml')
animal = import_module('flitrak3d.skeleton_fitter.animals.' + animal_name)
animal_sets = animal.AnimalSettings
body = import_module('flitrak3d.skeleton_fitter.modules.bodies.' + animal_sets.body_module_name)
body_sets = body.BodyModuleSettings(init_yaml_path)
body_params = copy.deepcopy(body_sets.params_init)
body_params['yaw_a'] = -55.0
body_params['pitch_a'] = 20.0
body_params['roll_a'] = -5.0
[body_params['x_com'], body_params['y_com'], body_params['z_com']] = [0.0, 5.0, 3.0]
body_skeleton3d_1 = body.rotate_skeleton3d(body_sets.skeleton3d_init, body_params)
body_skeleton3d_1 = body.translate_skeleton3d(body_skeleton3d_1, body_params)
#plot_body(body_skeleton3d_1)
limbs, limbs_sets, limbs_params = {}, {}, {}
for num_limb, limb_name in enumerate(animal_sets.limb_names):
limbs[limb_name] = import_module('flitrak3d.skeleton_fitter.modules.limbs.' + animal_sets.limb_module_names[num_limb])
limbs_sets[limb_name] = limbs[limb_name].LimbsModuleSettings(init_yaml_path)
limbs_params[limb_name] = copy.deepcopy(limbs_sets[limb_name].params_init)
for param_name in body_params.keys():
limbs_params[limb_name]['right'][param_name] = body_params[param_name]
limbs_params[limb_name]['left'][param_name] = body_params[param_name]
limbs_params['wings']['right']['stroke_a'] = -55.0
limbs_params['wings']['right']['deviation_a'] = 10.0
limbs_params['wings']['right']['rotation_a'] = 60.0
limbs_params['wings']['left']['stroke_a'] = 5.0
limbs_params['wings']['left']['deviation_a'] = -12.0
limbs_params['wings']['left']['rotation_a'] = -20.0
limbs_skeleton3d = dict()
limbs_skeleton3d['wings'] = copy.deepcopy(limbs_sets['wings'].skeleton3d_init)
for side in limbs_sets['wings'].skeleton3d_init.keys():
hinge_label = '{0}_wing_hinge'.format(side)
[limbs_params['wings'][side]['x_hinge'],
limbs_params['wings'][side]['y_hinge'],
limbs_params['wings'][side]['z_hinge']] = body_skeleton3d_1[hinge_label]
[limbs_params['wings'][side]['x_com'],
limbs_params['wings'][side]['y_com'],
limbs_params['wings'][side]['z_com']] = [body_params['x_com'], body_params['y_com'], body_params['z_com']]
limbs_skeleton3d['wings'][side] = limbs['wings'].rotate_and_translate_skeleton3d(limbs_sets['wings'].skeleton3d_init[side],
limbs_params['wings'][side], side)
if show_plots:
# plot_limbs(limbs_skeleton3d, wait_show_plot=True)
plot_body_and_limbs(body_skeleton3d_1, limbs_skeleton3d, wait_show_plot=True)
# Estimate body and wings parameters
body_params1_est = body.estimate_params_from_skeleton3d(body_skeleton3d_1, body_sets.params_init)
limbs_params1_est = {}
for num_limb, limb_name in enumerate(animal_sets.limb_names):
limbs_params1_est[limb_name] = {}
for side in ['right', 'left']:
limbs_params1_est[limb_name][side] = (
limbs[limb_name].estimate_params_from_skeleton3d(limbs_skeleton3d['wings'][side], body_params1_est,
limbs_sets[limb_name].params_init[side], side))
if show_plots:
body_skeleton3d_est = body.rotate_skeleton3d(body_sets.skeleton3d_init, body_params1_est)
body_skeleton3d_est = body.translate_skeleton3d(body_skeleton3d_est, body_params1_est)
limbs_skeleton3d_est = dict()
limbs_skeleton3d_est['wings'] = copy.deepcopy(limbs_sets['wings'].skeleton3d_init)
for side in limbs_sets['wings'].skeleton3d_init.keys():
hinge_label = '{0}_wing_hinge'.format(side)
[limbs_params['wings'][side]['x_hinge'],
limbs_params['wings'][side]['y_hinge'],
limbs_params['wings'][side]['z_hinge']] = body_skeleton3d_1[hinge_label]
[limbs_params['wings'][side]['x_com'],
limbs_params['wings'][side]['y_com'],
limbs_params['wings'][side]['z_com']] = [body_params['x_com'], body_params['y_com'], body_params['z_com']]
limbs_skeleton3d['wings'][side] = limbs['wings'].rotate_and_translate_skeleton3d(limbs_sets['wings'].skeleton3d_init[side],
limbs_params['wings'][side], side)
[limbs_params1_est['wings'][side]['x_hinge'], limbs_params1_est['wings'][side]['y_hinge'],
limbs_params1_est['wings'][side]['z_hinge']] = body_skeleton3d_est[hinge_label]
if show_plots:
# plot_limbs(limbs_skeleton3d, wait_show_plot=True)
plot_body_and_limbs(body_skeleton3d_1, limbs_skeleton3d, wait_show_plot=True)
[limbs_params1_est['wings'][side]['x_com'], limbs_params1_est['wings'][side]['y_com'],
limbs_params1_est['wings'][side]['z_com']] = [body_params1_est['x_com'], body_params1_est['y_com'], body_params1_est['z_com']]
# Estimate body and wings parameters
body_params1_est = body.estimate_params_from_skeleton3d(body_skeleton3d_1, body_sets.params_init)
limbs_params1_est = {}
for num_limb, limb_name in enumerate(animal_sets.limb_names):
limbs_params1_est[limb_name] = {}
for side in ['right', 'left']:
limbs_params1_est[limb_name][side] = (
limbs[limb_name].estimate_params_from_skeleton3d(limbs_skeleton3d['wings'][side], body_params1_est,
limbs_sets[limb_name].params_init[side], side))
limbs_skeleton3d_est['wings'][side] = (
limbs['wings'].rotate_and_translate_skeleton3d(limbs_sets['wings'].skeleton3d_init[side],
limbs_params1_est['wings'][side], side))
if show_plots:
body_skeleton3d_est = body.rotate_skeleton3d(body_sets.skeleton3d_init, body_params1_est)
body_skeleton3d_est = body.translate_skeleton3d(body_skeleton3d_est, body_params1_est)
# plot_limbs(limbs_skeleton3d_est, wait_show_plot=True)
plot_body_and_limbs(body_skeleton3d_est, limbs_skeleton3d_est, wait_show_plot=True)
limbs_skeleton3d_est = dict()
limbs_skeleton3d_est['wings'] = copy.deepcopy(limbs_sets['wings'].skeleton3d_init)
for side in limbs_sets['wings'].skeleton3d_init.keys():
hinge_label = '{0}_wing_hinge'.format(side)
[limbs_params1_est['wings'][side]['x_hinge'], limbs_params1_est['wings'][side]['y_hinge'],
limbs_params1_est['wings'][side]['z_hinge']] = body_skeleton3d_est[hinge_label]
print('> differences between real value and estimate are:')
for label in ['yaw_a', 'pitch_a', 'roll_a']:
print(' {0} = {1} degrees'.format(label, body_params[label] - body_params1_est[label]))
[limbs_params1_est['wings'][side]['x_com'], limbs_params1_est['wings'][side]['y_com'],
limbs_params1_est['wings'][side]['z_com']] = [body_params1_est['x_com'], body_params1_est['y_com'], body_params1_est['z_com']]
for side in limbs_sets['wings'].skeleton3d_init.keys():
for label in ['stroke_a', 'deviation_a', 'rotation_a']:
print(' {0} {1} = {2} degrees'.format(label, side, limbs_params['wings'][side][label] -
limbs_params1_est['wings'][side][label]))
limbs_skeleton3d_est['wings'][side] = (
limbs['wings'].rotate_and_translate_skeleton3d(limbs_sets['wings'].skeleton3d_init[side],
limbs_params1_est['wings'][side], side))
if show_plots:
plt.show()
# plot_limbs(limbs_skeleton3d_est, wait_show_plot=True)
plot_body_and_limbs(body_skeleton3d_est, limbs_skeleton3d_est, wait_show_plot=True)
print('> differences between real value and estimate are:')
for label in ['yaw_a', 'pitch_a', 'roll_a']:
print(' {0} = {1} degrees'.format(label, body_params[label] - body_params1_est[label]))
for side in limbs_sets['wings'].skeleton3d_init.keys():
for label in ['stroke_a', 'deviation_a', 'rotation_a']:
print(' {0} {1} = {2} degrees'.format(label, side, limbs_params['wings'][side][label] -
limbs_params1_est['wings'][side][label]))
if show_plots: plt.show()
def test_skeleton_fitter_fly(test_data_directory: str):
@pytest.mark.parametrize('animal_name', ['fly', 'fly_geo', 'fly_hybrid'])
def test_skeleton_fitter_fly(test_data_directory: str, animal_name: str) -> None:
max_rmse = 10
max_nb_iterations = 10000
......@@ -133,93 +138,93 @@ def test_skeleton_fitter_fly(test_data_directory: str):
sides = ['left', 'right']
animal_names = ['fly', 'fly_geo', 'fly_hybrid']
if '_' in animal_name:
skeleton_type = animal_name[animal_name.index('_'):]
else:
skeleton_type = ''
init_yaml_path = os.path.join(test_data_directory,
'mosquito_escapes/initial_parameters/mosquito' + skeleton_type + '.yaml')
for animal_name in animal_names:
if '_' in animal_name: skeleton_type = animal_name[animal_name.index('_'):]
else: skeleton_type = ''
init_yaml_path = os.path.join(test_data_directory,
'mosquito_escapes/initial_parameters/mosquito' + skeleton_type + '.yaml')
animal = import_module('flitrak3d.skeleton_fitter.animals.' + animal_name)
animal_sets = animal.AnimalSettings()
animal = import_module('flitrak3d.skeleton_fitter.animals.' + animal_name)
animal_sets = animal.AnimalSettings()
body = import_module('flitrak3d.skeleton_fitter.modules.bodies.' + animal_sets.body_module_name)
body_sets = body.BodyModuleSettings(init_yaml_path)
body_params = copy.deepcopy(body_sets.params_init)
body = import_module('flitrak3d.skeleton_fitter.modules.bodies.' + animal_sets.body_module_name)
body_sets = body.BodyModuleSettings(init_yaml_path)
body_params = copy.deepcopy(body_sets.params_init)
limbs, limbs_sets, limbs_params = {}, {}, {}
for num_limb, limb_name in enumerate(animal_sets.limb_names):
limbs[limb_name] = import_module('flitrak3d.skeleton_fitter.modules.limbs.' + animal_sets.limb_module_names[num_limb])
limbs_sets[limb_name] = limbs[limb_name].LimbsModuleSettings(init_yaml_path)
limbs_params[limb_name] = copy.deepcopy(limbs_sets[limb_name].params_init)
limbs, limbs_sets, limbs_params = {}, {}, {}
for param_name in body_params.keys():
limbs_params[limb_name]['right'][param_name] = body_params[param_name]
limbs_params[limb_name]['left'][param_name] = body_params[param_name]
# Parameters to generate fake 3d and 2d skeletons
body_params['yaw_a'] = 2.0
body_params['pitch_a'] = -1.0
body_params['roll_a'] = -1.0
[body_params['x_com'], body_params['y_com'], body_params['z_com']] = [0.01, 0.05, -0.02]
limbs_params['wings']['left']['stroke_a'], limbs_params['wings']['right']['stroke_a'] = 2.5, -1.0
limbs_params['wings']['left']['deviation_a'], limbs_params['wings']['right']['deviation_a'] = -1.0, 3.0
limbs_params['wings']['left']['rotation_a'], limbs_params['wings']['right']['rotation_a'] = -2.0, 1.0
# Test all fit_methods and opt_methods
for fit_method in fit_methods:
body_skeleton = body.rotate_skeleton3d(body_sets.skeleton3d_init, body_params)
body_skeleton = body.translate_skeleton3d(body_skeleton, body_params)
limbs_skeleton = {}
for num_limb, limb_name in enumerate(animal_sets.limb_names):
limbs[limb_name] = import_module('flitrak3d.skeleton_fitter.modules.limbs.' + animal_sets.limb_module_names[num_limb])
limbs_sets[limb_name] = limbs[limb_name].LimbsModuleSettings(init_yaml_path)
limbs_params[limb_name] = copy.deepcopy(limbs_sets[limb_name].params_init)
for param_name in body_params.keys():
limbs_params[limb_name]['right'][param_name] = body_params[param_name]
limbs_params[limb_name]['left'][param_name] = body_params[param_name]
# Parameters to generate fake 3d and 2d skeletons
body_params['yaw_a'] = 2.0
body_params['pitch_a'] = -1.0
body_params['roll_a'] = -1.0
[body_params['x_com'], body_params['y_com'], body_params['z_com']] = [0.01, 0.05, -0.02]
limbs_params['wings']['left']['stroke_a'], limbs_params['wings']['right']['stroke_a'] = 2.5, -1.0
limbs_params['wings']['left']['deviation_a'], limbs_params['wings']['right']['deviation_a'] = -1.0, 3.0
limbs_params['wings']['left']['rotation_a'], limbs_params['wings']['right']['rotation_a'] = -2.0, 1.0
# Test all fit_methods and opt_methods
for fit_method in fit_methods:
body_skeleton = body.rotate_skeleton3d(body_sets.skeleton3d_init, body_params)
body_skeleton = body.translate_skeleton3d(body_skeleton, body_params)
limbs_skeleton = {}
limbs_skeleton['wings'][limb_name] = copy.deepcopy(limbs_sets[limb_name].skeleton3d_init)
for side in sides:
limbs_skeleton['wings'][limb_name][side] = limbs[limb_name].rotate_and_translate_skeleton3d(
limbs_sets[limb_name].skeleton3d_init[side], limbs_params[limb_name][side], side)
if '2d' in fit_method:
body_skeleton = utils.reproject_skeleton3d_to2d(body_skeleton, dlt_coefs)
for num_limb, limb_name in enumerate(animal_sets.limb_names):
for side in sides:
limbs_skeleton['wings'][limb_name][side] = utils.reproject_skeleton3d_to2d(limbs_skeleton['wings'][limb_name][side], dlt_coefs)
for opt_method in opt_methods:
# Test optimisation fit for the body
param_names_to_optimize = list(set(body_sets.param_names) - set(body_sets.param_names_to_keep_cst))
body_param_ests, body_rmse, body_nb_iterations = \
optimiser.optim_fit_body_params(animal_name, body_sets.skeleton3d_init, body_skeleton, body_sets.params_init,
param_names_to_optimize, 'body_' + fit_method, opt_method, body_sets.bounds_init, dlt_coefs=dlt_coefs)
assert body_rmse <= max_rmse, \
"Too high rmse ({0} > {1}) when optimise fitting in {2} with {3}".format(body_rmse, max_rmse, fit_method, opt_method)
assert body_nb_iterations <= max_nb_iterations, \
"Too high nb_iterations ({0} > {1}}) when optimise fitting in {2} with {3}".format(body_nb_iterations, max_nb_iterations, fit_method, opt_method)
# Test optimisation fit for the limbs
for num_limb, limb_name in enumerate(animal_sets.limb_names):
param_names_to_optimize = list(set(limbs_sets[limb_name].param_names) - set(limbs_sets[limb_name].param_names_to_keep_cst))
param_names_to_optimize = list(set(param_names_to_optimize) - set(body_sets.param_names))
wings_param_ests, wings_rmse, wings_nb_iterations = \
optimiser.optim_fit_limbs_params(animal_name, limbs_sets[limb_name].skeleton3d_init,
limbs_skeleton['wings'][limb_name], limbs_sets[limb_name].params_init,
param_names_to_optimize, 'limb_' + fit_method,
opt_method, limbs_sets[limb_name].bounds_init, dlt_coefs=dlt_coefs)
limbs_skeleton['wings'][limb_name] = copy.deepcopy(limbs_sets[limb_name].skeleton3d_init)
for side in sides:
limbs_skeleton['wings'][limb_name][side] = limbs[limb_name].rotate_and_translate_skeleton3d(
limbs_sets[limb_name].skeleton3d_init[side], limbs_params[limb_name][side], side)
if '2d' in fit_method:
body_skeleton = utils.reproject_skeleton3d_to2d(body_skeleton, dlt_coefs)
for num_limb, limb_name in enumerate(animal_sets.limb_names):
for side in sides:
limbs_skeleton['wings'][limb_name][side] = utils.reproject_skeleton3d_to2d(limbs_skeleton['wings'][limb_name][side], dlt_coefs)
for opt_method in opt_methods:
# Test optimisation fit for the body
param_names_to_optimize = list(set(body_sets.param_names) - set(body_sets.param_names_to_keep_cst))
body_param_ests, body_rmse, body_nb_iterations = \
optimiser.optim_fit_body_params(animal_name, body_sets.skeleton3d_init, body_skeleton, body_sets.params_init,
param_names_to_optimize, 'body_' + fit_method, opt_method, body_sets.bounds_init, dlt_coefs=dlt_coefs)
assert body_rmse <= max_rmse, \
"Too high rmse ({0} > {1}) when optimise fitting in {2} with {3}".format(body_rmse, max_rmse, fit_method, opt_method)
assert body_nb_iterations <= max_nb_iterations, \
"Too high nb_iterations ({0} > {1}}) when optimise fitting in {2} with {3}".format(body_nb_iterations, max_nb_iterations, fit_method, opt_method)
# Test optimisation fit for the limbs
for num_limb, limb_name in enumerate(animal_sets.limb_names):
param_names_to_optimize = list(set(limbs_sets[limb_name].param_names) - set(limbs_sets[limb_name].param_names_to_keep_cst))
param_names_to_optimize = list(set(param_names_to_optimize) - set(body_sets.param_names))
wings_param_ests, wings_rmse, wings_nb_iterations = \
optimiser.optim_fit_limbs_params(animal_name, limbs_sets[limb_name].skeleton3d_init,
limbs_skeleton['wings'][limb_name], limbs_sets[limb_name].params_init,
param_names_to_optimize, 'limb_' + fit_method,
opt_method, limbs_sets[limb_name].bounds_init, dlt_coefs=dlt_coefs)
for side in sides:
assert wings_rmse[side] <= max_rmse, \
"Too high rmse ({0} > {1}) when optimise fitting in {2} with {3} (side: {4})".\
format(wings_rmse[side], max_rmse, fit_method, opt_method, side)
assert wings_nb_iterations[side] <= max_nb_iterations, \
"Too high nb_iterations ({0} > {1}}) when optimise fitting in {2} with {3} (side: {4})".\
format(wings_nb_iterations[side], max_nb_iterations, fit_method, opt_method, side)
assert wings_rmse[side] <= max_rmse, \
"Too high rmse ({0} > {1}) when optimise fitting in {2} with {3} (side: {4})".\
format(wings_rmse[side], max_rmse, fit_method, opt_method, side)
assert wings_nb_iterations[side] <= max_nb_iterations, \
"Too high nb_iterations ({0} > {1}}) when optimise fitting in {2} with {3} (side: {4})".\
format(wings_nb_iterations[side], max_nb_iterations, fit_method, opt_method, side)
def main():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment