作者 | 李秋鍵
責(zé)編 | Carol
出品 | AI科技大本營(ID:rgznai100)
自然語言處理作為人工智能的一個(gè)重要分支,在我們的生活中得到了廣泛應(yīng)用。其中RNN算法作為自然語言處理的經(jīng)典算法之一,是文本生成的重要手段。而今天我們就將利用RNN算法建立一個(gè)寫歌詞的軟件。其中的界面如下:
RNN指的是循環(huán)神經(jīng)網(wǎng)絡(luò),Recurrent Neural Network。不同于前饋神經(jīng)網(wǎng)絡(luò)的是,RNN可以利用它內(nèi)部的記憶來處理任意時(shí)序的輸入序列,這讓它可以更容易處理如不分段的手寫識(shí)別、語音識(shí)別等。
RNN模型有比較多的變種,這里介紹最主流的RNN模型結(jié)構(gòu)如下:
上圖中左邊是RNN模型沒有按時(shí)間展開的圖,如果按時(shí)間序列展開,則是上圖中的右邊部分。我們重點(diǎn)觀察右邊部分的圖。
這幅圖描述了在序列索引號(hào)tt附近RNN的模型。其中:
-
x(t)x(t)代表在序列索引號(hào)tt時(shí)訓(xùn)練樣本的輸入。同樣的,x(t?1)x(t?1)和x(t 1)x(t 1)代表在序列索引號(hào)t?1t?1和t 1t 1時(shí)訓(xùn)練樣本的輸入。
h(t)h(t)代表在序列索引號(hào)tt時(shí)模型的隱藏狀態(tài)。h(t)h(t)由x(t)x(t)和h(t?1)h(t?1)共同決定。
o(t)o(t)代表在序列索引號(hào)tt時(shí)模型的輸出。o(t)o(t)只由模型當(dāng)前的隱藏狀態(tài)h(t)h(t)決定。
L(t)L(t)代表在序列索引號(hào)tt時(shí)模型的損失函數(shù)。
y(t)y(t)代表在序列索引號(hào)tt時(shí)訓(xùn)練樣本序列的真實(shí)輸出。
U,W,VU,W,V這三個(gè)矩陣是我們的模型的線性關(guān)系參數(shù),它在整個(gè)RNN網(wǎng)絡(luò)中是共享的,這點(diǎn)和DNN很不相同。也正因?yàn)槭枪蚕砹耍w現(xiàn)了RNN的模型的“循環(huán)反饋”的思想。
基于以上認(rèn)知,我們開始搭建我們的軟件。
實(shí)驗(yàn)前的準(zhǔn)備
首先我們使用的python版本是3.6.5所用到的庫有TensorFlow,是用來訓(xùn)練和加載神經(jīng)網(wǎng)絡(luò)常見的框架,常常用于數(shù)值計(jì)算的開源軟件庫。節(jié)點(diǎn)表示數(shù)學(xué)操作,線則表示在節(jié)點(diǎn)間相互聯(lián)系的多維數(shù)據(jù)數(shù)組,即張量(tensor);tkinter用來繪制GUI界面的庫;
Pillow庫在此項(xiàng)目中用來處理圖片和字體等問題。因?yàn)槲覀兊能浖皇强瞻妆尘暗?。需要借助Image函數(shù)添加背景。
RNN算法搭建
1、數(shù)據(jù)集處理和準(zhǔn)備:
我們訓(xùn)練的數(shù)據(jù)集使用各種歌手的歌詞本作為訓(xùn)練集。其中數(shù)據(jù)集放在date.txt里,其中部分?jǐn)?shù)據(jù)集如下:
2、模型的訓(xùn)練:
模型訓(xùn)練的代碼直接運(yùn)行train.py即可訓(xùn)練。其中流程如下:
-
首先要讀取數(shù)據(jù)集
設(shè)定訓(xùn)練批次、步數(shù)等等
數(shù)據(jù)載入RNN進(jìn)行訓(xùn)練即可
其中代碼如下:
def train:
filename = \’date.txt\’
with open(filename, \’r\’, encoding=\’utf-8\’) as f:
text = f.read
reader = TxtReader(text=text, maxVocab=3500)
reader.save(\’voc.data\’)
array = reader.text2array(text)
generator = GetBatch(array, n_seqs=100, n_steps=100)
model = CharRNN(
numClasses = reader.vocabLen,
mode =\’train\’,
numSeqs = 100,
numSteps = 100,
lstmSize = 128,
numLayers = 2,
lr = 0.001,
Trainprob = 0.5,
useEmbedding = True,
numEmbedding = 128
)
model.train(
generator,
logStep = 10,
saveStep = 1000,
maxStep = 100000
)
3、RNN網(wǎng)絡(luò)搭建:
RNN算法的搭建,我們定義整個(gè)神經(jīng)網(wǎng)絡(luò)類,然后分別定義初始化、輸入、神經(jīng)元定義等函數(shù)。損失函數(shù)和優(yōu)化器使用均方差和AdamOptimizer優(yōu)化器即可。
部分代碼如下:
# 創(chuàng)建輸入
def buildInputs(self):
numSeqs = self.numSeqs
numSteps = self.numSteps
numClasses = self.numClasses
numEmbedding = self.numEmbedding
useEmbedding = self.useEmbedding
with tf.name_scope(\’inputs\’):
self.inData = tf.placeholder(tf.int32, shape=(numSeqs, numSteps), name=\’inData\’)
self.targets = tf.placeholder(tf.int32, shape=(numSeqs, numSteps), name=\’targets\’)
self.keepProb = tf.placeholder(tf.float32, name=\’keepProb\’)
# 中文
if useEmbedding:
with tf.device(\”/cpu:0\”):
embedding = tf.get_variable(\’embedding\’, [numClasses, numEmbedding])
self.lstmInputs = tf.nn.embedding_lookup(embedding, self.inData)
# 英文
else:
self.lstmInputs = tf.one_hot(self.inData, numClasses)
# 創(chuàng)建單個(gè)Cell
def buildCell(self, lstmSize, keepProb):
basicCell = tf.nn.rnn_cell.BasicLSTMCell(lstmSize)
drop = tf.nn.rnn_cell.DropoutWrapper(basicCell, output_keep_prob=keepProb)
return drop
# 將單個(gè)Cell堆疊多層
def buildLstm(self):
lstmSize = self.lstmSize
numLayers = self.numLayers
keepProb = self.keepProb
numSeqs = self.numSeqs
numClasses = self.numClasses
with tf.name_scope(\’lstm\’):
multiCell = tf.nn.rnn_cell.MultiRNNCell(
[self.buildCell(lstmSize, keepProb) for _ in range(numLayers)]
)
self.initial_state = multiCell.zero_state(numSeqs, tf.float32)
self.lstmOutputs, self.finalState = tf.nn.dynamic_rnn(multiCell, self.lstmInputs, initial_state=self.initial_state)
seqOutputs = tf.concat(self.lstmOutputs, 1)
x = tf.reshape(seqOutputs, [-1, lstmSize])
with tf.variable_scope(\’softmax\’):
softmax_w = tf.Variable(tf.truncated_normal([lstmSize, numClasses], stddev=0.1))
softmax_b = tf.Variable(tf.zeros(numClasses))
self.logits = tf.matmul(x, softmax_w) softmax_b
self.prediction = tf.nn.softmax(self.logits, name=\’prediction\’)
# 計(jì)算損失
def buildLoss(self):
numClasses = self.numClasses
with tf.name_scope(\’loss\’):
targets = tf.one_hot(self.targets, numClasses)
targets = tf.reshape(targets, self.logits.get_shape)
loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=targets)
self.loss = tf.reduce_mean(loss)
# 創(chuàng)建優(yōu)化器
def buildOptimizer(self):
gradClip = self.gradClip
lr = self.lr
trainVars = tf.trainable_variables
# 限制權(quán)重更新
grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, trainVars), gradClip)
trainOp = tf.train.AdamOptimizer(lr)
self.optimizer = trainOp.apply_gradients(zip(grads, trainVars))
# 訓(xùn)練
def train(self, data, logStep=10, saveStep=1000, savepath=\’./models/\’, maxStep=100000):
if not os.path.exists(savepath):
os.mkdir(savepath)
Trainprob = self.Trainprob
self.session = tf.Session
with self.session as sess:
step = 0
sess.run(tf.global_variables_initializer)
state_now = sess.run(self.initial_state)
for x, y in data:
step = 1
feed_dict = {
self.inData: x,
self.targets: y,
self.keepProb: Trainprob,
self.initial_state: state_now
}
loss, state_now, _ = sess.run([self.loss, self.finalState, self.optimizer], feed_dict=feed_dict)
if step % logStep == 0:
print(\'[INFO]: <step>: {}/{}, loss: {:.4f}\’.format(step, maxStep, loss))
if step % saveStep == 0:
self.saver.save(sess, savepath, global_step=step)
if step > maxStep:
self.saver.save(sess, savepath, global_step=step)
break
# 從前N個(gè)預(yù)測值中選
def GetTopN(self, preds, size, top_n=5):
p = np.squeeze(preds)
p[np.argsort(p)[:-top_n]] = 0
p = p / np.sum(p)
c = np.random.choice(size, 1, p=p)[0]
return c
4、歌詞的生成:
設(shè)置關(guān)鍵詞變量,讀取模型文件,輸出結(jié)果即可。
代碼如下:
def main(_):
reader = TxtReader(filename=\’voc.data\’)
model = CharRNN(
numClasses = reader.vocabLen,
mode = \’test\’,
lstmSize = 128,
numLayers = 2,
useEmbedding = True,
numEmbedding = 128
)
checkpoint = tf.train.latest_checkpoint(\’./models/\’)
model.load(checkpoint)
key=\”雪花\”
prime = reader.text2array(key)
array = model.test(prime, size=reader.vocabLen, n_samples=300)
print(\”《\” key \”》\”)
print(reader.array2text(array))
界面的定義和調(diào)用
界面中我們的布局是文本框、編輯框和按鈕控件。程序的調(diào)用使用批處理文件調(diào)用以達(dá)到顯示運(yùn)行過程的效果。因?yàn)槿绻麤]有運(yùn)行過程,難免會(huì)導(dǎo)致用戶不清楚程序流程而強(qiáng)制運(yùn)行容易導(dǎo)致卡死的情況。
其中Bat里直接寫入:
python song.py
其中過程效果如下:
1、界面布局:
界面布局使用canvas畫布以達(dá)到添加背景圖片的效果。背景圖片設(shè)置為1.jpg,按鈕背景圖片設(shè)置為3.jpg。圖片也可以自己更換掉。然后文本框作為提示的效果,分別定義字體,大小等等即可
代碼如下:
root = tk.Tk
root.title(\’AI寫歌詞\’)
# 背景
canvas = tk.Canvas(root, width=800, height=500, bd=0, highlightthickness=0)
imgpath = \’1.jpg\’
img = Image.open(imgpath)
photo = ImageTk.PhotoImage(img)
imgpath2 = \’3.jpg\’
img2 = Image.open(imgpath2)
photo2 = ImageTk.PhotoImage(img2)
canvas.create_image(700, 400, image=photo)
canvas.pack
label=tk.Label(text=\”請輸入關(guān)鍵詞:\”,font=(\”微軟雅黑\”,20))
entry = tk.Entry(root, insertbackground=\’blue\’, highlightthickness=2,font=(\”微軟雅黑\”,15))
entry.pack
entry1 = tk.Text(height=15,width=115)
entry1.pack
2、功能調(diào)用:
我們使用按鈕中的command參數(shù)調(diào)用已設(shè)置好的函數(shù)即可。其中函數(shù)部分我們通過生成文本和刪除文本的方式讀入數(shù)據(jù)和寫入數(shù)據(jù)。為了防止數(shù)據(jù)重疊故在要時(shí)刻監(jiān)測重復(fù)軟件。定義的函數(shù)內(nèi)容如下:
def song:
ss=entry.get
f=open(\”1.txt\”,\”w\”)
f.write(ss)
f.close
os.startfile(\”1.bat\”)
while True:
if os.path.exists(\”2.txt\”):
f=open(\”2.txt\”)
ws=f.read
f.close
entry1.insert(\”0.0\”, ws)
break
try:
os.remove(\”1.txt\”)
os.remove(\”2.txt\”)
except:
pass
3、GUI代碼:
整個(gè)GUI界面代碼如下:
import tkinter as tk
from PIL import ImageTk, Image
import os
try:
os.remove(\”1.txt\”)
os.remove(\”2.txt\”)
except:
pass
import os
def song:
ss=entry.get
f=open(\”1.txt\”,\”w\”)
f.write(ss)
f.close
os.startfile(\”1.bat\”)
while True:
if os.path.exists(\”2.txt\”):
f=open(\”2.txt\”)
ws=f.read
f.close
entry1.insert(\”0.0\”, ws)
break
try:
os.remove(\”1.txt\”)
os.remove(\”2.txt\”)
except:
pass
root = tk.Tk
root.title(\’AI寫歌詞\’)
# 背景
canvas = tk.Canvas(root, width=800, height=500, bd=0, highlightthickness=0)
imgpath = \’1.jpg\’
img = Image.open(imgpath)
photo = ImageTk.PhotoImage(img)
imgpath2 = \’3.jpg\’
img2 = Image.open(imgpath2)
photo2 = ImageTk.PhotoImage(img2)
canvas.create_image(700, 400, image=photo)
canvas.pack
label=tk.Label(text=\”請輸入關(guān)鍵詞:\”,font=(\”微軟雅黑\”,20))
entry = tk.Entry(root, insertbackground=\’blue\’, highlightthickness=2,font=(\”微軟雅黑\”,15))
entry.pack
entry1 = tk.Text(height=15,width=115)
entry1.pack
bnt = tk.Button(width=15,height=2,image=photo2,command=song)
canvas.create_window(100, 50, width=200, height=30,
window=label)
canvas.create_window(500, 50, width=630, height=30,
window=entry)
canvas.create_window(400, 100, width=220, height=50,
window=bnt)
canvas.create_window(400, 335, width=600, height=400,
window=entry1)
root.mainloop
到這里,我們整體的程序就搭建完成,下面為我們程序的運(yùn)行過程和結(jié)果:
源碼地址:
鏈接:https://pan.baidu.com/s/1EJsHIXbKUmRG-MdHcqkdFg
提取碼:iz5m
作者簡介 :
李秋鍵,CSDN 博客專家,CSDN達(dá)人課作者。碩士在讀于中國礦業(yè)大學(xué),開發(fā)有taptap安卓武俠游戲一部,vip視頻解析,文意轉(zhuǎn)換工具,寫作機(jī)器人等項(xiàng)目,發(fā)表論文若干,多次高數(shù)競賽獲獎(jiǎng)等等。
-
AI修復(fù)100年前晚清影像喜提熱搜,這兩大算法立功了
- CycleGan人臉轉(zhuǎn)為漫畫臉,牛掰的知識(shí)又增加了 | 附代碼
一次對語音技術(shù)的徹底批判
用大白話徹底搞懂 HBase RowKey 詳細(xì)設(shè)計(jì)
- 為什么黑客無法攻擊公開的區(qū)塊鏈?
- 再見 Python,Hello Julia!
百萬人學(xué)AI 萬人在線大會(huì), 15 場直播搶先看!
版權(quán)聲明:本文內(nèi)容由互聯(lián)網(wǎng)用戶自發(fā)貢獻(xiàn),該文觀點(diǎn)僅代表作者本人。本站僅提供信息存儲(chǔ)空間服務(wù),不擁有所有權(quán),不承擔(dān)相關(guān)法律責(zé)任。如發(fā)現(xiàn)本站有涉嫌抄襲侵權(quán)/違法違規(guī)的內(nèi)容, 請發(fā)送郵件至 舉報(bào),一經(jīng)查實(shí),本站將立刻刪除。