comfyui自定义节点,生成自己想要的场景

comfyui自定义节点,生成自己想要的场景

编码文章call10242024-12-19 11:41:0325A+A-

我的需求:

我想通过comfyui的节点工作流生成方式,实现一键视频生成文章的功能,这样我就能把自己喜欢的一些视频通过这种方式直接转化成PDF的形式。

实现过程

  1. 第一步:从bilibili网站找到直接喜欢的视频,通过视频链接下载到本地,生成对应的图片。
  2. 第二步:视频里提取音频,通过调用大模型生成文本。
  3. 第三步:通过图片结合文本方式形成PDF

实现技术

comfyui 自定义插件实现功能。

第一步已实现插件代码逻辑

在custom_nodes目录下创建自己的插件 ComfyUI-videoToArticle,如图所示:


进入插件目录,目录及文件如图:


实现第一步的三个节点源码,以下给大家分享。

__init__.py 源码:

import os
import subprocess
import sys
import importlib.util

# 检查并安装依赖的函数
def check_and_install_requirements():
    requirements_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
    
    if not os.path.exists(requirements_path):
        print("未找到 requirements.txt 文件")
        return
    
    # 读取 requirements.txt
    with open(requirements_path, 'r', encoding='utf-8') as f:
        requirements = [line.strip() for line in f.readlines() if line.strip()]
    
    # 检查每个依赖
    for requirement in requirements:
        package_name = requirement.split('>=')[0].split('==')[0].strip()
        try:
            importlib.util.find_spec(package_name)
        except ImportError:
            print(f"正在安装依赖: {requirement}")
            try:
                subprocess.check_call([sys.executable, "-m", "pip", "install", requirement])
                print(f"成功安装: {requirement}")
            except subprocess.CalledProcessError as e:
                print(f"安装失败 {requirement}: {str(e)}")

# 在导入时自动检查并安装依赖
check_and_install_requirements()

# 导入节点类
try:
    from .视频获取 import VideoDownloader
    from .视频帧提取 import VideoFrameExtractor
    
    # 注册节点
    NODE_CLASS_MAPPINGS = {
        "VideoDownloader": VideoDownloader,
        "VideoFrameExtractor": VideoFrameExtractor
    }

    NODE_DISPLAY_NAME_MAPPINGS = {
        "VideoDownloader": "B站视频下载器",
        "VideoFrameExtractor": "视频帧提取器"
    }

except ImportError as e:
    print(f"导入节点类时出错: {str(e)}")
    NODE_CLASS_MAPPINGS = {}
    NODE_DISPLAY_NAME_MAPPINGS = {}

__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']

视频获取.py 源码:

import os
import torch
import cv2
from bilibili_api import video, Credential, sync
import aiohttp
import asyncio
import re
import time
import subprocess

