You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
996 lines
36 KiB
996 lines
36 KiB
"""
|
|
AI Agent - 智能代理系统
|
|
支持函数调用,让AI自主选择工具完成任务
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import asyncio
|
|
from typing import List, Dict, Optional, Any
|
|
from pathlib import Path
|
|
import dashscope
|
|
from dashscope import Generation
|
|
from datetime import datetime
|
|
from dotenv import load_dotenv
|
|
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
|
|
class AIAgent:
|
|
"""AI智能代理,支持函数调用"""
|
|
|
|
def __init__(self, api_key: str = None, model: str = "qwen-plus"):
|
|
"""
|
|
初始化AI代理
|
|
|
|
Args:
|
|
api_key: 阿里云百炼API Key
|
|
model: 使用的模型名称
|
|
"""
|
|
self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY")
|
|
if not self.api_key:
|
|
raise ValueError("未提供API Key")
|
|
|
|
dashscope.api_key = self.api_key
|
|
self.model = model
|
|
self.conversation_history = []
|
|
self.available_tools = {}
|
|
|
|
def register_tool(self, name: str, func: callable, description: str, parameters: dict):
|
|
"""
|
|
注册工具函数
|
|
|
|
Args:
|
|
name: 工具名称
|
|
func: 工具函数
|
|
description: 工具描述
|
|
parameters: 参数定义(JSON Schema格式)
|
|
"""
|
|
self.available_tools[name] = {
|
|
"function": func,
|
|
"definition": {
|
|
"type": "function",
|
|
"function": {
|
|
"name": name,
|
|
"description": description,
|
|
"parameters": parameters
|
|
}
|
|
}
|
|
}
|
|
|
|
def get_tools_definition(self) -> List[Dict]:
|
|
"""获取所有工具的定义"""
|
|
return [tool["definition"] for tool in self.available_tools.values()]
|
|
|
|
async def execute_tool(self, tool_name: str, arguments: dict) -> Any:
|
|
"""
|
|
执行工具函数
|
|
|
|
Args:
|
|
tool_name: 工具名称
|
|
arguments: 工具参数
|
|
|
|
Returns:
|
|
工具执行结果
|
|
"""
|
|
if tool_name not in self.available_tools:
|
|
return {"error": f"工具 {tool_name} 不存在"}
|
|
|
|
try:
|
|
func = self.available_tools[tool_name]["function"]
|
|
|
|
# 如果是异步函数
|
|
if asyncio.iscoroutinefunction(func):
|
|
result = await func(**arguments)
|
|
else:
|
|
result = func(**arguments)
|
|
|
|
return result
|
|
except Exception as e:
|
|
return {"error": f"执行工具 {tool_name} 时出错: {str(e)}"}
|
|
|
|
def load_system_prompt(self, prompt_file: str) -> str:
|
|
"""加载系统提示词"""
|
|
try:
|
|
with open(prompt_file, 'r', encoding='utf-8') as f:
|
|
return f.read()
|
|
except Exception as e:
|
|
return f"加载提示词失败: {e}"
|
|
|
|
async def run(self, user_input: str, system_prompt_file: str = "prompts/agent_prompt.md",
|
|
max_iterations: int = 10) -> Dict:
|
|
"""
|
|
运行AI代理
|
|
|
|
Args:
|
|
user_input: 用户输入
|
|
system_prompt_file: 系统提示词文件
|
|
max_iterations: 最大迭代次数(防止无限循环)
|
|
|
|
Returns:
|
|
执行结果
|
|
"""
|
|
# 加载系统提示词
|
|
system_prompt = self.load_system_prompt(system_prompt_file)
|
|
|
|
# 初始化对话历史
|
|
self.conversation_history = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_input}
|
|
]
|
|
|
|
iteration = 0
|
|
tool_calls_log = []
|
|
|
|
while iteration < max_iterations:
|
|
iteration += 1
|
|
|
|
# 调用模型
|
|
try:
|
|
response = Generation.call(
|
|
model=self.model,
|
|
messages=self.conversation_history,
|
|
tools=self.get_tools_definition() if self.available_tools else None,
|
|
result_format='message'
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
return {
|
|
"success": False,
|
|
"error": f"API调用失败: {response.code} - {response.message}",
|
|
"iteration": iteration
|
|
}
|
|
|
|
assistant_message = response.output.choices[0].message
|
|
|
|
# 检查是否有工具调用(安全检查)
|
|
has_tool_calls = False
|
|
tool_calls_data = None
|
|
|
|
try:
|
|
if hasattr(assistant_message, 'tool_calls'):
|
|
tool_calls_data = assistant_message.tool_calls
|
|
if tool_calls_data:
|
|
has_tool_calls = True
|
|
except (KeyError, AttributeError):
|
|
pass
|
|
|
|
# 将助手消息添加到历史
|
|
message_to_add = {
|
|
"role": assistant_message.role,
|
|
"content": assistant_message.content or ""
|
|
}
|
|
|
|
if has_tool_calls:
|
|
message_to_add["tool_calls"] = tool_calls_data
|
|
|
|
self.conversation_history.append(message_to_add)
|
|
|
|
# 执行工具调用
|
|
if has_tool_calls:
|
|
# 执行所有工具调用
|
|
for tool_call in tool_calls_data:
|
|
# 处理不同的tool_call格式
|
|
if isinstance(tool_call, dict):
|
|
tool_name = tool_call['function']['name']
|
|
tool_args = json.loads(tool_call['function']['arguments'])
|
|
else:
|
|
tool_name = tool_call.function.name
|
|
tool_args = json.loads(tool_call.function.arguments)
|
|
|
|
print(f"\n[迭代 {iteration}] 调用工具: {tool_name}")
|
|
print(f"参数: {json.dumps(tool_args, ensure_ascii=False, indent=2)}")
|
|
|
|
# 执行工具
|
|
tool_result = await self.execute_tool(tool_name, tool_args)
|
|
|
|
# 记录工具调用
|
|
tool_calls_log.append({
|
|
"iteration": iteration,
|
|
"tool_name": tool_name,
|
|
"arguments": tool_args,
|
|
"result": tool_result
|
|
})
|
|
|
|
print(f"结果: {json.dumps(tool_result, ensure_ascii=False, indent=2)[:200]}...")
|
|
|
|
# 将工具结果添加到对话历史
|
|
self.conversation_history.append({
|
|
"role": "tool",
|
|
"name": tool_name,
|
|
"content": json.dumps(tool_result, ensure_ascii=False)
|
|
})
|
|
|
|
# 继续下一轮对话
|
|
continue
|
|
|
|
# 没有工具调用,说明任务完成
|
|
return {
|
|
"success": True,
|
|
"final_answer": assistant_message.content,
|
|
"iteration": iteration,
|
|
"tool_calls": tool_calls_log,
|
|
"conversation_history": self.conversation_history
|
|
}
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
error_details = traceback.format_exc()
|
|
print(f"\n错误详情:\n{error_details}")
|
|
return {
|
|
"success": False,
|
|
"error": f"执行过程出错: {str(e)}",
|
|
"iteration": iteration,
|
|
"tool_calls": tool_calls_log
|
|
}
|
|
|
|
# 达到最大迭代次数
|
|
return {
|
|
"success": False,
|
|
"error": f"达到最大迭代次数 {max_iterations}",
|
|
"iteration": iteration,
|
|
"tool_calls": tool_calls_log,
|
|
"partial_result": self.conversation_history[-1].get("content") if self.conversation_history else None
|
|
}
|
|
|
|
|
|
# ==================== 工具函数定义 ====================
|
|
|
|
async def search_douyin_videos(keyword: str, max_scroll: int = 5) -> Dict:
|
|
"""
|
|
搜索抖音视频
|
|
|
|
Args:
|
|
keyword: 搜索关键词
|
|
max_scroll: 最大滚动次数
|
|
|
|
Returns:
|
|
搜索结果
|
|
"""
|
|
from douyin_data_soupce.douyin_search_crawler import DouyinSearchCrawler
|
|
|
|
crawler = DouyinSearchCrawler(headless=True)
|
|
|
|
try:
|
|
await crawler.init_browser()
|
|
cookie_loaded = await crawler.load_cookies("douyin_data_soupce/douyin_cookie.json")
|
|
|
|
if not cookie_loaded:
|
|
return {"success": False, "error": "无法加载Cookie"}
|
|
|
|
videos = await crawler.search_videos(keyword, max_scroll=max_scroll)
|
|
await crawler.save_results(keyword, videos)
|
|
|
|
return {
|
|
"success": True,
|
|
"keyword": keyword,
|
|
"total_count": len(videos),
|
|
"videos": videos[:10] # 只返回前10个,避免数据过大
|
|
}
|
|
except Exception as e:
|
|
return {"success": False, "error": str(e)}
|
|
finally:
|
|
await crawler.close()
|
|
|
|
|
|
async def get_creative_guidance(category: str = "全部") -> Dict:
|
|
"""
|
|
获取抖音创作指导数据
|
|
|
|
Args:
|
|
category: 分类(全部/美食/旅行/泛生活/汽车/科技/游戏/二次元)
|
|
|
|
Returns:
|
|
创作指导数据
|
|
"""
|
|
from playwright.async_api import async_playwright
|
|
import sys
|
|
|
|
# 导入创作指导爬虫的函数
|
|
sys.path.append(str(Path(__file__).parent / "douyin_data_soupce"))
|
|
|
|
url = "https://creator.douyin.com/creator-micro/creative-guidance"
|
|
output_dir = "douyin_data_soupce/douyin_data"
|
|
cookie_file = "douyin_data_soupce/douyin_cookie.json"
|
|
|
|
try:
|
|
async with async_playwright() as p:
|
|
browser = await p.chromium.launch(headless=True)
|
|
|
|
# 加载Cookie
|
|
context_options = {}
|
|
if Path(cookie_file).exists():
|
|
with open(cookie_file, 'r', encoding='utf-8') as f:
|
|
cookies = json.load(f)
|
|
context_options['storage_state'] = {'cookies': cookies}
|
|
|
|
context = await browser.new_context(**context_options)
|
|
page = await context.new_page()
|
|
|
|
await page.goto(url, wait_until="domcontentloaded", timeout=60000)
|
|
await asyncio.sleep(10)
|
|
|
|
# 点击分类
|
|
category_clicked = False
|
|
if category and category != "全部":
|
|
print(f"尝试点击分类: {category}")
|
|
|
|
# 先展开所有分类
|
|
await page.evaluate("""
|
|
() => {
|
|
const showButtons = document.querySelectorAll('.show-button-sDo51G');
|
|
showButtons.forEach(btn => {
|
|
const text = btn.textContent.trim();
|
|
// 如果按钮不是"收起",说明需要展开
|
|
if (!text.includes('收起')) {
|
|
btn.click();
|
|
}
|
|
});
|
|
}
|
|
""")
|
|
await asyncio.sleep(1)
|
|
|
|
# 查找并点击分类标签
|
|
category_clicked = await page.evaluate(f"""
|
|
() => {{
|
|
const categoryDivs = Array.from(document.querySelectorAll('.each-kind-MR__DN'));
|
|
const targetDiv = categoryDivs.find(div =>
|
|
div.textContent.trim() === '{category}'
|
|
);
|
|
if (targetDiv) {{
|
|
targetDiv.click();
|
|
return true;
|
|
}}
|
|
return false;
|
|
}}
|
|
""")
|
|
print(f"分类点击结果: {category_clicked}")
|
|
if category_clicked:
|
|
await asyncio.sleep(8)
|
|
else:
|
|
print(f"警告: 未找到分类 '{category}',将返回默认数据")
|
|
|
|
# 提取数据
|
|
data = await page.evaluate("""
|
|
() => {
|
|
const videos = [];
|
|
const authorLinks = Array.from(document.querySelectorAll('a[href*="iesdouyin.com/share/user/"]'));
|
|
|
|
authorLinks.forEach((authorLink, index) => {
|
|
try {
|
|
const author = authorLink.textContent.trim();
|
|
let container = authorLink.parentElement;
|
|
let maxLevels = 10;
|
|
|
|
while (container && maxLevels > 0) {
|
|
if (container.querySelector('.contain-info-LpWGHS')) break;
|
|
container = container.parentElement;
|
|
maxLevels--;
|
|
}
|
|
|
|
if (!container) return;
|
|
|
|
const paragraphs = Array.from(container.querySelectorAll('p'));
|
|
let description = '';
|
|
for (let p of paragraphs) {
|
|
const text = p.textContent.trim();
|
|
if (text && text !== '|' && text.length > 5 && !text.includes('万') && !text.includes(':')) {
|
|
description = text;
|
|
break;
|
|
}
|
|
}
|
|
|
|
let hot = '', plays = '', likes = '', comments = '';
|
|
const infoContainer = container.querySelector('.contain-info-LpWGHS');
|
|
if (infoContainer) {
|
|
const infoItems = infoContainer.querySelectorAll('.each-info-TpmTI0');
|
|
infoItems.forEach(item => {
|
|
const img = item.querySelector('img');
|
|
const text = item.textContent.trim();
|
|
if (img && img.src) {
|
|
if (img.src.includes('hot_')) hot = text;
|
|
else if (img.src.includes('play')) plays = text;
|
|
else if (img.src.includes('digg')) likes = text;
|
|
else if (img.src.includes('comment')) comments = text;
|
|
}
|
|
});
|
|
}
|
|
|
|
const hotWords = [];
|
|
const hotWordElements = container.querySelectorAll('.other-text-XeleRf');
|
|
hotWordElements.forEach((el, i) => {
|
|
const text = el.textContent.trim();
|
|
if (i > 0 && text && !text.includes('热词')) hotWords.push(text);
|
|
});
|
|
|
|
const hashTags = description.match(/#[^\\s#]+/g) || [];
|
|
|
|
if (author && description) {
|
|
videos.push({
|
|
index: index + 1,
|
|
author: author,
|
|
description: description,
|
|
authorLink: authorLink.href,
|
|
hot: hot,
|
|
plays: plays,
|
|
likes: likes,
|
|
comments: comments,
|
|
hotWords: hotWords,
|
|
hashTags: hashTags
|
|
});
|
|
}
|
|
} catch (e) {}
|
|
});
|
|
|
|
return { total: videos.length, videos: videos };
|
|
}
|
|
""")
|
|
|
|
await browser.close()
|
|
|
|
# 保存数据
|
|
if data['total'] > 0:
|
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
json_file = Path(output_dir) / f"douyin_creative_guidance_{category}_{timestamp}.json"
|
|
|
|
result_data = {
|
|
'page_url': url,
|
|
'category': category,
|
|
'crawl_time': datetime.now().isoformat(),
|
|
'total_videos': data['total'],
|
|
'videos': data['videos']
|
|
}
|
|
|
|
with open(json_file, "w", encoding="utf-8") as f:
|
|
json.dump(result_data, f, ensure_ascii=False, indent=2)
|
|
|
|
return {
|
|
"success": True,
|
|
"category": category,
|
|
"total_count": data['total'],
|
|
"videos": data['videos'][:10] # 只返回前10个
|
|
}
|
|
|
|
except Exception as e:
|
|
return {"success": False, "error": str(e)}
|
|
|
|
|
|
def analyze_video_data(videos: List[Dict], focus: str = None) -> Dict:
|
|
"""
|
|
使用AI深度分析视频数据(使用analyze_prompt.md提示词)
|
|
|
|
Args:
|
|
videos: 视频数据列表
|
|
focus: 分析重点(可选)
|
|
|
|
Returns:
|
|
分析结果
|
|
"""
|
|
if not videos:
|
|
return {"success": False, "error": "没有视频数据"}
|
|
|
|
try:
|
|
# 使用AIAnalyzer进行深度分析
|
|
from ai_analyzer import AIAnalyzer
|
|
|
|
analyzer = AIAnalyzer()
|
|
|
|
# 构建自定义指令
|
|
custom_instruction = f"重点关注: {focus}" if focus else None
|
|
|
|
# 调用AI分析(使用analyze_prompt.md)
|
|
result = analyzer.analyze(
|
|
videos=videos,
|
|
prompt_file="prompts/analyze_prompt.md",
|
|
custom_instruction=custom_instruction
|
|
)
|
|
|
|
if result["success"]:
|
|
return {
|
|
"success": True,
|
|
"analysis": result["analysis"],
|
|
"model": result["model"],
|
|
"video_count": result["video_count"],
|
|
"usage": result.get("usage"),
|
|
"focus": focus
|
|
}
|
|
else:
|
|
# 如果AI分析失败,回退到简单统计分析
|
|
return simple_analyze_video_data(videos, focus)
|
|
|
|
except Exception as e:
|
|
# 如果出错,回退到简单统计分析
|
|
print(f"AI分析出错,使用简单统计: {e}")
|
|
return simple_analyze_video_data(videos, focus)
|
|
|
|
|
|
def generate_creative_inspirations(videos: List[Dict], user_query: str, count: int = 9) -> Dict:
|
|
"""
|
|
基于视频数据生成创作灵感 - 为每条视频生成对应的灵感
|
|
|
|
Args:
|
|
videos: 视频数据列表
|
|
user_query: 用户的原始查询
|
|
count: 生成灵感的数量,默认9个
|
|
|
|
Returns:
|
|
创作灵感列表
|
|
"""
|
|
if not videos:
|
|
return {"success": False, "error": "没有视频数据"}
|
|
|
|
try:
|
|
from ai_analyzer import AIAnalyzer
|
|
|
|
analyzer = AIAnalyzer()
|
|
|
|
# 只使用前9条视频数据
|
|
videos_to_use = videos[:count]
|
|
actual_count = len(videos_to_use)
|
|
|
|
# 构建专门用于生成灵感的提示词
|
|
inspiration_prompt = f"""
|
|
# 创作灵感生成专家
|
|
|
|
你是一个专业的短视频创作灵感生成专家。
|
|
|
|
## 任务
|
|
|
|
基于提供的 {actual_count} 条热门视频数据,为用户生成 {actual_count} 个具体可执行的创作灵感。
|
|
|
|
**重要**:每个灵感必须对应一条视频数据,从该视频中提炼创意。
|
|
|
|
用户需求:{user_query}
|
|
|
|
## 灵感要求
|
|
|
|
每个灵感必须包含:
|
|
1. **id**:灵感编号(1-{actual_count})
|
|
2. **title**:吸引人的标题(15字以内)
|
|
3. **description**:核心创意和执行建议的综合描述(100字以内)
|
|
4. **reference_author**:参考视频的作者名
|
|
5. **reference_description**:参考视频的描述片段(30字以内)
|
|
6. **url**:参考视频作者的主页链接(authorLink字段)
|
|
7. **platform**:平台名称,固定为"抖音"
|
|
8. **tags**:推荐使用的标签数组(来自视频的hashTags)
|
|
9. **keywords**:热门关键词数组(来自视频的hotWords)
|
|
|
|
## 输出格式
|
|
|
|
请严格按照以下JSON格式输出:
|
|
|
|
```json
|
|
{{
|
|
"inspirations": [
|
|
{{
|
|
"id": 1,
|
|
"title": "灵感标题",
|
|
"description": "核心创意描述和执行建议",
|
|
"reference_author": "参考视频作者",
|
|
"reference_description": "参考视频描述片段",
|
|
"url": "参考视频作者主页链接",
|
|
"platform": "抖音",
|
|
"tags": ["#标签1", "#标签2"],
|
|
"keywords": ["关键词1", "关键词2"]
|
|
}}
|
|
]
|
|
}}
|
|
```
|
|
|
|
## 注意事项
|
|
|
|
1. **每个灵感对应一条视频**:灵感1对应视频1,灵感2对应视频2,以此类推
|
|
2. **使用视频的真实数据**:
|
|
- reference_author 使用视频的 author 字段
|
|
- reference_description 使用视频的 description 字段(截取前30字)
|
|
- url 使用视频的 authorLink 字段
|
|
- tags 使用视频的 hashTags 字段
|
|
- keywords 使用视频的 hotWords 字段
|
|
3. **description要综合**:将核心创意和执行建议合并成一段描述
|
|
4. **要具体可执行**:不要泛泛而谈,要给出具体建议
|
|
5. **要有创新性**:不是简单模仿,而是从视频中提炼创意
|
|
6. **要符合用户需求**:灵感要符合用户的意图
|
|
"""
|
|
|
|
# 格式化视频数据
|
|
video_data_text = f"## 视频数据\n总数: {actual_count}\n\n"
|
|
for i, video in enumerate(videos_to_use, 1):
|
|
video_data_text += f"### 视频 {i}\n"
|
|
video_data_text += f"- 作者: {video.get('author', 'N/A')}\n"
|
|
video_data_text += f"- 作者主页链接: {video.get('authorLink', 'N/A')}\n"
|
|
video_data_text += f"- 描述: {video.get('description', 'N/A')}\n"
|
|
if video.get('hot'):
|
|
video_data_text += f"- 热度: {video.get('hot')}\n"
|
|
if video.get('plays'):
|
|
video_data_text += f"- 播放: {video.get('plays')}\n"
|
|
if video.get('likes'):
|
|
video_data_text += f"- 点赞: {video.get('likes')}\n"
|
|
if video.get('hashTags'):
|
|
video_data_text += f"- 标签: {', '.join(video.get('hashTags'))}\n"
|
|
if video.get('hotWords'):
|
|
video_data_text += f"- 热词: {', '.join(video.get('hotWords'))}\n"
|
|
video_data_text += "\n"
|
|
|
|
# 调用AI生成灵感
|
|
import dashscope
|
|
from dashscope import Generation
|
|
|
|
# 设置API Key
|
|
api_key = os.getenv("DASHSCOPE_API_KEY")
|
|
if api_key:
|
|
dashscope.api_key = api_key
|
|
|
|
response = Generation.call(
|
|
model="qwen-plus",
|
|
messages=[
|
|
{"role": "system", "content": inspiration_prompt},
|
|
{"role": "user", "content": video_data_text}
|
|
],
|
|
result_format='message'
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
content = response.output.choices[0].message.content
|
|
|
|
# 尝试解析JSON
|
|
import re
|
|
json_match = re.search(r'```json\s*(.*?)\s*```', content, re.DOTALL)
|
|
if json_match:
|
|
import json
|
|
inspirations_data = json.loads(json_match.group(1))
|
|
inspirations = inspirations_data.get("inspirations", [])[:actual_count]
|
|
|
|
# 直接填充URL - 确保每个灵感都有对应视频的authorLink
|
|
for i, inspiration in enumerate(inspirations):
|
|
if i < len(videos_to_use):
|
|
# 直接使用对应视频的authorLink
|
|
inspiration['url'] = videos_to_use[i].get('authorLink', 'N/A')
|
|
|
|
return {
|
|
"success": True,
|
|
"inspirations": inspirations,
|
|
"total_count": len(inspirations),
|
|
"source": "creative_guidance"
|
|
}
|
|
else:
|
|
# 如果没有JSON格式,返回原始内容
|
|
return {
|
|
"success": True,
|
|
"raw_content": content,
|
|
"source": "creative_guidance"
|
|
}
|
|
else:
|
|
return {
|
|
"success": False,
|
|
"error": f"AI调用失败: {response.code}"
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": f"生成灵感时出错: {str(e)}"
|
|
}
|
|
|
|
|
|
def simple_analyze_video_data(videos: List[Dict], focus: str = None) -> Dict:
|
|
"""
|
|
简单统计分析视频数据(不使用AI)
|
|
|
|
Args:
|
|
videos: 视频数据列表
|
|
focus: 分析重点(可选)
|
|
|
|
Returns:
|
|
分析结果
|
|
"""
|
|
if not videos:
|
|
return {"success": False, "error": "没有视频数据"}
|
|
|
|
# 统计分析
|
|
total_videos = len(videos)
|
|
|
|
# 提取所有标签
|
|
all_tags = []
|
|
all_hot_words = []
|
|
|
|
for video in videos:
|
|
all_tags.extend(video.get('hashTags', []))
|
|
all_hot_words.extend(video.get('hotWords', []))
|
|
|
|
# 统计高频标签
|
|
from collections import Counter
|
|
tag_counter = Counter(all_tags)
|
|
hot_word_counter = Counter(all_hot_words)
|
|
|
|
top_tags = tag_counter.most_common(10)
|
|
top_hot_words = hot_word_counter.most_common(10)
|
|
|
|
# 互动数据分析
|
|
high_engagement_videos = []
|
|
for video in videos:
|
|
if video.get('likes') or video.get('plays'):
|
|
high_engagement_videos.append({
|
|
"author": video.get('author'),
|
|
"description": video.get('description', '')[:50] + "...",
|
|
"likes": video.get('likes'),
|
|
"plays": video.get('plays')
|
|
})
|
|
|
|
return {
|
|
"success": True,
|
|
"summary": {
|
|
"total_videos": total_videos,
|
|
"top_tags": [{"tag": tag, "count": count} for tag, count in top_tags],
|
|
"top_hot_words": [{"word": word, "count": count} for word, count in top_hot_words],
|
|
"high_engagement_count": len(high_engagement_videos)
|
|
},
|
|
"high_engagement_videos": high_engagement_videos[:5],
|
|
"focus": focus,
|
|
"analysis_type": "simple_statistics"
|
|
}
|
|
|
|
|
|
def extract_search_keywords(user_query: str) -> Dict:
|
|
"""
|
|
使用AI从用户查询中提取最匹配的内容分类
|
|
|
|
Args:
|
|
user_query: 用户查询
|
|
|
|
Returns:
|
|
提取的分类信息
|
|
"""
|
|
import dashscope
|
|
from dashscope import Generation
|
|
|
|
# 设置API Key
|
|
api_key = os.getenv("DASHSCOPE_API_KEY")
|
|
if api_key:
|
|
dashscope.api_key = api_key
|
|
|
|
# 支持的分类列表
|
|
categories = [
|
|
"美食", "旅行", "泛生活", "汽车", "科技", "游戏", "二次元",
|
|
"娱乐", "明星", "体育", "文化教育", "校园", "政务",
|
|
"时尚", "才艺", "随拍", "动植物", "图文控",
|
|
"剧情", "亲子", "三农", "创意", "户外", "公益"
|
|
]
|
|
|
|
# 构建AI提示词
|
|
prompt = f"""你是一个内容分类专家。请根据用户的描述,从以下分类中选择最匹配的一个:
|
|
|
|
{', '.join(categories)}
|
|
|
|
用户描述:{user_query}
|
|
|
|
请分析用户想要创作的内容类型,并从上述分类中选择最匹配的一个。
|
|
|
|
要求:
|
|
1. 只返回一个分类名称
|
|
2. 必须从给定的分类列表中选择
|
|
3. 如果实在无法确定,返回"泛生活"
|
|
4. 不要返回其他任何内容,只返回分类名称
|
|
|
|
分类:"""
|
|
|
|
try:
|
|
# 调用AI进行分类
|
|
response = Generation.call(
|
|
model="qwen-plus",
|
|
messages=[
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
result_format='message'
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
category = response.output.choices[0].message.content.strip()
|
|
|
|
# 验证返回的分类是否在列表中
|
|
if category in categories:
|
|
return {
|
|
"success": True,
|
|
"categories": [category],
|
|
"primary_category": category,
|
|
"original_query": user_query,
|
|
"method": "ai_classification"
|
|
}
|
|
else:
|
|
# 如果AI返回的不在列表中,尝试模糊匹配
|
|
for cat in categories:
|
|
if cat in category:
|
|
return {
|
|
"success": True,
|
|
"categories": [cat],
|
|
"primary_category": cat,
|
|
"original_query": user_query,
|
|
"method": "ai_classification_fuzzy"
|
|
}
|
|
|
|
# 都不匹配,使用默认
|
|
return {
|
|
"success": True,
|
|
"categories": ["泛生活"],
|
|
"primary_category": "泛生活",
|
|
"original_query": user_query,
|
|
"method": "default"
|
|
}
|
|
else:
|
|
# API调用失败,使用默认分类
|
|
return {
|
|
"success": True,
|
|
"categories": ["泛生活"],
|
|
"primary_category": "泛生活",
|
|
"original_query": user_query,
|
|
"method": "default"
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"AI分类出错,使用默认分类: {e}")
|
|
# 出错时使用默认分类
|
|
return {
|
|
"success": True,
|
|
"categories": ["泛生活"],
|
|
"primary_category": "泛生活",
|
|
"original_query": user_query,
|
|
"method": "default"
|
|
}
|
|
|
|
|
|
|
|
|
|
# ==================== 注册工具 ====================
|
|
|
|
def create_agent() -> AIAgent:
|
|
"""创建并配置AI代理"""
|
|
agent = AIAgent(model="qwen-plus")
|
|
|
|
# 注册工具1: 提取搜索关键词
|
|
agent.register_tool(
|
|
name="extract_search_keywords",
|
|
func=extract_search_keywords,
|
|
description="使用AI智能分析用户查询,从26个内容分类中选择最匹配的一个。这个工具会理解用户的意图并返回最合适的分类。",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"user_query": {
|
|
"type": "string",
|
|
"description": "用户的原始查询或描述"
|
|
}
|
|
},
|
|
"required": ["user_query"]
|
|
}
|
|
)
|
|
|
|
# 注册工具2: 获取创作指导数据
|
|
agent.register_tool(
|
|
name="get_creative_guidance",
|
|
func=get_creative_guidance,
|
|
description="获取抖音创作指导页面的热门视频数据,支持按分类筛选。这是获取高质量创作灵感的首选方法。",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"category": {
|
|
"type": "string",
|
|
"description": "视频分类",
|
|
"enum": ["全部", "美食", "旅行", "泛生活", "汽车", "科技", "游戏", "二次元",
|
|
"娱乐", "明星", "体育", "文化教育", "校园", "政务", "时尚", "才艺",
|
|
"随拍", "动植物", "图文控", "剧情", "亲子", "三农", "创意", "户外", "公益"]
|
|
}
|
|
},
|
|
"required": ["category"]
|
|
}
|
|
)
|
|
|
|
# 注册工具3: 搜索抖音视频
|
|
agent.register_tool(
|
|
name="search_douyin_videos",
|
|
func=search_douyin_videos,
|
|
description="根据关键词搜索抖音视频",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"keyword": {
|
|
"type": "string",
|
|
"description": "搜索关键词"
|
|
},
|
|
"max_scroll": {
|
|
"type": "integer",
|
|
"description": "最大滚动次数,默认5",
|
|
"default": 5
|
|
}
|
|
},
|
|
"required": ["keyword"]
|
|
}
|
|
)
|
|
|
|
# 注册工具4: 分析视频数据
|
|
agent.register_tool(
|
|
name="analyze_video_data",
|
|
func=analyze_video_data,
|
|
description="使用AI深度分析视频数据(基于prompts/analyze_prompt.md提示词),提取标签、热词、互动数据等统计信息,并给出专业的内容趋势分析和创作建议",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"videos": {
|
|
"type": "array",
|
|
"description": "视频数据列表",
|
|
"items": {"type": "object"}
|
|
},
|
|
"focus": {
|
|
"type": "string",
|
|
"description": "分析重点(可选),例如:热门趋势、创作建议、爆款特征等"
|
|
}
|
|
},
|
|
"required": ["videos"]
|
|
}
|
|
)
|
|
|
|
# 注册工具5: 生成创作灵感
|
|
agent.register_tool(
|
|
name="generate_creative_inspirations",
|
|
func=generate_creative_inspirations,
|
|
description="基于视频数据生成具体可执行的创作灵感。这是为用户提供创作建议的核心工具,会生成包含标题、核心创意、执行建议和热门标签的灵感列表。",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"videos": {
|
|
"type": "array",
|
|
"description": "视频数据列表",
|
|
"items": {"type": "object"}
|
|
},
|
|
"user_query": {
|
|
"type": "string",
|
|
"description": "用户的原始查询,用于理解用户需求"
|
|
},
|
|
"count": {
|
|
"type": "integer",
|
|
"description": "生成灵感的数量,默认9个",
|
|
"default": 9
|
|
}
|
|
},
|
|
"required": ["videos", "user_query"]
|
|
}
|
|
)
|
|
|
|
return agent
|
|
|
|
|
|
async def main():
|
|
"""测试函数"""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="AI Agent测试")
|
|
parser.add_argument("--query", "-q", required=True, help="用户查询")
|
|
parser.add_argument("--prompt", "-p", default="prompts/agent_prompt.md", help="系统提示词文件")
|
|
args = parser.parse_args()
|
|
|
|
print("=" * 80)
|
|
print("AI Agent 智能代理系统")
|
|
print("=" * 80)
|
|
print(f"\n用户查询: {args.query}\n")
|
|
|
|
# 创建代理
|
|
agent = create_agent()
|
|
|
|
# 运行代理
|
|
result = await agent.run(args.query, system_prompt_file=args.prompt)
|
|
|
|
print("\n" + "=" * 80)
|
|
print("执行结果:")
|
|
print("=" * 80)
|
|
|
|
if result["success"]:
|
|
print(f"\n{result['final_answer']}\n")
|
|
print(f"总迭代次数: {result['iteration']}")
|
|
print(f"工具调用次数: {len(result['tool_calls'])}")
|
|
else:
|
|
print(f"\n执行失败: {result['error']}")
|
|
if result.get('partial_result'):
|
|
print(f"部分结果: {result['partial_result']}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|
|
|