diff --git a/framework_tvb/tvb/adapters/visualizers/annotations_viewer.py b/framework_tvb/tvb/adapters/visualizers/annotations_viewer.py index bc976d9117..f97c961ca8 100644 --- a/framework_tvb/tvb/adapters/visualizers/annotations_viewer.py +++ b/framework_tvb/tvb/adapters/visualizers/annotations_viewer.py @@ -33,10 +33,12 @@ """ import json + from tvb.adapters.datatypes.h5.surface_h5 import SurfaceH5 from tvb.adapters.visualizers.surface_view import ABCSurfaceDisplayer, SurfaceURLGenerator from tvb.adapters.datatypes.db.region_mapping import RegionMappingIndex from tvb.adapters.datatypes.db.annotation import * +from tvb.core.entities.filters.chain import FilterChain from tvb.core.neocom import h5 from tvb.core.adapters.abcadapter import ABCAdapterForm from tvb.core.adapters.abcdisplayer import URLGenerator @@ -73,13 +75,22 @@ class ConnectivityAnnotationsViewForm(ABCAdapterForm): def __init__(self): super(ConnectivityAnnotationsViewForm, self).__init__() - # Used for filtering + + connectivity_index_filter = FilterChain(fields=[FilterChain.datatype + '.number_of_regions'], operations=["=="], + values=['fk_connectivity_gid']) self.connectivity_index = TraitDataTypeSelectField(ConnectivityAnnotationsViewModel.connectivity_index, - 'connectivity_index') + 'connectivity_index', + runtime_conditions=('annotations_index', + connectivity_index_filter)) + self.annotations_index = TraitDataTypeSelectField(ConnectivityAnnotationsViewModel.annotations_index, 'annotations_index', conditions=self.get_filters()) - self.region_mapping_index = TraitDataTypeSelectField(ConnectivityAnnotationsViewModel.region_mapping_index, - 'region_mapping_index') + + rm_runtime_condition = FilterChain(fields=[FilterChain.datatype + '.fk_connectivity_gid'], operations=["=="], + values=[FilterChain.DEFAULT_RUNTIME_VALUE]) + self.region_mapping_index = TraitDataTypeSelectField( + ConnectivityAnnotationsViewModel.region_mapping_index, 'region_mapping_index', + runtime_conditions=('connectivity_index', rm_runtime_condition)) @staticmethod def get_view_model(): diff --git a/framework_tvb/tvb/adapters/visualizers/connectivity.py b/framework_tvb/tvb/adapters/visualizers/connectivity.py index efbb0b6896..4df180223e 100644 --- a/framework_tvb/tvb/adapters/visualizers/connectivity.py +++ b/framework_tvb/tvb/adapters/visualizers/connectivity.py @@ -37,6 +37,7 @@ import math import numpy from copy import copy + from tvb.adapters.visualizers.time_series import ABCSpaceDisplayer from tvb.adapters.visualizers.surface_view import SurfaceURLGenerator from tvb.basic.neotraits.api import Attr @@ -114,8 +115,9 @@ class ConnectivityViewerForm(ABCAdapterForm): def __init__(self): super(ConnectivityViewerForm, self).__init__() - self.connectivity = TraitDataTypeSelectField(ConnectivityViewerModel.connectivity, name='input_data', - conditions=self.get_filters()) + self.connectivity_data = TraitDataTypeSelectField(ConnectivityViewerModel.connectivity, + name='connectivity_data', conditions=self.get_filters()) + surface_conditions = FilterChain(fields=[FilterChain.datatype + '.surface_type'], operations=["=="], values=['Cortical Surface']) self.surface_data = TraitDataTypeSelectField(ConnectivityViewerModel.surface_data, name='surface_data', @@ -123,12 +125,15 @@ def __init__(self): self.step = FloatField(ConnectivityViewerModel.step, name='step') - colors_conditions = FilterChain(fields=[FilterChain.datatype + '.ndim'], operations=["=="], values=[1]) - self.colors = TraitDataTypeSelectField(ConnectivityViewerModel.colors, name='colors', - conditions=colors_conditions) + cm_condition = FilterChain(fields=[FilterChain.datatype + '.ndim'], operations=["=="], values=[1]) + cm_runtime_condition = FilterChain(fields=[FilterChain.datatype + '.fk_connectivity_gid'], operations=["=="], + values=[FilterChain.DEFAULT_RUNTIME_VALUE]) - rays_conditions = FilterChain(fields=[FilterChain.datatype + '.ndim'], operations=["=="], values=[1]) - self.rays = TraitDataTypeSelectField(ConnectivityViewerModel.rays, name='rays', conditions=rays_conditions) + self.colors = TraitDataTypeSelectField(ConnectivityViewerModel.colors, name='colors', + conditions=cm_condition, + runtime_conditions=('connectivity_data', cm_runtime_condition)) + self.rays = TraitDataTypeSelectField(ConnectivityViewerModel.rays, name='rays', conditions=cm_condition, + runtime_conditions=('connectivity_data', cm_runtime_condition)) @staticmethod def get_view_model(): @@ -258,7 +263,8 @@ def _compute_connectivity_global_params(self, connectivity): path_labels = SurfaceURLGenerator.paths2url(conn_gid, 'ordered_labels') path_hemisphere_order_indices = SurfaceURLGenerator.paths2url(conn_gid, 'hemisphere_order_indices') - algo = AlgorithmService().get_algorithm_by_module_and_class(CONNECTIVITY_CREATOR_MODULE, CONNECTIVITY_CREATOR_CLASS) + algo = AlgorithmService().get_algorithm_by_module_and_class(CONNECTIVITY_CREATOR_MODULE, + CONNECTIVITY_CREATOR_CLASS) submit_url = '/{}/{}/{}'.format(SurfaceURLGenerator.FLOW, algo.fk_category, algo.id) global_pages = dict(controlPage="connectivity/top_right_controls") diff --git a/framework_tvb/tvb/adapters/visualizers/region_volume_mapping.py b/framework_tvb/tvb/adapters/visualizers/region_volume_mapping.py index 51b8c33046..e4951ad5a4 100644 --- a/framework_tvb/tvb/adapters/visualizers/region_volume_mapping.py +++ b/framework_tvb/tvb/adapters/visualizers/region_volume_mapping.py @@ -388,8 +388,12 @@ def __init__(self): self.connectivity_measure = TraitDataTypeSelectField( ConnectivityMeasureVolumeVisualizerModel.connectivity_measure, name='connectivity_measure', conditions=self.get_filters()) + + rvm_runtime_filter = FilterChain(fields=[FilterChain.datatype + '.gid'], operations=["=="], + values=['fk_connectivity_gid:fk_connectivity_gid']) self.region_mapping_volume = TraitDataTypeSelectField( - ConnectivityMeasureVolumeVisualizerModel.region_mapping_volume, name='region_mapping_volume') + ConnectivityMeasureVolumeVisualizerModel.region_mapping_volume, name='region_mapping_volume', + runtime_conditions=('connectivity_measure', rvm_runtime_filter)) @staticmethod def get_view_model(): @@ -461,8 +465,12 @@ def __init__(self): cm_conditions = FilterChain( fields=[FilterChain.datatype + '.ndim', FilterChain.datatype + '.has_volume_mapping'], operations=["==", "=="], values=[1, True]) + cm_runtime_filter = FilterChain(fields=[FilterChain.datatype + '.gid'], operations=["=="], + values=['fk_connectivity_gid:fk_connectivity_gid']) self.connectivity_measure = TraitDataTypeSelectField(RegionVolumeMappingVisualiserModel.connectivity_measure, - name='connectivity_measure', conditions=cm_conditions) + name='connectivity_measure', conditions=cm_conditions, + runtime_conditions=('region_mapping_volume', + cm_runtime_filter)) @staticmethod def get_view_model(): diff --git a/framework_tvb/tvb/adapters/visualizers/surface_view.py b/framework_tvb/tvb/adapters/visualizers/surface_view.py index 640eda8122..60f2916c77 100644 --- a/framework_tvb/tvb/adapters/visualizers/surface_view.py +++ b/framework_tvb/tvb/adapters/visualizers/surface_view.py @@ -159,12 +159,18 @@ class BaseSurfaceViewerForm(ABCAdapterForm): def __init__(self): super(BaseSurfaceViewerForm, self).__init__() + self.region_map = TraitDataTypeSelectField(BaseSurfaceViewerModel.region_map, name='region_map') + conn_filter = FilterChain( fields=[FilterChain.datatype + '.ndim', FilterChain.datatype + '.has_surface_mapping'], operations=["==", "=="], values=[1, True]) + cm_runtime_filter = FilterChain(fields=[FilterChain.datatype + '.gid'], operations=["=="], + values=['fk_connectivity_gid:fk_connectivity_gid']) self.connectivity_measure = TraitDataTypeSelectField(BaseSurfaceViewerModel.connectivity_measure, - name='connectivity_measure', conditions=conn_filter) + name='connectivity_measure', conditions=conn_filter, + runtime_conditions=('region_map', cm_runtime_filter)) + self.shell_surface = TraitDataTypeSelectField(BaseSurfaceViewerModel.shell_surface, name='shell_surface') @staticmethod @@ -184,6 +190,10 @@ class SurfaceViewerModel(BaseSurfaceViewerModel): class SurfaceViewerForm(BaseSurfaceViewerForm): def __init__(self): super(SurfaceViewerForm, self).__init__() + rm_runtime_condition = FilterChain(fields=[FilterChain.datatype + '.fk_surface_gid'], operations=["=="], + values=[FilterChain.DEFAULT_RUNTIME_VALUE]) + self.region_map.runtime_conditions = ('surface', rm_runtime_condition) + self.surface = TraitDataTypeSelectField(SurfaceViewerModel.surface, name='surface') @staticmethod diff --git a/framework_tvb/tvb/adapters/visualizers/time_series_volume.py b/framework_tvb/tvb/adapters/visualizers/time_series_volume.py index dec8124928..12333d5c1c 100644 --- a/framework_tvb/tvb/adapters/visualizers/time_series_volume.py +++ b/framework_tvb/tvb/adapters/visualizers/time_series_volume.py @@ -74,10 +74,11 @@ class TimeSeriesVolumeVisualiserForm(ABCAdapterForm): def __init__(self): super(TimeSeriesVolumeVisualiserForm, self).__init__() - self.time_series = TraitDataTypeSelectField(TimeSeriesVolumeVisualiserModel.time_series, name='time_series', - conditions=self.get_filters()) + self.time_series = TraitDataTypeSelectField(TimeSeriesVolumeVisualiserModel.time_series, name='time_series') + # conditions=self.get_filters()) self.background = TraitDataTypeSelectField(TimeSeriesVolumeVisualiserModel.background, name='background') + @staticmethod def get_view_model(): return TimeSeriesVolumeVisualiserModel diff --git a/framework_tvb/tvb/core/entities/filters/chain.py b/framework_tvb/tvb/core/entities/filters/chain.py index 40ad2ae920..07b3deb112 100644 --- a/framework_tvb/tvb/core/entities/filters/chain.py +++ b/framework_tvb/tvb/core/entities/filters/chain.py @@ -77,6 +77,9 @@ class FilterChain(object): algorithm_category_replacement = "AlgorithmCategory" operation_replacement = "Operation" + # This is used for the simplest runtime filters, where we just have to replace the current value at runtime + DEFAULT_RUNTIME_VALUE = "default_runtime_value" + def __init__(self, display_name="", fields=None, values=None, operations=None, operator_between_fields='and'): """ Initialize filter attributes. diff --git a/framework_tvb/tvb/core/neotraits/forms.py b/framework_tvb/tvb/core/neotraits/forms.py index 6f9d950101..a7531ef3dd 100644 --- a/framework_tvb/tvb/core/neotraits/forms.py +++ b/framework_tvb/tvb/core/neotraits/forms.py @@ -134,7 +134,7 @@ class TraitDataTypeSelectField(TraitField): def __init__(self, trait_attribute, name=None, conditions=None, draw_dynamic_conditions_buttons=True, has_all_option=False, - show_only_all_option=False): + show_only_all_option=False, runtime_conditions=None): super(TraitDataTypeSelectField, self).__init__(trait_attribute, name) if issubclass(type(trait_attribute), DataTypeGidAttr): @@ -151,6 +151,7 @@ def __init__(self, trait_attribute, name=None, conditions=None, self.has_all_option = has_all_option self.show_only_all_option = show_only_all_option self.datatype_options = [] + self.runtime_conditions = runtime_conditions def from_trait(self, trait, f_name): if hasattr(trait, f_name): @@ -166,6 +167,10 @@ def get_dynamic_filters(self): def get_form_filters(self): return self.conditions + @property + def get_runtime_filters(self): + return self.runtime_conditions + def options(self): if not self.required: choice = None @@ -414,6 +419,7 @@ def __str__(self): class Form(object): + template = 'form_fields/form.html' def __init__(self): self.errors = [] diff --git a/framework_tvb/tvb/core/services/algorithm_service.py b/framework_tvb/tvb/core/services/algorithm_service.py index a4786543d6..d0ac728f5d 100644 --- a/framework_tvb/tvb/core/services/algorithm_service.py +++ b/framework_tvb/tvb/core/services/algorithm_service.py @@ -112,6 +112,10 @@ def fill_selectfield_with_datatypes(self, field, project_id, extra_conditions=No filtering_conditions = FilterChain() filtering_conditions += field.conditions filtering_conditions += extra_conditions + + if field.runtime_conditions is not None: + filtering_conditions += field.runtime_conditions[1] + datatypes, _ = dao.get_values_of_datatype(project_id, field.datatype_index, filtering_conditions) datatype_options = [] for datatype in datatypes: diff --git a/framework_tvb/tvb/interfaces/web/controllers/flow_controller.py b/framework_tvb/tvb/interfaces/web/controllers/flow_controller.py index 42d1b40a84..48cd8dd454 100644 --- a/framework_tvb/tvb/interfaces/web/controllers/flow_controller.py +++ b/framework_tvb/tvb/interfaces/web/controllers/flow_controller.py @@ -50,6 +50,7 @@ from tvb.core.entities.file.files_helper import FilesHelper from tvb.core.entities.filters.chain import FilterChain from tvb.core.entities.load import load_entity_by_gid +from tvb.core.entities.storage import dao from tvb.core.neocom import h5 from tvb.core.neocom.h5 import REGISTRY from tvb.core.neotraits.forms import TraitDataTypeSelectField @@ -250,38 +251,148 @@ def default(self, step_key, adapter_key, cancel=False, back_page=None, **data): self.fill_default_attributes(template_specification, algorithm.displayname) return template_specification + @staticmethod + def _fill_reversed_filter_value(runtime_filters, i): + # Get index of the currently chosen value for the field that the filtering will be applied on + datatype_index = dao.get_datatype_by_gid(runtime_filters['runtime_reverse_filtering_values'][i]) + if datatype_index: + # Get the linked datatype and the value that needs to be used for the filter + linked_datatype_field = runtime_filters['runtime_values'][i] + split_linked_datatype_field = linked_datatype_field.split(':') + linked_datatype_gid = getattr(datatype_index, split_linked_datatype_field[0]) + linked_datatype_index = dao.get_datatype_by_gid(linked_datatype_gid) + filter_field = runtime_filters['runtime_fields'][i].replace(FilterChain.datatype + '.', '') + filter_value = getattr(linked_datatype_index, filter_field) + + # If there was a ':' character in linked_datatype_field, it means that the linked datatype is not among the + # ui fields of the current form and the value needed to be obtained from one of the existing fields + runtime_filters['runtime_values'][i] = filter_value + runtime_filters['runtime_fields'][i] = FilterChain.datatype + '.' + (split_linked_datatype_field[1] if + len(split_linked_datatype_field) > 1 + else filter_field) + @expose_fragment('form_fields/options_field') @settings @context_selected - def get_filtered_datatypes(self, dt_module, dt_class, filters, has_all_option, has_none_option): - """ - Given the name from the input tree, the dataType required and a number of - filters, return the available dataType that satisfy the conditions imposed. + def get_filtered_datatypes(self, dt_module, dt_class, default_filters, user_filters, runtime_filters, + has_all_option, has_none_option): + # type: (str, str, str, str, str, bool, bool) -> dict + """ + This method applies all three types of filters on one field. + @param dt_module: module of the field's datatype index + @param dt_class: class of the field's datatype index + @param default_filters: a string in json format which contains all default filters for this field + @param user_filters: a string in json format which contains all filters defined by users for this field + @param runtime_filters: a string in json format which contains all filters for this field + that have values that can be obtained only at runtime (related to linked data types) + @param has_all_option: if the All option should be added or not + @param has_none_option: if the None option should be added or not (if the field is required or not) """ index_class = getattr(sys.modules[dt_module], dt_class)() - filters_dict = json.loads(filters) - - fields = [] - operations = [] - values = [] - - for idx in range(len(filters_dict['fields'])): - fields.append(filters_dict['fields'][idx]) - operations.append(filters_dict['operations'][idx]) - values.append(filters_dict['values'][idx]) - - filter = FilterChain(fields=fields, operations=operations, values=values) + default_filters_dict = json.loads(default_filters) + user_filters_dict = json.loads(user_filters) + runtime_filters_dict = json.loads(runtime_filters) + + for i in range(len(runtime_filters_dict['runtime_fields'])): + if (len(runtime_filters_dict['runtime_reverse_filtering_values'][i])) > 0: + self._fill_reversed_filter_value(runtime_filters_dict, i) + + filters = FilterChain(fields=default_filters_dict['default_fields'], + operations=default_filters_dict['default_operations'], + values=default_filters_dict['default_values']) + filters += FilterChain(fields=user_filters_dict['user_fields'], + operations=user_filters_dict['user_operations'], + values=user_filters_dict['user_values']) + filters += FilterChain(fields=runtime_filters_dict['runtime_fields'], + operations=runtime_filters_dict['runtime_operations'], + values=runtime_filters_dict['runtime_values']) project = common.get_current_project() data_type_gid_attr = DataTypeGidAttr(linked_datatype=REGISTRY.get_datatype_for_index(index_class)) data_type_gid_attr.required = not string2bool(has_none_option) - select_field = TraitDataTypeSelectField(data_type_gid_attr, conditions=filter, + select_field = TraitDataTypeSelectField(data_type_gid_attr, conditions=filters, has_all_option=string2bool(has_all_option)) self.algorithm_service.fill_selectfield_with_datatypes(select_field, project.id) return {'options': select_field.options()} + @expose_fragment('form_fields/form') + @settings + @context_selected + def get_runtime_filtered_form(self, algorithm_id, default_filters, user_filters, runtime_filters): + # type: (str, str, str, str) -> dict + """ + This method returns a newly rendered form, where all the filters are applied on the respective fields. + @param algorithm_id: id of the adapter that can be used to return an instance of the current form + @param default_filters: a string in json format which contains all default filters + @param user_filters: a string in json format which contains all filters defined by users + @param runtime_filters: a string in json format which contains all filters that have values that can be obtained + only at runtime (related to linked data types) + """ + + # Get an instance of the needed form class + algorithm = dao.get_algorithm_by_id(algorithm_id) + adapter = getattr(sys.modules[algorithm.module], algorithm.classname)() + form = adapter.get_form_class()() + + # Load filters as dictionaries + user_filters_dict = json.loads(user_filters) + default_filters_dict = json.loads(default_filters) + runtime_filters_dict = json.loads(runtime_filters) + project_id = common.get_current_project().id + + # Iterate over the filters of each field + for key, user_filters in user_filters_dict.items(): + select_field_attr = getattr(form, key) + + # Add default filters even if they are empty, otherwise applying the + operator on None will fail + default_filter_chain = FilterChain(fields=default_filters_dict[key]['default_fields'], + operations=default_filters_dict[key]['default_operations'], + values=default_filters_dict[key]['default_values']) + select_field_attr.conditions = default_filter_chain + + # Add filters defined by users so they can both be applied + select_field_attr.conditions += FilterChain(fields=user_filters['user_fields'], + operations=user_filters['user_operations'], + values=user_filters['user_values']) + runtime_filters = runtime_filters_dict[key] + + # Keep these values because they need to be reset after applying the runtime filters + runtime_filter_values_copy = runtime_filters['runtime_values'].copy() + runtime_filter_fields_copy = runtime_filters['runtime_fields'].copy() + + for i in range(len(runtime_filters['runtime_fields'])): + + # If this condition is true, then it means we need to apply the filters in 'inversed order', + # so we need the information from the filter value (and maybe from the filter field as well) + if (len(runtime_filters['runtime_reverse_filtering_values'][i])) > 0: + self._fill_reversed_filter_value(runtime_filters, i) + else: + runtime_filter_values_copy[i] = FilterChain.DEFAULT_RUNTIME_VALUE + + # Runtime conditions are added as a tuple of two elements, where the first element is the field that can + # trigger a change in the current field and the second element is the filter itself + if select_field_attr.runtime_conditions: + select_field_attr.runtime_conditions = (select_field_attr.runtime_conditions[0], FilterChain( + fields=runtime_filters['runtime_fields'], operations=runtime_filters['runtime_operations'], + values=runtime_filters['runtime_values'])) + + # Perform the filtering + self.algorithm_service.fill_selectfield_with_datatypes(select_field_attr, project_id) + select_field_attr.data = runtime_filters['ui_value'] + + # After applying the user defined filters, we need to eliminate them so they won't be added as hidden + # fields next to the default and runtime filters + select_field_attr.conditions = default_filter_chain + + # Runtime conditions need to be reset, because they were edited so they can be applied + if select_field_attr.runtime_conditions: + select_field_attr.runtime_conditions[1].values = runtime_filter_values_copy + select_field_attr.runtime_conditions[1].fields = runtime_filter_fields_copy + + return {'adapter_form': form} + def execute_post(self, project_id, submit_url, step_key, algorithm, **data): """ Execute HTTP POST on a generic step.""" errors = None @@ -303,7 +414,6 @@ def execute_post(self, project_id, submit_url, step_key, algorithm, **data): raise InvalidFormValues("Invalid form inputs! Could not fill algorithm from the given inputs!", error_dict=form.get_errors_dict()) - adapter_instance.submit_form(form) if issubclass(type(adapter_instance), ABCDisplayer): @@ -317,7 +427,7 @@ def execute_post(self, project_id, submit_url, step_key, algorithm, **data): return {} result = self.operation_services.fire_operation(adapter_instance, common.get_logged_user(), - project_id, view_model=view_model) + project_id, view_model=view_model) if isinstance(result, list): result = "Launched %s operations." % len(result) common.set_important_message(str(result)) diff --git a/framework_tvb/tvb/interfaces/web/static/js/filters.js b/framework_tvb/tvb/interfaces/web/static/js/filters.js index 4b00837976..249edead57 100644 --- a/framework_tvb/tvb/interfaces/web/static/js/filters.js +++ b/framework_tvb/tvb/interfaces/web/static/js/filters.js @@ -52,7 +52,7 @@ function _FIL_createUiForFilterType(filter, newDiv, isDate){ function addFilter(div_id, filters) { //Create a new div for the filter - var newDiv = $('