新增功能: - 支持图片格式控制:webp, jpg, png - 新增#生图格式指令 - 新增尺寸快捷方式:正方形、横图、竖图、宽屏、手机、小图 - 使用 PIL 进行格式转换,webp 默认质量 85 - 优化文件体积:webp 比 PNG 小 14 倍(93KB vs 1.3MB) 测试结果: - ✅ webp 格式生成成功(1280*720 横图) - ✅ 文件大小:93KB - ✅ 格式验证通过:RIFF (little-endian) data
414 lines
15 KiB
Python
414 lines
15 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
WordPress 发布系统 - AI 图片生成模块
|
||
基于阿里云 DashScope API 实现文生图功能
|
||
支持:通义万相 (wanx-v1, wanx2.1-t2i-turbo, wanx2.1-t2i-plus)
|
||
支持格式:jpg, png, webp
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import time
|
||
import requests
|
||
from datetime import datetime
|
||
|
||
from modules.wp_logger import get_publish_logger, get_debug_logger
|
||
|
||
# 基础目录
|
||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
IMAGE_DIR = os.path.join(BASE_DIR, 'temp')
|
||
os.makedirs(IMAGE_DIR, exist_ok=True)
|
||
|
||
|
||
class ImageGenerator:
|
||
"""AI 图片生成器"""
|
||
|
||
# 支持的模型列表
|
||
SUPPORTED_MODELS = {
|
||
'wanx-v1': {
|
||
'name': '通义万相 v1',
|
||
'base_url': 'https://dashscope.aliyuncs.com/api/v1/services/aigc/text2image/image-synthesis',
|
||
'sizes': ['1024*1024', '720*1280', '1280*720', '512*512', '512*1024', '1024*512'],
|
||
'max_images': 4
|
||
},
|
||
'wanx2.1-t2i-turbo': {
|
||
'name': '通义万相 v2 Turbo',
|
||
'base_url': 'https://dashscope.aliyuncs.com/api/v1/services/aigc/text2image/image-synthesis',
|
||
'sizes': ['1024*1024', '720*1280', '1280*720', '512*512', '1440*720', '720*1440'],
|
||
'max_images': 4
|
||
},
|
||
'wanx2.1-t2i-plus': {
|
||
'name': '通义万相 v2 Plus',
|
||
'base_url': 'https://dashscope.aliyuncs.com/api/v1/services/aigc/text2image/image-synthesis',
|
||
'sizes': ['1024*1024', '720*1280', '1280*720', '512*512', '1440*720', '720*1440'],
|
||
'max_images': 4
|
||
}
|
||
}
|
||
|
||
# 支持的图片格式
|
||
SUPPORTED_FORMATS = ['jpg', 'png', 'webp']
|
||
|
||
# 默认尺寸
|
||
DEFAULT_SIZE = '1024*1024'
|
||
|
||
# 默认格式
|
||
DEFAULT_FORMAT = 'webp'
|
||
|
||
# 尺寸快捷方式
|
||
SIZE_PRESETS = {
|
||
'正方形': '1024*1024',
|
||
'横图': '1280*720',
|
||
'竖图': '720*1280',
|
||
'宽屏': '1440*720',
|
||
'手机': '720*1440',
|
||
'小图': '512*512'
|
||
}
|
||
|
||
def __init__(self, api_key=None, model='wanx-v1', size=None, image_format=None):
|
||
"""
|
||
初始化图片生成器
|
||
|
||
Args:
|
||
api_key: DashScope API Key
|
||
model: 模型名称
|
||
size: 图片尺寸
|
||
image_format: 图片格式 (jpg, png, webp)
|
||
"""
|
||
self.api_key = api_key or self._get_api_key_from_env()
|
||
self.model = model if model in self.SUPPORTED_MODELS else 'wanx-v1'
|
||
self.size = size or self.DEFAULT_SIZE
|
||
self.image_format = image_format or self.DEFAULT_FORMAT
|
||
|
||
# 验证尺寸
|
||
model_config = self.SUPPORTED_MODELS[self.model]
|
||
if self.size not in model_config['sizes']:
|
||
self.size = self.DEFAULT_SIZE
|
||
|
||
# 验证格式
|
||
if self.image_format not in self.SUPPORTED_FORMATS:
|
||
self.image_format = self.DEFAULT_FORMAT
|
||
|
||
self.base_url = model_config['base_url']
|
||
self.max_images = model_config['max_images']
|
||
|
||
self.pl = get_publish_logger()
|
||
self.dl = get_debug_logger()
|
||
|
||
def generate_image(self, prompt, negative_prompt=None, n=1, style=None):
|
||
"""
|
||
生成图片(异步提交 + 轮询结果)
|
||
|
||
Args:
|
||
prompt: 图片描述(中文)
|
||
negative_prompt: 反向提示词
|
||
n: 生成图片数量 (1-4)
|
||
style: 风格(如:写实、动漫、水彩等)
|
||
|
||
Returns:
|
||
list: 生成的图片本地路径列表
|
||
"""
|
||
if n < 1 or n > self.max_images:
|
||
n = 1
|
||
|
||
self.pl.info(f"🎨 AI 生图开始 - 模型:{self.model}, 尺寸:{self.size}, 格式:{self.image_format}, 数量:{n}")
|
||
self.dl.log_step("AI 生图", f"模型:{self.model}, 尺寸:{self.size}, 格式:{self.image_format}")
|
||
self.dl.debug(f"提示词:{prompt}")
|
||
if negative_prompt:
|
||
self.dl.debug(f"反向提示词:{negative_prompt}")
|
||
|
||
try:
|
||
# 提交生图任务
|
||
task_id = self._submit_task(prompt, negative_prompt, n, style)
|
||
if not task_id:
|
||
return []
|
||
|
||
# 轮询等待结果
|
||
image_urls = self._poll_task(task_id)
|
||
|
||
if not image_urls:
|
||
self.pl.error("生图任务失败:未获取到图片 URL")
|
||
return []
|
||
|
||
# 下载图片到本地(转换为指定格式)
|
||
local_paths = self._download_images(image_urls, prompt)
|
||
|
||
self.pl.success(f"AI 生图完成 - 成功 {len(local_paths)} 张图片")
|
||
self.dl.log_result("生图结果", {
|
||
'model': self.model,
|
||
'format': self.image_format,
|
||
'count': len(local_paths),
|
||
'paths': local_paths
|
||
})
|
||
|
||
return local_paths
|
||
|
||
except Exception as e:
|
||
self.pl.error(f"AI 生图异常:{str(e)}")
|
||
self.dl.error(f"生图异常:{str(e)}", exc_info=True)
|
||
return []
|
||
|
||
def _submit_task(self, prompt, negative_prompt=None, n=1, style=None):
|
||
"""提交生图任务"""
|
||
headers = {
|
||
'Authorization': f'Bearer {self.api_key}',
|
||
'Content-Type': 'application/json',
|
||
'X-DashScope-Async': 'enable'
|
||
}
|
||
|
||
payload = {
|
||
'model': self.model,
|
||
'input': {
|
||
'prompt': prompt
|
||
},
|
||
'parameters': {
|
||
'size': self.size,
|
||
'n': n
|
||
}
|
||
}
|
||
|
||
if negative_prompt:
|
||
payload['parameters']['negative_prompt'] = negative_prompt
|
||
|
||
if style:
|
||
payload['parameters']['style'] = style
|
||
|
||
try:
|
||
response = requests.post(
|
||
self.base_url,
|
||
headers=headers,
|
||
json=payload,
|
||
timeout=30
|
||
)
|
||
|
||
self.dl.debug(f"提交任务响应:{response.text[:500]}")
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
task_id = result.get('output', {}).get('task_id')
|
||
if task_id:
|
||
self.pl.info(f"📋 生图任务已提交 - Task ID: {task_id}")
|
||
self.dl.debug(f"任务 ID: {task_id}")
|
||
return task_id
|
||
else:
|
||
self.pl.error(f"未获取到任务 ID:{result}")
|
||
return None
|
||
else:
|
||
error_msg = response.text
|
||
self.pl.error(f"提交生图任务失败 - 状态码:{response.status_code}")
|
||
self.pl.error(f"错误详情:{error_msg}")
|
||
self.dl.error(f"提交失败:{error_msg}")
|
||
return None
|
||
|
||
except requests.exceptions.Timeout:
|
||
self.pl.error("提交生图任务超时")
|
||
return None
|
||
except Exception as e:
|
||
self.pl.error(f"提交生图任务异常:{str(e)}")
|
||
return None
|
||
|
||
def _poll_task(self, task_id, max_retries=60, interval=5):
|
||
"""轮询任务状态"""
|
||
url = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
|
||
headers = {
|
||
'Authorization': f'Bearer {self.api_key}'
|
||
}
|
||
|
||
self.pl.info(f"⏳ 等待生图完成...")
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
response = requests.get(url, headers=headers, timeout=30)
|
||
|
||
if response.status_code != 200:
|
||
self.dl.warning(f"查询任务状态失败:{response.status_code}")
|
||
time.sleep(interval)
|
||
continue
|
||
|
||
result = response.json()
|
||
status = result.get('output', {}).get('task_status', '')
|
||
|
||
self.dl.debug(f"任务状态:{status} (尝试 {attempt + 1}/{max_retries})")
|
||
|
||
if status == 'SUCCEEDED':
|
||
image_urls = []
|
||
results = result.get('output', {}).get('results', [])
|
||
for item in results:
|
||
url = item.get('url', '')
|
||
if url:
|
||
image_urls.append(url)
|
||
|
||
if image_urls:
|
||
self.pl.success(f"✅ 生图任务完成 - 获取到 {len(image_urls)} 张图片 URL")
|
||
return image_urls
|
||
else:
|
||
self.pl.error("任务完成但未获取到图片 URL")
|
||
return []
|
||
|
||
elif status == 'FAILED':
|
||
message = result.get('output', {}).get('message', '未知错误')
|
||
self.pl.error(f"生图任务失败:{message}")
|
||
self.dl.error(f"任务失败:{message}")
|
||
return []
|
||
|
||
elif status in ['PENDING', 'RUNNING']:
|
||
# 继续等待
|
||
if (attempt + 1) % 6 == 0: # 每 30 秒输出一次日志
|
||
self.pl.info(f"⏳ 生图进行中... ({attempt * interval} 秒)")
|
||
time.sleep(interval)
|
||
continue
|
||
|
||
else:
|
||
self.dl.warning(f"未知任务状态:{status}")
|
||
time.sleep(interval)
|
||
continue
|
||
|
||
except requests.exceptions.Timeout:
|
||
self.dl.warning("查询任务状态超时")
|
||
time.sleep(interval)
|
||
continue
|
||
except Exception as e:
|
||
self.pl.error(f"查询任务状态异常:{str(e)}")
|
||
time.sleep(interval)
|
||
continue
|
||
|
||
self.pl.error(f"生图任务超时({max_retries * interval} 秒)")
|
||
return []
|
||
|
||
def _download_images(self, image_urls, prompt):
|
||
"""下载图片到本地并转换为指定格式"""
|
||
local_paths = []
|
||
|
||
for i, url in enumerate(image_urls):
|
||
try:
|
||
response = requests.get(url, timeout=30)
|
||
|
||
if response.status_code == 200:
|
||
# 使用 PIL 进行格式转换
|
||
local_path = self._convert_and_save_image(
|
||
response.content,
|
||
prompt,
|
||
i + 1
|
||
)
|
||
if local_path:
|
||
local_paths.append(local_path)
|
||
else:
|
||
self.pl.error(f"下载图片失败:{url} (状态码:{response.status_code})")
|
||
|
||
except Exception as e:
|
||
self.pl.error(f"下载图片异常:{url} - {str(e)}")
|
||
|
||
return local_paths
|
||
|
||
def _convert_and_save_image(self, image_data, prompt, index):
|
||
"""
|
||
将图片数据转换格式并保存
|
||
|
||
Args:
|
||
image_data: 图片二进制数据
|
||
prompt: 提示词(用于生成文件名)
|
||
index: 图片索引
|
||
|
||
Returns:
|
||
str: 保存的文件路径
|
||
"""
|
||
try:
|
||
from PIL import Image
|
||
from io import BytesIO
|
||
|
||
# 打开图片
|
||
img = Image.open(BytesIO(image_data))
|
||
|
||
# 确保是 RGB 模式(webp 不支持 RGBA)
|
||
if img.mode == 'RGBA':
|
||
# 创建白色背景
|
||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||
background.paste(img, mask=img.split()[3]) # 使用 alpha 通道作为 mask
|
||
img = background
|
||
elif img.mode != 'RGB':
|
||
img = img.convert('RGB')
|
||
|
||
# 生成文件名
|
||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
prompt_hash = abs(hash(prompt)) % 10000
|
||
ext = self.image_format if self.image_format == 'jpg' else self.image_format
|
||
filename = f"ai_image_{timestamp}_{index}_{prompt_hash}.{ext}"
|
||
filepath = os.path.join(IMAGE_DIR, filename)
|
||
|
||
# 保存图片
|
||
if self.image_format == 'webp':
|
||
img.save(filepath, 'WEBP', quality=85)
|
||
elif self.image_format == 'jpg':
|
||
img.save(filepath, 'JPEG', quality=95)
|
||
elif self.image_format == 'png':
|
||
img.save(filepath, 'PNG')
|
||
|
||
file_size = os.path.getsize(filepath)
|
||
self.dl.debug(f"图片已保存:{filepath} ({file_size} 字节, 格式:{self.image_format})")
|
||
|
||
return filepath
|
||
|
||
except ImportError:
|
||
self.pl.error("PIL 未安装,无法转换图片格式,保存原始图片")
|
||
# 降级:保存原始图片
|
||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
prompt_hash = abs(hash(prompt)) % 10000
|
||
filename = f"ai_image_{timestamp}_{index}_{prompt_hash}.png"
|
||
filepath = os.path.join(IMAGE_DIR, filename)
|
||
with open(filepath, 'wb') as f:
|
||
f.write(image_data)
|
||
return filepath
|
||
except Exception as e:
|
||
self.pl.error(f"转换图片格式失败:{str(e)}")
|
||
return None
|
||
|
||
def _get_api_key_from_env(self):
|
||
"""从环境变量获取 API Key"""
|
||
return os.environ.get('DASHSCOPE_API_KEY', '')
|
||
|
||
def get_supported_models(self):
|
||
"""获取支持的模型列表"""
|
||
return list(self.SUPPORTED_MODELS.keys())
|
||
|
||
def get_model_info(self, model=None):
|
||
"""获取模型信息"""
|
||
if model:
|
||
return self.SUPPORTED_MODELS.get(model, {})
|
||
return self.SUPPORTED_MODELS
|
||
|
||
def get_supported_formats(self):
|
||
"""获取支持的格式列表"""
|
||
return self.SUPPORTED_FORMATS
|
||
|
||
def get_size_presets(self):
|
||
"""获取尺寸快捷方式"""
|
||
return self.SIZE_PRESETS
|
||
|
||
|
||
def create_image_generator(api_key=None, model='wanx-v1', size=None, image_format=None):
|
||
"""创建图片生成器实例"""
|
||
return ImageGenerator(api_key=api_key, model=model, size=size, image_format=image_format)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
import sys
|
||
|
||
if len(sys.argv) < 3:
|
||
print("用法:python wp_image_generator.py <api_key> <prompt> [model] [size] [count] [format]")
|
||
print("示例:python wp_image_generator.py sk-xxx '一只可爱的猫咪' wanx-v1 1024*1024 1 webp")
|
||
sys.exit(1)
|
||
|
||
api_key = sys.argv[1]
|
||
prompt = sys.argv[2]
|
||
model = sys.argv[3] if len(sys.argv) > 3 else 'wanx-v1'
|
||
size = sys.argv[4] if len(sys.argv) > 4 else '1024*1024'
|
||
count = int(sys.argv[5]) if len(sys.argv) > 5 else 1
|
||
image_format = sys.argv[6] if len(sys.argv) > 6 else 'webp'
|
||
|
||
generator = create_image_generator(api_key=api_key, model=model, size=size, image_format=image_format)
|
||
paths = generator.generate_image(prompt, n=count)
|
||
|
||
print(f"\n生成结果:")
|
||
for path in paths:
|
||
print(f" {path}")
|