I’ve had the chance to test a few of the up-and-coming dedicated AI accelerators, and except for the TPU, none of them gave a speed up or perf/watt improvement that was worth the additional complexity of using something that didn’t natively support Cuda. (And the reason the TPU is different is because XLA is a first class citizen for TF/JAX, and has pretty good support in pytorch)
And I definitely see some future for FP16/BF16 or INT8 devices for inference, but I don’t know how widespread such devices will become for training. For many types of models, they simply won’t converge if they don’t at least use a mixed precision scheme. And there are certain problems where even for inference you get a significant benefit from using FP32 - for instance, I work on building DL surrogate models for physical simulations. We get better results with the increased dynamic range and precision.
What might be needed from NN accelerator designers is not so much perf per watt as fast links between chips, between chips and memory, and between nodes in a cluster, so that a large distributed cluster of chips, each with its own memory, would appear to a user as a single chip with a huge amount of memory. Imagine 10TB/s links everywhere in such a cluster - not only training a large model would be vastly simpler, it would also be a lot more efficient (e.g. no need for data parallel model replication). In theory you could have a truly unified and shared memory space, with little need for synchronization.
Strongly agree. Bandwidth and latency are the two big limiting factors always though, and unfortunately they tend to improve slower than compute. I think right now the fastest infiniband you can achieve is 3-400 GB/s, and compare that to the H100 HBM3 memory bandwidth at up to 3 TB/s. So I don’t see a time in the near future where we can naively treat our server fleets as a truly unified machine.
On the other hand, work being done in the big three DL packages for making distributed training easier has been quite nice. I know the least about TF, but their dtensor looks promising. JAX’s entire distributed paradigm using pmap/xmap etc make certain classes of models very easy to distribute. The one I’m following most closely though is Pytorch and their sharded tensor, and it looks like they’re planning on implementing native distributed ops powered by their RPC framework, which should make full tensor- and model-parallelism significantly easier.
Technically 8 H100CNX cards on Gen5 128GBs backplane can get one card a third of the HBM bandwidth. Given that Gen 6 and 7 are following much more quickly than the Gen 3 > 4 development, we may be there in a couple of years. I'm hoping that Gen 6/7 will have significant effects on the cost of the high end by making specialized boards less attractive and leveraging commodity switching.
And I definitely see some future for FP16/BF16 or INT8 devices for inference, but I don’t know how widespread such devices will become for training. For many types of models, they simply won’t converge if they don’t at least use a mixed precision scheme. And there are certain problems where even for inference you get a significant benefit from using FP32 - for instance, I work on building DL surrogate models for physical simulations. We get better results with the increased dynamic range and precision.