AttentionFreeTransformer 核心结构图(GraphViz 重绘)

发布时间 2024-01-10 13:00:36作者: 绝不原创的飞龙

AFTFull

请添加图片描述

digraph AFTFull {
	rankdir=BT
    node [
		style=filled, 
		color=Black
		fontcolor=White, 
		fillcolor="#30638e", 
		fontname="SimHei",
		fontsize=32,
		width=5, height=2,
	]

    inp [label="输入\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"]
    llq [label="LinearQ\n[HidSize, ProjSize]", shape="box"]
    llk [label="LinearK\n[HidSize, ProjSize]", shape="box"]
    llv [label="LinearV\n[HidSize, ProjSize]", shape="box"]
	w [label="W:Param\n[SeqLen, SeqLen]", shape="Mrecord"]
    q [label="Q\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
    k [label="K\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
    v [label="V\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
    σ [label="Sigmoid", shape="box", width=3]
    atten_op [label="exp(W) @ (exp(K) * V)\n/ exp(W) * exp(K)", shape="box"]
    atten [label="[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
    mul [label="*", shape="box", width=3]
    llo [label="LinearO\n[ProjSize, HidSize]", shape="box"]
    oup [label="输出\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"]
    
	inp -> llq
	inp -> llk
	inp -> llv
	llq -> q
	llk -> k
	llv -> v
	q -> σ
	w -> atten_op
	k -> atten_op
	v -> atten_op
	atten_op -> atten
	σ -> mul
	atten -> mul
	mul -> llo
	llo -> oup
}

AFTSimple

请添加图片描述

digraph AFTSimple {
	rankdir=BT
    node [
		style=filled, 
		color=Black
		fontcolor=White, 
		fillcolor="#30638e", 
		fontname="SimHei",
		fontsize=32,
		width=5, height=2,
	]

    inp [label="输入\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"]
    llq [label="LinearQ\n[HidSize, ProjSize]", shape="box"]
    llk [label="LinearK\n[HidSize, ProjSize]", shape="box"]
    llv [label="LinearV\n[HidSize, ProjSize]", shape="box"]
    q [label="Q\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
    k [label="K\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
    v [label="V\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
    σ [label="Sigmoid", shape="box", width=3]
    atten_op [label="sum(softmax(K, 1) * V, 1)", shape="box"]
    atten [label="[BatchSize, 1, ProjSize]", shape="Mrecord"]
    mul [label="*", shape="box", width=3]
    llo [label="LinearO\n[ProjSize, HidSize]", shape="box"]
    oup [label="输出\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"]
    
	inp -> llq
	inp -> llk
	inp -> llv
	llq -> q
	llk -> k
	llv -> v
	q -> σ
	k -> atten_op
	v -> atten_op
	atten_op -> atten
	σ -> mul
	atten -> mul
	mul -> llo
	llo -> oup
}

AFTLocal

请添加图片描述

digraph AFTLocal {
	rankdir=BT
    node [
		style=filled, 
		color=Black
		fontcolor=White, 
		fillcolor="#30638e", 
		fontname="SimHei",
		fontsize=32,
		width=5, height=2,
	]

    inp [label="输入\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"]
    llq [label="LinearQ\n[HidSize, ProjSize]", shape="box"]
    llk [label="LinearK\n[HidSize, ProjSize]", shape="box"]
    llv [label="LinearV\n[HidSize, ProjSize]", shape="box"]
	w [label="W:Param\n[SeqLen, SeqLen]", shape="Mrecord"]
	mask [label="mask\n[SeqLen, SeqLen]\nabs(i - j) < S? 1: 0", shape="box"]
    q [label="Q\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
    k [label="K\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
    v [label="V\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
    σ [label="Sigmoid", shape="box", width=3]
    atten_op [label="exp(W) @ (exp(K) * V)\n/ exp(W) * exp(K)", shape="box"]
    atten [label="[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
    mul [label="*", shape="box", width=3]
    llo [label="LinearO\n[ProjSize, HidSize]", shape="box"]
    oup [label="输出\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"]
    
	inp -> llq
	inp -> llk
	inp -> llv
	llq -> q
	llk -> k
	llv -> v
	q -> σ
	w -> mask
	mask -> atten_op
	k -> atten_op
	v -> atten_op
	atten_op -> atten
	σ -> mul
	atten -> mul
	mul -> llo
	llo -> oup
}