SVFR – demo.py FIXED CODE


Open demo.py in your text editor.
Do a search for “def infer”

Replace the block with this code:

def infer(lq_sequence, task_name, mask, seed):
   unique_id = str(uuid.uuid4())
   output_dir = f"results_{unique_id}"

   task_mapping = {
       "BFR": 0,
       "Colorization": 1,
       "Inpainting": 2
    }

   task_ids = [task_mapping[task] for task in task_name if task in task_mapping]
   # task_id = ",".join(task_ids)

   try:
       parser = argparse.ArgumentParser()
       args = parser.parse_args()
       args.task_ids = task_ids
       args.input_path = f"{lq_sequence}"
       args.output_dir = f"{output_dir}"
       args.mask_path = f"{mask}"
       args.seed = int(seed)
       args.restore_frames = False

       gen(args,pipe)

       # Search for the mp4 file in a subfolder of output_dir
       output_video = glob(os.path.join(output_dir,"*gen.mp4"))
       face_region_video = glob(os.path.join(output_dir,"*ori.mp4"))
       # print(face_region_video,output_video)

       if output_video:
           output_video_path = output_video[0]  # Get the first match
           face_region_video_path = face_region_video[0]  # Get the first match
       else:
           output_video_path = None
           face_region_video = None

       print(output_video_path,face_region_video_path)
       torch.cuda.empty_cache()
       return face_region_video_path,output_video_path

   except subprocess.CalledProcessError as e:
       torch.cuda.empty_cache()
       raise gr.Error(f"Error during inference: {str(e)}")

Leave a Reply

Your email address will not be published. Required fields are marked *