본문 바로가기
3D-GS

[논문 리뷰] DISTWAR (arXiv2023) : 3D-GS 학습 속도 개선

by xoft 2024. 1. 27.

DISTWAR: Fast Differentiable Rendering on Raster-based Rendering Pipelines, Sankeerth Durvasula, arXiv 2024

 

3D Gaussian Splatting (이전글)의 학습 속도를 GPU 연산 최적화로 개선한 연구입니다. 학습 속도 개선이 실용적이라 많은 곳에서 쓰일 것이라 생각되어 정리해봤습니다. 다른 논문들과 다르게 GPU연산 관점을 바라보고 있어서 이런 3DGS가 GPU연산 관점에서 문제점이 있고 이런 계열 논문이 있다, 정도만 아시고 넘어가기에 좋다고 생각합니다.

 

본론으로 들어가기 전에 아래글을 먼저 읽기를 바랍니다. 본 논문을 읽기 위해 필요한 개념과 용어를 정리해두었습니다.  GPU 메모리에 대한 일반적인 지식을 쌓을 수 있게 써봤습니다. 밑에 설명을 읽다가 모르는 용어가 나오면 참고하시길 바랍니다.

 

[개념 정리] GPU 메모리 구조 및 용어

GPU 메모리 구성과 데이터 흐름에 관해 다뤄보겠습니다. Graphic관련 논문을 읽기 위한 목적으로 일반적인 개념들을 정리해보았습니다. GPU 메모리 구조 GPU와 CPU메모리 구조의 차이입니다. (출처 : l

xoft.tistory.com

 

 

 

 

기존 방법의 문제점

3DGS의 gradient computation단계는 RTX4090기준으로 전체 학습시간에서 평균 30%(최대 65%)의 비중을 차지하게 되는데 여기서 큰 bottleneck이 발생했다고 합니다. 아래는 Training시 연산시간을 세부 분석한 그래프입니다. 3D가 3DGS에 해당하며 하이픈(-) 뒤는 서로 다른 데이터셋을 의미합니다. (NV와PS는 모델에 대해서는 다루지 않겠습니다.)

Gaussian의 갯수가 많이 요구하는 복잡한 real-world scene(PR, DR)에서는 forward pass/loss계산시 연산시간의 변동이 없었지만, gradient computation 시간은 크게 증가하였다고 합니다. 그 중에서도 gradient를 aggregation할 때 큰 bottleneck을 만들었습니다.

  • 여러 thread가 동일한 3D gaussian의 parameter를 update하기 위해, pixel당 3D gaussian마다 update는 atomic operation (=thread)으로 만들어집니다. 그리고 한개의 scene에서 많은 갯수의 atomic operation이 만들어지게 됩니다.
  • 많은 atomic operation은 L2 cache에서 contention(=충돌)을 만들어내고 SM(=stream multiprocessor)에서 긴 stall(=정체)를 만들게 됩니다.

atomic requests하는 thread갯수는 atomic request의 traffic(=혼잡량)에 비례합니다.

 

이전 선행 연구들의 해결방법을 보자면, 이전 연구들은 L2 memory의 atomic unit(=ROP)로 traffic을 줄이기 위해, L1 cache에서 atomic update를 buffering하고 aggregation했습니다.

이 방법을 통해 다양한 application에서 overhead를 효과적으로 줄였지만, 3D GS와 같은 differentiable rendering에서는 그렇지 못했습니다. atomic이 aggregated되기도 전에 생성되는 atomic request갯수가 LSU(Load-Store Unit) 허용량을 초과하면서 bottleneck을 해결하지 못하게 됩니다. 아래 그래프는 instruction마다 stall되는 횟수를 NVIDIA NSIGHT profiler로 표시했습니다. 왼쪽은 RTX4090, RTX3060입니다. 

LSU Stall이 상당히 높다는 것은 Sub-Core(=Cuda Core)에서 global memory로 전달되는 atomic operation이 많이 몰린다는 것을 의미합니다. (추가적으로 더 고사양의 4090에서 오히려 더 많은 stall이 발생하는데, SM당 ROP 비율이 낮기 때문입니다. RTX 4090의 SM 144개, ROP 176개이며,   RTX3060의 SM은 28개, ROP는 48개입니다. <-너무 세부적인 정보)

 

때문에 본 논문에서는 gradient computation에서 bottleneck을 만드는 atomic operation의 속도를 개선하여 전체 학습 속도를 개선하는 것을 목표로 하고 있습니다.

 

 

 

 

알고리즘 아이디어

2가지 Observation으로 본 논문에서 제시하는 알고리즘 아이디어를 얻었다고 합니다.

1) Observation1 : 하나의 warp안에 있는 thread는 같은 gaussin의 parameter를 업데이트 합니다.

