Skip to content

Commit

Permalink
fix weight traversal for long keys. The state traverser did not corre…
Browse files Browse the repository at this point in the history
…ctly (KeyviDev#293)

fix weight traversal for long keys. The state traverser did not correctly report the inner weight for the long key optimization. The parent weight must be provided when getting outgoing states.

Fixes issues with completions for long keys
  • Loading branch information
hendrikmuhs authored Mar 15, 2024
1 parent 2254d5c commit ba1c697
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
7 changes: 4 additions & 3 deletions keyvi/include/keyvi/dictionary/fsa/automata.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ class Automata final {
template <class TransitionT, typename std::enable_if<std::is_base_of<traversal::Transition, TransitionT>::value,
traversal::Transition>::type* = nullptr>
void GetOutGoingTransitions(uint64_t starting_state, traversal::TraversalState<TransitionT>* traversal_state,
traversal::TraversalPayload<TransitionT>* payload) const {
traversal::TraversalPayload<TransitionT>* payload,
[[maybe_unused]] uint32_t parent_weight = 0) const {
// reset the state
traversal_state->Clear();

Expand Down Expand Up @@ -236,10 +237,10 @@ class Automata final {
typename std::enable_if<std::is_base_of<traversal::WeightedTransition, TransitionT>::value,
traversal::WeightedTransition>::type* = nullptr>
inline void GetOutGoingTransitions(uint64_t starting_state, traversal::TraversalState<TransitionT>* traversal_state,
traversal::TraversalPayload<TransitionT>* payload) const {
traversal::TraversalPayload<TransitionT>* payload,
[[maybe_unused]] uint32_t parent_weight) const {
// reset the state
traversal_state->Clear();
uint32_t parent_weight = GetInnerWeight(starting_state);

#if defined(KEYVI_SSE42)
// Optimized version using SSE4.2, see http://www.strchr.com/strcmp_and_strlen_using_sse_4.2
Expand Down
10 changes: 6 additions & 4 deletions keyvi/include/keyvi/dictionary/fsa/state_traverser.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class StateTraverser final {
explicit StateTraverser(automata_t f)
: fsa_(f), current_state_(f->GetStartState()), current_weight_(0), current_label_(0), at_end_(false), stack_() {
TRACE("StateTraverser starting with Start state %d", current_state_);
f->GetOutGoingTransitions(current_state_, &stack_.GetStates(), &stack_.traversal_stack_payload);
f->GetOutGoingTransitions(current_state_, &stack_.GetStates(), &stack_.traversal_stack_payload, 0);

this->operator++(0);
}
Expand All @@ -61,7 +61,8 @@ class StateTraverser final {
current_state_ = start_state;

TRACE("StateTraverser starting with Start state %d", current_state_);
f->GetOutGoingTransitions(start_state, &stack_.GetStates(), &stack_.traversal_stack_payload);
f->GetOutGoingTransitions(start_state, &stack_.GetStates(), &stack_.traversal_stack_payload,
f->GetInnerWeight(start_state));

if (advance) {
this->operator++(0);
Expand All @@ -71,7 +72,8 @@ class StateTraverser final {
StateTraverser(automata_t f, const uint64_t start_state, const bool advance = true)
: fsa_(f), current_state_(start_state), current_weight_(0), current_label_(0), at_end_(false), stack_() {
TRACE("StateTraverser starting with Start state %d", current_state_);
f->GetOutGoingTransitions(start_state, &stack_.GetStates(), &stack_.traversal_stack_payload);
f->GetOutGoingTransitions(start_state, &stack_.GetStates(), &stack_.traversal_stack_payload,
f->GetInnerWeight(start_state));

if (advance) {
this->operator++(0);
Expand Down Expand Up @@ -149,7 +151,7 @@ class StateTraverser final {
current_weight_ = stack_.GetStates().GetNextInnerWeight();
TRACE("Label: %c", current_label_);
stack_++;
fsa_->GetOutGoingTransitions(current_state_, &stack_.GetStates(), &stack_.traversal_stack_payload);
fsa_->GetOutGoingTransitions(current_state_, &stack_.GetStates(), &stack_.traversal_stack_payload, current_weight_);
TRACE("found %ld outgoing states", stack_.GetStates().size());
}

Expand Down
2 changes: 1 addition & 1 deletion keyvi/tests/keyvi/dictionary/fsa/automata_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ BOOST_AUTO_TEST_CASE(GetOutGoingTransitionsWeightTest) {

traversal::TraversalStack<traversal::WeightedTransition> stack;

f->GetOutGoingTransitions(f->GetStartState(), &stack.GetStates(), &stack.traversal_stack_payload);
f->GetOutGoingTransitions(f->GetStartState(), &stack.GetStates(), &stack.traversal_stack_payload, 42);

BOOST_CHECK_EQUAL(1, stack.GetStates().traversal_state_payload.transitions.size());
BOOST_CHECK_EQUAL(444, stack.GetStates().traversal_state_payload.transitions[0].weight);
Expand Down
20 changes: 20 additions & 0 deletions keyvi/tests/keyvi/dictionary/fsa/state_traverser_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,26 @@ BOOST_AUTO_TEST_CASE(traversal_min_weight) {
BOOST_CHECK_EQUAL(0, s.GetDepth());
}

BOOST_AUTO_TEST_CASE(traversal_inner_weight_long_entry) {
std::vector<std::pair<std::string, uint32_t>> test_data = {{std::string(500, 'a'), 300}};

testing::TempDictionary dictionary(&test_data);
automata_t f = dictionary.GetFsa();

StateTraverser<traversal::WeightedTransition> s(f);

int steps = 0;
while (s) {
++steps;
BOOST_CHECK_EQUAL('a', s.GetStateLabel());
BOOST_CHECK_EQUAL(steps, s.GetDepth());
BOOST_CHECK_EQUAL(300, s.GetInnerWeight());
s++;
}

BOOST_CHECK_EQUAL(500, steps);
}

BOOST_AUTO_TEST_SUITE_END()

} /* namespace fsa */
Expand Down

0 comments on commit ba1c697

Please sign in to comment.