Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Graph module for saa #97

Merged
merged 3 commits into from
Nov 7, 2024
Merged

Add Graph module for saa #97

merged 3 commits into from
Nov 7, 2024

Conversation

wxbty
Copy link

@wxbty wxbty commented Nov 7, 2024

Describe what this PR does / why we need it

Add Graph module for saa

Describe how to verify it

set system env spring.ai.dashscope.api-key first.
Run main func of studio.AgentExecutorStreamingServer, then visit http://localhost:8080/.

@chickenlj chickenlj changed the base branch from main to workflow November 7, 2024 10:52
@chickenlj chickenlj merged commit 3a093e5 into alibaba:workflow Nov 7, 2024
2 checks passed
@wxbty
Copy link
Author

wxbty commented Nov 7, 2024

架构

引入Sample

以下sample参考langchain-graph,改造成spring

@Import({AgentService.class, ToolService.class})
@SpringBootApplication
public class AgentExecutorStreamingServer {

    public static void main(String[] args) throws Exception {
        //因为studio的jetty占用了8080端口,修改springboot端口避免冲突
        System.setProperty("server.port","8090");
        ConfigurableApplicationContext context = SpringApplication.run(AgentExecutorStreamingServer.class, args);
        //AgentService注入了spring-ai的ChatClient,引入dashscope包后,默认注入百炼大模型
        AgentService agentService = context.getBean(AgentService.class);
        //没有使用springmvc,所以手动用jackson进行参数映射
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);

        //核心步骤:构建graph对象
        //1、注入llm
        //2、注入序列化对象
        //3、组装工作流流程
        //AgentExecutor 
        var graph = new AgentExecutor(agentService).graphBuilder()
        .stateSerializer(AgentExecutor.Serializers.JSON.object() )
        .build();

        /*
        build(){
           return new StateGraph<>(State.SCHEMA, stateSerializer)
                    .addEdge(START,"agent")
                    .addNode( "agent", node_async(AgentExecutor.this::callAgent) )
                    .addNode( "action", AgentExecutor.this::executeTools )
                    .addConditionalEdges(
                            "agent",
                            edge_async(AgentExecutor.this::shouldContinue),
                            Map.of("continue", "action", "end", END)
                    )
                    .addEdge("action", "agent")
                    ;
          }
        */

        //打印工作流内容,非json格式
        GraphRepresentation plantUml = graph.getGraph(GraphRepresentation.Type.PLANTUML, "Adaptive RAG");
        System.out.println(plantUml.getContent());
        /*
         @startuml unnamed.puml
         skinparam usecaseFontSize 14
         skinparam usecaseStereotypeFontSize 12
         skinparam hexagonFontSize 14
         skinparam hexagonStereotypeFontSize 12
         title "Adaptive RAG"
         footer

         powered by langgraph4j
         end footer
         circle start<<input>>
         circle stop as __END__
         usecase "agent"<<Node>>
         usecase "action"<<Node>>
         hexagon "check state" as condition1<<Condition>>
         start -down-> "agent"
         "agent" -down-> "condition1"
         "condition1" --> "action": "continue"
         '"agent" --> "action": "continue"
         "condition1" -down-> stop: "end"
         '"agent" -down-> stop: "end"
         "action" -down-> "agent"
        @enduml
        */

        var server = LangGraphStreamingServerJetty.builder()
        .port(8080)
        .objectMapper(objectMapper)
        .title("AGENT EXECUTOR")
        .addInputStringArg("input")
        .stateGraph(graph)
        .build();

        //启动jetty,访问studio: http://127.0.0.1:8080
        server.start().join();

    }

}
public class AgentService {
    public final ToolService toolService;
    private final ChatClient chatClient;

    public AgentService(ChatClient.Builder chatClientBuilder, ToolService toolService) {
        var functions = toolService.agentFunctionsCallback().toArray(FunctionCallback[]::new);

        this.chatClient = chatClientBuilder
        .defaultSystem("You are a helpful AI Assistant answering questions.")
        .defaultFunctions( functions )
        .build();
        this.toolService = toolService;
    }

工程结构

红色的为新增内容

core-jdk8

  • action: 节点和边的回调函数,比如节点完成后执行tool,边的条件执行(按照前继节点的result,决定后续节点)
  • checkpoint: 检查点,用于保存和恢复对话的状态。

如上图,检查点包含当前节点id、后续节点、序列化、持久化等内容。

  • diagram:uml图相关,和studio通信的内容
  • serializer:节点和节点,节点和边之间需要序列化和反序列化,主要是json
  • state:状态相关操作,主要对map<String,Object> 进行操作,包括输入、输出、action等
  • utils:在jdk8上兼容高版本jdk的api,比如mapOf
  • CompiledGraph:状态机,更新节点状态及流转
public RunnableConfig updateState( RunnableConfig config, Map<String,Object> values, String asNode ) throws Exception {
    BaseCheckpointSaver saver = compileConfig.checkpointSaver().orElseThrow( () -> (new IllegalStateException("Missing CheckpointSaver!")) );

    // 合并状态value,存入检查点(可持久化)
    Checkpoint branchCheckpoint = saver.get(config)
    .map(Checkpoint::new)
    .map( cp -> cp.updateState(values, stateGraph.getChannels()) )
    .orElseThrow( () -> (new IllegalStateException("Missing Checkpoint!")) );

    String nextNodeId = null;
    if( asNode != null ) {
        //获取下一个节点
        nextNodeId = nextNodeId( asNode, branchCheckpoint.getState() );
    }
    // 更新检查点
    RunnableConfig newConfig = saver.put( config, branchCheckpoint );

    return RunnableConfig.builder(newConfig)
    .checkPointId( branchCheckpoint.getId() )
    .nextNode( nextNodeId )
    .build();
}
private String nextNodeId( EdgeValue<State> route , Map<String,Object> state, String nodeId ) throws Exception {

    if( route == null ) {
        throw StateGraph.RunnableErrors.missingEdge.exception(nodeId);
    }
    if( route.id() != null ) {
        return route.id();
    }
    if( route.value() != null ) {
        State derefState = stateGraph.getStateFactory().apply(state);
        org.bsc.langgraph4j.action.AsyncEdgeAction<State> condition = route.value().action();
        String newRoute = condition.apply(derefState).get();
        //根据条件边,路由到对应节点
        String result = route.value().mappings().get(newRoute);
        if( result == null ) {
            throw StateGraph.RunnableErrors.missingNodeInEdgeMapping.exception(nodeId, newRoute);
        }
        return result;
    }
    throw StateGraph.RunnableErrors.executionError.exception( format("invalid edge value for nodeId: [%s] !", nodeId) );
}
  • StateGraph:暴露给用户的graph对象,主要有addNode,addEdge,compile操作。compile会参数校验及检查图的结构(比如是否有独立节点),输出CompiledGraph。

samples

各种单元测试。包含了node回调的中间步骤:intermediateSteps。

studio

包含了jetty web的studio测试用例。见“引入Sample”

其他

agent节点完成后封装的内容,action类似functioncall/tool,finish是结果。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants