You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
80 lines
2.6 KiB
80 lines
2.6 KiB
"""
|
|
测试JSON格式的灵感输出
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
from ai_agent import create_agent
|
|
|
|
|
|
async def test_json_format():
|
|
"""测试JSON格式输出"""
|
|
|
|
print("=" * 80)
|
|
print("测试JSON格式的创作灵感生成")
|
|
print("=" * 80)
|
|
|
|
# 创建Agent
|
|
agent = create_agent()
|
|
|
|
# 测试查询
|
|
query = "我想做一些校园相关的短视频"
|
|
|
|
print(f"\n用户查询: {query}\n")
|
|
print("正在生成灵感...\n")
|
|
|
|
# 运行Agent
|
|
result = await agent.run(
|
|
user_input=query,
|
|
system_prompt_file="prompts/agent_prompt.md",
|
|
max_iterations=15
|
|
)
|
|
|
|
if result["success"]:
|
|
print("✓ 执行成功\n")
|
|
print("=" * 80)
|
|
print("最终答案:")
|
|
print("=" * 80)
|
|
print(result["final_answer"])
|
|
print("\n" + "=" * 80)
|
|
|
|
# 尝试从答案中提取JSON
|
|
import re
|
|
json_match = re.search(r'```json\s*(.*?)\s*```', result["final_answer"], re.DOTALL)
|
|
if json_match:
|
|
try:
|
|
json_data = json.loads(json_match.group(1))
|
|
print("\n✓ JSON格式验证成功")
|
|
print(f"✓ 生成了 {len(json_data.get('inspirations', []))} 个灵感")
|
|
|
|
# 显示第一个灵感作为示例
|
|
if json_data.get('inspirations'):
|
|
first = json_data['inspirations'][0]
|
|
print("\n示例灵感:")
|
|
print(json.dumps(first, ensure_ascii=False, indent=2))
|
|
|
|
# 验证字段
|
|
required_fields = ['id', 'title', 'description', 'reference_author',
|
|
'reference_description', 'url', 'platform', 'tags', 'keywords']
|
|
missing_fields = [f for f in required_fields if f not in first]
|
|
|
|
if missing_fields:
|
|
print(f"\n✗ 缺少字段: {', '.join(missing_fields)}")
|
|
else:
|
|
print("\n✓ 所有必需字段都存在")
|
|
|
|
except json.JSONDecodeError as e:
|
|
print(f"\n✗ JSON解析失败: {e}")
|
|
else:
|
|
print("\n✗ 未找到JSON格式的输出")
|
|
|
|
else:
|
|
print(f"✗ 执行失败: {result.get('error')}")
|
|
if result.get('tool_calls'):
|
|
print(f"\n工具调用记录:")
|
|
for call in result['tool_calls']:
|
|
print(f" - {call['tool_name']}: {call.get('result', {}).get('success', 'N/A')}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(test_json_format())
|
|
|