gaussian의 parameter는 특정 카메라 포즈에서 inference한 이미지와 GT간의 pixel차이로 업데이트하게 됩니다. 업데이트할 gaussian은 근접한 여러개의 pixel들로 구성되어 있으므로, 동일한 gaussian의 parameter들을 업데이트하는 thread들이 (동일한 instruction을 수행하는 32개 thread 단위인) warp로 만들어지며, 이 warp들이 99% 확률로 같은 memory 위치를 access하는 것을 확인했습니다.

 

 

2) Observation2 : 하나의 warp안에 일부 thread만으로도 즉시 atomic update를 수행합니다.

3DGS에선 특정 조건 (예를 들어 특정 값이 정의한 threshold를 넘지 못 할 경우)에선 graident를 update하지 않습니다. 그렇게 될 경우, thread가 inactive가 되고 graident 업데이트가 skip하게 되며, 이는 하나의 warp에서 thread를 32개를 채우지 못하게 되면서, LSU와 ROP에 다른 traffic을 만들게 됩니다. 아래는 warp안에 있는 thread갯수를 나타내는 그래프입니다. 하나의 warp안에서 포함된 thread갯수가 (꽉 찬) 32개가 아니라 다양하다 라는 점이 observation입니다.

 

 

 

 

제안하는 알고리즘

Warp-level Reduction과 atomic operation 스케쥴링을 통해 traffic을 줄입니다.

 

Warp-level Reduction

atomic update를 하면서 warp안의 최대 32개 thread를 1개의 thread로 줄어드는 과정을 warp-level reduction이라고 합니다. 이를 register에서 수행합니다. 2가지 방법을 제안합니다.

  • Serialized Reduction
    warp마다 같은 메모리에 같은 3D gaussian parameter를 업데이트하는 thread집합을 찾고, serial하게 accumulation해서 L2 ROP로 전달합니다. serial하게 더하는 것 자체는 비효율적이지만, 이 동작이 SM내 thread block마다 병렬적으로 수행되면서 효율성을 만들어 냅니다.

 

 

  • Butterfly Reduction
    warp마다 모든 thread가 같은 3D gaussian parameter을 업데이트하는 경우를 찾고, active한 thread를 모두 더합니다. observation1에 따르면 warp의 99%는 같은 gaussian의 parameter를 업데이트한다고 했었습니다. 이 특징을 사용해, 병렬적으로 트리 구조 형태로 더하여 연산량을 줄입니다.

 

 

 

Scheduling atomic update

atomic updates를 L2 ROP에서 수행할지 아니면 register에서 warp-level reduction을 통해 Core에서 수행 할 지를 결정합니다.

하나의 paramter를 업데이트하는 warp내의 thread갯수와 미리정의한 balancing threshold를 넘는지에 따라 정해집니다.

balancing threshold는atomic unit의 contention수를 줄이는 방향이 되어야 하므로 아래 수치에 따라 바뀝니다.

  • Dataset (=scene) and workload(=3DGS) : camera 해상도, model architecture, scene의 복잡도
  • GPU architecture : SM과 ROP unit의 비율
  • Reduction method used : butterfly reduction 또는 serial reduction 사용 유무

