jacobmp commited on
Commit
81788df
·
verified ·
1 Parent(s): e76d06f

core code added

Browse files
Files changed (1) hide show
  1. app.py +49 -2
app.py CHANGED
@@ -1,7 +1,54 @@
1
  import gradio as gr
2
 
 
 
 
 
 
 
 
3
  def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
  import gradio as gr
2
 
3
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
+ from huggingface_hub import hf_hub_download
5
+ from transformers import AutoModel
6
+ from ultralytics import YOLO
7
+ from PIL import Image
8
+ import torch
9
+
10
  def greet(name):
11
+ LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection"
12
+ #OCR_MODEL_PATH = "Kansallisarkisto/multicentury-htr-model"
13
+ OCR_MODEL_PATH = "microsoft/trocr-large-handwritten"
14
+
15
+ # Load the model and processor
16
+ processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH)
17
+ model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH)
18
+
19
+ # Open an image of handwritten text
20
+ image = Image.open("/content/drive/My Dive/ocr/img/hhhhhh-x595.jpeg").convert("RGB")
21
+
22
+ try:
23
+ # Load the trained line detection model
24
+ cached_model_path = hf_hub_download(repo_id = LINE_MODEL_PATH, filename="lines_20240827.pt")
25
+ line_model = YOLO(cached_model_path)
26
+ except Exception as e:
27
+ print('Failed to load the line detection model: %s' % e)
28
+
29
+ results = line_model.predict(source = image)[0]
30
+ full_text = ""
31
+ boxes = results.boxes.xyxy
32
+ indices = boxes[:,1].sort().indices
33
+ boxes = boxes[indices]
34
+ for box in boxes:
35
+ #box = box + torch.tensor([-10,0, 10, 0])
36
+ box = [tensor.item() for tensor in box]
37
+ #print(box)
38
+ lineImg = image.crop(tuple(list(box)))
39
+ #plt.imshow(lineImg)
40
+ #plt.show()
41
+
42
+ # Preprocess and predict
43
+ pixel_values = processor(lineImg, return_tensors="pt").pixel_values
44
+ generated_ids = model.generate(pixel_values)
45
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
46
+ #print(generated_text)
47
+ full_text += generated_text
48
+ #print("--------------------------------------------")
49
+
50
+ return full_text
51
+ #print("--------------------------------------------")
52
 
53
+ demo = gr.Interface(fn=greet, inputs="image", outputs="text")
54
  demo.launch()