Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix streaming hieroglyphs #1492

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions samples/python/multinomial_causal_lm/multinomial_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, tokenizer):
self.tokens_cache = []
self.text_queue = queue.Queue()
self.print_len = 0
self.decoded_lengths = []

def __iter__(self):
"""
Expand Down Expand Up @@ -80,34 +81,35 @@ def put(self, token_id: int) -> bool:

Returns:
bool: True if generation should be stopped, False otherwise.
"""
"""
self.tokens_cache.append(token_id)
text = self.tokenizer.decode(self.tokens_cache)
self.decoded_lengths.append(len(text))

word = ''
delay_n_chars = 4
delay_n_tokens = 3
if len(text) > self.print_len and '\n' == text[-1]:
# Flush the cache after the new line symbol.
word = text[self.print_len:]
word = text[self.print_len:]
self.tokens_cache = []
self.decoded_lengths = []
self.print_len = 0
elif len(text) >= 3 and text[-1] == chr(65533):
elif len(text) > 0 and text[-1] == chr(65533):
# Don't print incomplete text.
pass
elif len(text) > self.print_len + delay_n_chars:
# It is possible to have a shorter text after adding new token.
# Print to output only if text length is increaesed.
# Also, in some cases adding the next token can shorten the text,
# e.g. when apostrophe removing regex had worked after adding new tokens.
# Several last characters are delayed before flushed to output.
word = text[self.print_len:-delay_n_chars]
self.print_len = len(text) - delay_n_chars
self.put_word(word)

self.decoded_lengths[-1] = -1
elif len(self.tokens_cache) >= delay_n_tokens:
print_until = self.decoded_lengths[-delay_n_tokens]
if print_until != -1 and print_until > self.print_len:
# It is possible to have a shorter text after adding new token.
# Print to output only if text length is increased and text is complete (print_until != -1).
word = text[self.print_len:print_until]
self.print_len = print_until
self.put_word(word)

if self.get_stop_flag():
# When generation is stopped from streamer then end is not called, need to call it here manually.
self.end()
return True # True means stop generation
return True # True means stop generation
else:
return False # False means continue generation

Expand Down
37 changes: 23 additions & 14 deletions src/cpp/src/text_callback_streamer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,36 @@ bool TextCallbackStreamer::put(int64_t token) {
std::stringstream res;
m_tokens_cache.push_back(token);
std::string text = m_tokenizer.decode(m_tokens_cache);
m_decoded_lengths.push_back(text.length());

if (!text.empty() && '\n' == text.back() && text.size() > print_len) {
if (!text.empty() && '\n' == text.back() && text.size() > m_printed_len) {
// Flush the cache after the new line symbol
res << std::string_view{text.data() + print_len, text.size() - print_len};
res << std::string_view{text.data() + m_printed_len, text.size() - m_printed_len};
m_tokens_cache.clear();
print_len = 0;
m_decoded_lengths.clear();
m_printed_len = 0;
return on_finalized_subword_callback(res.str());
}

// In some cases adding the next token can shorten the text,
// e.g. when apostrophe removing regex had worked after adding new tokens.
// Several last characters are delayed before flushed to output.
constexpr size_t delay_n_chars = 4;
constexpr size_t delay_n_tokens = 3;
auto print_until = m_decoded_lengths[m_decoded_lengths.size() - delay_n_tokens];
constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error.
if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) {
m_decoded_lengths[m_decoded_lengths.size() - 1] = -1;
// Don't print incomplete text
return on_finalized_subword_callback(res.str());
} else if (text.size() > print_len + delay_n_chars) {
}
// In some cases adding the next token can shorten the text,
// e.g. when apostrophe removing regex had worked after adding new tokens.
// Printing several last tokens is delayed.
if (m_tokens_cache.size() < delay_n_tokens) {
return on_finalized_subword_callback(res.str());
}
if (print_until != -1 && print_until > m_printed_len) {
// It is possible to have a shorter text after adding new token.
// Print to output only if text length is increaesed.
res << std::string_view{text.data() + print_len, text.size() - print_len - delay_n_chars} << std::flush;
print_len = text.size() - delay_n_chars;
res << std::string_view{text.data() + m_printed_len, print_until - m_printed_len} << std::flush;
m_printed_len = print_until;
}

return on_finalized_subword_callback(res.str());
Expand All @@ -45,11 +53,12 @@ bool TextCallbackStreamer::put(int64_t token) {
void TextCallbackStreamer::end() {
std::stringstream res;
std::string text = m_tokenizer.decode(m_tokens_cache);
if (text.size() <= print_len)
return ;
res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush;
if (text.size() <= m_printed_len)
return;
res << std::string_view{text.data() + m_printed_len, text.size() - m_printed_len} << std::flush;
m_tokens_cache.clear();
print_len = 0;
m_decoded_lengths.clear();
m_printed_len = 0;
on_finalized_subword_callback(res.str());
return;
}
Expand Down
4 changes: 3 additions & 1 deletion src/cpp/src/text_callback_streamer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace genai {
class TextCallbackStreamer: public StreamerBase {
public:
bool put(int64_t token) override;

void end() override;

TextCallbackStreamer(const Tokenizer& tokenizer, std::function<bool(std::string)> callback);
Expand All @@ -20,7 +21,8 @@ class TextCallbackStreamer: public StreamerBase {
private:
Tokenizer m_tokenizer;
std::vector<int64_t> m_tokens_cache;
size_t print_len = 0;
std::vector<int64_t> m_decoded_lengths;
size_t m_printed_len = 0;
};

} // namespace genai
Expand Down
34 changes: 30 additions & 4 deletions tests/python_tests/test_generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,25 +361,51 @@ def test_callback_batch_fail(callback):
pipe.generate(['1', '2'], ov_genai.GenerationConfig(), callback)


@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
class StremerWithResults:
Copy link
Contributor

@ilya-lavrenov ilya-lavrenov Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to move this to

def run_llm_pipeline(
models_path : Path,
prompts: List[str],
generation_config : GenerationConfig,
use_cb : bool = False
) -> List[GenerationResult]:
properties = get_default_properties()
if use_cb:
properties['scheduler_config'] = SchedulerConfig()
ov_pipe = LLMPipeline(models_path, device='CPU', **properties)
generate_outputs : DecodedResults = ov_pipe.generate(inputs=prompts, generation_config=generation_config)
index = 0
generation_results = []
for _ in prompts:
generation_result = GenerationResult()
generation_result.m_generation_ids = generate_outputs.texts[index : index + generation_config.num_return_sequences]
# sequences_scores are available only for beam search case
if generation_config.is_beam_search():
generation_result.m_scores = generate_outputs.scores[index : index + generation_config.num_return_sequences]
generation_results.append(generation_result)
index += generation_config.num_return_sequences
del ov_pipe
shutil.rmtree(models_path)
return generation_results

because it covers significantly more cases (but limit it to cases when we have a single batch and num_return_sequences is 1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. But it's present only in master. This PR is for release branch.
I will do that when cherry-pick to master.

results: List[str] = []
def __init__(self):
self.results = []

def accumulate(self, subword) -> bool:
self.results.append(subword)
return False

def get_result_str(self) -> str:
return ''.join(self.results)


@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword), StremerWithResults()])
@pytest.mark.precommit
@pytest.mark.nightly
def test_callback_kwargs_one_string(callback):
streamer_class = None
if isinstance(callback, StremerWithResults):
streamer_class = callback
callback = callback.accumulate
pipe = read_model(get_models_list()[0])[4]
pipe.generate('table is made of', max_new_tokens=10, streamer=callback)
res = pipe.generate('table is made of', max_new_tokens=10, streamer=callback)
if isinstance(streamer_class, StremerWithResults):
assert res == streamer_class.get_result_str()

@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])

@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword), StremerWithResults()])
@pytest.mark.precommit
@pytest.mark.nightly
@pytest.mark.parametrize("model_descr", get_models_list())
def test_callback_decoding_metallama(model_descr, callback):
streamer_class = None
if isinstance(callback, StremerWithResults):
streamer_class = callback
callback = callback.accumulate
# On metallam this prompt generates output which can shorten after adding new tokens.
# Test that streamer correctly handles such cases.
prompt = 'I have an interview about product speccing with the company Weekend Health. Give me an example of a question they might ask with regards about a new feature'
if model_descr[0] != 'meta-llama/Meta-Llama-3-8B-Instruct':
pytest.skip()
pipe = read_model(model_descr)[4]
pipe.generate(prompt, max_new_tokens=300, streamer=callback)
res = pipe.generate(prompt, max_new_tokens=300, streamer=callback)
if isinstance(streamer_class, StremerWithResults):
assert res == streamer_class.get_result_str()


@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
Expand Down
10 changes: 7 additions & 3 deletions tests/python_tests/test_vlm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_ov_model(cache):
@pytest.mark.nightly
def test_vlm_pipeline(cache):
def streamer(word: str) -> bool:
result_from_streamer.append(word)
return False

models_path = get_ov_model(cache)
Expand All @@ -54,14 +55,17 @@ def streamer(word: str) -> bool:
images = []
for link in links:
images.append(get_image_by_link(link))

pipe = VLMPipeline(models_path, "CPU")
pipe.start_chat()

pipe.generate(prompts[0], images=images, generation_config=get_greedy(), streamer=streamer)
result_from_streamer = []
res = pipe.generate(prompts[0], images=images, generation_config=get_greedy(), streamer=streamer)
assert res.texts[0] == ''.join(result_from_streamer)

for prompt in prompts[1:]:
pipe.generate(prompt, generation_config=get_greedy(), streamer=streamer)
result_from_streamer = []
res = pipe.generate(prompt, generation_config=get_greedy(), streamer=streamer)
assert res.texts[0] == ''.join(result_from_streamer)

pipe.finish_chat()

Expand Down
Loading