做网站过时了西安百度推广开户运营
Large World Model(LWM)现在大火,其最主要特点是不仅能够针对文本进行检索交互,还能对图片、视频进行问答交互,自从上文《LWM(LargeWorldModel)大世界模型-可文字可图片可视频-多模态LargeWorld-详细安装记录》发出后,短短两天,github的Star已经涨到了5.4k!
上次安装之后,没有成功运行起来,今天终于把项目成功跑起来了。
首先体验的是视频问答功能,就是给LWM模型一段视频,然后就这段视频进行问答交互,首先看我给的视频:
beginning-spring
然后:
看起来效果不错!
项目体验
运行方法:修改scripts/run_vision_chat.sh中对应的模型路径:
export llama_tokenizer_path=""
export vqgan_checkpoint=""
export lwm_checkpoint=""
export input_file=""
然后注意–mesh_dim='!1,-1,32,1’这个参数,按官方解释说:
You can use mesh_dim=dp, fsdp, tp, sp to control the degree of parallelism and RingAttention. It is a string of 4 integers separated
by commas, representing the number of data parallelism, fully sharded
data parallelism, tensor parallelism, and sequence parallelism. For
example, mesh_dim=‘1,64,4,1’ means 1 data parallelism, 64 fully
sharded data parallelism, 4 tensor parallelism, and 1 sequence
parallelism. mesh_dim=‘1,1,4,64’ means 1 data parallelism, 1 fully
sharded data parallelism, 4 tensor parallelism, and 64 sequence
parallelism for RingAttention.
但我这里无论怎么调都失败,索性删掉改参数,让程序使用默认配置。
下一个参数–dtype=‘fp32’ 修改为–dtype=‘fp16’ ,由于我的设备内存有限,改为fp16才能正常运行,修改后能正常运行的命令如下:
python3 -u -m lwm.vision_chat \--prompt="What is the video about?" \--input_file="$input_file" \--vqgan_checkpoint="$vqgan_checkpoint" \--dtype='fp16' \--load_llama_config='7b' \--max_n_frames=8 \--update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,use_flash_attention=False,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,remat_attention='',scan_mlp=False,scan_mlp_chunk_size=2048,remat_mlp='',remat_block='',scan_layers=True)" \--load_checkpoint="params::$lwm_checkpoint" \--tokenizer.vocab_file="$llama_tokenizer_path" \
然后 bash run_vision_chat.sh即可查看效果。
多次问答
官方源码中的run_vision_chat.sh脚本只能执行一次脚本输入一个prompt,这样交互太累,我这里做了一点调整,实现多次循环问答交互:
修改代码如下:
首先复制一份lwm/vision_chat.py,重命名为vision_chat2.py,将其中的FLAGS参数对象修改为:
p={"prompt": "","input_file": "","vqgan_checkpoint": "","temperature": 0.2,"max_n_frames": 8,"seed": 1234,"mesh_dim": "1,-1,1,1","dtype": "fp32","load_llama_config": "","update_llama_config": "","load_checkpoint": "","tokenizer":VideoLLaMAConfig.get_tokenizer_config(),"llama":VideoLLaMAConfig.get_default_config(),"jax_distributed":JaxDistributedConfig.get_default_config()
}
FLAGS = types.SimpleNamespace(**p)FLAGS.vqgan_checkpoint = "模型vqgan路径"
FLAGS.dtype='fp16'
FLAGS.load_llama_config='7b'
FLAGS.max_n_frames=8
FLAGS.update_llama_config = "dict(sample_mode='text',theta=50000000,max_sequence_length=131072,use_flash_attention=False,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,remat_attention='',scan_mlp=False,scan_mlp_chunk_size=2048,remat_mlp='',remat_block='',scan_layers=True)"
FLAGS.load_checkpoint = "params::模型params路径"
FLAGS.tokenizer.vocab_file = "模型tokenizer.model路径"
对其中的main方法修改为:
if __name__ == "__main__": FLAGS.input_file = input('vide path:')JaxDistributedConfig.initialize(FLAGS.jax_distributed)set_random_seed(FLAGS.seed)sampler = Sampler()while True:while FLAGS.prompt=='':FLAGS.prompt = input('input prompt:')prompts = [{'input_path': FLAGS.input_file, 'question': FLAGS.prompt}]output = sampler(prompts, FLAGS.max_n_frames)[0]print(f"Question: {FLAGS.prompt}\nAnswer: {output}")FLAGS.prompt=''
修改完成之后,仍然通过bash脚本来调用,新建一个bash脚本文件:
#! /bin/bash
python3 lwm/vision_chat2.py
祝大家玩得愉快!