조금 더 구체적으로, 각 thread는 같은 warp안에서 몇개의 다른 thread가 같은 3D gaussian parameter를 update하는지 계산합니다. 그 갯수가 balancing threshold보다 작다면 ROP에서 수행하고, 크다면 register에서 수행됩니다. register에서 수행된 atomic updates 결과는 ROP unit으로 전달됩니다.

balance threshold는 0~31개의 값으로 셋팅 할 수 있는데,  32개 값을 사용해서 gradient computation을 1 iteration수행한 후 가장 높은 속도 향상을 만든 value로 thresold를 정합니다. 이를 profiling step이라고 부르며,  2000 iteration마다 한번 수행합니다. profiling step은 무시할만한 오버헤드를 준다고 합니다.

 

 

요약하자면,

register을 사용해서 warp-level reduction을 수행하여 atomic update에 대한 intra-warp locality를 만들어 traffic을 줄이고, 

atomic operation연산에 대해 balance threshold으로 SM과 L2 ROP을 적절히 스케쥴링하여 traffic을 줄입니다

 

 

 

 

성능 평가

복잡한 realscene 데이터셋을 3DGS로 모델링하는 3D-PR, 3D-DR을 보면, 학습 속도가 1.5배 이상 향상된 것을 볼 수 있습니다. gradient computation만 본다면 3배 이상 속도 향상된 것을 볼 수 있습니다.

또한 고사양 GPU인 RTX 4090가 더 높은 비율로 속도 향상을 보이는 것을 볼 수 있습니다. 이는 4090이 SM당 ROP갯수 비율이 적어서 bottleneck이 더 많이 발생했기 때문이라고 합니다. 4090이 144SM에 176ROP이고, 3060이 28SM에 48ROP입니다. 

SW-S-#는 serialized reduction, SW-B-#는 butterfly reduction을 의미하고, #는 balancing threshold를 의미합니다. butterfly reduction이 속도 향상에 많은 기여를 한 것을 볼 수 있습니다. 또한 SW-B이고 threshold가 0일 때, 높은 효율을 보였고, SW-S이고 threshold가 16일 때 높은 효율을 보였습니다.

'기존 방법 문제점' 섹션 마지막에서 다뤘던 stall그래프와 비교하면, LSU stall이 엄청나게 사라진 것을 볼 수 있었습니다.

warp-level reduction을 자동으로 하는 NVIDIA CCCL Library를 사용할 경우 속도가 오히려 떨어졌다고 합니다. 본 논문은 3D Gaussian의 parameter를 batch로 가지면서 reduction을 수행하지만, CCCL은 batch없이 각각의 parameter에 대해 reduction을 수행하게 되며, 또한 CCCL은 SM과 ROP unit에 atomic computation을 분배하지 않았고, 동시에 추가적인 instruction이 발생하였다고 합니다. 때문에 CCCL은  warp-level reduction을 해도 오히려 속도가 20% 느려졌다고 합니다.

 

 

 

 

Closing..

3D Gaussian Splatting의 학습시 gpu연산에 대해서 상세히 알아보는 글이었습니다. gpu에 대한 low level 코드까지 파고들어 gpu연산 최적화 관점에서 생소한 분야를 다뤄봤습니다. 덕분에 개인적으로 gpu 메모리에 대해서 많은 것을 공부하게 됬습니다.

무엇보다 이 논문은 3DGS분야 전반적으로 많은 도움이 될 것 같습니다. 논문에서 언급한 속도만큼 개선이 된다면, 학습 속도가 개선된 만큼 실험 시간도 줄어들고, 서비스를 만든다고하면 그만큼 서버 비용을 절약 할 수 있으니 사용이 필수가 될 것 같습니다. 

소스코드가 공개(link)되어 있더군요. 빌드해보기 싶긴하지만 저는 나중으로 미룰까합니다. 혹시나 빌드해보신분 있으시면 속도가 정말 빨라지는지 댓글로 공유 부탁드립니다.

댓글