AI创想

标题: LangGraph异步化sqlite checkpoint [打印本页]

作者: 创想小编    时间: 昨天 09:43
标题: LangGraph异步化sqlite checkpoint
作者:CSDN博客
安装
  1. pip install langgraph-checkpoint-sqlite
复制代码
异步checkpiont初始化:
  1. from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
  2. conn = aiosqlite.connect(":memory:", check_same_thread=False)
  3. memory = AsyncSqliteSaver(conn)
复制代码
如果使用异步流式应对,需要确保llm节点或者相关节点也转成异步化操作
  1. asyncdefllm(self, state: AgentState):
  2.         llm_msgs = state['messages']if self.systemMessage:
  3.             llm_msgs = self.systemMessage + state['messages']print(f'ask llm to handle request msg, msg: {llm_msgs}')try:# 关键修复:await 异步调用并直接获取结果
  4.             msg =await self.model.ainvoke(llm_msgs)print(f'msg={msg}')return{'messages':[msg]}# 确保返回的是消息对象而非协程except Exception as e:print(f"Model invocation error: {e}")# 返回错误提示消息(需符合Message类型)from langchain_core.messages import AIMessage
  5.             return{'messages':[AIMessage(content=f"Error: {str(e)}")]}asyncdeftake_action_tool(self, state: AgentState):
  6.         current_tools: List[ToolCall]= state['messages'][-1].tool_calls
  7.         results =[]for t in current_tools:
  8.             tool_result =await self.tools[t['name']].ainvoke(t['args'])
  9.             results.append(ToolMessage(
  10.                 tool_call_id=t['id'],
  11.                 content=str(tool_result),
  12.                 name=t['name'],))print(f'Back to model')return{'messages': results}
复制代码
最后的完整代码如下:
  1. import asyncio
  2. from typing import Annotated, List, TypedDict
  3. import os
  4. import aiosqlite
  5. from langchain_community.chat_models import ChatTongyi
  6. from langchain_core.language_models import BaseChatModel
  7. from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall
  8. from dotenv import load_dotenv
  9. from langchain_community.tools.tavily_search import TavilySearchResults
  10. from langchain_core.tools import BaseTool
  11. from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
  12. from langgraph.constants import END, START
  13. from langgraph.graph import add_messages, StateGraph
  14. conn = aiosqlite.connect(":memory:", check_same_thread=False)
  15. load_dotenv(dotenv_path='../keys.env')
  16. ts_tool = TavilySearchResults(max_results=2)classAgentState(TypedDict):
  17.     messages: Annotated[List[AnyMessage], add_messages]classAgent:def__init__(
  18.             self,
  19.             model: BaseChatModel,
  20.             systemMessage: List[SystemMessage],
  21.             tools: List[BaseTool],
  22.             memory,):assertall(isinstance(t, BaseTool)for t in tools),'tools must implement BASEcALL'
  23.         graph = StateGraph(AgentState)
  24.         graph.add_node('llm', self.llm)
  25.         graph.add_node('take_action_tool', self.take_action_tool)
  26.         graph.add_conditional_edges('llm',
  27.             self.exist_action,{True:'take_action_tool',False: END
  28.             })
  29.         graph.set_entry_point('llm')
  30.         graph.add_edge('take_action_tool','llm')
  31.         self.app = graph.compile(checkpointer=memory)
  32.         self.tools ={t.name: t for t in tools}
  33.         self.model = model.bind_tools(tools)
  34.         self.systemMessage = systemMessage
  35.     defexist_action(self, state: AgentState):
  36.         tool_calls = state['messages'][-1].tool_calls
  37.         print(f'tool_calls size {len(tool_calls)}')returnlen(tool_calls)>0asyncdefllm(self, state: AgentState):
  38.         llm_msgs = state['messages']if self.systemMessage:
  39.             llm_msgs = self.systemMessage + state['messages']print(f'ask llm to handle request msg, msg: {llm_msgs}')try:# 关键修复:await 异步调用并直接获取结果
  40.             msg =await self.model.ainvoke(llm_msgs)print(f'msg={msg}')return{'messages':[msg]}# 确保返回的是消息对象而非协程except Exception as e:print(f"Model invocation error: {e}")# 返回错误提示消息(需符合Message类型)from langchain_core.messages import AIMessage
  41.             return{'messages':[AIMessage(content=f"Error: {str(e)}")]}asyncdeftake_action_tool(self, state: AgentState):
  42.         current_tools: List[ToolCall]= state['messages'][-1].tool_calls
  43.         results =[]for t in current_tools:
  44.             tool_result =await self.tools[t['name']].ainvoke(t['args'])
  45.             results.append(ToolMessage(
  46.                 tool_call_id=t['id'],
  47.                 content=str(tool_result),
  48.                 name=t['name'],))print(f'Back to model')return{'messages': results}asyncdefwork():
  49.     prompt ="""You are a smart research assistant. Use the search engine to look up information. \
  50.     You are allowed to make multiple calls (either together or in sequence). \
  51.     Only look up information when you are sure of what you want. \
  52.     If you need to look up some information before asking a follow up question, you are allowed to do that!
  53.     """
  54.     qwen_model = ChatTongyi(
  55.         model=os.getenv('model'),
  56.         api_key=os.getenv('api_key'),
  57.         base_url=os.getenv('base_url'),)# reduce inference cost
  58.     memory = AsyncSqliteSaver(conn)
  59.     agent = Agent(model=qwen_model, tools=[ts_tool], systemMessage=[SystemMessage(content=prompt)], memory=memory)
  60.     messages =[HumanMessage("who is the popular football star in the world?")]
  61.     configurable ={"configurable":{"thread_id":"1"}}asyncfor event in agent.app.astream_events({"messages": messages}, configurable, version="v1"):
  62.         kind = event["event"]# print(f"kind = {kind}")if kind =="on_chat_model_stream":
  63.             content = event["data"]["chunk"].content
  64.             if content:# Empty content in the context of OpenAI means# that the model is asking for a tool to be invoked.# So we only print non-empty contentprint(content, end="|")if __name__ =='__main__':
  65.     asyncio.run(work())
复制代码
原文地址:https://blog.csdn.net/zhangkai1992/article/details/147075196




欢迎光临 AI创想 (https://www.llms-ai.com/) Powered by Discuz! X3.4