Deep learning/Error

Error : logits and labels must have the same shape

비비이잉 2021. 7. 1. 13:50
반응형

 

 

ValueError: logits and labels must have the same shape (() vs (39, ))

 

이 에러,,, 너무 많이 봐서 진절머리가 난다

입사하고 초반에는 이해도가 떨어져서 한참 헤맸었는데 다시 나타난 에러,, one-hot encoding으로 금방 해결할 수 있었다.

이 글을 읽는 다른 분들은 나처럼 삽질을 하지 않았으면 하는 마음에 정리해본다. 구글링을 아무리해도 loss function을 바꿔라 ! 라는 말밖에 없는데,,, 그게 문제가 아니라 로스를 계산할 수 있는 shape로 맞춰주지 않았기때문에 오류가 뜬거였다... 계산을 못하니까 !

 

from tensorflow.keras.utils import to_categorical

train_labels = [0,1,2,3,4,5,6,7,8,9,10,11,12]
one_hot = to_categorical(train_labels)

#print(train_oh_labels)




print('====================')
print(one_hot)
print(len(one_hot))

--> [[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
13

tensorflow.keras.utils에 보면 아주 좋은 기능들이 많이 있다. 그 중에 이번에 사용할 것은

to_categorical 이다. 임의로 하나의 train_labels라는 리스트를 만들고 to_categorical로 변환해보면 다음과 같은 결과가 나타남을 알 수 있다.

 

내가 만든 batch는 다음과 같다.

<tf.Tensor: shape=(39,), dtype=int32, numpy=
array([ 0,  0,  0,  2,  2,  2, 12, 12, 12,  4,  4,  4,  9,  9,  9, 11, 11,
       11,  8,  8,  8, 10, 10, 10,  5,  5,  5,  3,  3,  3,  6,  6,  6,  1,

 

총 13개의 클래스에서 각 클래스당 3개씩 뽑아서 이어붙인 39개로 이루어져있다.

근데 내가 만든 라벨 부분을 보면 저렇게 각각 0부터 12까지 하나의 숫자로만 이루어져있다.

 

이렇게 출력을 해서 보니, 저 error의 뜻이 무슨말인지 이해가 갔다.

 

to_categorical을 이용해서

 

i=0
for item in train_dataset_test:
    print('i = ',i)
    print(item)
    i+=1


--------------------------------------------------------

<tf.Tensor: shape=(39, 13), dtype=float32, numpy=
array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]],
      dtype=float32)>)

 

다음과 같은 코드를 이용해서 이렇게 바꿔줬다.

from tensorflow.keras.utils import to_cateogrical

for_one_hot = add + train_oh_labels
one_hot = to_categorical(for_one_hot)

 

그런 다음, 똑같이 학습시킬때 넣어줬더니 에러가 안뜨고 돌아가 돌아가 !!!!

 
 

 

반응형