Out of curiosity, how do you handle things where the output shape is input dependent (as opposed to only dependent on input shapes)?
This is from `torch.sum(tensor, dim)` where dim might be nonconstant to `torch.nonzero(x)` and of course advanced indexing.