graph: sdpa: support dropout seed/offset/prob in fused sdpa#4961
graph: sdpa: support dropout seed/offset/prob in fused sdpa#4961
Conversation
|
Noticed correctness issue via benchdnn. Debugging... # with fused sdpa kernel
$ _ONEDNN_GRAPH_SDPA_FORCE_PRIMITIVE=0 ./tests/benchdnn/benchdnn --graph --engine=gpu --case=complex_fusion/mha/gqa-plain-training-fwd-w-dropout-bf16-f32.json
[COMPARE_STATS][DST]: trh=0 err_max_diff: 2.01562 err_max_rdiff:8.37618e+37 all_max_diff: 2.01562 all_max_rdiff:8.37618e+37
[COMPARE_STATS] Norm check is prohibited; error_to_total_ratio: 233469/262144; allowed_ratio: 256/262144;
Error: Function 'doit' at (/nfs/pdx/disks/hal9000/lvtao/oneDNN/tests/benchdnn/graph/graph.cpp:787) returned '1'
0:FAILED (errors:233469 total:262144) (3079 ms) __REPRO: --graph --engine=gpu --case=complex_fusion/mha/gqa-plain-training-fwd-w-dropout-bf16-f32.json
===========================================================
= Failed cases summary (--summary=no-failures to disable) =
===========================================================
0:FAILED (errors:233469 total:262144) (3079 ms) __REPRO: --graph --engine=gpu --case=complex_fusion/mha/gqa-plain-training-fwd-w-dropout-bf16-f32.json
============================
tests:1 passed:0 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:1 listed:0
total: 3.09s; create_pd: 0.09s (3%); create_prim: 0.96s (31%); fill: 0.00s (0%); execute: 0.01s (0%); compute_ref: 0.00s (0%); compare: 0.00s (0%);
# with primitive based kernel
$ _ONEDNN_GRAPH_SDPA_FORCE_PRIMITIVE=1 ./tests/benchdnn/benchdnn --graph --engine=gpu --case=complex_fusion/mha/gqa-plain-training-fwd-w-dropout-
bf16-f32.json
0:PASSED (10122 ms) __REPRO: --graph --engine=gpu --case=complex_fusion/mha/gqa-plain-training-fwd-w-dropout-bf16-f32.json
tests:1 passed:1 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:0 listed:0
total: 10.12s; create_pd: 0.07s (1%); create_prim: 0.67s (7%); fill: 0.00s (0%); execute: 0.00s (0%); compute_ref: 0.00s (0%); compare: 0.00s (0%); |
|
make test |
c2806db to
3e7b8cd
Compare
|
make test |
3e7b8cd to
2376fbd
Compare
|
We are not enabling mask here? Also, we will need a backport to v3.12 branch as well. |
Dropout mask output is not required for SDPA training in PyTorch. |
2376fbd to
851ad8e
Compare
|
make test |
78e1fb8 to
c28941c
Compare
dzarukin
left a comment
There was a problem hiding this comment.
(Minor) It looks to me if the output_mask from dropout will be requested, the pattern won't be picked up. If that's the case, probably, would be good to reflect that in documentation or/and code comment. If this is false impression, then OK.
c28941c to
aa04588
Compare
|
make test |
For SDPA forward with dropout seed/offset/prob.
SDPA backward will be fixed laterUpdate: SDPA backward is also fixed.