File size: 572 Bytes
f69d33f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
build:
  cuda: "11.2"
  gpu: true
  python_version: "3.8"
  system_packages:
    - "libgl1-mesa-glx"
    - "libglib2.0-0"
  python_packages:
    - "numpy==1.21.1"
    - "ipython==7.21.0"
    - "absl-py==1.0.0"
    - "chex==0.1.3"
    - "clu==0.0.6"
    - "einops==0.4.1"
    - "flax==0.4.1"
    - "ml-collections==0.1.1"
    - "pandas==1.4.2"
    - "tensorflow==2.8.0"
  run:
    - pip install --upgrade pip
    - pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html

predict: "maxim/predict.py:Predictor"