root commited on
Commit
a2b4eb3
·
1 Parent(s): ae93403

fix torch_dtype

Browse files
Files changed (1) hide show
  1. modeling_srv1_tp.py +1 -1
modeling_srv1_tp.py CHANGED
@@ -838,7 +838,7 @@ class SRV1ForCausalLMParallel(SRV1ForCausalLM):
838
  revision = kwargs.get("revision", None)
839
  trust_remote_code = kwargs.get("trust_remote_code", False)
840
  quantize = kwargs.get("quantize", None)
841
- dtype = kwargs.get("dtype", None)
842
  if dtype is None:
843
  dtype = config.torch_dtype
844
 
 
838
  revision = kwargs.get("revision", None)
839
  trust_remote_code = kwargs.get("trust_remote_code", False)
840
  quantize = kwargs.get("quantize", None)
841
+ dtype = kwargs.get("torch_dtype", None)
842
  if dtype is None:
843
  dtype = config.torch_dtype
844