class VideoDownloader:
    def __init__(self):
        # 创建下载目录
        self.output_dir = "downloaded_videos"
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
            
        # B站请求头
        self.headers = {
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
            'Referer': 'https://www.bilibili.com',
            'Accept': '*/*',
            'Origin': 'https://www.bilibili.com',
            'Accept-Encoding': 'gzip, deflate, br',
            'Accept-Language': 'zh-CN,zh;q=0.9',
        }

    CATEGORY = "视频转文章"

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "视频链接": ("STRING", {
                    "default": "", 
                    "multiline": False,
                    "placeholder": "请输入B站视频URL或BV号"
                }),
                "预览帧": ("INT", {
                    "default": 0,
                    "min": 0,
                    "max": 10000,
                    "step": 1,
                    "display": "number"
                }),
            },
            "optional": {
                "SESSDATA": ("STRING", {
                    "default": "",
                    "multiline": False,
                    "placeholder": "输入B站SESSDATA(可选)"
                })
            }
        }

    RETURN_TYPES = ("STRING", "IMAGE")
    RETURN_NAMES = ("视频路径", "预览图像")
    
    FUNCTION = "download_and_preview"

    def extract_bvid(self, url):
        # 从URL中提取BV号
        bv_pattern = r'BV[a-zA-Z0-9]+'
        match = re.search(bv_pattern, url)
        if match:
            return match.group()
        return url

    async def download_bilibili_video(self, url, sessdata=None):
        bvid = self.extract_bvid(url)
        temp_video = None
        temp_audio = None
        
        try:
            # 检查 ffmpeg 是否可用
            try:
                subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)
            except (subprocess.SubprocessError, FileNotFoundError):
                raise ValueError("未找到 ffmpeg,请先安装 ffmpeg 并确保其在系统路径中")

            credential = None
            if sessdata:
                credential = Credential(sessdata=sessdata)
            v = video.Video(bvid=bvid, credential=credential)
            
            video_info = await v.get_info()
            title = video_info['title']
            
            # 安全的文件名
            safe_title = "".join(x for x in title if x.isalnum() or x in (' ','-','_')).rstrip()
            video_path = os.path.join(self.output_dir, f"{safe_title}.mp4")
            temp_video = os.path.join(self.output_dir, f"{safe_title}_temp_video.m4s")
            temp_audio = os.path.join(self.output_dir, f"{safe_title}_temp_audio.m4s")
            
            # 如果文件已存在,先删除
            for file in [video_path, temp_video, temp_audio]:
                if os.path.exists(file):
                    os.remove(file)
            
            video_url = await v.get_download_url(0)
            video_stream_url = video_url['dash']['video'][0]['baseUrl']
            audio_stream_url = video_url['dash']['audio'][0]['baseUrl']
            
            print(f"开始下载视频: {safe_title}")
            
            # 下载视频流
            async with aiohttp.ClientSession() as session:
                # 下载视频部分
                print("下载视频流...")
                async with session.get(video_stream_url, headers=self.headers) as resp:
                    if resp.status != 200:
                        raise ValueError(f"视频下载失败,状态码:{resp.status}")
                    with open(temp_video, 'wb') as f:
                        async for chunk in resp.content.iter_chunked(1024*1024):
                            f.write(chunk)
                
                # 下载音频部分
                print("下载音频流...")
                async with session.get(audio_stream_url, headers=self.headers) as resp:
                    if resp.status != 200:
                        raise ValueError(f"音频下载失败,状态码:{resp.status}")
                    with open(temp_audio, 'wb') as f:
                        async for chunk in resp.content.iter_chunked(1024*1024):
                            f.write(chunk)
            
            # 检查临时文件是否存在
            if not os.path.exists(temp_video) or not os.path.exists(temp_audio):
                raise ValueError("临时文件下载失败")
                
            print("合并音视频...")
            # 使用绝对路径执行ffmpeg
            ffmpeg_cmd = [
                'ffmpeg',
                '-i', os.path.abspath(temp_video),
                '-i', os.path.abspath(temp_audio),
                '-c', 'copy',
                os.path.abspath(video_path),
                '-y'
            ]
            
            process = subprocess.Popen(
                ffmpeg_cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True
            )
            stdout, stderr = process.communicate()
            
            if process.returncode != 0:
                print(f"FFmpeg错误输出: {stderr}")
                raise ValueError(f"FFmpeg合并失败,返回码: {process.returncode}")
            
            # 检查输出文件
            if not os.path.exists(video_path):
                raise ValueError("合并后的视频文件未生成")
                
            print("清理临时文件...")
            # 清理临时文件
            for file in [temp_video, temp_audio]:
                if os.path.exists(file):
                    os.remove(file)
            
            # 等待文件写入完成
            time.sleep(1)
            
            print("验证视频文件...")
            # 验证文件是否可读
            cap = cv2.VideoCapture(video_path)
            if not cap.isOpened():
                raise ValueError("无法打开合并后的视频文件")
            cap.release()
            
            print("视频处理完成")
            return video_path
            
        except Exception as e:
            # 清理所有临时文件
            if temp_video and os.path.exists(temp_video):
                os.remove(temp_video)
            if temp_audio and os.path.exists(temp_audio):
                os.remove(temp_audio)
            raise ValueError(f"视频处理失败: {str(e)}")

    def download_and_preview(self, 视频链接, 预览帧, SESSDATA=""):
        if not 视频链接:
            raise ValueError("请输入有效的视频URL")

        try:
            video_path = asyncio.run(self.download_bilibili_video(视频链接, SESSDATA))
            
            # 等待文件完全写入
            time.sleep(1)
            
            # 尝试多次打开视频文件
            max_attempts = 3
            for attempt in range(max_attempts):
                cap = cv2.VideoCapture(video_path)
                if cap.isOpened():
                    break
                time.sleep(1)
            
            if not cap.isOpened():
                raise ValueError("无法打开视频文件")
            
            # 获取实际帧数
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            if 预览帧 >= total_frames:
                预览帧 = 0
            
            cap.set(cv2.CAP_PROP_POS_FRAMES, 预览帧)
            ret, frame = cap.read()
            
            if not ret:
                raise ValueError(f"无法读取视频帧 {预览帧}")

            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            preview_image = torch.from_numpy(frame_rgb).float() / 255.0
            preview_image = preview_image.unsqueeze(0)
            
            cap.release()

            return (video_path, preview_image)

        except Exception as e:
            raise ValueError(f"下载或处理视频时出错: {str(e)}")

    @classmethod
    def IS_CHANGED(cls, 视频链接, 预览帧, SESSDATA=""):
        return float("nan")

# 节点注册
NODE_CLASS_MAPPINGS = {
    "VideoDownloader": VideoDownloader
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "VideoDownloader": "B站视频下载器"
}

视频帧提取.py 源码

# -*- coding: utf-8 -*-
import os
import cv2
import torch
import numpy as np
from PIL import Image
import sys

