版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/ZWX2445205419/article/details/88667774
数据读取
import tensorflow as tf
test_file = 'test.txt'
train_file = 'train.txt'
def get_items(filename):
filenames = []
labels = []
for line in open(filename):
filename, label = line.strip('\n').split()
filenames.append(filename)
labels.append(label)
return tf.constant(filenames), tf.constant(labels)
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string)
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
test_filenames, test_labels = get_items(test_file)
train_filenames, train_labels = get_items(train_file)
train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))
train_dataset = train_dataset.map(_parse_function, num_parallel_calls=4).repeat().shuffle(buffer_size=10000).batch(32).prefetch(buffer_size=1000)
test_dataset = tf.data.Dataset.from_tensor_slices((test_filenames, test_labels))
test_dataset = test_dataset.map(_parse_function, num_parallel_calls=4).repeat().batch(32).prefetch(buffer_size=1000)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes)
next_element = iterator.get_next()
train_iterator = train_dataset.make_one_shot_iterator()
test_iterator = test_dataset.make_initializable_iterator()
sess = tf.Session()
train_handle = sess.run(train_iterator.string_handle())
test_handle = sess.run(test_iterator.string_handle())
while True:
print('train')
for _ in range(20):
inputs, labels = sess.run(next_element, feed_dict={handle: train_handle})
print(inputs.shape, labels.shape)
print('test')
sess.run(test_iterator.initializer)
for _ in range(5):
inputs, labels = sess.run(next_element, feed_dict={handle: test_handle})
print(inputs.shape, labels.shape)
Tensorflow导入数据探究
#! -*- coding: utf-8 -*-
import tensorflow as tf
train_x = tf.range(0, 1000)
train_y = tf.range(0, 1000)
test_x = tf.range(0, 100)
test_y = tf.range(0, 100)
train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
train_dataset = train_dataset.shuffle(buffer_size=100).batch(batch_size=10).repeat().prefetch(buffer_size=20)
train_iterator = train_dataset.make_one_shot_iterator()
train_next_element = train_iterator.get_next()
test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_dataset = test_dataset.batch(10).prefetch(buffer_size=20)
test_iterator = test_dataset.make_initializable_iterator()
test_next_element = test_iterator.get_next()
with tf.Session() as sess:
for _ in range(5):
for _ in range(10):
train_x, train_y = sess.run(train_next_element)
print('train: ', train_x, train_y)
sess.run(test_iterator.initializer)
for _ in range(5):
test_x, test_y = sess.run(test_next_element)
print('test: ', test_x, test_y)
训练集使用make_one_shot_iterator()
,且进行了shuffle
,batch
,repeat
,prefetch
等操作
测试集使用make_initializable_iterator()
,且没有进行shuffle
和repeat
操作
其结果为:
train: [ 39 64 37 86 1 23 31 100 101 107] [ 39 64 37 86 1 23 31 100 101 107]
train: [73 32 70 67 81 44 30 50 7 12] [73 32 70 67 81 44 30 50 7 12]
train: [114 116 119 36 121 103 63 80 28 2] [114 116 119 36 121 103 63 80 28 2]
train: [ 58 104 102 46 10 95 41 133 62 96] [ 58 104 102 46 10 95 41 133 62 96]
train: [ 5 4 130 132 15 89 43 54 99 126] [ 5 4 130 132 15 89 43 54 99 126]
train: [ 17 51 123 48 113 57 61 59 93 142] [ 17 51 123 48 113 57 61 59 93 142]
train: [ 52 128 6 148 158 20 9 161 71 53] [ 52 128 6 148 158 20 9 161 71 53]
train: [150 149 156 55 77 166 172 124 122 153] [150 149 156 55 77 166 172 124 122 153]
train: [134 79 118 136 115 127 0 16 164 180] [134 79 118 136 115 127 0 16 164 180]
train: [ 26 137 179 49 129 11 38 195 25 197] [ 26 137 179 49 129 11 38 195 25 197]
test: [0 1 2 3 4 5 6 7 8 9] [0 1 2 3 4 5 6 7 8 9]
test: [10 11 12 13 14 15 16 17 18 19] [10 11 12 13 14 15 16 17 18 19]
test: [20 21 22 23 24 25 26 27 28 29] [20 21 22 23 24 25 26 27 28 29]
test: [30 31 32 33 34 35 36 37 38 39] [30 31 32 33 34 35 36 37 38 39]
test: [40 41 42 43 44 45 46 47 48 49] [40 41 42 43 44 45 46 47 48 49]
train: [171 184 91 189 155 35 181 83 24 203] [171 184 91 189 155 35 181 83 24 203]
train: [160 174 33 75 204 82 170 85 177 154] [160 174 33 75 204 82 170 85 177 154]
train: [109 151 22 196 198 18 192 40 186 215] [109 151 22 196 198 18 192 40 186 215]
train: [147 88 66 3 182 205 223 229 145 237] [147 88 66 3 182 205 223 229 145 237]
train: [210 218 178 68 162 19 243 222 138 165] [210 218 178 68 162 19 243 222 138 165]
train: [ 42 167 221 13 92 131 72 239 236 110] [ 42 167 221 13 92 131 72 239 236 110]
train: [ 14 106 233 163 191 90 98 230 251 235] [ 14 106 233 163 191 90 98 230 251 235]
train: [225 270 211 45 255 274 259 56 265 252] [225 270 211 45 255 274 259 56 265 252]
train: [185 246 208 78 74 27 268 245 261 202] [185 246 208 78 74 27 268 245 261 202]
train: [152 213 258 256 176 292 284 286 248 281] [152 213 258 256 176 292 284 286 248 281]
test: [0 1 2 3 4 5 6 7 8 9] [0 1 2 3 4 5 6 7 8 9]
test: [10 11 12 13 14 15 16 17 18 19] [10 11 12 13 14 15 16 17 18 19]
test: [20 21 22 23 24 25 26 27 28 29] [20 21 22 23 24 25 26 27 28 29]
test: [30 31 32 33 34 35 36 37 38 39] [30 31 32 33 34 35 36 37 38 39]
test: [40 41 42 43 44 45 46 47 48 49] [40 41 42 43 44 45 46 47 48 49]
train: [105 257 146 87 226 157 183 287 244 273] [105 257 146 87 226 157 183 287 244 273]
train: [253 289 190 283 173 250 234 8 65 217] [253 289 190 283 173 250 234 8 65 217]
train: [117 305 269 201 282 288 315 309 140 76] [117 305 269 201 282 288 315 309 140 76]
train: [188 141 111 175 326 300 240 187 321 97] [188 141 111 175 326 300 240 187 321 97]
train: [329 302 249 220 69 84 307 200 267 320] [329 302 249 220 69 84 307 200 267 320]
train: [216 209 304 135 314 232 308 334 347 357] [216 209 304 135 314 232 308 334 347 357]
train: [335 264 351 272 263 313 199 299 341 194] [335 264 351 272 263 313 199 299 341 194]
train: [168 298 144 193 323 290 296 346 94 277] [168 298 144 193 323 290 296 346 94 277]
train: [ 60 139 339 366 356 280 332 348 247 291] [ 60 139 339 366 356 280 332 348 247 291]
train: [112 262 324 125 333 353 227 231 393 297] [112 262 324 125 333 353 227 231 393 297]
test: [0 1 2 3 4 5 6 7 8 9] [0 1 2 3 4 5 6 7 8 9]
test: [10 11 12 13 14 15 16 17 18 19] [10 11 12 13 14 15 16 17 18 19]
test: [20 21 22 23 24 25 26 27 28 29] [20 21 22 23 24 25 26 27 28 29]
test: [30 31 32 33 34 35 36 37 38 39] [30 31 32 33 34 35 36 37 38 39]
test: [40 41 42 43 44 45 46 47 48 49] [40 41 42 43 44 45 46 47 48 49]
train: [391 337 368 344 212 241 359 260 159 325] [391 337 368 344 212 241 359 260 159 325]
train: [303 238 362 349 364 322 345 405 409 417] [303 238 362 349 364 322 345 405 409 417]
train: [343 404 310 358 384 29 418 407 412 271] [343 404 310 358 384 29 418 407 412 271]
train: [413 120 381 336 398 294 389 228 376 328] [413 120 381 336 398 294 389 228 376 328]
train: [436 206 169 385 420 383 395 372 367 396] [436 206 169 385 420 383 395 372 367 396]
train: [438 275 449 399 371 439 433 34 327 388] [438 275 449 399 371 439 433 34 327 388]
train: [440 295 427 432 301 401 431 458 442 278] [440 295 427 432 301 401 431 458 442 278]
train: [370 468 414 279 403 276 456 378 459 463] [370 468 414 279 403 276 456 378 459 463]
train: [410 435 350 471 406 479 415 214 21 352] [410 435 350 471 406 479 415 214 21 352]
train: [424 312 361 489 360 457 316 402 447 428] [424 312 361 489 360 457 316 402 447 428]
test: [0 1 2 3 4 5 6 7 8 9] [0 1 2 3 4 5 6 7 8 9]
test: [10 11 12 13 14 15 16 17 18 19] [10 11 12 13 14 15 16 17 18 19]
test: [20 21 22 23 24 25 26 27 28 29] [20 21 22 23 24 25 26 27 28 29]
test: [30 31 32 33 34 35 36 37 38 39] [30 31 32 33 34 35 36 37 38 39]
test: [40 41 42 43 44 45 46 47 48 49] [40 41 42 43 44 45 46 47 48 49]
train: [397 317 451 379 470 491 318 469 421 494] [397 317 451 379 470 491 318 469 421 494]
train: [455 475 464 426 444 365 500 474 502 293] [455 475 464 426 444 365 500 474 502 293]
train: [485 434 330 481 386 480 430 319 266 411] [485 434 330 481 386 480 430 319 266 411]
train: [505 373 511 108 478 461 482 311 495 504] [505 373 511 108 478 461 482 311 495 504]
train: [462 539 536 448 219 527 531 499 453 533] [462 539 536 448 219 527 531 499 453 533]
train: [416 465 497 375 331 518 355 419 450 369] [416 465 497 375 331 518 355 419 450 369]
train: [390 549 496 306 392 467 374 422 477 508] [390 549 496 306 392 467 374 422 477 508]
train: [254 460 473 547 572 445 425 570 554 525] [254 460 473 547 572 445 425 570 554 525]
train: [377 285 569 517 566 542 560 529 544 564] [377 285 569 517 566 542 560 529 544 564]
train: [552 143 503 476 576 573 512 591 571 488] [552 143 503 476 576 573 512 591 571 488]
test: [0 1 2 3 4 5 6 7 8 9] [0 1 2 3 4 5 6 7 8 9]
test: [10 11 12 13 14 15 16 17 18 19] [10 11 12 13 14 15 16 17 18 19]
test: [20 21 22 23 24 25 26 27 28 29] [20 21 22 23 24 25 26 27 28 29]
test: [30 31 32 33 34 35 36 37 38 39] [30 31 32 33 34 35 36 37 38 39]
test: [40 41 42 43 44 45 46 47 48 49] [40 41 42 43 44 45 46 47 48 49]
可以看到,训练集的每个batch都进行了shuffle,而测试集每次都从头开始重新取50个数据。
我们改动一下,使得训练时随机获取,且每训练10个batch进行一次测试,测试时从测试集中随机获取5个batch
#! -*- coding: utf-8 -*-
import tensorflow as tf
train_x = tf.range(0, 1000)
train_y = tf.range(0, 1000)
test_x = tf.range(0, 100)
test_y = tf.range(0, 100)
train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
train_dataset = train_dataset.shuffle(buffer_size=100).batch(batch_size=10).repeat().prefetch(buffer_size=20)
train_iterator = train_dataset.make_one_shot_iterator()
train_next_element = train_iterator.get_next()
test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_dataset = test_dataset.shuffle(buffer_size=100).batch(10).prefetch(buffer_size=20)
test_iterator = test_dataset.make_initializable_iterator()
test_next_element = test_iterator.get_next()
with tf.Session() as sess:
for _ in range(5):
for _ in range(10):
train_x, train_y = sess.run(train_next_element)
print('train: ', train_x, train_y)
sess.run(test_iterator.initializer)
for _ in range(5):
test_x, test_y = sess.run(test_next_element)
print('test: ', test_x, test_y)
其结果为:
train: [ 83 6 18 16 90 54 12 57 29 103] [ 83 6 18 16 90 54 12 57 29 103]
train: [104 88 35 55 43 63 108 38 8 61] [104 88 35 55 43 63 108 38 8 61]
train: [ 4 91 28 42 77 30 37 31 114 1] [ 4 91 28 42 77 30 37 31 114 1]
train: [ 68 78 7 34 87 125 73 62 60 24] [ 68 78 7 34 87 125 73 62 60 24]
train: [ 3 138 100 139 133 17 48 11 130 39] [ 3 138 100 139 133 17 48 11 130 39]
train: [ 74 51 44 46 153 126 142 143 82 50] [ 74 51 44 46 153 126 142 143 82 50]
train: [148 22 140 136 14 0 33 113 20 109] [148 22 140 136 14 0 33 113 20 109]
train: [168 169 155 122 5 117 166 69 115 176] [168 169 155 122 5 117 166 69 115 176]
train: [ 21 27 72 162 49 2 111 99 127 70] [ 21 27 72 162 49 2 111 99 127 70]
train: [ 56 188 184 175 152 150 164 196 158 187] [ 56 188 184 175 152 150 164 196 158 187]
test: [43 17 83 36 90 96 62 95 0 63] [43 17 83 36 90 96 62 95 0 63]
test: [16 28 87 34 15 22 68 42 35 25] [16 28 87 34 15 22 68 42 35 25]
test: [ 9 65 30 31 33 73 70 92 40 80] [ 9 65 30 31 33 73 70 92 40 80]
test: [23 2 52 85 20 75 97 4 61 91] [23 2 52 85 20 75 97 4 61 91]
test: [ 8 71 58 46 55 64 37 6 76 11] [ 8 71 58 46 55 64 37 6 76 11]
train: [ 96 95 186 19 181 203 191 144 32 75] [ 96 95 186 19 181 203 191 144 32 75]
train: [156 118 146 178 59 85 132 193 157 159] [156 118 146 178 59 85 132 193 157 159]
train: [ 45 89 198 86 201 207 225 105 149 151] [ 45 89 198 86 201 207 225 105 149 151]
train: [218 183 214 123 94 216 121 102 227 180] [218 183 214 123 94 216 121 102 227 180]
train: [219 106 167 15 228 92 230 231 141 245] [219 106 167 15 228 92 230 231 141 245]
train: [239 170 200 217 213 84 120 107 163 81] [239 170 200 217 213 84 120 107 163 81]
train: [ 13 241 195 173 248 25 64 194 254 98] [ 13 241 195 173 248 25 64 194 254 98]
train: [267 256 223 234 202 274 154 76 179 259] [267 256 223 234 202 274 154 76 179 259]
train: [211 275 222 171 97 137 182 220 185 277] [211 275 222 171 97 137 182 220 185 277]
train: [272 265 290 232 36 161 260 255 128 205] [272 265 290 232 36 161 260 255 128 205]
test: [14 78 79 66 73 16 31 18 91 48] [14 78 79 66 73 16 31 18 91 48]
test: [68 34 50 17 70 92 28 94 27 6] [68 34 50 17 70 92 28 94 27 6]
test: [62 46 51 32 99 44 81 59 25 54] [62 46 51 32 99 44 81 59 25 54]
test: [58 5 95 64 63 10 4 89 67 98] [58 5 95 64 63 10 4 89 67 98]
test: [76 65 15 33 19 74 22 45 7 13] [76 65 15 33 19 74 22 45 7 13]
train: [247 293 270 302 172 129 199 9 306 10] [247 293 270 302 172 129 199 9 306 10]
train: [262 257 280 235 208 309 282 124 263 252] [262 257 280 235 208 309 282 124 263 252]
train: [221 314 271 147 289 224 47 79 287 313] [221 314 271 147 289 224 47 79 287 313]
train: [192 261 209 269 112 330 204 316 298 71] [192 261 209 269 112 330 204 316 298 71]
train: [110 276 131 165 334 310 329 336 229 331] [110 276 131 165 334 310 329 336 229 331]
train: [339 300 341 284 295 338 285 53 342 305] [339 300 341 284 295 338 285 53 342 305]
train: [212 174 52 297 312 317 322 320 244 237] [212 174 52 297 312 317 322 320 244 237]
train: [296 299 93 286 246 249 67 366 324 26] [296 299 93 286 246 249 67 366 324 26]
train: [332 251 337 266 315 283 101 351 358 236] [332 251 337 266 315 283 101 351 358 236]
train: [145 354 363 273 373 250 65 352 393 340] [145 354 363 273 373 250 65 352 393 340]
test: [38 13 53 74 64 94 61 91 56 59] [38 13 53 74 64 94 61 91 56 59]
test: [82 57 54 11 26 66 92 0 1 60] [82 57 54 11 26 66 92 0 1 60]
test: [70 12 17 15 31 37 41 3 99 80] [70 12 17 15 31 37 41 3 99 80]
test: [77 19 63 72 89 43 81 97 50 85] [77 19 63 72 89 43 81 97 50 85]
test: [75 18 36 47 4 33 24 22 39 46] [75 18 36 47 4 33 24 22 39 46]
train: [226 370 346 362 304 375 402 374 376 58] [226 370 346 362 304 375 402 374 376 58]
train: [383 80 190 233 258 206 360 367 344 307] [383 80 190 233 258 206 360 367 344 307]
train: [391 23 323 328 368 395 197 409 410 372] [391 23 323 328 368 395 197 409 410 372]
train: [243 160 430 268 429 353 343 414 326 294] [243 160 430 268 429 353 343 414 326 294]
train: [279 407 428 419 397 406 253 436 119 303] [279 407 428 419 397 406 253 436 119 303]
train: [347 398 442 382 377 421 288 345 386 432] [347 398 442 382 377 421 288 345 386 432]
train: [349 318 447 333 451 462 444 308 418 403] [349 318 447 333 451 462 444 308 418 403]
train: [426 401 457 413 423 189 371 327 459 468] [426 401 457 413 423 189 371 327 459 468]
train: [350 458 292 41 479 399 465 454 238 387] [350 458 292 41 479 399 465 454 238 387]
train: [427 455 134 435 461 66 477 378 416 456] [427 455 134 435 461 66 477 378 416 456]
test: [23 46 59 53 10 81 13 48 43 27] [23 46 59 53 10 81 13 48 43 27]
test: [25 70 69 98 55 68 74 92 29 3] [25 70 69 98 55 68 74 92 29 3]
test: [11 39 80 1 73 6 32 41 75 96] [11 39 80 1 73 6 32 41 75 96]
test: [17 24 97 95 9 51 67 76 44 58] [17 24 97 95 9 51 67 76 44 58]
test: [21 45 71 18 15 72 0 88 91 31] [21 45 71 18 15 72 0 88 91 31]
train: [215 311 483 446 396 440 503 364 496 450] [215 311 483 446 396 440 503 364 496 450]
train: [385 392 493 361 473 480 443 408 400 507] [385 392 493 361 473 480 443 408 400 507]
train: [501 481 500 445 449 489 475 335 488 412] [501 481 500 445 449 489 475 335 488 412]
train: [510 291 514 469 434 476 422 487 498 490] [510 291 514 469 434 476 422 487 498 490]
train: [357 474 264 321 516 384 452 380 531 495] [357 474 264 321 516 384 452 380 531 495]
train: [542 509 431 532 135 453 554 524 301 539] [542 509 431 532 135 453 554 524 301 539]
train: [441 319 278 521 448 460 553 558 520 379] [441 319 278 521 448 460 553 558 520 379]
train: [411 505 562 537 388 437 325 116 560 389] [411 505 562 537 388 437 325 116 560 389]
train: [471 552 544 569 484 526 499 548 525 390] [471 552 544 569 484 526 499 548 525 390]
train: [551 355 513 497 438 466 540 579 369 550] [551 355 513 497 438 466 540 579 369 550]
test: [25 64 39 42 90 23 86 20 55 60] [25 64 39 42 90 23 86 20 55 60]
test: [94 53 14 73 16 81 84 13 92 24] [94 53 14 73 16 81 84 13 92 24]
test: [79 67 7 2 61 10 99 36 28 95] [79 67 7 2 61 10 99 36 28 95]
test: [31 38 76 33 15 41 48 59 11 4] [31 38 76 33 15 41 48 59 11 4]
test: [68 88 37 96 70 6 93 78 62 26] [68 88 37 96 70 6 93 78 62 26]
这样的设置是比较合理的,但测试集比较大时,进行全量测试是比较耗时的,我们设置每隔多少个训练batch,随意抽取一些测试集的batch查看模型效果。
设置GPU使用资源
import tensorflow as tf
gpu_config = tf.GPUOptions(
allow_growth=True, # 刚开始会分配少量的GPU容量,然后按需慢慢地增加,由于不会释放内存,所以会导致碎片
per_process_gpu_memory_fraction=0.7, # 给GPU分配固定大小的计算资源
)
config = tf.ConfigProto(
log_device_placement=True, # 是否打印设备分配日志
allow_soft_placement=True, # 如果指定的设备不存在,允许TF自动分配设备
gpu_options=gpu_config, # 设置GPU使用资源
)
with tf.Session(config=config) as sess:
with tf.device("/gpu:0"): # 指定GPU运算
a = tf.placeholder(tf.int16)
b = tf.placeholder(tf.int16)
add = tf.add(a, b)
print(sess.run(add, feed_dict={a: 1, b: 2}))
使用Estimator
参考:https://www.jianshu.com/p/e343758a185e
#! -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
tf.logging.set_verbosity(tf.logging.INFO) # 设定输出日志的模式
# 我们的程序代码将放在这里
def cnn_model_fn(features, labels, mode):
# 输入层,-1表示自动计算,这里是图片批次大小,宽高各28,最后1表示颜色单色
input_layer = tf.reshape(features["x"], [-1, 28, 28, 1])
# 1号卷积层,过滤32次,核心区域5x5,激活函数relu
conv1 = tf.layers.conv2d(
inputs=input_layer, # 接收上面创建的输入层输出的张量
filters=32,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu)
# 1号池化层,接收1号卷积层输出的张量
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
# 2号卷积层
conv2 = tf.layers.conv2d(
inputs=pool1, # 继续1号池化层的输出
filters=64,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu)
# 2号池化层
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
# 对2号池化层的输入变换张量形状
pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
# 密度层
dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
# 丢弃层进行简化
dropout = tf.layers.dropout(
inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)
# 使用密度层作为最终输出,unit可能的分类数量
logits = tf.layers.dense(inputs=dropout, units=10)
# 预测和评价使用的输出数据内容
predictions = {
# 产生预测,argmax输出第一个轴向的最大数值
"classes": tf.argmax(input=logits, axis=1),
# 输出可能性
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
}
# 以下是根据mode切换的三个不同的方法,都返回tf.estimator.EstimatorSpec对象
# 预测
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# 损失函数(训练与评价使用),稀疏柔性最大值交叉熵
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
# 训练,使用梯度下降优化器,
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
train_op = optimizer.minimize(
loss=loss,
global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
# 评价函数(上面两个mode之外else)添加评价度量(for EVAL mode)
eval_metric_ops = {
"accuracy": tf.metrics.accuracy(
labels=labels, predictions=predictions["classes"])}
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
dir_path = os.path.dirname(os.path.realpath(__file__))
data_path = os.path.join(dir_path, 'MNIST_data')
def main(args):
# 载入训练和测试数据
mnist = input_data.read_data_sets(data_path)
train_data = mnist.train.images # 得到np.array
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images # 得到np.array
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)
# 创建估算器
mnist_classifier = tf.estimator.Estimator(
model_fn=cnn_model_fn, model_dir="/tmp/mnist_convnet_model")
# 设置输出预测的日志
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=50)
# 训练喂食函数
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": train_data},
y=train_labels,
batch_size=100,
num_epochs=None,
shuffle=True)
# 启动训练
mnist_classifier.train(
input_fn=train_input_fn,
steps=20000,
hooks=[logging_hook])
# 评价喂食函数
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": eval_data},
y=eval_labels,
num_epochs=1,
shuffle=False)
# 启动评价并输出结果
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print(eval_results)
# 这个文件能够直接运行,也可以作为模块被其他文件载入
if __name__ == "__main__":
tf.app.run()
TensorFlow搭建网络模型
tensorflow搭建网络的库: tf.keras, tf.nn, tf.layers
- tf.nn
最底层的函数,其他各种库基本都是基于这个底层库来进行扩展的 - tf.layers
比tf.nn更高级的库,对tf.nn进行了多方位功能扩展,就是用tf.nn造的轮子。最大的特点就是,库中每个函数都有相应的类 - tf.keras
tf.keras是基于tf.layers和tf.nn的高度封装