Name: SeqGAN
Owner: YOCTOL INFO INC.
Description: Implementation of Sequence Generative Adversarial Nets with Policy Gradient
Forked from: LantaoYu/SeqGAN
Created: 2017-04-11 08:07:53.0
Updated: 2017-04-11 08:07:55.0
Pushed: 2017-04-11 10:05:43.0
Homepage: null
Size: 3021
Language: Python
GitHub Committers
User | Most Recent Commit | # Commits |
---|
Other Committers
User | Most Recent Commit | # Commits |
---|
Apply Generative Adversarial Nets to generating sequences of discrete tokens.
The illustration of SeqGAN. Left: D is trained over the real data and the generated data by G. Right: G is trained by policy gradient where the final reward signal is provided by D and is passed back to the intermediate action value via Monte Carlo search.
The research paper SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient has been accepted at the Thirty-First AAAI Conference on Artificial Intelligence (AAAI-17).
We provide example codes to repeat the synthetic data experiments with oracle evaluation mechanisms. To run the experiment with default parameters:
thon sequence_gan.py
You can change the all the parameters in sequence_gan.py
.
The experiment has two stages. In the first stage, use the positive data provided by the oracle model and Maximum Likelihood Estimation to perform supervise learning. In the second stage, use adversarial training to improve the generator.
After running the experiments, you could get the negative log-likelihodd performance saved in save/experiment-log.txt
like:
training...
h: 0 nll: 10.1716
h: 5 nll: 9.42939
h: 10 nll: 9.2388
h: 15 nll: 9.11899
h: 20 nll: 9.13099
h: 25 nll: 9.14474
h: 30 nll: 9.12539
h: 35 nll: 9.13982
h: 40 nll: 9.135
h: 45 nll: 9.13081
h: 50 nll: 9.10678
h: 55 nll: 9.10694
h: 60 nll: 9.10349
h: 65 nll: 9.10403
h: 70 nll: 9.07613
h: 75 nll: 9.091
h: 80 nll: 9.08909
h: 85 nll: 9.0807
h: 90 nll: 9.08434
h: 95 nll: 9.08936
h: 100 nll: 9.07443
h: 105 nll: 9.08305
h: 110 nll: 9.06973
h: 115 nll: 9.07058
rsarial training...
h: 0 nll: 9.08457
h: 5 nll: 9.04511
h: 10 nll: 9.03079
h: 15 nll: 8.99239
h: 20 nll: 8.96401
h: 25 nll: 8.93864
h: 30 nll: 8.91642
h: 35 nll: 8.87761
h: 40 nll: 8.88582
h: 45 nll: 8.8592
h: 50 nll: 8.83388
h: 55 nll: 8.81342
h: 60 nll: 8.80247
h: 65 nll: 8.77778
h: 70 nll: 8.7567
h: 75 nll: 8.73002
h: 80 nll: 8.72488
h: 85 nll: 8.72233
h: 90 nll: 8.71473
h: 95 nll: 8.71163
h: 100 nll: 8.70113
h: 105 nll: 8.69879
h: 110 nll: 8.69208
h: 115 nll: 8.69291
h: 120 nll: 8.68371
h: 125 nll: 8.689
h: 130 nll: 8.68989
h: 135 nll: 8.68269
h: 140 nll: 8.68647
h: 145 nll: 8.68066
h: 150 nll: 8.6832
Note: this code is based on the previous work by ofirnachum. Many thanks to ofirnachum.