import torch
import torch.nn as nn
import numpy as np
from scipy.special import softmax




class LearnWtsAndActivations(nn.Module): 

    def __init__(self): 
        super(LearnWtsAndActivations, self).__init__()
        self.sigmoid = nn.Sigmoid() #To keep activations between 0 and +1.0

        self.A =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #left Foot edge activation in acccented stem
        self.mpl =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #Max-path left edge of Foot
        self.dpl =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #Dep-path left edge of Foot
        self.ml =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #Max Left edge of Foot
        self.dl =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #Dep left edge of Foot
        self.ar =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #Align Foot right (rewarded if satisfied)
        self.dpr =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #Dep-path right edge of Foot
        self.dr =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #Dep right edge of Foot
        self.mpr =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #Max-path right edge of Foot
        self.mr =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #Max-right edge of Foot
        self.Sl =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #Left Foot edge activation on dominant "stressless"
        self.Sr =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #Right Foot edge activation on dominant "stressless"
        self.u =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #Uniformity. Violated if two Foot edges coalesce.
        self.dep_anch_left_stem = nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #If a Foot occurs at the left edge of the word in the output, there must be a corresponding left Foot edge at the left edge of the stem in the input 
        self.anch_r_ft_edge =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #A right Foot edge in the input must have a corresponding right Foot edge in the output, just in case the site of that potential Foot edge occurs qwithin a Foot in the output.
        self.R =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #Activation of the left Foot edge of a recessive preaccenting suffix
        self.D =   nn.Parameter(torch.rand(1, requires_grad=True, dtype=torch.float)) #Activation of the left Foot edge of a dominant preaccenting suffix




    def forward(self, combs):
        self.combination_calculation = {'mpl': torch.exp(self.mpl[0]),
 'A*mpl': self.sigmoid(self.A[0])*torch.exp(self.mpl[0]),
 'Sl*mpl': self.sigmoid(self.Sl[0])*torch.exp(self.mpl[0]),
 'ml': torch.exp(self.ml[0]),
 'A*ml': self.sigmoid(self.A[0])*torch.exp(self.ml[0]),
 'Sl*ml': self.sigmoid(self.Sl[0])*torch.exp(self.ml[0]),
 'dpl': -torch.exp(self.dpl[0]),
 '(1-A)*dpl': -(1-self.sigmoid(self.A[0]))*(torch.exp(self.dpl[0])),
 '(1-Sl)*dpl': -(1-self.sigmoid(self.Sl[0]))*torch.exp(self.dpl[0]),
 'dl': -torch.exp(self.dl[0]),
 '(1-A)*dl': -(1-self.sigmoid(self.A[0]))*torch.exp(self.dl[0]),
 '(1-Sl)*dl': -(1-self.sigmoid(self.Sl[0]))*torch.exp(self.dl[0]),
 'mpr': torch.exp(self.mpr[0]),
 'D*mpr': self.sigmoid(self.D[0])*torch.exp(self.mpr[0]),
 'R*mpr': self.sigmoid(self.R[0])*torch.exp(self.mpr[0]),
 'Sr*mpr': self.sigmoid(self.Sr[0])*torch.exp(self.mpr[0]),
 'mr': torch.exp(self.mr[0]),
 'D*mr': self.sigmoid(self.D[0])*torch.exp(self.mr[0]),
 'R*mr': self.sigmoid(self.R[0])*torch.exp(self.mr[0]),
 'Sr*mr': self.sigmoid(self.Sr[0])*torch.exp(self.mr[0]),
 'dpr': -torch.exp(self.dpr[0]),
 '(1-D)*dpr': -(1-self.sigmoid(self.D[0]))*torch.exp(self.dpr[0]),
 '(1-R)*dpr': -(1-self.sigmoid(self.R[0]))*torch.exp(self.dpr[0]),
 '(1-Sr)*dpr': -(1-self.sigmoid(self.Sr[0]))*torch.exp(self.dpr[0]),
 'dr': -torch.exp(self.dr[0]),
 '(1-D)*dr': -(1-self.sigmoid(self.D[0]))*torch.exp(self.dr[0]),
 '(1-R)*dr': -(1-self.sigmoid(self.R[0]))*torch.exp(self.dr[0]),
 '(1-Sr)*dr': -(1-self.sigmoid(self.Sr[0]))*torch.exp(self.dr[0]),
 'ar': torch.exp(self.ar[0]),
 'dep_anch_left_stem': -torch.exp(self.dep_anch_left_stem[0]),
 'anch_r_ft_edge*R': -torch.exp(self.anch_r_ft_edge[0])*self.sigmoid(self.R[0]),
  'anch_r_ft_edge*Sr': -torch.exp(self.anch_r_ft_edge[0])*self.sigmoid(self.Sr[0]),
  'anch_r_ft_edge*D': -torch.exp(self.anch_r_ft_edge[0])*self.sigmoid(self.D[0])} 
        harmonies = torch.zeros(len(combs))
        for i, combos in enumerate(combs):
            for c in combos:
                harmonies[i] += self.combination_calculation[c]
                print(i, c, self.combination_calculation[c])
        return torch.unsqueeze(harmonies, 0)
            

