Skip to content

Commit

Permalink
feat: add proxy_tool_calls
Browse files Browse the repository at this point in the history
  • Loading branch information
hongshuo-wang committed Jan 20, 2025
1 parent 4c7d823 commit ed874be
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ public ChatResponse call(Prompt prompt) {
return response;
});

if (isToolCall(chatResponse,
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse,
Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
// Recursively call the call method with the tool call message
Expand Down Expand Up @@ -276,9 +276,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {

// @formatter:off
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {

if (isToolCall(response,
Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
if (!isProxyToolCalls(prompt, this.defaultOptions) &&
isToolCall(response, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ public class DashScopeChatOptions implements FunctionCallingOptions, ChatOptions
*/
private @JsonProperty("multi_model") Boolean multiModel = false;

/**
* If true, the Spring AI will not handle the function calls internally, but will proxy them to the client.
* It is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results.
* If false, the Spring AI will handle the function calls internally.
*/
@JsonIgnore
private Boolean proxyToolCalls;

@NestedConfigurationProperty
@JsonIgnore
private Map<String, Object> toolContext;
Expand Down Expand Up @@ -274,6 +282,15 @@ public void setSeed(Integer seed) {
this.seed = seed;
}

@Override
public Boolean getProxyToolCalls() {
return this.proxyToolCalls;
}

public void setProxyToolCalls(Boolean proxyToolCalls) {
this.proxyToolCalls = proxyToolCalls;
}

@Override
public List<FunctionCallback> getFunctionCallbacks() {
return this.functionCallbacks;
Expand Down Expand Up @@ -412,6 +429,11 @@ public DashscopeChatOptionsBuilder withFunctions(Set<String> functionNames) {
return this;
}

public DashscopeChatOptionsBuilder withProxyToolCalls(Boolean proxyToolCalls) {
this.options.proxyToolCalls = proxyToolCalls;
return this;
}

public DashscopeChatOptionsBuilder withSeed(Integer seed) {
this.options.seed = seed;
return this;
Expand Down Expand Up @@ -465,6 +487,7 @@ public static DashScopeChatOptions fromOptions(DashScopeChatOptions fromOptions)
.withTools(fromOptions.getTools())
.withToolContext(fromOptions.getToolContext())
.withMultiModel(fromOptions.getMultiModel())
.withProxyToolCalls(fromOptions.getProxyToolCalls())
.withVlHighResolutionImages(fromOptions.getVlHighResolutionImages())
.build();
}
Expand All @@ -476,13 +499,13 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
DashScopeChatOptions that = (DashScopeChatOptions) o;

return Objects.equals(model, that.model) && Objects.equals(stream, that.stream) && Objects.equals(temperature, that.temperature) && Objects.equals(seed, that.seed) && Objects.equals(topP, that.topP) && Objects.equals(topK, that.topK) && Objects.equals(stop, that.stop) && Objects.equals(enableSearch, that.enableSearch) && Objects.equals(responseFormat, that.responseFormat) && Objects.equals(incrementalOutput, that.incrementalOutput) && Objects.equals(repetitionPenalty, that.repetitionPenalty) && Objects.equals(tools, that.tools) && Objects.equals(toolChoice, that.toolChoice) && Objects.equals(vlHighResolutionImages, that.vlHighResolutionImages) && Objects.equals(functionCallbacks, that.functionCallbacks) && Objects.equals(functions, that.functions) && Objects.equals(multiModel, that.multiModel) && Objects.equals(toolContext, that.toolContext);
return Objects.equals(model, that.model) && Objects.equals(stream, that.stream) && Objects.equals(temperature, that.temperature) && Objects.equals(seed, that.seed) && Objects.equals(topP, that.topP) && Objects.equals(topK, that.topK) && Objects.equals(stop, that.stop) && Objects.equals(enableSearch, that.enableSearch) && Objects.equals(responseFormat, that.responseFormat) && Objects.equals(incrementalOutput, that.incrementalOutput) && Objects.equals(repetitionPenalty, that.repetitionPenalty) && Objects.equals(tools, that.tools) && Objects.equals(toolChoice, that.toolChoice) && Objects.equals(vlHighResolutionImages, that.vlHighResolutionImages) && Objects.equals(functionCallbacks, that.functionCallbacks) && Objects.equals(functions, that.functions) && Objects.equals(multiModel, that.multiModel) && Objects.equals(toolContext, that.toolContext) && Objects.equals(proxyToolCalls, that.proxyToolCalls);
}

@Override
public int hashCode() {

return Objects.hash(model, stream, temperature, seed, topP, topK, stop, enableSearch, responseFormat, incrementalOutput, repetitionPenalty, tools, toolChoice, vlHighResolutionImages, functionCallbacks, functions, multiModel, toolContext);
return Objects.hash(model, stream, temperature, seed, topP, topK, stop, enableSearch, responseFormat, incrementalOutput, repetitionPenalty, tools, toolChoice, vlHighResolutionImages, functionCallbacks, functions, multiModel, toolContext, proxyToolCalls);
}

@Override
Expand Down

0 comments on commit ed874be

Please sign in to comment.