feat: add middle out str cut

This commit is contained in:
Jeremy Yin 2024-03-25 13:23:53 +08:00
parent 145f80df12
commit 8aaed632b6
5 changed files with 142 additions and 7 deletions

View File

@ -5,7 +5,7 @@
groups = ["default"]
strategy = ["cross_platform", "inherit_metadata"]
lock_version = "4.4.1"
content_hash = "sha256:64230f6c1194c02cb0bf18c190ac2a97d0a50add0cd0697a8c770b3ea68a7cf9"
content_hash = "sha256:7299263d59b3a2f2a928fcafab2f02c42200d274f160a1655b55035914db1096"
[[package]]
name = "annotated-types"
@ -44,6 +44,32 @@ files = [
{file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"},
]
[[package]]
name = "charset-normalizer"
version = "3.3.2"
requires_python = ">=3.7.0"
summary = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
groups = ["default"]
files = [
{file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"},
{file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"},
{file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"},
{file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"},
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"},
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"},
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"},
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"},
{file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"},
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"},
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"},
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"},
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"},
{file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"},
{file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"},
{file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"},
{file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"},
]
[[package]]
name = "click"
version = "8.1.7"
@ -342,6 +368,48 @@ files = [
{file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"},
]
[[package]]
name = "regex"
version = "2023.12.25"
requires_python = ">=3.7"
summary = "Alternative regular expression module, to replace re."
groups = ["default"]
files = [
{file = "regex-2023.12.25-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8a0ccf52bb37d1a700375a6b395bff5dd15c50acb745f7db30415bae3c2b0715"},
{file = "regex-2023.12.25-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c3c4a78615b7762740531c27cf46e2f388d8d727d0c0c739e72048beb26c8a9d"},
{file = "regex-2023.12.25-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ad83e7545b4ab69216cef4cc47e344d19622e28aabec61574b20257c65466d6a"},
{file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7a635871143661feccce3979e1727c4e094f2bdfd3ec4b90dfd4f16f571a87a"},
{file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d498eea3f581fbe1b34b59c697512a8baef88212f92e4c7830fcc1499f5b45a5"},
{file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:43f7cd5754d02a56ae4ebb91b33461dc67be8e3e0153f593c509e21d219c5060"},
{file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51f4b32f793812714fd5307222a7f77e739b9bc566dc94a18126aba3b92b98a3"},
{file = "regex-2023.12.25-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba99d8077424501b9616b43a2d208095746fb1284fc5ba490139651f971d39d9"},
{file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4bfc2b16e3ba8850e0e262467275dd4d62f0d045e0e9eda2bc65078c0110a11f"},
{file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8c2c19dae8a3eb0ea45a8448356ed561be843b13cbc34b840922ddf565498c1c"},
{file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:60080bb3d8617d96f0fb7e19796384cc2467447ef1c491694850ebd3670bc457"},
{file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b77e27b79448e34c2c51c09836033056a0547aa360c45eeeb67803da7b0eedaf"},
{file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:518440c991f514331f4850a63560321f833979d145d7d81186dbe2f19e27ae3d"},
{file = "regex-2023.12.25-cp312-cp312-win32.whl", hash = "sha256:e2610e9406d3b0073636a3a2e80db05a02f0c3169b5632022b4e81c0364bcda5"},
{file = "regex-2023.12.25-cp312-cp312-win_amd64.whl", hash = "sha256:cc37b9aeebab425f11f27e5e9e6cf580be7206c6582a64467a14dda211abc232"},
{file = "regex-2023.12.25.tar.gz", hash = "sha256:29171aa128da69afdf4bde412d5bedc335f2ca8fcfe4489038577d05f16181e5"},
]
[[package]]
name = "requests"
version = "2.31.0"
requires_python = ">=3.7"
summary = "Python HTTP for Humans."
groups = ["default"]
dependencies = [
"certifi>=2017.4.17",
"charset-normalizer<4,>=2",
"idna<4,>=2.5",
"urllib3<3,>=1.21.1",
]
files = [
{file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"},
{file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"},
]
[[package]]
name = "sentry-sdk"
version = "1.42.0"
@ -381,6 +449,27 @@ files = [
{file = "starlette-0.36.3.tar.gz", hash = "sha256:90a671733cfb35771d8cc605e0b679d23b992f8dcfad48cc60b38cb29aeb7080"},
]
[[package]]
name = "tiktoken"
version = "0.6.0"
requires_python = ">=3.8"
summary = "tiktoken is a fast BPE tokeniser for use with OpenAI's models"
groups = ["default"]
dependencies = [
"regex>=2022.1.18",
"requests>=2.26.0",
]
files = [
{file = "tiktoken-0.6.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:17cc8a4a3245ab7d935c83a2db6bb71619099d7284b884f4b2aea4c74f2f83e3"},
{file = "tiktoken-0.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:284aebcccffe1bba0d6571651317df6a5b376ff6cfed5aeb800c55df44c78177"},
{file = "tiktoken-0.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c1a3a5d33846f8cd9dd3b7897c1d45722f48625a587f8e6f3d3e85080559be8"},
{file = "tiktoken-0.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6318b2bb2337f38ee954fd5efa82632c6e5ced1d52a671370fa4b2eff1355e91"},
{file = "tiktoken-0.6.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1f5f0f2ed67ba16373f9a6013b68da298096b27cd4e1cf276d2d3868b5c7efd1"},
{file = "tiktoken-0.6.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:75af4c0b16609c2ad02581f3cdcd1fb698c7565091370bf6c0cf8624ffaba6dc"},
{file = "tiktoken-0.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:45577faf9a9d383b8fd683e313cf6df88b6076c034f0a16da243bb1c139340c3"},
{file = "tiktoken-0.6.0.tar.gz", hash = "sha256:ace62a4ede83c75b0374a2ddfa4b76903cf483e9cb06247f566be3bf14e6beed"},
]
[[package]]
name = "tqdm"
version = "4.66.2"
@ -412,7 +501,6 @@ version = "2.2.1"
requires_python = ">=3.8"
summary = "HTTP library with thread-safe connection pooling, file post, and more."
groups = ["default"]
marker = "python_version >= \"3.6\""
files = [
{file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"},
{file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"},

View File

@ -16,6 +16,7 @@ dependencies = [
"sentry-sdk>=1.42.0",
"pydantic-mongo>=2.1.2",
"gunicorn>=21.2.0",
"tiktoken>=0.6.0",
]
requires-python = "==3.12.*"
readme = "README.md"

View File

@ -1,3 +1,5 @@
import math
from openai import OpenAI
from os import getenv
from loguru import logger
@ -5,11 +7,28 @@ from loguru import logger
from simplylab.database import Database
def middle_out(request_content: str, model_name: str, limit: int) -> str:
import tiktoken
encoding = tiktoken.encoding_for_model(model_name)
tokens = encoding.encode(request_content)
token_count = len(tokens)
if token_count > limit:
rate = limit / token_count
middle = math.floor(len(request_content) / 2)
first_end = math.floor(middle * rate)
second_start = math.ceil(middle + middle * (1 - rate))
logger.debug(f"{token_count=} {limit=} {rate=} {len(request_content)=} {middle=} {first_end=} {second_start=}")
chat_content = request_content[:first_end] + request_content[second_start:]
return chat_content
return request_content
class OpenRouterProvider:
def __init__(self, db: Database) -> None:
self.db = db
async def chat(self, content: str) -> str:
chat_content = middle_out(content, "gpt-3.5-turbo", 8192)
# gets API Key from environment variable OPENAI_API_KEY
api_key = getenv("OPENROUTER_API_KEY")
client = OpenAI(
@ -26,10 +45,10 @@ class OpenRouterProvider:
messages=[
{
"role": "user",
"content": content,
"content": chat_content,
},
],
)
logger.debug(f"request content: {content}")
logger.debug(f"request chat_content: {chat_content}")
logger.debug(f"response content: {completion.choices[0].message.content}")
return completion.choices[0].message.content

View File

@ -1,5 +1,3 @@
from typing import Any
from loguru import logger
from simplylab.model.res import UserChatMessage, \
@ -24,7 +22,6 @@ class ChatService:
raise MessageLimitedInDailyError()
request_content = req.message
# todo: request content middle out
response_content = await self.pvd.openrouter.chat(content=request_content)
user_message = Message(
user_id=self.ctx.user.id,

30
tests/test_middle_out.py Normal file
View File

@ -0,0 +1,30 @@
import unittest
from loguru import logger
from simplylab.providers.openrouter import middle_out
class MyTestCase(unittest.TestCase):
def test_middle_out(self):
content = """
Text generation and embeddings models process text in chunks called tokens. Tokens represent commonly occurring sequences of characters. For example, the string " tokenization" is decomposed as " token" and "ization", while a short and common word like " the" is represented as a single token. Note that in a sentence, the first token of each word typically starts with a space character. Check out our tokenizer tool to test specific strings and see how they are translated into tokens. As a rough rule of thumb, 1 token is approximately 4 characters or 0.75 words for English text.
One limitation to keep in mind is that for a text generation model the prompt and the generated output combined must be no more than the model's maximum context length. For embeddings models (which do not output tokens), the input must be shorter than the model's maximum context length. The maximum context lengths for each text generation and embeddings model can be found in the model index.
"""
model_name = "gpt-3.5-turbo"
limit = 100
chat_content = middle_out(content, model_name, limit)
logger.debug(chat_content)
import tiktoken
encoding = tiktoken.encoding_for_model(model_name)
tokens = encoding.encode(chat_content)
token_count = len(tokens)
logger.debug(f"{token_count=}, {limit=}")
res = token_count <= limit
self.assertEqual(res, True) # add assertion here
if __name__ == '__main__':
unittest.main()