Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TVB-2757: Apply filters on linked datatypes at runtime by rerendering the form #351

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
19 changes: 15 additions & 4 deletions framework_tvb/tvb/adapters/visualizers/annotations_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
"""

import json

from tvb.adapters.datatypes.h5.surface_h5 import SurfaceH5
from tvb.adapters.visualizers.surface_view import ABCSurfaceDisplayer, SurfaceURLGenerator
from tvb.adapters.datatypes.db.region_mapping import RegionMappingIndex
from tvb.adapters.datatypes.db.annotation import *
from tvb.core.entities.filters.chain import FilterChain
from tvb.core.neocom import h5
from tvb.core.adapters.abcadapter import ABCAdapterForm
from tvb.core.adapters.abcdisplayer import URLGenerator
Expand Down Expand Up @@ -73,13 +75,22 @@ class ConnectivityAnnotationsViewForm(ABCAdapterForm):

def __init__(self):
super(ConnectivityAnnotationsViewForm, self).__init__()
# Used for filtering

connectivity_index_filter = FilterChain(fields=[FilterChain.datatype + '.number_of_regions'], operations=["=="],
values=['fk_connectivity_gid'])
self.connectivity_index = TraitDataTypeSelectField(ConnectivityAnnotationsViewModel.connectivity_index,
'connectivity_index')
'connectivity_index',
runtime_conditions=('annotations_index',
connectivity_index_filter))

self.annotations_index = TraitDataTypeSelectField(ConnectivityAnnotationsViewModel.annotations_index,
'annotations_index', conditions=self.get_filters())
self.region_mapping_index = TraitDataTypeSelectField(ConnectivityAnnotationsViewModel.region_mapping_index,
'region_mapping_index')

rm_runtime_condition = FilterChain(fields=[FilterChain.datatype + '.fk_connectivity_gid'], operations=["=="],
values=[FilterChain.DEFAULT_RUNTIME_VALUE])
self.region_mapping_index = TraitDataTypeSelectField(
ConnectivityAnnotationsViewModel.region_mapping_index, 'region_mapping_index',
runtime_conditions=('connectivity_index', rm_runtime_condition))

@staticmethod
def get_view_model():
Expand Down
22 changes: 14 additions & 8 deletions framework_tvb/tvb/adapters/visualizers/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import math
import numpy
from copy import copy

from tvb.adapters.visualizers.time_series import ABCSpaceDisplayer
from tvb.adapters.visualizers.surface_view import SurfaceURLGenerator
from tvb.basic.neotraits.api import Attr
Expand Down Expand Up @@ -114,21 +115,25 @@ class ConnectivityViewerForm(ABCAdapterForm):
def __init__(self):
super(ConnectivityViewerForm, self).__init__()

self.connectivity = TraitDataTypeSelectField(ConnectivityViewerModel.connectivity, name='input_data',
conditions=self.get_filters())
self.connectivity_data = TraitDataTypeSelectField(ConnectivityViewerModel.connectivity,
name='connectivity_data', conditions=self.get_filters())

surface_conditions = FilterChain(fields=[FilterChain.datatype + '.surface_type'], operations=["=="],
values=['Cortical Surface'])
self.surface_data = TraitDataTypeSelectField(ConnectivityViewerModel.surface_data, name='surface_data',
conditions=surface_conditions)

self.step = FloatField(ConnectivityViewerModel.step, name='step')

colors_conditions = FilterChain(fields=[FilterChain.datatype + '.ndim'], operations=["=="], values=[1])
self.colors = TraitDataTypeSelectField(ConnectivityViewerModel.colors, name='colors',
conditions=colors_conditions)
cm_condition = FilterChain(fields=[FilterChain.datatype + '.ndim'], operations=["=="], values=[1])
cm_runtime_condition = FilterChain(fields=[FilterChain.datatype + '.fk_connectivity_gid'], operations=["=="],
values=[FilterChain.DEFAULT_RUNTIME_VALUE])

rays_conditions = FilterChain(fields=[FilterChain.datatype + '.ndim'], operations=["=="], values=[1])
self.rays = TraitDataTypeSelectField(ConnectivityViewerModel.rays, name='rays', conditions=rays_conditions)
self.colors = TraitDataTypeSelectField(ConnectivityViewerModel.colors, name='colors',
conditions=cm_condition,
runtime_conditions=('connectivity_data', cm_runtime_condition))
self.rays = TraitDataTypeSelectField(ConnectivityViewerModel.rays, name='rays', conditions=cm_condition,
runtime_conditions=('connectivity_data', cm_runtime_condition))

@staticmethod
def get_view_model():
Expand Down Expand Up @@ -258,7 +263,8 @@ def _compute_connectivity_global_params(self, connectivity):
path_labels = SurfaceURLGenerator.paths2url(conn_gid, 'ordered_labels')
path_hemisphere_order_indices = SurfaceURLGenerator.paths2url(conn_gid, 'hemisphere_order_indices')

algo = AlgorithmService().get_algorithm_by_module_and_class(CONNECTIVITY_CREATOR_MODULE, CONNECTIVITY_CREATOR_CLASS)
algo = AlgorithmService().get_algorithm_by_module_and_class(CONNECTIVITY_CREATOR_MODULE,
CONNECTIVITY_CREATOR_CLASS)
submit_url = '/{}/{}/{}'.format(SurfaceURLGenerator.FLOW, algo.fk_category, algo.id)
global_pages = dict(controlPage="connectivity/top_right_controls")

Expand Down
12 changes: 10 additions & 2 deletions framework_tvb/tvb/adapters/visualizers/region_volume_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,12 @@ def __init__(self):
self.connectivity_measure = TraitDataTypeSelectField(
ConnectivityMeasureVolumeVisualizerModel.connectivity_measure, name='connectivity_measure',
conditions=self.get_filters())

rvm_runtime_filter = FilterChain(fields=[FilterChain.datatype + '.gid'], operations=["=="],
values=['fk_connectivity_gid:fk_connectivity_gid'])
self.region_mapping_volume = TraitDataTypeSelectField(
ConnectivityMeasureVolumeVisualizerModel.region_mapping_volume, name='region_mapping_volume')
ConnectivityMeasureVolumeVisualizerModel.region_mapping_volume, name='region_mapping_volume',
runtime_conditions=('connectivity_measure', rvm_runtime_filter))

@staticmethod
def get_view_model():
Expand Down Expand Up @@ -461,8 +465,12 @@ def __init__(self):
cm_conditions = FilterChain(
fields=[FilterChain.datatype + '.ndim', FilterChain.datatype + '.has_volume_mapping'],
operations=["==", "=="], values=[1, True])
cm_runtime_filter = FilterChain(fields=[FilterChain.datatype + '.gid'], operations=["=="],
values=['fk_connectivity_gid:fk_connectivity_gid'])
self.connectivity_measure = TraitDataTypeSelectField(RegionVolumeMappingVisualiserModel.connectivity_measure,
name='connectivity_measure', conditions=cm_conditions)
name='connectivity_measure', conditions=cm_conditions,
runtime_conditions=('region_mapping_volume',
cm_runtime_filter))

@staticmethod
def get_view_model():
Expand Down
12 changes: 11 additions & 1 deletion framework_tvb/tvb/adapters/visualizers/surface_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,18 @@ class BaseSurfaceViewerForm(ABCAdapterForm):

def __init__(self):
super(BaseSurfaceViewerForm, self).__init__()

self.region_map = TraitDataTypeSelectField(BaseSurfaceViewerModel.region_map, name='region_map')

conn_filter = FilterChain(
fields=[FilterChain.datatype + '.ndim', FilterChain.datatype + '.has_surface_mapping'],
operations=["==", "=="], values=[1, True])
cm_runtime_filter = FilterChain(fields=[FilterChain.datatype + '.gid'], operations=["=="],
values=['fk_connectivity_gid:fk_connectivity_gid'])
self.connectivity_measure = TraitDataTypeSelectField(BaseSurfaceViewerModel.connectivity_measure,
name='connectivity_measure', conditions=conn_filter)
name='connectivity_measure', conditions=conn_filter,
runtime_conditions=('region_map', cm_runtime_filter))

self.shell_surface = TraitDataTypeSelectField(BaseSurfaceViewerModel.shell_surface, name='shell_surface')

@staticmethod
Expand All @@ -184,6 +190,10 @@ class SurfaceViewerModel(BaseSurfaceViewerModel):
class SurfaceViewerForm(BaseSurfaceViewerForm):
def __init__(self):
super(SurfaceViewerForm, self).__init__()
rm_runtime_condition = FilterChain(fields=[FilterChain.datatype + '.fk_surface_gid'], operations=["=="],
values=[FilterChain.DEFAULT_RUNTIME_VALUE])
self.region_map.runtime_conditions = ('surface', rm_runtime_condition)

self.surface = TraitDataTypeSelectField(SurfaceViewerModel.surface, name='surface')

@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions framework_tvb/tvb/adapters/visualizers/time_series_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,11 @@ class TimeSeriesVolumeVisualiserForm(ABCAdapterForm):

def __init__(self):
super(TimeSeriesVolumeVisualiserForm, self).__init__()
self.time_series = TraitDataTypeSelectField(TimeSeriesVolumeVisualiserModel.time_series, name='time_series',
conditions=self.get_filters())
self.time_series = TraitDataTypeSelectField(TimeSeriesVolumeVisualiserModel.time_series, name='time_series')
# conditions=self.get_filters())
self.background = TraitDataTypeSelectField(TimeSeriesVolumeVisualiserModel.background, name='background')


@staticmethod
def get_view_model():
return TimeSeriesVolumeVisualiserModel
Expand Down
3 changes: 3 additions & 0 deletions framework_tvb/tvb/core/entities/filters/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class FilterChain(object):
algorithm_category_replacement = "AlgorithmCategory"
operation_replacement = "Operation"

# This is used for the simplest runtime filters, where we just have to replace the current value at runtime
DEFAULT_RUNTIME_VALUE = "default_runtime_value"

def __init__(self, display_name="", fields=None, values=None, operations=None, operator_between_fields='and'):
"""
Initialize filter attributes.
Expand Down
8 changes: 7 additions & 1 deletion framework_tvb/tvb/core/neotraits/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class TraitDataTypeSelectField(TraitField):

