Skip to content

Commit

Permalink
feat(huixiangdou): support code search
Browse files Browse the repository at this point in the history
  • Loading branch information
tpoisonooo committed Sep 2, 2024
1 parent 7f6b1b8 commit 46b240c
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 33 deletions.
2 changes: 1 addition & 1 deletion evaluation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ For bce-embedding-base_v1
For bge-large-zh-v1.5

- The chunksize range should be (423, 1240)
- The compression rate of embedding.tokenzier is slightly lower
- The compression rate of embedding.tokenizer is slightly lower
- The best F1@throttle obtained on the right value is [email protected]

The basis for choosing splitter is:
Expand Down
2 changes: 1 addition & 1 deletion evaluation/README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ print(result)
对 bge-large-zh-v1.5

- chunksize 范围应在 (423, 1240)
- embedding.tokenzier 的压缩率略低
- embedding.tokenizer 的压缩率略低
- 右值取到的最佳 F1@throttle 为 [email protected]

splitter 选择依据
Expand Down
2 changes: 1 addition & 1 deletion huixiangdou/gradio_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def build_feature_store(main_args):
ui_web_search = gr.Radio(["no", "yes"], label="Enable web search", info="Disable by default ")
ui_web_search.change(fn=on_web_search_changed, inputs=ui_web_search, outputs=[])
with gr.Column():
ui_code_search = gr.Radio(["no", "yes"], label="Enable code search", info="Enable by default ")
ui_code_search = gr.Radio(["yes", "no"], label="Enable code search", info="Enable by default ")
ui_code_search.change(fn=on_code_search_changed, inputs=ui_code_search, outputs=[])

with gr.Row():
Expand Down
37 changes: 28 additions & 9 deletions huixiangdou/primitive/bm250kapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import numpy as np
import pickle as pkl
import os
import jieba.analyse

from loguru import logger
from typing import List
from typing import List, Union
from .chunk import Chunk

