Skip to content

Commit

Permalink
refactor matches to be returned as shared pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmuhs committed Apr 20, 2024
1 parent a6c2168 commit 1272880
Show file tree
Hide file tree
Showing 32 changed files with 325 additions and 329 deletions.
16 changes: 8 additions & 8 deletions keyvi/bin/keyvi_c/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

using keyvi::dictionary::Dictionary;
using keyvi::dictionary::dictionary_t;
using keyvi::dictionary::Match;
using keyvi::dictionary::match_t;
using keyvi::dictionary::MatchIterator;
using keyvi::dictionary::completion::MultiWordCompletion;
using keyvi::dictionary::completion::PrefixCompletion;
Expand All @@ -54,9 +54,9 @@ struct keyvi_dictionary {
};

struct keyvi_match {
explicit keyvi_match(const Match& obj) : obj_(obj) {}
explicit keyvi_match(const match_t& obj) : obj_(obj) {}

Match obj_;
match_t obj_;
};

struct keyvi_match_iterator {
Expand Down Expand Up @@ -143,20 +143,20 @@ void keyvi_match_destroy(const keyvi_match* match) {
}

bool keyvi_match_is_empty(const keyvi_match* match) {
return match->obj_.IsEmpty();
return match->obj_->IsEmpty();
}

double keyvi_match_get_score(const keyvi_match* match) {
return match->obj_.GetScore();
return match->obj_->GetScore();
}

char* keyvi_match_get_value_as_string(const keyvi_match* match) {
return std_2_c_string(match->obj_.GetValueAsString());
return std_2_c_string(match->obj_->GetValueAsString());
}

keyvi_bytes keyvi_match_get_msgpacked_value(const struct keyvi_match* match) {
const keyvi_bytes empty_keyvi_bytes{0, nullptr};
const std::string msgpacked_value = match->obj_.GetMsgPackedValueAsString();
const std::string msgpacked_value = match->obj_->GetMsgPackedValueAsString();

const size_t data_size = msgpacked_value.size();
if (0 == data_size) {
Expand All @@ -172,7 +172,7 @@ keyvi_bytes keyvi_match_get_msgpacked_value(const struct keyvi_match* match) {
}

char* keyvi_match_get_matched_string(const keyvi_match* match) {
return std_2_c_string(match->obj_.GetMatchedString());
return std_2_c_string(match->obj_->GetMatchedString());
}

//////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ForwardBackwardCompletion final {
: forward_completions_(forward_dictionary), backward_completions_(backward_dictionary) {}

struct result_compare {
bool operator()(const Match& m1, const Match& m2) const { return m1.GetScore() < m2.GetScore(); }
bool operator()(const match_t& m1, const match_t& m2) const { return m1->GetScore() < m2->GetScore(); }
};

MatchIterator::MatchIteratorPair GetCompletions(const std::string& query, int number_of_results = 10) {
Expand All @@ -70,14 +70,14 @@ class ForwardBackwardCompletion final {

// priority queue for pruning results
util::BoundedPriorityQueue<uint32_t> best_scores(2 * number_of_results);
std::vector<Match> results;
std::vector<match_t> results;

for (auto match : forward_completions_.GetCompletions(query, number_of_results)) {
uint32_t weight = boost::lexical_cast<uint32_t>(match.GetAttribute("weight"));
uint32_t weight = boost::lexical_cast<uint32_t>(match->GetAttribute("weight"));

// put the weight into the priority queue
best_scores.Put(weight);
match.SetScore(weight);
match->SetScore(weight);
results.push_back(match);

TRACE("Forward Completion: %s %d", match.GetMatchedString().c_str(), match.GetScore());
Expand All @@ -86,17 +86,17 @@ class ForwardBackwardCompletion final {
if (results.size() > 0 && query_length > 4) {
std::make_heap(results.begin(), results.end(), result_compare());

std::vector<Match> results_forward_and_backward;
std::vector<match_t> results_forward_and_backward;

do {
std::pop_heap(results.begin(), results.end(), result_compare());
Match m = results.back();
match_t m = results.back();
results.pop_back();

std::string phrase = m.GetMatchedString();
std::string phrase = m->GetMatchedString();

// heuristic: stop expanding if phrase has a lower score than the worst best score
if (best_scores.Back() > m.GetScore()) {
if (best_scores.Back() > m->GetScore()) {
TRACE("Stop backward completions score to low %d", m.GetScore());
break;
}
Expand All @@ -110,7 +110,7 @@ class ForwardBackwardCompletion final {

uint32_t last_weight = 0;
for (auto match : backward_completions_.GetCompletions(phrase.c_str(), number_of_results)) {
uint32_t weight = boost::lexical_cast<uint32_t>(match.GetAttribute("weight"));
uint32_t weight = boost::lexical_cast<uint32_t>(match->GetAttribute("weight"));

if (weight < best_scores.Back()) {
TRACE("Skip Backward, score to low %d", weight);
Expand All @@ -126,13 +126,13 @@ class ForwardBackwardCompletion final {

// accept the result
best_scores.Put(weight);
match.SetScore(weight);
match->SetScore(weight);

// reverse the matched string
std::string matched_string = match.GetMatchedString();
std::string matched_string = match->GetMatchedString();
std::reverse(matched_string.begin(), matched_string.end());

match.SetMatchedString(matched_string);
match->SetMatchedString(matched_string);

results_forward_and_backward.push_back(match);

Expand Down Expand Up @@ -161,16 +161,16 @@ class ForwardBackwardCompletion final {
// reuse results vector
results.clear();
for (auto match : backward_completions_.GetCompletions(phrase.c_str(), number_of_results)) {
std::string matched_string = match.GetMatchedString();
std::string matched_string = match->GetMatchedString();
std::reverse(matched_string.begin(), matched_string.end());
// if the original query had a space at the end, this result should as well
if (last_character_is_space) {
matched_string.append(" ");
}

uint32_t weight = boost::lexical_cast<uint32_t>(match.GetAttribute("weight"));
match.SetScore(weight);
match.SetMatchedString(matched_string);
uint32_t weight = boost::lexical_cast<uint32_t>(match->GetAttribute("weight"));
match->SetScore(weight);
match->SetMatchedString(matched_string);

results.push_back(match);
TRACE("Backward Completion from query add: %s %d", match.GetMatchedString().c_str(), match.GetScore());
Expand All @@ -181,23 +181,23 @@ class ForwardBackwardCompletion final {

do {
std::pop_heap(results.begin(), results.end(), result_compare());
Match m = results.back();
match_t m = results.back();
results.pop_back();

std::string phrase = m.GetMatchedString();
std::string phrase = m->GetMatchedString();
TRACE("Do forward from backward completion for %s (%d / %d)", m.GetMatchedString().c_str(), m.GetScore(),
best_scores.Back());

// heuristic: stop expanding if phrase has a lower score than the worst best score
if (best_scores.Back() > m.GetScore()) {
if (best_scores.Back() > m->GetScore()) {
TRACE("Stop backward forward completions scores to low %d", m.GetScore());
break;
}

// match forward with this
for (auto match_forward :
forward_completions_.GetCompletions(m.GetMatchedString().c_str(), number_of_results)) {
uint32_t weight = boost::lexical_cast<uint32_t>(match_forward.GetAttribute("weight"));
forward_completions_.GetCompletions(m->GetMatchedString().c_str(), number_of_results)) {
uint32_t weight = boost::lexical_cast<uint32_t>(match_forward->GetAttribute("weight"));

if (weight < best_scores.Back()) {
TRACE("Skip Backward forward, score to low %d", weight);
Expand All @@ -206,7 +206,7 @@ class ForwardBackwardCompletion final {

// accept the result
best_scores.Put(weight);
match_forward.SetScore(weight);
match_forward->SetScore(weight);

results_forward_and_backward.push_back(match_forward);

Expand All @@ -223,10 +223,10 @@ class ForwardBackwardCompletion final {
std::make_heap(results.begin(), results.end(), result_compare());

struct delegate_payload {
explicit delegate_payload(std::vector<Match>& r) : results(std::move(r)) {}
explicit delegate_payload(std::vector<match_t>& r) : results(std::move(r)) {}

std::vector<Match> results;
Match last_result;
std::vector<match_t> results;
match_t last_result;
};

std::shared_ptr<delegate_payload> data(new delegate_payload(results));
Expand All @@ -236,10 +236,10 @@ class ForwardBackwardCompletion final {
std::pop_heap(data->results.begin(), data->results.end(), result_compare());

// de-duplicate
while (data->last_result.GetMatchedString() == data->results.back().GetMatchedString()) {
while (data->last_result && data->last_result->GetMatchedString() == data->results.back()->GetMatchedString()) {
data->results.pop_back();
if (data->results.size() == 0) {
return Match();
return match_t();
}

std::pop_heap(data->results.begin(), data->results.end(), result_compare());
Expand All @@ -251,7 +251,7 @@ class ForwardBackwardCompletion final {
return data->last_result;
}

return Match();
return match_t();
};

return MatchIterator::MakeIteratorPair(tfunc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class MultiWordCompletion final {
traversal_stack.reserve(100);

if (depth == query_length) {
Match first_match;
match_t first_match;
TRACE("matched prefix");

// data which is required for the callback as well
Expand All @@ -95,7 +95,7 @@ class MultiWordCompletion final {
TRACE("prefix matched depth %d %s", query_length + data->traverser.GetDepth(),
std::string(reinterpret_cast<char*>(&data->traversal_stack[0]), query_length + data->traverser.GetDepth())
.c_str());
first_match = Match(0, query_length, query, 0, fsa_, fsa_->GetStateValue(state));
first_match = std::make_shared<Match>(0, query_length, query, 0, fsa_, fsa_->GetStateValue(state));
}

auto tfunc = [data, query_length]() {
Expand Down Expand Up @@ -140,7 +140,7 @@ class MultiWordCompletion final {
query_length + data->traverser.GetDepth());
}

Match m(0, data->traverser.GetDepth() + query_length, matched_entry, 0, data->traverser.GetFsa(),
match_t m = std::make_shared<Match>(0, data->traverser.GetDepth() + query_length, matched_entry, 0, data->traverser.GetFsa(),
data->traverser.GetStateValue());

data->traverser++;
Expand All @@ -150,7 +150,7 @@ class MultiWordCompletion final {
data->traverser++;
} else {
TRACE("StateTraverser exhausted.");
return Match();
return match_t();
}
}
};
Expand Down
16 changes: 8 additions & 8 deletions keyvi/include/keyvi/dictionary/completion/prefix_completion.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class PrefixCompletion final {
traversal_stack.reserve(1024);

if (depth == query_length) {
Match first_match;
match_t first_match;
TRACE("matched prefix");

// data which is required for the callback as well
Expand All @@ -91,7 +91,7 @@ class PrefixCompletion final {
TRACE("prefix matched depth %d %s", query_length + data->traverser.GetDepth(),
std::string(reinterpret_cast<char*>(&data->traversal_stack[0]), query_length + data->traverser.GetDepth())
.c_str());
first_match = Match(0, query_length, query, 0, fsa_, fsa_->GetStateValue(state));
first_match = std::make_shared<Match>(0, query_length, query, 0, fsa_, fsa_->GetStateValue(state));
}

auto tfunc = [data, query_length]() {
Expand All @@ -106,7 +106,7 @@ class PrefixCompletion final {
std::string match_str = std::string(reinterpret_cast<char*>(&data->traversal_stack[0]),
query_length + data->traverser.GetDepth());
TRACE("found final state at depth %d %s", query_length + data->traverser.GetDepth(), match_str.c_str());
Match m(0, data->traverser.GetDepth() + query_length, match_str, 0, data->traverser.GetFsa(),
match_t m = std::make_shared<Match>(0, data->traverser.GetDepth() + query_length, match_str, 0, data->traverser.GetFsa(),
data->traverser.GetStateValue());

data->traverser++;
Expand All @@ -116,7 +116,7 @@ class PrefixCompletion final {
data->traverser++;
} else {
TRACE("StateTraverser exhausted.");
return Match();
return match_t();
}
}
};
Expand Down Expand Up @@ -173,13 +173,13 @@ class PrefixCompletion final {

TRACE("state %d", state);

Match first_match;
match_t first_match;
TRACE("matched prefix");

if (depth == query_length && fsa_->IsFinalState(state)) {
TRACE("prefix matched depth %d %s", query_length + data->traverser.GetDepth(),
std::string(query, query_length).c_str());
first_match = Match(0, query_length, query, 0, fsa_, fsa_->GetStateValue(state));
first_match = std::make_shared<Match>(0, query_length, query, 0, fsa_, fsa_->GetStateValue(state));
}

auto tfunc = [data, query_length, max_edit_distance, exact_prefix]() {
Expand All @@ -204,7 +204,7 @@ class PrefixCompletion final {
if (data->traverser.IsFinalState()) {
if (query_length < depth || data->metric.GetScore() <= max_edit_distance) {
TRACE("found final state at depth %d %s", depth, data->metric.GetCandidate().c_str());
Match m(0, depth, data->metric.GetCandidate(), data->metric.GetScore(), data->traverser.GetFsa(),
match_t m=std::make_shared<Match>(0, depth, data->metric.GetCandidate(), data->metric.GetScore(), data->traverser.GetFsa(),
data->traverser.GetStateValue());

data->traverser++;
Expand All @@ -214,7 +214,7 @@ class PrefixCompletion final {
data->traverser++;
} else {
TRACE("StateTraverser exhausted.");
return Match();
return match_t();
}
}
};
Expand Down
Loading

0 comments on commit 1272880

Please sign in to comment.