Skip to content

Commit

Permalink
1. refactor rerank model.
Browse files Browse the repository at this point in the history
2. make rag variables same with spring ai.
  • Loading branch information
robinyeeh committed Oct 15, 2024
1 parent bb0dca3 commit e3f66fc
Show file tree
Hide file tree
Showing 18 changed files with 298 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,18 @@
public class DocumentRetrievalAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {

private static final String DEFAULT_USER_TEXT_ADVISE = """
请记住以下材料,他们可能对回答问题有帮助。
Context information is below.
---------------------
{documents}
{question_answer_context}
---------------------
Given the context and provided history information and not prior knowledge,
reply to the user comment. If the answer is not in the context, inform
the user that you can't answer the question.
""";

private static final int DEFAULT_ORDER = 0;

public static String RETRIEVED_DOCUMENTS = "documents";
public static String RETRIEVED_DOCUMENTS = "qa_retrieved_documents";

private final DocumentRetriever retriever;

Expand Down Expand Up @@ -137,7 +140,7 @@ private AdvisedRequest before(AdvisedRequest request) {
.collect(Collectors.joining(System.lineSeparator()));

Map<String, Object> advisedUserParams = new HashMap<>(request.userParams());
advisedUserParams.put(RETRIEVED_DOCUMENTS, documentContext);
advisedUserParams.put("question_answer_context", documentContext);

return AdvisedRequest.from(request)
.withSystemText(this.userTextAdvise)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.alibaba.cloud.ai.document.DocumentWithScore;
import com.alibaba.cloud.ai.model.RerankModel;
import com.alibaba.cloud.ai.model.RerankRequest;
import com.alibaba.cloud.ai.model.RerankResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -184,17 +185,18 @@ protected List<Document> doRerank(AdvisedRequest request, List<Document> documen
return documents;
}

RerankResponse response = rerankModel.rerank(request.userText(), documents);
var rerankRequest = new RerankRequest(request.userText(), documents);
RerankResponse response = rerankModel.call(rerankRequest);
logger.debug("reranked documents: {}", response);
if (response == null || response.getDocuments() == null) {
if (response == null || response.getResults() == null) {
return documents;
}

return response.getDocuments()
return response.getResults()
.stream()
.filter(doc -> doc != null && doc.getScore() >= minScore)
.sorted(Comparator.comparingDouble(DocumentWithScore::getScore).reversed())
.map(DocumentWithScore::getDocument)
.map(DocumentWithScore::getOutput)
.collect(toList());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ public class DashScopeDocumentRetrievalAdvisor implements CallAroundAdvisor, Str
【正文】光合作用是利用阳光将CO2和H2O转化为氧气和葡萄糖的过程。
$$材料:
{documents}
{question_answer_context}
""";

private static final int DEFAULT_ORDER = 0;

public static String RETRIEVED_DOCUMENTS = "documents";
public static String RETRIEVED_DOCUMENTS = "qa_retrieved_documents";

private final DocumentRetriever retriever;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,20 @@
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
import com.alibaba.cloud.ai.dashscope.metadata.DashScopeAiUsage;
import com.alibaba.cloud.ai.document.DocumentWithScore;
import com.alibaba.cloud.ai.model.RerankModel;
import com.alibaba.cloud.ai.model.RerankResponse;
import com.alibaba.cloud.ai.model.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.lang.Nullable;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

import java.util.Collection;
import java.util.Collections;
import java.util.List;

import static java.util.Comparator.comparingInt;
Expand All @@ -45,7 +49,7 @@

public class DashScopeRerankModel implements RerankModel {

private static final Logger logger = LoggerFactory.getLogger(DashScopeChatModel.class);
private static final Logger logger = LoggerFactory.getLogger(DashScopeRerankModel.class);

/** Low-level access to the DashScope API */
private final DashScopeApi dashscopeApi;
Expand All @@ -54,61 +58,85 @@ public class DashScopeRerankModel implements RerankModel {
private final RetryTemplate retryTemplate;

/** rerank options */
private final DashScopeRerankOptions options;
private final DashScopeRerankOptions defaultOptions;

public DashScopeRerankModel(DashScopeApi dashscopeApi) {
this(dashscopeApi, DashScopeRerankOptions.builder().build());
}

public DashScopeRerankModel(DashScopeApi dashscopeApi, DashScopeRerankOptions options) {
this(dashscopeApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
public DashScopeRerankModel(DashScopeApi dashscopeApi, DashScopeRerankOptions defaultOptions) {
this(dashscopeApi, defaultOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public DashScopeRerankModel(DashScopeApi dashscopeApi, DashScopeRerankOptions options,
public DashScopeRerankModel(DashScopeApi dashscopeApi, DashScopeRerankOptions defaultOptions,
RetryTemplate retryTemplate) {
Assert.notNull(dashscopeApi, "DashScopeApi must not be null");
Assert.notNull(options, "Options must not be null");
Assert.notNull(defaultOptions, "Options must not be null");
Assert.notNull(retryTemplate, "RetryTemplate must not be null");

this.dashscopeApi = dashscopeApi;
this.options = options;
this.defaultOptions = defaultOptions;
this.retryTemplate = retryTemplate;
}

@Override
public RerankResponse rerank(String query, List<Document> documents) {
Assert.notNull(query, "query must not be null");
Assert.notNull(documents, "Options must not be null");
public RerankResponse call(RerankRequest request) {
Assert.notNull(request.getQuery(), "query must not be null");
Assert.notNull(request.getInstructions(), "documents must not be null");

List<String> docs = documents.stream().map(Document::getContent).toList();

DashScopeApi.RerankRequestParameter parameter = new DashScopeApi.RerankRequestParameter(options.getTopN(),
options.getReturnDocuments());
DashScopeApi.RerankRequestInput input = new DashScopeApi.RerankRequestInput(query, docs);
DashScopeApi.RerankRequest request = new DashScopeApi.RerankRequest(options.getModel(), input, parameter);
DashScopeRerankOptions requestOptions = mergeOptions(request.getOptions(), this.defaultOptions);
DashScopeApi.RerankRequest rerankRequest = createRequest(request, requestOptions);

ResponseEntity<DashScopeApi.RerankResponse> responseEntity = this.retryTemplate
.execute(ctx -> this.dashscopeApi.rerankEntity(request));
.execute(ctx -> this.dashscopeApi.rerankEntity(rerankRequest));

var response = responseEntity.getBody();

if (response == null) {
logger.warn("No rerank returned for query: {}", query);
return RerankResponse.builder().build();
logger.warn("No rerank returned for query: {}", request.getQuery());
return new RerankResponse(Collections.emptyList());
}

List<DocumentWithScore> documentWithScores = response.output()
.results()
.stream()
.map(data -> DocumentWithScore.builder()
.withScore(data.relevanceScore())
.withDocument(documents.get(data.index()))
.withDocument(request.getInstructions().get(data.index()))
.build())
.toList();

return RerankResponse.builder()
.withUsage(DashScopeAiUsage.from(response.usage()))
.withDocuments(documentWithScores)
var metadata = new RerankResponseMetadata(DashScopeAiUsage.from(response.usage()));
return new RerankResponse(documentWithScores, metadata);
}

private DashScopeApi.RerankRequest createRequest(RerankRequest request, DashScopeRerankOptions requestOptions) {
List<String> docs = request.getInstructions().stream().map(Document::getContent).toList();

DashScopeApi.RerankRequestParameter parameter = new DashScopeApi.RerankRequestParameter(
requestOptions.getTopN(), requestOptions.getReturnDocuments());
var input = new DashScopeApi.RerankRequestInput(request.getQuery(), docs);
return new DashScopeApi.RerankRequest(requestOptions.getModel(), input, parameter);
}

/**
* Merge runtime and default {@link RerankOptions} to compute the final options to use
* in the request.
*/
private DashScopeRerankOptions mergeOptions(@Nullable RerankOptions runtimeOptions,
DashScopeRerankOptions defaultOptions) {
var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, RerankOptions.class,
DashScopeRerankOptions.class);

if (runtimeOptionsForProvider == null) {
return defaultOptions;
}

return DashScopeRerankOptions.builder()
.withModel(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
.withTopN(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getTopN(), defaultOptions.getTopN()))
.withReturnDocuments(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getReturnDocuments(),
defaultOptions.getReturnDocuments()))
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,20 @@

public class DashScopeRerankOptions implements RerankOptions {

/** ID of the model to use. */
/**
* ID of the model to use.
*/
private String model = "gte-rerank";

/**
* return top n best relevant docs for query
*/
private Integer topN = 3;

private Boolean returnDocuments = true;
/**
* if need to return original document
*/
private Boolean returnDocuments = false;

@Override
public String getModel() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

package com.alibaba.cloud.ai.document;

import com.alibaba.cloud.ai.model.RerankResultMetadata;
import org.springframework.ai.document.Document;
import org.springframework.ai.model.ModelResult;
import org.springframework.ai.model.ResultMetadata;

import java.util.Objects;

Expand All @@ -28,7 +31,7 @@
* @since 1.0.0-M2
*/

public class DocumentWithScore {
public class DocumentWithScore implements ModelResult<Document> {

/**
* Score of document
Expand All @@ -40,6 +43,8 @@ public class DocumentWithScore {
*/
private Document document;

private RerankResultMetadata metadata;

public Double getScore() {
return score;
}
Expand All @@ -48,18 +53,28 @@ public void setScore(Double score) {
this.score = score;
}

public Document getDocument() {
return document;
}

public void setDocument(Document document) {
this.document = document;
}

public void setMetadata(RerankResultMetadata metadata) {
this.metadata = metadata;
}

public static Builder builder() {
return new Builder();
}

@Override
public Document getOutput() {
return this.document;
}

@Override
public ResultMetadata getMetadata() {
return this.metadata;
}

public static final class Builder {

private final DocumentWithScore documentWithScore;
Expand All @@ -78,6 +93,11 @@ public Builder withDocument(Document document) {
return this;
}

public Builder withMetadata(RerankResultMetadata metadata) {
this.documentWithScore.setMetadata(metadata);
return this;
}

public DocumentWithScore build() {
return documentWithScore;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public class AnswerFaithfulnessEvaluator extends LaajEvaluator {
最终答案按照标准的json格式输出,不要使用markdown的格式, 比如:
{"score": 0.7, "feedback": "STUDENT ANSWER的内容超出了FACTS的事实内容。"}
FACTS: {documents}
FACTS: {context}
STUDENT ANSWER: {student_answer}
""";

Expand Down Expand Up @@ -85,7 +85,7 @@ public EvaluationResponse evaluate(EvaluationRequest evaluationRequest) {
String llmEvaluationResponse = getChatClientBuilder().build()
.prompt()
.user(userSpec -> userSpec.text(getEvaluationPromptText())
.param("documents", context)
.param("context", context)
.param("student_answer", response))
.call()
.content();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@

package com.alibaba.cloud.ai.model;

import org.springframework.ai.document.Document;

import java.util.List;
import org.springframework.ai.model.Model;

/**
* Title rerank model interface.<br>
Expand All @@ -29,8 +27,9 @@
* @since 1.0.0-M2
*/

public interface RerankModel {
public interface RerankModel extends Model<RerankRequest, RerankResponse> {

RerankResponse rerank(String query, List<Document> documents);
@Override
RerankResponse call(RerankRequest request);

}
Loading

0 comments on commit e3f66fc

Please sign in to comment.