You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
154 lines
4.4 KiB
154 lines
4.4 KiB
"""
|
|
AI Agent 测试脚本
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from ai_agent import create_agent
|
|
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
|
|
async def test_agent():
|
|
"""测试AI Agent"""
|
|
|
|
print("=" * 80)
|
|
print("AI Agent 测试")
|
|
print("=" * 80)
|
|
print()
|
|
|
|
# 测试用例
|
|
test_cases = [
|
|
{
|
|
"name": "测试1:分析游戏类视频趋势",
|
|
"query": "帮我分析一下游戏类视频的热门趋势"
|
|
},
|
|
{
|
|
"name": "测试2:美食创作建议",
|
|
"query": "我想做美食相关的内容,给我一些建议"
|
|
},
|
|
{
|
|
"name": "测试3:特定关键词分析",
|
|
"query": "王者荣耀的视频现在什么内容最火?"
|
|
}
|
|
]
|
|
|
|
# 选择要运行的测试
|
|
print("可用的测试用例:")
|
|
for i, case in enumerate(test_cases, 1):
|
|
print(f"{i}. {case['name']}")
|
|
print()
|
|
|
|
choice = input("请选择要运行的测试(输入数字,直接回车运行测试1): ").strip()
|
|
|
|
if not choice:
|
|
choice = "1"
|
|
|
|
try:
|
|
test_index = int(choice) - 1
|
|
if test_index < 0 or test_index >= len(test_cases):
|
|
print("无效的选择,使用测试1")
|
|
test_index = 0
|
|
except ValueError:
|
|
print("无效的输入,使用测试1")
|
|
test_index = 0
|
|
|
|
selected_test = test_cases[test_index]
|
|
|
|
print("\n" + "=" * 80)
|
|
print(f"运行: {selected_test['name']}")
|
|
print("=" * 80)
|
|
print(f"\n用户查询: {selected_test['query']}\n")
|
|
|
|
try:
|
|
# 创建Agent
|
|
print("正在初始化AI Agent...")
|
|
agent = create_agent()
|
|
print("✓ Agent初始化成功")
|
|
print()
|
|
|
|
# 运行Agent
|
|
print("正在执行任务...\n")
|
|
result = await agent.run(
|
|
user_input=selected_test['query'],
|
|
system_prompt_file="prompts/agent_prompt.md",
|
|
max_iterations=10
|
|
)
|
|
|
|
print("\n" + "=" * 80)
|
|
print("执行结果")
|
|
print("=" * 80)
|
|
|
|
if result["success"]:
|
|
print("\n" + result["final_answer"])
|
|
print("\n" + "=" * 80)
|
|
print("执行统计")
|
|
print("=" * 80)
|
|
print(f"总迭代次数: {result['iteration']}")
|
|
print(f"工具调用次数: {len(result['tool_calls'])}")
|
|
|
|
if result['tool_calls']:
|
|
print("\n工具调用详情:")
|
|
for i, call in enumerate(result['tool_calls'], 1):
|
|
print(f"\n {i}. {call['tool_name']}")
|
|
print(f" 迭代: {call['iteration']}")
|
|
print(f" 参数: {call['arguments']}")
|
|
if call['result'].get('success'):
|
|
print(f" 状态: ✓ 成功")
|
|
else:
|
|
print(f" 状态: ✗ 失败 - {call['result'].get('error')}")
|
|
|
|
print("\n✓ 测试成功!")
|
|
else:
|
|
print(f"\n✗ 执行失败: {result['error']}")
|
|
if result.get('partial_result'):
|
|
print(f"\n部分结果:\n{result['partial_result']}")
|
|
|
|
if result.get('tool_calls'):
|
|
print(f"\n已完成的工具调用: {len(result['tool_calls'])}")
|
|
|
|
except Exception as e:
|
|
print(f"\n✗ 测试出错: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
|
|
async def test_simple():
|
|
"""简单测试"""
|
|
print("=" * 80)
|
|
print("AI Agent 简单测试")
|
|
print("=" * 80)
|
|
print()
|
|
|
|
query = input("请输入你的查询(直接回车使用默认查询): ").strip()
|
|
|
|
if not query:
|
|
query = "帮我分析一下游戏类视频的热门趋势"
|
|
print(f"使用默认查询: {query}")
|
|
|
|
print()
|
|
|
|
try:
|
|
agent = create_agent()
|
|
result = await agent.run(query, system_prompt_file="prompts/agent_prompt.md")
|
|
|
|
if result["success"]:
|
|
print("\n" + "=" * 80)
|
|
print("结果")
|
|
print("=" * 80)
|
|
print(result["final_answer"])
|
|
else:
|
|
print(f"\n执行失败: {result['error']}")
|
|
except Exception as e:
|
|
print(f"\n出错: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
if len(sys.argv) > 1 and sys.argv[1] == "--simple":
|
|
asyncio.run(test_simple())
|
|
else:
|
|
asyncio.run(test_agent())
|
|
|