You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
242 lines
7.9 KiB
242 lines
7.9 KiB
"""
|
|
阿里云百炼AI分析模块
|
|
用于分析抖音数据
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
from pathlib import Path
|
|
from typing import List, Dict, Optional
|
|
import dashscope
|
|
from dashscope import Generation
|
|
from dotenv import load_dotenv
|
|
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
|
|
class AIAnalyzer:
|
|
"""AI数据分析器"""
|
|
|
|
def __init__(self, api_key: str = None, model: str = "qwen-plus"):
|
|
"""
|
|
初始化AI分析器
|
|
|
|
Args:
|
|
api_key: 阿里云百炼API Key,如果不提供则从环境变量读取
|
|
model: 使用的模型名称,默认qwen-plus
|
|
"""
|
|
self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY")
|
|
if not self.api_key:
|
|
raise ValueError("未提供API Key,请设置DASHSCOPE_API_KEY环境变量或传入api_key参数")
|
|
|
|
dashscope.api_key = self.api_key
|
|
self.model = model
|
|
|
|
def load_prompt(self, prompt_file: str) -> str:
|
|
"""
|
|
加载提示词文件
|
|
|
|
Args:
|
|
prompt_file: 提示词文件路径
|
|
|
|
Returns:
|
|
提示词内容
|
|
"""
|
|
try:
|
|
with open(prompt_file, 'r', encoding='utf-8') as f:
|
|
return f.read()
|
|
except Exception as e:
|
|
raise Exception(f"加载提示词文件失败: {e}")
|
|
|
|
def format_video_data(self, videos: List[Dict]) -> str:
|
|
"""
|
|
格式化视频数据为文本
|
|
|
|
Args:
|
|
videos: 视频数据列表
|
|
|
|
Returns:
|
|
格式化后的文本
|
|
"""
|
|
data_text = f"## 数据总览\n总视频数: {len(videos)}\n\n## 视频详情\n\n"
|
|
|
|
for i, video in enumerate(videos, 1):
|
|
data_text += f"### 视频 {i}\n"
|
|
data_text += f"- 作者: {video.get('author', 'N/A')}\n"
|
|
data_text += f"- 描述: {video.get('description', 'N/A')}\n"
|
|
|
|
if video.get('duration'):
|
|
data_text += f"- 时长: {video.get('duration')}\n"
|
|
|
|
if video.get('hot'):
|
|
data_text += f"- 热度: {video.get('hot')}\n"
|
|
if video.get('plays'):
|
|
data_text += f"- 播放量: {video.get('plays')}\n"
|
|
if video.get('likes'):
|
|
data_text += f"- 点赞: {video.get('likes')}\n"
|
|
if video.get('comments'):
|
|
data_text += f"- 评论: {video.get('comments')}\n"
|
|
|
|
if video.get('hashTags'):
|
|
data_text += f"- 标签: {', '.join(video.get('hashTags'))}\n"
|
|
|
|
if video.get('hotWords'):
|
|
data_text += f"- 热词: {', '.join(video.get('hotWords'))}\n"
|
|
|
|
data_text += "\n"
|
|
|
|
return data_text
|
|
|
|
def analyze(self, videos: List[Dict], prompt_file: str = "prompts/analyze_prompt.md",
|
|
custom_instruction: str = None) -> Dict:
|
|
"""
|
|
分析视频数据
|
|
|
|
Args:
|
|
videos: 视频数据列表
|
|
prompt_file: 提示词文件路径
|
|
custom_instruction: 自定义分析指令(可选)
|
|
|
|
Returns:
|
|
分析结果字典
|
|
"""
|
|
try:
|
|
# 加载提示词
|
|
system_prompt = self.load_prompt(prompt_file)
|
|
|
|
# 格式化视频数据
|
|
video_data_text = self.format_video_data(videos)
|
|
|
|
# 构建用户消息
|
|
user_message = f"{video_data_text}\n\n"
|
|
if custom_instruction:
|
|
user_message += f"特殊要求: {custom_instruction}\n\n"
|
|
user_message += "请根据以上数据进行分析。"
|
|
|
|
# 调用API
|
|
messages = [
|
|
{'role': 'system', 'content': system_prompt},
|
|
{'role': 'user', 'content': user_message}
|
|
]
|
|
|
|
response = Generation.call(
|
|
model=self.model,
|
|
messages=messages,
|
|
result_format='message'
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
analysis_result = response.output.choices[0].message.content
|
|
|
|
return {
|
|
"success": True,
|
|
"analysis": analysis_result,
|
|
"model": self.model,
|
|
"video_count": len(videos),
|
|
"usage": {
|
|
"input_tokens": response.usage.input_tokens,
|
|
"output_tokens": response.usage.output_tokens,
|
|
"total_tokens": response.usage.total_tokens
|
|
}
|
|
}
|
|
else:
|
|
return {
|
|
"success": False,
|
|
"error": f"API调用失败: {response.code} - {response.message}",
|
|
"video_count": len(videos)
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": f"分析过程出错: {str(e)}",
|
|
"video_count": len(videos)
|
|
}
|
|
|
|
def analyze_from_file(self, json_file: str, prompt_file: str = "prompts/analyze_prompt.md",
|
|
custom_instruction: str = None) -> Dict:
|
|
"""
|
|
从JSON文件读取数据并分析
|
|
|
|
Args:
|
|
json_file: JSON数据文件路径
|
|
prompt_file: 提示词文件路径
|
|
custom_instruction: 自定义分析指令(可选)
|
|
|
|
Returns:
|
|
分析结果字典
|
|
"""
|
|
try:
|
|
with open(json_file, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
|
|
videos = data.get('videos', [])
|
|
if not videos:
|
|
return {
|
|
"success": False,
|
|
"error": "JSON文件中没有视频数据",
|
|
"video_count": 0
|
|
}
|
|
|
|
return self.analyze(videos, prompt_file, custom_instruction)
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": f"读取文件失败: {str(e)}",
|
|
"video_count": 0
|
|
}
|
|
|
|
|
|
def main():
|
|
"""测试函数"""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="AI数据分析工具")
|
|
parser.add_argument("--file", "-f", required=True, help="JSON数据文件路径")
|
|
parser.add_argument("--prompt", "-p", default="prompts/analyze_prompt.md", help="提示词文件路径")
|
|
parser.add_argument("--instruction", "-i", help="自定义分析指令")
|
|
parser.add_argument("--model", "-m", default="qwen-plus", help="模型名称")
|
|
parser.add_argument("--output", "-o", help="输出文件路径(可选)")
|
|
args = parser.parse_args()
|
|
|
|
# 创建分析器
|
|
analyzer = AIAnalyzer(model=args.model)
|
|
|
|
# 分析数据
|
|
print(f"正在分析文件: {args.file}")
|
|
print(f"使用模型: {args.model}")
|
|
print(f"提示词文件: {args.prompt}")
|
|
print()
|
|
|
|
result = analyzer.analyze_from_file(
|
|
json_file=args.file,
|
|
prompt_file=args.prompt,
|
|
custom_instruction=args.instruction
|
|
)
|
|
|
|
if result["success"]:
|
|
print("=" * 80)
|
|
print("分析结果:")
|
|
print("=" * 80)
|
|
print(result["analysis"])
|
|
print()
|
|
print("=" * 80)
|
|
print(f"统计信息:")
|
|
print(f" 视频数量: {result['video_count']}")
|
|
print(f" 输入Token: {result['usage']['input_tokens']}")
|
|
print(f" 输出Token: {result['usage']['output_tokens']}")
|
|
print(f" 总Token: {result['usage']['total_tokens']}")
|
|
|
|
# 保存结果
|
|
if args.output:
|
|
with open(args.output, 'w', encoding='utf-8') as f:
|
|
f.write(result["analysis"])
|
|
print(f"\n✓ 分析结果已保存到: {args.output}")
|
|
else:
|
|
print(f"✗ 分析失败: {result['error']}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|