'''For each item, we send it to the model separately. The output of harmony for each candidate will vary depending on how many candidates there are. So the preds will have a different number of values depending on the item. As long as the parameters in the model remain the same, it can initialize the harmony with some large magnitude negative numbers with the number of candidates that is sent to the model. Training can send a model a one-hot vector of each possible constraint-foot_edge combination that may or may not be calculated for a candidate '''

combinations = ['mpl', 'A*mpl', 'Sl*mpl', 'ml', 'A*ml', 'Sl*ml', 'dpl', '(1-A)*dpl', '(1-Sl)*dpl', 'dl', '(1-A)*dl', '(1-Sl)*dl', 'mpr', 'D*mpr', 'R*mpr', 'Sr*mpr', 'mr', 'D*mr', 'R*mr', 'Sr*mr', 'dpr', '(1-D)*dpr', '(1-R)*dpr', '(1-Sr)*dpr', 'dr', '(1-D)*dr', '(1-R)*dr', '(1-Sr)*dr', 'ar', 'dep_anch_left_stem', 'anch_r_ft_edge*R',  'anch_r_ft_edge*Sr',  'anch_r_ft_edge*D'] 
combo2ix = {}
for i, c in enumerate(combinations):
    combo2ix[c] = i
net = LearnWtsAndActivations()
criterion = nn.CrossEntropyLoss() 
optimiser = torch.optim.Adam(net.parameters(), lr=0.03) 
verbose = True

num_combinations = len(combinations)