"""
Expand All @@ -33,7 +34,7 @@ def __init__(self, k1=1.5, b=0.75, epsilon=0.25):
self.chunks = []

# option
self.tokenizer = None
self.tokenizer = jieba.analyse.extract_tags

def _initialize(self, corpus):
nd = {} # word -> number of documents with word
Expand Down Expand Up @@ -64,16 +65,20 @@ def _tokenize_corpus(self, corpus):
tokenized_corpus = self.tokenizer(corpus)
return tokenized_corpus

def save(self, chunks:List[Chunk], filedir:str, tokenizer=None):
def save(self, chunks:List[Chunk], filedir:str):
# generate idf with corpus
self.chunks = chunks

filtered_corpus = []
for c in chunks:
content = c.content_or_path
if tokenizer:
if self.tokenizer is not None:
# input str, output list of str
corpus = self._tokenize_corpus(content)
corpus = self.tokenizer(content)
if content not in corpus:
corpus.append(content)
else:
logger.warning('No tokenizer, use naive split')
corpus = content.split(' ')
filtered_corpus.append(corpus)

Expand All @@ -91,7 +96,7 @@ def save(self, chunks:List[Chunk], filedir:str, tokenizer=None):
'chunks': chunks
}
logger.info('bm250kpi dump..')
logger.info(data)
# logger.info(data)

if not os.path.exists(filedir):
os.makedirs(filedir)
Expand Down Expand Up @@ -135,14 +140,16 @@ def _calc_idf(self, nd):
for word in negative_idfs:
self.idf[word] = eps

def get_scores(self, query):
def get_scores(self, query: List):
"""
The ATIRE BM25 variant uses an idf function which uses a log(idf) score. To prevent negative idf scores,
this algorithm also adds a floor to the idf value of epsilon.
See [Trotman, A., X. Jia, M. Crane, Towards an Efficient and Effective Search Engine] for more info
:param query:
:return:
"""
if type(query) is not list:
raise ValueError('query must be list, tokenize it byself.')
score = np.zeros(self.corpus_size)
doc_len = np.array(self.doc_len)
for q in query:
Expand All @@ -164,7 +171,19 @@ def get_batch_scores(self, query, doc_ids):
(q_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)))
return score.tolist()

def get_top_n(self, query, n=5):
scores = self.get_scores(query)
def get_top_n(self, query: Union[List,str], n=5):
if type(query) is str:
if self.tokenizer is not None:
queries = self.tokenizer(query)
else:
queries = query.split(' ')
else:
queries = query

scores = self.get_scores(queries)
top_n = np.argsort(scores)[::-1][:n]
logger.info('{} {}'.format(scores, top_n))
if abs(scores[top_n[0]]) < 1e-5:
# not match, quit
return []
return [self.chunks[i] for i in top_n]
16 changes: 8 additions & 8 deletions huixiangdou/primitive/file_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,20 +217,20 @@ def read(self, filepath: str):
text += soup.text

elif file_type == 'code':
text = filepath + "\n"
with open(filepath, errors="ignore") as f:
text += f.read()
return text

except Exception as e:
logger.error((filepath, str(e)))
return '', e
text = text.replace('\n\n', '\n')
text = text.replace('\n\n', '\n')
text = text.replace('\n\n', '\n')
text = text.replace(' ', ' ')
text = text.replace(' ', ' ')
text = text.replace(' ', ' ')

if file_type != 'code':
text = text.replace('\n\n', '\n')
text = text.replace('\n\n', '\n')
text = text.replace('\n\n', '\n')
text = text.replace(' ', ' ')
text = text.replace(' ', ' ')
text = text.replace(' ', ' ')
return text, None


Expand Down
5 changes: 1 addition & 4 deletions huixiangdou/primitive/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,10 +643,7 @@ def split_python_code(filepath: str, text: str, metadata: dict = {}):
if data:
texts.append(f"{child_node.name} {data}")
except Exception as e:
logger.error(e)
with open(filepath) as f:
texts.append(f.read())

logger.error('{} {}, continue'.format(filepath, str(e)))
chunks = []
for text in texts:
chunks.append(Chunk(content_or_path=text, metadata=metadata))
Expand Down
2 changes: 1 addition & 1 deletion huixiangdou/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def rag(process_id: int, task: list, output_dir: str):
for item in task:
query = item.query

for sess in assistant.generate(query=query, history=[], groupname='')
for sess in assistant.generate(query=query, history=[], groupname=''):
item.rag_reply = sess.response
item.code = int(sess.code)
item.reason = str(sess.code)
Expand Down
9 changes: 7 additions & 2 deletions huixiangdou/service/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

from ..primitive import (ChineseRecursiveTextSplitter, Chunk, Embedder, Faiss,
FileName, FileOperation,
RecursiveCharacterTextSplitter, nested_split_markdown)
RecursiveCharacterTextSplitter, nested_split_markdown,
split_python_code,
BM25Okapi)
from .helper import histogram
from .llm_server_hybrid import start_llm_server
from .retriever import CacheRetriever, Retriever
Expand Down Expand Up @@ -117,7 +119,9 @@ def build_sparse(self, files: List[FileName], work_dir: str):
chunks = []

for file in files:
content = fileopr.read(file.origin)
content, error = fileopr.read(file.origin)
if error is not None:
continue
file_chunks = split_python_code(filepath=file.origin, text=content, metadata={'source': file.origin, 'read': file.copypath})
chunks += file_chunks

Expand Down Expand Up @@ -402,6 +406,7 @@ def test_query(retriever: Retriever, sample: str = None):

# walk all files in repo dir
file_opr = FileOperation()

files = file_opr.scan_dir(repo_dir=args.repo_dir)
fs_init.initialize(files=files, work_dir=args.work_dir)
file_opr.summarize(files)
Expand Down
2 changes: 1 addition & 1 deletion huixiangdou/service/parallel_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def process_coroutine(self, sess: Session) -> Session:
"""Try get reply with text2vec & rerank model."""

# retrieve from knowledge base
if self.retriever.bm25 is None
if self.retriever.bm25 is None:
sess.parallel_chunks = []
return sess

Expand Down
6 changes: 2 additions & 4 deletions huixiangdou/service/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
import os
import pdb
import time
from typing import Any, Union, Tuple, List

import numpy as np
import pytoml
from loguru import logger
from sklearn.metrics import precision_recall_curve
from typing import Any, Union, Tuple, List

from huixiangdou.primitive import Embedder, Faiss, LLMReranker, Query, Chunk

from ..primitive import FileOperation
from huixiangdou.primitive import Embedder, Faiss, LLMReranker, Query, Chunk, BM25Okapi, FileOperation
from .helper import QueryTracker
from .kg import KnowledgeGraph

Expand Down
7 changes: 6 additions & 1 deletion unittest/primitive/test_bm250api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ def test_bm25_dump():
def test_bm25_load():
bm25 = BM25Okapi()
bm25.load('./')
res = bm25.get_top_n(query='what is the weather')
query_text = 'what is the weather'

res = bm25.get_top_n(query=query_text.split(' '))
print(res)

res = bm25.get_top_n(query=query_text)
print(res)

if __name__ == '__main__':
Expand Down

0 comments on commit 46b240c

Please sign in to comment.