From 8aaed632b6b3b1224f5c8029b9bd281851684a43 Mon Sep 17 00:00:00 2001 From: Jeremy Yin Date: Mon, 25 Mar 2024 13:23:53 +0800 Subject: [PATCH] feat: add middle out str cut --- pdm.lock | 92 ++++++++++++++++++++++++++++++- pyproject.toml | 1 + simplylab/providers/openrouter.py | 23 +++++++- simplylab/services/chat.py | 3 - tests/test_middle_out.py | 30 ++++++++++ 5 files changed, 142 insertions(+), 7 deletions(-) create mode 100644 tests/test_middle_out.py diff --git a/pdm.lock b/pdm.lock index 4259ba5..c835747 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:64230f6c1194c02cb0bf18c190ac2a97d0a50add0cd0697a8c770b3ea68a7cf9" +content_hash = "sha256:7299263d59b3a2f2a928fcafab2f02c42200d274f160a1655b55035914db1096" [[package]] name = "annotated-types" @@ -44,6 +44,32 @@ files = [ {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"}, ] +[[package]] +name = "charset-normalizer" +version = "3.3.2" +requires_python = ">=3.7.0" +summary = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +groups = ["default"] +files = [ + {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"}, + {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, +] + [[package]] name = "click" version = "8.1.7" @@ -342,6 +368,48 @@ files = [ {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, ] +[[package]] +name = "regex" +version = "2023.12.25" +requires_python = ">=3.7" +summary = "Alternative regular expression module, to replace re." +groups = ["default"] +files = [ + {file = "regex-2023.12.25-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8a0ccf52bb37d1a700375a6b395bff5dd15c50acb745f7db30415bae3c2b0715"}, + {file = "regex-2023.12.25-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c3c4a78615b7762740531c27cf46e2f388d8d727d0c0c739e72048beb26c8a9d"}, + {file = "regex-2023.12.25-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ad83e7545b4ab69216cef4cc47e344d19622e28aabec61574b20257c65466d6a"}, + {file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7a635871143661feccce3979e1727c4e094f2bdfd3ec4b90dfd4f16f571a87a"}, + {file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d498eea3f581fbe1b34b59c697512a8baef88212f92e4c7830fcc1499f5b45a5"}, + {file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:43f7cd5754d02a56ae4ebb91b33461dc67be8e3e0153f593c509e21d219c5060"}, + {file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51f4b32f793812714fd5307222a7f77e739b9bc566dc94a18126aba3b92b98a3"}, + {file = "regex-2023.12.25-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba99d8077424501b9616b43a2d208095746fb1284fc5ba490139651f971d39d9"}, + {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4bfc2b16e3ba8850e0e262467275dd4d62f0d045e0e9eda2bc65078c0110a11f"}, + {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8c2c19dae8a3eb0ea45a8448356ed561be843b13cbc34b840922ddf565498c1c"}, + {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:60080bb3d8617d96f0fb7e19796384cc2467447ef1c491694850ebd3670bc457"}, + {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b77e27b79448e34c2c51c09836033056a0547aa360c45eeeb67803da7b0eedaf"}, + {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:518440c991f514331f4850a63560321f833979d145d7d81186dbe2f19e27ae3d"}, + {file = "regex-2023.12.25-cp312-cp312-win32.whl", hash = "sha256:e2610e9406d3b0073636a3a2e80db05a02f0c3169b5632022b4e81c0364bcda5"}, + {file = "regex-2023.12.25-cp312-cp312-win_amd64.whl", hash = "sha256:cc37b9aeebab425f11f27e5e9e6cf580be7206c6582a64467a14dda211abc232"}, + {file = "regex-2023.12.25.tar.gz", hash = "sha256:29171aa128da69afdf4bde412d5bedc335f2ca8fcfe4489038577d05f16181e5"}, +] + +[[package]] +name = "requests" +version = "2.31.0" +requires_python = ">=3.7" +summary = "Python HTTP for Humans." +groups = ["default"] +dependencies = [ + "certifi>=2017.4.17", + "charset-normalizer<4,>=2", + "idna<4,>=2.5", + "urllib3<3,>=1.21.1", +] +files = [ + {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, + {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, +] + [[package]] name = "sentry-sdk" version = "1.42.0" @@ -381,6 +449,27 @@ files = [ {file = "starlette-0.36.3.tar.gz", hash = "sha256:90a671733cfb35771d8cc605e0b679d23b992f8dcfad48cc60b38cb29aeb7080"}, ] +[[package]] +name = "tiktoken" +version = "0.6.0" +requires_python = ">=3.8" +summary = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" +groups = ["default"] +dependencies = [ + "regex>=2022.1.18", + "requests>=2.26.0", +] +files = [ + {file = "tiktoken-0.6.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:17cc8a4a3245ab7d935c83a2db6bb71619099d7284b884f4b2aea4c74f2f83e3"}, + {file = "tiktoken-0.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:284aebcccffe1bba0d6571651317df6a5b376ff6cfed5aeb800c55df44c78177"}, + {file = "tiktoken-0.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c1a3a5d33846f8cd9dd3b7897c1d45722f48625a587f8e6f3d3e85080559be8"}, + {file = "tiktoken-0.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6318b2bb2337f38ee954fd5efa82632c6e5ced1d52a671370fa4b2eff1355e91"}, + {file = "tiktoken-0.6.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1f5f0f2ed67ba16373f9a6013b68da298096b27cd4e1cf276d2d3868b5c7efd1"}, + {file = "tiktoken-0.6.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:75af4c0b16609c2ad02581f3cdcd1fb698c7565091370bf6c0cf8624ffaba6dc"}, + {file = "tiktoken-0.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:45577faf9a9d383b8fd683e313cf6df88b6076c034f0a16da243bb1c139340c3"}, + {file = "tiktoken-0.6.0.tar.gz", hash = "sha256:ace62a4ede83c75b0374a2ddfa4b76903cf483e9cb06247f566be3bf14e6beed"}, +] + [[package]] name = "tqdm" version = "4.66.2" @@ -412,7 +501,6 @@ version = "2.2.1" requires_python = ">=3.8" summary = "HTTP library with thread-safe connection pooling, file post, and more." groups = ["default"] -marker = "python_version >= \"3.6\"" files = [ {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, diff --git a/pyproject.toml b/pyproject.toml index 1ae6cfb..33ddb6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "sentry-sdk>=1.42.0", "pydantic-mongo>=2.1.2", "gunicorn>=21.2.0", + "tiktoken>=0.6.0", ] requires-python = "==3.12.*" readme = "README.md" diff --git a/simplylab/providers/openrouter.py b/simplylab/providers/openrouter.py index 53b16e7..397c2db 100644 --- a/simplylab/providers/openrouter.py +++ b/simplylab/providers/openrouter.py @@ -1,3 +1,5 @@ +import math + from openai import OpenAI from os import getenv from loguru import logger @@ -5,11 +7,28 @@ from loguru import logger from simplylab.database import Database +def middle_out(request_content: str, model_name: str, limit: int) -> str: + import tiktoken + encoding = tiktoken.encoding_for_model(model_name) + tokens = encoding.encode(request_content) + token_count = len(tokens) + if token_count > limit: + rate = limit / token_count + middle = math.floor(len(request_content) / 2) + first_end = math.floor(middle * rate) + second_start = math.ceil(middle + middle * (1 - rate)) + logger.debug(f"{token_count=} {limit=} {rate=} {len(request_content)=} {middle=} {first_end=} {second_start=}") + chat_content = request_content[:first_end] + request_content[second_start:] + return chat_content + return request_content + + class OpenRouterProvider: def __init__(self, db: Database) -> None: self.db = db async def chat(self, content: str) -> str: + chat_content = middle_out(content, "gpt-3.5-turbo", 8192) # gets API Key from environment variable OPENAI_API_KEY api_key = getenv("OPENROUTER_API_KEY") client = OpenAI( @@ -26,10 +45,10 @@ class OpenRouterProvider: messages=[ { "role": "user", - "content": content, + "content": chat_content, }, ], ) - logger.debug(f"request content: {content}") + logger.debug(f"request chat_content: {chat_content}") logger.debug(f"response content: {completion.choices[0].message.content}") return completion.choices[0].message.content diff --git a/simplylab/services/chat.py b/simplylab/services/chat.py index 8d1f4da..02948f6 100644 --- a/simplylab/services/chat.py +++ b/simplylab/services/chat.py @@ -1,5 +1,3 @@ -from typing import Any - from loguru import logger from simplylab.model.res import UserChatMessage, \ @@ -24,7 +22,6 @@ class ChatService: raise MessageLimitedInDailyError() request_content = req.message - # todo: request content middle out response_content = await self.pvd.openrouter.chat(content=request_content) user_message = Message( user_id=self.ctx.user.id, diff --git a/tests/test_middle_out.py b/tests/test_middle_out.py new file mode 100644 index 0000000..3820ffa --- /dev/null +++ b/tests/test_middle_out.py @@ -0,0 +1,30 @@ +import unittest + +from loguru import logger + +from simplylab.providers.openrouter import middle_out + + +class MyTestCase(unittest.TestCase): + def test_middle_out(self): + content = """ +Text generation and embeddings models process text in chunks called tokens. Tokens represent commonly occurring sequences of characters. For example, the string " tokenization" is decomposed as " token" and "ization", while a short and common word like " the" is represented as a single token. Note that in a sentence, the first token of each word typically starts with a space character. Check out our tokenizer tool to test specific strings and see how they are translated into tokens. As a rough rule of thumb, 1 token is approximately 4 characters or 0.75 words for English text. + +One limitation to keep in mind is that for a text generation model the prompt and the generated output combined must be no more than the model's maximum context length. For embeddings models (which do not output tokens), the input must be shorter than the model's maximum context length. The maximum context lengths for each text generation and embeddings model can be found in the model index. + """ + model_name = "gpt-3.5-turbo" + limit = 100 + chat_content = middle_out(content, model_name, limit) + logger.debug(chat_content) + + import tiktoken + encoding = tiktoken.encoding_for_model(model_name) + tokens = encoding.encode(chat_content) + token_count = len(tokens) + logger.debug(f"{token_count=}, {limit=}") + res = token_count <= limit + self.assertEqual(res, True) # add assertion here + + +if __name__ == '__main__': + unittest.main()