for epoch in range(100):
    print('epoch *******************', epoch)
    num_tested = 0
    num_correct = 0
    ep_loss = 0
    if verbose: print('(AAA')
    num_tested += 1
    f1 = ['A*mpl' ,'A*ml' ,'(1-A)*dl' ,'(1-A)*dpl' ,'dpr' ,'dr']
    f2 = [        'A*ml' ,'(1-A)*dl'       ,'dpl' ,'dpr' ,'dr'  ,'ar']
    target = torch.tensor(0)
    target= torch.unsqueeze(target, 0)
    target= target.long()
    preds = net([f1, f2])
    print('preds', preds)
    with torch.no_grad():
        for val in softmax(preds.detach().numpy())[0].tolist():
            print(round(val, 3), end=' ')
        print()
    pred_name = torch.argmax(preds).item()
    loss = criterion(preds, target)
    ep_loss += loss
    if pred_name == target:
        num_correct += 1
        if verbose: print('CORRECT')
    loss.backward() 
    optimiser.step()
    optimiser.zero_grad()

    if verbose: print('(AA(S)')
    num_tested += 1
    f1 = ['A*ml' ,'A*mpl' ,'(1-A)*dpl' ,'(1-A)*dl'     ,'dr'       ,'dpr']
    f2 = ['A*ml'             ,'dpl'  ,'(1-A)*dl' ,'(1-Sr)*dr' ,'(1-Sr)*dpr' ,'Sr*mr' ,'Sr*mpr' ,'ar']
    target = torch.tensor(1)
    target= torch.unsqueeze(target, 0)
    target= target.long()
    preds = net([f1, f2])
    print('preds', preds)
    with torch.no_grad():
        for val in softmax(preds.detach().numpy())[0].tolist():
            print(round(val, 3), end=' ')
        print()
    pred_name = torch.argmax(preds).item()
    loss = criterion(preds, target)
    ep_loss += loss
    if pred_name == target:
        num_correct += 1
        if verbose: print('CORRECT')
    loss.backward() 
    optimiser.step()
    optimiser.zero_grad()

    if verbose: print('UU(S)')
    num_tested += 1
    f1 = [  'dl'        ,'dpl'                 ,'dr'        ,'dpr' ,'dep_anch_left_stem'  ] #(UUS
    f2 = [  '(1-Sl)*dl' ,'dpl'  ,'Sr*mr' ,'Sr*mpr' ,'(1-Sr)*dr' ,'(1-Sr)*dpr' ,'ar' ] #U(US
    with torch.no_grad():
        target = torch.tensor(1)
        target= torch.unsqueeze(target, 0)
        target= target.long()
        preds = net([f1, f2])
        print('preds', preds)
        with torch.no_grad():
            for val in softmax(preds.detach().numpy())[0].tolist():
                print(round(val, 3), end=' ')
            print()
        pred_name = torch.argmax(preds).item()
        if pred_name == target:
            num_correct += 1
            if verbose: print('CORRECT')

    if verbose: print('(AA(S)N')
    num_tested += 1
    f1 = ['A*mpl' ,'A*ml'   ,'(1-A)*dpl' ,'(1-A)*dl'     ,'dr'       ,'dpr' ] #(AASN
    f2 = [  'A*ml'       ,'dpl'    ,'(1-A)*dl'   ,'(1-Sr)*dr' ,'(1-Sr)*dpr' ,'Sr*mpr' ,'Sr*mr' ] #A(ASN
    f3 = ['Sl*mpl' ,'Sl*ml' ,'(1-Sl)*dpl' ,'(1-Sl)*dl' ,'dr' ,'dpr' ,'ar'   ] #AA(SN
    target = torch.tensor(2)
    target= torch.unsqueeze(target, 0)
    target= target.long()
    preds = net([f1, f2, f3])
    print('preds', preds)
    with torch.no_grad():
        for val in softmax(preds.detach().numpy())[0].tolist():
            print(round(val, 3), end=' ')
        print()
    pred_name = torch.argmax(preds).item()
    loss = criterion(preds, target)
    ep_loss += loss
    if pred_name == target:
        num_correct += 1
        if verbose: print('CORRECT')
    loss.backward() 
    optimiser.step()
    optimiser.zero_grad()

    if verbose: print('(AA(S)(S)N')
    num_tested += 1
    f1 = ['A*mpl' ,'A*ml' ,'(1-A)*dpl'     ,'(1-A)*dl'    ,'dr'        ,'dpr' ] #(AASSN
    f2 = [  'A*ml'      ,'dpl'      ,'(1-A)*dl'    ,'(1-Sr)*dr' ,'(1-Sr)*dpr' ,'Sr*mr' ,'Sr*mpr'  ] #A(ASSN
    f3 = ['Sl*mpl' ,'Sl*ml'     ,'(1-Sl)*dpl' ,'(1-Sl)*dl' ,'(1-Sr)*dr'  ,'(1-Sr)*dpr' ,'Sr*mr' ,'Sr*mpr'  ,'anch_r_ft_edge*Sr'       ] #AA(SSN
    f4 = ['Sl*mpl' ,'Sl*ml'     ,'(1-Sl)*dpl' ,'(1-Sl)*dl' ,'(1-Sr)*dr'  ,'dpr'        ,'Sr*mr'          ,'anch_r_ft_edge*Sr'   ,'ar' ] #AAS(SN
    target = torch.tensor(3)
    target= torch.unsqueeze(target, 0)
    target= target.long()
    preds = net([f1, f2, f3, f4])
    print('preds', preds)
    with torch.no_grad():
        for val in softmax(preds.detach().numpy())[0].tolist():
            print(round(val, 3), end=' ')
        print()
    pred_name = torch.argmax(preds).item()
    loss = criterion(preds, target)
    ep_loss += loss
    if pred_name == target:
        num_correct += 1
        if verbose: print('CORRECT')
    loss.backward() 
    optimiser.step()
    optimiser.zero_grad()

    if verbose: print('UUR)R)')
    num_tested += 1
    f1 = [ 'dpl' ,'dl'               ,'dr'       ,'dpr'     ,'dep_anch_left_stem'            ] #(UURR
    f2 = [ 'dpl' ,'dl' ,'R*mr'  ,'R*mpr' ,'(1-R)*dr' ,'(1-R)*dpr'        ] #U(URR
    f3 = [ 'dpl' ,'dl' ,'R*mr'  ,'R*mpr' ,'(1-R)*dr' ,'(1-R)*dpr'  ,'ar' ,'anch_r_ft_edge*R'    ] #UU(RR
    target = torch.tensor(1)
    target= torch.unsqueeze(target, 0)
    target= target.long()
    preds = net([f1, f2, f3])
    print('preds', preds)
    with torch.no_grad():
        for val in softmax(preds.detach().numpy())[0].tolist():
            print(round(val, 3), end=' ')
        print()
    pred_name = torch.argmax(preds).item()
    loss = criterion(preds, target)
    ep_loss += loss
    if pred_name == target:
        num_correct += 1
        if verbose: print('CORRECT')
    loss.backward() 
    optimiser.step()
    optimiser.zero_grad()

    if verbose: print('(AAD)D)R))')
    num_tested += 1
    f1 = ['A*mpl' ,'A*ml'    ,'(1-A)*dpl' ,'(1-A)*dl'                    ,'dr' ,'dpr'                                                          ] #(AADDR
    f2 = [  'A*ml'   ,'dpl'        ,'(1-A)*dl'   ,'D*mr' ,'D*mpr'    ,'(1-D)*dr'          ,'(1-D)*dpr'                                     ] #A(ADDR
    f3 = [  'dpl'        ,'dl'         ,'D*mr' ,'D*mpr'    ,'(1-D)*dr'          ,'(1-D)*dpr'          ,'anch_r_ft_edge*D'          ] #AA(DDR
    f4 = [  'dpl'        ,'dl'         ,'D*mr'           ,'(1-D)*dr'    ,'(1-R)*dpr'         ,'ar'  ,'anch_r_ft_edge*D'     ] #AAD(DR
    with torch.no_grad():
        target = torch.tensor(1)
        target= torch.unsqueeze(target, 0)
        target= target.long()
        preds = net([f1, f2, f3, f4])
        print('preds', preds)
        with torch.no_grad():
            for val in softmax(preds.detach().numpy())[0].tolist():
                print(round(val, 3), end=' ')
            print()
        pred_name = torch.argmax(preds).item()
        if pred_name == target:
            num_correct += 1
            if verbose: print('CORRECT')

    if verbose: print('UUR)D)D)')
    num_tested += 1
    f1 = [  'dpl' ,'dl' ,'dpr'       ,'dr'                                         ,'dep_anch_left_stem'      ] #(UURDD
    f2 = [  'dpl' ,'dl' ,'(1-R)*dpr' ,'(1-D)*dr'          ,'R*mpr' ,'D*mr'                               ] #U(URDD
    f3 = [  'dpl' ,'dl' ,'(1-D)*dpr' ,'(1-D)*dr'   ,'D*mpr' ,'D*mr' ,'anch_r_ft_edge*R'             ] #UU(RDD
    f4 = [  'dpl' ,'dl' ,'(1-D)*dpr' ,'(1-D)*dr'   ,'D*mpr' ,'D*mr' ,'anch_r_ft_edge*D' ,'ar'         ] #UUR(DD
    target = torch.tensor(2)
    target= torch.unsqueeze(target, 0)
    target= target.long()
    preds = net([f1, f2, f3, f4])
    print('preds', preds)
    with torch.no_grad():
        for val in softmax(preds.detach().numpy())[0].tolist():
            print(round(val, 3), end=' ')
        print()
    pred_name = torch.argmax(preds).item()
    loss = criterion(preds, target)
    ep_loss += loss
    if pred_name == target:
        num_correct += 1
        if verbose: print('CORRECT')
    loss.backward() 
    optimiser.step()
    optimiser.zero_grad()

    if verbose: print('(AA(S)NR)')
    num_tested += 1
    f1 = ['A*mpl'    ,'A*ml'  ,'(1-A)*dpl'    ,'(1-A)*dl' ,'dr'                    ,'dpr' ] #(AASNR
    f2 = [          'A*ml'        ,'dpl'   ,'(1-A)*dl' ,'(1-Sr)*dr'             ,'(1-Sr)*dpr'   ,'Sr*mr'             ,'Sr*mpr' ] #A(ASNR
    f3 = [  'Sl*mpl'  ,'Sl*ml'  ,'(1-Sl)*dpl'  ,'(1-Sl)*dl' ,'(1-Sr)*dr'    ,'dpr'          ,'Sr*mr'                      ,'anch_r_ft_edge*Sr' ] #AA(SNR
    f4 = [  'Sl*ml'  ,'dpl'        ,'(1-Sl)*dl' ,'(1-Sr)*dr'      ,'(1-R)*dpr'   ,'Sr*mr'      ,'R*mpr'                ,'ar' ] #AAS(NR
    target = torch.tensor(3)
    target= torch.unsqueeze(target, 0)
    target= target.long()
    preds = net([f1, f2, f3, f4])
    print('preds', preds)
    with torch.no_grad():
        for val in softmax(preds.detach().numpy())[0].tolist():
            print(round(val, 3), end=' ')
        print()
    pred_name = torch.argmax(preds).item()
    loss = criterion(preds, target)
    ep_loss += loss
    if pred_name == target:
        num_correct += 1
        if verbose: print('CORRECT')
    loss.backward() 
    optimiser.step()
    optimiser.zero_grad()

    if verbose: print('(AANNR)')
    num_tested += 1
    f1 = ['A*mpl' ,'A*ml' ,'(1-A)*dpl' ,'(1-A)*dl'    ,'dr'       ,'dpr' ] #(AANNR
    f2 = [       'A*ml'        ,'dpl'   ,'(1-A)*dl' ,'(1-R)*dr' ,'dpr'  ,'R*mr' ] #A(ANNR
    f3 = [  'dpl'  ,'dl'        ,'(1-R)*dr'  ,'dpr' ,'R*mr' ] #AA(NNR
    f4 = [  'dpl'   ,'dl'       ,'(1-R)*dr'   ,'(1-R)*dpr' ,'R*mr' ,'R*mpr' ,'ar' ] #AAN(NR
    target = torch.tensor(0)
    target= torch.unsqueeze(target, 0)
    target= target.long()
    preds = net([f1, f2, f3, f4])
    print('preds', preds)
    with torch.no_grad():
        for val in softmax(preds.detach().numpy())[0].tolist():
            print(round(val, 3), end=' ')
        print()
    pred_name = torch.argmax(preds).item()
    loss = criterion(preds, target)
    ep_loss += loss
    if pred_name == target:
        num_correct += 1
        if verbose: print('CORRECT')
    loss.backward() 
    optimiser.step()
    optimiser.zero_grad()

    if verbose: print('(AAN') #3rd test only
    num_tested += 1
    f1 = ['A*mpl' ,'A*ml' ,'(1-A)*dpl' ,'(1-A)*dl' ,'dr' ,'dpr' ] #(AAN
    f2 = [  'A*ml'      ,'dpl'  ,'(1-A)*dl' ,'dr' ,'dpr' ,'ar' ] #A(AN 
    with torch.no_grad():
        target = torch.tensor(0)
        target= torch.unsqueeze(target, 0)
        target= target.long()
        preds = net([f1, f2])
        print('preds', preds)
        with torch.no_grad():
            for val in softmax(preds.detach().numpy())[0].tolist():
                print(round(val, 3), end=' ')
            print()
        pred_name = torch.argmax(preds).item()
        if pred_name == target:
            num_correct += 1
            if verbose: print('CORRECT')
    
    if verbose: print('UNN') #4th test only
    num_tested += 1
    f1 = ['dpl' ,'dl'  ,'dpr' ,'dr' ,'dep_anch_left_stem' ] #(UNN
    f2 = ['dpl' ,'dl' ,'dpr' ,'dr' ,'ar' ] #U(NN
    with torch.no_grad():
        target = torch.tensor(1)
        target= torch.unsqueeze(target, 0)
        target= target.long()
        preds = net([f1, f2])
        print('preds', preds)
        with torch.no_grad():
            for val in softmax(preds.detach().numpy())[0].tolist():
                print(round(val, 3), end=' ')
            print()
        pred_name = torch.argmax(preds).item()
        if pred_name == target:
            num_correct += 1
            if verbose: print('CORRECT')

    if verbose: print('UN(S)') #5th test only
    num_tested += 1
    f1 = [  'dpl' ,'dl'        ,'dpr'        ,'(1-Sr)*dr' ,'Sr*mr' ,'dep_anch_left_stem' ] #(UNS
    f2 = ['Sl*ml' ,'dpl' ,'(1-Sl)*dl' ,'(1-Sr)*dpr' ,'(1-Sr)*dr' ,'Sr*mr'  ,'Sr*mpr' ,'ar' ] #U(NS
    with torch.no_grad():
        target = torch.tensor(1)
        target= torch.unsqueeze(target, 0)
        target= target.long()
        preds = net([f1, f2])
        print('preds', preds)
        with torch.no_grad():
            for val in softmax(preds.detach().numpy())[0].tolist():
                print(round(val, 3), end=' ')
            print()
        pred_name = torch.argmax(preds).item()
        if pred_name == target:
            num_correct += 1
            if verbose: print('CORRECT')

    if verbose: print('(AANN(S)')
    num_tested += 1
    f1 = ['A*mpl' ,'A*ml' ,'(1-A)*dpl' ,'(1-A)*dl' ,'dr'         ,'dpr' ] #(AANNS
    f2 = [  'A*ml'      ,'dpl'  ,'(1-A)*dl' ,'(1-Sr)*dr'  ,'dpr'        ,'Sr*mr' ] #A(ANNS
    f3 = [  'Sl*ml'     ,'dpl'  ,'(1-Sl)*dl' ,'(1-Sr)*dr' ,'dpr'        ,'Sr*mr'] #AA(NNS]
    f4 = [  'Sl*ml'     ,'dpl'  ,'(1-Sl)*dl' ,'(1-Sr)*dr' ,'(1-Sr)*dpr' ,'Sr*mr'  ,'Sr*mpr' ,'ar' ] #AAN(NS
    target = torch.tensor(3)
    target= torch.unsqueeze(target, 0)
    target= target.long()
    preds = net([f1, f2, f3, f4])
    print('preds', preds)
    with torch.no_grad():
        for val in softmax(preds.detach().numpy())[0].tolist():
            print(round(val, 3), end=' ')
        print()
    pred_name = torch.argmax(preds).item()
    loss = criterion(preds, target)
    ep_loss += loss
    if pred_name == target:
        num_correct += 1
        if verbose: print('CORRECT')
    loss.backward() 
    optimiser.step()
    optimiser.zero_grad()


    if verbose: print('UU(S)(S)')
    num_tested += 1
    f1 = [  'dpl'         ,'dl'          ,'dpr'        ,'dr'          ,'dep_anch_left_stem' ] #(UUSS
    f2 = [  'dpl'        ,'dl'          ,'(1-Sr)*dpr' ,'(1-Sr)*dr' ,'Sr*mr' ,'Sr*mpr' ] #U(USS
    f3 = ['Sl*ml' ,'Sl*mpl' ,'(1-Sl)*dpl' ,'(1-Sl)*dl'   ,'(1-Sr)*dpr' ,'(1-Sr)*dr' ,'Sr*mr' ,'Sr*mpr' ,'ar' ,'anch_r_ft_edge*Sr' ] #UU(SS
    target = torch.tensor(2)
    target= torch.unsqueeze(target, 0)
    target= target.long()
    preds = net([f1, f2, f3])
    print('preds', preds)
    with torch.no_grad():
        for val in softmax(preds.detach().numpy())[0].tolist():
            print(round(val, 3), end=' ')
        print()
    pred_name = torch.argmax(preds).item()
    loss = criterion(preds, target)
    ep_loss += loss
    if pred_name == target:
        num_correct += 1
        if verbose: print('CORRECT')
    loss.backward() 
    optimiser.step()
    optimiser.zero_grad()

    if verbose: print('(AA(S)(S)')
    num_tested += 1
    f1 = ['A*mpl'    ,'A*ml' ,'(1-A)*dpl' ,'(1-A)*dl'    ,'dr'        ,'dpr' ] #(AASS
    f2 = [          'A*ml' ,'dpl'       ,'(1-A)*dl'    ,'(1-Sr)*dr' ,'(1-Sr)*dpr' ,'Sr*mpr' ,'Sr*mr' ] #A(ASS
    f3 = [ 'Sl*mpl'  ,'Sl*ml' ,'(1-Sl)*dpl' ,'(1-Sl)*dl' ,'(1-Sr)*dr' ,'(1-Sr)*dpr' ,'Sr*mpr' ,'Sr*mr' ,'ar' ,'anch_r_ft_edge*Sr' ] #AA(SS
    target = torch.tensor(2)
    target= torch.unsqueeze(target, 0)
    target= target.long()
    preds = net([f1, f2, f3])
    print('preds', preds)
    with torch.no_grad():
        for val in softmax(preds.detach().numpy())[0].tolist():
            print(round(val, 3), end=' ')
        print()
    pred_name = torch.argmax(preds).item()
    loss = criterion(preds, target)
    ep_loss += loss
    if pred_name == target:
        num_correct += 1
        if verbose: print('CORRECT')
    loss.backward() 
    optimiser.step()
    optimiser.zero_grad()

    if verbose: print('(AAR)')
    num_tested += 1
    f1 = ['A*mpl' ,'A*ml' ,'(1-A)*dpl' ,'(1-A)*dl' ,'dr'       ,'dpr'                         ] #(AAR
    f2 = [       'A*ml'       ,'dpl' ,'(1-A)*dl' ,'(1-R)*dr' ,'(1-R)*dpr' ,'R*mpr' ,'R*mr' ,'ar' ] #A(AR
    target = torch.tensor(0)
    target= torch.unsqueeze(target, 0)
    target= target.long()
    preds = net([f1, f2])
    print('preds', preds)
    with torch.no_grad():
        for val in softmax(preds.detach().numpy())[0].tolist():
            print(round(val, 3), end=' ')
        print()
    pred_name = torch.argmax(preds).item()
    loss = criterion(preds, target)
    ep_loss += loss
    if pred_name == target:
        num_correct += 1
        if verbose: print('CORRECT')
    loss.backward() 
    optimiser.step()
    optimiser.zero_grad()

    if verbose: print('(AAR)R)') #6th test only
    num_tested += 1
    f1 = ['A*mpl' ,'A*ml' ,'(1-A)*dpl' ,'(1-A)*dl'   ,'dr'        ,'dpr'                                             ] #(AARR
    f2 = [       'A*ml' ,'dpl'        ,'(1-A)*dl'  ,'(1-R)*dr'  ,'(1-R)*dpr' ,'R*mpr' ,'R*mr'                           ] #A(ARR
    f3 = [  'dpl'        ,'dl'        ,'(1-R)*dr'  ,'(1-R)*dpr' ,'R*mpr' ,'R*mr'  ,'ar'  ,'anch_r_ft_edge*R'   ] #AA(RR
    with torch.no_grad():
        target = torch.tensor(0)
        target= torch.unsqueeze(target, 0)
        target= target.long()
        preds = net([f1, f2, f3 ])
        print('preds', preds)
        with torch.no_grad():
            for val in softmax(preds.detach().numpy())[0].tolist():
                print(round(val, 3), end=' ')
            print()
        pred_name = torch.argmax(preds).item()
        if pred_name == target:
            num_correct += 1
            if verbose: print('CORRECT')

    if verbose: print('(ANR)') #7th test only
    num_tested += 1
    f1 = ['A*mpl' ,'A*ml' ,'(1-A)*dpl' ,'(1-A)*dl' ,'dr'       ,'dpr'                         ] #(ANR
    f2 = [  'dpl' ,'dl'       ,'(1-R)*dr' ,'(1-R)*dpr' ,'R*mpr' ,'R*mr' ,'ar' ] #A(NR
    with torch.no_grad():
        target = torch.tensor(0)
        target= torch.unsqueeze(target, 0)
        target= target.long()
        preds = net([f1, f2])
        print('preds', preds)
        with torch.no_grad():
            for val in softmax(preds.detach().numpy())[0].tolist():
                print(round(val, 3), end=' ')
            print()
        pred_name = torch.argmax(preds).item()
        if pred_name == target:
            num_correct += 1
            if verbose: print('CORRECT')

    if verbose: print('(AA(S)R)') #8th test only
    num_tested += 1
    f1 = ['A*mpl' ,'A*ml' ,'(1-A)*dpl' ,'(1-A)*dl' ,'dr'                     ,'dpr' ] #(AASR
    f2 = [       'A*ml'        ,'dpl' ,'(1-A)*dl' ,'(1-Sr)*dr'             ,'(1-Sr)*dpr' ,'Sr*mpr' ,'Sr*mr' ] #A(ASR
    f3 = ['Sl*mpl'  ,'Sl*ml'  ,'(1-Sl)*dpl' ,'(1-Sl)*dl' ,'(1-Sr)*dr'  ,'(1-R)*dpr' ,'Sr*mr' ,'R*mpr' ,'anch_r_ft_edge*Sr' ,'ar' ] #AA(SR
    with torch.no_grad():
        target = torch.tensor(2)
        target= torch.unsqueeze(target, 0)
        target= target.long()
        preds = net([f1, f2, f3])
        print('preds', preds)
        with torch.no_grad():
            for val in softmax(preds.detach().numpy())[0].tolist():
                print(round(val, 3), end=' ')
            print()
        pred_name = torch.argmax(preds).item()
        if pred_name == target:
            num_correct += 1
            if verbose: print('CORRECT')

    if verbose: print('(AAD)D))') #9th test only
    num_tested += 1
    f1 = ['A*mpl' ,'A*ml'    ,'(1-A)*dpl' ,'(1-A)*dl'                    ,'dr'                 ,'dpr'                                          ] #(AADD
    f2 = [  'A*ml'   ,'dpl'        ,'(1-A)*dl'   ,'D*mr' ,'D*mpr'    ,'(1-D)*dr'          ,'(1-D)*dpr'                                     ] #A(ADD
    f3 = [  'dpl'        ,'dl'         ,'D*mr' ,'D*mpr'    ,'(1-D)*dr'          ,'(1-D)*dpr'          ,'anch_r_ft_edge*D'   ,'ar'       ] #AA(DD
    with torch.no_grad():
        target = torch.tensor(1)
        target= torch.unsqueeze(target, 0)
        target= target.long()
        preds = net([f1, f2, f3])
        print('preds', preds)
        with torch.no_grad():
            for val in softmax(preds.detach().numpy())[0].tolist():
                print(round(val, 3), end=' ')
            print()
        pred_name = torch.argmax(preds).item()
        if pred_name == target:
            num_correct += 1
            if verbose: print('CORRECT')

    if verbose: print('UU(S)ND)D)') #10th test only
    num_tested += 1
    f1 = [  'dpl'         ,'dl'       ,'dpr'        ,'dr'                           ,'dep_anch_left_stem' ] #(UUSNDD
    f2 = [  'dpl'         ,'dl'      ,'(1-Sr)*dpr' ,'(1-D)*dr'  ,'D*mr' ,'Sr*mpr'                           ] #U(USNDD
    f3 = ['Sl*ml' ,'Sl*mpl' ,'(1-Sl)*dpl' ,'(1-Sl)*dl' ,'dpr'        ,'(1-D)*dr' ,'D*mr'                ,'anch_r_ft_edge*Sr' ] #UU(SNDD
    f4 = ['Sl*ml'         ,'dpl'         ,'(1-Sl)*dl' ,'(1-D)*dpr' ,'(1-D)*dr' ,'D*mr' ,'D*mpr'                            ] #UUS(NDD
    f5 = ['Sl*ml'         ,'dpl'        ,'(1-Sl)*dl' ,'(1-D)*dpr'  ,'(1-D)*dr' ,'D*mr' ,'D*mpr'   ,'ar'    ,'anch_r_ft_edge*D' ] #UUSN(DD
    with torch.no_grad():
        target = torch.tensor(3)
        target= torch.unsqueeze(target, 0)
        target= target.long()
        preds = net([f1, f2, f3, f4, f5])
        print('preds', preds)
        with torch.no_grad():
            for val in softmax(preds.detach().numpy())[0].tolist():
                print(round(val, 3), end=' ')
            print()
        pred_name = torch.argmax(preds).item()
        if pred_name == target:
            num_correct += 1
            if verbose: print('CORRECT')

    if verbose: print('(AN(S)D)D)') 
    num_tested += 1
    f1 = ['A*mpl' ,'A*ml' ,'(1-A)*dpl'    ,'(1-A)*dl'   ,'dr'       ,'dpr'                         ] #(ANSDD
    f2 = [  'Sl*ml'      ,'dpl'    ,'(1-Sl)*dl'  ,'(1-D)*dr' ,'(1-Sr)*dpr' ,'Sr*mpr' ,'D*mr'  ] #A(NSDD
    f3 = ['Sl*mpl' ,'Sl*ml' ,'(1-Sl)*dpl' ,'(1-Sl)*dl' ,'(1-D)*dr' ,'(1-D)*dpr'  ,'D*mpr'  ,'D*mr'  ,'anch_r_ft_edge*Sr'  ] #AN(SDD
    f4 = [  'Sl*ml' ,'dpl'        ,'(1-Sl)*dl'  ,'(1-D)*dr' ,'(1-D)*dpr'  ,'D*mpr' ,'D*mr' ,'ar' ,'anch_r_ft_edge*D' ] #ANS(DD
    target = torch.tensor(2)
    target= torch.unsqueeze(target, 0)
    target= target.long()
    preds = net([f1, f2, f3, f4])
    print('preds', preds)
    with torch.no_grad():
        for val in softmax(preds.detach().numpy())[0].tolist():
            print(round(val, 3), end=' ')
        print()
    pred_name = torch.argmax(preds).item()
    loss = criterion(preds, target)
    ep_loss += loss
    if pred_name == target:
        num_correct += 1
        if verbose: print('CORRECT')
    loss.backward() 
    optimiser.step()
    optimiser.zero_grad()

    if verbose: print('(AA(S)ND)D)')
    num_tested += 1
    f1 = ['A*mpl'  ,'A*ml' ,'(1-A)*dpl' ,'(1-A)*dl'     ,'dr'         ,'dpr' ] #(AASNDD
    f2 = [        'A*ml'        ,'dpl'   ,'(1-A)*dl'  ,'(1-D)*dr'   ,'(1-Sr)*dpr' ,'D*mr' ,'Sr*mpr' ] #A(ASNDD
    f3 = ['Sl*mpl' ,'Sl*ml'  ,'(1-Sl)*dpl'  ,'(1-Sl)*dl' ,'(1-D)*dr'  ,'dpr'      ,'D*mr'             ,'anch_r_ft_edge*Sr' ] #AA(SNDD
    f4 = [  'Sl*ml'       ,'dpl'   ,'(1-Sl)*dl' ,'(1-D)*dr'   ,'(1-D)*dpr' ,'D*mr'  ,'D*mpr'       ] #AAS(NDD
    f5 = [  'Sl*ml'       ,'dpl'   ,'(1-Sl)*dl' ,'(1-D)*dr'   ,'(1-D)*dpr' ,'D*mr'  ,'D*mpr' ,'ar' ,'anch_r_ft_edge*D' ] #AASN(DD  
    target = torch.tensor(3)
    target= torch.unsqueeze(target, 0)
    target= target.long()
    preds = net([f1, f2, f3, f4, f5])
    print('preds', preds)
    with torch.no_grad():
        for val in softmax(preds.detach().numpy())[0].tolist():
            print(round(val, 3), end=' ')
        print()
    pred_name = torch.argmax(preds).item()
    loss = criterion(preds, target)
    ep_loss += loss
    if pred_name == target:
        num_correct += 1
        if verbose: print('CORRECT')
    loss.backward() 
    optimiser.step()
    optimiser.zero_grad()


    print('Epoch', epoch, 'loss', ep_loss, num_correct, 'correct out of', num_tested, 'tested')
    if num_correct == num_tested:
        print('\nALL CORRECT')
        break

print('\n\n')
for param, value in [
    ('A', net.A),('Sl', net.Sl),('Sr', net.Sr),('R', net.R),('D', net.D)]:
    print(param, net.sigmoid(value))
for param, value in [
    ('mpl', net.mpl),('dpl', net.dpl),('ml', net.ml),('dl', net.dl),('ar', net.ar),('dpr', net.dpr),('dr', net.dr),('mpr', net.mpr),('mr', net.mr),('u', net.u),('dep_anch_left_stem', net.dep_anch_left_stem),('anch_r_ft_edge', net.anch_r_ft_edge)]:
    print(param, torch.exp(value))