class VideoFrameExtractor:
    def __init__(self):
        # 创建输出目录
        self.base_output_dir = "extracted_frames"
        if not os.path.exists(self.base_output_dir):
            os.makedirs(self.base_output_dir)

    CATEGORY = "视频转文章"

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "视频路径": ("STRING", {"forceInput": True}),
                "起始帧": ("INT", {
                    "default": 0,
                    "min": 0,
                    "max": 100000,
                    "step": 1,
                    "display": "number"
                }),
                "提取间隔": ("INT", {
                    "default": 30,
                    "min": 1,
                    "max": 300,
                    "step": 1,
                    "display": "number"
                }),
                "最大提取数": ("INT", {
                    "default": 20,
                    "min": 1,
                    "max": 100,
                    "step": 1,
                    "display": "number"
                }),
            },
            "optional": {
                "结束帧": ("INT", {
                    "default": -1,
                    "min": -1,
                    "max": 100000,
                    "step": 1,
                    "display": "number"
                }),
                "保存帧": ("BOOLEAN", {"default": True}),
                "子目录名": ("STRING", {
                    "default": "",
                    "multiline": False,
                    "placeholder": "可选,留空则使用视频文件名"
                }),
            }
        }

    RETURN_TYPES = ("IMAGE", "STRING")
    RETURN_NAMES = ("帧序列", "帧路径列表")
    
    FUNCTION = "extract_frames"

    def create_output_dir(self, video_path, sub_dir=""):
        try:
            if not sub_dir:
                sub_dir = os.path.splitext(os.path.basename(video_path))[0]
            output_dir = os.path.join(self.base_output_dir, sub_dir)
            original_dir = output_dir
            counter = 1
            while os.path.exists(output_dir):
                output_dir = f"{original_dir}_{counter}"
                counter += 1
            os.makedirs(output_dir)
            return output_dir
        except Exception as e:
            print(f"创建目录时出错: {str(e)}")
            import time
            backup_dir = os.path.join(self.base_output_dir, f"frames_{int(time.time())}")
            os.makedirs(backup_dir, exist_ok=True)
            return backup_dir

    def extract_frames(self, 视频路径, 起始帧, 提取间隔, 最大提取数, 结束帧=-1, 保存帧=True, 子目录名=""):
        if not os.path.exists(视频路径):
            raise ValueError(f"视频文件不存在: {视频路径}")

        cap = None
        try:
            output_dir = self.create_output_dir(视频路径, 子目录名) if 保存帧 else None
            
            cap = cv2.VideoCapture(视频路径)
            if not cap.isOpened():
                raise ValueError("无法打开视频文件")
            
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            fps = cap.get(cv2.CAP_PROP_FPS)
            
            if 结束帧 == -1 or 结束帧 >= total_frames:
                结束帧 = total_frames - 1
            
            if 起始帧 < 0 or 起始帧 >= total_frames:
                raise ValueError(f"起始帧超出范围 (0-{total_frames-1})")
            if 结束帧 < 起始帧:
                raise ValueError(f"结束帧必须大于起始帧")
            
            帧范围 = 结束帧 - 起始帧 + 1
            实际间隔 = max(提取间隔, int(帧范围 / 最大提取数))
            帧位置列表 = range(起始帧, 结束帧 + 1, 实际间隔)
            帧位置列表 = list(帧位置列表)[:最大提取数]
            
            frames = []
            frame_paths = []
            
            print(f"开始提取帧,范围:{起始帧}-{结束帧},间隔:{实际间隔},计划提取:{len(帧位置列表)}帧")
            
            for frame_pos in 帧位置列表:
                cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)
                ret, frame = cap.read()
                
                if not ret:
                    print(f"警告:无法读取帧位置 {frame_pos}")
                    continue
                
                # 转换颜色空间并确保格式正确
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame_tensor = torch.from_numpy(frame_rgb).float() / 255.0
                
                # 确保维度正确 (H, W, C)
                if len(frame_tensor.shape) == 3:
                    frames.append(frame_tensor)
                    
                    if 保存帧 and output_dir:
                        frame_path = os.path.join(output_dir, f"frame_{frame_pos:06d}.png")
                        Image.fromarray((frame_rgb).astype(np.uint8)).save(frame_path)
                        frame_paths.append(frame_path)
                        print(f"已保存帧 {frame_pos}: {frame_path}")
            
            if not frames:
                raise ValueError("没有成功提取到帧")
            
            # 堆叠所有帧并确保格式正确 (N, H, W, C)
            frames_tensor = torch.stack(frames)
            frame_paths_str = ",".join(frame_paths) if frame_paths else ""
            
            print(f"帧提取完成,共提取{len(frames)}帧")
            if 保存帧:
                print(f"帧已保存到目录: {output_dir}")
            
            return (frames_tensor, frame_paths_str)

        except Exception as e:
            raise ValueError(f"提取帧时出错: {str(e)}")
        finally:
            if cap is not None:
                cap.release()

    @classmethod
    def IS_CHANGED(cls, 视频路径, *args):
        return float("nan")

# 节点注册
NODE_CLASS_MAPPINGS = {
    "VideoFrameExtractor": VideoFrameExtractor
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "VideoFrameExtractor": "视频帧提取器"
} 

注意:python文件的依赖要下载。

第一步实现的效果

点击这里复制本文地址 以上内容由文彬编程网整理呈现,请务必在转载分享时注明本文地址!如对内容有疑问,请联系我们,谢谢!
qrcode

文彬编程网 © All Rights Reserved.  蜀ICP备2024111239号-4