WithNetNS restores original namespace when callback errors
authorGabe Rosenhouse <grosenhouse@pivotal.io>
Tue, 15 Mar 2016 01:57:16 +0000 (18:57 -0700)
committerGabe Rosenhouse <grosenhouse@pivotal.io>
Tue, 15 Mar 2016 08:51:58 +0000 (01:51 -0700)
- adds test coverage of WithNetNS in BDD-style

pkg/ns/ns.go
pkg/ns/ns_suite_test.go [new file with mode: 0644]
pkg/ns/ns_test.go [new file with mode: 0644]
test

index 9996aac..97c6a1e 100644 (file)
@@ -82,11 +82,11 @@ func WithNetNS(ns *os.File, lockThread bool, f func(*os.File) error) error {
        if err = SetNS(ns, syscall.CLONE_NEWNET); err != nil {
                return fmt.Errorf("Error switching to ns %v: %v", ns.Name(), err)
        }
+       defer SetNS(thisNS, syscall.CLONE_NEWNET) // switch back
 
        if err = f(thisNS); err != nil {
                return err
        }
 
-       // switch back
-       return SetNS(thisNS, syscall.CLONE_NEWNET)
+       return nil
 }
diff --git a/pkg/ns/ns_suite_test.go b/pkg/ns/ns_suite_test.go
new file mode 100644 (file)
index 0000000..ff26ec2
--- /dev/null
@@ -0,0 +1,20 @@
+package ns_test
+
+import (
+       "math/rand"
+       "runtime"
+
+       . "github.com/onsi/ginkgo"
+       "github.com/onsi/ginkgo/config"
+       . "github.com/onsi/gomega"
+
+       "testing"
+)
+
+func TestNs(t *testing.T) {
+       rand.Seed(config.GinkgoConfig.RandomSeed)
+       runtime.LockOSThread()
+
+       RegisterFailHandler(Fail)
+       RunSpecs(t, "pkg/ns Suite")
+}
diff --git a/pkg/ns/ns_test.go b/pkg/ns/ns_test.go
new file mode 100644 (file)
index 0000000..ad52f7f
--- /dev/null
@@ -0,0 +1,153 @@
+package ns_test
+
+import (
+       "errors"
+       "fmt"
+       "math/rand"
+       "os"
+       "os/exec"
+       "path/filepath"
+
+       "golang.org/x/sys/unix"
+
+       "github.com/appc/cni/pkg/ns"
+       . "github.com/onsi/ginkgo"
+       . "github.com/onsi/gomega"
+)
+
+func getInode(path string) (uint64, error) {
+       file, err := os.Open(path)
+       if err != nil {
+               return 0, err
+       }
+       defer file.Close()
+       return getInodeF(file)
+}
+
+func getInodeF(file *os.File) (uint64, error) {
+       stat := &unix.Stat_t{}
+       err := unix.Fstat(int(file.Fd()), stat)
+       return stat.Ino, err
+}
+
+const CurrentNetNS = "/proc/self/ns/net"
+
+var _ = Describe("Linux namespace operations", func() {
+       Describe("WithNetNS", func() {
+               var (
+                       originalNetNS *os.File
+
+                       targetNetNSName string
+                       targetNetNSPath string
+                       targetNetNS     *os.File
+               )
+
+               BeforeEach(func() {
+                       var err error
+                       originalNetNS, err = os.Open(CurrentNetNS)
+                       Expect(err).NotTo(HaveOccurred())
+
+                       targetNetNSName = fmt.Sprintf("test-netns-%d", rand.Int())
+
+                       err = exec.Command("ip", "netns", "add", targetNetNSName).Run()
+                       Expect(err).NotTo(HaveOccurred())
+
+                       targetNetNSPath = filepath.Join("/var/run/netns/", targetNetNSName)
+                       targetNetNS, err = os.Open(targetNetNSPath)
+                       Expect(err).NotTo(HaveOccurred())
+               })
+
+               AfterEach(func() {
+                       Expect(targetNetNS.Close()).To(Succeed())
+
+                       err := exec.Command("ip", "netns", "del", targetNetNSName).Run()
+                       Expect(err).NotTo(HaveOccurred())
+
+                       Expect(originalNetNS.Close()).To(Succeed())
+               })
+
+               It("executes the callback within the target network namespace", func() {
+                       expectedInode, err := getInode(targetNetNSPath)
+                       Expect(err).NotTo(HaveOccurred())
+
+                       var actualInode uint64
+                       var innerErr error
+                       err = ns.WithNetNS(targetNetNS, false, func(*os.File) error {
+                               actualInode, innerErr = getInode(CurrentNetNS)
+                               return nil
+                       })
+                       Expect(err).NotTo(HaveOccurred())
+
+                       Expect(innerErr).NotTo(HaveOccurred())
+                       Expect(actualInode).To(Equal(expectedInode))
+               })
+
+               It("provides the original namespace as the argument to the callback", func() {
+                       hostNSInode, err := getInode(CurrentNetNS)
+                       Expect(err).NotTo(HaveOccurred())
+
+                       var inputNSInode uint64
+                       var innerErr error
+                       err = ns.WithNetNS(targetNetNS, false, func(inputNS *os.File) error {
+                               inputNSInode, err = getInodeF(inputNS)
+                               return nil
+                       })
+                       Expect(err).NotTo(HaveOccurred())
+
+                       Expect(innerErr).NotTo(HaveOccurred())
+                       Expect(inputNSInode).To(Equal(hostNSInode))
+               })
+
+               It("restores the calling thread to the original network namespace", func() {
+                       preTestInode, err := getInode(CurrentNetNS)
+                       Expect(err).NotTo(HaveOccurred())
+
+                       err = ns.WithNetNS(targetNetNS, false, func(*os.File) error {
+                               return nil
+                       })
+                       Expect(err).NotTo(HaveOccurred())
+
+                       postTestInode, err := getInode(CurrentNetNS)
+                       Expect(err).NotTo(HaveOccurred())
+
+                       Expect(postTestInode).To(Equal(preTestInode))
+               })
+
+               Context("when the callback returns an error", func() {
+                       It("restores the calling thread to the original namespace before returning", func() {
+                               preTestInode, err := getInode(CurrentNetNS)
+                               Expect(err).NotTo(HaveOccurred())
+
+                               _ = ns.WithNetNS(targetNetNS, false, func(*os.File) error {
+                                       return errors.New("potato")
+                               })
+
+                               postTestInode, err := getInode(CurrentNetNS)
+                               Expect(err).NotTo(HaveOccurred())
+
+                               Expect(postTestInode).To(Equal(preTestInode))
+                       })
+
+                       It("returns the error from the callback", func() {
+                               err := ns.WithNetNS(targetNetNS, false, func(*os.File) error {
+                                       return errors.New("potato")
+                               })
+                               Expect(err).To(MatchError("potato"))
+                       })
+               })
+
+               Describe("validating inode mapping to namespaces", func() {
+                       It("checks that different namespaces have different inodes", func() {
+                               hostNSInode, err := getInode(CurrentNetNS)
+                               Expect(err).NotTo(HaveOccurred())
+
+                               testNsInode, err := getInode(targetNetNSPath)
+                               Expect(err).NotTo(HaveOccurred())
+
+                               Expect(hostNSInode).NotTo(Equal(0))
+                               Expect(testNsInode).NotTo(Equal(0))
+                               Expect(testNsInode).NotTo(Equal(hostNSInode))
+                       })
+               })
+       })
+})
diff --git a/test b/test
index 3811771..93327aa 100755 (executable)
--- a/test
+++ b/test
@@ -1,6 +1,6 @@
 #!/usr/bin/env bash
 #
-# Run all CNI tests 
+# Run all CNI tests
 #   ./test
 #   ./test -v
 #
@@ -11,7 +11,7 @@ set -e
 
 source ./build
 
-TESTABLE="plugins/ipam/dhcp plugins/main/loopback pkg/invoke"
+TESTABLE="plugins/ipam/dhcp plugins/main/loopback pkg/invoke pkg/ns"
 FORMATTABLE="$TESTABLE libcni pkg/ip pkg/ns pkg/types pkg/ipam pkg/skel plugins/ipam/host-local plugins/main/bridge plugins/meta/flannel plugins/meta/tuning"
 
 # user has not provided PKG override