diff --git a/samples/python/multinomial_causal_lm/multinomial_causal_lm.py b/samples/python/multinomial_causal_lm/multinomial_causal_lm.py index 9e43294ae5..191e2bd09b 100755 --- a/samples/python/multinomial_causal_lm/multinomial_causal_lm.py +++ b/samples/python/multinomial_causal_lm/multinomial_causal_lm.py @@ -31,6 +31,7 @@ def __init__(self, tokenizer): self.tokens_cache = [] self.text_queue = queue.Queue() self.print_len = 0 + self.decoded_lengths = [] def __iter__(self): """ @@ -80,34 +81,35 @@ def put(self, token_id: int) -> bool: Returns: bool: True if generation should be stopped, False otherwise. - """ + """ self.tokens_cache.append(token_id) text = self.tokenizer.decode(self.tokens_cache) + self.decoded_lengths.append(len(text)) word = '' - delay_n_chars = 4 + delay_n_tokens = 3 if len(text) > self.print_len and '\n' == text[-1]: # Flush the cache after the new line symbol. - word = text[self.print_len:] + word = text[self.print_len:] self.tokens_cache = [] + self.decoded_lengths = [] self.print_len = 0 - elif len(text) >= 3 and text[-1] == chr(65533): + elif len(text) > 0 and text[-1] == chr(65533): # Don't print incomplete text. - pass - elif len(text) > self.print_len + delay_n_chars: - # It is possible to have a shorter text after adding new token. - # Print to output only if text length is increaesed. - # Also, in some cases adding the next token can shorten the text, - # e.g. when apostrophe removing regex had worked after adding new tokens. - # Several last characters are delayed before flushed to output. - word = text[self.print_len:-delay_n_chars] - self.print_len = len(text) - delay_n_chars - self.put_word(word) - + self.decoded_lengths[-1] = -1 + elif len(self.tokens_cache) >= delay_n_tokens: + print_until = self.decoded_lengths[-delay_n_tokens] + if print_until != -1 and print_until > self.print_len: + # It is possible to have a shorter text after adding new token. + # Print to output only if text length is increased and text is complete (print_until != -1). + word = text[self.print_len:print_until] + self.print_len = print_until + self.put_word(word) + if self.get_stop_flag(): # When generation is stopped from streamer then end is not called, need to call it here manually. self.end() - return True # True means stop generation + return True # True means stop generation else: return False # False means continue generation diff --git a/src/cpp/src/text_callback_streamer.cpp b/src/cpp/src/text_callback_streamer.cpp index 28bd1e95a0..251a815cbb 100644 --- a/src/cpp/src/text_callback_streamer.cpp +++ b/src/cpp/src/text_callback_streamer.cpp @@ -15,28 +15,36 @@ bool TextCallbackStreamer::put(int64_t token) { std::stringstream res; m_tokens_cache.push_back(token); std::string text = m_tokenizer.decode(m_tokens_cache); + m_decoded_lengths.push_back(text.length()); - if (!text.empty() && '\n' == text.back() && text.size() > print_len) { + if (!text.empty() && '\n' == text.back() && text.size() > m_printed_len) { // Flush the cache after the new line symbol - res << std::string_view{text.data() + print_len, text.size() - print_len}; + res << std::string_view{text.data() + m_printed_len, text.size() - m_printed_len}; m_tokens_cache.clear(); - print_len = 0; + m_decoded_lengths.clear(); + m_printed_len = 0; return on_finalized_subword_callback(res.str()); } - // In some cases adding the next token can shorten the text, - // e.g. when apostrophe removing regex had worked after adding new tokens. - // Several last characters are delayed before flushed to output. - constexpr size_t delay_n_chars = 4; + constexpr size_t delay_n_tokens = 3; + auto print_until = m_decoded_lengths[m_decoded_lengths.size() - delay_n_tokens]; constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error. if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) { + m_decoded_lengths[m_decoded_lengths.size() - 1] = -1; // Don't print incomplete text return on_finalized_subword_callback(res.str()); - } else if (text.size() > print_len + delay_n_chars) { + } + // In some cases adding the next token can shorten the text, + // e.g. when apostrophe removing regex had worked after adding new tokens. + // Printing several last tokens is delayed. + if (m_tokens_cache.size() < delay_n_tokens) { + return on_finalized_subword_callback(res.str()); + } + if (print_until != -1 && print_until > m_printed_len) { // It is possible to have a shorter text after adding new token. // Print to output only if text length is increaesed. - res << std::string_view{text.data() + print_len, text.size() - print_len - delay_n_chars} << std::flush; - print_len = text.size() - delay_n_chars; + res << std::string_view{text.data() + m_printed_len, print_until - m_printed_len} << std::flush; + m_printed_len = print_until; } return on_finalized_subword_callback(res.str()); @@ -45,11 +53,12 @@ bool TextCallbackStreamer::put(int64_t token) { void TextCallbackStreamer::end() { std::stringstream res; std::string text = m_tokenizer.decode(m_tokens_cache); - if (text.size() <= print_len) - return ; - res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush; + if (text.size() <= m_printed_len) + return; + res << std::string_view{text.data() + m_printed_len, text.size() - m_printed_len} << std::flush; m_tokens_cache.clear(); - print_len = 0; + m_decoded_lengths.clear(); + m_printed_len = 0; on_finalized_subword_callback(res.str()); return; } diff --git a/src/cpp/src/text_callback_streamer.hpp b/src/cpp/src/text_callback_streamer.hpp index 7afc52b4f6..510a35d016 100644 --- a/src/cpp/src/text_callback_streamer.hpp +++ b/src/cpp/src/text_callback_streamer.hpp @@ -12,6 +12,7 @@ namespace genai { class TextCallbackStreamer: public StreamerBase { public: bool put(int64_t token) override; + void end() override; TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback); @@ -20,7 +21,8 @@ class TextCallbackStreamer: public StreamerBase { private: Tokenizer m_tokenizer; std::vector m_tokens_cache; - size_t print_len = 0; + std::vector m_decoded_lengths; + size_t m_printed_len = 0; }; } // namespace genai diff --git a/tests/python_tests/test_generate_api.py b/tests/python_tests/test_generate_api.py index 5b9b1252a7..0673460c0f 100644 --- a/tests/python_tests/test_generate_api.py +++ b/tests/python_tests/test_generate_api.py @@ -361,25 +361,51 @@ def test_callback_batch_fail(callback): pipe.generate(['1', '2'], ov_genai.GenerationConfig(), callback) -@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) +class StremerWithResults: + results: List[str] = [] + def __init__(self): + self.results = [] + + def accumulate(self, subword) -> bool: + self.results.append(subword) + return False + + def get_result_str(self) -> str: + return ''.join(self.results) + + +@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword), StremerWithResults()]) @pytest.mark.precommit @pytest.mark.nightly def test_callback_kwargs_one_string(callback): + streamer_class = None + if isinstance(callback, StremerWithResults): + streamer_class = callback + callback = callback.accumulate pipe = read_model(get_models_list()[0])[4] - pipe.generate('table is made of', max_new_tokens=10, streamer=callback) + res = pipe.generate('table is made of', max_new_tokens=10, streamer=callback) + if isinstance(streamer_class, StremerWithResults): + assert res == streamer_class.get_result_str() -@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) + +@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword), StremerWithResults()]) @pytest.mark.precommit @pytest.mark.nightly @pytest.mark.parametrize("model_descr", get_models_list()) def test_callback_decoding_metallama(model_descr, callback): + streamer_class = None + if isinstance(callback, StremerWithResults): + streamer_class = callback + callback = callback.accumulate # On metallam this prompt generates output which can shorten after adding new tokens. # Test that streamer correctly handles such cases. prompt = 'I have an interview about product speccing with the company Weekend Health. Give me an example of a question they might ask with regards about a new feature' if model_descr[0] != 'meta-llama/Meta-Llama-3-8B-Instruct': pytest.skip() pipe = read_model(model_descr)[4] - pipe.generate(prompt, max_new_tokens=300, streamer=callback) + res = pipe.generate(prompt, max_new_tokens=300, streamer=callback) + if isinstance(streamer_class, StremerWithResults): + assert res == streamer_class.get_result_str() @pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) diff --git a/tests/python_tests/test_vlm_api.py b/tests/python_tests/test_vlm_api.py index 0cb2e509f3..4b43dd99c6 100644 --- a/tests/python_tests/test_vlm_api.py +++ b/tests/python_tests/test_vlm_api.py @@ -46,6 +46,7 @@ def get_ov_model(cache): @pytest.mark.nightly def test_vlm_pipeline(cache): def streamer(word: str) -> bool: + result_from_streamer.append(word) return False models_path = get_ov_model(cache) @@ -54,14 +55,17 @@ def streamer(word: str) -> bool: images = [] for link in links: images.append(get_image_by_link(link)) - pipe = VLMPipeline(models_path, "CPU") pipe.start_chat() - pipe.generate(prompts[0], images=images, generation_config=get_greedy(), streamer=streamer) + result_from_streamer = [] + res = pipe.generate(prompts[0], images=images, generation_config=get_greedy(), streamer=streamer) + assert res.texts[0] == ''.join(result_from_streamer) for prompt in prompts[1:]: - pipe.generate(prompt, generation_config=get_greedy(), streamer=streamer) + result_from_streamer = [] + res = pipe.generate(prompt, generation_config=get_greedy(), streamer=streamer) + assert res.texts[0] == ''.join(result_from_streamer) pipe.finish_chat()