Skip to content

Commit

Permalink
Merge pull request #8 from pha123661/feat/sent-tokenize-model
Browse files Browse the repository at this point in the history
Add sentence tokenizer for EN & update generation parameter
  • Loading branch information
blafea authored Feb 4, 2025
2 parents 9b1b1c1 + 6419739 commit 2ff0200
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 39 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
__pycache__
.env
app.log
*.ftz
100 changes: 63 additions & 37 deletions app/emojilm_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from asyncio import Semaphore

import aiohttp
import fasttext
import nltk
from async_lru import alru_cache

logger = logging.getLogger()
language_model = fasttext.load_model("lid.176.ftz")


class EmojiLmHf:
Expand All @@ -22,11 +25,11 @@ class EmojiLmHf:
def __init__(
self,
hf_api_token_list,
concurrency=10,
concurrency=3,
keep_alive_interval=300, # seconds
):
self.hf_api_token_list = hf_api_token_list
self.semaphore = Semaphore(concurrency)
self.query_semaphore = Semaphore(concurrency)
self.keep_alive_interval = keep_alive_interval

self.api_idx = random.randint(0, len(self.hf_api_token_list)-1)
Expand Down Expand Up @@ -72,11 +75,6 @@ async def ping_serverless_api():
self.last_query_time = current_time

async def generate(self, input_text):
async with self.semaphore:
await asyncio.sleep(0.1)
return await self._generate(input_text=input_text)

async def _generate(self, input_text):
text_list, delimiter_list = preprocess_input_text(input_text)
logger.debug(f"Text list length: {len(text_list)}")

Expand All @@ -88,28 +86,24 @@ async def _generate(self, input_text):
last_sentence_within_limit[-5:]
return f"太長了啦❗️ 你輸入了{len(text_list)}句 目前限制{self.SENTENCE_LIMIT}句話 大概到這邊而已:「{last_sentence_within_limit}」", []

out_emoji_list = []
for text in text_list:
out_emoji = await self.query(self.INPUT_PREFIX + text)
out_emoji_list.append(out_emoji)
emojis = await asyncio.gather(*(self.query(self.INPUT_PREFIX + t) for t in text_list))

output_list = []
output_list = list(itertools.chain.from_iterable(
zip(text_list, out_emoji_list, delimiter_list)))
zip(text_list, emojis, delimiter_list)))
min_length = min(len(text_list), len(
out_emoji_list), len(delimiter_list))
emojis), len(delimiter_list))
if len(text_list) > min_length:
output_list.extend(text_list[min_length:])
if len(out_emoji_list) > min_length:
output_list.extend(out_emoji_list[min_length:])
if len(emojis) > min_length:
output_list.extend(emojis[min_length:])
if len(delimiter_list) > min_length:
output_list.extend(delimiter_list[min_length:])

output = "".join(output_list)

output_emoji_set = set()
for out_emoji in out_emoji_list:
output_emoji_set = output_emoji_set.union(set(out_emoji))
for e in emojis:
output_emoji_set = output_emoji_set.union(set(e))

return output, output_emoji_set

Expand All @@ -121,22 +115,26 @@ async def query(self, input_text):
"options": {"wait_for_model": True},
"parameters": {
"max_new_tokens": 5,
"do_sample": False,
"do_sample": True,
"temperature": 1.2,
'top_p': 0.8
},
}

try:
async with self.aio_session.post(self.API_URL, headers=self.api_header, json=payload) as response:
resp = await response.json(encoding='utf-8')
ret = resp[0]['generated_text']
except Exception as e:
logger.exception(e)

# retry once
self.update_hf_api_token()
async with self.aio_session.post(self.API_URL, headers=self.api_header, json=payload) as response:
resp = await response.json(encoding='utf-8')
ret = resp[0]['generated_text']
async with self.query_semaphore:
try:
async with self.aio_session.post(self.API_URL, headers=self.api_header, json=payload) as response:
resp = await response.json(encoding='utf-8')
ret = resp[0]['generated_text']
except Exception as e:
logger.exception(e)
if type(e) == KeyError:
logger.debug(f"Response: {resp}")
# retry once
self.update_hf_api_token()
async with self.aio_session.post(self.API_URL, headers=self.api_header, json=payload) as response:
resp = await response.json(encoding='utf-8')
ret = resp[0]['generated_text']

ret = post_process_output(ret)
logger.info(f"Input: `{input_text}` Output: `{ret}`")
Expand All @@ -148,12 +146,36 @@ async def close(self):

def preprocess_input_text(input_text: str):
input_text = re.sub(r"https?://\S+|www\.\S+", "", input_text)
input_text = input_text.strip(" ,。,.\n")
parts = re.split(r'(\s*[ ,。?;,.\n]\s*)', input_text)

text_list = parts[::2]
delimiter_list = parts[1::2]
return text_list, delimiter_list
language_label = language_model.predict(
[input_text.replace("\n", "")])[0][0][0]
if language_label in ['__label__zh', '__label__ja', '__label__ko']:
input_text = input_text.strip(" \n")
parts = re.split(r'([ ,,。.??!!;\n\s]+)', input_text)
sentence_list = parts[::2]
delimiter_list = parts[1::2]

while len(sentence_list) > 0 and sentence_list[-1] == '':
sentence_list.pop()
if len(delimiter_list) > len(sentence_list):
delimiter_list.pop()
delimiter_list += [''] * (len(sentence_list) - len(delimiter_list))
return sentence_list, delimiter_list
else:
sentences = nltk.tokenize.sent_tokenize(input_text, language='english')
delimiter_list = []
# Regular expression to match trailing punctuation
pattern = re.compile(r'([^\w\s]+)$')

cleaned_sentences = []
for sentence in sentences:
match = pattern.search(sentence)
if match:
delimiter_list.append(match.group(1)) # Extract punctuation
sentence = sentence[:match.start()] # Remove punctuation
else:
delimiter_list.append("")
cleaned_sentences.append(sentence)
return cleaned_sentences, delimiter_list


def post_process_output(output_emoji: str):
Expand All @@ -164,4 +186,8 @@ def post_process_output(output_emoji: str):
for code_unit in code_points).decode('utf-8')
except ValueError:
pass

# Remove the sad emoji due to the limitation of the model
if output_emoji == '🥲':
output_emoji = ""
return output_emoji
8 changes: 6 additions & 2 deletions dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ WORKDIR /app

COPY requirements.txt /app/requirements.txt

RUN pip install --no-cache-dir -r requirements.txt
RUN apt-get update && \
apt-get install --no-install-recommends --yes build-essential wget && \
pip install --no-cache-dir -r requirements.txt && \
python -m nltk.downloader punkt_tab && \
wget https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz -O lid.176.ftz

COPY ./app /app
COPY ./app/*.py /app

EXPOSE 8000

Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cryptography==41.0.7
Deprecated==1.2.14
exceptiongroup==1.2.0
fastapi==0.104.1
fasttext==0.9.3
Flask==3.0.0
frozenlist==1.4.0
future==0.18.3
Expand All @@ -25,6 +26,7 @@ line-bot-sdk==3.11.0
MarkupSafe==2.1.3
motor==3.4.0
multidict==6.0.4
nltk==3.9.1
pycparser==2.21
pydantic==2.5.2
pydantic_core==2.14.5
Expand Down

0 comments on commit 2ff0200

Please sign in to comment.