}
 }
 
-static void msi_set_mask_bit(unsigned int irq, int flag)
+static void msi_set_mask_bits(unsigned int irq, u32 mask, u32 flag)
 {
        struct msi_desc *entry;
 
 
                        pos = (long)entry->mask_base;
                        pci_read_config_dword(entry->dev, pos, &mask_bits);
-                       mask_bits &= ~(1);
-                       mask_bits |= flag;
+                       mask_bits &= ~(mask);
+                       mask_bits |= flag & mask;
                        pci_write_config_dword(entry->dev, pos, mask_bits);
                } else {
                        msi_set_enable(entry->dev, !flag);
 
 void mask_msi_irq(unsigned int irq)
 {
-       msi_set_mask_bit(irq, 1);
+       msi_set_mask_bits(irq, 1, 1);
        msix_flush_writes(irq);
 }
 
 void unmask_msi_irq(unsigned int irq)
 {
-       msi_set_mask_bit(irq, 0);
+       msi_set_mask_bits(irq, 1, 0);
        msix_flush_writes(irq);
 }
 
        msi_set_enable(dev, 0);
        write_msi_msg(dev->irq, &entry->msg);
        if (entry->msi_attrib.maskbit)
-               msi_set_mask_bit(dev->irq, entry->msi_attrib.masked);
+               msi_set_mask_bits(dev->irq, entry->msi_attrib.maskbits_mask,
+                                 entry->msi_attrib.masked);
 
        pci_read_config_word(dev, pos + PCI_MSI_FLAGS, &control);
        control &= ~(PCI_MSI_FLAGS_QSIZE | PCI_MSI_FLAGS_ENABLE);
 
        list_for_each_entry(entry, &dev->msi_list, list) {
                write_msi_msg(entry->irq, &entry->msg);
-               msi_set_mask_bit(entry->irq, entry->msi_attrib.masked);
+               msi_set_mask_bits(entry->irq, 1, entry->msi_attrib.masked);
        }
 
        BUG_ON(list_empty(&dev->msi_list));
                pci_write_config_dword(dev,
                        msi_mask_bits_reg(pos, is_64bit_address(control)),
                        maskbits);
+               entry->msi_attrib.maskbits_mask = temp;
        }
        list_add_tail(&entry->list, &dev->msi_list);
 
 
        BUG_ON(list_empty(&dev->msi_list));
        entry = list_entry(dev->msi_list.next, struct msi_desc, list);
+       /* Return the the pci reset with msi irqs unmasked */
+       if (entry->msi_attrib.maskbit) {
+               u32 mask = entry->msi_attrib.maskbits_mask;
+               msi_set_mask_bits(dev->irq, mask, ~mask);
+       }
        if (!entry->dev || entry->msi_attrib.type != PCI_CAP_ID_MSI) {
                return;
        }