def __init__(self, trait_attribute, name=None, conditions=None,
draw_dynamic_conditions_buttons=True, has_all_option=False,
show_only_all_option=False):
show_only_all_option=False, runtime_conditions=None):
super(TraitDataTypeSelectField, self).__init__(trait_attribute, name)

if issubclass(type(trait_attribute), DataTypeGidAttr):
Expand All @@ -151,6 +151,7 @@ def __init__(self, trait_attribute, name=None, conditions=None,
self.has_all_option = has_all_option
self.show_only_all_option = show_only_all_option
self.datatype_options = []
self.runtime_conditions = runtime_conditions

def from_trait(self, trait, f_name):
if hasattr(trait, f_name):
Expand All @@ -166,6 +167,10 @@ def get_dynamic_filters(self):
def get_form_filters(self):
return self.conditions

@property
def get_runtime_filters(self):
return self.runtime_conditions

def options(self):
if not self.required:
choice = None
Expand Down Expand Up @@ -414,6 +419,7 @@ def __str__(self):


class Form(object):
template = 'form_fields/form.html'

def __init__(self):
self.errors = []
Expand Down
4 changes: 4 additions & 0 deletions framework_tvb/tvb/core/services/algorithm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def fill_selectfield_with_datatypes(self, field, project_id, extra_conditions=No
filtering_conditions = FilterChain()
filtering_conditions += field.conditions
filtering_conditions += extra_conditions

if field.runtime_conditions is not None:
filtering_conditions += field.runtime_conditions[1]

datatypes, _ = dao.get_values_of_datatype(project_id, field.datatype_index, filtering_conditions)
datatype_options = []
for datatype in datatypes:
Expand Down
Loading