-
Notifications
You must be signed in to change notification settings - Fork 195
/
Copy pathtopk_specified_field_selector.py
92 lines (80 loc) · 3.5 KB
/
topk_specified_field_selector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import heapq
from typing import Optional
from pydantic import Field, PositiveInt
from typing_extensions import Annotated
from data_juicer.utils.common_utils import stats_to_number
from ..base_op import OPERATORS, Selector
@OPERATORS.register_module('topk_specified_field_selector')
class TopkSpecifiedFieldSelector(Selector):
"""Selector to select top samples based on the sorted specified field
value."""
def __init__(self,
field_key: str = '',
top_ratio: Optional[Annotated[float,
Field(ge=0, le=1)]] = None,
topk: Optional[PositiveInt] = None,
reverse: bool = True,
*args,
**kwargs):
"""
Initialization method.
:param field_key: Selector based on the specified value
corresponding to the target key. The target key
corresponding to multi-level field information need to be
separated by '.'.
:param top_ratio: Ratio of selected top samples, samples will be
selected if their specified field values are within this
parameter. When both topk and top_ratio are set, the value
corresponding to the smaller number of samples will be
applied.
:param topk: Number of selected top sample, samples will be
selected if their specified field values are within this
parameter. When both topk and top_ratio are set, the value
corresponding to the smaller number of samples will be
applied.
:param reverse: Determine the sorting rule, if reverse=True,
then sort in descending order.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.field_key = field_key
self.top_ratio = top_ratio
self.topk = topk
self.reverse = reverse
def process(self, dataset):
if len(dataset) <= 1 or not self.field_key:
return dataset
select_num = 0
if not self.top_ratio:
if not self.topk:
return dataset
else:
select_num = self.topk
else:
select_num = self.top_ratio * len(dataset)
if self.topk and self.topk < select_num:
select_num = self.topk
field_keys = self.field_key.split('.')
assert field_keys[0] in dataset.features.keys(
), "'{}' not in {}".format(field_keys[0], dataset.features.keys())
if len(field_keys) == 1:
field_value_list = dataset[field_keys[0]]
else:
field_value_list = []
for item in dataset[field_keys[0]]:
field_value = item
for key in field_keys[1:]:
assert key in field_value.keys(), "'{}' not in {}".format(
key, field_value.keys())
field_value = field_value[key]
field_value_list.append(
stats_to_number(field_value, self.reverse))
if self.reverse:
select_index = heapq.nlargest(int(select_num), range(len(dataset)),
field_value_list.__getitem__)
else:
select_index = heapq.nsmallest(int(select_num),
range(len(dataset)),
field_value_list.__getitem__)
return dataset.select(select_index)