comfyui自定义节点,生成自己想要的场景
我的需求:
我想通过comfyui的节点工作流生成方式,实现一键视频生成文章的功能,这样我就能把自己喜欢的一些视频通过这种方式直接转化成PDF的形式。
实现过程
- 第一步:从bilibili网站找到直接喜欢的视频,通过视频链接下载到本地,生成对应的图片。
- 第二步:视频里提取音频,通过调用大模型生成文本。
- 第三步:通过图片结合文本方式形成PDF
实现技术
comfyui 自定义插件实现功能。
第一步已实现插件代码逻辑
在custom_nodes目录下创建自己的插件 ComfyUI-videoToArticle,如图所示:
进入插件目录,目录及文件如图:
实现第一步的三个节点源码,以下给大家分享。
__init__.py 源码:
import os
import subprocess
import sys
import importlib.util
# 检查并安装依赖的函数
def check_and_install_requirements():
requirements_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
if not os.path.exists(requirements_path):
print("未找到 requirements.txt 文件")
return
# 读取 requirements.txt
with open(requirements_path, 'r', encoding='utf-8') as f:
requirements = [line.strip() for line in f.readlines() if line.strip()]
# 检查每个依赖
for requirement in requirements:
package_name = requirement.split('>=')[0].split('==')[0].strip()
try:
importlib.util.find_spec(package_name)
except ImportError:
print(f"正在安装依赖: {requirement}")
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", requirement])
print(f"成功安装: {requirement}")
except subprocess.CalledProcessError as e:
print(f"安装失败 {requirement}: {str(e)}")
# 在导入时自动检查并安装依赖
check_and_install_requirements()
# 导入节点类
try:
from .视频获取 import VideoDownloader
from .视频帧提取 import VideoFrameExtractor
# 注册节点
NODE_CLASS_MAPPINGS = {
"VideoDownloader": VideoDownloader,
"VideoFrameExtractor": VideoFrameExtractor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VideoDownloader": "B站视频下载器",
"VideoFrameExtractor": "视频帧提取器"
}
except ImportError as e:
print(f"导入节点类时出错: {str(e)}")
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
视频获取.py 源码:
import os
import torch
import cv2
from bilibili_api import video, Credential, sync
import aiohttp
import asyncio
import re
import time
import subprocess
class VideoDownloader:
def __init__(self):
# 创建下载目录
self.output_dir = "downloaded_videos"
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
# B站请求头
self.headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Referer': 'https://www.bilibili.com',
'Accept': '*/*',
'Origin': 'https://www.bilibili.com',
'Accept-Encoding': 'gzip, deflate, br',
'Accept-Language': 'zh-CN,zh;q=0.9',
}
CATEGORY = "视频转文章"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"视频链接": ("STRING", {
"default": "",
"multiline": False,
"placeholder": "请输入B站视频URL或BV号"
}),
"预览帧": ("INT", {
"default": 0,
"min": 0,
"max": 10000,
"step": 1,
"display": "number"
}),
},
"optional": {
"SESSDATA": ("STRING", {
"default": "",
"multiline": False,
"placeholder": "输入B站SESSDATA(可选)"
})
}
}
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("视频路径", "预览图像")
FUNCTION = "download_and_preview"
def extract_bvid(self, url):
# 从URL中提取BV号
bv_pattern = r'BV[a-zA-Z0-9]+'
match = re.search(bv_pattern, url)
if match:
return match.group()
return url
async def download_bilibili_video(self, url, sessdata=None):
bvid = self.extract_bvid(url)
temp_video = None
temp_audio = None
try:
# 检查 ffmpeg 是否可用
try:
subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)
except (subprocess.SubprocessError, FileNotFoundError):
raise ValueError("未找到 ffmpeg,请先安装 ffmpeg 并确保其在系统路径中")
credential = None
if sessdata:
credential = Credential(sessdata=sessdata)
v = video.Video(bvid=bvid, credential=credential)
video_info = await v.get_info()
title = video_info['title']
# 安全的文件名
safe_title = "".join(x for x in title if x.isalnum() or x in (' ','-','_')).rstrip()
video_path = os.path.join(self.output_dir, f"{safe_title}.mp4")
temp_video = os.path.join(self.output_dir, f"{safe_title}_temp_video.m4s")
temp_audio = os.path.join(self.output_dir, f"{safe_title}_temp_audio.m4s")
# 如果文件已存在,先删除
for file in [video_path, temp_video, temp_audio]:
if os.path.exists(file):
os.remove(file)
video_url = await v.get_download_url(0)
video_stream_url = video_url['dash']['video'][0]['baseUrl']
audio_stream_url = video_url['dash']['audio'][0]['baseUrl']
print(f"开始下载视频: {safe_title}")
# 下载视频流
async with aiohttp.ClientSession() as session:
# 下载视频部分
print("下载视频流...")
async with session.get(video_stream_url, headers=self.headers) as resp:
if resp.status != 200:
raise ValueError(f"视频下载失败,状态码:{resp.status}")
with open(temp_video, 'wb') as f:
async for chunk in resp.content.iter_chunked(1024*1024):
f.write(chunk)
# 下载音频部分
print("下载音频流...")
async with session.get(audio_stream_url, headers=self.headers) as resp:
if resp.status != 200:
raise ValueError(f"音频下载失败,状态码:{resp.status}")
with open(temp_audio, 'wb') as f:
async for chunk in resp.content.iter_chunked(1024*1024):
f.write(chunk)
# 检查临时文件是否存在
if not os.path.exists(temp_video) or not os.path.exists(temp_audio):
raise ValueError("临时文件下载失败")
print("合并音视频...")
# 使用绝对路径执行ffmpeg
ffmpeg_cmd = [
'ffmpeg',
'-i', os.path.abspath(temp_video),
'-i', os.path.abspath(temp_audio),
'-c', 'copy',
os.path.abspath(video_path),
'-y'
]
process = subprocess.Popen(
ffmpeg_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
stdout, stderr = process.communicate()
if process.returncode != 0:
print(f"FFmpeg错误输出: {stderr}")
raise ValueError(f"FFmpeg合并失败,返回码: {process.returncode}")
# 检查输出文件
if not os.path.exists(video_path):
raise ValueError("合并后的视频文件未生成")
print("清理临时文件...")
# 清理临时文件
for file in [temp_video, temp_audio]:
if os.path.exists(file):
os.remove(file)
# 等待文件写入完成
time.sleep(1)
print("验证视频文件...")
# 验证文件是否可读
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("无法打开合并后的视频文件")
cap.release()
print("视频处理完成")
return video_path
except Exception as e:
# 清理所有临时文件
if temp_video and os.path.exists(temp_video):
os.remove(temp_video)
if temp_audio and os.path.exists(temp_audio):
os.remove(temp_audio)
raise ValueError(f"视频处理失败: {str(e)}")
def download_and_preview(self, 视频链接, 预览帧, SESSDATA=""):
if not 视频链接:
raise ValueError("请输入有效的视频URL")
try:
video_path = asyncio.run(self.download_bilibili_video(视频链接, SESSDATA))
# 等待文件完全写入
time.sleep(1)
# 尝试多次打开视频文件
max_attempts = 3
for attempt in range(max_attempts):
cap = cv2.VideoCapture(video_path)
if cap.isOpened():
break
time.sleep(1)
if not cap.isOpened():
raise ValueError("无法打开视频文件")
# 获取实际帧数
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if 预览帧 >= total_frames:
预览帧 = 0
cap.set(cv2.CAP_PROP_POS_FRAMES, 预览帧)
ret, frame = cap.read()
if not ret:
raise ValueError(f"无法读取视频帧 {预览帧}")
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
preview_image = torch.from_numpy(frame_rgb).float() / 255.0
preview_image = preview_image.unsqueeze(0)
cap.release()
return (video_path, preview_image)
except Exception as e:
raise ValueError(f"下载或处理视频时出错: {str(e)}")
@classmethod
def IS_CHANGED(cls, 视频链接, 预览帧, SESSDATA=""):
return float("nan")
# 节点注册
NODE_CLASS_MAPPINGS = {
"VideoDownloader": VideoDownloader
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VideoDownloader": "B站视频下载器"
}
视频帧提取.py 源码
# -*- coding: utf-8 -*-
import os
import cv2
import torch
import numpy as np
from PIL import Image
import sys
class VideoFrameExtractor:
def __init__(self):
# 创建输出目录
self.base_output_dir = "extracted_frames"
if not os.path.exists(self.base_output_dir):
os.makedirs(self.base_output_dir)
CATEGORY = "视频转文章"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"视频路径": ("STRING", {"forceInput": True}),
"起始帧": ("INT", {
"default": 0,
"min": 0,
"max": 100000,
"step": 1,
"display": "number"
}),
"提取间隔": ("INT", {
"default": 30,
"min": 1,
"max": 300,
"step": 1,
"display": "number"
}),
"最大提取数": ("INT", {
"default": 20,
"min": 1,
"max": 100,
"step": 1,
"display": "number"
}),
},
"optional": {
"结束帧": ("INT", {
"default": -1,
"min": -1,
"max": 100000,
"step": 1,
"display": "number"
}),
"保存帧": ("BOOLEAN", {"default": True}),
"子目录名": ("STRING", {
"default": "",
"multiline": False,
"placeholder": "可选,留空则使用视频文件名"
}),
}
}
RETURN_TYPES = ("IMAGE", "STRING")
RETURN_NAMES = ("帧序列", "帧路径列表")
FUNCTION = "extract_frames"
def create_output_dir(self, video_path, sub_dir=""):
try:
if not sub_dir:
sub_dir = os.path.splitext(os.path.basename(video_path))[0]
output_dir = os.path.join(self.base_output_dir, sub_dir)
original_dir = output_dir
counter = 1
while os.path.exists(output_dir):
output_dir = f"{original_dir}_{counter}"
counter += 1
os.makedirs(output_dir)
return output_dir
except Exception as e:
print(f"创建目录时出错: {str(e)}")
import time
backup_dir = os.path.join(self.base_output_dir, f"frames_{int(time.time())}")
os.makedirs(backup_dir, exist_ok=True)
return backup_dir
def extract_frames(self, 视频路径, 起始帧, 提取间隔, 最大提取数, 结束帧=-1, 保存帧=True, 子目录名=""):
if not os.path.exists(视频路径):
raise ValueError(f"视频文件不存在: {视频路径}")
cap = None
try:
output_dir = self.create_output_dir(视频路径, 子目录名) if 保存帧 else None
cap = cv2.VideoCapture(视频路径)
if not cap.isOpened():
raise ValueError("无法打开视频文件")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
if 结束帧 == -1 or 结束帧 >= total_frames:
结束帧 = total_frames - 1
if 起始帧 < 0 or 起始帧 >= total_frames:
raise ValueError(f"起始帧超出范围 (0-{total_frames-1})")
if 结束帧 < 起始帧:
raise ValueError(f"结束帧必须大于起始帧")
帧范围 = 结束帧 - 起始帧 + 1
实际间隔 = max(提取间隔, int(帧范围 / 最大提取数))
帧位置列表 = range(起始帧, 结束帧 + 1, 实际间隔)
帧位置列表 = list(帧位置列表)[:最大提取数]
frames = []
frame_paths = []
print(f"开始提取帧,范围:{起始帧}-{结束帧},间隔:{实际间隔},计划提取:{len(帧位置列表)}帧")
for frame_pos in 帧位置列表:
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)
ret, frame = cap.read()
if not ret:
print(f"警告:无法读取帧位置 {frame_pos}")
continue
# 转换颜色空间并确保格式正确
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_tensor = torch.from_numpy(frame_rgb).float() / 255.0
# 确保维度正确 (H, W, C)
if len(frame_tensor.shape) == 3:
frames.append(frame_tensor)
if 保存帧 and output_dir:
frame_path = os.path.join(output_dir, f"frame_{frame_pos:06d}.png")
Image.fromarray((frame_rgb).astype(np.uint8)).save(frame_path)
frame_paths.append(frame_path)
print(f"已保存帧 {frame_pos}: {frame_path}")
if not frames:
raise ValueError("没有成功提取到帧")
# 堆叠所有帧并确保格式正确 (N, H, W, C)
frames_tensor = torch.stack(frames)
frame_paths_str = ",".join(frame_paths) if frame_paths else ""
print(f"帧提取完成,共提取{len(frames)}帧")
if 保存帧:
print(f"帧已保存到目录: {output_dir}")
return (frames_tensor, frame_paths_str)
except Exception as e:
raise ValueError(f"提取帧时出错: {str(e)}")
finally:
if cap is not None:
cap.release()
@classmethod
def IS_CHANGED(cls, 视频路径, *args):
return float("nan")
# 节点注册
NODE_CLASS_MAPPINGS = {
"VideoFrameExtractor": VideoFrameExtractor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VideoFrameExtractor": "视频帧提取器"
}
注意:python文件的依赖要下载。