fix 'get_avg_color' and implement score function

This commit is contained in:
Rusty Striker 2024-01-20 21:38:12 +02:00
parent 5a51510209
commit 2f98d26842
Signed by: RustyStriker
GPG Key ID: 9DBDBC7C48FC3C31
1 changed files with 30 additions and 3 deletions

View File

@ -4,6 +4,10 @@ import h5py as h5
db = None db = None
def init_train():
''' Default init based on the train set `train.h5` '''
init('train.h5')
def init(path): def init(path):
''' initializes the database, must be called before any use ''' ''' initializes the database, must be called before any use '''
global db global db
@ -51,18 +55,41 @@ def get_avg_color(img, mask):
gets avg color from an image that is underneath a mask, gets avg color from an image that is underneath a mask,
img and mask needs to be of same size(in x,y) but img can img and mask needs to be of same size(in x,y) but img can
have any third dimension size it want(usually 3 for rgb or 1 for grayscale) have any third dimension size it want(usually 3 for rgb or 1 for grayscale)
mask needs to be of shape(img.width, img.height, 1)
''' '''
sx, sy, sw = img.shape sx, sy, sw = img.shape
mx, my = mask.shape mx, my = mask.shape
if sx != mx or sy != my: if sx != mx or sy != my:
print('Image and mask shape doesnt match!') print('Image and mask size doesnt match!')
return None return None
avg = np.array(sw, dtype=np.float32) avg = np.zeros(sw, dtype=np.float32)
count = 0.0 count = 0.0
for x in range(sx): for x in range(sx):
for y in range(sy): for y in range(sy):
m = mas[x, y] m = mask[x, y]
avg += img[x, y] * m avg += img[x, y] * m
count += m count += m
avg /= count avg /= count
return avg return avg
def calc_score(img, mask, avg_color):
'''
Calculates the score for each mask with each color
'''
sx, sy, sw = img.shape
mx, my = mask.shape
if sx != mx or sy != my:
print('Image and mask size doesnt match!')
return 0.0
if sw != avg_color.shape[0]:
print('Image width doesnt match color width!')
return 0.0
score = 0.0
for x in range(sx):
for y in range(sy):
m = 0.5 - mask[x, y]
diff = img[x, y] - avg_color
mag = np.sqrt(diff.dot(diff)) # calculate magnitude
score += mag * m
return score